##// END OF EJS Templates
Update test for new transformation API
Thomas Kluyver -
Show More
@@ -1,176 +1,170
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the TerminalInteractiveShell and related pieces."""
2 """Tests for the TerminalInteractiveShell and related pieces."""
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import sys
6 import sys
7 import unittest
7 import unittest
8
8
9 from IPython.core.inputtransformer import InputTransformer
9 from IPython.core.inputtransformer import InputTransformer
10 from IPython.testing import tools as tt
10 from IPython.testing import tools as tt
11 from IPython.utils.capture import capture_output
11 from IPython.utils.capture import capture_output
12
12
13 from IPython.terminal.ptutils import _elide, _adjust_completion_text_based_on_context
13 from IPython.terminal.ptutils import _elide, _adjust_completion_text_based_on_context
14 import nose.tools as nt
14 import nose.tools as nt
15
15
16 class TestElide(unittest.TestCase):
16 class TestElide(unittest.TestCase):
17
17
18 def test_elide(self):
18 def test_elide(self):
19 _elide('concatenate((a1, a2, ...), axis') # do not raise
19 _elide('concatenate((a1, a2, ...), axis') # do not raise
20 _elide('concatenate((a1, a2, ..), . axis') # do not raise
20 _elide('concatenate((a1, a2, ..), . axis') # do not raise
21 nt.assert_equal(_elide('aaaa.bbbb.ccccc.dddddd.eeeee.fffff.gggggg.hhhhhh'), 'aaaa.b…g.hhhhhh')
21 nt.assert_equal(_elide('aaaa.bbbb.ccccc.dddddd.eeeee.fffff.gggggg.hhhhhh'), 'aaaa.b…g.hhhhhh')
22
22
23
23
24 class TestContextAwareCompletion(unittest.TestCase):
24 class TestContextAwareCompletion(unittest.TestCase):
25
25
26 def test_adjust_completion_text_based_on_context(self):
26 def test_adjust_completion_text_based_on_context(self):
27 # Adjusted case
27 # Adjusted case
28 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a=)', 7), 'arg1')
28 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a=)', 7), 'arg1')
29
29
30 # Untouched cases
30 # Untouched cases
31 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a)', 7), 'arg1=')
31 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a)', 7), 'arg1=')
32 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a', 7), 'arg1=')
32 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a', 7), 'arg1=')
33 nt.assert_equal(_adjust_completion_text_based_on_context('%magic', 'func1(a=)', 7), '%magic')
33 nt.assert_equal(_adjust_completion_text_based_on_context('%magic', 'func1(a=)', 7), '%magic')
34 nt.assert_equal(_adjust_completion_text_based_on_context('func2', 'func1(a=)', 7), 'func2')
34 nt.assert_equal(_adjust_completion_text_based_on_context('func2', 'func1(a=)', 7), 'func2')
35
35
36 # Decorator for interaction loop tests -----------------------------------------
36 # Decorator for interaction loop tests -----------------------------------------
37
37
38 class mock_input_helper(object):
38 class mock_input_helper(object):
39 """Machinery for tests of the main interact loop.
39 """Machinery for tests of the main interact loop.
40
40
41 Used by the mock_input decorator.
41 Used by the mock_input decorator.
42 """
42 """
43 def __init__(self, testgen):
43 def __init__(self, testgen):
44 self.testgen = testgen
44 self.testgen = testgen
45 self.exception = None
45 self.exception = None
46 self.ip = get_ipython()
46 self.ip = get_ipython()
47
47
48 def __enter__(self):
48 def __enter__(self):
49 self.orig_prompt_for_code = self.ip.prompt_for_code
49 self.orig_prompt_for_code = self.ip.prompt_for_code
50 self.ip.prompt_for_code = self.fake_input
50 self.ip.prompt_for_code = self.fake_input
51 return self
51 return self
52
52
53 def __exit__(self, etype, value, tb):
53 def __exit__(self, etype, value, tb):
54 self.ip.prompt_for_code = self.orig_prompt_for_code
54 self.ip.prompt_for_code = self.orig_prompt_for_code
55
55
56 def fake_input(self):
56 def fake_input(self):
57 try:
57 try:
58 return next(self.testgen)
58 return next(self.testgen)
59 except StopIteration:
59 except StopIteration:
60 self.ip.keep_running = False
60 self.ip.keep_running = False
61 return u''
61 return u''
62 except:
62 except:
63 self.exception = sys.exc_info()
63 self.exception = sys.exc_info()
64 self.ip.keep_running = False
64 self.ip.keep_running = False
65 return u''
65 return u''
66
66
67 def mock_input(testfunc):
67 def mock_input(testfunc):
68 """Decorator for tests of the main interact loop.
68 """Decorator for tests of the main interact loop.
69
69
70 Write the test as a generator, yield-ing the input strings, which IPython
70 Write the test as a generator, yield-ing the input strings, which IPython
71 will see as if they were typed in at the prompt.
71 will see as if they were typed in at the prompt.
72 """
72 """
73 def test_method(self):
73 def test_method(self):
74 testgen = testfunc(self)
74 testgen = testfunc(self)
75 with mock_input_helper(testgen) as mih:
75 with mock_input_helper(testgen) as mih:
76 mih.ip.interact()
76 mih.ip.interact()
77
77
78 if mih.exception is not None:
78 if mih.exception is not None:
79 # Re-raise captured exception
79 # Re-raise captured exception
80 etype, value, tb = mih.exception
80 etype, value, tb = mih.exception
81 import traceback
81 import traceback
82 traceback.print_tb(tb, file=sys.stdout)
82 traceback.print_tb(tb, file=sys.stdout)
83 del tb # Avoid reference loop
83 del tb # Avoid reference loop
84 raise value
84 raise value
85
85
86 return test_method
86 return test_method
87
87
88 # Test classes -----------------------------------------------------------------
88 # Test classes -----------------------------------------------------------------
89
89
90 class InteractiveShellTestCase(unittest.TestCase):
90 class InteractiveShellTestCase(unittest.TestCase):
91 def rl_hist_entries(self, rl, n):
91 def rl_hist_entries(self, rl, n):
92 """Get last n readline history entries as a list"""
92 """Get last n readline history entries as a list"""
93 return [rl.get_history_item(rl.get_current_history_length() - x)
93 return [rl.get_history_item(rl.get_current_history_length() - x)
94 for x in range(n - 1, -1, -1)]
94 for x in range(n - 1, -1, -1)]
95
95
96 @mock_input
96 @mock_input
97 def test_inputtransformer_syntaxerror(self):
97 def test_inputtransformer_syntaxerror(self):
98 ip = get_ipython()
98 ip = get_ipython()
99 transformer = SyntaxErrorTransformer()
99 ip.input_transformer_manager.line_transforms.append(syntax_error_transformer)
100 ip.input_splitter.python_line_transforms.append(transformer)
101 ip.input_transformer_manager.python_line_transforms.append(transformer)
102
100
103 try:
101 try:
104 #raise Exception
102 #raise Exception
105 with tt.AssertPrints('4', suppress=False):
103 with tt.AssertPrints('4', suppress=False):
106 yield u'print(2*2)'
104 yield u'print(2*2)'
107
105
108 with tt.AssertPrints('SyntaxError: input contains', suppress=False):
106 with tt.AssertPrints('SyntaxError: input contains', suppress=False):
109 yield u'print(2345) # syntaxerror'
107 yield u'print(2345) # syntaxerror'
110
108
111 with tt.AssertPrints('16', suppress=False):
109 with tt.AssertPrints('16', suppress=False):
112 yield u'print(4*4)'
110 yield u'print(4*4)'
113
111
114 finally:
112 finally:
115 ip.input_splitter.python_line_transforms.remove(transformer)
113 ip.input_transformer_manager.line_transforms.remove(syntax_error_transformer)
116 ip.input_transformer_manager.python_line_transforms.remove(transformer)
117
114
118 def test_plain_text_only(self):
115 def test_plain_text_only(self):
119 ip = get_ipython()
116 ip = get_ipython()
120 formatter = ip.display_formatter
117 formatter = ip.display_formatter
121 assert formatter.active_types == ['text/plain']
118 assert formatter.active_types == ['text/plain']
122 assert not formatter.ipython_display_formatter.enabled
119 assert not formatter.ipython_display_formatter.enabled
123
120
124 class Test(object):
121 class Test(object):
125 def __repr__(self):
122 def __repr__(self):
126 return "<Test %i>" % id(self)
123 return "<Test %i>" % id(self)
127
124
128 def _repr_html_(self):
125 def _repr_html_(self):
129 return '<html>'
126 return '<html>'
130
127
131 # verify that HTML repr isn't computed
128 # verify that HTML repr isn't computed
132 obj = Test()
129 obj = Test()
133 data, _ = formatter.format(obj)
130 data, _ = formatter.format(obj)
134 self.assertEqual(data, {'text/plain': repr(obj)})
131 self.assertEqual(data, {'text/plain': repr(obj)})
135
132
136 class Test2(Test):
133 class Test2(Test):
137 def _ipython_display_(self):
134 def _ipython_display_(self):
138 from IPython.display import display
135 from IPython.display import display
139 display('<custom>')
136 display('<custom>')
140
137
141 # verify that _ipython_display_ shortcut isn't called
138 # verify that _ipython_display_ shortcut isn't called
142 obj = Test2()
139 obj = Test2()
143 with capture_output() as captured:
140 with capture_output() as captured:
144 data, _ = formatter.format(obj)
141 data, _ = formatter.format(obj)
145
142
146 self.assertEqual(data, {'text/plain': repr(obj)})
143 self.assertEqual(data, {'text/plain': repr(obj)})
147 assert captured.stdout == ''
144 assert captured.stdout == ''
148
145
149
146 def syntax_error_transformer(lines):
150
147 """Transformer that throws SyntaxError if 'syntaxerror' is in the code."""
151 class SyntaxErrorTransformer(InputTransformer):
148 for line in lines:
152 def push(self, line):
153 pos = line.find('syntaxerror')
149 pos = line.find('syntaxerror')
154 if pos >= 0:
150 if pos >= 0:
155 e = SyntaxError('input contains "syntaxerror"')
151 e = SyntaxError('input contains "syntaxerror"')
156 e.text = line
152 e.text = line
157 e.offset = pos + 1
153 e.offset = pos + 1
158 raise e
154 raise e
159 return line
155 return lines
160
156
161 def reset(self):
162 pass
163
157
164 class TerminalMagicsTestCase(unittest.TestCase):
158 class TerminalMagicsTestCase(unittest.TestCase):
165 def test_paste_magics_blankline(self):
159 def test_paste_magics_blankline(self):
166 """Test that code with a blank line doesn't get split (gh-3246)."""
160 """Test that code with a blank line doesn't get split (gh-3246)."""
167 ip = get_ipython()
161 ip = get_ipython()
168 s = ('def pasted_func(a):\n'
162 s = ('def pasted_func(a):\n'
169 ' b = a+1\n'
163 ' b = a+1\n'
170 '\n'
164 '\n'
171 ' return b')
165 ' return b')
172
166
173 tm = ip.magics_manager.registry['TerminalMagics']
167 tm = ip.magics_manager.registry['TerminalMagics']
174 tm.store_or_execute(s, name=None)
168 tm.store_or_execute(s, name=None)
175
169
176 self.assertEqual(ip.user_ns['pasted_func'](54), 55)
170 self.assertEqual(ip.user_ns['pasted_func'](54), 55)
General Comments 0
You need to be logged in to leave comments. Login now