##// END OF EJS Templates
Rename callbacks -> events (mostly), fire -> trigger
Thomas Kluyver -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,129 +1,129 b''
1 """Infrastructure for registering and firing callbacks.
1 """Infrastructure for registering and firing callbacks on application events.
2
2
3 Unlike :mod:`IPython.core.hooks`, which lets end users set single functions to
3 Unlike :mod:`IPython.core.hooks`, which lets end users set single functions to
4 be called at specific times, or a collection of alternative methods to try,
4 be called at specific times, or a collection of alternative methods to try,
5 callbacks are designed to be used by extension authors. A number of callbacks
5 callbacks are designed to be used by extension authors. A number of callbacks
6 can be registered for the same event without needing to be aware of one another.
6 can be registered for the same event without needing to be aware of one another.
7
7
8 The functions defined in this module are no-ops indicating the names of available
8 The functions defined in this module are no-ops indicating the names of available
9 events and the arguments which will be passed to them.
9 events and the arguments which will be passed to them.
10 """
10 """
11 from __future__ import print_function
11 from __future__ import print_function
12
12
13 class CallbackManager(object):
13 class EventManager(object):
14 """Manage a collection of events and a sequence of callbacks for each.
14 """Manage a collection of events and a sequence of callbacks for each.
15
15
16 This is attached to :class:`~IPython.core.interactiveshell.InteractiveShell`
16 This is attached to :class:`~IPython.core.interactiveshell.InteractiveShell`
17 instances as a ``callbacks`` attribute.
17 instances as a ``callbacks`` attribute.
18 """
18 """
19 def __init__(self, shell, available_callbacks):
19 def __init__(self, shell, available_events):
20 """Initialise the :class:`CallbackManager`.
20 """Initialise the :class:`CallbackManager`.
21
21
22 Parameters
22 Parameters
23 ----------
23 ----------
24 shell
24 shell
25 The :class:`~IPython.core.interactiveshell.InteractiveShell` instance
25 The :class:`~IPython.core.interactiveshell.InteractiveShell` instance
26 available_callbacks
26 available_callbacks
27 An iterable of names for callback events.
27 An iterable of names for callback events.
28 """
28 """
29 self.shell = shell
29 self.shell = shell
30 self.callbacks = {n:[] for n in available_callbacks}
30 self.callbacks = {n:[] for n in available_events}
31
31
32 def register(self, name, function):
32 def register(self, event, function):
33 """Register a new callback
33 """Register a new callback
34
34
35 Parameters
35 Parameters
36 ----------
36 ----------
37 name : str
37 event : str
38 The event for which to register this callback.
38 The event for which to register this callback.
39 function : callable
39 function : callable
40 A function to be called on the given event. It should take the same
40 A function to be called on the given event. It should take the same
41 parameters as the appropriate callback prototype.
41 parameters as the appropriate callback prototype.
42
42
43 Raises
43 Raises
44 ------
44 ------
45 TypeError
45 TypeError
46 If ``function`` is not callable.
46 If ``function`` is not callable.
47 KeyError
47 KeyError
48 If ``name`` is not one of the known callback events.
48 If ``event`` is not one of the known events.
49 """
49 """
50 if not callable(function):
50 if not callable(function):
51 raise TypeError('Need a callable, got %r' % function)
51 raise TypeError('Need a callable, got %r' % function)
52 self.callbacks[name].append(function)
52 self.callbacks[event].append(function)
53
53
54 def unregister(self, name, function):
54 def unregister(self, event, function):
55 """Remove a callback from the given event."""
55 """Remove a callback from the given event."""
56 self.callbacks[name].remove(function)
56 self.callbacks[event].remove(function)
57
57
58 def reset(self, name):
58 def reset(self, event):
59 """Clear all callbacks for the given event."""
59 """Clear all callbacks for the given event."""
60 self.callbacks[name] = []
60 self.callbacks[event] = []
61
61
62 def reset_all(self):
62 def reset_all(self):
63 """Clear all callbacks for all events."""
63 """Clear all callbacks for all events."""
64 self.callbacks = {n:[] for n in self.callbacks}
64 self.callbacks = {n:[] for n in self.callbacks}
65
65
66 def fire(self, name, *args, **kwargs):
66 def trigger(self, event, *args, **kwargs):
67 """Call callbacks for the event ``name``.
67 """Call callbacks for ``event``.
68
68
69 Any additional arguments are passed to all callbacks registered for this
69 Any additional arguments are passed to all callbacks registered for this
70 event. Exceptions raised by callbacks are caught, and a message printed.
70 event. Exceptions raised by callbacks are caught, and a message printed.
71 """
71 """
72 for func in self.callbacks[name]:
72 for func in self.callbacks[event]:
73 try:
73 try:
74 func(*args, **kwargs)
74 func(*args, **kwargs)
75 except Exception:
75 except Exception:
76 print("Error in callback {} (for {}):".format(func, name))
76 print("Error in callback {} (for {}):".format(func, event))
77 self.shell.showtraceback()
77 self.shell.showtraceback()
78
78
79 # event_name -> prototype mapping
79 # event_name -> prototype mapping
80 available_callbacks = {}
80 available_events = {}
81
81
82 def _collect(callback_proto):
82 def _collect(callback_proto):
83 available_callbacks[callback_proto.__name__] = callback_proto
83 available_events[callback_proto.__name__] = callback_proto
84 return callback_proto
84 return callback_proto
85
85
86 # ------------------------------------------------------------------------------
86 # ------------------------------------------------------------------------------
87 # Callback prototypes
87 # Callback prototypes
88 #
88 #
89 # No-op functions which describe the names of available events and the
89 # No-op functions which describe the names of available events and the
90 # signatures of callbacks for those events.
90 # signatures of callbacks for those events.
91 # ------------------------------------------------------------------------------
91 # ------------------------------------------------------------------------------
92
92
93 @_collect
93 @_collect
94 def pre_execute():
94 def pre_execute():
95 """Fires before code is executed in response to user/frontend action.
95 """Fires before code is executed in response to user/frontend action.
96
96
97 This includes comm and widget messages as well as user code cells."""
97 This includes comm and widget messages as well as user code cells."""
98 pass
98 pass
99
99
100 @_collect
100 @_collect
101 def pre_execute_explicit():
101 def pre_execute_explicit():
102 """Fires before user-entered code runs."""
102 """Fires before user-entered code runs."""
103 pass
103 pass
104
104
105 @_collect
105 @_collect
106 def post_execute():
106 def post_execute():
107 """Fires after code is executed in response to user/frontend action.
107 """Fires after code is executed in response to user/frontend action.
108
108
109 This includes comm and widget messages as well as user code cells."""
109 This includes comm and widget messages as well as user code cells."""
110 pass
110 pass
111
111
112 @_collect
112 @_collect
113 def post_execute_explicit():
113 def post_execute_explicit():
114 """Fires after user-entered code runs."""
114 """Fires after user-entered code runs."""
115 pass
115 pass
116
116
117 @_collect
117 @_collect
118 def shell_initialised(ip):
118 def shell_initialised(ip):
119 """Fires after initialisation of :class:`~IPython.core.interactiveshell.InteractiveShell`.
119 """Fires after initialisation of :class:`~IPython.core.interactiveshell.InteractiveShell`.
120
120
121 This is before extensions and startup scripts are loaded, so it can only be
121 This is before extensions and startup scripts are loaded, so it can only be
122 set by subclassing.
122 set by subclassing.
123
123
124 Parameters
124 Parameters
125 ----------
125 ----------
126 ip : :class:`~IPython.core.interactiveshell.InteractiveShell`
126 ip : :class:`~IPython.core.interactiveshell.InteractiveShell`
127 The newly initialised shell.
127 The newly initialised shell.
128 """
128 """
129 pass
129 pass
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
@@ -1,46 +1,46 b''
1 import unittest
1 import unittest
2 try: # Python 3.3 +
2 try: # Python 3.3 +
3 from unittest.mock import Mock
3 from unittest.mock import Mock
4 except ImportError:
4 except ImportError:
5 from mock import Mock
5 from mock import Mock
6
6
7 from IPython.core import callbacks
7 from IPython.core import events
8 import IPython.testing.tools as tt
8 import IPython.testing.tools as tt
9
9
10 def ping_received():
10 def ping_received():
11 pass
11 pass
12
12
13 class CallbackTests(unittest.TestCase):
13 class CallbackTests(unittest.TestCase):
14 def setUp(self):
14 def setUp(self):
15 self.cbm = callbacks.CallbackManager(get_ipython(), {'ping_received': ping_received})
15 self.em = events.EventManager(get_ipython(), {'ping_received': ping_received})
16
16
17 def test_register_unregister(self):
17 def test_register_unregister(self):
18 cb = Mock()
18 cb = Mock()
19
19
20 self.cbm.register('ping_received', cb)
20 self.em.register('ping_received', cb)
21 self.cbm.fire('ping_received')
21 self.em.trigger('ping_received')
22 self.assertEqual(cb.call_count, 1)
22 self.assertEqual(cb.call_count, 1)
23
23
24 self.cbm.unregister('ping_received', cb)
24 self.em.unregister('ping_received', cb)
25 self.cbm.fire('ping_received')
25 self.em.trigger('ping_received')
26 self.assertEqual(cb.call_count, 1)
26 self.assertEqual(cb.call_count, 1)
27
27
28 def test_reset(self):
28 def test_reset(self):
29 cb = Mock()
29 cb = Mock()
30 self.cbm.register('ping_received', cb)
30 self.em.register('ping_received', cb)
31 self.cbm.reset('ping_received')
31 self.em.reset('ping_received')
32 self.cbm.fire('ping_received')
32 self.em.trigger('ping_received')
33 assert not cb.called
33 assert not cb.called
34
34
35 def test_reset_all(self):
35 def test_reset_all(self):
36 cb = Mock()
36 cb = Mock()
37 self.cbm.register('ping_received', cb)
37 self.em.register('ping_received', cb)
38 self.cbm.reset_all()
38 self.em.reset_all()
39 self.cbm.fire('ping_received')
39 self.em.trigger('ping_received')
40 assert not cb.called
40 assert not cb.called
41
41
42 def test_cb_error(self):
42 def test_cb_error(self):
43 cb = Mock(side_effect=ValueError)
43 cb = Mock(side_effect=ValueError)
44 self.cbm.register('ping_received', cb)
44 self.em.register('ping_received', cb)
45 with tt.AssertPrints("Error in callback"):
45 with tt.AssertPrints("Error in callback"):
46 self.cbm.fire('ping_received') No newline at end of file
46 self.em.trigger('ping_received') No newline at end of file
@@ -1,730 +1,730 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7
7
8 Authors
8 Authors
9 -------
9 -------
10 * Fernando Perez
10 * Fernando Perez
11 """
11 """
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2011 The IPython Development Team
13 # Copyright (C) 2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # stdlib
22 # stdlib
23 import ast
23 import ast
24 import os
24 import os
25 import signal
25 import signal
26 import shutil
26 import shutil
27 import sys
27 import sys
28 import tempfile
28 import tempfile
29 import unittest
29 import unittest
30 try:
30 try:
31 from unittest import mock
31 from unittest import mock
32 except ImportError:
32 except ImportError:
33 import mock
33 import mock
34 from os.path import join
34 from os.path import join
35
35
36 # third-party
36 # third-party
37 import nose.tools as nt
37 import nose.tools as nt
38
38
39 # Our own
39 # Our own
40 from IPython.core.inputtransformer import InputTransformer
40 from IPython.core.inputtransformer import InputTransformer
41 from IPython.testing.decorators import skipif, skip_win32, onlyif_unicode_paths
41 from IPython.testing.decorators import skipif, skip_win32, onlyif_unicode_paths
42 from IPython.testing import tools as tt
42 from IPython.testing import tools as tt
43 from IPython.utils import io
43 from IPython.utils import io
44 from IPython.utils import py3compat
44 from IPython.utils import py3compat
45 from IPython.utils.py3compat import unicode_type, PY3
45 from IPython.utils.py3compat import unicode_type, PY3
46
46
47 if PY3:
47 if PY3:
48 from io import StringIO
48 from io import StringIO
49 else:
49 else:
50 from StringIO import StringIO
50 from StringIO import StringIO
51
51
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53 # Globals
53 # Globals
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55 # This is used by every single test, no point repeating it ad nauseam
55 # This is used by every single test, no point repeating it ad nauseam
56 ip = get_ipython()
56 ip = get_ipython()
57
57
58 #-----------------------------------------------------------------------------
58 #-----------------------------------------------------------------------------
59 # Tests
59 # Tests
60 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
61
61
62 class InteractiveShellTestCase(unittest.TestCase):
62 class InteractiveShellTestCase(unittest.TestCase):
63 def test_naked_string_cells(self):
63 def test_naked_string_cells(self):
64 """Test that cells with only naked strings are fully executed"""
64 """Test that cells with only naked strings are fully executed"""
65 # First, single-line inputs
65 # First, single-line inputs
66 ip.run_cell('"a"\n')
66 ip.run_cell('"a"\n')
67 self.assertEqual(ip.user_ns['_'], 'a')
67 self.assertEqual(ip.user_ns['_'], 'a')
68 # And also multi-line cells
68 # And also multi-line cells
69 ip.run_cell('"""a\nb"""\n')
69 ip.run_cell('"""a\nb"""\n')
70 self.assertEqual(ip.user_ns['_'], 'a\nb')
70 self.assertEqual(ip.user_ns['_'], 'a\nb')
71
71
72 def test_run_empty_cell(self):
72 def test_run_empty_cell(self):
73 """Just make sure we don't get a horrible error with a blank
73 """Just make sure we don't get a horrible error with a blank
74 cell of input. Yes, I did overlook that."""
74 cell of input. Yes, I did overlook that."""
75 old_xc = ip.execution_count
75 old_xc = ip.execution_count
76 ip.run_cell('')
76 ip.run_cell('')
77 self.assertEqual(ip.execution_count, old_xc)
77 self.assertEqual(ip.execution_count, old_xc)
78
78
79 def test_run_cell_multiline(self):
79 def test_run_cell_multiline(self):
80 """Multi-block, multi-line cells must execute correctly.
80 """Multi-block, multi-line cells must execute correctly.
81 """
81 """
82 src = '\n'.join(["x=1",
82 src = '\n'.join(["x=1",
83 "y=2",
83 "y=2",
84 "if 1:",
84 "if 1:",
85 " x += 1",
85 " x += 1",
86 " y += 1",])
86 " y += 1",])
87 ip.run_cell(src)
87 ip.run_cell(src)
88 self.assertEqual(ip.user_ns['x'], 2)
88 self.assertEqual(ip.user_ns['x'], 2)
89 self.assertEqual(ip.user_ns['y'], 3)
89 self.assertEqual(ip.user_ns['y'], 3)
90
90
91 def test_multiline_string_cells(self):
91 def test_multiline_string_cells(self):
92 "Code sprinkled with multiline strings should execute (GH-306)"
92 "Code sprinkled with multiline strings should execute (GH-306)"
93 ip.run_cell('tmp=0')
93 ip.run_cell('tmp=0')
94 self.assertEqual(ip.user_ns['tmp'], 0)
94 self.assertEqual(ip.user_ns['tmp'], 0)
95 ip.run_cell('tmp=1;"""a\nb"""\n')
95 ip.run_cell('tmp=1;"""a\nb"""\n')
96 self.assertEqual(ip.user_ns['tmp'], 1)
96 self.assertEqual(ip.user_ns['tmp'], 1)
97
97
98 def test_dont_cache_with_semicolon(self):
98 def test_dont_cache_with_semicolon(self):
99 "Ending a line with semicolon should not cache the returned object (GH-307)"
99 "Ending a line with semicolon should not cache the returned object (GH-307)"
100 oldlen = len(ip.user_ns['Out'])
100 oldlen = len(ip.user_ns['Out'])
101 a = ip.run_cell('1;', store_history=True)
101 a = ip.run_cell('1;', store_history=True)
102 newlen = len(ip.user_ns['Out'])
102 newlen = len(ip.user_ns['Out'])
103 self.assertEqual(oldlen, newlen)
103 self.assertEqual(oldlen, newlen)
104 #also test the default caching behavior
104 #also test the default caching behavior
105 ip.run_cell('1', store_history=True)
105 ip.run_cell('1', store_history=True)
106 newlen = len(ip.user_ns['Out'])
106 newlen = len(ip.user_ns['Out'])
107 self.assertEqual(oldlen+1, newlen)
107 self.assertEqual(oldlen+1, newlen)
108
108
109 def test_In_variable(self):
109 def test_In_variable(self):
110 "Verify that In variable grows with user input (GH-284)"
110 "Verify that In variable grows with user input (GH-284)"
111 oldlen = len(ip.user_ns['In'])
111 oldlen = len(ip.user_ns['In'])
112 ip.run_cell('1;', store_history=True)
112 ip.run_cell('1;', store_history=True)
113 newlen = len(ip.user_ns['In'])
113 newlen = len(ip.user_ns['In'])
114 self.assertEqual(oldlen+1, newlen)
114 self.assertEqual(oldlen+1, newlen)
115 self.assertEqual(ip.user_ns['In'][-1],'1;')
115 self.assertEqual(ip.user_ns['In'][-1],'1;')
116
116
117 def test_magic_names_in_string(self):
117 def test_magic_names_in_string(self):
118 ip.run_cell('a = """\n%exit\n"""')
118 ip.run_cell('a = """\n%exit\n"""')
119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
120
120
121 def test_trailing_newline(self):
121 def test_trailing_newline(self):
122 """test that running !(command) does not raise a SyntaxError"""
122 """test that running !(command) does not raise a SyntaxError"""
123 ip.run_cell('!(true)\n', False)
123 ip.run_cell('!(true)\n', False)
124 ip.run_cell('!(true)\n\n\n', False)
124 ip.run_cell('!(true)\n\n\n', False)
125
125
126 def test_gh_597(self):
126 def test_gh_597(self):
127 """Pretty-printing lists of objects with non-ascii reprs may cause
127 """Pretty-printing lists of objects with non-ascii reprs may cause
128 problems."""
128 problems."""
129 class Spam(object):
129 class Spam(object):
130 def __repr__(self):
130 def __repr__(self):
131 return "\xe9"*50
131 return "\xe9"*50
132 import IPython.core.formatters
132 import IPython.core.formatters
133 f = IPython.core.formatters.PlainTextFormatter()
133 f = IPython.core.formatters.PlainTextFormatter()
134 f([Spam(),Spam()])
134 f([Spam(),Spam()])
135
135
136
136
137 def test_future_flags(self):
137 def test_future_flags(self):
138 """Check that future flags are used for parsing code (gh-777)"""
138 """Check that future flags are used for parsing code (gh-777)"""
139 ip.run_cell('from __future__ import print_function')
139 ip.run_cell('from __future__ import print_function')
140 try:
140 try:
141 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
141 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
142 assert 'prfunc_return_val' in ip.user_ns
142 assert 'prfunc_return_val' in ip.user_ns
143 finally:
143 finally:
144 # Reset compiler flags so we don't mess up other tests.
144 # Reset compiler flags so we don't mess up other tests.
145 ip.compile.reset_compiler_flags()
145 ip.compile.reset_compiler_flags()
146
146
147 def test_future_unicode(self):
147 def test_future_unicode(self):
148 """Check that unicode_literals is imported from __future__ (gh #786)"""
148 """Check that unicode_literals is imported from __future__ (gh #786)"""
149 try:
149 try:
150 ip.run_cell(u'byte_str = "a"')
150 ip.run_cell(u'byte_str = "a"')
151 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
151 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
152 ip.run_cell('from __future__ import unicode_literals')
152 ip.run_cell('from __future__ import unicode_literals')
153 ip.run_cell(u'unicode_str = "a"')
153 ip.run_cell(u'unicode_str = "a"')
154 assert isinstance(ip.user_ns['unicode_str'], unicode_type) # strings literals are now unicode
154 assert isinstance(ip.user_ns['unicode_str'], unicode_type) # strings literals are now unicode
155 finally:
155 finally:
156 # Reset compiler flags so we don't mess up other tests.
156 # Reset compiler flags so we don't mess up other tests.
157 ip.compile.reset_compiler_flags()
157 ip.compile.reset_compiler_flags()
158
158
159 def test_can_pickle(self):
159 def test_can_pickle(self):
160 "Can we pickle objects defined interactively (GH-29)"
160 "Can we pickle objects defined interactively (GH-29)"
161 ip = get_ipython()
161 ip = get_ipython()
162 ip.reset()
162 ip.reset()
163 ip.run_cell(("class Mylist(list):\n"
163 ip.run_cell(("class Mylist(list):\n"
164 " def __init__(self,x=[]):\n"
164 " def __init__(self,x=[]):\n"
165 " list.__init__(self,x)"))
165 " list.__init__(self,x)"))
166 ip.run_cell("w=Mylist([1,2,3])")
166 ip.run_cell("w=Mylist([1,2,3])")
167
167
168 from pickle import dumps
168 from pickle import dumps
169
169
170 # We need to swap in our main module - this is only necessary
170 # We need to swap in our main module - this is only necessary
171 # inside the test framework, because IPython puts the interactive module
171 # inside the test framework, because IPython puts the interactive module
172 # in place (but the test framework undoes this).
172 # in place (but the test framework undoes this).
173 _main = sys.modules['__main__']
173 _main = sys.modules['__main__']
174 sys.modules['__main__'] = ip.user_module
174 sys.modules['__main__'] = ip.user_module
175 try:
175 try:
176 res = dumps(ip.user_ns["w"])
176 res = dumps(ip.user_ns["w"])
177 finally:
177 finally:
178 sys.modules['__main__'] = _main
178 sys.modules['__main__'] = _main
179 self.assertTrue(isinstance(res, bytes))
179 self.assertTrue(isinstance(res, bytes))
180
180
181 def test_global_ns(self):
181 def test_global_ns(self):
182 "Code in functions must be able to access variables outside them."
182 "Code in functions must be able to access variables outside them."
183 ip = get_ipython()
183 ip = get_ipython()
184 ip.run_cell("a = 10")
184 ip.run_cell("a = 10")
185 ip.run_cell(("def f(x):\n"
185 ip.run_cell(("def f(x):\n"
186 " return x + a"))
186 " return x + a"))
187 ip.run_cell("b = f(12)")
187 ip.run_cell("b = f(12)")
188 self.assertEqual(ip.user_ns["b"], 22)
188 self.assertEqual(ip.user_ns["b"], 22)
189
189
190 def test_bad_custom_tb(self):
190 def test_bad_custom_tb(self):
191 """Check that InteractiveShell is protected from bad custom exception handlers"""
191 """Check that InteractiveShell is protected from bad custom exception handlers"""
192 from IPython.utils import io
192 from IPython.utils import io
193 save_stderr = io.stderr
193 save_stderr = io.stderr
194 try:
194 try:
195 # capture stderr
195 # capture stderr
196 io.stderr = StringIO()
196 io.stderr = StringIO()
197 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
197 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
198 self.assertEqual(ip.custom_exceptions, (IOError,))
198 self.assertEqual(ip.custom_exceptions, (IOError,))
199 ip.run_cell(u'raise IOError("foo")')
199 ip.run_cell(u'raise IOError("foo")')
200 self.assertEqual(ip.custom_exceptions, ())
200 self.assertEqual(ip.custom_exceptions, ())
201 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
201 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
202 finally:
202 finally:
203 io.stderr = save_stderr
203 io.stderr = save_stderr
204
204
205 def test_bad_custom_tb_return(self):
205 def test_bad_custom_tb_return(self):
206 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
206 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
207 from IPython.utils import io
207 from IPython.utils import io
208 save_stderr = io.stderr
208 save_stderr = io.stderr
209 try:
209 try:
210 # capture stderr
210 # capture stderr
211 io.stderr = StringIO()
211 io.stderr = StringIO()
212 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
212 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
213 self.assertEqual(ip.custom_exceptions, (NameError,))
213 self.assertEqual(ip.custom_exceptions, (NameError,))
214 ip.run_cell(u'a=abracadabra')
214 ip.run_cell(u'a=abracadabra')
215 self.assertEqual(ip.custom_exceptions, ())
215 self.assertEqual(ip.custom_exceptions, ())
216 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
216 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
217 finally:
217 finally:
218 io.stderr = save_stderr
218 io.stderr = save_stderr
219
219
220 def test_drop_by_id(self):
220 def test_drop_by_id(self):
221 myvars = {"a":object(), "b":object(), "c": object()}
221 myvars = {"a":object(), "b":object(), "c": object()}
222 ip.push(myvars, interactive=False)
222 ip.push(myvars, interactive=False)
223 for name in myvars:
223 for name in myvars:
224 assert name in ip.user_ns, name
224 assert name in ip.user_ns, name
225 assert name in ip.user_ns_hidden, name
225 assert name in ip.user_ns_hidden, name
226 ip.user_ns['b'] = 12
226 ip.user_ns['b'] = 12
227 ip.drop_by_id(myvars)
227 ip.drop_by_id(myvars)
228 for name in ["a", "c"]:
228 for name in ["a", "c"]:
229 assert name not in ip.user_ns, name
229 assert name not in ip.user_ns, name
230 assert name not in ip.user_ns_hidden, name
230 assert name not in ip.user_ns_hidden, name
231 assert ip.user_ns['b'] == 12
231 assert ip.user_ns['b'] == 12
232 ip.reset()
232 ip.reset()
233
233
234 def test_var_expand(self):
234 def test_var_expand(self):
235 ip.user_ns['f'] = u'Ca\xf1o'
235 ip.user_ns['f'] = u'Ca\xf1o'
236 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
236 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
237 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
237 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
238 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
238 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
239 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
239 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
240
240
241 ip.user_ns['f'] = b'Ca\xc3\xb1o'
241 ip.user_ns['f'] = b'Ca\xc3\xb1o'
242 # This should not raise any exception:
242 # This should not raise any exception:
243 ip.var_expand(u'echo $f')
243 ip.var_expand(u'echo $f')
244
244
245 def test_var_expand_local(self):
245 def test_var_expand_local(self):
246 """Test local variable expansion in !system and %magic calls"""
246 """Test local variable expansion in !system and %magic calls"""
247 # !system
247 # !system
248 ip.run_cell('def test():\n'
248 ip.run_cell('def test():\n'
249 ' lvar = "ttt"\n'
249 ' lvar = "ttt"\n'
250 ' ret = !echo {lvar}\n'
250 ' ret = !echo {lvar}\n'
251 ' return ret[0]\n')
251 ' return ret[0]\n')
252 res = ip.user_ns['test']()
252 res = ip.user_ns['test']()
253 nt.assert_in('ttt', res)
253 nt.assert_in('ttt', res)
254
254
255 # %magic
255 # %magic
256 ip.run_cell('def makemacro():\n'
256 ip.run_cell('def makemacro():\n'
257 ' macroname = "macro_var_expand_locals"\n'
257 ' macroname = "macro_var_expand_locals"\n'
258 ' %macro {macroname} codestr\n')
258 ' %macro {macroname} codestr\n')
259 ip.user_ns['codestr'] = "str(12)"
259 ip.user_ns['codestr'] = "str(12)"
260 ip.run_cell('makemacro()')
260 ip.run_cell('makemacro()')
261 nt.assert_in('macro_var_expand_locals', ip.user_ns)
261 nt.assert_in('macro_var_expand_locals', ip.user_ns)
262
262
263 def test_var_expand_self(self):
263 def test_var_expand_self(self):
264 """Test variable expansion with the name 'self', which was failing.
264 """Test variable expansion with the name 'self', which was failing.
265
265
266 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
266 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
267 """
267 """
268 ip.run_cell('class cTest:\n'
268 ip.run_cell('class cTest:\n'
269 ' classvar="see me"\n'
269 ' classvar="see me"\n'
270 ' def test(self):\n'
270 ' def test(self):\n'
271 ' res = !echo Variable: {self.classvar}\n'
271 ' res = !echo Variable: {self.classvar}\n'
272 ' return res[0]\n')
272 ' return res[0]\n')
273 nt.assert_in('see me', ip.user_ns['cTest']().test())
273 nt.assert_in('see me', ip.user_ns['cTest']().test())
274
274
275 def test_bad_var_expand(self):
275 def test_bad_var_expand(self):
276 """var_expand on invalid formats shouldn't raise"""
276 """var_expand on invalid formats shouldn't raise"""
277 # SyntaxError
277 # SyntaxError
278 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
278 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
279 # NameError
279 # NameError
280 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
280 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
281 # ZeroDivisionError
281 # ZeroDivisionError
282 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
282 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
283
283
284 def test_silent_postexec(self):
284 def test_silent_postexec(self):
285 """run_cell(silent=True) doesn't invoke pre/post_execute_explicit callbacks"""
285 """run_cell(silent=True) doesn't invoke pre/post_execute_explicit callbacks"""
286 pre_explicit = mock.Mock()
286 pre_explicit = mock.Mock()
287 pre_always = mock.Mock()
287 pre_always = mock.Mock()
288 post_explicit = mock.Mock()
288 post_explicit = mock.Mock()
289 post_always = mock.Mock()
289 post_always = mock.Mock()
290
290
291 ip.callbacks.register('pre_execute_explicit', pre_explicit)
291 ip.events.register('pre_execute_explicit', pre_explicit)
292 ip.callbacks.register('pre_execute', pre_always)
292 ip.events.register('pre_execute', pre_always)
293 ip.callbacks.register('post_execute_explicit', post_explicit)
293 ip.events.register('post_execute_explicit', post_explicit)
294 ip.callbacks.register('post_execute', post_always)
294 ip.events.register('post_execute', post_always)
295
295
296 try:
296 try:
297 ip.run_cell("1", silent=True)
297 ip.run_cell("1", silent=True)
298 assert pre_always.called
298 assert pre_always.called
299 assert not pre_explicit.called
299 assert not pre_explicit.called
300 assert post_always.called
300 assert post_always.called
301 assert not post_explicit.called
301 assert not post_explicit.called
302 # double-check that non-silent exec did what we expected
302 # double-check that non-silent exec did what we expected
303 # silent to avoid
303 # silent to avoid
304 ip.run_cell("1")
304 ip.run_cell("1")
305 assert pre_explicit.called
305 assert pre_explicit.called
306 assert post_explicit.called
306 assert post_explicit.called
307 finally:
307 finally:
308 # remove post-exec
308 # remove post-exec
309 ip.callbacks.reset_all()
309 ip.events.reset_all()
310
310
311 def test_silent_noadvance(self):
311 def test_silent_noadvance(self):
312 """run_cell(silent=True) doesn't advance execution_count"""
312 """run_cell(silent=True) doesn't advance execution_count"""
313 ec = ip.execution_count
313 ec = ip.execution_count
314 # silent should force store_history=False
314 # silent should force store_history=False
315 ip.run_cell("1", store_history=True, silent=True)
315 ip.run_cell("1", store_history=True, silent=True)
316
316
317 self.assertEqual(ec, ip.execution_count)
317 self.assertEqual(ec, ip.execution_count)
318 # double-check that non-silent exec did what we expected
318 # double-check that non-silent exec did what we expected
319 # silent to avoid
319 # silent to avoid
320 ip.run_cell("1", store_history=True)
320 ip.run_cell("1", store_history=True)
321 self.assertEqual(ec+1, ip.execution_count)
321 self.assertEqual(ec+1, ip.execution_count)
322
322
323 def test_silent_nodisplayhook(self):
323 def test_silent_nodisplayhook(self):
324 """run_cell(silent=True) doesn't trigger displayhook"""
324 """run_cell(silent=True) doesn't trigger displayhook"""
325 d = dict(called=False)
325 d = dict(called=False)
326
326
327 trap = ip.display_trap
327 trap = ip.display_trap
328 save_hook = trap.hook
328 save_hook = trap.hook
329
329
330 def failing_hook(*args, **kwargs):
330 def failing_hook(*args, **kwargs):
331 d['called'] = True
331 d['called'] = True
332
332
333 try:
333 try:
334 trap.hook = failing_hook
334 trap.hook = failing_hook
335 ip.run_cell("1", silent=True)
335 ip.run_cell("1", silent=True)
336 self.assertFalse(d['called'])
336 self.assertFalse(d['called'])
337 # double-check that non-silent exec did what we expected
337 # double-check that non-silent exec did what we expected
338 # silent to avoid
338 # silent to avoid
339 ip.run_cell("1")
339 ip.run_cell("1")
340 self.assertTrue(d['called'])
340 self.assertTrue(d['called'])
341 finally:
341 finally:
342 trap.hook = save_hook
342 trap.hook = save_hook
343
343
344 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
344 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
345 def test_print_softspace(self):
345 def test_print_softspace(self):
346 """Verify that softspace is handled correctly when executing multiple
346 """Verify that softspace is handled correctly when executing multiple
347 statements.
347 statements.
348
348
349 In [1]: print 1; print 2
349 In [1]: print 1; print 2
350 1
350 1
351 2
351 2
352
352
353 In [2]: print 1,; print 2
353 In [2]: print 1,; print 2
354 1 2
354 1 2
355 """
355 """
356
356
357 def test_ofind_line_magic(self):
357 def test_ofind_line_magic(self):
358 from IPython.core.magic import register_line_magic
358 from IPython.core.magic import register_line_magic
359
359
360 @register_line_magic
360 @register_line_magic
361 def lmagic(line):
361 def lmagic(line):
362 "A line magic"
362 "A line magic"
363
363
364 # Get info on line magic
364 # Get info on line magic
365 lfind = ip._ofind('lmagic')
365 lfind = ip._ofind('lmagic')
366 info = dict(found=True, isalias=False, ismagic=True,
366 info = dict(found=True, isalias=False, ismagic=True,
367 namespace = 'IPython internal', obj= lmagic.__wrapped__,
367 namespace = 'IPython internal', obj= lmagic.__wrapped__,
368 parent = None)
368 parent = None)
369 nt.assert_equal(lfind, info)
369 nt.assert_equal(lfind, info)
370
370
371 def test_ofind_cell_magic(self):
371 def test_ofind_cell_magic(self):
372 from IPython.core.magic import register_cell_magic
372 from IPython.core.magic import register_cell_magic
373
373
374 @register_cell_magic
374 @register_cell_magic
375 def cmagic(line, cell):
375 def cmagic(line, cell):
376 "A cell magic"
376 "A cell magic"
377
377
378 # Get info on cell magic
378 # Get info on cell magic
379 find = ip._ofind('cmagic')
379 find = ip._ofind('cmagic')
380 info = dict(found=True, isalias=False, ismagic=True,
380 info = dict(found=True, isalias=False, ismagic=True,
381 namespace = 'IPython internal', obj= cmagic.__wrapped__,
381 namespace = 'IPython internal', obj= cmagic.__wrapped__,
382 parent = None)
382 parent = None)
383 nt.assert_equal(find, info)
383 nt.assert_equal(find, info)
384
384
385 def test_custom_exception(self):
385 def test_custom_exception(self):
386 called = []
386 called = []
387 def my_handler(shell, etype, value, tb, tb_offset=None):
387 def my_handler(shell, etype, value, tb, tb_offset=None):
388 called.append(etype)
388 called.append(etype)
389 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
389 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
390
390
391 ip.set_custom_exc((ValueError,), my_handler)
391 ip.set_custom_exc((ValueError,), my_handler)
392 try:
392 try:
393 ip.run_cell("raise ValueError('test')")
393 ip.run_cell("raise ValueError('test')")
394 # Check that this was called, and only once.
394 # Check that this was called, and only once.
395 self.assertEqual(called, [ValueError])
395 self.assertEqual(called, [ValueError])
396 finally:
396 finally:
397 # Reset the custom exception hook
397 # Reset the custom exception hook
398 ip.set_custom_exc((), None)
398 ip.set_custom_exc((), None)
399
399
400 @skipif(sys.version_info[0] >= 3, "no differences with __future__ in py3")
400 @skipif(sys.version_info[0] >= 3, "no differences with __future__ in py3")
401 def test_future_environment(self):
401 def test_future_environment(self):
402 "Can we run code with & without the shell's __future__ imports?"
402 "Can we run code with & without the shell's __future__ imports?"
403 ip.run_cell("from __future__ import division")
403 ip.run_cell("from __future__ import division")
404 ip.run_cell("a = 1/2", shell_futures=True)
404 ip.run_cell("a = 1/2", shell_futures=True)
405 self.assertEqual(ip.user_ns['a'], 0.5)
405 self.assertEqual(ip.user_ns['a'], 0.5)
406 ip.run_cell("b = 1/2", shell_futures=False)
406 ip.run_cell("b = 1/2", shell_futures=False)
407 self.assertEqual(ip.user_ns['b'], 0)
407 self.assertEqual(ip.user_ns['b'], 0)
408
408
409 ip.compile.reset_compiler_flags()
409 ip.compile.reset_compiler_flags()
410 # This shouldn't leak to the shell's compiler
410 # This shouldn't leak to the shell's compiler
411 ip.run_cell("from __future__ import division \nc=1/2", shell_futures=False)
411 ip.run_cell("from __future__ import division \nc=1/2", shell_futures=False)
412 self.assertEqual(ip.user_ns['c'], 0.5)
412 self.assertEqual(ip.user_ns['c'], 0.5)
413 ip.run_cell("d = 1/2", shell_futures=True)
413 ip.run_cell("d = 1/2", shell_futures=True)
414 self.assertEqual(ip.user_ns['d'], 0)
414 self.assertEqual(ip.user_ns['d'], 0)
415
415
416
416
417 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
417 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
418
418
419 @onlyif_unicode_paths
419 @onlyif_unicode_paths
420 def setUp(self):
420 def setUp(self):
421 self.BASETESTDIR = tempfile.mkdtemp()
421 self.BASETESTDIR = tempfile.mkdtemp()
422 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
422 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
423 os.mkdir(self.TESTDIR)
423 os.mkdir(self.TESTDIR)
424 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
424 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
425 sfile.write("pass\n")
425 sfile.write("pass\n")
426 self.oldpath = py3compat.getcwd()
426 self.oldpath = py3compat.getcwd()
427 os.chdir(self.TESTDIR)
427 os.chdir(self.TESTDIR)
428 self.fname = u"Γ₯Àâtestscript.py"
428 self.fname = u"Γ₯Àâtestscript.py"
429
429
430 def tearDown(self):
430 def tearDown(self):
431 os.chdir(self.oldpath)
431 os.chdir(self.oldpath)
432 shutil.rmtree(self.BASETESTDIR)
432 shutil.rmtree(self.BASETESTDIR)
433
433
434 @onlyif_unicode_paths
434 @onlyif_unicode_paths
435 def test_1(self):
435 def test_1(self):
436 """Test safe_execfile with non-ascii path
436 """Test safe_execfile with non-ascii path
437 """
437 """
438 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
438 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
439
439
440 class ExitCodeChecks(tt.TempFileMixin):
440 class ExitCodeChecks(tt.TempFileMixin):
441 def test_exit_code_ok(self):
441 def test_exit_code_ok(self):
442 self.system('exit 0')
442 self.system('exit 0')
443 self.assertEqual(ip.user_ns['_exit_code'], 0)
443 self.assertEqual(ip.user_ns['_exit_code'], 0)
444
444
445 def test_exit_code_error(self):
445 def test_exit_code_error(self):
446 self.system('exit 1')
446 self.system('exit 1')
447 self.assertEqual(ip.user_ns['_exit_code'], 1)
447 self.assertEqual(ip.user_ns['_exit_code'], 1)
448
448
449 @skipif(not hasattr(signal, 'SIGALRM'))
449 @skipif(not hasattr(signal, 'SIGALRM'))
450 def test_exit_code_signal(self):
450 def test_exit_code_signal(self):
451 self.mktmp("import signal, time\n"
451 self.mktmp("import signal, time\n"
452 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
452 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
453 "time.sleep(1)\n")
453 "time.sleep(1)\n")
454 self.system("%s %s" % (sys.executable, self.fname))
454 self.system("%s %s" % (sys.executable, self.fname))
455 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
455 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
456
456
457 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
457 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
458 system = ip.system_raw
458 system = ip.system_raw
459
459
460 @onlyif_unicode_paths
460 @onlyif_unicode_paths
461 def test_1(self):
461 def test_1(self):
462 """Test system_raw with non-ascii cmd
462 """Test system_raw with non-ascii cmd
463 """
463 """
464 cmd = u'''python -c "'Γ₯Àâ'" '''
464 cmd = u'''python -c "'Γ₯Àâ'" '''
465 ip.system_raw(cmd)
465 ip.system_raw(cmd)
466
466
467 # TODO: Exit codes are currently ignored on Windows.
467 # TODO: Exit codes are currently ignored on Windows.
468 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
468 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
469 system = ip.system_piped
469 system = ip.system_piped
470
470
471 @skip_win32
471 @skip_win32
472 def test_exit_code_ok(self):
472 def test_exit_code_ok(self):
473 ExitCodeChecks.test_exit_code_ok(self)
473 ExitCodeChecks.test_exit_code_ok(self)
474
474
475 @skip_win32
475 @skip_win32
476 def test_exit_code_error(self):
476 def test_exit_code_error(self):
477 ExitCodeChecks.test_exit_code_error(self)
477 ExitCodeChecks.test_exit_code_error(self)
478
478
479 @skip_win32
479 @skip_win32
480 def test_exit_code_signal(self):
480 def test_exit_code_signal(self):
481 ExitCodeChecks.test_exit_code_signal(self)
481 ExitCodeChecks.test_exit_code_signal(self)
482
482
483 class TestModules(unittest.TestCase, tt.TempFileMixin):
483 class TestModules(unittest.TestCase, tt.TempFileMixin):
484 def test_extraneous_loads(self):
484 def test_extraneous_loads(self):
485 """Test we're not loading modules on startup that we shouldn't.
485 """Test we're not loading modules on startup that we shouldn't.
486 """
486 """
487 self.mktmp("import sys\n"
487 self.mktmp("import sys\n"
488 "print('numpy' in sys.modules)\n"
488 "print('numpy' in sys.modules)\n"
489 "print('IPython.parallel' in sys.modules)\n"
489 "print('IPython.parallel' in sys.modules)\n"
490 "print('IPython.kernel.zmq' in sys.modules)\n"
490 "print('IPython.kernel.zmq' in sys.modules)\n"
491 )
491 )
492 out = "False\nFalse\nFalse\n"
492 out = "False\nFalse\nFalse\n"
493 tt.ipexec_validate(self.fname, out)
493 tt.ipexec_validate(self.fname, out)
494
494
495 class Negator(ast.NodeTransformer):
495 class Negator(ast.NodeTransformer):
496 """Negates all number literals in an AST."""
496 """Negates all number literals in an AST."""
497 def visit_Num(self, node):
497 def visit_Num(self, node):
498 node.n = -node.n
498 node.n = -node.n
499 return node
499 return node
500
500
501 class TestAstTransform(unittest.TestCase):
501 class TestAstTransform(unittest.TestCase):
502 def setUp(self):
502 def setUp(self):
503 self.negator = Negator()
503 self.negator = Negator()
504 ip.ast_transformers.append(self.negator)
504 ip.ast_transformers.append(self.negator)
505
505
506 def tearDown(self):
506 def tearDown(self):
507 ip.ast_transformers.remove(self.negator)
507 ip.ast_transformers.remove(self.negator)
508
508
509 def test_run_cell(self):
509 def test_run_cell(self):
510 with tt.AssertPrints('-34'):
510 with tt.AssertPrints('-34'):
511 ip.run_cell('print (12 + 22)')
511 ip.run_cell('print (12 + 22)')
512
512
513 # A named reference to a number shouldn't be transformed.
513 # A named reference to a number shouldn't be transformed.
514 ip.user_ns['n'] = 55
514 ip.user_ns['n'] = 55
515 with tt.AssertNotPrints('-55'):
515 with tt.AssertNotPrints('-55'):
516 ip.run_cell('print (n)')
516 ip.run_cell('print (n)')
517
517
518 def test_timeit(self):
518 def test_timeit(self):
519 called = set()
519 called = set()
520 def f(x):
520 def f(x):
521 called.add(x)
521 called.add(x)
522 ip.push({'f':f})
522 ip.push({'f':f})
523
523
524 with tt.AssertPrints("best of "):
524 with tt.AssertPrints("best of "):
525 ip.run_line_magic("timeit", "-n1 f(1)")
525 ip.run_line_magic("timeit", "-n1 f(1)")
526 self.assertEqual(called, set([-1]))
526 self.assertEqual(called, set([-1]))
527 called.clear()
527 called.clear()
528
528
529 with tt.AssertPrints("best of "):
529 with tt.AssertPrints("best of "):
530 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
530 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
531 self.assertEqual(called, set([-2, -3]))
531 self.assertEqual(called, set([-2, -3]))
532
532
533 def test_time(self):
533 def test_time(self):
534 called = []
534 called = []
535 def f(x):
535 def f(x):
536 called.append(x)
536 called.append(x)
537 ip.push({'f':f})
537 ip.push({'f':f})
538
538
539 # Test with an expression
539 # Test with an expression
540 with tt.AssertPrints("Wall time: "):
540 with tt.AssertPrints("Wall time: "):
541 ip.run_line_magic("time", "f(5+9)")
541 ip.run_line_magic("time", "f(5+9)")
542 self.assertEqual(called, [-14])
542 self.assertEqual(called, [-14])
543 called[:] = []
543 called[:] = []
544
544
545 # Test with a statement (different code path)
545 # Test with a statement (different code path)
546 with tt.AssertPrints("Wall time: "):
546 with tt.AssertPrints("Wall time: "):
547 ip.run_line_magic("time", "a = f(-3 + -2)")
547 ip.run_line_magic("time", "a = f(-3 + -2)")
548 self.assertEqual(called, [5])
548 self.assertEqual(called, [5])
549
549
550 def test_macro(self):
550 def test_macro(self):
551 ip.push({'a':10})
551 ip.push({'a':10})
552 # The AST transformation makes this do a+=-1
552 # The AST transformation makes this do a+=-1
553 ip.define_macro("amacro", "a+=1\nprint(a)")
553 ip.define_macro("amacro", "a+=1\nprint(a)")
554
554
555 with tt.AssertPrints("9"):
555 with tt.AssertPrints("9"):
556 ip.run_cell("amacro")
556 ip.run_cell("amacro")
557 with tt.AssertPrints("8"):
557 with tt.AssertPrints("8"):
558 ip.run_cell("amacro")
558 ip.run_cell("amacro")
559
559
560 class IntegerWrapper(ast.NodeTransformer):
560 class IntegerWrapper(ast.NodeTransformer):
561 """Wraps all integers in a call to Integer()"""
561 """Wraps all integers in a call to Integer()"""
562 def visit_Num(self, node):
562 def visit_Num(self, node):
563 if isinstance(node.n, int):
563 if isinstance(node.n, int):
564 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
564 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
565 args=[node], keywords=[])
565 args=[node], keywords=[])
566 return node
566 return node
567
567
568 class TestAstTransform2(unittest.TestCase):
568 class TestAstTransform2(unittest.TestCase):
569 def setUp(self):
569 def setUp(self):
570 self.intwrapper = IntegerWrapper()
570 self.intwrapper = IntegerWrapper()
571 ip.ast_transformers.append(self.intwrapper)
571 ip.ast_transformers.append(self.intwrapper)
572
572
573 self.calls = []
573 self.calls = []
574 def Integer(*args):
574 def Integer(*args):
575 self.calls.append(args)
575 self.calls.append(args)
576 return args
576 return args
577 ip.push({"Integer": Integer})
577 ip.push({"Integer": Integer})
578
578
579 def tearDown(self):
579 def tearDown(self):
580 ip.ast_transformers.remove(self.intwrapper)
580 ip.ast_transformers.remove(self.intwrapper)
581 del ip.user_ns['Integer']
581 del ip.user_ns['Integer']
582
582
583 def test_run_cell(self):
583 def test_run_cell(self):
584 ip.run_cell("n = 2")
584 ip.run_cell("n = 2")
585 self.assertEqual(self.calls, [(2,)])
585 self.assertEqual(self.calls, [(2,)])
586
586
587 # This shouldn't throw an error
587 # This shouldn't throw an error
588 ip.run_cell("o = 2.0")
588 ip.run_cell("o = 2.0")
589 self.assertEqual(ip.user_ns['o'], 2.0)
589 self.assertEqual(ip.user_ns['o'], 2.0)
590
590
591 def test_timeit(self):
591 def test_timeit(self):
592 called = set()
592 called = set()
593 def f(x):
593 def f(x):
594 called.add(x)
594 called.add(x)
595 ip.push({'f':f})
595 ip.push({'f':f})
596
596
597 with tt.AssertPrints("best of "):
597 with tt.AssertPrints("best of "):
598 ip.run_line_magic("timeit", "-n1 f(1)")
598 ip.run_line_magic("timeit", "-n1 f(1)")
599 self.assertEqual(called, set([(1,)]))
599 self.assertEqual(called, set([(1,)]))
600 called.clear()
600 called.clear()
601
601
602 with tt.AssertPrints("best of "):
602 with tt.AssertPrints("best of "):
603 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
603 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
604 self.assertEqual(called, set([(2,), (3,)]))
604 self.assertEqual(called, set([(2,), (3,)]))
605
605
606 class ErrorTransformer(ast.NodeTransformer):
606 class ErrorTransformer(ast.NodeTransformer):
607 """Throws an error when it sees a number."""
607 """Throws an error when it sees a number."""
608 def visit_Num(self):
608 def visit_Num(self):
609 raise ValueError("test")
609 raise ValueError("test")
610
610
611 class TestAstTransformError(unittest.TestCase):
611 class TestAstTransformError(unittest.TestCase):
612 def test_unregistering(self):
612 def test_unregistering(self):
613 err_transformer = ErrorTransformer()
613 err_transformer = ErrorTransformer()
614 ip.ast_transformers.append(err_transformer)
614 ip.ast_transformers.append(err_transformer)
615
615
616 with tt.AssertPrints("unregister", channel='stderr'):
616 with tt.AssertPrints("unregister", channel='stderr'):
617 ip.run_cell("1 + 2")
617 ip.run_cell("1 + 2")
618
618
619 # This should have been removed.
619 # This should have been removed.
620 nt.assert_not_in(err_transformer, ip.ast_transformers)
620 nt.assert_not_in(err_transformer, ip.ast_transformers)
621
621
622 def test__IPYTHON__():
622 def test__IPYTHON__():
623 # This shouldn't raise a NameError, that's all
623 # This shouldn't raise a NameError, that's all
624 __IPYTHON__
624 __IPYTHON__
625
625
626
626
627 class DummyRepr(object):
627 class DummyRepr(object):
628 def __repr__(self):
628 def __repr__(self):
629 return "DummyRepr"
629 return "DummyRepr"
630
630
631 def _repr_html_(self):
631 def _repr_html_(self):
632 return "<b>dummy</b>"
632 return "<b>dummy</b>"
633
633
634 def _repr_javascript_(self):
634 def _repr_javascript_(self):
635 return "console.log('hi');", {'key': 'value'}
635 return "console.log('hi');", {'key': 'value'}
636
636
637
637
638 def test_user_variables():
638 def test_user_variables():
639 # enable all formatters
639 # enable all formatters
640 ip.display_formatter.active_types = ip.display_formatter.format_types
640 ip.display_formatter.active_types = ip.display_formatter.format_types
641
641
642 ip.user_ns['dummy'] = d = DummyRepr()
642 ip.user_ns['dummy'] = d = DummyRepr()
643 keys = set(['dummy', 'doesnotexist'])
643 keys = set(['dummy', 'doesnotexist'])
644 r = ip.user_variables(keys)
644 r = ip.user_variables(keys)
645
645
646 nt.assert_equal(keys, set(r.keys()))
646 nt.assert_equal(keys, set(r.keys()))
647 dummy = r['dummy']
647 dummy = r['dummy']
648 nt.assert_equal(set(['status', 'data', 'metadata']), set(dummy.keys()))
648 nt.assert_equal(set(['status', 'data', 'metadata']), set(dummy.keys()))
649 nt.assert_equal(dummy['status'], 'ok')
649 nt.assert_equal(dummy['status'], 'ok')
650 data = dummy['data']
650 data = dummy['data']
651 metadata = dummy['metadata']
651 metadata = dummy['metadata']
652 nt.assert_equal(data.get('text/html'), d._repr_html_())
652 nt.assert_equal(data.get('text/html'), d._repr_html_())
653 js, jsmd = d._repr_javascript_()
653 js, jsmd = d._repr_javascript_()
654 nt.assert_equal(data.get('application/javascript'), js)
654 nt.assert_equal(data.get('application/javascript'), js)
655 nt.assert_equal(metadata.get('application/javascript'), jsmd)
655 nt.assert_equal(metadata.get('application/javascript'), jsmd)
656
656
657 dne = r['doesnotexist']
657 dne = r['doesnotexist']
658 nt.assert_equal(dne['status'], 'error')
658 nt.assert_equal(dne['status'], 'error')
659 nt.assert_equal(dne['ename'], 'KeyError')
659 nt.assert_equal(dne['ename'], 'KeyError')
660
660
661 # back to text only
661 # back to text only
662 ip.display_formatter.active_types = ['text/plain']
662 ip.display_formatter.active_types = ['text/plain']
663
663
664 def test_user_expression():
664 def test_user_expression():
665 # enable all formatters
665 # enable all formatters
666 ip.display_formatter.active_types = ip.display_formatter.format_types
666 ip.display_formatter.active_types = ip.display_formatter.format_types
667 query = {
667 query = {
668 'a' : '1 + 2',
668 'a' : '1 + 2',
669 'b' : '1/0',
669 'b' : '1/0',
670 }
670 }
671 r = ip.user_expressions(query)
671 r = ip.user_expressions(query)
672 import pprint
672 import pprint
673 pprint.pprint(r)
673 pprint.pprint(r)
674 nt.assert_equal(set(r.keys()), set(query.keys()))
674 nt.assert_equal(set(r.keys()), set(query.keys()))
675 a = r['a']
675 a = r['a']
676 nt.assert_equal(set(['status', 'data', 'metadata']), set(a.keys()))
676 nt.assert_equal(set(['status', 'data', 'metadata']), set(a.keys()))
677 nt.assert_equal(a['status'], 'ok')
677 nt.assert_equal(a['status'], 'ok')
678 data = a['data']
678 data = a['data']
679 metadata = a['metadata']
679 metadata = a['metadata']
680 nt.assert_equal(data.get('text/plain'), '3')
680 nt.assert_equal(data.get('text/plain'), '3')
681
681
682 b = r['b']
682 b = r['b']
683 nt.assert_equal(b['status'], 'error')
683 nt.assert_equal(b['status'], 'error')
684 nt.assert_equal(b['ename'], 'ZeroDivisionError')
684 nt.assert_equal(b['ename'], 'ZeroDivisionError')
685
685
686 # back to text only
686 # back to text only
687 ip.display_formatter.active_types = ['text/plain']
687 ip.display_formatter.active_types = ['text/plain']
688
688
689
689
690
690
691
691
692
692
693 class TestSyntaxErrorTransformer(unittest.TestCase):
693 class TestSyntaxErrorTransformer(unittest.TestCase):
694 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
694 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
695
695
696 class SyntaxErrorTransformer(InputTransformer):
696 class SyntaxErrorTransformer(InputTransformer):
697
697
698 def push(self, line):
698 def push(self, line):
699 pos = line.find('syntaxerror')
699 pos = line.find('syntaxerror')
700 if pos >= 0:
700 if pos >= 0:
701 e = SyntaxError('input contains "syntaxerror"')
701 e = SyntaxError('input contains "syntaxerror"')
702 e.text = line
702 e.text = line
703 e.offset = pos + 1
703 e.offset = pos + 1
704 raise e
704 raise e
705 return line
705 return line
706
706
707 def reset(self):
707 def reset(self):
708 pass
708 pass
709
709
710 def setUp(self):
710 def setUp(self):
711 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
711 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
712 ip.input_splitter.python_line_transforms.append(self.transformer)
712 ip.input_splitter.python_line_transforms.append(self.transformer)
713 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
713 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
714
714
715 def tearDown(self):
715 def tearDown(self):
716 ip.input_splitter.python_line_transforms.remove(self.transformer)
716 ip.input_splitter.python_line_transforms.remove(self.transformer)
717 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
717 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
718
718
719 def test_syntaxerror_input_transformer(self):
719 def test_syntaxerror_input_transformer(self):
720 with tt.AssertPrints('1234'):
720 with tt.AssertPrints('1234'):
721 ip.run_cell('1234')
721 ip.run_cell('1234')
722 with tt.AssertPrints('SyntaxError: invalid syntax'):
722 with tt.AssertPrints('SyntaxError: invalid syntax'):
723 ip.run_cell('1 2 3') # plain python syntax error
723 ip.run_cell('1 2 3') # plain python syntax error
724 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
724 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
725 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
725 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
726 with tt.AssertPrints('3456'):
726 with tt.AssertPrints('3456'):
727 ip.run_cell('3456')
727 ip.run_cell('3456')
728
728
729
729
730
730
@@ -1,505 +1,505 b''
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport -foo``
59 ``%aimport -foo``
60
60
61 Mark module 'foo' to not be autoreloaded.
61 Mark module 'foo' to not be autoreloaded.
62
62
63 Caveats
63 Caveats
64 =======
64 =======
65
65
66 Reloading Python modules in a reliable way is in general difficult,
66 Reloading Python modules in a reliable way is in general difficult,
67 and unexpected things may occur. ``%autoreload`` tries to work around
67 and unexpected things may occur. ``%autoreload`` tries to work around
68 common pitfalls by replacing function code objects and parts of
68 common pitfalls by replacing function code objects and parts of
69 classes previously in the module with new versions. This makes the
69 classes previously in the module with new versions. This makes the
70 following things to work:
70 following things to work:
71
71
72 - Functions and classes imported via 'from xxx import foo' are upgraded
72 - Functions and classes imported via 'from xxx import foo' are upgraded
73 to new versions when 'xxx' is reloaded.
73 to new versions when 'xxx' is reloaded.
74
74
75 - Methods and properties of classes are upgraded on reload, so that
75 - Methods and properties of classes are upgraded on reload, so that
76 calling 'c.foo()' on an object 'c' created before the reload causes
76 calling 'c.foo()' on an object 'c' created before the reload causes
77 the new code for 'foo' to be executed.
77 the new code for 'foo' to be executed.
78
78
79 Some of the known remaining caveats are:
79 Some of the known remaining caveats are:
80
80
81 - Replacing code objects does not always succeed: changing a @property
81 - Replacing code objects does not always succeed: changing a @property
82 in a class to an ordinary method or a method to a member variable
82 in a class to an ordinary method or a method to a member variable
83 can cause problems (but in old objects only).
83 can cause problems (but in old objects only).
84
84
85 - Functions that are removed (eg. via monkey-patching) from a module
85 - Functions that are removed (eg. via monkey-patching) from a module
86 before it is reloaded are not upgraded.
86 before it is reloaded are not upgraded.
87
87
88 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
88 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
89 """
89 """
90 from __future__ import print_function
90 from __future__ import print_function
91
91
92 skip_doctest = True
92 skip_doctest = True
93
93
94 #-----------------------------------------------------------------------------
94 #-----------------------------------------------------------------------------
95 # Copyright (C) 2000 Thomas Heller
95 # Copyright (C) 2000 Thomas Heller
96 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
96 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
97 # Copyright (C) 2012 The IPython Development Team
97 # Copyright (C) 2012 The IPython Development Team
98 #
98 #
99 # Distributed under the terms of the BSD License. The full license is in
99 # Distributed under the terms of the BSD License. The full license is in
100 # the file COPYING, distributed as part of this software.
100 # the file COPYING, distributed as part of this software.
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102 #
102 #
103 # This IPython module is written by Pauli Virtanen, based on the autoreload
103 # This IPython module is written by Pauli Virtanen, based on the autoreload
104 # code by Thomas Heller.
104 # code by Thomas Heller.
105
105
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107 # Imports
107 # Imports
108 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
109
109
110 import os
110 import os
111 import sys
111 import sys
112 import traceback
112 import traceback
113 import types
113 import types
114 import weakref
114 import weakref
115
115
116 try:
116 try:
117 # Reload is not defined by default in Python3.
117 # Reload is not defined by default in Python3.
118 reload
118 reload
119 except NameError:
119 except NameError:
120 from imp import reload
120 from imp import reload
121
121
122 from IPython.utils import openpy
122 from IPython.utils import openpy
123 from IPython.utils.py3compat import PY3
123 from IPython.utils.py3compat import PY3
124
124
125 #------------------------------------------------------------------------------
125 #------------------------------------------------------------------------------
126 # Autoreload functionality
126 # Autoreload functionality
127 #------------------------------------------------------------------------------
127 #------------------------------------------------------------------------------
128
128
129 class ModuleReloader(object):
129 class ModuleReloader(object):
130 enabled = False
130 enabled = False
131 """Whether this reloader is enabled"""
131 """Whether this reloader is enabled"""
132
132
133 failed = {}
133 failed = {}
134 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
134 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
135
135
136 modules = {}
136 modules = {}
137 """Modules specially marked as autoreloadable."""
137 """Modules specially marked as autoreloadable."""
138
138
139 skip_modules = {}
139 skip_modules = {}
140 """Modules specially marked as not autoreloadable."""
140 """Modules specially marked as not autoreloadable."""
141
141
142 check_all = True
142 check_all = True
143 """Autoreload all modules, not just those listed in 'modules'"""
143 """Autoreload all modules, not just those listed in 'modules'"""
144
144
145 old_objects = {}
145 old_objects = {}
146 """(module-name, name) -> weakref, for replacing old code objects"""
146 """(module-name, name) -> weakref, for replacing old code objects"""
147
147
148 def mark_module_skipped(self, module_name):
148 def mark_module_skipped(self, module_name):
149 """Skip reloading the named module in the future"""
149 """Skip reloading the named module in the future"""
150 try:
150 try:
151 del self.modules[module_name]
151 del self.modules[module_name]
152 except KeyError:
152 except KeyError:
153 pass
153 pass
154 self.skip_modules[module_name] = True
154 self.skip_modules[module_name] = True
155
155
156 def mark_module_reloadable(self, module_name):
156 def mark_module_reloadable(self, module_name):
157 """Reload the named module in the future (if it is imported)"""
157 """Reload the named module in the future (if it is imported)"""
158 try:
158 try:
159 del self.skip_modules[module_name]
159 del self.skip_modules[module_name]
160 except KeyError:
160 except KeyError:
161 pass
161 pass
162 self.modules[module_name] = True
162 self.modules[module_name] = True
163
163
164 def aimport_module(self, module_name):
164 def aimport_module(self, module_name):
165 """Import a module, and mark it reloadable
165 """Import a module, and mark it reloadable
166
166
167 Returns
167 Returns
168 -------
168 -------
169 top_module : module
169 top_module : module
170 The imported module if it is top-level, or the top-level
170 The imported module if it is top-level, or the top-level
171 top_name : module
171 top_name : module
172 Name of top_module
172 Name of top_module
173
173
174 """
174 """
175 self.mark_module_reloadable(module_name)
175 self.mark_module_reloadable(module_name)
176
176
177 __import__(module_name)
177 __import__(module_name)
178 top_name = module_name.split('.')[0]
178 top_name = module_name.split('.')[0]
179 top_module = sys.modules[top_name]
179 top_module = sys.modules[top_name]
180 return top_module, top_name
180 return top_module, top_name
181
181
182 def check(self, check_all=False):
182 def check(self, check_all=False):
183 """Check whether some modules need to be reloaded."""
183 """Check whether some modules need to be reloaded."""
184
184
185 if not self.enabled and not check_all:
185 if not self.enabled and not check_all:
186 return
186 return
187
187
188 if check_all or self.check_all:
188 if check_all or self.check_all:
189 modules = list(sys.modules.keys())
189 modules = list(sys.modules.keys())
190 else:
190 else:
191 modules = list(self.modules.keys())
191 modules = list(self.modules.keys())
192
192
193 for modname in modules:
193 for modname in modules:
194 m = sys.modules.get(modname, None)
194 m = sys.modules.get(modname, None)
195
195
196 if modname in self.skip_modules:
196 if modname in self.skip_modules:
197 continue
197 continue
198
198
199 if not hasattr(m, '__file__'):
199 if not hasattr(m, '__file__'):
200 continue
200 continue
201
201
202 if m.__name__ == '__main__':
202 if m.__name__ == '__main__':
203 # we cannot reload(__main__)
203 # we cannot reload(__main__)
204 continue
204 continue
205
205
206 filename = m.__file__
206 filename = m.__file__
207 path, ext = os.path.splitext(filename)
207 path, ext = os.path.splitext(filename)
208
208
209 if ext.lower() == '.py':
209 if ext.lower() == '.py':
210 pyc_filename = openpy.cache_from_source(filename)
210 pyc_filename = openpy.cache_from_source(filename)
211 py_filename = filename
211 py_filename = filename
212 else:
212 else:
213 pyc_filename = filename
213 pyc_filename = filename
214 try:
214 try:
215 py_filename = openpy.source_from_cache(filename)
215 py_filename = openpy.source_from_cache(filename)
216 except ValueError:
216 except ValueError:
217 continue
217 continue
218
218
219 try:
219 try:
220 pymtime = os.stat(py_filename).st_mtime
220 pymtime = os.stat(py_filename).st_mtime
221 if pymtime <= os.stat(pyc_filename).st_mtime:
221 if pymtime <= os.stat(pyc_filename).st_mtime:
222 continue
222 continue
223 if self.failed.get(py_filename, None) == pymtime:
223 if self.failed.get(py_filename, None) == pymtime:
224 continue
224 continue
225 except OSError:
225 except OSError:
226 continue
226 continue
227
227
228 try:
228 try:
229 superreload(m, reload, self.old_objects)
229 superreload(m, reload, self.old_objects)
230 if py_filename in self.failed:
230 if py_filename in self.failed:
231 del self.failed[py_filename]
231 del self.failed[py_filename]
232 except:
232 except:
233 print("[autoreload of %s failed: %s]" % (
233 print("[autoreload of %s failed: %s]" % (
234 modname, traceback.format_exc(1)), file=sys.stderr)
234 modname, traceback.format_exc(1)), file=sys.stderr)
235 self.failed[py_filename] = pymtime
235 self.failed[py_filename] = pymtime
236
236
237 #------------------------------------------------------------------------------
237 #------------------------------------------------------------------------------
238 # superreload
238 # superreload
239 #------------------------------------------------------------------------------
239 #------------------------------------------------------------------------------
240
240
241 if PY3:
241 if PY3:
242 func_attrs = ['__code__', '__defaults__', '__doc__',
242 func_attrs = ['__code__', '__defaults__', '__doc__',
243 '__closure__', '__globals__', '__dict__']
243 '__closure__', '__globals__', '__dict__']
244 else:
244 else:
245 func_attrs = ['func_code', 'func_defaults', 'func_doc',
245 func_attrs = ['func_code', 'func_defaults', 'func_doc',
246 'func_closure', 'func_globals', 'func_dict']
246 'func_closure', 'func_globals', 'func_dict']
247
247
248
248
249 def update_function(old, new):
249 def update_function(old, new):
250 """Upgrade the code object of a function"""
250 """Upgrade the code object of a function"""
251 for name in func_attrs:
251 for name in func_attrs:
252 try:
252 try:
253 setattr(old, name, getattr(new, name))
253 setattr(old, name, getattr(new, name))
254 except (AttributeError, TypeError):
254 except (AttributeError, TypeError):
255 pass
255 pass
256
256
257
257
258 def update_class(old, new):
258 def update_class(old, new):
259 """Replace stuff in the __dict__ of a class, and upgrade
259 """Replace stuff in the __dict__ of a class, and upgrade
260 method code objects"""
260 method code objects"""
261 for key in list(old.__dict__.keys()):
261 for key in list(old.__dict__.keys()):
262 old_obj = getattr(old, key)
262 old_obj = getattr(old, key)
263
263
264 try:
264 try:
265 new_obj = getattr(new, key)
265 new_obj = getattr(new, key)
266 except AttributeError:
266 except AttributeError:
267 # obsolete attribute: remove it
267 # obsolete attribute: remove it
268 try:
268 try:
269 delattr(old, key)
269 delattr(old, key)
270 except (AttributeError, TypeError):
270 except (AttributeError, TypeError):
271 pass
271 pass
272 continue
272 continue
273
273
274 if update_generic(old_obj, new_obj): continue
274 if update_generic(old_obj, new_obj): continue
275
275
276 try:
276 try:
277 setattr(old, key, getattr(new, key))
277 setattr(old, key, getattr(new, key))
278 except (AttributeError, TypeError):
278 except (AttributeError, TypeError):
279 pass # skip non-writable attributes
279 pass # skip non-writable attributes
280
280
281
281
282 def update_property(old, new):
282 def update_property(old, new):
283 """Replace get/set/del functions of a property"""
283 """Replace get/set/del functions of a property"""
284 update_generic(old.fdel, new.fdel)
284 update_generic(old.fdel, new.fdel)
285 update_generic(old.fget, new.fget)
285 update_generic(old.fget, new.fget)
286 update_generic(old.fset, new.fset)
286 update_generic(old.fset, new.fset)
287
287
288
288
289 def isinstance2(a, b, typ):
289 def isinstance2(a, b, typ):
290 return isinstance(a, typ) and isinstance(b, typ)
290 return isinstance(a, typ) and isinstance(b, typ)
291
291
292
292
293 UPDATE_RULES = [
293 UPDATE_RULES = [
294 (lambda a, b: isinstance2(a, b, type),
294 (lambda a, b: isinstance2(a, b, type),
295 update_class),
295 update_class),
296 (lambda a, b: isinstance2(a, b, types.FunctionType),
296 (lambda a, b: isinstance2(a, b, types.FunctionType),
297 update_function),
297 update_function),
298 (lambda a, b: isinstance2(a, b, property),
298 (lambda a, b: isinstance2(a, b, property),
299 update_property),
299 update_property),
300 ]
300 ]
301
301
302
302
303 if PY3:
303 if PY3:
304 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
304 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
305 lambda a, b: update_function(a.__func__, b.__func__)),
305 lambda a, b: update_function(a.__func__, b.__func__)),
306 ])
306 ])
307 else:
307 else:
308 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.ClassType),
308 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.ClassType),
309 update_class),
309 update_class),
310 (lambda a, b: isinstance2(a, b, types.MethodType),
310 (lambda a, b: isinstance2(a, b, types.MethodType),
311 lambda a, b: update_function(a.__func__, b.__func__)),
311 lambda a, b: update_function(a.__func__, b.__func__)),
312 ])
312 ])
313
313
314
314
315 def update_generic(a, b):
315 def update_generic(a, b):
316 for type_check, update in UPDATE_RULES:
316 for type_check, update in UPDATE_RULES:
317 if type_check(a, b):
317 if type_check(a, b):
318 update(a, b)
318 update(a, b)
319 return True
319 return True
320 return False
320 return False
321
321
322
322
323 class StrongRef(object):
323 class StrongRef(object):
324 def __init__(self, obj):
324 def __init__(self, obj):
325 self.obj = obj
325 self.obj = obj
326 def __call__(self):
326 def __call__(self):
327 return self.obj
327 return self.obj
328
328
329
329
330 def superreload(module, reload=reload, old_objects={}):
330 def superreload(module, reload=reload, old_objects={}):
331 """Enhanced version of the builtin reload function.
331 """Enhanced version of the builtin reload function.
332
332
333 superreload remembers objects previously in the module, and
333 superreload remembers objects previously in the module, and
334
334
335 - upgrades the class dictionary of every old class in the module
335 - upgrades the class dictionary of every old class in the module
336 - upgrades the code object of every old function and method
336 - upgrades the code object of every old function and method
337 - clears the module's namespace before reloading
337 - clears the module's namespace before reloading
338
338
339 """
339 """
340
340
341 # collect old objects in the module
341 # collect old objects in the module
342 for name, obj in list(module.__dict__.items()):
342 for name, obj in list(module.__dict__.items()):
343 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
343 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
344 continue
344 continue
345 key = (module.__name__, name)
345 key = (module.__name__, name)
346 try:
346 try:
347 old_objects.setdefault(key, []).append(weakref.ref(obj))
347 old_objects.setdefault(key, []).append(weakref.ref(obj))
348 except TypeError:
348 except TypeError:
349 # weakref doesn't work for all types;
349 # weakref doesn't work for all types;
350 # create strong references for 'important' cases
350 # create strong references for 'important' cases
351 if not PY3 and isinstance(obj, types.ClassType):
351 if not PY3 and isinstance(obj, types.ClassType):
352 old_objects.setdefault(key, []).append(StrongRef(obj))
352 old_objects.setdefault(key, []).append(StrongRef(obj))
353
353
354 # reload module
354 # reload module
355 try:
355 try:
356 # clear namespace first from old cruft
356 # clear namespace first from old cruft
357 old_dict = module.__dict__.copy()
357 old_dict = module.__dict__.copy()
358 old_name = module.__name__
358 old_name = module.__name__
359 module.__dict__.clear()
359 module.__dict__.clear()
360 module.__dict__['__name__'] = old_name
360 module.__dict__['__name__'] = old_name
361 module.__dict__['__loader__'] = old_dict['__loader__']
361 module.__dict__['__loader__'] = old_dict['__loader__']
362 except (TypeError, AttributeError, KeyError):
362 except (TypeError, AttributeError, KeyError):
363 pass
363 pass
364
364
365 try:
365 try:
366 module = reload(module)
366 module = reload(module)
367 except:
367 except:
368 # restore module dictionary on failed reload
368 # restore module dictionary on failed reload
369 module.__dict__.update(old_dict)
369 module.__dict__.update(old_dict)
370 raise
370 raise
371
371
372 # iterate over all objects and update functions & classes
372 # iterate over all objects and update functions & classes
373 for name, new_obj in list(module.__dict__.items()):
373 for name, new_obj in list(module.__dict__.items()):
374 key = (module.__name__, name)
374 key = (module.__name__, name)
375 if key not in old_objects: continue
375 if key not in old_objects: continue
376
376
377 new_refs = []
377 new_refs = []
378 for old_ref in old_objects[key]:
378 for old_ref in old_objects[key]:
379 old_obj = old_ref()
379 old_obj = old_ref()
380 if old_obj is None: continue
380 if old_obj is None: continue
381 new_refs.append(old_ref)
381 new_refs.append(old_ref)
382 update_generic(old_obj, new_obj)
382 update_generic(old_obj, new_obj)
383
383
384 if new_refs:
384 if new_refs:
385 old_objects[key] = new_refs
385 old_objects[key] = new_refs
386 else:
386 else:
387 del old_objects[key]
387 del old_objects[key]
388
388
389 return module
389 return module
390
390
391 #------------------------------------------------------------------------------
391 #------------------------------------------------------------------------------
392 # IPython connectivity
392 # IPython connectivity
393 #------------------------------------------------------------------------------
393 #------------------------------------------------------------------------------
394
394
395 from IPython.core.magic import Magics, magics_class, line_magic
395 from IPython.core.magic import Magics, magics_class, line_magic
396
396
397 @magics_class
397 @magics_class
398 class AutoreloadMagics(Magics):
398 class AutoreloadMagics(Magics):
399 def __init__(self, *a, **kw):
399 def __init__(self, *a, **kw):
400 super(AutoreloadMagics, self).__init__(*a, **kw)
400 super(AutoreloadMagics, self).__init__(*a, **kw)
401 self._reloader = ModuleReloader()
401 self._reloader = ModuleReloader()
402 self._reloader.check_all = False
402 self._reloader.check_all = False
403
403
404 @line_magic
404 @line_magic
405 def autoreload(self, parameter_s=''):
405 def autoreload(self, parameter_s=''):
406 r"""%autoreload => Reload modules automatically
406 r"""%autoreload => Reload modules automatically
407
407
408 %autoreload
408 %autoreload
409 Reload all modules (except those excluded by %aimport) automatically
409 Reload all modules (except those excluded by %aimport) automatically
410 now.
410 now.
411
411
412 %autoreload 0
412 %autoreload 0
413 Disable automatic reloading.
413 Disable automatic reloading.
414
414
415 %autoreload 1
415 %autoreload 1
416 Reload all modules imported with %aimport every time before executing
416 Reload all modules imported with %aimport every time before executing
417 the Python code typed.
417 the Python code typed.
418
418
419 %autoreload 2
419 %autoreload 2
420 Reload all modules (except those excluded by %aimport) every time
420 Reload all modules (except those excluded by %aimport) every time
421 before executing the Python code typed.
421 before executing the Python code typed.
422
422
423 Reloading Python modules in a reliable way is in general
423 Reloading Python modules in a reliable way is in general
424 difficult, and unexpected things may occur. %autoreload tries to
424 difficult, and unexpected things may occur. %autoreload tries to
425 work around common pitfalls by replacing function code objects and
425 work around common pitfalls by replacing function code objects and
426 parts of classes previously in the module with new versions. This
426 parts of classes previously in the module with new versions. This
427 makes the following things to work:
427 makes the following things to work:
428
428
429 - Functions and classes imported via 'from xxx import foo' are upgraded
429 - Functions and classes imported via 'from xxx import foo' are upgraded
430 to new versions when 'xxx' is reloaded.
430 to new versions when 'xxx' is reloaded.
431
431
432 - Methods and properties of classes are upgraded on reload, so that
432 - Methods and properties of classes are upgraded on reload, so that
433 calling 'c.foo()' on an object 'c' created before the reload causes
433 calling 'c.foo()' on an object 'c' created before the reload causes
434 the new code for 'foo' to be executed.
434 the new code for 'foo' to be executed.
435
435
436 Some of the known remaining caveats are:
436 Some of the known remaining caveats are:
437
437
438 - Replacing code objects does not always succeed: changing a @property
438 - Replacing code objects does not always succeed: changing a @property
439 in a class to an ordinary method or a method to a member variable
439 in a class to an ordinary method or a method to a member variable
440 can cause problems (but in old objects only).
440 can cause problems (but in old objects only).
441
441
442 - Functions that are removed (eg. via monkey-patching) from a module
442 - Functions that are removed (eg. via monkey-patching) from a module
443 before it is reloaded are not upgraded.
443 before it is reloaded are not upgraded.
444
444
445 - C extension modules cannot be reloaded, and so cannot be
445 - C extension modules cannot be reloaded, and so cannot be
446 autoreloaded.
446 autoreloaded.
447
447
448 """
448 """
449 if parameter_s == '':
449 if parameter_s == '':
450 self._reloader.check(True)
450 self._reloader.check(True)
451 elif parameter_s == '0':
451 elif parameter_s == '0':
452 self._reloader.enabled = False
452 self._reloader.enabled = False
453 elif parameter_s == '1':
453 elif parameter_s == '1':
454 self._reloader.check_all = False
454 self._reloader.check_all = False
455 self._reloader.enabled = True
455 self._reloader.enabled = True
456 elif parameter_s == '2':
456 elif parameter_s == '2':
457 self._reloader.check_all = True
457 self._reloader.check_all = True
458 self._reloader.enabled = True
458 self._reloader.enabled = True
459
459
460 @line_magic
460 @line_magic
461 def aimport(self, parameter_s='', stream=None):
461 def aimport(self, parameter_s='', stream=None):
462 """%aimport => Import modules for automatic reloading.
462 """%aimport => Import modules for automatic reloading.
463
463
464 %aimport
464 %aimport
465 List modules to automatically import and not to import.
465 List modules to automatically import and not to import.
466
466
467 %aimport foo
467 %aimport foo
468 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
468 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
469
469
470 %aimport -foo
470 %aimport -foo
471 Mark module 'foo' to not be autoreloaded for %autoreload 1
471 Mark module 'foo' to not be autoreloaded for %autoreload 1
472 """
472 """
473 modname = parameter_s
473 modname = parameter_s
474 if not modname:
474 if not modname:
475 to_reload = sorted(self._reloader.modules.keys())
475 to_reload = sorted(self._reloader.modules.keys())
476 to_skip = sorted(self._reloader.skip_modules.keys())
476 to_skip = sorted(self._reloader.skip_modules.keys())
477 if stream is None:
477 if stream is None:
478 stream = sys.stdout
478 stream = sys.stdout
479 if self._reloader.check_all:
479 if self._reloader.check_all:
480 stream.write("Modules to reload:\nall-except-skipped\n")
480 stream.write("Modules to reload:\nall-except-skipped\n")
481 else:
481 else:
482 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
482 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
483 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
483 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
484 elif modname.startswith('-'):
484 elif modname.startswith('-'):
485 modname = modname[1:]
485 modname = modname[1:]
486 self._reloader.mark_module_skipped(modname)
486 self._reloader.mark_module_skipped(modname)
487 else:
487 else:
488 top_module, top_name = self._reloader.aimport_module(modname)
488 top_module, top_name = self._reloader.aimport_module(modname)
489
489
490 # Inject module to user namespace
490 # Inject module to user namespace
491 self.shell.push({top_name: top_module})
491 self.shell.push({top_name: top_module})
492
492
493 def pre_execute_explicit(self):
493 def pre_execute_explicit(self):
494 if self._reloader.enabled:
494 if self._reloader.enabled:
495 try:
495 try:
496 self._reloader.check()
496 self._reloader.check()
497 except:
497 except:
498 pass
498 pass
499
499
500
500
501 def load_ipython_extension(ip):
501 def load_ipython_extension(ip):
502 """Load the extension in IPython."""
502 """Load the extension in IPython."""
503 auto_reload = AutoreloadMagics(ip)
503 auto_reload = AutoreloadMagics(ip)
504 ip.register_magics(auto_reload)
504 ip.register_magics(auto_reload)
505 ip.callbacks.register('pre_execute_explicit', auto_reload.pre_execute_explicit)
505 ip.events.register('pre_execute_explicit', auto_reload.pre_execute_explicit)
@@ -1,321 +1,321 b''
1 """Tests for autoreload extension.
1 """Tests for autoreload extension.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (c) 2012 IPython Development Team.
4 # Copyright (c) 2012 IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import shutil
18 import shutil
19 import random
19 import random
20 import time
20 import time
21
21
22 import nose.tools as nt
22 import nose.tools as nt
23 import IPython.testing.tools as tt
23 import IPython.testing.tools as tt
24
24
25 from IPython.extensions.autoreload import AutoreloadMagics
25 from IPython.extensions.autoreload import AutoreloadMagics
26 from IPython.core.callbacks import CallbackManager, pre_execute_explicit
26 from IPython.core.events import EventManager, pre_execute_explicit
27 from IPython.utils.py3compat import PY3
27 from IPython.utils.py3compat import PY3
28
28
29 if PY3:
29 if PY3:
30 from io import StringIO
30 from io import StringIO
31 else:
31 else:
32 from StringIO import StringIO
32 from StringIO import StringIO
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # Test fixture
35 # Test fixture
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 noop = lambda *a, **kw: None
38 noop = lambda *a, **kw: None
39
39
40 class FakeShell(object):
40 class FakeShell(object):
41
41
42 def __init__(self):
42 def __init__(self):
43 self.ns = {}
43 self.ns = {}
44 self.callbacks = CallbackManager(self, {'pre_execute_explicit', pre_execute_explicit})
44 self.events = EventManager(self, {'pre_execute_explicit', pre_execute_explicit})
45 self.auto_magics = AutoreloadMagics(shell=self)
45 self.auto_magics = AutoreloadMagics(shell=self)
46 self.callbacks.register('pre_execute_explicit', self.auto_magics.pre_execute_explicit)
46 self.events.register('pre_execute_explicit', self.auto_magics.pre_execute_explicit)
47
47
48 register_magics = set_hook = noop
48 register_magics = set_hook = noop
49
49
50 def run_code(self, code):
50 def run_code(self, code):
51 self.callbacks.fire('pre_execute_explicit')
51 self.events.trigger('pre_execute_explicit')
52 exec(code, self.ns)
52 exec(code, self.ns)
53
53
54 def push(self, items):
54 def push(self, items):
55 self.ns.update(items)
55 self.ns.update(items)
56
56
57 def magic_autoreload(self, parameter):
57 def magic_autoreload(self, parameter):
58 self.auto_magics.autoreload(parameter)
58 self.auto_magics.autoreload(parameter)
59
59
60 def magic_aimport(self, parameter, stream=None):
60 def magic_aimport(self, parameter, stream=None):
61 self.auto_magics.aimport(parameter, stream=stream)
61 self.auto_magics.aimport(parameter, stream=stream)
62
62
63
63
64 class Fixture(object):
64 class Fixture(object):
65 """Fixture for creating test module files"""
65 """Fixture for creating test module files"""
66
66
67 test_dir = None
67 test_dir = None
68 old_sys_path = None
68 old_sys_path = None
69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
70
70
71 def setUp(self):
71 def setUp(self):
72 self.test_dir = tempfile.mkdtemp()
72 self.test_dir = tempfile.mkdtemp()
73 self.old_sys_path = list(sys.path)
73 self.old_sys_path = list(sys.path)
74 sys.path.insert(0, self.test_dir)
74 sys.path.insert(0, self.test_dir)
75 self.shell = FakeShell()
75 self.shell = FakeShell()
76
76
77 def tearDown(self):
77 def tearDown(self):
78 shutil.rmtree(self.test_dir)
78 shutil.rmtree(self.test_dir)
79 sys.path = self.old_sys_path
79 sys.path = self.old_sys_path
80
80
81 self.test_dir = None
81 self.test_dir = None
82 self.old_sys_path = None
82 self.old_sys_path = None
83 self.shell = None
83 self.shell = None
84
84
85 def get_module(self):
85 def get_module(self):
86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
87 if module_name in sys.modules:
87 if module_name in sys.modules:
88 del sys.modules[module_name]
88 del sys.modules[module_name]
89 file_name = os.path.join(self.test_dir, module_name + ".py")
89 file_name = os.path.join(self.test_dir, module_name + ".py")
90 return module_name, file_name
90 return module_name, file_name
91
91
92 def write_file(self, filename, content):
92 def write_file(self, filename, content):
93 """
93 """
94 Write a file, and force a timestamp difference of at least one second
94 Write a file, and force a timestamp difference of at least one second
95
95
96 Notes
96 Notes
97 -----
97 -----
98 Python's .pyc files record the timestamp of their compilation
98 Python's .pyc files record the timestamp of their compilation
99 with a time resolution of one second.
99 with a time resolution of one second.
100
100
101 Therefore, we need to force a timestamp difference between .py
101 Therefore, we need to force a timestamp difference between .py
102 and .pyc, without having the .py file be timestamped in the
102 and .pyc, without having the .py file be timestamped in the
103 future, and without changing the timestamp of the .pyc file
103 future, and without changing the timestamp of the .pyc file
104 (because that is stored in the file). The only reliable way
104 (because that is stored in the file). The only reliable way
105 to achieve this seems to be to sleep.
105 to achieve this seems to be to sleep.
106 """
106 """
107
107
108 # Sleep one second + eps
108 # Sleep one second + eps
109 time.sleep(1.05)
109 time.sleep(1.05)
110
110
111 # Write
111 # Write
112 f = open(filename, 'w')
112 f = open(filename, 'w')
113 try:
113 try:
114 f.write(content)
114 f.write(content)
115 finally:
115 finally:
116 f.close()
116 f.close()
117
117
118 def new_module(self, code):
118 def new_module(self, code):
119 mod_name, mod_fn = self.get_module()
119 mod_name, mod_fn = self.get_module()
120 f = open(mod_fn, 'w')
120 f = open(mod_fn, 'w')
121 try:
121 try:
122 f.write(code)
122 f.write(code)
123 finally:
123 finally:
124 f.close()
124 f.close()
125 return mod_name, mod_fn
125 return mod_name, mod_fn
126
126
127 #-----------------------------------------------------------------------------
127 #-----------------------------------------------------------------------------
128 # Test automatic reloading
128 # Test automatic reloading
129 #-----------------------------------------------------------------------------
129 #-----------------------------------------------------------------------------
130
130
131 class TestAutoreload(Fixture):
131 class TestAutoreload(Fixture):
132 def _check_smoketest(self, use_aimport=True):
132 def _check_smoketest(self, use_aimport=True):
133 """
133 """
134 Functional test for the automatic reloader using either
134 Functional test for the automatic reloader using either
135 '%autoreload 1' or '%autoreload 2'
135 '%autoreload 1' or '%autoreload 2'
136 """
136 """
137
137
138 mod_name, mod_fn = self.new_module("""
138 mod_name, mod_fn = self.new_module("""
139 x = 9
139 x = 9
140
140
141 z = 123 # this item will be deleted
141 z = 123 # this item will be deleted
142
142
143 def foo(y):
143 def foo(y):
144 return y + 3
144 return y + 3
145
145
146 class Baz(object):
146 class Baz(object):
147 def __init__(self, x):
147 def __init__(self, x):
148 self.x = x
148 self.x = x
149 def bar(self, y):
149 def bar(self, y):
150 return self.x + y
150 return self.x + y
151 @property
151 @property
152 def quux(self):
152 def quux(self):
153 return 42
153 return 42
154 def zzz(self):
154 def zzz(self):
155 '''This method will be deleted below'''
155 '''This method will be deleted below'''
156 return 99
156 return 99
157
157
158 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
158 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
159 def foo(self):
159 def foo(self):
160 return 1
160 return 1
161 """)
161 """)
162
162
163 #
163 #
164 # Import module, and mark for reloading
164 # Import module, and mark for reloading
165 #
165 #
166 if use_aimport:
166 if use_aimport:
167 self.shell.magic_autoreload("1")
167 self.shell.magic_autoreload("1")
168 self.shell.magic_aimport(mod_name)
168 self.shell.magic_aimport(mod_name)
169 stream = StringIO()
169 stream = StringIO()
170 self.shell.magic_aimport("", stream=stream)
170 self.shell.magic_aimport("", stream=stream)
171 nt.assert_true(("Modules to reload:\n%s" % mod_name) in
171 nt.assert_true(("Modules to reload:\n%s" % mod_name) in
172 stream.getvalue())
172 stream.getvalue())
173
173
174 nt.assert_raises(
174 nt.assert_raises(
175 ImportError,
175 ImportError,
176 self.shell.magic_aimport, "tmpmod_as318989e89ds")
176 self.shell.magic_aimport, "tmpmod_as318989e89ds")
177 else:
177 else:
178 self.shell.magic_autoreload("2")
178 self.shell.magic_autoreload("2")
179 self.shell.run_code("import %s" % mod_name)
179 self.shell.run_code("import %s" % mod_name)
180 stream = StringIO()
180 stream = StringIO()
181 self.shell.magic_aimport("", stream=stream)
181 self.shell.magic_aimport("", stream=stream)
182 nt.assert_true("Modules to reload:\nall-except-skipped" in
182 nt.assert_true("Modules to reload:\nall-except-skipped" in
183 stream.getvalue())
183 stream.getvalue())
184 nt.assert_in(mod_name, self.shell.ns)
184 nt.assert_in(mod_name, self.shell.ns)
185
185
186 mod = sys.modules[mod_name]
186 mod = sys.modules[mod_name]
187
187
188 #
188 #
189 # Test module contents
189 # Test module contents
190 #
190 #
191 old_foo = mod.foo
191 old_foo = mod.foo
192 old_obj = mod.Baz(9)
192 old_obj = mod.Baz(9)
193 old_obj2 = mod.Bar()
193 old_obj2 = mod.Bar()
194
194
195 def check_module_contents():
195 def check_module_contents():
196 nt.assert_equal(mod.x, 9)
196 nt.assert_equal(mod.x, 9)
197 nt.assert_equal(mod.z, 123)
197 nt.assert_equal(mod.z, 123)
198
198
199 nt.assert_equal(old_foo(0), 3)
199 nt.assert_equal(old_foo(0), 3)
200 nt.assert_equal(mod.foo(0), 3)
200 nt.assert_equal(mod.foo(0), 3)
201
201
202 obj = mod.Baz(9)
202 obj = mod.Baz(9)
203 nt.assert_equal(old_obj.bar(1), 10)
203 nt.assert_equal(old_obj.bar(1), 10)
204 nt.assert_equal(obj.bar(1), 10)
204 nt.assert_equal(obj.bar(1), 10)
205 nt.assert_equal(obj.quux, 42)
205 nt.assert_equal(obj.quux, 42)
206 nt.assert_equal(obj.zzz(), 99)
206 nt.assert_equal(obj.zzz(), 99)
207
207
208 obj2 = mod.Bar()
208 obj2 = mod.Bar()
209 nt.assert_equal(old_obj2.foo(), 1)
209 nt.assert_equal(old_obj2.foo(), 1)
210 nt.assert_equal(obj2.foo(), 1)
210 nt.assert_equal(obj2.foo(), 1)
211
211
212 check_module_contents()
212 check_module_contents()
213
213
214 #
214 #
215 # Simulate a failed reload: no reload should occur and exactly
215 # Simulate a failed reload: no reload should occur and exactly
216 # one error message should be printed
216 # one error message should be printed
217 #
217 #
218 self.write_file(mod_fn, """
218 self.write_file(mod_fn, """
219 a syntax error
219 a syntax error
220 """)
220 """)
221
221
222 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
222 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
223 self.shell.run_code("pass") # trigger reload
223 self.shell.run_code("pass") # trigger reload
224 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
224 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
225 self.shell.run_code("pass") # trigger another reload
225 self.shell.run_code("pass") # trigger another reload
226 check_module_contents()
226 check_module_contents()
227
227
228 #
228 #
229 # Rewrite module (this time reload should succeed)
229 # Rewrite module (this time reload should succeed)
230 #
230 #
231 self.write_file(mod_fn, """
231 self.write_file(mod_fn, """
232 x = 10
232 x = 10
233
233
234 def foo(y):
234 def foo(y):
235 return y + 4
235 return y + 4
236
236
237 class Baz(object):
237 class Baz(object):
238 def __init__(self, x):
238 def __init__(self, x):
239 self.x = x
239 self.x = x
240 def bar(self, y):
240 def bar(self, y):
241 return self.x + y + 1
241 return self.x + y + 1
242 @property
242 @property
243 def quux(self):
243 def quux(self):
244 return 43
244 return 43
245
245
246 class Bar: # old-style class
246 class Bar: # old-style class
247 def foo(self):
247 def foo(self):
248 return 2
248 return 2
249 """)
249 """)
250
250
251 def check_module_contents():
251 def check_module_contents():
252 nt.assert_equal(mod.x, 10)
252 nt.assert_equal(mod.x, 10)
253 nt.assert_false(hasattr(mod, 'z'))
253 nt.assert_false(hasattr(mod, 'z'))
254
254
255 nt.assert_equal(old_foo(0), 4) # superreload magic!
255 nt.assert_equal(old_foo(0), 4) # superreload magic!
256 nt.assert_equal(mod.foo(0), 4)
256 nt.assert_equal(mod.foo(0), 4)
257
257
258 obj = mod.Baz(9)
258 obj = mod.Baz(9)
259 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
259 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
260 nt.assert_equal(obj.bar(1), 11)
260 nt.assert_equal(obj.bar(1), 11)
261
261
262 nt.assert_equal(old_obj.quux, 43)
262 nt.assert_equal(old_obj.quux, 43)
263 nt.assert_equal(obj.quux, 43)
263 nt.assert_equal(obj.quux, 43)
264
264
265 nt.assert_false(hasattr(old_obj, 'zzz'))
265 nt.assert_false(hasattr(old_obj, 'zzz'))
266 nt.assert_false(hasattr(obj, 'zzz'))
266 nt.assert_false(hasattr(obj, 'zzz'))
267
267
268 obj2 = mod.Bar()
268 obj2 = mod.Bar()
269 nt.assert_equal(old_obj2.foo(), 2)
269 nt.assert_equal(old_obj2.foo(), 2)
270 nt.assert_equal(obj2.foo(), 2)
270 nt.assert_equal(obj2.foo(), 2)
271
271
272 self.shell.run_code("pass") # trigger reload
272 self.shell.run_code("pass") # trigger reload
273 check_module_contents()
273 check_module_contents()
274
274
275 #
275 #
276 # Another failure case: deleted file (shouldn't reload)
276 # Another failure case: deleted file (shouldn't reload)
277 #
277 #
278 os.unlink(mod_fn)
278 os.unlink(mod_fn)
279
279
280 self.shell.run_code("pass") # trigger reload
280 self.shell.run_code("pass") # trigger reload
281 check_module_contents()
281 check_module_contents()
282
282
283 #
283 #
284 # Disable autoreload and rewrite module: no reload should occur
284 # Disable autoreload and rewrite module: no reload should occur
285 #
285 #
286 if use_aimport:
286 if use_aimport:
287 self.shell.magic_aimport("-" + mod_name)
287 self.shell.magic_aimport("-" + mod_name)
288 stream = StringIO()
288 stream = StringIO()
289 self.shell.magic_aimport("", stream=stream)
289 self.shell.magic_aimport("", stream=stream)
290 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
290 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
291 stream.getvalue())
291 stream.getvalue())
292
292
293 # This should succeed, although no such module exists
293 # This should succeed, although no such module exists
294 self.shell.magic_aimport("-tmpmod_as318989e89ds")
294 self.shell.magic_aimport("-tmpmod_as318989e89ds")
295 else:
295 else:
296 self.shell.magic_autoreload("0")
296 self.shell.magic_autoreload("0")
297
297
298 self.write_file(mod_fn, """
298 self.write_file(mod_fn, """
299 x = -99
299 x = -99
300 """)
300 """)
301
301
302 self.shell.run_code("pass") # trigger reload
302 self.shell.run_code("pass") # trigger reload
303 self.shell.run_code("pass")
303 self.shell.run_code("pass")
304 check_module_contents()
304 check_module_contents()
305
305
306 #
306 #
307 # Re-enable autoreload: reload should now occur
307 # Re-enable autoreload: reload should now occur
308 #
308 #
309 if use_aimport:
309 if use_aimport:
310 self.shell.magic_aimport(mod_name)
310 self.shell.magic_aimport(mod_name)
311 else:
311 else:
312 self.shell.magic_autoreload("")
312 self.shell.magic_autoreload("")
313
313
314 self.shell.run_code("pass") # trigger reload
314 self.shell.run_code("pass") # trigger reload
315 nt.assert_equal(mod.x, -99)
315 nt.assert_equal(mod.x, -99)
316
316
317 def test_smoketest_aimport(self):
317 def test_smoketest_aimport(self):
318 self._check_smoketest(use_aimport=True)
318 self._check_smoketest(use_aimport=True)
319
319
320 def test_smoketest_autoreload(self):
320 def test_smoketest_autoreload(self):
321 self._check_smoketest(use_aimport=False)
321 self._check_smoketest(use_aimport=False)
@@ -1,142 +1,142 b''
1 """Base class for a Comm"""
1 """Base class for a Comm"""
2
2
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2013 The IPython Development Team
4 # Copyright (C) 2013 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 import uuid
14 import uuid
15
15
16 from IPython.config import LoggingConfigurable
16 from IPython.config import LoggingConfigurable
17 from IPython.core.getipython import get_ipython
17 from IPython.core.getipython import get_ipython
18
18
19 from IPython.utils.traitlets import Instance, Unicode, Bytes, Bool, Dict, Any
19 from IPython.utils.traitlets import Instance, Unicode, Bytes, Bool, Dict, Any
20
20
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # Code
22 # Code
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24
24
25 class Comm(LoggingConfigurable):
25 class Comm(LoggingConfigurable):
26
26
27 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
27 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
28 def _shell_default(self):
28 def _shell_default(self):
29 return get_ipython()
29 return get_ipython()
30
30
31 iopub_socket = Any()
31 iopub_socket = Any()
32 def _iopub_socket_default(self):
32 def _iopub_socket_default(self):
33 return self.shell.kernel.iopub_socket
33 return self.shell.kernel.iopub_socket
34 session = Instance('IPython.kernel.zmq.session.Session')
34 session = Instance('IPython.kernel.zmq.session.Session')
35 def _session_default(self):
35 def _session_default(self):
36 if self.shell is None:
36 if self.shell is None:
37 return
37 return
38 return self.shell.kernel.session
38 return self.shell.kernel.session
39
39
40 target_name = Unicode('comm')
40 target_name = Unicode('comm')
41
41
42 topic = Bytes()
42 topic = Bytes()
43 def _topic_default(self):
43 def _topic_default(self):
44 return ('comm-%s' % self.comm_id).encode('ascii')
44 return ('comm-%s' % self.comm_id).encode('ascii')
45
45
46 _open_data = Dict(help="data dict, if any, to be included in comm_open")
46 _open_data = Dict(help="data dict, if any, to be included in comm_open")
47 _close_data = Dict(help="data dict, if any, to be included in comm_close")
47 _close_data = Dict(help="data dict, if any, to be included in comm_close")
48
48
49 _msg_callback = Any()
49 _msg_callback = Any()
50 _close_callback = Any()
50 _close_callback = Any()
51
51
52 _closed = Bool(False)
52 _closed = Bool(False)
53 comm_id = Unicode()
53 comm_id = Unicode()
54 def _comm_id_default(self):
54 def _comm_id_default(self):
55 return uuid.uuid4().hex
55 return uuid.uuid4().hex
56
56
57 primary = Bool(True, help="Am I the primary or secondary Comm?")
57 primary = Bool(True, help="Am I the primary or secondary Comm?")
58
58
59 def __init__(self, target_name='', data=None, **kwargs):
59 def __init__(self, target_name='', data=None, **kwargs):
60 if target_name:
60 if target_name:
61 kwargs['target_name'] = target_name
61 kwargs['target_name'] = target_name
62 super(Comm, self).__init__(**kwargs)
62 super(Comm, self).__init__(**kwargs)
63 get_ipython().comm_manager.register_comm(self)
63 get_ipython().comm_manager.register_comm(self)
64 if self.primary:
64 if self.primary:
65 # I am primary, open my peer.
65 # I am primary, open my peer.
66 self.open(data)
66 self.open(data)
67
67
68 def _publish_msg(self, msg_type, data=None, metadata=None, **keys):
68 def _publish_msg(self, msg_type, data=None, metadata=None, **keys):
69 """Helper for sending a comm message on IOPub"""
69 """Helper for sending a comm message on IOPub"""
70 data = {} if data is None else data
70 data = {} if data is None else data
71 metadata = {} if metadata is None else metadata
71 metadata = {} if metadata is None else metadata
72 self.session.send(self.iopub_socket, msg_type,
72 self.session.send(self.iopub_socket, msg_type,
73 dict(data=data, comm_id=self.comm_id, **keys),
73 dict(data=data, comm_id=self.comm_id, **keys),
74 metadata=metadata,
74 metadata=metadata,
75 parent=self.shell.get_parent(),
75 parent=self.shell.get_parent(),
76 ident=self.topic,
76 ident=self.topic,
77 )
77 )
78
78
79 def __del__(self):
79 def __del__(self):
80 """trigger close on gc"""
80 """trigger close on gc"""
81 self.close()
81 self.close()
82
82
83 # publishing messages
83 # publishing messages
84
84
85 def open(self, data=None, metadata=None):
85 def open(self, data=None, metadata=None):
86 """Open the frontend-side version of this comm"""
86 """Open the frontend-side version of this comm"""
87 if data is None:
87 if data is None:
88 data = self._open_data
88 data = self._open_data
89 self._publish_msg('comm_open', data, metadata, target_name=self.target_name)
89 self._publish_msg('comm_open', data, metadata, target_name=self.target_name)
90
90
91 def close(self, data=None, metadata=None):
91 def close(self, data=None, metadata=None):
92 """Close the frontend-side version of this comm"""
92 """Close the frontend-side version of this comm"""
93 if self._closed:
93 if self._closed:
94 # only close once
94 # only close once
95 return
95 return
96 if data is None:
96 if data is None:
97 data = self._close_data
97 data = self._close_data
98 self._publish_msg('comm_close', data, metadata)
98 self._publish_msg('comm_close', data, metadata)
99 self._closed = True
99 self._closed = True
100
100
101 def send(self, data=None, metadata=None):
101 def send(self, data=None, metadata=None):
102 """Send a message to the frontend-side version of this comm"""
102 """Send a message to the frontend-side version of this comm"""
103 self._publish_msg('comm_msg', data, metadata)
103 self._publish_msg('comm_msg', data, metadata)
104
104
105 # registering callbacks
105 # registering callbacks
106
106
107 def on_close(self, callback):
107 def on_close(self, callback):
108 """Register a callback for comm_close
108 """Register a callback for comm_close
109
109
110 Will be called with the `data` of the close message.
110 Will be called with the `data` of the close message.
111
111
112 Call `on_close(None)` to disable an existing callback.
112 Call `on_close(None)` to disable an existing callback.
113 """
113 """
114 self._close_callback = callback
114 self._close_callback = callback
115
115
116 def on_msg(self, callback):
116 def on_msg(self, callback):
117 """Register a callback for comm_msg
117 """Register a callback for comm_msg
118
118
119 Will be called with the `data` of any comm_msg messages.
119 Will be called with the `data` of any comm_msg messages.
120
120
121 Call `on_msg(None)` to disable an existing callback.
121 Call `on_msg(None)` to disable an existing callback.
122 """
122 """
123 self._msg_callback = callback
123 self._msg_callback = callback
124
124
125 # handling of incoming messages
125 # handling of incoming messages
126
126
127 def handle_close(self, msg):
127 def handle_close(self, msg):
128 """Handle a comm_close message"""
128 """Handle a comm_close message"""
129 self.log.debug("handle_close[%s](%s)", self.comm_id, msg)
129 self.log.debug("handle_close[%s](%s)", self.comm_id, msg)
130 if self._close_callback:
130 if self._close_callback:
131 self._close_callback(msg)
131 self._close_callback(msg)
132
132
133 def handle_msg(self, msg):
133 def handle_msg(self, msg):
134 """Handle a comm_msg message"""
134 """Handle a comm_msg message"""
135 self.log.debug("handle_msg[%s](%s)", self.comm_id, msg)
135 self.log.debug("handle_msg[%s](%s)", self.comm_id, msg)
136 if self._msg_callback:
136 if self._msg_callback:
137 self.shell.callbacks.fire('pre_execute')
137 self.shell.events.trigger('pre_execute')
138 self._msg_callback(msg)
138 self._msg_callback(msg)
139 self.shell.callbacks.fire('post_execute')
139 self.shell.events.trigger('post_execute')
140
140
141
141
142 __all__ = ['Comm']
142 __all__ = ['Comm']
General Comments 0
You need to be logged in to leave comments. Login now