##// END OF EJS Templates
Update test for new transformation API
Thomas Kluyver -
Show More
@@ -1,176 +1,170
1 1 # -*- coding: utf-8 -*-
2 2 """Tests for the TerminalInteractiveShell and related pieces."""
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import sys
7 7 import unittest
8 8
9 9 from IPython.core.inputtransformer import InputTransformer
10 10 from IPython.testing import tools as tt
11 11 from IPython.utils.capture import capture_output
12 12
13 13 from IPython.terminal.ptutils import _elide, _adjust_completion_text_based_on_context
14 14 import nose.tools as nt
15 15
16 16 class TestElide(unittest.TestCase):
17 17
18 18 def test_elide(self):
19 19 _elide('concatenate((a1, a2, ...), axis') # do not raise
20 20 _elide('concatenate((a1, a2, ..), . axis') # do not raise
21 21 nt.assert_equal(_elide('aaaa.bbbb.ccccc.dddddd.eeeee.fffff.gggggg.hhhhhh'), 'aaaa.b…g.hhhhhh')
22 22
23 23
24 24 class TestContextAwareCompletion(unittest.TestCase):
25 25
26 26 def test_adjust_completion_text_based_on_context(self):
27 27 # Adjusted case
28 28 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a=)', 7), 'arg1')
29 29
30 30 # Untouched cases
31 31 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a)', 7), 'arg1=')
32 32 nt.assert_equal(_adjust_completion_text_based_on_context('arg1=', 'func1(a', 7), 'arg1=')
33 33 nt.assert_equal(_adjust_completion_text_based_on_context('%magic', 'func1(a=)', 7), '%magic')
34 34 nt.assert_equal(_adjust_completion_text_based_on_context('func2', 'func1(a=)', 7), 'func2')
35 35
36 36 # Decorator for interaction loop tests -----------------------------------------
37 37
38 38 class mock_input_helper(object):
39 39 """Machinery for tests of the main interact loop.
40 40
41 41 Used by the mock_input decorator.
42 42 """
43 43 def __init__(self, testgen):
44 44 self.testgen = testgen
45 45 self.exception = None
46 46 self.ip = get_ipython()
47 47
48 48 def __enter__(self):
49 49 self.orig_prompt_for_code = self.ip.prompt_for_code
50 50 self.ip.prompt_for_code = self.fake_input
51 51 return self
52 52
53 53 def __exit__(self, etype, value, tb):
54 54 self.ip.prompt_for_code = self.orig_prompt_for_code
55 55
56 56 def fake_input(self):
57 57 try:
58 58 return next(self.testgen)
59 59 except StopIteration:
60 60 self.ip.keep_running = False
61 61 return u''
62 62 except:
63 63 self.exception = sys.exc_info()
64 64 self.ip.keep_running = False
65 65 return u''
66 66
67 67 def mock_input(testfunc):
68 68 """Decorator for tests of the main interact loop.
69 69
70 70 Write the test as a generator, yield-ing the input strings, which IPython
71 71 will see as if they were typed in at the prompt.
72 72 """
73 73 def test_method(self):
74 74 testgen = testfunc(self)
75 75 with mock_input_helper(testgen) as mih:
76 76 mih.ip.interact()
77 77
78 78 if mih.exception is not None:
79 79 # Re-raise captured exception
80 80 etype, value, tb = mih.exception
81 81 import traceback
82 82 traceback.print_tb(tb, file=sys.stdout)
83 83 del tb # Avoid reference loop
84 84 raise value
85 85
86 86 return test_method
87 87
88 88 # Test classes -----------------------------------------------------------------
89 89
90 90 class InteractiveShellTestCase(unittest.TestCase):
91 91 def rl_hist_entries(self, rl, n):
92 92 """Get last n readline history entries as a list"""
93 93 return [rl.get_history_item(rl.get_current_history_length() - x)
94 94 for x in range(n - 1, -1, -1)]
95 95
96 96 @mock_input
97 97 def test_inputtransformer_syntaxerror(self):
98 98 ip = get_ipython()
99 transformer = SyntaxErrorTransformer()
100 ip.input_splitter.python_line_transforms.append(transformer)
101 ip.input_transformer_manager.python_line_transforms.append(transformer)
99 ip.input_transformer_manager.line_transforms.append(syntax_error_transformer)
102 100
103 101 try:
104 102 #raise Exception
105 103 with tt.AssertPrints('4', suppress=False):
106 104 yield u'print(2*2)'
107 105
108 106 with tt.AssertPrints('SyntaxError: input contains', suppress=False):
109 107 yield u'print(2345) # syntaxerror'
110 108
111 109 with tt.AssertPrints('16', suppress=False):
112 110 yield u'print(4*4)'
113 111
114 112 finally:
115 ip.input_splitter.python_line_transforms.remove(transformer)
116 ip.input_transformer_manager.python_line_transforms.remove(transformer)
113 ip.input_transformer_manager.line_transforms.remove(syntax_error_transformer)
117 114
118 115 def test_plain_text_only(self):
119 116 ip = get_ipython()
120 117 formatter = ip.display_formatter
121 118 assert formatter.active_types == ['text/plain']
122 119 assert not formatter.ipython_display_formatter.enabled
123 120
124 121 class Test(object):
125 122 def __repr__(self):
126 123 return "<Test %i>" % id(self)
127 124
128 125 def _repr_html_(self):
129 126 return '<html>'
130 127
131 128 # verify that HTML repr isn't computed
132 129 obj = Test()
133 130 data, _ = formatter.format(obj)
134 131 self.assertEqual(data, {'text/plain': repr(obj)})
135 132
136 133 class Test2(Test):
137 134 def _ipython_display_(self):
138 135 from IPython.display import display
139 136 display('<custom>')
140 137
141 138 # verify that _ipython_display_ shortcut isn't called
142 139 obj = Test2()
143 140 with capture_output() as captured:
144 141 data, _ = formatter.format(obj)
145 142
146 143 self.assertEqual(data, {'text/plain': repr(obj)})
147 144 assert captured.stdout == ''
148 145
149
150
151 class SyntaxErrorTransformer(InputTransformer):
152 def push(self, line):
146 def syntax_error_transformer(lines):
147 """Transformer that throws SyntaxError if 'syntaxerror' is in the code."""
148 for line in lines:
153 149 pos = line.find('syntaxerror')
154 150 if pos >= 0:
155 151 e = SyntaxError('input contains "syntaxerror"')
156 152 e.text = line
157 153 e.offset = pos + 1
158 154 raise e
159 return line
155 return lines
160 156
161 def reset(self):
162 pass
163 157
164 158 class TerminalMagicsTestCase(unittest.TestCase):
165 159 def test_paste_magics_blankline(self):
166 160 """Test that code with a blank line doesn't get split (gh-3246)."""
167 161 ip = get_ipython()
168 162 s = ('def pasted_func(a):\n'
169 163 ' b = a+1\n'
170 164 '\n'
171 165 ' return b')
172 166
173 167 tm = ip.magics_manager.registry['TerminalMagics']
174 168 tm.store_or_execute(s, name=None)
175 169
176 170 self.assertEqual(ip.user_ns['pasted_func'](54), 55)
General Comments 0
You need to be logged in to leave comments. Login now