##// END OF EJS Templates
Remove most uses of ParametricTestCase
Thomas Kluyver -
Show More
@@ -1,380 +1,379 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Test Parallel magics
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import re
20 20 import sys
21 21 import time
22 22
23 23 import zmq
24 24 from nose import SkipTest
25 25
26 26 from IPython.testing import decorators as dec
27 from IPython.testing.ipunittest import ParametricTestCase
28 27 from IPython.utils.io import capture_output
29 28
30 29 from IPython import parallel as pmod
31 30 from IPython.parallel import error
32 31 from IPython.parallel import AsyncResult
33 32 from IPython.parallel.util import interactive
34 33
35 34 from IPython.parallel.tests import add_engines
36 35
37 36 from .clienttest import ClusterTestCase, generate_output
38 37
39 38 def setup():
40 39 add_engines(3, total=True)
41 40
42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
41 class TestParallelMagics(ClusterTestCase):
43 42
44 43 def test_px_blocking(self):
45 44 ip = get_ipython()
46 45 v = self.client[-1:]
47 46 v.activate()
48 47 v.block=True
49 48
50 49 ip.magic('px a=5')
51 50 self.assertEqual(v['a'], [5])
52 51 ip.magic('px a=10')
53 52 self.assertEqual(v['a'], [10])
54 53 # just 'print a' works ~99% of the time, but this ensures that
55 54 # the stdout message has arrived when the result is finished:
56 55 with capture_output() as io:
57 56 ip.magic(
58 57 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
59 58 )
60 59 self.assertIn('[stdout:', io.stdout)
61 60 self.assertNotIn('\n\n', io.stdout)
62 61 assert io.stdout.rstrip().endswith('10')
63 62 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
64 63
65 64 def _check_generated_stderr(self, stderr, n):
66 65 expected = [
67 66 r'\[stderr:\d+\]',
68 67 '^stderr$',
69 68 '^stderr2$',
70 69 ] * n
71 70
72 71 self.assertNotIn('\n\n', stderr)
73 72 lines = stderr.splitlines()
74 73 self.assertEqual(len(lines), len(expected), stderr)
75 74 for line,expect in zip(lines, expected):
76 75 if isinstance(expect, str):
77 76 expect = [expect]
78 77 for ex in expect:
79 78 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
80 79
81 80 def test_cellpx_block_args(self):
82 81 """%%px --[no]block flags work"""
83 82 ip = get_ipython()
84 83 v = self.client[-1:]
85 84 v.activate()
86 85 v.block=False
87 86
88 87 for block in (True, False):
89 88 v.block = block
90 89 ip.magic("pxconfig --verbose")
91 90 with capture_output(display=False) as io:
92 91 ip.run_cell_magic("px", "", "1")
93 92 if block:
94 93 assert io.stdout.startswith("Parallel"), io.stdout
95 94 else:
96 95 assert io.stdout.startswith("Async"), io.stdout
97 96
98 97 with capture_output(display=False) as io:
99 98 ip.run_cell_magic("px", "--block", "1")
100 99 assert io.stdout.startswith("Parallel"), io.stdout
101 100
102 101 with capture_output(display=False) as io:
103 102 ip.run_cell_magic("px", "--noblock", "1")
104 103 assert io.stdout.startswith("Async"), io.stdout
105 104
106 105 def test_cellpx_groupby_engine(self):
107 106 """%%px --group-outputs=engine"""
108 107 ip = get_ipython()
109 108 v = self.client[:]
110 109 v.block = True
111 110 v.activate()
112 111
113 112 v['generate_output'] = generate_output
114 113
115 114 with capture_output(display=False) as io:
116 115 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
117 116
118 117 self.assertNotIn('\n\n', io.stdout)
119 118 lines = io.stdout.splitlines()
120 119 expected = [
121 120 r'\[stdout:\d+\]',
122 121 'stdout',
123 122 'stdout2',
124 123 r'\[output:\d+\]',
125 124 r'IPython\.core\.display\.HTML',
126 125 r'IPython\.core\.display\.Math',
127 126 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
128 127 ] * len(v)
129 128
130 129 self.assertEqual(len(lines), len(expected), io.stdout)
131 130 for line,expect in zip(lines, expected):
132 131 if isinstance(expect, str):
133 132 expect = [expect]
134 133 for ex in expect:
135 134 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
136 135
137 136 self._check_generated_stderr(io.stderr, len(v))
138 137
139 138
140 139 def test_cellpx_groupby_order(self):
141 140 """%%px --group-outputs=order"""
142 141 ip = get_ipython()
143 142 v = self.client[:]
144 143 v.block = True
145 144 v.activate()
146 145
147 146 v['generate_output'] = generate_output
148 147
149 148 with capture_output(display=False) as io:
150 149 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
151 150
152 151 self.assertNotIn('\n\n', io.stdout)
153 152 lines = io.stdout.splitlines()
154 153 expected = []
155 154 expected.extend([
156 155 r'\[stdout:\d+\]',
157 156 'stdout',
158 157 'stdout2',
159 158 ] * len(v))
160 159 expected.extend([
161 160 r'\[output:\d+\]',
162 161 'IPython.core.display.HTML',
163 162 ] * len(v))
164 163 expected.extend([
165 164 r'\[output:\d+\]',
166 165 'IPython.core.display.Math',
167 166 ] * len(v))
168 167 expected.extend([
169 168 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
170 169 ] * len(v))
171 170
172 171 self.assertEqual(len(lines), len(expected), io.stdout)
173 172 for line,expect in zip(lines, expected):
174 173 if isinstance(expect, str):
175 174 expect = [expect]
176 175 for ex in expect:
177 176 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
178 177
179 178 self._check_generated_stderr(io.stderr, len(v))
180 179
181 180 def test_cellpx_groupby_type(self):
182 181 """%%px --group-outputs=type"""
183 182 ip = get_ipython()
184 183 v = self.client[:]
185 184 v.block = True
186 185 v.activate()
187 186
188 187 v['generate_output'] = generate_output
189 188
190 189 with capture_output(display=False) as io:
191 190 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
192 191
193 192 self.assertNotIn('\n\n', io.stdout)
194 193 lines = io.stdout.splitlines()
195 194
196 195 expected = []
197 196 expected.extend([
198 197 r'\[stdout:\d+\]',
199 198 'stdout',
200 199 'stdout2',
201 200 ] * len(v))
202 201 expected.extend([
203 202 r'\[output:\d+\]',
204 203 r'IPython\.core\.display\.HTML',
205 204 r'IPython\.core\.display\.Math',
206 205 ] * len(v))
207 206 expected.extend([
208 207 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
209 208 ] * len(v))
210 209
211 210 self.assertEqual(len(lines), len(expected), io.stdout)
212 211 for line,expect in zip(lines, expected):
213 212 if isinstance(expect, str):
214 213 expect = [expect]
215 214 for ex in expect:
216 215 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
217 216
218 217 self._check_generated_stderr(io.stderr, len(v))
219 218
220 219
221 220 def test_px_nonblocking(self):
222 221 ip = get_ipython()
223 222 v = self.client[-1:]
224 223 v.activate()
225 224 v.block=False
226 225
227 226 ip.magic('px a=5')
228 227 self.assertEqual(v['a'], [5])
229 228 ip.magic('px a=10')
230 229 self.assertEqual(v['a'], [10])
231 230 ip.magic('pxconfig --verbose')
232 231 with capture_output() as io:
233 232 ar = ip.magic('px print (a)')
234 233 self.assertIsInstance(ar, AsyncResult)
235 234 self.assertIn('Async', io.stdout)
236 235 self.assertNotIn('[stdout:', io.stdout)
237 236 self.assertNotIn('\n\n', io.stdout)
238 237
239 238 ar = ip.magic('px 1/0')
240 239 self.assertRaisesRemote(ZeroDivisionError, ar.get)
241 240
242 241 def test_autopx_blocking(self):
243 242 ip = get_ipython()
244 243 v = self.client[-1]
245 244 v.activate()
246 245 v.block=True
247 246
248 247 with capture_output(display=False) as io:
249 248 ip.magic('autopx')
250 249 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
251 250 ip.run_cell('b*=2')
252 251 ip.run_cell('print (b)')
253 252 ip.run_cell('b')
254 253 ip.run_cell("b/c")
255 254 ip.magic('autopx')
256 255
257 256 output = io.stdout
258 257
259 258 assert output.startswith('%autopx enabled'), output
260 259 assert output.rstrip().endswith('%autopx disabled'), output
261 260 self.assertIn('ZeroDivisionError', output)
262 261 self.assertIn('\nOut[', output)
263 262 self.assertIn(': 24690', output)
264 263 ar = v.get_result(-1)
265 264 self.assertEqual(v['a'], 5)
266 265 self.assertEqual(v['b'], 24690)
267 266 self.assertRaisesRemote(ZeroDivisionError, ar.get)
268 267
269 268 def test_autopx_nonblocking(self):
270 269 ip = get_ipython()
271 270 v = self.client[-1]
272 271 v.activate()
273 272 v.block=False
274 273
275 274 with capture_output() as io:
276 275 ip.magic('autopx')
277 276 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
278 277 ip.run_cell('print (b)')
279 278 ip.run_cell('import time; time.sleep(0.1)')
280 279 ip.run_cell("b/c")
281 280 ip.run_cell('b*=2')
282 281 ip.magic('autopx')
283 282
284 283 output = io.stdout.rstrip()
285 284
286 285 assert output.startswith('%autopx enabled'), output
287 286 assert output.endswith('%autopx disabled'), output
288 287 self.assertNotIn('ZeroDivisionError', output)
289 288 ar = v.get_result(-2)
290 289 self.assertRaisesRemote(ZeroDivisionError, ar.get)
291 290 # prevent TaskAborted on pulls, due to ZeroDivisionError
292 291 time.sleep(0.5)
293 292 self.assertEqual(v['a'], 5)
294 293 # b*=2 will not fire, due to abort
295 294 self.assertEqual(v['b'], 10)
296 295
297 296 def test_result(self):
298 297 ip = get_ipython()
299 298 v = self.client[-1]
300 299 v.activate()
301 300 data = dict(a=111,b=222)
302 301 v.push(data, block=True)
303 302
304 303 for name in ('a', 'b'):
305 304 ip.magic('px ' + name)
306 305 with capture_output(display=False) as io:
307 306 ip.magic('pxresult')
308 307 self.assertIn(str(data[name]), io.stdout)
309 308
310 309 @dec.skipif_not_matplotlib
311 310 def test_px_pylab(self):
312 311 """%pylab works on engines"""
313 312 ip = get_ipython()
314 313 v = self.client[-1]
315 314 v.block = True
316 315 v.activate()
317 316
318 317 with capture_output() as io:
319 318 ip.magic("px %pylab inline")
320 319
321 320 self.assertIn("Populating the interactive namespace from numpy and matplotlib", io.stdout)
322 321
323 322 with capture_output(display=False) as io:
324 323 ip.magic("px plot(rand(100))")
325 324 self.assertIn('Out[', io.stdout)
326 325 self.assertIn('matplotlib.lines', io.stdout)
327 326
328 327 def test_pxconfig(self):
329 328 ip = get_ipython()
330 329 rc = self.client
331 330 v = rc.activate(-1, '_tst')
332 331 self.assertEqual(v.targets, rc.ids[-1])
333 332 ip.magic("%pxconfig_tst -t :")
334 333 self.assertEqual(v.targets, rc.ids)
335 334 ip.magic("%pxconfig_tst -t ::2")
336 335 self.assertEqual(v.targets, rc.ids[::2])
337 336 ip.magic("%pxconfig_tst -t 1::2")
338 337 self.assertEqual(v.targets, rc.ids[1::2])
339 338 ip.magic("%pxconfig_tst -t 1")
340 339 self.assertEqual(v.targets, 1)
341 340 ip.magic("%pxconfig_tst --block")
342 341 self.assertEqual(v.block, True)
343 342 ip.magic("%pxconfig_tst --noblock")
344 343 self.assertEqual(v.block, False)
345 344
346 345 def test_cellpx_targets(self):
347 346 """%%px --targets doesn't change defaults"""
348 347 ip = get_ipython()
349 348 rc = self.client
350 349 view = rc.activate(rc.ids)
351 350 self.assertEqual(view.targets, rc.ids)
352 351 ip.magic('pxconfig --verbose')
353 352 for cell in ("pass", "1/0"):
354 353 with capture_output(display=False) as io:
355 354 try:
356 355 ip.run_cell_magic("px", "--targets all", cell)
357 356 except pmod.RemoteError:
358 357 pass
359 358 self.assertIn('engine(s): all', io.stdout)
360 359 self.assertEqual(view.targets, rc.ids)
361 360
362 361
363 362 def test_cellpx_block(self):
364 363 """%%px --block doesn't change default"""
365 364 ip = get_ipython()
366 365 rc = self.client
367 366 view = rc.activate(rc.ids)
368 367 view.block = False
369 368 self.assertEqual(view.targets, rc.ids)
370 369 ip.magic('pxconfig --verbose')
371 370 for cell in ("pass", "1/0"):
372 371 with capture_output(display=False) as io:
373 372 try:
374 373 ip.run_cell_magic("px", "--block", cell)
375 374 except pmod.RemoteError:
376 375 pass
377 376 self.assertNotIn('Async', io.stdout)
378 377 self.assertEqual(view.block, False)
379 378
380 379
@@ -1,814 +1,813 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import base64
20 20 import sys
21 21 import platform
22 22 import time
23 23 from collections import namedtuple
24 24 from tempfile import mktemp
25 25 from StringIO import StringIO
26 26
27 27 import zmq
28 28 from nose import SkipTest
29 29 from nose.plugins.attrib import attr
30 30
31 31 from IPython.testing import decorators as dec
32 from IPython.testing.ipunittest import ParametricTestCase
33 32 from IPython.utils.io import capture_output
34 33
35 34 from IPython import parallel as pmod
36 35 from IPython.parallel import error
37 36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
38 37 from IPython.parallel import DirectView
39 38 from IPython.parallel.util import interactive
40 39
41 40 from IPython.parallel.tests import add_engines
42 41
43 42 from .clienttest import ClusterTestCase, crash, wait, skip_without
44 43
45 44 def setup():
46 45 add_engines(3, total=True)
47 46
48 47 point = namedtuple("point", "x y")
49 48
50 class TestView(ClusterTestCase, ParametricTestCase):
49 class TestView(ClusterTestCase):
51 50
52 51 def setUp(self):
53 52 # On Win XP, wait for resource cleanup, else parallel test group fails
54 53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
55 54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
56 55 time.sleep(2)
57 56 super(TestView, self).setUp()
58 57
59 58 @attr('crash')
60 59 def test_z_crash_mux(self):
61 60 """test graceful handling of engine death (direct)"""
62 61 # self.add_engines(1)
63 62 eid = self.client.ids[-1]
64 63 ar = self.client[eid].apply_async(crash)
65 64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
66 65 eid = ar.engine_id
67 66 tic = time.time()
68 67 while eid in self.client.ids and time.time()-tic < 5:
69 68 time.sleep(.01)
70 69 self.client.spin()
71 70 self.assertFalse(eid in self.client.ids, "Engine should have died")
72 71
73 72 def test_push_pull(self):
74 73 """test pushing and pulling"""
75 74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
76 75 t = self.client.ids[-1]
77 76 v = self.client[t]
78 77 push = v.push
79 78 pull = v.pull
80 79 v.block=True
81 80 nengines = len(self.client)
82 81 push({'data':data})
83 82 d = pull('data')
84 83 self.assertEqual(d, data)
85 84 self.client[:].push({'data':data})
86 85 d = self.client[:].pull('data', block=True)
87 86 self.assertEqual(d, nengines*[data])
88 87 ar = push({'data':data}, block=False)
89 88 self.assertTrue(isinstance(ar, AsyncResult))
90 89 r = ar.get()
91 90 ar = self.client[:].pull('data', block=False)
92 91 self.assertTrue(isinstance(ar, AsyncResult))
93 92 r = ar.get()
94 93 self.assertEqual(r, nengines*[data])
95 94 self.client[:].push(dict(a=10,b=20))
96 95 r = self.client[:].pull(('a','b'), block=True)
97 96 self.assertEqual(r, nengines*[[10,20]])
98 97
99 98 def test_push_pull_function(self):
100 99 "test pushing and pulling functions"
101 100 def testf(x):
102 101 return 2.0*x
103 102
104 103 t = self.client.ids[-1]
105 104 v = self.client[t]
106 105 v.block=True
107 106 push = v.push
108 107 pull = v.pull
109 108 execute = v.execute
110 109 push({'testf':testf})
111 110 r = pull('testf')
112 111 self.assertEqual(r(1.0), testf(1.0))
113 112 execute('r = testf(10)')
114 113 r = pull('r')
115 114 self.assertEqual(r, testf(10))
116 115 ar = self.client[:].push({'testf':testf}, block=False)
117 116 ar.get()
118 117 ar = self.client[:].pull('testf', block=False)
119 118 rlist = ar.get()
120 119 for r in rlist:
121 120 self.assertEqual(r(1.0), testf(1.0))
122 121 execute("def g(x): return x*x")
123 122 r = pull(('testf','g'))
124 123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
125 124
126 125 def test_push_function_globals(self):
127 126 """test that pushed functions have access to globals"""
128 127 @interactive
129 128 def geta():
130 129 return a
131 130 # self.add_engines(1)
132 131 v = self.client[-1]
133 132 v.block=True
134 133 v['f'] = geta
135 134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
136 135 v.execute('a=5')
137 136 v.execute('b=f()')
138 137 self.assertEqual(v['b'], 5)
139 138
140 139 def test_push_function_defaults(self):
141 140 """test that pushed functions preserve default args"""
142 141 def echo(a=10):
143 142 return a
144 143 v = self.client[-1]
145 144 v.block=True
146 145 v['f'] = echo
147 146 v.execute('b=f()')
148 147 self.assertEqual(v['b'], 10)
149 148
150 149 def test_get_result(self):
151 150 """test getting results from the Hub."""
152 151 c = pmod.Client(profile='iptest')
153 152 # self.add_engines(1)
154 153 t = c.ids[-1]
155 154 v = c[t]
156 155 v2 = self.client[t]
157 156 ar = v.apply_async(wait, 1)
158 157 # give the monitor time to notice the message
159 158 time.sleep(.25)
160 159 ahr = v2.get_result(ar.msg_ids[0])
161 160 self.assertTrue(isinstance(ahr, AsyncHubResult))
162 161 self.assertEqual(ahr.get(), ar.get())
163 162 ar2 = v2.get_result(ar.msg_ids[0])
164 163 self.assertFalse(isinstance(ar2, AsyncHubResult))
165 164 c.spin()
166 165 c.close()
167 166
168 167 def test_run_newline(self):
169 168 """test that run appends newline to files"""
170 169 tmpfile = mktemp()
171 170 with open(tmpfile, 'w') as f:
172 171 f.write("""def g():
173 172 return 5
174 173 """)
175 174 v = self.client[-1]
176 175 v.run(tmpfile, block=True)
177 176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
178 177
179 178 def test_apply_tracked(self):
180 179 """test tracking for apply"""
181 180 # self.add_engines(1)
182 181 t = self.client.ids[-1]
183 182 v = self.client[t]
184 183 v.block=False
185 184 def echo(n=1024*1024, **kwargs):
186 185 with v.temp_flags(**kwargs):
187 186 return v.apply(lambda x: x, 'x'*n)
188 187 ar = echo(1, track=False)
189 188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 189 self.assertTrue(ar.sent)
191 190 ar = echo(track=True)
192 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
193 192 self.assertEqual(ar.sent, ar._tracker.done)
194 193 ar._tracker.wait()
195 194 self.assertTrue(ar.sent)
196 195
197 196 def test_push_tracked(self):
198 197 t = self.client.ids[-1]
199 198 ns = dict(x='x'*1024*1024)
200 199 v = self.client[t]
201 200 ar = v.push(ns, block=False, track=False)
202 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 202 self.assertTrue(ar.sent)
204 203
205 204 ar = v.push(ns, block=False, track=True)
206 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
207 206 ar._tracker.wait()
208 207 self.assertEqual(ar.sent, ar._tracker.done)
209 208 self.assertTrue(ar.sent)
210 209 ar.get()
211 210
212 211 def test_scatter_tracked(self):
213 212 t = self.client.ids
214 213 x='x'*1024*1024
215 214 ar = self.client[t].scatter('x', x, block=False, track=False)
216 215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
217 216 self.assertTrue(ar.sent)
218 217
219 218 ar = self.client[t].scatter('x', x, block=False, track=True)
220 219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
221 220 self.assertEqual(ar.sent, ar._tracker.done)
222 221 ar._tracker.wait()
223 222 self.assertTrue(ar.sent)
224 223 ar.get()
225 224
226 225 def test_remote_reference(self):
227 226 v = self.client[-1]
228 227 v['a'] = 123
229 228 ra = pmod.Reference('a')
230 229 b = v.apply_sync(lambda x: x, ra)
231 230 self.assertEqual(b, 123)
232 231
233 232
234 233 def test_scatter_gather(self):
235 234 view = self.client[:]
236 235 seq1 = range(16)
237 236 view.scatter('a', seq1)
238 237 seq2 = view.gather('a', block=True)
239 238 self.assertEqual(seq2, seq1)
240 239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
241 240
242 241 @skip_without('numpy')
243 242 def test_scatter_gather_numpy(self):
244 243 import numpy
245 244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
246 245 view = self.client[:]
247 246 a = numpy.arange(64)
248 247 view.scatter('a', a, block=True)
249 248 b = view.gather('a', block=True)
250 249 assert_array_equal(b, a)
251 250
252 251 def test_scatter_gather_lazy(self):
253 252 """scatter/gather with targets='all'"""
254 253 view = self.client.direct_view(targets='all')
255 254 x = range(64)
256 255 view.scatter('x', x)
257 256 gathered = view.gather('x', block=True)
258 257 self.assertEqual(gathered, x)
259 258
260 259
261 260 @dec.known_failure_py3
262 261 @skip_without('numpy')
263 262 def test_push_numpy_nocopy(self):
264 263 import numpy
265 264 view = self.client[:]
266 265 a = numpy.arange(64)
267 266 view['A'] = a
268 267 @interactive
269 268 def check_writeable(x):
270 269 return x.flags.writeable
271 270
272 271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
273 272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
274 273
275 274 view.push(dict(B=a))
276 275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
277 276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
278 277
279 278 @skip_without('numpy')
280 279 def test_apply_numpy(self):
281 280 """view.apply(f, ndarray)"""
282 281 import numpy
283 282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
284 283
285 284 A = numpy.random.random((100,100))
286 285 view = self.client[-1]
287 286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
288 287 B = A.astype(dt)
289 288 C = view.apply_sync(lambda x:x, B)
290 289 assert_array_equal(B,C)
291 290
292 291 @skip_without('numpy')
293 292 def test_push_pull_recarray(self):
294 293 """push/pull recarrays"""
295 294 import numpy
296 295 from numpy.testing.utils import assert_array_equal
297 296
298 297 view = self.client[-1]
299 298
300 299 R = numpy.array([
301 300 (1, 'hi', 0.),
302 301 (2**30, 'there', 2.5),
303 302 (-99999, 'world', -12345.6789),
304 303 ], [('n', int), ('s', '|S10'), ('f', float)])
305 304
306 305 view['RR'] = R
307 306 R2 = view['RR']
308 307
309 308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
310 309 self.assertEqual(r_dtype, R.dtype)
311 310 self.assertEqual(r_shape, R.shape)
312 311 self.assertEqual(R2.dtype, R.dtype)
313 312 self.assertEqual(R2.shape, R.shape)
314 313 assert_array_equal(R2, R)
315 314
316 315 @skip_without('pandas')
317 316 def test_push_pull_timeseries(self):
318 317 """push/pull pandas.TimeSeries"""
319 318 import pandas
320 319
321 320 ts = pandas.TimeSeries(range(10))
322 321
323 322 view = self.client[-1]
324 323
325 324 view.push(dict(ts=ts), block=True)
326 325 rts = view['ts']
327 326
328 327 self.assertEqual(type(rts), type(ts))
329 328 self.assertTrue((ts == rts).all())
330 329
331 330 def test_map(self):
332 331 view = self.client[:]
333 332 def f(x):
334 333 return x**2
335 334 data = range(16)
336 335 r = view.map_sync(f, data)
337 336 self.assertEqual(r, map(f, data))
338 337
339 338 def test_map_iterable(self):
340 339 """test map on iterables (direct)"""
341 340 view = self.client[:]
342 341 # 101 is prime, so it won't be evenly distributed
343 342 arr = range(101)
344 343 # ensure it will be an iterator, even in Python 3
345 344 it = iter(arr)
346 345 r = view.map_sync(lambda x: x, it)
347 346 self.assertEqual(r, list(arr))
348 347
349 348 @skip_without('numpy')
350 349 def test_map_numpy(self):
351 350 """test map on numpy arrays (direct)"""
352 351 import numpy
353 352 from numpy.testing.utils import assert_array_equal
354 353
355 354 view = self.client[:]
356 355 # 101 is prime, so it won't be evenly distributed
357 356 arr = numpy.arange(101)
358 357 r = view.map_sync(lambda x: x, arr)
359 358 assert_array_equal(r, arr)
360 359
361 360 def test_scatter_gather_nonblocking(self):
362 361 data = range(16)
363 362 view = self.client[:]
364 363 view.scatter('a', data, block=False)
365 364 ar = view.gather('a', block=False)
366 365 self.assertEqual(ar.get(), data)
367 366
368 367 @skip_without('numpy')
369 368 def test_scatter_gather_numpy_nonblocking(self):
370 369 import numpy
371 370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
372 371 a = numpy.arange(64)
373 372 view = self.client[:]
374 373 ar = view.scatter('a', a, block=False)
375 374 self.assertTrue(isinstance(ar, AsyncResult))
376 375 amr = view.gather('a', block=False)
377 376 self.assertTrue(isinstance(amr, AsyncMapResult))
378 377 assert_array_equal(amr.get(), a)
379 378
380 379 def test_execute(self):
381 380 view = self.client[:]
382 381 # self.client.debug=True
383 382 execute = view.execute
384 383 ar = execute('c=30', block=False)
385 384 self.assertTrue(isinstance(ar, AsyncResult))
386 385 ar = execute('d=[0,1,2]', block=False)
387 386 self.client.wait(ar, 1)
388 387 self.assertEqual(len(ar.get()), len(self.client))
389 388 for c in view['c']:
390 389 self.assertEqual(c, 30)
391 390
392 391 def test_abort(self):
393 392 view = self.client[-1]
394 393 ar = view.execute('import time; time.sleep(1)', block=False)
395 394 ar2 = view.apply_async(lambda : 2)
396 395 ar3 = view.apply_async(lambda : 3)
397 396 view.abort(ar2)
398 397 view.abort(ar3.msg_ids)
399 398 self.assertRaises(error.TaskAborted, ar2.get)
400 399 self.assertRaises(error.TaskAborted, ar3.get)
401 400
402 401 def test_abort_all(self):
403 402 """view.abort() aborts all outstanding tasks"""
404 403 view = self.client[-1]
405 404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
406 405 view.abort()
407 406 view.wait(timeout=5)
408 407 for ar in ars[5:]:
409 408 self.assertRaises(error.TaskAborted, ar.get)
410 409
411 410 def test_temp_flags(self):
412 411 view = self.client[-1]
413 412 view.block=True
414 413 with view.temp_flags(block=False):
415 414 self.assertFalse(view.block)
416 415 self.assertTrue(view.block)
417 416
418 417 @dec.known_failure_py3
419 418 def test_importer(self):
420 419 view = self.client[-1]
421 420 view.clear(block=True)
422 421 with view.importer:
423 422 import re
424 423
425 424 @interactive
426 425 def findall(pat, s):
427 426 # this globals() step isn't necessary in real code
428 427 # only to prevent a closure in the test
429 428 re = globals()['re']
430 429 return re.findall(pat, s)
431 430
432 431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
433 432
434 433 def test_unicode_execute(self):
435 434 """test executing unicode strings"""
436 435 v = self.client[-1]
437 436 v.block=True
438 437 if sys.version_info[0] >= 3:
439 438 code="a='é'"
440 439 else:
441 440 code=u"a=u'é'"
442 441 v.execute(code)
443 442 self.assertEqual(v['a'], u'é')
444 443
445 444 def test_unicode_apply_result(self):
446 445 """test unicode apply results"""
447 446 v = self.client[-1]
448 447 r = v.apply_sync(lambda : u'é')
449 448 self.assertEqual(r, u'é')
450 449
451 450 def test_unicode_apply_arg(self):
452 451 """test passing unicode arguments to apply"""
453 452 v = self.client[-1]
454 453
455 454 @interactive
456 455 def check_unicode(a, check):
457 456 assert isinstance(a, unicode), "%r is not unicode"%a
458 457 assert isinstance(check, bytes), "%r is not bytes"%check
459 458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
460 459
461 460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
462 461 try:
463 462 v.apply_sync(check_unicode, s, s.encode('utf8'))
464 463 except error.RemoteError as e:
465 464 if e.ename == 'AssertionError':
466 465 self.fail(e.evalue)
467 466 else:
468 467 raise e
469 468
470 469 def test_map_reference(self):
471 470 """view.map(<Reference>, *seqs) should work"""
472 471 v = self.client[:]
473 472 v.scatter('n', self.client.ids, flatten=True)
474 473 v.execute("f = lambda x,y: x*y")
475 474 rf = pmod.Reference('f')
476 475 nlist = list(range(10))
477 476 mlist = nlist[::-1]
478 477 expected = [ m*n for m,n in zip(mlist, nlist) ]
479 478 result = v.map_sync(rf, mlist, nlist)
480 479 self.assertEqual(result, expected)
481 480
482 481 def test_apply_reference(self):
483 482 """view.apply(<Reference>, *args) should work"""
484 483 v = self.client[:]
485 484 v.scatter('n', self.client.ids, flatten=True)
486 485 v.execute("f = lambda x: n*x")
487 486 rf = pmod.Reference('f')
488 487 result = v.apply_sync(rf, 5)
489 488 expected = [ 5*id for id in self.client.ids ]
490 489 self.assertEqual(result, expected)
491 490
492 491 def test_eval_reference(self):
493 492 v = self.client[self.client.ids[0]]
494 493 v['g'] = range(5)
495 494 rg = pmod.Reference('g[0]')
496 495 echo = lambda x:x
497 496 self.assertEqual(v.apply_sync(echo, rg), 0)
498 497
499 498 def test_reference_nameerror(self):
500 499 v = self.client[self.client.ids[0]]
501 500 r = pmod.Reference('elvis_has_left')
502 501 echo = lambda x:x
503 502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
504 503
505 504 def test_single_engine_map(self):
506 505 e0 = self.client[self.client.ids[0]]
507 506 r = range(5)
508 507 check = [ -1*i for i in r ]
509 508 result = e0.map_sync(lambda x: -1*x, r)
510 509 self.assertEqual(result, check)
511 510
512 511 def test_len(self):
513 512 """len(view) makes sense"""
514 513 e0 = self.client[self.client.ids[0]]
515 yield self.assertEqual(len(e0), 1)
514 self.assertEqual(len(e0), 1)
516 515 v = self.client[:]
517 yield self.assertEqual(len(v), len(self.client.ids))
516 self.assertEqual(len(v), len(self.client.ids))
518 517 v = self.client.direct_view('all')
519 yield self.assertEqual(len(v), len(self.client.ids))
518 self.assertEqual(len(v), len(self.client.ids))
520 519 v = self.client[:2]
521 yield self.assertEqual(len(v), 2)
520 self.assertEqual(len(v), 2)
522 521 v = self.client[:1]
523 yield self.assertEqual(len(v), 1)
522 self.assertEqual(len(v), 1)
524 523 v = self.client.load_balanced_view()
525 yield self.assertEqual(len(v), len(self.client.ids))
524 self.assertEqual(len(v), len(self.client.ids))
526 525 # parametric tests seem to require manual closing?
527 526 self.client.close()
528 527
529 528
530 529 # begin execute tests
531 530
532 531 def test_execute_reply(self):
533 532 e0 = self.client[self.client.ids[0]]
534 533 e0.block = True
535 534 ar = e0.execute("5", silent=False)
536 535 er = ar.get()
537 536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
538 537 self.assertEqual(er.pyout['data']['text/plain'], '5')
539 538
540 539 def test_execute_reply_rich(self):
541 540 e0 = self.client[self.client.ids[0]]
542 541 e0.block = True
543 542 e0.execute("from IPython.display import Image, HTML")
544 543 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
545 544 er = ar.get()
546 545 b64data = base64.encodestring(b'garbage').decode('ascii')
547 546 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
548 547 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
549 548 er = ar.get()
550 549 self.assertEqual(er._repr_html_(), "<b>bold</b>")
551 550
552 551 def test_execute_reply_stdout(self):
553 552 e0 = self.client[self.client.ids[0]]
554 553 e0.block = True
555 554 ar = e0.execute("print (5)", silent=False)
556 555 er = ar.get()
557 556 self.assertEqual(er.stdout.strip(), '5')
558 557
559 558 def test_execute_pyout(self):
560 559 """execute triggers pyout with silent=False"""
561 560 view = self.client[:]
562 561 ar = view.execute("5", silent=False, block=True)
563 562
564 563 expected = [{'text/plain' : '5'}] * len(view)
565 564 mimes = [ out['data'] for out in ar.pyout ]
566 565 self.assertEqual(mimes, expected)
567 566
568 567 def test_execute_silent(self):
569 568 """execute does not trigger pyout with silent=True"""
570 569 view = self.client[:]
571 570 ar = view.execute("5", block=True)
572 571 expected = [None] * len(view)
573 572 self.assertEqual(ar.pyout, expected)
574 573
575 574 def test_execute_magic(self):
576 575 """execute accepts IPython commands"""
577 576 view = self.client[:]
578 577 view.execute("a = 5")
579 578 ar = view.execute("%whos", block=True)
580 579 # this will raise, if that failed
581 580 ar.get(5)
582 581 for stdout in ar.stdout:
583 582 lines = stdout.splitlines()
584 583 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
585 584 found = False
586 585 for line in lines[2:]:
587 586 split = line.split()
588 587 if split == ['a', 'int', '5']:
589 588 found = True
590 589 break
591 590 self.assertTrue(found, "whos output wrong: %s" % stdout)
592 591
593 592 def test_execute_displaypub(self):
594 593 """execute tracks display_pub output"""
595 594 view = self.client[:]
596 595 view.execute("from IPython.core.display import *")
597 596 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
598 597
599 598 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
600 599 for outputs in ar.outputs:
601 600 mimes = [ out['data'] for out in outputs ]
602 601 self.assertEqual(mimes, expected)
603 602
604 603 def test_apply_displaypub(self):
605 604 """apply tracks display_pub output"""
606 605 view = self.client[:]
607 606 view.execute("from IPython.core.display import *")
608 607
609 608 @interactive
610 609 def publish():
611 610 [ display(i) for i in range(5) ]
612 611
613 612 ar = view.apply_async(publish)
614 613 ar.get(5)
615 614 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
616 615 for outputs in ar.outputs:
617 616 mimes = [ out['data'] for out in outputs ]
618 617 self.assertEqual(mimes, expected)
619 618
620 619 def test_execute_raises(self):
621 620 """exceptions in execute requests raise appropriately"""
622 621 view = self.client[-1]
623 622 ar = view.execute("1/0")
624 623 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
625 624
626 625 def test_remoteerror_render_exception(self):
627 626 """RemoteErrors get nice tracebacks"""
628 627 view = self.client[-1]
629 628 ar = view.execute("1/0")
630 629 ip = get_ipython()
631 630 ip.user_ns['ar'] = ar
632 631 with capture_output() as io:
633 632 ip.run_cell("ar.get(2)")
634 633
635 634 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
636 635
637 636 def test_compositeerror_render_exception(self):
638 637 """CompositeErrors get nice tracebacks"""
639 638 view = self.client[:]
640 639 ar = view.execute("1/0")
641 640 ip = get_ipython()
642 641 ip.user_ns['ar'] = ar
643 642
644 643 with capture_output() as io:
645 644 ip.run_cell("ar.get(2)")
646 645
647 646 count = min(error.CompositeError.tb_limit, len(view))
648 647
649 648 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
650 649 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
651 650 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
652 651
653 652 def test_compositeerror_truncate(self):
654 653 """Truncate CompositeErrors with many exceptions"""
655 654 view = self.client[:]
656 655 msg_ids = []
657 656 for i in range(10):
658 657 ar = view.execute("1/0")
659 658 msg_ids.extend(ar.msg_ids)
660 659
661 660 ar = self.client.get_result(msg_ids)
662 661 try:
663 662 ar.get()
664 663 except error.CompositeError as _e:
665 664 e = _e
666 665 else:
667 666 self.fail("Should have raised CompositeError")
668 667
669 668 lines = e.render_traceback()
670 669 with capture_output() as io:
671 670 e.print_traceback()
672 671
673 672 self.assertTrue("more exceptions" in lines[-1])
674 673 count = e.tb_limit
675 674
676 675 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
677 676 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
678 677 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
679 678
680 679 @dec.skipif_not_matplotlib
681 680 def test_magic_pylab(self):
682 681 """%pylab works on engines"""
683 682 view = self.client[-1]
684 683 ar = view.execute("%pylab inline")
685 684 # at least check if this raised:
686 685 reply = ar.get(5)
687 686 # include imports, in case user config
688 687 ar = view.execute("plot(rand(100))", silent=False)
689 688 reply = ar.get(5)
690 689 self.assertEqual(len(reply.outputs), 1)
691 690 output = reply.outputs[0]
692 691 self.assertTrue("data" in output)
693 692 data = output['data']
694 693 self.assertTrue("image/png" in data)
695 694
696 695 def test_func_default_func(self):
697 696 """interactively defined function as apply func default"""
698 697 def foo():
699 698 return 'foo'
700 699
701 700 def bar(f=foo):
702 701 return f()
703 702
704 703 view = self.client[-1]
705 704 ar = view.apply_async(bar)
706 705 r = ar.get(10)
707 706 self.assertEqual(r, 'foo')
708 707 def test_data_pub_single(self):
709 708 view = self.client[-1]
710 709 ar = view.execute('\n'.join([
711 710 'from IPython.kernel.zmq.datapub import publish_data',
712 711 'for i in range(5):',
713 712 ' publish_data(dict(i=i))'
714 713 ]), block=False)
715 714 self.assertTrue(isinstance(ar.data, dict))
716 715 ar.get(5)
717 716 self.assertEqual(ar.data, dict(i=4))
718 717
719 718 def test_data_pub(self):
720 719 view = self.client[:]
721 720 ar = view.execute('\n'.join([
722 721 'from IPython.kernel.zmq.datapub import publish_data',
723 722 'for i in range(5):',
724 723 ' publish_data(dict(i=i))'
725 724 ]), block=False)
726 725 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
727 726 ar.get(5)
728 727 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
729 728
730 729 def test_can_list_arg(self):
731 730 """args in lists are canned"""
732 731 view = self.client[-1]
733 732 view['a'] = 128
734 733 rA = pmod.Reference('a')
735 734 ar = view.apply_async(lambda x: x, [rA])
736 735 r = ar.get(5)
737 736 self.assertEqual(r, [128])
738 737
739 738 def test_can_dict_arg(self):
740 739 """args in dicts are canned"""
741 740 view = self.client[-1]
742 741 view['a'] = 128
743 742 rA = pmod.Reference('a')
744 743 ar = view.apply_async(lambda x: x, dict(foo=rA))
745 744 r = ar.get(5)
746 745 self.assertEqual(r, dict(foo=128))
747 746
748 747 def test_can_list_kwarg(self):
749 748 """kwargs in lists are canned"""
750 749 view = self.client[-1]
751 750 view['a'] = 128
752 751 rA = pmod.Reference('a')
753 752 ar = view.apply_async(lambda x=5: x, x=[rA])
754 753 r = ar.get(5)
755 754 self.assertEqual(r, [128])
756 755
757 756 def test_can_dict_kwarg(self):
758 757 """kwargs in dicts are canned"""
759 758 view = self.client[-1]
760 759 view['a'] = 128
761 760 rA = pmod.Reference('a')
762 761 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
763 762 r = ar.get(5)
764 763 self.assertEqual(r, dict(foo=128))
765 764
766 765 def test_map_ref(self):
767 766 """view.map works with references"""
768 767 view = self.client[:]
769 768 ranks = sorted(self.client.ids)
770 769 view.scatter('rank', ranks, flatten=True)
771 770 rrank = pmod.Reference('rank')
772 771
773 772 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
774 773 drank = amr.get(5)
775 774 self.assertEqual(drank, [ r*2 for r in ranks ])
776 775
777 776 def test_nested_getitem_setitem(self):
778 777 """get and set with view['a.b']"""
779 778 view = self.client[-1]
780 779 view.execute('\n'.join([
781 780 'class A(object): pass',
782 781 'a = A()',
783 782 'a.b = 128',
784 783 ]), block=True)
785 784 ra = pmod.Reference('a')
786 785
787 786 r = view.apply_sync(lambda x: x.b, ra)
788 787 self.assertEqual(r, 128)
789 788 self.assertEqual(view['a.b'], 128)
790 789
791 790 view['a.b'] = 0
792 791
793 792 r = view.apply_sync(lambda x: x.b, ra)
794 793 self.assertEqual(r, 0)
795 794 self.assertEqual(view['a.b'], 0)
796 795
797 796 def test_return_namedtuple(self):
798 797 def namedtuplify(x, y):
799 798 from IPython.parallel.tests.test_view import point
800 799 return point(x, y)
801 800
802 801 view = self.client[-1]
803 802 p = view.apply_sync(namedtuplify, 1, 2)
804 803 self.assertEqual(p.x, 1)
805 804 self.assertEqual(p.y, 2)
806 805
807 806 def test_apply_namedtuple(self):
808 807 def echoxy(p):
809 808 return p.y, p.x
810 809
811 810 view = self.client[-1]
812 811 tup = view.apply_sync(echoxy, point(1, 2))
813 812 self.assertEqual(tup, (2,1))
814 813
@@ -1,86 +1,86 b''
1 1 # encoding: utf-8
2 2 """Tests for io.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15
16 16 import sys
17 17
18 18 from StringIO import StringIO
19 19 from subprocess import Popen, PIPE
20 import unittest
20 21
21 22 import nose.tools as nt
22 23
23 from IPython.testing.ipunittest import ParametricTestCase
24 24 from IPython.utils.io import Tee, capture_output
25 25 from IPython.utils.py3compat import doctest_refactor_print
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Tests
29 29 #-----------------------------------------------------------------------------
30 30
31 31
32 32 def test_tee_simple():
33 33 "Very simple check with stdout only"
34 34 chan = StringIO()
35 35 text = 'Hello'
36 36 tee = Tee(chan, channel='stdout')
37 37 print(text, file=chan)
38 38 nt.assert_equal(chan.getvalue(), text+"\n")
39 39
40 40
41 class TeeTestCase(ParametricTestCase):
41 class TeeTestCase(unittest.TestCase):
42 42
43 43 def tchan(self, channel, check='close'):
44 44 trap = StringIO()
45 45 chan = StringIO()
46 46 text = 'Hello'
47 47
48 48 std_ori = getattr(sys, channel)
49 49 setattr(sys, channel, trap)
50 50
51 51 tee = Tee(chan, channel=channel)
52 52 print(text, end='', file=chan)
53 53 setattr(sys, channel, std_ori)
54 54 trap_val = trap.getvalue()
55 55 nt.assert_equal(chan.getvalue(), text)
56 56 if check=='close':
57 57 tee.close()
58 58 else:
59 59 del tee
60 60
61 61 def test(self):
62 62 for chan in ['stdout', 'stderr']:
63 63 for check in ['close', 'del']:
64 yield self.tchan(chan, check)
64 self.tchan(chan, check)
65 65
66 66 def test_io_init():
67 67 """Test that io.stdin/out/err exist at startup"""
68 68 for name in ('stdin', 'stdout', 'stderr'):
69 69 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
70 70 p = Popen([sys.executable, '-c', cmd],
71 71 stdout=PIPE)
72 72 p.wait()
73 73 classname = p.stdout.read().strip().decode('ascii')
74 74 # __class__ is a reference to the class object in Python 3, so we can't
75 75 # just test for string equality.
76 76 assert 'IPython.utils.io.IOStream' in classname, classname
77 77
78 78 def test_capture_output():
79 79 """capture_output() context works"""
80 80
81 81 with capture_output() as io:
82 82 print('hi, stdout')
83 83 print('hi, stderr', file=sys.stderr)
84 84
85 85 nt.assert_equal(io.stdout, 'hi, stdout\n')
86 86 nt.assert_equal(io.stderr, 'hi, stderr\n')
General Comments 0
You need to be logged in to leave comments. Login now