##// END OF EJS Templates
Merge pull request #2347 from minrk/byzero...
Min RK -
r8298:011222a5 merge
parent child Browse files
Show More
@@ -1,703 +1,703 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from tempfile import mktemp
23 23 from StringIO import StringIO
24 24
25 25 import zmq
26 26 from nose import SkipTest
27 27
28 28 from IPython.testing import decorators as dec
29 29 from IPython.testing.ipunittest import ParametricTestCase
30 30 from IPython.utils.io import capture_output
31 31
32 32 from IPython import parallel as pmod
33 33 from IPython.parallel import error
34 34 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
35 35 from IPython.parallel import DirectView
36 36 from IPython.parallel.util import interactive
37 37
38 38 from IPython.parallel.tests import add_engines
39 39
40 40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41 41
42 42 def setup():
43 43 add_engines(3, total=True)
44 44
45 45 class TestView(ClusterTestCase, ParametricTestCase):
46 46
47 47 def setUp(self):
48 48 # On Win XP, wait for resource cleanup, else parallel test group fails
49 49 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
50 50 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
51 51 time.sleep(2)
52 52 super(TestView, self).setUp()
53 53
54 54 def test_z_crash_mux(self):
55 55 """test graceful handling of engine death (direct)"""
56 56 raise SkipTest("crash tests disabled, due to undesirable crash reports")
57 57 # self.add_engines(1)
58 58 eid = self.client.ids[-1]
59 59 ar = self.client[eid].apply_async(crash)
60 60 self.assertRaisesRemote(error.EngineError, ar.get, 10)
61 61 eid = ar.engine_id
62 62 tic = time.time()
63 63 while eid in self.client.ids and time.time()-tic < 5:
64 64 time.sleep(.01)
65 65 self.client.spin()
66 66 self.assertFalse(eid in self.client.ids, "Engine should have died")
67 67
68 68 def test_push_pull(self):
69 69 """test pushing and pulling"""
70 70 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
71 71 t = self.client.ids[-1]
72 72 v = self.client[t]
73 73 push = v.push
74 74 pull = v.pull
75 75 v.block=True
76 76 nengines = len(self.client)
77 77 push({'data':data})
78 78 d = pull('data')
79 79 self.assertEqual(d, data)
80 80 self.client[:].push({'data':data})
81 81 d = self.client[:].pull('data', block=True)
82 82 self.assertEqual(d, nengines*[data])
83 83 ar = push({'data':data}, block=False)
84 84 self.assertTrue(isinstance(ar, AsyncResult))
85 85 r = ar.get()
86 86 ar = self.client[:].pull('data', block=False)
87 87 self.assertTrue(isinstance(ar, AsyncResult))
88 88 r = ar.get()
89 89 self.assertEqual(r, nengines*[data])
90 90 self.client[:].push(dict(a=10,b=20))
91 91 r = self.client[:].pull(('a','b'), block=True)
92 92 self.assertEqual(r, nengines*[[10,20]])
93 93
94 94 def test_push_pull_function(self):
95 95 "test pushing and pulling functions"
96 96 def testf(x):
97 97 return 2.0*x
98 98
99 99 t = self.client.ids[-1]
100 100 v = self.client[t]
101 101 v.block=True
102 102 push = v.push
103 103 pull = v.pull
104 104 execute = v.execute
105 105 push({'testf':testf})
106 106 r = pull('testf')
107 107 self.assertEqual(r(1.0), testf(1.0))
108 108 execute('r = testf(10)')
109 109 r = pull('r')
110 110 self.assertEqual(r, testf(10))
111 111 ar = self.client[:].push({'testf':testf}, block=False)
112 112 ar.get()
113 113 ar = self.client[:].pull('testf', block=False)
114 114 rlist = ar.get()
115 115 for r in rlist:
116 116 self.assertEqual(r(1.0), testf(1.0))
117 117 execute("def g(x): return x*x")
118 118 r = pull(('testf','g'))
119 119 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
120 120
121 121 def test_push_function_globals(self):
122 122 """test that pushed functions have access to globals"""
123 123 @interactive
124 124 def geta():
125 125 return a
126 126 # self.add_engines(1)
127 127 v = self.client[-1]
128 128 v.block=True
129 129 v['f'] = geta
130 130 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
131 131 v.execute('a=5')
132 132 v.execute('b=f()')
133 133 self.assertEqual(v['b'], 5)
134 134
135 135 def test_push_function_defaults(self):
136 136 """test that pushed functions preserve default args"""
137 137 def echo(a=10):
138 138 return a
139 139 v = self.client[-1]
140 140 v.block=True
141 141 v['f'] = echo
142 142 v.execute('b=f()')
143 143 self.assertEqual(v['b'], 10)
144 144
145 145 def test_get_result(self):
146 146 """test getting results from the Hub."""
147 147 c = pmod.Client(profile='iptest')
148 148 # self.add_engines(1)
149 149 t = c.ids[-1]
150 150 v = c[t]
151 151 v2 = self.client[t]
152 152 ar = v.apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 155 ahr = v2.get_result(ar.msg_ids)
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 157 self.assertEqual(ahr.get(), ar.get())
158 158 ar2 = v2.get_result(ar.msg_ids)
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.spin()
161 161 c.close()
162 162
163 163 def test_run_newline(self):
164 164 """test that run appends newline to files"""
165 165 tmpfile = mktemp()
166 166 with open(tmpfile, 'w') as f:
167 167 f.write("""def g():
168 168 return 5
169 169 """)
170 170 v = self.client[-1]
171 171 v.run(tmpfile, block=True)
172 172 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
173 173
174 174 def test_apply_tracked(self):
175 175 """test tracking for apply"""
176 176 # self.add_engines(1)
177 177 t = self.client.ids[-1]
178 178 v = self.client[t]
179 179 v.block=False
180 180 def echo(n=1024*1024, **kwargs):
181 181 with v.temp_flags(**kwargs):
182 182 return v.apply(lambda x: x, 'x'*n)
183 183 ar = echo(1, track=False)
184 184 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
185 185 self.assertTrue(ar.sent)
186 186 ar = echo(track=True)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertEqual(ar.sent, ar._tracker.done)
189 189 ar._tracker.wait()
190 190 self.assertTrue(ar.sent)
191 191
192 192 def test_push_tracked(self):
193 193 t = self.client.ids[-1]
194 194 ns = dict(x='x'*1024*1024)
195 195 v = self.client[t]
196 196 ar = v.push(ns, block=False, track=False)
197 197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
198 198 self.assertTrue(ar.sent)
199 199
200 200 ar = v.push(ns, block=False, track=True)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 ar._tracker.wait()
203 203 self.assertEqual(ar.sent, ar._tracker.done)
204 204 self.assertTrue(ar.sent)
205 205 ar.get()
206 206
207 207 def test_scatter_tracked(self):
208 208 t = self.client.ids
209 209 x='x'*1024*1024
210 210 ar = self.client[t].scatter('x', x, block=False, track=False)
211 211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 212 self.assertTrue(ar.sent)
213 213
214 214 ar = self.client[t].scatter('x', x, block=False, track=True)
215 215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 216 self.assertEqual(ar.sent, ar._tracker.done)
217 217 ar._tracker.wait()
218 218 self.assertTrue(ar.sent)
219 219 ar.get()
220 220
221 221 def test_remote_reference(self):
222 222 v = self.client[-1]
223 223 v['a'] = 123
224 224 ra = pmod.Reference('a')
225 225 b = v.apply_sync(lambda x: x, ra)
226 226 self.assertEqual(b, 123)
227 227
228 228
229 229 def test_scatter_gather(self):
230 230 view = self.client[:]
231 231 seq1 = range(16)
232 232 view.scatter('a', seq1)
233 233 seq2 = view.gather('a', block=True)
234 234 self.assertEqual(seq2, seq1)
235 235 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
236 236
237 237 @skip_without('numpy')
238 238 def test_scatter_gather_numpy(self):
239 239 import numpy
240 240 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
241 241 view = self.client[:]
242 242 a = numpy.arange(64)
243 243 view.scatter('a', a, block=True)
244 244 b = view.gather('a', block=True)
245 245 assert_array_equal(b, a)
246 246
247 247 def test_scatter_gather_lazy(self):
248 248 """scatter/gather with targets='all'"""
249 249 view = self.client.direct_view(targets='all')
250 250 x = range(64)
251 251 view.scatter('x', x)
252 252 gathered = view.gather('x', block=True)
253 253 self.assertEqual(gathered, x)
254 254
255 255
256 256 @dec.known_failure_py3
257 257 @skip_without('numpy')
258 258 def test_push_numpy_nocopy(self):
259 259 import numpy
260 260 view = self.client[:]
261 261 a = numpy.arange(64)
262 262 view['A'] = a
263 263 @interactive
264 264 def check_writeable(x):
265 265 return x.flags.writeable
266 266
267 267 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
268 268 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
269 269
270 270 view.push(dict(B=a))
271 271 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
272 272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 273
274 274 @skip_without('numpy')
275 275 def test_apply_numpy(self):
276 276 """view.apply(f, ndarray)"""
277 277 import numpy
278 278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
279 279
280 280 A = numpy.random.random((100,100))
281 281 view = self.client[-1]
282 282 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
283 283 B = A.astype(dt)
284 284 C = view.apply_sync(lambda x:x, B)
285 285 assert_array_equal(B,C)
286 286
287 287 @skip_without('numpy')
288 288 def test_push_pull_recarray(self):
289 289 """push/pull recarrays"""
290 290 import numpy
291 291 from numpy.testing.utils import assert_array_equal
292 292
293 293 view = self.client[-1]
294 294
295 295 R = numpy.array([
296 296 (1, 'hi', 0.),
297 297 (2**30, 'there', 2.5),
298 298 (-99999, 'world', -12345.6789),
299 299 ], [('n', int), ('s', '|S10'), ('f', float)])
300 300
301 301 view['RR'] = R
302 302 R2 = view['RR']
303 303
304 304 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
305 305 self.assertEqual(r_dtype, R.dtype)
306 306 self.assertEqual(r_shape, R.shape)
307 307 self.assertEqual(R2.dtype, R.dtype)
308 308 self.assertEqual(R2.shape, R.shape)
309 309 assert_array_equal(R2, R)
310 310
311 311 def test_map(self):
312 312 view = self.client[:]
313 313 def f(x):
314 314 return x**2
315 315 data = range(16)
316 316 r = view.map_sync(f, data)
317 317 self.assertEqual(r, map(f, data))
318 318
319 319 def test_map_iterable(self):
320 320 """test map on iterables (direct)"""
321 321 view = self.client[:]
322 322 # 101 is prime, so it won't be evenly distributed
323 323 arr = range(101)
324 324 # ensure it will be an iterator, even in Python 3
325 325 it = iter(arr)
326 326 r = view.map_sync(lambda x:x, arr)
327 327 self.assertEqual(r, list(arr))
328 328
329 329 def test_scatter_gather_nonblocking(self):
330 330 data = range(16)
331 331 view = self.client[:]
332 332 view.scatter('a', data, block=False)
333 333 ar = view.gather('a', block=False)
334 334 self.assertEqual(ar.get(), data)
335 335
336 336 @skip_without('numpy')
337 337 def test_scatter_gather_numpy_nonblocking(self):
338 338 import numpy
339 339 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
340 340 a = numpy.arange(64)
341 341 view = self.client[:]
342 342 ar = view.scatter('a', a, block=False)
343 343 self.assertTrue(isinstance(ar, AsyncResult))
344 344 amr = view.gather('a', block=False)
345 345 self.assertTrue(isinstance(amr, AsyncMapResult))
346 346 assert_array_equal(amr.get(), a)
347 347
348 348 def test_execute(self):
349 349 view = self.client[:]
350 350 # self.client.debug=True
351 351 execute = view.execute
352 352 ar = execute('c=30', block=False)
353 353 self.assertTrue(isinstance(ar, AsyncResult))
354 354 ar = execute('d=[0,1,2]', block=False)
355 355 self.client.wait(ar, 1)
356 356 self.assertEqual(len(ar.get()), len(self.client))
357 357 for c in view['c']:
358 358 self.assertEqual(c, 30)
359 359
360 360 def test_abort(self):
361 361 view = self.client[-1]
362 362 ar = view.execute('import time; time.sleep(1)', block=False)
363 363 ar2 = view.apply_async(lambda : 2)
364 364 ar3 = view.apply_async(lambda : 3)
365 365 view.abort(ar2)
366 366 view.abort(ar3.msg_ids)
367 367 self.assertRaises(error.TaskAborted, ar2.get)
368 368 self.assertRaises(error.TaskAborted, ar3.get)
369 369
370 370 def test_abort_all(self):
371 371 """view.abort() aborts all outstanding tasks"""
372 372 view = self.client[-1]
373 373 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
374 374 view.abort()
375 375 view.wait(timeout=5)
376 376 for ar in ars[5:]:
377 377 self.assertRaises(error.TaskAborted, ar.get)
378 378
379 379 def test_temp_flags(self):
380 380 view = self.client[-1]
381 381 view.block=True
382 382 with view.temp_flags(block=False):
383 383 self.assertFalse(view.block)
384 384 self.assertTrue(view.block)
385 385
386 386 @dec.known_failure_py3
387 387 def test_importer(self):
388 388 view = self.client[-1]
389 389 view.clear(block=True)
390 390 with view.importer:
391 391 import re
392 392
393 393 @interactive
394 394 def findall(pat, s):
395 395 # this globals() step isn't necessary in real code
396 396 # only to prevent a closure in the test
397 397 re = globals()['re']
398 398 return re.findall(pat, s)
399 399
400 400 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
401 401
402 402 def test_unicode_execute(self):
403 403 """test executing unicode strings"""
404 404 v = self.client[-1]
405 405 v.block=True
406 406 if sys.version_info[0] >= 3:
407 407 code="a='é'"
408 408 else:
409 409 code=u"a=u'é'"
410 410 v.execute(code)
411 411 self.assertEqual(v['a'], u'é')
412 412
413 413 def test_unicode_apply_result(self):
414 414 """test unicode apply results"""
415 415 v = self.client[-1]
416 416 r = v.apply_sync(lambda : u'é')
417 417 self.assertEqual(r, u'é')
418 418
419 419 def test_unicode_apply_arg(self):
420 420 """test passing unicode arguments to apply"""
421 421 v = self.client[-1]
422 422
423 423 @interactive
424 424 def check_unicode(a, check):
425 425 assert isinstance(a, unicode), "%r is not unicode"%a
426 426 assert isinstance(check, bytes), "%r is not bytes"%check
427 427 assert a.encode('utf8') == check, "%s != %s"%(a,check)
428 428
429 429 for s in [ u'é', u'ßø®∫',u'asdf' ]:
430 430 try:
431 431 v.apply_sync(check_unicode, s, s.encode('utf8'))
432 432 except error.RemoteError as e:
433 433 if e.ename == 'AssertionError':
434 434 self.fail(e.evalue)
435 435 else:
436 436 raise e
437 437
438 438 def test_map_reference(self):
439 439 """view.map(<Reference>, *seqs) should work"""
440 440 v = self.client[:]
441 441 v.scatter('n', self.client.ids, flatten=True)
442 442 v.execute("f = lambda x,y: x*y")
443 443 rf = pmod.Reference('f')
444 444 nlist = list(range(10))
445 445 mlist = nlist[::-1]
446 446 expected = [ m*n for m,n in zip(mlist, nlist) ]
447 447 result = v.map_sync(rf, mlist, nlist)
448 448 self.assertEqual(result, expected)
449 449
450 450 def test_apply_reference(self):
451 451 """view.apply(<Reference>, *args) should work"""
452 452 v = self.client[:]
453 453 v.scatter('n', self.client.ids, flatten=True)
454 454 v.execute("f = lambda x: n*x")
455 455 rf = pmod.Reference('f')
456 456 result = v.apply_sync(rf, 5)
457 457 expected = [ 5*id for id in self.client.ids ]
458 458 self.assertEqual(result, expected)
459 459
460 460 def test_eval_reference(self):
461 461 v = self.client[self.client.ids[0]]
462 462 v['g'] = range(5)
463 463 rg = pmod.Reference('g[0]')
464 464 echo = lambda x:x
465 465 self.assertEqual(v.apply_sync(echo, rg), 0)
466 466
467 467 def test_reference_nameerror(self):
468 468 v = self.client[self.client.ids[0]]
469 469 r = pmod.Reference('elvis_has_left')
470 470 echo = lambda x:x
471 471 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
472 472
473 473 def test_single_engine_map(self):
474 474 e0 = self.client[self.client.ids[0]]
475 475 r = range(5)
476 476 check = [ -1*i for i in r ]
477 477 result = e0.map_sync(lambda x: -1*x, r)
478 478 self.assertEqual(result, check)
479 479
480 480 def test_len(self):
481 481 """len(view) makes sense"""
482 482 e0 = self.client[self.client.ids[0]]
483 483 yield self.assertEqual(len(e0), 1)
484 484 v = self.client[:]
485 485 yield self.assertEqual(len(v), len(self.client.ids))
486 486 v = self.client.direct_view('all')
487 487 yield self.assertEqual(len(v), len(self.client.ids))
488 488 v = self.client[:2]
489 489 yield self.assertEqual(len(v), 2)
490 490 v = self.client[:1]
491 491 yield self.assertEqual(len(v), 1)
492 492 v = self.client.load_balanced_view()
493 493 yield self.assertEqual(len(v), len(self.client.ids))
494 494 # parametric tests seem to require manual closing?
495 495 self.client.close()
496 496
497 497
498 498 # begin execute tests
499 499
500 500 def test_execute_reply(self):
501 501 e0 = self.client[self.client.ids[0]]
502 502 e0.block = True
503 503 ar = e0.execute("5", silent=False)
504 504 er = ar.get()
505 505 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
506 506 self.assertEqual(er.pyout['data']['text/plain'], '5')
507 507
508 508 def test_execute_reply_stdout(self):
509 509 e0 = self.client[self.client.ids[0]]
510 510 e0.block = True
511 511 ar = e0.execute("print (5)", silent=False)
512 512 er = ar.get()
513 513 self.assertEqual(er.stdout.strip(), '5')
514 514
515 515 def test_execute_pyout(self):
516 516 """execute triggers pyout with silent=False"""
517 517 view = self.client[:]
518 518 ar = view.execute("5", silent=False, block=True)
519 519
520 520 expected = [{'text/plain' : '5'}] * len(view)
521 521 mimes = [ out['data'] for out in ar.pyout ]
522 522 self.assertEqual(mimes, expected)
523 523
524 524 def test_execute_silent(self):
525 525 """execute does not trigger pyout with silent=True"""
526 526 view = self.client[:]
527 527 ar = view.execute("5", block=True)
528 528 expected = [None] * len(view)
529 529 self.assertEqual(ar.pyout, expected)
530 530
531 531 def test_execute_magic(self):
532 532 """execute accepts IPython commands"""
533 533 view = self.client[:]
534 534 view.execute("a = 5")
535 535 ar = view.execute("%whos", block=True)
536 536 # this will raise, if that failed
537 537 ar.get(5)
538 538 for stdout in ar.stdout:
539 539 lines = stdout.splitlines()
540 540 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
541 541 found = False
542 542 for line in lines[2:]:
543 543 split = line.split()
544 544 if split == ['a', 'int', '5']:
545 545 found = True
546 546 break
547 547 self.assertTrue(found, "whos output wrong: %s" % stdout)
548 548
549 549 def test_execute_displaypub(self):
550 550 """execute tracks display_pub output"""
551 551 view = self.client[:]
552 552 view.execute("from IPython.core.display import *")
553 553 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
554 554
555 555 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
556 556 for outputs in ar.outputs:
557 557 mimes = [ out['data'] for out in outputs ]
558 558 self.assertEqual(mimes, expected)
559 559
560 560 def test_apply_displaypub(self):
561 561 """apply tracks display_pub output"""
562 562 view = self.client[:]
563 563 view.execute("from IPython.core.display import *")
564 564
565 565 @interactive
566 566 def publish():
567 567 [ display(i) for i in range(5) ]
568 568
569 569 ar = view.apply_async(publish)
570 570 ar.get(5)
571 571 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
572 572 for outputs in ar.outputs:
573 573 mimes = [ out['data'] for out in outputs ]
574 574 self.assertEqual(mimes, expected)
575 575
576 576 def test_execute_raises(self):
577 577 """exceptions in execute requests raise appropriately"""
578 578 view = self.client[-1]
579 579 ar = view.execute("1/0")
580 580 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
581 581
582 582 def test_remoteerror_render_exception(self):
583 583 """RemoteErrors get nice tracebacks"""
584 584 view = self.client[-1]
585 585 ar = view.execute("1/0")
586 586 ip = get_ipython()
587 587 ip.user_ns['ar'] = ar
588 588 with capture_output() as io:
589 589 ip.run_cell("ar.get(2)")
590 590
591 591 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
592 592
593 593 def test_compositeerror_render_exception(self):
594 594 """CompositeErrors get nice tracebacks"""
595 595 view = self.client[:]
596 596 ar = view.execute("1/0")
597 597 ip = get_ipython()
598 598 ip.user_ns['ar'] = ar
599 599 with capture_output() as io:
600 600 ip.run_cell("ar.get(2)")
601 601
602 602 self.assertEqual(io.stdout.count('ZeroDivisionError'), len(view) * 2, io.stdout)
603 self.assertEqual(io.stdout.count('integer division'), len(view), io.stdout)
603 self.assertEqual(io.stdout.count('by zero'), len(view), io.stdout)
604 604 self.assertEqual(io.stdout.count(':execute'), len(view), io.stdout)
605 605
606 606 @dec.skipif_not_matplotlib
607 607 def test_magic_pylab(self):
608 608 """%pylab works on engines"""
609 609 view = self.client[-1]
610 610 ar = view.execute("%pylab inline")
611 611 # at least check if this raised:
612 612 reply = ar.get(5)
613 613 # include imports, in case user config
614 614 ar = view.execute("plot(rand(100))", silent=False)
615 615 reply = ar.get(5)
616 616 self.assertEqual(len(reply.outputs), 1)
617 617 output = reply.outputs[0]
618 618 self.assertTrue("data" in output)
619 619 data = output['data']
620 620 self.assertTrue("image/png" in data)
621 621
622 622 def test_func_default_func(self):
623 623 """interactively defined function as apply func default"""
624 624 def foo():
625 625 return 'foo'
626 626
627 627 def bar(f=foo):
628 628 return f()
629 629
630 630 view = self.client[-1]
631 631 ar = view.apply_async(bar)
632 632 r = ar.get(10)
633 633 self.assertEqual(r, 'foo')
634 634 def test_data_pub_single(self):
635 635 view = self.client[-1]
636 636 ar = view.execute('\n'.join([
637 637 'from IPython.zmq.datapub import publish_data',
638 638 'for i in range(5):',
639 639 ' publish_data(dict(i=i))'
640 640 ]), block=False)
641 641 self.assertTrue(isinstance(ar.data, dict))
642 642 ar.get(5)
643 643 self.assertEqual(ar.data, dict(i=4))
644 644
645 645 def test_data_pub(self):
646 646 view = self.client[:]
647 647 ar = view.execute('\n'.join([
648 648 'from IPython.zmq.datapub import publish_data',
649 649 'for i in range(5):',
650 650 ' publish_data(dict(i=i))'
651 651 ]), block=False)
652 652 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
653 653 ar.get(5)
654 654 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
655 655
656 656 def test_can_list_arg(self):
657 657 """args in lists are canned"""
658 658 view = self.client[-1]
659 659 view['a'] = 128
660 660 rA = pmod.Reference('a')
661 661 ar = view.apply_async(lambda x: x, [rA])
662 662 r = ar.get(5)
663 663 self.assertEqual(r, [128])
664 664
665 665 def test_can_dict_arg(self):
666 666 """args in dicts are canned"""
667 667 view = self.client[-1]
668 668 view['a'] = 128
669 669 rA = pmod.Reference('a')
670 670 ar = view.apply_async(lambda x: x, dict(foo=rA))
671 671 r = ar.get(5)
672 672 self.assertEqual(r, dict(foo=128))
673 673
674 674 def test_can_list_kwarg(self):
675 675 """kwargs in lists are canned"""
676 676 view = self.client[-1]
677 677 view['a'] = 128
678 678 rA = pmod.Reference('a')
679 679 ar = view.apply_async(lambda x=5: x, x=[rA])
680 680 r = ar.get(5)
681 681 self.assertEqual(r, [128])
682 682
683 683 def test_can_dict_kwarg(self):
684 684 """kwargs in dicts are canned"""
685 685 view = self.client[-1]
686 686 view['a'] = 128
687 687 rA = pmod.Reference('a')
688 688 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
689 689 r = ar.get(5)
690 690 self.assertEqual(r, dict(foo=128))
691 691
692 692 def test_map_ref(self):
693 693 """view.map works with references"""
694 694 view = self.client[:]
695 695 ranks = sorted(self.client.ids)
696 696 view.scatter('rank', ranks, flatten=True)
697 697 rrank = pmod.Reference('rank')
698 698
699 699 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
700 700 drank = amr.get(5)
701 701 self.assertEqual(drank, [ r*2 for r in ranks ])
702 702
703 703
General Comments 0
You need to be logged in to leave comments. Login now