##// END OF EJS Templates
MAINT: Move `InputRejected` to `IPython.core.error`.
Scott Sanderson -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,53 +1,60 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Global exception classes for IPython.core.
3 Global exception classes for IPython.core.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez
8 * Fernando Perez
9 * Min Ragan-Kelley
9 * Min Ragan-Kelley
10
10
11 Notes
11 Notes
12 -----
12 -----
13 """
13 """
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Copyright (C) 2008 The IPython Development Team
16 # Copyright (C) 2008 The IPython Development Team
17 #
17 #
18 # Distributed under the terms of the BSD License. The full license is in
18 # Distributed under the terms of the BSD License. The full license is in
19 # the file COPYING, distributed as part of this software.
19 # the file COPYING, distributed as part of this software.
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 # Imports
23 # Imports
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Exception classes
27 # Exception classes
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30 class IPythonCoreError(Exception):
30 class IPythonCoreError(Exception):
31 pass
31 pass
32
32
33
33
34 class TryNext(IPythonCoreError):
34 class TryNext(IPythonCoreError):
35 """Try next hook exception.
35 """Try next hook exception.
36
36
37 Raise this in your hook function to indicate that the next hook handler
37 Raise this in your hook function to indicate that the next hook handler
38 should be used to handle the operation.
38 should be used to handle the operation.
39 """
39 """
40
40
41 class UsageError(IPythonCoreError):
41 class UsageError(IPythonCoreError):
42 """Error in magic function arguments, etc.
42 """Error in magic function arguments, etc.
43
43
44 Something that probably won't warrant a full traceback, but should
44 Something that probably won't warrant a full traceback, but should
45 nevertheless interrupt a macro / batch file.
45 nevertheless interrupt a macro / batch file.
46 """
46 """
47
47
48 class StdinNotImplementedError(IPythonCoreError, NotImplementedError):
48 class StdinNotImplementedError(IPythonCoreError, NotImplementedError):
49 """raw_input was requested in a context where it is not supported
49 """raw_input was requested in a context where it is not supported
50
50
51 For use in IPython kernels, where only some frontends may support
51 For use in IPython kernels, where only some frontends may support
52 stdin requests.
52 stdin requests.
53 """
53 """
54
55 class InputRejected(Exception):
56 """Input rejected by ast transformer.
57
58 Raise this in your NodeTransformer to indicate that InteractiveShell should
59 not execute the supplied input.
60 """
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
@@ -1,842 +1,842 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import ast
12 import ast
13 import os
13 import os
14 import signal
14 import signal
15 import shutil
15 import shutil
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import unittest
18 import unittest
19 try:
19 try:
20 from unittest import mock
20 from unittest import mock
21 except ImportError:
21 except ImportError:
22 import mock
22 import mock
23 from os.path import join
23 from os.path import join
24
24
25 import nose.tools as nt
25 import nose.tools as nt
26
26
27 from IPython.core.error import InputRejected
27 from IPython.core.inputtransformer import InputTransformer
28 from IPython.core.inputtransformer import InputTransformer
28 from IPython.lib.security import InputRejected
29 from IPython.testing.decorators import (
29 from IPython.testing.decorators import (
30 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
30 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
31 )
31 )
32 from IPython.testing import tools as tt
32 from IPython.testing import tools as tt
33 from IPython.utils import io
33 from IPython.utils import io
34 from IPython.utils.process import find_cmd
34 from IPython.utils.process import find_cmd
35 from IPython.utils import py3compat
35 from IPython.utils import py3compat
36 from IPython.utils.py3compat import unicode_type, PY3
36 from IPython.utils.py3compat import unicode_type, PY3
37
37
38 if PY3:
38 if PY3:
39 from io import StringIO
39 from io import StringIO
40 else:
40 else:
41 from StringIO import StringIO
41 from StringIO import StringIO
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Globals
44 # Globals
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46 # This is used by every single test, no point repeating it ad nauseam
46 # This is used by every single test, no point repeating it ad nauseam
47 ip = get_ipython()
47 ip = get_ipython()
48
48
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50 # Tests
50 # Tests
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52
52
53 class InteractiveShellTestCase(unittest.TestCase):
53 class InteractiveShellTestCase(unittest.TestCase):
54 def test_naked_string_cells(self):
54 def test_naked_string_cells(self):
55 """Test that cells with only naked strings are fully executed"""
55 """Test that cells with only naked strings are fully executed"""
56 # First, single-line inputs
56 # First, single-line inputs
57 ip.run_cell('"a"\n')
57 ip.run_cell('"a"\n')
58 self.assertEqual(ip.user_ns['_'], 'a')
58 self.assertEqual(ip.user_ns['_'], 'a')
59 # And also multi-line cells
59 # And also multi-line cells
60 ip.run_cell('"""a\nb"""\n')
60 ip.run_cell('"""a\nb"""\n')
61 self.assertEqual(ip.user_ns['_'], 'a\nb')
61 self.assertEqual(ip.user_ns['_'], 'a\nb')
62
62
63 def test_run_empty_cell(self):
63 def test_run_empty_cell(self):
64 """Just make sure we don't get a horrible error with a blank
64 """Just make sure we don't get a horrible error with a blank
65 cell of input. Yes, I did overlook that."""
65 cell of input. Yes, I did overlook that."""
66 old_xc = ip.execution_count
66 old_xc = ip.execution_count
67 ip.run_cell('')
67 ip.run_cell('')
68 self.assertEqual(ip.execution_count, old_xc)
68 self.assertEqual(ip.execution_count, old_xc)
69
69
70 def test_run_cell_multiline(self):
70 def test_run_cell_multiline(self):
71 """Multi-block, multi-line cells must execute correctly.
71 """Multi-block, multi-line cells must execute correctly.
72 """
72 """
73 src = '\n'.join(["x=1",
73 src = '\n'.join(["x=1",
74 "y=2",
74 "y=2",
75 "if 1:",
75 "if 1:",
76 " x += 1",
76 " x += 1",
77 " y += 1",])
77 " y += 1",])
78 ip.run_cell(src)
78 ip.run_cell(src)
79 self.assertEqual(ip.user_ns['x'], 2)
79 self.assertEqual(ip.user_ns['x'], 2)
80 self.assertEqual(ip.user_ns['y'], 3)
80 self.assertEqual(ip.user_ns['y'], 3)
81
81
82 def test_multiline_string_cells(self):
82 def test_multiline_string_cells(self):
83 "Code sprinkled with multiline strings should execute (GH-306)"
83 "Code sprinkled with multiline strings should execute (GH-306)"
84 ip.run_cell('tmp=0')
84 ip.run_cell('tmp=0')
85 self.assertEqual(ip.user_ns['tmp'], 0)
85 self.assertEqual(ip.user_ns['tmp'], 0)
86 ip.run_cell('tmp=1;"""a\nb"""\n')
86 ip.run_cell('tmp=1;"""a\nb"""\n')
87 self.assertEqual(ip.user_ns['tmp'], 1)
87 self.assertEqual(ip.user_ns['tmp'], 1)
88
88
89 def test_dont_cache_with_semicolon(self):
89 def test_dont_cache_with_semicolon(self):
90 "Ending a line with semicolon should not cache the returned object (GH-307)"
90 "Ending a line with semicolon should not cache the returned object (GH-307)"
91 oldlen = len(ip.user_ns['Out'])
91 oldlen = len(ip.user_ns['Out'])
92 for cell in ['1;', '1;1;']:
92 for cell in ['1;', '1;1;']:
93 ip.run_cell(cell, store_history=True)
93 ip.run_cell(cell, store_history=True)
94 newlen = len(ip.user_ns['Out'])
94 newlen = len(ip.user_ns['Out'])
95 self.assertEqual(oldlen, newlen)
95 self.assertEqual(oldlen, newlen)
96 i = 0
96 i = 0
97 #also test the default caching behavior
97 #also test the default caching behavior
98 for cell in ['1', '1;1']:
98 for cell in ['1', '1;1']:
99 ip.run_cell(cell, store_history=True)
99 ip.run_cell(cell, store_history=True)
100 newlen = len(ip.user_ns['Out'])
100 newlen = len(ip.user_ns['Out'])
101 i += 1
101 i += 1
102 self.assertEqual(oldlen+i, newlen)
102 self.assertEqual(oldlen+i, newlen)
103
103
104 def test_In_variable(self):
104 def test_In_variable(self):
105 "Verify that In variable grows with user input (GH-284)"
105 "Verify that In variable grows with user input (GH-284)"
106 oldlen = len(ip.user_ns['In'])
106 oldlen = len(ip.user_ns['In'])
107 ip.run_cell('1;', store_history=True)
107 ip.run_cell('1;', store_history=True)
108 newlen = len(ip.user_ns['In'])
108 newlen = len(ip.user_ns['In'])
109 self.assertEqual(oldlen+1, newlen)
109 self.assertEqual(oldlen+1, newlen)
110 self.assertEqual(ip.user_ns['In'][-1],'1;')
110 self.assertEqual(ip.user_ns['In'][-1],'1;')
111
111
112 def test_magic_names_in_string(self):
112 def test_magic_names_in_string(self):
113 ip.run_cell('a = """\n%exit\n"""')
113 ip.run_cell('a = """\n%exit\n"""')
114 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
114 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
115
115
116 def test_trailing_newline(self):
116 def test_trailing_newline(self):
117 """test that running !(command) does not raise a SyntaxError"""
117 """test that running !(command) does not raise a SyntaxError"""
118 ip.run_cell('!(true)\n', False)
118 ip.run_cell('!(true)\n', False)
119 ip.run_cell('!(true)\n\n\n', False)
119 ip.run_cell('!(true)\n\n\n', False)
120
120
121 def test_gh_597(self):
121 def test_gh_597(self):
122 """Pretty-printing lists of objects with non-ascii reprs may cause
122 """Pretty-printing lists of objects with non-ascii reprs may cause
123 problems."""
123 problems."""
124 class Spam(object):
124 class Spam(object):
125 def __repr__(self):
125 def __repr__(self):
126 return "\xe9"*50
126 return "\xe9"*50
127 import IPython.core.formatters
127 import IPython.core.formatters
128 f = IPython.core.formatters.PlainTextFormatter()
128 f = IPython.core.formatters.PlainTextFormatter()
129 f([Spam(),Spam()])
129 f([Spam(),Spam()])
130
130
131
131
132 def test_future_flags(self):
132 def test_future_flags(self):
133 """Check that future flags are used for parsing code (gh-777)"""
133 """Check that future flags are used for parsing code (gh-777)"""
134 ip.run_cell('from __future__ import print_function')
134 ip.run_cell('from __future__ import print_function')
135 try:
135 try:
136 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
136 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
137 assert 'prfunc_return_val' in ip.user_ns
137 assert 'prfunc_return_val' in ip.user_ns
138 finally:
138 finally:
139 # Reset compiler flags so we don't mess up other tests.
139 # Reset compiler flags so we don't mess up other tests.
140 ip.compile.reset_compiler_flags()
140 ip.compile.reset_compiler_flags()
141
141
142 def test_future_unicode(self):
142 def test_future_unicode(self):
143 """Check that unicode_literals is imported from __future__ (gh #786)"""
143 """Check that unicode_literals is imported from __future__ (gh #786)"""
144 try:
144 try:
145 ip.run_cell(u'byte_str = "a"')
145 ip.run_cell(u'byte_str = "a"')
146 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
146 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
147 ip.run_cell('from __future__ import unicode_literals')
147 ip.run_cell('from __future__ import unicode_literals')
148 ip.run_cell(u'unicode_str = "a"')
148 ip.run_cell(u'unicode_str = "a"')
149 assert isinstance(ip.user_ns['unicode_str'], unicode_type) # strings literals are now unicode
149 assert isinstance(ip.user_ns['unicode_str'], unicode_type) # strings literals are now unicode
150 finally:
150 finally:
151 # Reset compiler flags so we don't mess up other tests.
151 # Reset compiler flags so we don't mess up other tests.
152 ip.compile.reset_compiler_flags()
152 ip.compile.reset_compiler_flags()
153
153
154 def test_can_pickle(self):
154 def test_can_pickle(self):
155 "Can we pickle objects defined interactively (GH-29)"
155 "Can we pickle objects defined interactively (GH-29)"
156 ip = get_ipython()
156 ip = get_ipython()
157 ip.reset()
157 ip.reset()
158 ip.run_cell(("class Mylist(list):\n"
158 ip.run_cell(("class Mylist(list):\n"
159 " def __init__(self,x=[]):\n"
159 " def __init__(self,x=[]):\n"
160 " list.__init__(self,x)"))
160 " list.__init__(self,x)"))
161 ip.run_cell("w=Mylist([1,2,3])")
161 ip.run_cell("w=Mylist([1,2,3])")
162
162
163 from pickle import dumps
163 from pickle import dumps
164
164
165 # We need to swap in our main module - this is only necessary
165 # We need to swap in our main module - this is only necessary
166 # inside the test framework, because IPython puts the interactive module
166 # inside the test framework, because IPython puts the interactive module
167 # in place (but the test framework undoes this).
167 # in place (but the test framework undoes this).
168 _main = sys.modules['__main__']
168 _main = sys.modules['__main__']
169 sys.modules['__main__'] = ip.user_module
169 sys.modules['__main__'] = ip.user_module
170 try:
170 try:
171 res = dumps(ip.user_ns["w"])
171 res = dumps(ip.user_ns["w"])
172 finally:
172 finally:
173 sys.modules['__main__'] = _main
173 sys.modules['__main__'] = _main
174 self.assertTrue(isinstance(res, bytes))
174 self.assertTrue(isinstance(res, bytes))
175
175
176 def test_global_ns(self):
176 def test_global_ns(self):
177 "Code in functions must be able to access variables outside them."
177 "Code in functions must be able to access variables outside them."
178 ip = get_ipython()
178 ip = get_ipython()
179 ip.run_cell("a = 10")
179 ip.run_cell("a = 10")
180 ip.run_cell(("def f(x):\n"
180 ip.run_cell(("def f(x):\n"
181 " return x + a"))
181 " return x + a"))
182 ip.run_cell("b = f(12)")
182 ip.run_cell("b = f(12)")
183 self.assertEqual(ip.user_ns["b"], 22)
183 self.assertEqual(ip.user_ns["b"], 22)
184
184
185 def test_bad_custom_tb(self):
185 def test_bad_custom_tb(self):
186 """Check that InteractiveShell is protected from bad custom exception handlers"""
186 """Check that InteractiveShell is protected from bad custom exception handlers"""
187 from IPython.utils import io
187 from IPython.utils import io
188 save_stderr = io.stderr
188 save_stderr = io.stderr
189 try:
189 try:
190 # capture stderr
190 # capture stderr
191 io.stderr = StringIO()
191 io.stderr = StringIO()
192 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
192 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
193 self.assertEqual(ip.custom_exceptions, (IOError,))
193 self.assertEqual(ip.custom_exceptions, (IOError,))
194 ip.run_cell(u'raise IOError("foo")')
194 ip.run_cell(u'raise IOError("foo")')
195 self.assertEqual(ip.custom_exceptions, ())
195 self.assertEqual(ip.custom_exceptions, ())
196 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
196 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
197 finally:
197 finally:
198 io.stderr = save_stderr
198 io.stderr = save_stderr
199
199
200 def test_bad_custom_tb_return(self):
200 def test_bad_custom_tb_return(self):
201 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
201 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
202 from IPython.utils import io
202 from IPython.utils import io
203 save_stderr = io.stderr
203 save_stderr = io.stderr
204 try:
204 try:
205 # capture stderr
205 # capture stderr
206 io.stderr = StringIO()
206 io.stderr = StringIO()
207 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
207 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
208 self.assertEqual(ip.custom_exceptions, (NameError,))
208 self.assertEqual(ip.custom_exceptions, (NameError,))
209 ip.run_cell(u'a=abracadabra')
209 ip.run_cell(u'a=abracadabra')
210 self.assertEqual(ip.custom_exceptions, ())
210 self.assertEqual(ip.custom_exceptions, ())
211 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
211 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
212 finally:
212 finally:
213 io.stderr = save_stderr
213 io.stderr = save_stderr
214
214
215 def test_drop_by_id(self):
215 def test_drop_by_id(self):
216 myvars = {"a":object(), "b":object(), "c": object()}
216 myvars = {"a":object(), "b":object(), "c": object()}
217 ip.push(myvars, interactive=False)
217 ip.push(myvars, interactive=False)
218 for name in myvars:
218 for name in myvars:
219 assert name in ip.user_ns, name
219 assert name in ip.user_ns, name
220 assert name in ip.user_ns_hidden, name
220 assert name in ip.user_ns_hidden, name
221 ip.user_ns['b'] = 12
221 ip.user_ns['b'] = 12
222 ip.drop_by_id(myvars)
222 ip.drop_by_id(myvars)
223 for name in ["a", "c"]:
223 for name in ["a", "c"]:
224 assert name not in ip.user_ns, name
224 assert name not in ip.user_ns, name
225 assert name not in ip.user_ns_hidden, name
225 assert name not in ip.user_ns_hidden, name
226 assert ip.user_ns['b'] == 12
226 assert ip.user_ns['b'] == 12
227 ip.reset()
227 ip.reset()
228
228
229 def test_var_expand(self):
229 def test_var_expand(self):
230 ip.user_ns['f'] = u'Ca\xf1o'
230 ip.user_ns['f'] = u'Ca\xf1o'
231 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
231 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
232 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
232 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
233 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
233 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
234 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
234 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
235
235
236 ip.user_ns['f'] = b'Ca\xc3\xb1o'
236 ip.user_ns['f'] = b'Ca\xc3\xb1o'
237 # This should not raise any exception:
237 # This should not raise any exception:
238 ip.var_expand(u'echo $f')
238 ip.var_expand(u'echo $f')
239
239
240 def test_var_expand_local(self):
240 def test_var_expand_local(self):
241 """Test local variable expansion in !system and %magic calls"""
241 """Test local variable expansion in !system and %magic calls"""
242 # !system
242 # !system
243 ip.run_cell('def test():\n'
243 ip.run_cell('def test():\n'
244 ' lvar = "ttt"\n'
244 ' lvar = "ttt"\n'
245 ' ret = !echo {lvar}\n'
245 ' ret = !echo {lvar}\n'
246 ' return ret[0]\n')
246 ' return ret[0]\n')
247 res = ip.user_ns['test']()
247 res = ip.user_ns['test']()
248 nt.assert_in('ttt', res)
248 nt.assert_in('ttt', res)
249
249
250 # %magic
250 # %magic
251 ip.run_cell('def makemacro():\n'
251 ip.run_cell('def makemacro():\n'
252 ' macroname = "macro_var_expand_locals"\n'
252 ' macroname = "macro_var_expand_locals"\n'
253 ' %macro {macroname} codestr\n')
253 ' %macro {macroname} codestr\n')
254 ip.user_ns['codestr'] = "str(12)"
254 ip.user_ns['codestr'] = "str(12)"
255 ip.run_cell('makemacro()')
255 ip.run_cell('makemacro()')
256 nt.assert_in('macro_var_expand_locals', ip.user_ns)
256 nt.assert_in('macro_var_expand_locals', ip.user_ns)
257
257
258 def test_var_expand_self(self):
258 def test_var_expand_self(self):
259 """Test variable expansion with the name 'self', which was failing.
259 """Test variable expansion with the name 'self', which was failing.
260
260
261 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
261 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
262 """
262 """
263 ip.run_cell('class cTest:\n'
263 ip.run_cell('class cTest:\n'
264 ' classvar="see me"\n'
264 ' classvar="see me"\n'
265 ' def test(self):\n'
265 ' def test(self):\n'
266 ' res = !echo Variable: {self.classvar}\n'
266 ' res = !echo Variable: {self.classvar}\n'
267 ' return res[0]\n')
267 ' return res[0]\n')
268 nt.assert_in('see me', ip.user_ns['cTest']().test())
268 nt.assert_in('see me', ip.user_ns['cTest']().test())
269
269
270 def test_bad_var_expand(self):
270 def test_bad_var_expand(self):
271 """var_expand on invalid formats shouldn't raise"""
271 """var_expand on invalid formats shouldn't raise"""
272 # SyntaxError
272 # SyntaxError
273 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
273 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
274 # NameError
274 # NameError
275 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
275 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
276 # ZeroDivisionError
276 # ZeroDivisionError
277 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
277 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
278
278
279 def test_silent_postexec(self):
279 def test_silent_postexec(self):
280 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
280 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
281 pre_explicit = mock.Mock()
281 pre_explicit = mock.Mock()
282 pre_always = mock.Mock()
282 pre_always = mock.Mock()
283 post_explicit = mock.Mock()
283 post_explicit = mock.Mock()
284 post_always = mock.Mock()
284 post_always = mock.Mock()
285
285
286 ip.events.register('pre_run_cell', pre_explicit)
286 ip.events.register('pre_run_cell', pre_explicit)
287 ip.events.register('pre_execute', pre_always)
287 ip.events.register('pre_execute', pre_always)
288 ip.events.register('post_run_cell', post_explicit)
288 ip.events.register('post_run_cell', post_explicit)
289 ip.events.register('post_execute', post_always)
289 ip.events.register('post_execute', post_always)
290
290
291 try:
291 try:
292 ip.run_cell("1", silent=True)
292 ip.run_cell("1", silent=True)
293 assert pre_always.called
293 assert pre_always.called
294 assert not pre_explicit.called
294 assert not pre_explicit.called
295 assert post_always.called
295 assert post_always.called
296 assert not post_explicit.called
296 assert not post_explicit.called
297 # double-check that non-silent exec did what we expected
297 # double-check that non-silent exec did what we expected
298 # silent to avoid
298 # silent to avoid
299 ip.run_cell("1")
299 ip.run_cell("1")
300 assert pre_explicit.called
300 assert pre_explicit.called
301 assert post_explicit.called
301 assert post_explicit.called
302 finally:
302 finally:
303 # remove post-exec
303 # remove post-exec
304 ip.events.reset_all()
304 ip.events.reset_all()
305
305
306 def test_silent_noadvance(self):
306 def test_silent_noadvance(self):
307 """run_cell(silent=True) doesn't advance execution_count"""
307 """run_cell(silent=True) doesn't advance execution_count"""
308 ec = ip.execution_count
308 ec = ip.execution_count
309 # silent should force store_history=False
309 # silent should force store_history=False
310 ip.run_cell("1", store_history=True, silent=True)
310 ip.run_cell("1", store_history=True, silent=True)
311
311
312 self.assertEqual(ec, ip.execution_count)
312 self.assertEqual(ec, ip.execution_count)
313 # double-check that non-silent exec did what we expected
313 # double-check that non-silent exec did what we expected
314 # silent to avoid
314 # silent to avoid
315 ip.run_cell("1", store_history=True)
315 ip.run_cell("1", store_history=True)
316 self.assertEqual(ec+1, ip.execution_count)
316 self.assertEqual(ec+1, ip.execution_count)
317
317
318 def test_silent_nodisplayhook(self):
318 def test_silent_nodisplayhook(self):
319 """run_cell(silent=True) doesn't trigger displayhook"""
319 """run_cell(silent=True) doesn't trigger displayhook"""
320 d = dict(called=False)
320 d = dict(called=False)
321
321
322 trap = ip.display_trap
322 trap = ip.display_trap
323 save_hook = trap.hook
323 save_hook = trap.hook
324
324
325 def failing_hook(*args, **kwargs):
325 def failing_hook(*args, **kwargs):
326 d['called'] = True
326 d['called'] = True
327
327
328 try:
328 try:
329 trap.hook = failing_hook
329 trap.hook = failing_hook
330 ip.run_cell("1", silent=True)
330 ip.run_cell("1", silent=True)
331 self.assertFalse(d['called'])
331 self.assertFalse(d['called'])
332 # double-check that non-silent exec did what we expected
332 # double-check that non-silent exec did what we expected
333 # silent to avoid
333 # silent to avoid
334 ip.run_cell("1")
334 ip.run_cell("1")
335 self.assertTrue(d['called'])
335 self.assertTrue(d['called'])
336 finally:
336 finally:
337 trap.hook = save_hook
337 trap.hook = save_hook
338
338
339 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
339 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
340 def test_print_softspace(self):
340 def test_print_softspace(self):
341 """Verify that softspace is handled correctly when executing multiple
341 """Verify that softspace is handled correctly when executing multiple
342 statements.
342 statements.
343
343
344 In [1]: print 1; print 2
344 In [1]: print 1; print 2
345 1
345 1
346 2
346 2
347
347
348 In [2]: print 1,; print 2
348 In [2]: print 1,; print 2
349 1 2
349 1 2
350 """
350 """
351
351
352 def test_ofind_line_magic(self):
352 def test_ofind_line_magic(self):
353 from IPython.core.magic import register_line_magic
353 from IPython.core.magic import register_line_magic
354
354
355 @register_line_magic
355 @register_line_magic
356 def lmagic(line):
356 def lmagic(line):
357 "A line magic"
357 "A line magic"
358
358
359 # Get info on line magic
359 # Get info on line magic
360 lfind = ip._ofind('lmagic')
360 lfind = ip._ofind('lmagic')
361 info = dict(found=True, isalias=False, ismagic=True,
361 info = dict(found=True, isalias=False, ismagic=True,
362 namespace = 'IPython internal', obj= lmagic.__wrapped__,
362 namespace = 'IPython internal', obj= lmagic.__wrapped__,
363 parent = None)
363 parent = None)
364 nt.assert_equal(lfind, info)
364 nt.assert_equal(lfind, info)
365
365
366 def test_ofind_cell_magic(self):
366 def test_ofind_cell_magic(self):
367 from IPython.core.magic import register_cell_magic
367 from IPython.core.magic import register_cell_magic
368
368
369 @register_cell_magic
369 @register_cell_magic
370 def cmagic(line, cell):
370 def cmagic(line, cell):
371 "A cell magic"
371 "A cell magic"
372
372
373 # Get info on cell magic
373 # Get info on cell magic
374 find = ip._ofind('cmagic')
374 find = ip._ofind('cmagic')
375 info = dict(found=True, isalias=False, ismagic=True,
375 info = dict(found=True, isalias=False, ismagic=True,
376 namespace = 'IPython internal', obj= cmagic.__wrapped__,
376 namespace = 'IPython internal', obj= cmagic.__wrapped__,
377 parent = None)
377 parent = None)
378 nt.assert_equal(find, info)
378 nt.assert_equal(find, info)
379
379
380 def test_ofind_property_with_error(self):
380 def test_ofind_property_with_error(self):
381 class A(object):
381 class A(object):
382 @property
382 @property
383 def foo(self):
383 def foo(self):
384 raise NotImplementedError()
384 raise NotImplementedError()
385 a = A()
385 a = A()
386
386
387 found = ip._ofind('a.foo', [('locals', locals())])
387 found = ip._ofind('a.foo', [('locals', locals())])
388 info = dict(found=True, isalias=False, ismagic=False,
388 info = dict(found=True, isalias=False, ismagic=False,
389 namespace='locals', obj=A.foo, parent=a)
389 namespace='locals', obj=A.foo, parent=a)
390 nt.assert_equal(found, info)
390 nt.assert_equal(found, info)
391
391
392 def test_ofind_multiple_attribute_lookups(self):
392 def test_ofind_multiple_attribute_lookups(self):
393 class A(object):
393 class A(object):
394 @property
394 @property
395 def foo(self):
395 def foo(self):
396 raise NotImplementedError()
396 raise NotImplementedError()
397
397
398 a = A()
398 a = A()
399 a.a = A()
399 a.a = A()
400 a.a.a = A()
400 a.a.a = A()
401
401
402 found = ip._ofind('a.a.a.foo', [('locals', locals())])
402 found = ip._ofind('a.a.a.foo', [('locals', locals())])
403 info = dict(found=True, isalias=False, ismagic=False,
403 info = dict(found=True, isalias=False, ismagic=False,
404 namespace='locals', obj=A.foo, parent=a.a.a)
404 namespace='locals', obj=A.foo, parent=a.a.a)
405 nt.assert_equal(found, info)
405 nt.assert_equal(found, info)
406
406
407 def test_ofind_slotted_attributes(self):
407 def test_ofind_slotted_attributes(self):
408 class A(object):
408 class A(object):
409 __slots__ = ['foo']
409 __slots__ = ['foo']
410 def __init__(self):
410 def __init__(self):
411 self.foo = 'bar'
411 self.foo = 'bar'
412
412
413 a = A()
413 a = A()
414 found = ip._ofind('a.foo', [('locals', locals())])
414 found = ip._ofind('a.foo', [('locals', locals())])
415 info = dict(found=True, isalias=False, ismagic=False,
415 info = dict(found=True, isalias=False, ismagic=False,
416 namespace='locals', obj=a.foo, parent=a)
416 namespace='locals', obj=a.foo, parent=a)
417 nt.assert_equal(found, info)
417 nt.assert_equal(found, info)
418
418
419 found = ip._ofind('a.bar', [('locals', locals())])
419 found = ip._ofind('a.bar', [('locals', locals())])
420 info = dict(found=False, isalias=False, ismagic=False,
420 info = dict(found=False, isalias=False, ismagic=False,
421 namespace=None, obj=None, parent=a)
421 namespace=None, obj=None, parent=a)
422 nt.assert_equal(found, info)
422 nt.assert_equal(found, info)
423
423
424 def test_ofind_prefers_property_to_instance_level_attribute(self):
424 def test_ofind_prefers_property_to_instance_level_attribute(self):
425 class A(object):
425 class A(object):
426 @property
426 @property
427 def foo(self):
427 def foo(self):
428 return 'bar'
428 return 'bar'
429 a = A()
429 a = A()
430 a.__dict__['foo'] = 'baz'
430 a.__dict__['foo'] = 'baz'
431 nt.assert_equal(a.foo, 'bar')
431 nt.assert_equal(a.foo, 'bar')
432 found = ip._ofind('a.foo', [('locals', locals())])
432 found = ip._ofind('a.foo', [('locals', locals())])
433 nt.assert_is(found['obj'], A.foo)
433 nt.assert_is(found['obj'], A.foo)
434
434
435 def test_custom_exception(self):
435 def test_custom_exception(self):
436 called = []
436 called = []
437 def my_handler(shell, etype, value, tb, tb_offset=None):
437 def my_handler(shell, etype, value, tb, tb_offset=None):
438 called.append(etype)
438 called.append(etype)
439 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
439 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
440
440
441 ip.set_custom_exc((ValueError,), my_handler)
441 ip.set_custom_exc((ValueError,), my_handler)
442 try:
442 try:
443 ip.run_cell("raise ValueError('test')")
443 ip.run_cell("raise ValueError('test')")
444 # Check that this was called, and only once.
444 # Check that this was called, and only once.
445 self.assertEqual(called, [ValueError])
445 self.assertEqual(called, [ValueError])
446 finally:
446 finally:
447 # Reset the custom exception hook
447 # Reset the custom exception hook
448 ip.set_custom_exc((), None)
448 ip.set_custom_exc((), None)
449
449
450 @skipif(sys.version_info[0] >= 3, "no differences with __future__ in py3")
450 @skipif(sys.version_info[0] >= 3, "no differences with __future__ in py3")
451 def test_future_environment(self):
451 def test_future_environment(self):
452 "Can we run code with & without the shell's __future__ imports?"
452 "Can we run code with & without the shell's __future__ imports?"
453 ip.run_cell("from __future__ import division")
453 ip.run_cell("from __future__ import division")
454 ip.run_cell("a = 1/2", shell_futures=True)
454 ip.run_cell("a = 1/2", shell_futures=True)
455 self.assertEqual(ip.user_ns['a'], 0.5)
455 self.assertEqual(ip.user_ns['a'], 0.5)
456 ip.run_cell("b = 1/2", shell_futures=False)
456 ip.run_cell("b = 1/2", shell_futures=False)
457 self.assertEqual(ip.user_ns['b'], 0)
457 self.assertEqual(ip.user_ns['b'], 0)
458
458
459 ip.compile.reset_compiler_flags()
459 ip.compile.reset_compiler_flags()
460 # This shouldn't leak to the shell's compiler
460 # This shouldn't leak to the shell's compiler
461 ip.run_cell("from __future__ import division \nc=1/2", shell_futures=False)
461 ip.run_cell("from __future__ import division \nc=1/2", shell_futures=False)
462 self.assertEqual(ip.user_ns['c'], 0.5)
462 self.assertEqual(ip.user_ns['c'], 0.5)
463 ip.run_cell("d = 1/2", shell_futures=True)
463 ip.run_cell("d = 1/2", shell_futures=True)
464 self.assertEqual(ip.user_ns['d'], 0)
464 self.assertEqual(ip.user_ns['d'], 0)
465
465
466 def test_mktempfile(self):
466 def test_mktempfile(self):
467 filename = ip.mktempfile()
467 filename = ip.mktempfile()
468 # Check that we can open the file again on Windows
468 # Check that we can open the file again on Windows
469 with open(filename, 'w') as f:
469 with open(filename, 'w') as f:
470 f.write('abc')
470 f.write('abc')
471
471
472 filename = ip.mktempfile(data='blah')
472 filename = ip.mktempfile(data='blah')
473 with open(filename, 'r') as f:
473 with open(filename, 'r') as f:
474 self.assertEqual(f.read(), 'blah')
474 self.assertEqual(f.read(), 'blah')
475
475
476 def test_new_main_mod(self):
476 def test_new_main_mod(self):
477 # Smoketest to check that this accepts a unicode module name
477 # Smoketest to check that this accepts a unicode module name
478 name = u'jiefmw'
478 name = u'jiefmw'
479 mod = ip.new_main_mod(u'%s.py' % name, name)
479 mod = ip.new_main_mod(u'%s.py' % name, name)
480 self.assertEqual(mod.__name__, name)
480 self.assertEqual(mod.__name__, name)
481
481
482 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
482 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
483
483
484 @onlyif_unicode_paths
484 @onlyif_unicode_paths
485 def setUp(self):
485 def setUp(self):
486 self.BASETESTDIR = tempfile.mkdtemp()
486 self.BASETESTDIR = tempfile.mkdtemp()
487 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
487 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
488 os.mkdir(self.TESTDIR)
488 os.mkdir(self.TESTDIR)
489 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
489 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
490 sfile.write("pass\n")
490 sfile.write("pass\n")
491 self.oldpath = py3compat.getcwd()
491 self.oldpath = py3compat.getcwd()
492 os.chdir(self.TESTDIR)
492 os.chdir(self.TESTDIR)
493 self.fname = u"Γ₯Àâtestscript.py"
493 self.fname = u"Γ₯Àâtestscript.py"
494
494
495 def tearDown(self):
495 def tearDown(self):
496 os.chdir(self.oldpath)
496 os.chdir(self.oldpath)
497 shutil.rmtree(self.BASETESTDIR)
497 shutil.rmtree(self.BASETESTDIR)
498
498
499 @onlyif_unicode_paths
499 @onlyif_unicode_paths
500 def test_1(self):
500 def test_1(self):
501 """Test safe_execfile with non-ascii path
501 """Test safe_execfile with non-ascii path
502 """
502 """
503 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
503 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
504
504
505 class ExitCodeChecks(tt.TempFileMixin):
505 class ExitCodeChecks(tt.TempFileMixin):
506 def test_exit_code_ok(self):
506 def test_exit_code_ok(self):
507 self.system('exit 0')
507 self.system('exit 0')
508 self.assertEqual(ip.user_ns['_exit_code'], 0)
508 self.assertEqual(ip.user_ns['_exit_code'], 0)
509
509
510 def test_exit_code_error(self):
510 def test_exit_code_error(self):
511 self.system('exit 1')
511 self.system('exit 1')
512 self.assertEqual(ip.user_ns['_exit_code'], 1)
512 self.assertEqual(ip.user_ns['_exit_code'], 1)
513
513
514 @skipif(not hasattr(signal, 'SIGALRM'))
514 @skipif(not hasattr(signal, 'SIGALRM'))
515 def test_exit_code_signal(self):
515 def test_exit_code_signal(self):
516 self.mktmp("import signal, time\n"
516 self.mktmp("import signal, time\n"
517 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
517 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
518 "time.sleep(1)\n")
518 "time.sleep(1)\n")
519 self.system("%s %s" % (sys.executable, self.fname))
519 self.system("%s %s" % (sys.executable, self.fname))
520 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
520 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
521
521
522 @onlyif_cmds_exist("csh")
522 @onlyif_cmds_exist("csh")
523 def test_exit_code_signal_csh(self):
523 def test_exit_code_signal_csh(self):
524 SHELL = os.environ.get('SHELL', None)
524 SHELL = os.environ.get('SHELL', None)
525 os.environ['SHELL'] = find_cmd("csh")
525 os.environ['SHELL'] = find_cmd("csh")
526 try:
526 try:
527 self.test_exit_code_signal()
527 self.test_exit_code_signal()
528 finally:
528 finally:
529 if SHELL is not None:
529 if SHELL is not None:
530 os.environ['SHELL'] = SHELL
530 os.environ['SHELL'] = SHELL
531 else:
531 else:
532 del os.environ['SHELL']
532 del os.environ['SHELL']
533
533
534 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
534 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
535 system = ip.system_raw
535 system = ip.system_raw
536
536
537 @onlyif_unicode_paths
537 @onlyif_unicode_paths
538 def test_1(self):
538 def test_1(self):
539 """Test system_raw with non-ascii cmd
539 """Test system_raw with non-ascii cmd
540 """
540 """
541 cmd = u'''python -c "'Γ₯Àâ'" '''
541 cmd = u'''python -c "'Γ₯Àâ'" '''
542 ip.system_raw(cmd)
542 ip.system_raw(cmd)
543
543
544 # TODO: Exit codes are currently ignored on Windows.
544 # TODO: Exit codes are currently ignored on Windows.
545 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
545 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
546 system = ip.system_piped
546 system = ip.system_piped
547
547
548 @skip_win32
548 @skip_win32
549 def test_exit_code_ok(self):
549 def test_exit_code_ok(self):
550 ExitCodeChecks.test_exit_code_ok(self)
550 ExitCodeChecks.test_exit_code_ok(self)
551
551
552 @skip_win32
552 @skip_win32
553 def test_exit_code_error(self):
553 def test_exit_code_error(self):
554 ExitCodeChecks.test_exit_code_error(self)
554 ExitCodeChecks.test_exit_code_error(self)
555
555
556 @skip_win32
556 @skip_win32
557 def test_exit_code_signal(self):
557 def test_exit_code_signal(self):
558 ExitCodeChecks.test_exit_code_signal(self)
558 ExitCodeChecks.test_exit_code_signal(self)
559
559
560 class TestModules(unittest.TestCase, tt.TempFileMixin):
560 class TestModules(unittest.TestCase, tt.TempFileMixin):
561 def test_extraneous_loads(self):
561 def test_extraneous_loads(self):
562 """Test we're not loading modules on startup that we shouldn't.
562 """Test we're not loading modules on startup that we shouldn't.
563 """
563 """
564 self.mktmp("import sys\n"
564 self.mktmp("import sys\n"
565 "print('numpy' in sys.modules)\n"
565 "print('numpy' in sys.modules)\n"
566 "print('IPython.parallel' in sys.modules)\n"
566 "print('IPython.parallel' in sys.modules)\n"
567 "print('IPython.kernel.zmq' in sys.modules)\n"
567 "print('IPython.kernel.zmq' in sys.modules)\n"
568 )
568 )
569 out = "False\nFalse\nFalse\n"
569 out = "False\nFalse\nFalse\n"
570 tt.ipexec_validate(self.fname, out)
570 tt.ipexec_validate(self.fname, out)
571
571
572 class Negator(ast.NodeTransformer):
572 class Negator(ast.NodeTransformer):
573 """Negates all number literals in an AST."""
573 """Negates all number literals in an AST."""
574 def visit_Num(self, node):
574 def visit_Num(self, node):
575 node.n = -node.n
575 node.n = -node.n
576 return node
576 return node
577
577
578 class TestAstTransform(unittest.TestCase):
578 class TestAstTransform(unittest.TestCase):
579 def setUp(self):
579 def setUp(self):
580 self.negator = Negator()
580 self.negator = Negator()
581 ip.ast_transformers.append(self.negator)
581 ip.ast_transformers.append(self.negator)
582
582
583 def tearDown(self):
583 def tearDown(self):
584 ip.ast_transformers.remove(self.negator)
584 ip.ast_transformers.remove(self.negator)
585
585
586 def test_run_cell(self):
586 def test_run_cell(self):
587 with tt.AssertPrints('-34'):
587 with tt.AssertPrints('-34'):
588 ip.run_cell('print (12 + 22)')
588 ip.run_cell('print (12 + 22)')
589
589
590 # A named reference to a number shouldn't be transformed.
590 # A named reference to a number shouldn't be transformed.
591 ip.user_ns['n'] = 55
591 ip.user_ns['n'] = 55
592 with tt.AssertNotPrints('-55'):
592 with tt.AssertNotPrints('-55'):
593 ip.run_cell('print (n)')
593 ip.run_cell('print (n)')
594
594
595 def test_timeit(self):
595 def test_timeit(self):
596 called = set()
596 called = set()
597 def f(x):
597 def f(x):
598 called.add(x)
598 called.add(x)
599 ip.push({'f':f})
599 ip.push({'f':f})
600
600
601 with tt.AssertPrints("best of "):
601 with tt.AssertPrints("best of "):
602 ip.run_line_magic("timeit", "-n1 f(1)")
602 ip.run_line_magic("timeit", "-n1 f(1)")
603 self.assertEqual(called, set([-1]))
603 self.assertEqual(called, set([-1]))
604 called.clear()
604 called.clear()
605
605
606 with tt.AssertPrints("best of "):
606 with tt.AssertPrints("best of "):
607 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
607 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
608 self.assertEqual(called, set([-2, -3]))
608 self.assertEqual(called, set([-2, -3]))
609
609
610 def test_time(self):
610 def test_time(self):
611 called = []
611 called = []
612 def f(x):
612 def f(x):
613 called.append(x)
613 called.append(x)
614 ip.push({'f':f})
614 ip.push({'f':f})
615
615
616 # Test with an expression
616 # Test with an expression
617 with tt.AssertPrints("Wall time: "):
617 with tt.AssertPrints("Wall time: "):
618 ip.run_line_magic("time", "f(5+9)")
618 ip.run_line_magic("time", "f(5+9)")
619 self.assertEqual(called, [-14])
619 self.assertEqual(called, [-14])
620 called[:] = []
620 called[:] = []
621
621
622 # Test with a statement (different code path)
622 # Test with a statement (different code path)
623 with tt.AssertPrints("Wall time: "):
623 with tt.AssertPrints("Wall time: "):
624 ip.run_line_magic("time", "a = f(-3 + -2)")
624 ip.run_line_magic("time", "a = f(-3 + -2)")
625 self.assertEqual(called, [5])
625 self.assertEqual(called, [5])
626
626
627 def test_macro(self):
627 def test_macro(self):
628 ip.push({'a':10})
628 ip.push({'a':10})
629 # The AST transformation makes this do a+=-1
629 # The AST transformation makes this do a+=-1
630 ip.define_macro("amacro", "a+=1\nprint(a)")
630 ip.define_macro("amacro", "a+=1\nprint(a)")
631
631
632 with tt.AssertPrints("9"):
632 with tt.AssertPrints("9"):
633 ip.run_cell("amacro")
633 ip.run_cell("amacro")
634 with tt.AssertPrints("8"):
634 with tt.AssertPrints("8"):
635 ip.run_cell("amacro")
635 ip.run_cell("amacro")
636
636
637 class IntegerWrapper(ast.NodeTransformer):
637 class IntegerWrapper(ast.NodeTransformer):
638 """Wraps all integers in a call to Integer()"""
638 """Wraps all integers in a call to Integer()"""
639 def visit_Num(self, node):
639 def visit_Num(self, node):
640 if isinstance(node.n, int):
640 if isinstance(node.n, int):
641 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
641 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
642 args=[node], keywords=[])
642 args=[node], keywords=[])
643 return node
643 return node
644
644
645 class TestAstTransform2(unittest.TestCase):
645 class TestAstTransform2(unittest.TestCase):
646 def setUp(self):
646 def setUp(self):
647 self.intwrapper = IntegerWrapper()
647 self.intwrapper = IntegerWrapper()
648 ip.ast_transformers.append(self.intwrapper)
648 ip.ast_transformers.append(self.intwrapper)
649
649
650 self.calls = []
650 self.calls = []
651 def Integer(*args):
651 def Integer(*args):
652 self.calls.append(args)
652 self.calls.append(args)
653 return args
653 return args
654 ip.push({"Integer": Integer})
654 ip.push({"Integer": Integer})
655
655
656 def tearDown(self):
656 def tearDown(self):
657 ip.ast_transformers.remove(self.intwrapper)
657 ip.ast_transformers.remove(self.intwrapper)
658 del ip.user_ns['Integer']
658 del ip.user_ns['Integer']
659
659
660 def test_run_cell(self):
660 def test_run_cell(self):
661 ip.run_cell("n = 2")
661 ip.run_cell("n = 2")
662 self.assertEqual(self.calls, [(2,)])
662 self.assertEqual(self.calls, [(2,)])
663
663
664 # This shouldn't throw an error
664 # This shouldn't throw an error
665 ip.run_cell("o = 2.0")
665 ip.run_cell("o = 2.0")
666 self.assertEqual(ip.user_ns['o'], 2.0)
666 self.assertEqual(ip.user_ns['o'], 2.0)
667
667
668 def test_timeit(self):
668 def test_timeit(self):
669 called = set()
669 called = set()
670 def f(x):
670 def f(x):
671 called.add(x)
671 called.add(x)
672 ip.push({'f':f})
672 ip.push({'f':f})
673
673
674 with tt.AssertPrints("best of "):
674 with tt.AssertPrints("best of "):
675 ip.run_line_magic("timeit", "-n1 f(1)")
675 ip.run_line_magic("timeit", "-n1 f(1)")
676 self.assertEqual(called, set([(1,)]))
676 self.assertEqual(called, set([(1,)]))
677 called.clear()
677 called.clear()
678
678
679 with tt.AssertPrints("best of "):
679 with tt.AssertPrints("best of "):
680 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
680 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
681 self.assertEqual(called, set([(2,), (3,)]))
681 self.assertEqual(called, set([(2,), (3,)]))
682
682
683 class ErrorTransformer(ast.NodeTransformer):
683 class ErrorTransformer(ast.NodeTransformer):
684 """Throws an error when it sees a number."""
684 """Throws an error when it sees a number."""
685 def visit_Num(self, node):
685 def visit_Num(self, node):
686 raise ValueError("test")
686 raise ValueError("test")
687
687
688 class TestAstTransformError(unittest.TestCase):
688 class TestAstTransformError(unittest.TestCase):
689 def test_unregistering(self):
689 def test_unregistering(self):
690 err_transformer = ErrorTransformer()
690 err_transformer = ErrorTransformer()
691 ip.ast_transformers.append(err_transformer)
691 ip.ast_transformers.append(err_transformer)
692
692
693 with tt.AssertPrints("unregister", channel='stderr'):
693 with tt.AssertPrints("unregister", channel='stderr'):
694 ip.run_cell("1 + 2")
694 ip.run_cell("1 + 2")
695
695
696 # This should have been removed.
696 # This should have been removed.
697 nt.assert_not_in(err_transformer, ip.ast_transformers)
697 nt.assert_not_in(err_transformer, ip.ast_transformers)
698
698
699
699
700 class StringRejector(ast.NodeTransformer):
700 class StringRejector(ast.NodeTransformer):
701 """Throws an InputRejected when it sees a string literal.
701 """Throws an InputRejected when it sees a string literal.
702
702
703 Used to verify that NodeTransformers can signal that a piece of code should
703 Used to verify that NodeTransformers can signal that a piece of code should
704 not be executed by throwing an InputRejected.
704 not be executed by throwing an InputRejected.
705 """
705 """
706
706
707 def visit_Str(self, node):
707 def visit_Str(self, node):
708 raise InputRejected("test")
708 raise InputRejected("test")
709
709
710
710
711 class TestAstTransformInputRejection(unittest.TestCase):
711 class TestAstTransformInputRejection(unittest.TestCase):
712
712
713 def setUp(self):
713 def setUp(self):
714 self.transformer = StringRejector()
714 self.transformer = StringRejector()
715 ip.ast_transformers.append(self.transformer)
715 ip.ast_transformers.append(self.transformer)
716
716
717 def tearDown(self):
717 def tearDown(self):
718 ip.ast_transformers.remove(self.transformer)
718 ip.ast_transformers.remove(self.transformer)
719
719
720 def test_input_rejection(self):
720 def test_input_rejection(self):
721 """Check that NodeTransformers can reject input."""
721 """Check that NodeTransformers can reject input."""
722
722
723 expect_exception_tb = tt.AssertPrints("InputRejected: test")
723 expect_exception_tb = tt.AssertPrints("InputRejected: test")
724 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
724 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
725
725
726 # Run the same check twice to verify that the transformer is not
726 # Run the same check twice to verify that the transformer is not
727 # disabled after raising.
727 # disabled after raising.
728 with expect_exception_tb, expect_no_cell_output:
728 with expect_exception_tb, expect_no_cell_output:
729 ip.run_cell("'unsafe'")
729 ip.run_cell("'unsafe'")
730
730
731 with expect_exception_tb, expect_no_cell_output:
731 with expect_exception_tb, expect_no_cell_output:
732 ip.run_cell("'unsafe'")
732 ip.run_cell("'unsafe'")
733
733
734 def test__IPYTHON__():
734 def test__IPYTHON__():
735 # This shouldn't raise a NameError, that's all
735 # This shouldn't raise a NameError, that's all
736 __IPYTHON__
736 __IPYTHON__
737
737
738
738
739 class DummyRepr(object):
739 class DummyRepr(object):
740 def __repr__(self):
740 def __repr__(self):
741 return "DummyRepr"
741 return "DummyRepr"
742
742
743 def _repr_html_(self):
743 def _repr_html_(self):
744 return "<b>dummy</b>"
744 return "<b>dummy</b>"
745
745
746 def _repr_javascript_(self):
746 def _repr_javascript_(self):
747 return "console.log('hi');", {'key': 'value'}
747 return "console.log('hi');", {'key': 'value'}
748
748
749
749
750 def test_user_variables():
750 def test_user_variables():
751 # enable all formatters
751 # enable all formatters
752 ip.display_formatter.active_types = ip.display_formatter.format_types
752 ip.display_formatter.active_types = ip.display_formatter.format_types
753
753
754 ip.user_ns['dummy'] = d = DummyRepr()
754 ip.user_ns['dummy'] = d = DummyRepr()
755 keys = set(['dummy', 'doesnotexist'])
755 keys = set(['dummy', 'doesnotexist'])
756 r = ip.user_expressions({ key:key for key in keys})
756 r = ip.user_expressions({ key:key for key in keys})
757
757
758 nt.assert_equal(keys, set(r.keys()))
758 nt.assert_equal(keys, set(r.keys()))
759 dummy = r['dummy']
759 dummy = r['dummy']
760 nt.assert_equal(set(['status', 'data', 'metadata']), set(dummy.keys()))
760 nt.assert_equal(set(['status', 'data', 'metadata']), set(dummy.keys()))
761 nt.assert_equal(dummy['status'], 'ok')
761 nt.assert_equal(dummy['status'], 'ok')
762 data = dummy['data']
762 data = dummy['data']
763 metadata = dummy['metadata']
763 metadata = dummy['metadata']
764 nt.assert_equal(data.get('text/html'), d._repr_html_())
764 nt.assert_equal(data.get('text/html'), d._repr_html_())
765 js, jsmd = d._repr_javascript_()
765 js, jsmd = d._repr_javascript_()
766 nt.assert_equal(data.get('application/javascript'), js)
766 nt.assert_equal(data.get('application/javascript'), js)
767 nt.assert_equal(metadata.get('application/javascript'), jsmd)
767 nt.assert_equal(metadata.get('application/javascript'), jsmd)
768
768
769 dne = r['doesnotexist']
769 dne = r['doesnotexist']
770 nt.assert_equal(dne['status'], 'error')
770 nt.assert_equal(dne['status'], 'error')
771 nt.assert_equal(dne['ename'], 'NameError')
771 nt.assert_equal(dne['ename'], 'NameError')
772
772
773 # back to text only
773 # back to text only
774 ip.display_formatter.active_types = ['text/plain']
774 ip.display_formatter.active_types = ['text/plain']
775
775
776 def test_user_expression():
776 def test_user_expression():
777 # enable all formatters
777 # enable all formatters
778 ip.display_formatter.active_types = ip.display_formatter.format_types
778 ip.display_formatter.active_types = ip.display_formatter.format_types
779 query = {
779 query = {
780 'a' : '1 + 2',
780 'a' : '1 + 2',
781 'b' : '1/0',
781 'b' : '1/0',
782 }
782 }
783 r = ip.user_expressions(query)
783 r = ip.user_expressions(query)
784 import pprint
784 import pprint
785 pprint.pprint(r)
785 pprint.pprint(r)
786 nt.assert_equal(set(r.keys()), set(query.keys()))
786 nt.assert_equal(set(r.keys()), set(query.keys()))
787 a = r['a']
787 a = r['a']
788 nt.assert_equal(set(['status', 'data', 'metadata']), set(a.keys()))
788 nt.assert_equal(set(['status', 'data', 'metadata']), set(a.keys()))
789 nt.assert_equal(a['status'], 'ok')
789 nt.assert_equal(a['status'], 'ok')
790 data = a['data']
790 data = a['data']
791 metadata = a['metadata']
791 metadata = a['metadata']
792 nt.assert_equal(data.get('text/plain'), '3')
792 nt.assert_equal(data.get('text/plain'), '3')
793
793
794 b = r['b']
794 b = r['b']
795 nt.assert_equal(b['status'], 'error')
795 nt.assert_equal(b['status'], 'error')
796 nt.assert_equal(b['ename'], 'ZeroDivisionError')
796 nt.assert_equal(b['ename'], 'ZeroDivisionError')
797
797
798 # back to text only
798 # back to text only
799 ip.display_formatter.active_types = ['text/plain']
799 ip.display_formatter.active_types = ['text/plain']
800
800
801
801
802
802
803
803
804
804
805 class TestSyntaxErrorTransformer(unittest.TestCase):
805 class TestSyntaxErrorTransformer(unittest.TestCase):
806 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
806 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
807
807
808 class SyntaxErrorTransformer(InputTransformer):
808 class SyntaxErrorTransformer(InputTransformer):
809
809
810 def push(self, line):
810 def push(self, line):
811 pos = line.find('syntaxerror')
811 pos = line.find('syntaxerror')
812 if pos >= 0:
812 if pos >= 0:
813 e = SyntaxError('input contains "syntaxerror"')
813 e = SyntaxError('input contains "syntaxerror"')
814 e.text = line
814 e.text = line
815 e.offset = pos + 1
815 e.offset = pos + 1
816 raise e
816 raise e
817 return line
817 return line
818
818
819 def reset(self):
819 def reset(self):
820 pass
820 pass
821
821
822 def setUp(self):
822 def setUp(self):
823 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
823 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
824 ip.input_splitter.python_line_transforms.append(self.transformer)
824 ip.input_splitter.python_line_transforms.append(self.transformer)
825 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
825 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
826
826
827 def tearDown(self):
827 def tearDown(self):
828 ip.input_splitter.python_line_transforms.remove(self.transformer)
828 ip.input_splitter.python_line_transforms.remove(self.transformer)
829 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
829 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
830
830
831 def test_syntaxerror_input_transformer(self):
831 def test_syntaxerror_input_transformer(self):
832 with tt.AssertPrints('1234'):
832 with tt.AssertPrints('1234'):
833 ip.run_cell('1234')
833 ip.run_cell('1234')
834 with tt.AssertPrints('SyntaxError: invalid syntax'):
834 with tt.AssertPrints('SyntaxError: invalid syntax'):
835 ip.run_cell('1 2 3') # plain python syntax error
835 ip.run_cell('1 2 3') # plain python syntax error
836 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
836 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
837 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
837 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
838 with tt.AssertPrints('3456'):
838 with tt.AssertPrints('3456'):
839 ip.run_cell('3456')
839 ip.run_cell('3456')
840
840
841
841
842
842
@@ -1,126 +1,116 b''
1 """
1 """
2 Password generation for the IPython notebook.
2 Password generation for the IPython notebook.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Imports
5 # Imports
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Stdlib
7 # Stdlib
8 import getpass
8 import getpass
9 import hashlib
9 import hashlib
10 import random
10 import random
11
11
12 # Our own
12 # Our own
13 from IPython.core.error import UsageError
13 from IPython.core.error import UsageError
14 from IPython.testing.skipdoctest import skip_doctest
14 from IPython.testing.skipdoctest import skip_doctest
15 from IPython.utils.py3compat import cast_bytes, str_to_bytes
15 from IPython.utils.py3compat import cast_bytes, str_to_bytes
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Globals
18 # Globals
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 # Length of the salt in nr of hex chars, which implies salt_len * 4
21 # Length of the salt in nr of hex chars, which implies salt_len * 4
22 # bits of randomness.
22 # bits of randomness.
23 salt_len = 12
23 salt_len = 12
24
24
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Functions
26 # Functions
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28
28
29 @skip_doctest
29 @skip_doctest
30 def passwd(passphrase=None, algorithm='sha1'):
30 def passwd(passphrase=None, algorithm='sha1'):
31 """Generate hashed password and salt for use in notebook configuration.
31 """Generate hashed password and salt for use in notebook configuration.
32
32
33 In the notebook configuration, set `c.NotebookApp.password` to
33 In the notebook configuration, set `c.NotebookApp.password` to
34 the generated string.
34 the generated string.
35
35
36 Parameters
36 Parameters
37 ----------
37 ----------
38 passphrase : str
38 passphrase : str
39 Password to hash. If unspecified, the user is asked to input
39 Password to hash. If unspecified, the user is asked to input
40 and verify a password.
40 and verify a password.
41 algorithm : str
41 algorithm : str
42 Hashing algorithm to use (e.g, 'sha1' or any argument supported
42 Hashing algorithm to use (e.g, 'sha1' or any argument supported
43 by :func:`hashlib.new`).
43 by :func:`hashlib.new`).
44
44
45 Returns
45 Returns
46 -------
46 -------
47 hashed_passphrase : str
47 hashed_passphrase : str
48 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
48 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
49
49
50 Examples
50 Examples
51 --------
51 --------
52 >>> passwd('mypassword')
52 >>> passwd('mypassword')
53 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
53 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
54
54
55 """
55 """
56 if passphrase is None:
56 if passphrase is None:
57 for i in range(3):
57 for i in range(3):
58 p0 = getpass.getpass('Enter password: ')
58 p0 = getpass.getpass('Enter password: ')
59 p1 = getpass.getpass('Verify password: ')
59 p1 = getpass.getpass('Verify password: ')
60 if p0 == p1:
60 if p0 == p1:
61 passphrase = p0
61 passphrase = p0
62 break
62 break
63 else:
63 else:
64 print('Passwords do not match.')
64 print('Passwords do not match.')
65 else:
65 else:
66 raise UsageError('No matching passwords found. Giving up.')
66 raise UsageError('No matching passwords found. Giving up.')
67
67
68 h = hashlib.new(algorithm)
68 h = hashlib.new(algorithm)
69 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
69 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
70 h.update(cast_bytes(passphrase, 'utf-8') + str_to_bytes(salt, 'ascii'))
70 h.update(cast_bytes(passphrase, 'utf-8') + str_to_bytes(salt, 'ascii'))
71
71
72 return ':'.join((algorithm, salt, h.hexdigest()))
72 return ':'.join((algorithm, salt, h.hexdigest()))
73
73
74
74
75 def passwd_check(hashed_passphrase, passphrase):
75 def passwd_check(hashed_passphrase, passphrase):
76 """Verify that a given passphrase matches its hashed version.
76 """Verify that a given passphrase matches its hashed version.
77
77
78 Parameters
78 Parameters
79 ----------
79 ----------
80 hashed_passphrase : str
80 hashed_passphrase : str
81 Hashed password, in the format returned by `passwd`.
81 Hashed password, in the format returned by `passwd`.
82 passphrase : str
82 passphrase : str
83 Passphrase to validate.
83 Passphrase to validate.
84
84
85 Returns
85 Returns
86 -------
86 -------
87 valid : bool
87 valid : bool
88 True if the passphrase matches the hash.
88 True if the passphrase matches the hash.
89
89
90 Examples
90 Examples
91 --------
91 --------
92 >>> from IPython.lib.security import passwd_check
92 >>> from IPython.lib.security import passwd_check
93 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
93 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
94 ... 'mypassword')
94 ... 'mypassword')
95 True
95 True
96
96
97 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
97 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
98 ... 'anotherpassword')
98 ... 'anotherpassword')
99 False
99 False
100 """
100 """
101 try:
101 try:
102 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
102 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
103 except (ValueError, TypeError):
103 except (ValueError, TypeError):
104 return False
104 return False
105
105
106 try:
106 try:
107 h = hashlib.new(algorithm)
107 h = hashlib.new(algorithm)
108 except ValueError:
108 except ValueError:
109 return False
109 return False
110
110
111 if len(pw_digest) == 0:
111 if len(pw_digest) == 0:
112 return False
112 return False
113
113
114 h.update(cast_bytes(passphrase, 'utf-8') + cast_bytes(salt, 'ascii'))
114 h.update(cast_bytes(passphrase, 'utf-8') + cast_bytes(salt, 'ascii'))
115
115
116 return h.hexdigest() == pw_digest
116 return h.hexdigest() == pw_digest
117
118 #-----------------------------------------------------------------------------
119 # Exception classes
120 #-----------------------------------------------------------------------------
121 class InputRejected(Exception):
122 """Input rejected by ast transformer.
123
124 To be raised by user-supplied ast.NodeTransformers when an input should not
125 be executed.
126 """
General Comments 0
You need to be logged in to leave comments. Login now