##// END OF EJS Templates
Fix test_map_iterable()
Samuel Ainsworth -
Show More
@@ -1,801 +1,801 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from collections import namedtuple
23 23 from tempfile import mktemp
24 24 from StringIO import StringIO
25 25
26 26 import zmq
27 27 from nose import SkipTest
28 28 from nose.plugins.attrib import attr
29 29
30 30 from IPython.testing import decorators as dec
31 31 from IPython.testing.ipunittest import ParametricTestCase
32 32 from IPython.utils.io import capture_output
33 33
34 34 from IPython import parallel as pmod
35 35 from IPython.parallel import error
36 36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 37 from IPython.parallel import DirectView
38 38 from IPython.parallel.util import interactive
39 39
40 40 from IPython.parallel.tests import add_engines
41 41
42 42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43 43
44 44 def setup():
45 45 add_engines(3, total=True)
46 46
47 47 point = namedtuple("point", "x y")
48 48
49 49 class TestView(ClusterTestCase, ParametricTestCase):
50 50
51 51 def setUp(self):
52 52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 55 time.sleep(2)
56 56 super(TestView, self).setUp()
57 57
58 58 @attr('crash')
59 59 def test_z_crash_mux(self):
60 60 """test graceful handling of engine death (direct)"""
61 61 # self.add_engines(1)
62 62 eid = self.client.ids[-1]
63 63 ar = self.client[eid].apply_async(crash)
64 64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 65 eid = ar.engine_id
66 66 tic = time.time()
67 67 while eid in self.client.ids and time.time()-tic < 5:
68 68 time.sleep(.01)
69 69 self.client.spin()
70 70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71 71
72 72 def test_push_pull(self):
73 73 """test pushing and pulling"""
74 74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 75 t = self.client.ids[-1]
76 76 v = self.client[t]
77 77 push = v.push
78 78 pull = v.pull
79 79 v.block=True
80 80 nengines = len(self.client)
81 81 push({'data':data})
82 82 d = pull('data')
83 83 self.assertEqual(d, data)
84 84 self.client[:].push({'data':data})
85 85 d = self.client[:].pull('data', block=True)
86 86 self.assertEqual(d, nengines*[data])
87 87 ar = push({'data':data}, block=False)
88 88 self.assertTrue(isinstance(ar, AsyncResult))
89 89 r = ar.get()
90 90 ar = self.client[:].pull('data', block=False)
91 91 self.assertTrue(isinstance(ar, AsyncResult))
92 92 r = ar.get()
93 93 self.assertEqual(r, nengines*[data])
94 94 self.client[:].push(dict(a=10,b=20))
95 95 r = self.client[:].pull(('a','b'), block=True)
96 96 self.assertEqual(r, nengines*[[10,20]])
97 97
98 98 def test_push_pull_function(self):
99 99 "test pushing and pulling functions"
100 100 def testf(x):
101 101 return 2.0*x
102 102
103 103 t = self.client.ids[-1]
104 104 v = self.client[t]
105 105 v.block=True
106 106 push = v.push
107 107 pull = v.pull
108 108 execute = v.execute
109 109 push({'testf':testf})
110 110 r = pull('testf')
111 111 self.assertEqual(r(1.0), testf(1.0))
112 112 execute('r = testf(10)')
113 113 r = pull('r')
114 114 self.assertEqual(r, testf(10))
115 115 ar = self.client[:].push({'testf':testf}, block=False)
116 116 ar.get()
117 117 ar = self.client[:].pull('testf', block=False)
118 118 rlist = ar.get()
119 119 for r in rlist:
120 120 self.assertEqual(r(1.0), testf(1.0))
121 121 execute("def g(x): return x*x")
122 122 r = pull(('testf','g'))
123 123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124 124
125 125 def test_push_function_globals(self):
126 126 """test that pushed functions have access to globals"""
127 127 @interactive
128 128 def geta():
129 129 return a
130 130 # self.add_engines(1)
131 131 v = self.client[-1]
132 132 v.block=True
133 133 v['f'] = geta
134 134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 135 v.execute('a=5')
136 136 v.execute('b=f()')
137 137 self.assertEqual(v['b'], 5)
138 138
139 139 def test_push_function_defaults(self):
140 140 """test that pushed functions preserve default args"""
141 141 def echo(a=10):
142 142 return a
143 143 v = self.client[-1]
144 144 v.block=True
145 145 v['f'] = echo
146 146 v.execute('b=f()')
147 147 self.assertEqual(v['b'], 10)
148 148
149 149 def test_get_result(self):
150 150 """test getting results from the Hub."""
151 151 c = pmod.Client(profile='iptest')
152 152 # self.add_engines(1)
153 153 t = c.ids[-1]
154 154 v = c[t]
155 155 v2 = self.client[t]
156 156 ar = v.apply_async(wait, 1)
157 157 # give the monitor time to notice the message
158 158 time.sleep(.25)
159 159 ahr = v2.get_result(ar.msg_ids[0])
160 160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 161 self.assertEqual(ahr.get(), ar.get())
162 162 ar2 = v2.get_result(ar.msg_ids[0])
163 163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 164 c.spin()
165 165 c.close()
166 166
167 167 def test_run_newline(self):
168 168 """test that run appends newline to files"""
169 169 tmpfile = mktemp()
170 170 with open(tmpfile, 'w') as f:
171 171 f.write("""def g():
172 172 return 5
173 173 """)
174 174 v = self.client[-1]
175 175 v.run(tmpfile, block=True)
176 176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177 177
178 178 def test_apply_tracked(self):
179 179 """test tracking for apply"""
180 180 # self.add_engines(1)
181 181 t = self.client.ids[-1]
182 182 v = self.client[t]
183 183 v.block=False
184 184 def echo(n=1024*1024, **kwargs):
185 185 with v.temp_flags(**kwargs):
186 186 return v.apply(lambda x: x, 'x'*n)
187 187 ar = echo(1, track=False)
188 188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 189 self.assertTrue(ar.sent)
190 190 ar = echo(track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 self.assertEqual(ar.sent, ar._tracker.done)
193 193 ar._tracker.wait()
194 194 self.assertTrue(ar.sent)
195 195
196 196 def test_push_tracked(self):
197 197 t = self.client.ids[-1]
198 198 ns = dict(x='x'*1024*1024)
199 199 v = self.client[t]
200 200 ar = v.push(ns, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = v.push(ns, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 ar._tracker.wait()
207 207 self.assertEqual(ar.sent, ar._tracker.done)
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_scatter_tracked(self):
212 212 t = self.client.ids
213 213 x='x'*1024*1024
214 214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 216 self.assertTrue(ar.sent)
217 217
218 218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 220 self.assertEqual(ar.sent, ar._tracker.done)
221 221 ar._tracker.wait()
222 222 self.assertTrue(ar.sent)
223 223 ar.get()
224 224
225 225 def test_remote_reference(self):
226 226 v = self.client[-1]
227 227 v['a'] = 123
228 228 ra = pmod.Reference('a')
229 229 b = v.apply_sync(lambda x: x, ra)
230 230 self.assertEqual(b, 123)
231 231
232 232
233 233 def test_scatter_gather(self):
234 234 view = self.client[:]
235 235 seq1 = range(16)
236 236 view.scatter('a', seq1)
237 237 seq2 = view.gather('a', block=True)
238 238 self.assertEqual(seq2, seq1)
239 239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240 240
241 241 @skip_without('numpy')
242 242 def test_scatter_gather_numpy(self):
243 243 import numpy
244 244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 245 view = self.client[:]
246 246 a = numpy.arange(64)
247 247 view.scatter('a', a, block=True)
248 248 b = view.gather('a', block=True)
249 249 assert_array_equal(b, a)
250 250
251 251 def test_scatter_gather_lazy(self):
252 252 """scatter/gather with targets='all'"""
253 253 view = self.client.direct_view(targets='all')
254 254 x = range(64)
255 255 view.scatter('x', x)
256 256 gathered = view.gather('x', block=True)
257 257 self.assertEqual(gathered, x)
258 258
259 259
260 260 @dec.known_failure_py3
261 261 @skip_without('numpy')
262 262 def test_push_numpy_nocopy(self):
263 263 import numpy
264 264 view = self.client[:]
265 265 a = numpy.arange(64)
266 266 view['A'] = a
267 267 @interactive
268 268 def check_writeable(x):
269 269 return x.flags.writeable
270 270
271 271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 273
274 274 view.push(dict(B=a))
275 275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277 277
278 278 @skip_without('numpy')
279 279 def test_apply_numpy(self):
280 280 """view.apply(f, ndarray)"""
281 281 import numpy
282 282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283 283
284 284 A = numpy.random.random((100,100))
285 285 view = self.client[-1]
286 286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 287 B = A.astype(dt)
288 288 C = view.apply_sync(lambda x:x, B)
289 289 assert_array_equal(B,C)
290 290
291 291 @skip_without('numpy')
292 292 def test_push_pull_recarray(self):
293 293 """push/pull recarrays"""
294 294 import numpy
295 295 from numpy.testing.utils import assert_array_equal
296 296
297 297 view = self.client[-1]
298 298
299 299 R = numpy.array([
300 300 (1, 'hi', 0.),
301 301 (2**30, 'there', 2.5),
302 302 (-99999, 'world', -12345.6789),
303 303 ], [('n', int), ('s', '|S10'), ('f', float)])
304 304
305 305 view['RR'] = R
306 306 R2 = view['RR']
307 307
308 308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 309 self.assertEqual(r_dtype, R.dtype)
310 310 self.assertEqual(r_shape, R.shape)
311 311 self.assertEqual(R2.dtype, R.dtype)
312 312 self.assertEqual(R2.shape, R.shape)
313 313 assert_array_equal(R2, R)
314 314
315 315 @skip_without('pandas')
316 316 def test_push_pull_timeseries(self):
317 317 """push/pull pandas.TimeSeries"""
318 318 import pandas
319 319
320 320 ts = pandas.TimeSeries(range(10))
321 321
322 322 view = self.client[-1]
323 323
324 324 view.push(dict(ts=ts), block=True)
325 325 rts = view['ts']
326 326
327 327 self.assertEqual(type(rts), type(ts))
328 328 self.assertTrue((ts == rts).all())
329 329
330 330 def test_map(self):
331 331 view = self.client[:]
332 332 def f(x):
333 333 return x**2
334 334 data = range(16)
335 335 r = view.map_sync(f, data)
336 336 self.assertEqual(r, map(f, data))
337 337
338 338 def test_map_iterable(self):
339 339 """test map on iterables (direct)"""
340 340 view = self.client[:]
341 341 # 101 is prime, so it won't be evenly distributed
342 342 arr = range(101)
343 343 # ensure it will be an iterator, even in Python 3
344 344 it = iter(arr)
345 r = view.map_sync(lambda x:x, arr)
346 self.assertEqual(r, list(arr))
345 r = view.map_sync(lambda x: x, it)
346 self.assertEqual(r, list(it))
347 347
348 348 @skip_without('numpy')
349 349 def test_map_numpy(self):
350 350 """test map on numpy arrays (direct)"""
351 351 import numpy
352 352 from numpy.testing.utils import assert_array_equal
353 353
354 354 view = self.client[:]
355 355 # 101 is prime, so it won't be evenly distributed
356 356 arr = numpy.arange(101)
357 357 r = view.map_sync(lambda x: x, arr)
358 358 assert_array_equal(r, arr)
359 359
360 360 def test_scatter_gather_nonblocking(self):
361 361 data = range(16)
362 362 view = self.client[:]
363 363 view.scatter('a', data, block=False)
364 364 ar = view.gather('a', block=False)
365 365 self.assertEqual(ar.get(), data)
366 366
367 367 @skip_without('numpy')
368 368 def test_scatter_gather_numpy_nonblocking(self):
369 369 import numpy
370 370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 371 a = numpy.arange(64)
372 372 view = self.client[:]
373 373 ar = view.scatter('a', a, block=False)
374 374 self.assertTrue(isinstance(ar, AsyncResult))
375 375 amr = view.gather('a', block=False)
376 376 self.assertTrue(isinstance(amr, AsyncMapResult))
377 377 assert_array_equal(amr.get(), a)
378 378
379 379 def test_execute(self):
380 380 view = self.client[:]
381 381 # self.client.debug=True
382 382 execute = view.execute
383 383 ar = execute('c=30', block=False)
384 384 self.assertTrue(isinstance(ar, AsyncResult))
385 385 ar = execute('d=[0,1,2]', block=False)
386 386 self.client.wait(ar, 1)
387 387 self.assertEqual(len(ar.get()), len(self.client))
388 388 for c in view['c']:
389 389 self.assertEqual(c, 30)
390 390
391 391 def test_abort(self):
392 392 view = self.client[-1]
393 393 ar = view.execute('import time; time.sleep(1)', block=False)
394 394 ar2 = view.apply_async(lambda : 2)
395 395 ar3 = view.apply_async(lambda : 3)
396 396 view.abort(ar2)
397 397 view.abort(ar3.msg_ids)
398 398 self.assertRaises(error.TaskAborted, ar2.get)
399 399 self.assertRaises(error.TaskAborted, ar3.get)
400 400
401 401 def test_abort_all(self):
402 402 """view.abort() aborts all outstanding tasks"""
403 403 view = self.client[-1]
404 404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 405 view.abort()
406 406 view.wait(timeout=5)
407 407 for ar in ars[5:]:
408 408 self.assertRaises(error.TaskAborted, ar.get)
409 409
410 410 def test_temp_flags(self):
411 411 view = self.client[-1]
412 412 view.block=True
413 413 with view.temp_flags(block=False):
414 414 self.assertFalse(view.block)
415 415 self.assertTrue(view.block)
416 416
417 417 @dec.known_failure_py3
418 418 def test_importer(self):
419 419 view = self.client[-1]
420 420 view.clear(block=True)
421 421 with view.importer:
422 422 import re
423 423
424 424 @interactive
425 425 def findall(pat, s):
426 426 # this globals() step isn't necessary in real code
427 427 # only to prevent a closure in the test
428 428 re = globals()['re']
429 429 return re.findall(pat, s)
430 430
431 431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432 432
433 433 def test_unicode_execute(self):
434 434 """test executing unicode strings"""
435 435 v = self.client[-1]
436 436 v.block=True
437 437 if sys.version_info[0] >= 3:
438 438 code="a='é'"
439 439 else:
440 440 code=u"a=u'é'"
441 441 v.execute(code)
442 442 self.assertEqual(v['a'], u'é')
443 443
444 444 def test_unicode_apply_result(self):
445 445 """test unicode apply results"""
446 446 v = self.client[-1]
447 447 r = v.apply_sync(lambda : u'é')
448 448 self.assertEqual(r, u'é')
449 449
450 450 def test_unicode_apply_arg(self):
451 451 """test passing unicode arguments to apply"""
452 452 v = self.client[-1]
453 453
454 454 @interactive
455 455 def check_unicode(a, check):
456 456 assert isinstance(a, unicode), "%r is not unicode"%a
457 457 assert isinstance(check, bytes), "%r is not bytes"%check
458 458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459 459
460 460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 461 try:
462 462 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 463 except error.RemoteError as e:
464 464 if e.ename == 'AssertionError':
465 465 self.fail(e.evalue)
466 466 else:
467 467 raise e
468 468
469 469 def test_map_reference(self):
470 470 """view.map(<Reference>, *seqs) should work"""
471 471 v = self.client[:]
472 472 v.scatter('n', self.client.ids, flatten=True)
473 473 v.execute("f = lambda x,y: x*y")
474 474 rf = pmod.Reference('f')
475 475 nlist = list(range(10))
476 476 mlist = nlist[::-1]
477 477 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 478 result = v.map_sync(rf, mlist, nlist)
479 479 self.assertEqual(result, expected)
480 480
481 481 def test_apply_reference(self):
482 482 """view.apply(<Reference>, *args) should work"""
483 483 v = self.client[:]
484 484 v.scatter('n', self.client.ids, flatten=True)
485 485 v.execute("f = lambda x: n*x")
486 486 rf = pmod.Reference('f')
487 487 result = v.apply_sync(rf, 5)
488 488 expected = [ 5*id for id in self.client.ids ]
489 489 self.assertEqual(result, expected)
490 490
491 491 def test_eval_reference(self):
492 492 v = self.client[self.client.ids[0]]
493 493 v['g'] = range(5)
494 494 rg = pmod.Reference('g[0]')
495 495 echo = lambda x:x
496 496 self.assertEqual(v.apply_sync(echo, rg), 0)
497 497
498 498 def test_reference_nameerror(self):
499 499 v = self.client[self.client.ids[0]]
500 500 r = pmod.Reference('elvis_has_left')
501 501 echo = lambda x:x
502 502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503 503
504 504 def test_single_engine_map(self):
505 505 e0 = self.client[self.client.ids[0]]
506 506 r = range(5)
507 507 check = [ -1*i for i in r ]
508 508 result = e0.map_sync(lambda x: -1*x, r)
509 509 self.assertEqual(result, check)
510 510
511 511 def test_len(self):
512 512 """len(view) makes sense"""
513 513 e0 = self.client[self.client.ids[0]]
514 514 yield self.assertEqual(len(e0), 1)
515 515 v = self.client[:]
516 516 yield self.assertEqual(len(v), len(self.client.ids))
517 517 v = self.client.direct_view('all')
518 518 yield self.assertEqual(len(v), len(self.client.ids))
519 519 v = self.client[:2]
520 520 yield self.assertEqual(len(v), 2)
521 521 v = self.client[:1]
522 522 yield self.assertEqual(len(v), 1)
523 523 v = self.client.load_balanced_view()
524 524 yield self.assertEqual(len(v), len(self.client.ids))
525 525 # parametric tests seem to require manual closing?
526 526 self.client.close()
527 527
528 528
529 529 # begin execute tests
530 530
531 531 def test_execute_reply(self):
532 532 e0 = self.client[self.client.ids[0]]
533 533 e0.block = True
534 534 ar = e0.execute("5", silent=False)
535 535 er = ar.get()
536 536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 537 self.assertEqual(er.pyout['data']['text/plain'], '5')
538 538
539 539 def test_execute_reply_stdout(self):
540 540 e0 = self.client[self.client.ids[0]]
541 541 e0.block = True
542 542 ar = e0.execute("print (5)", silent=False)
543 543 er = ar.get()
544 544 self.assertEqual(er.stdout.strip(), '5')
545 545
546 546 def test_execute_pyout(self):
547 547 """execute triggers pyout with silent=False"""
548 548 view = self.client[:]
549 549 ar = view.execute("5", silent=False, block=True)
550 550
551 551 expected = [{'text/plain' : '5'}] * len(view)
552 552 mimes = [ out['data'] for out in ar.pyout ]
553 553 self.assertEqual(mimes, expected)
554 554
555 555 def test_execute_silent(self):
556 556 """execute does not trigger pyout with silent=True"""
557 557 view = self.client[:]
558 558 ar = view.execute("5", block=True)
559 559 expected = [None] * len(view)
560 560 self.assertEqual(ar.pyout, expected)
561 561
562 562 def test_execute_magic(self):
563 563 """execute accepts IPython commands"""
564 564 view = self.client[:]
565 565 view.execute("a = 5")
566 566 ar = view.execute("%whos", block=True)
567 567 # this will raise, if that failed
568 568 ar.get(5)
569 569 for stdout in ar.stdout:
570 570 lines = stdout.splitlines()
571 571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
572 572 found = False
573 573 for line in lines[2:]:
574 574 split = line.split()
575 575 if split == ['a', 'int', '5']:
576 576 found = True
577 577 break
578 578 self.assertTrue(found, "whos output wrong: %s" % stdout)
579 579
580 580 def test_execute_displaypub(self):
581 581 """execute tracks display_pub output"""
582 582 view = self.client[:]
583 583 view.execute("from IPython.core.display import *")
584 584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
585 585
586 586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
587 587 for outputs in ar.outputs:
588 588 mimes = [ out['data'] for out in outputs ]
589 589 self.assertEqual(mimes, expected)
590 590
591 591 def test_apply_displaypub(self):
592 592 """apply tracks display_pub output"""
593 593 view = self.client[:]
594 594 view.execute("from IPython.core.display import *")
595 595
596 596 @interactive
597 597 def publish():
598 598 [ display(i) for i in range(5) ]
599 599
600 600 ar = view.apply_async(publish)
601 601 ar.get(5)
602 602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
603 603 for outputs in ar.outputs:
604 604 mimes = [ out['data'] for out in outputs ]
605 605 self.assertEqual(mimes, expected)
606 606
607 607 def test_execute_raises(self):
608 608 """exceptions in execute requests raise appropriately"""
609 609 view = self.client[-1]
610 610 ar = view.execute("1/0")
611 611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
612 612
613 613 def test_remoteerror_render_exception(self):
614 614 """RemoteErrors get nice tracebacks"""
615 615 view = self.client[-1]
616 616 ar = view.execute("1/0")
617 617 ip = get_ipython()
618 618 ip.user_ns['ar'] = ar
619 619 with capture_output() as io:
620 620 ip.run_cell("ar.get(2)")
621 621
622 622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
623 623
624 624 def test_compositeerror_render_exception(self):
625 625 """CompositeErrors get nice tracebacks"""
626 626 view = self.client[:]
627 627 ar = view.execute("1/0")
628 628 ip = get_ipython()
629 629 ip.user_ns['ar'] = ar
630 630
631 631 with capture_output() as io:
632 632 ip.run_cell("ar.get(2)")
633 633
634 634 count = min(error.CompositeError.tb_limit, len(view))
635 635
636 636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
637 637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
638 638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
639 639
640 640 def test_compositeerror_truncate(self):
641 641 """Truncate CompositeErrors with many exceptions"""
642 642 view = self.client[:]
643 643 msg_ids = []
644 644 for i in range(10):
645 645 ar = view.execute("1/0")
646 646 msg_ids.extend(ar.msg_ids)
647 647
648 648 ar = self.client.get_result(msg_ids)
649 649 try:
650 650 ar.get()
651 651 except error.CompositeError as _e:
652 652 e = _e
653 653 else:
654 654 self.fail("Should have raised CompositeError")
655 655
656 656 lines = e.render_traceback()
657 657 with capture_output() as io:
658 658 e.print_traceback()
659 659
660 660 self.assertTrue("more exceptions" in lines[-1])
661 661 count = e.tb_limit
662 662
663 663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
664 664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
665 665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
666 666
667 667 @dec.skipif_not_matplotlib
668 668 def test_magic_pylab(self):
669 669 """%pylab works on engines"""
670 670 view = self.client[-1]
671 671 ar = view.execute("%pylab inline")
672 672 # at least check if this raised:
673 673 reply = ar.get(5)
674 674 # include imports, in case user config
675 675 ar = view.execute("plot(rand(100))", silent=False)
676 676 reply = ar.get(5)
677 677 self.assertEqual(len(reply.outputs), 1)
678 678 output = reply.outputs[0]
679 679 self.assertTrue("data" in output)
680 680 data = output['data']
681 681 self.assertTrue("image/png" in data)
682 682
683 683 def test_func_default_func(self):
684 684 """interactively defined function as apply func default"""
685 685 def foo():
686 686 return 'foo'
687 687
688 688 def bar(f=foo):
689 689 return f()
690 690
691 691 view = self.client[-1]
692 692 ar = view.apply_async(bar)
693 693 r = ar.get(10)
694 694 self.assertEqual(r, 'foo')
695 695 def test_data_pub_single(self):
696 696 view = self.client[-1]
697 697 ar = view.execute('\n'.join([
698 698 'from IPython.kernel.zmq.datapub import publish_data',
699 699 'for i in range(5):',
700 700 ' publish_data(dict(i=i))'
701 701 ]), block=False)
702 702 self.assertTrue(isinstance(ar.data, dict))
703 703 ar.get(5)
704 704 self.assertEqual(ar.data, dict(i=4))
705 705
706 706 def test_data_pub(self):
707 707 view = self.client[:]
708 708 ar = view.execute('\n'.join([
709 709 'from IPython.kernel.zmq.datapub import publish_data',
710 710 'for i in range(5):',
711 711 ' publish_data(dict(i=i))'
712 712 ]), block=False)
713 713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
714 714 ar.get(5)
715 715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
716 716
717 717 def test_can_list_arg(self):
718 718 """args in lists are canned"""
719 719 view = self.client[-1]
720 720 view['a'] = 128
721 721 rA = pmod.Reference('a')
722 722 ar = view.apply_async(lambda x: x, [rA])
723 723 r = ar.get(5)
724 724 self.assertEqual(r, [128])
725 725
726 726 def test_can_dict_arg(self):
727 727 """args in dicts are canned"""
728 728 view = self.client[-1]
729 729 view['a'] = 128
730 730 rA = pmod.Reference('a')
731 731 ar = view.apply_async(lambda x: x, dict(foo=rA))
732 732 r = ar.get(5)
733 733 self.assertEqual(r, dict(foo=128))
734 734
735 735 def test_can_list_kwarg(self):
736 736 """kwargs in lists are canned"""
737 737 view = self.client[-1]
738 738 view['a'] = 128
739 739 rA = pmod.Reference('a')
740 740 ar = view.apply_async(lambda x=5: x, x=[rA])
741 741 r = ar.get(5)
742 742 self.assertEqual(r, [128])
743 743
744 744 def test_can_dict_kwarg(self):
745 745 """kwargs in dicts are canned"""
746 746 view = self.client[-1]
747 747 view['a'] = 128
748 748 rA = pmod.Reference('a')
749 749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
750 750 r = ar.get(5)
751 751 self.assertEqual(r, dict(foo=128))
752 752
753 753 def test_map_ref(self):
754 754 """view.map works with references"""
755 755 view = self.client[:]
756 756 ranks = sorted(self.client.ids)
757 757 view.scatter('rank', ranks, flatten=True)
758 758 rrank = pmod.Reference('rank')
759 759
760 760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
761 761 drank = amr.get(5)
762 762 self.assertEqual(drank, [ r*2 for r in ranks ])
763 763
764 764 def test_nested_getitem_setitem(self):
765 765 """get and set with view['a.b']"""
766 766 view = self.client[-1]
767 767 view.execute('\n'.join([
768 768 'class A(object): pass',
769 769 'a = A()',
770 770 'a.b = 128',
771 771 ]), block=True)
772 772 ra = pmod.Reference('a')
773 773
774 774 r = view.apply_sync(lambda x: x.b, ra)
775 775 self.assertEqual(r, 128)
776 776 self.assertEqual(view['a.b'], 128)
777 777
778 778 view['a.b'] = 0
779 779
780 780 r = view.apply_sync(lambda x: x.b, ra)
781 781 self.assertEqual(r, 0)
782 782 self.assertEqual(view['a.b'], 0)
783 783
784 784 def test_return_namedtuple(self):
785 785 def namedtuplify(x, y):
786 786 from IPython.parallel.tests.test_view import point
787 787 return point(x, y)
788 788
789 789 view = self.client[-1]
790 790 p = view.apply_sync(namedtuplify, 1, 2)
791 791 self.assertEqual(p.x, 1)
792 792 self.assertEqual(p.y, 2)
793 793
794 794 def test_apply_namedtuple(self):
795 795 def echoxy(p):
796 796 return p.y, p.x
797 797
798 798 view = self.client[-1]
799 799 tup = view.apply_sync(echoxy, point(1, 2))
800 800 self.assertEqual(tup, (2,1))
801 801
General Comments 0
You need to be logged in to leave comments. Login now