##// END OF EJS Templates
explicit dtype for str in recarray test...
MinRK -
Show More
@@ -1,589 +1,589 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython.testing import decorators as dec
28 28 from IPython.testing.ipunittest import ParametricTestCase
29 29
30 30 from IPython import parallel as pmod
31 31 from IPython.parallel import error
32 32 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 33 from IPython.parallel import DirectView
34 34 from IPython.parallel.util import interactive
35 35
36 36 from IPython.parallel.tests import add_engines
37 37
38 38 from .clienttest import ClusterTestCase, crash, wait, skip_without
39 39
40 40 def setup():
41 41 add_engines(3, total=True)
42 42
43 43 class TestView(ClusterTestCase, ParametricTestCase):
44 44
45 45 def test_z_crash_mux(self):
46 46 """test graceful handling of engine death (direct)"""
47 47 raise SkipTest("crash tests disabled, due to undesirable crash reports")
48 48 # self.add_engines(1)
49 49 eid = self.client.ids[-1]
50 50 ar = self.client[eid].apply_async(crash)
51 51 self.assertRaisesRemote(error.EngineError, ar.get, 10)
52 52 eid = ar.engine_id
53 53 tic = time.time()
54 54 while eid in self.client.ids and time.time()-tic < 5:
55 55 time.sleep(.01)
56 56 self.client.spin()
57 57 self.assertFalse(eid in self.client.ids, "Engine should have died")
58 58
59 59 def test_push_pull(self):
60 60 """test pushing and pulling"""
61 61 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
62 62 t = self.client.ids[-1]
63 63 v = self.client[t]
64 64 push = v.push
65 65 pull = v.pull
66 66 v.block=True
67 67 nengines = len(self.client)
68 68 push({'data':data})
69 69 d = pull('data')
70 70 self.assertEquals(d, data)
71 71 self.client[:].push({'data':data})
72 72 d = self.client[:].pull('data', block=True)
73 73 self.assertEquals(d, nengines*[data])
74 74 ar = push({'data':data}, block=False)
75 75 self.assertTrue(isinstance(ar, AsyncResult))
76 76 r = ar.get()
77 77 ar = self.client[:].pull('data', block=False)
78 78 self.assertTrue(isinstance(ar, AsyncResult))
79 79 r = ar.get()
80 80 self.assertEquals(r, nengines*[data])
81 81 self.client[:].push(dict(a=10,b=20))
82 82 r = self.client[:].pull(('a','b'), block=True)
83 83 self.assertEquals(r, nengines*[[10,20]])
84 84
85 85 def test_push_pull_function(self):
86 86 "test pushing and pulling functions"
87 87 def testf(x):
88 88 return 2.0*x
89 89
90 90 t = self.client.ids[-1]
91 91 v = self.client[t]
92 92 v.block=True
93 93 push = v.push
94 94 pull = v.pull
95 95 execute = v.execute
96 96 push({'testf':testf})
97 97 r = pull('testf')
98 98 self.assertEqual(r(1.0), testf(1.0))
99 99 execute('r = testf(10)')
100 100 r = pull('r')
101 101 self.assertEquals(r, testf(10))
102 102 ar = self.client[:].push({'testf':testf}, block=False)
103 103 ar.get()
104 104 ar = self.client[:].pull('testf', block=False)
105 105 rlist = ar.get()
106 106 for r in rlist:
107 107 self.assertEqual(r(1.0), testf(1.0))
108 108 execute("def g(x): return x*x")
109 109 r = pull(('testf','g'))
110 110 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
111 111
112 112 def test_push_function_globals(self):
113 113 """test that pushed functions have access to globals"""
114 114 @interactive
115 115 def geta():
116 116 return a
117 117 # self.add_engines(1)
118 118 v = self.client[-1]
119 119 v.block=True
120 120 v['f'] = geta
121 121 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
122 122 v.execute('a=5')
123 123 v.execute('b=f()')
124 124 self.assertEquals(v['b'], 5)
125 125
126 126 def test_push_function_defaults(self):
127 127 """test that pushed functions preserve default args"""
128 128 def echo(a=10):
129 129 return a
130 130 v = self.client[-1]
131 131 v.block=True
132 132 v['f'] = echo
133 133 v.execute('b=f()')
134 134 self.assertEquals(v['b'], 10)
135 135
136 136 def test_get_result(self):
137 137 """test getting results from the Hub."""
138 138 c = pmod.Client(profile='iptest')
139 139 # self.add_engines(1)
140 140 t = c.ids[-1]
141 141 v = c[t]
142 142 v2 = self.client[t]
143 143 ar = v.apply_async(wait, 1)
144 144 # give the monitor time to notice the message
145 145 time.sleep(.25)
146 146 ahr = v2.get_result(ar.msg_ids)
147 147 self.assertTrue(isinstance(ahr, AsyncHubResult))
148 148 self.assertEquals(ahr.get(), ar.get())
149 149 ar2 = v2.get_result(ar.msg_ids)
150 150 self.assertFalse(isinstance(ar2, AsyncHubResult))
151 151 c.spin()
152 152 c.close()
153 153
154 154 def test_run_newline(self):
155 155 """test that run appends newline to files"""
156 156 tmpfile = mktemp()
157 157 with open(tmpfile, 'w') as f:
158 158 f.write("""def g():
159 159 return 5
160 160 """)
161 161 v = self.client[-1]
162 162 v.run(tmpfile, block=True)
163 163 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
164 164
165 165 def test_apply_tracked(self):
166 166 """test tracking for apply"""
167 167 # self.add_engines(1)
168 168 t = self.client.ids[-1]
169 169 v = self.client[t]
170 170 v.block=False
171 171 def echo(n=1024*1024, **kwargs):
172 172 with v.temp_flags(**kwargs):
173 173 return v.apply(lambda x: x, 'x'*n)
174 174 ar = echo(1, track=False)
175 175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
176 176 self.assertTrue(ar.sent)
177 177 ar = echo(track=True)
178 178 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
179 179 self.assertEquals(ar.sent, ar._tracker.done)
180 180 ar._tracker.wait()
181 181 self.assertTrue(ar.sent)
182 182
183 183 def test_push_tracked(self):
184 184 t = self.client.ids[-1]
185 185 ns = dict(x='x'*1024*1024)
186 186 v = self.client[t]
187 187 ar = v.push(ns, block=False, track=False)
188 188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 189 self.assertTrue(ar.sent)
190 190
191 191 ar = v.push(ns, block=False, track=True)
192 192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
193 193 ar._tracker.wait()
194 194 self.assertEquals(ar.sent, ar._tracker.done)
195 195 self.assertTrue(ar.sent)
196 196 ar.get()
197 197
198 198 def test_scatter_tracked(self):
199 199 t = self.client.ids
200 200 x='x'*1024*1024
201 201 ar = self.client[t].scatter('x', x, block=False, track=False)
202 202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 203 self.assertTrue(ar.sent)
204 204
205 205 ar = self.client[t].scatter('x', x, block=False, track=True)
206 206 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
207 207 self.assertEquals(ar.sent, ar._tracker.done)
208 208 ar._tracker.wait()
209 209 self.assertTrue(ar.sent)
210 210 ar.get()
211 211
212 212 def test_remote_reference(self):
213 213 v = self.client[-1]
214 214 v['a'] = 123
215 215 ra = pmod.Reference('a')
216 216 b = v.apply_sync(lambda x: x, ra)
217 217 self.assertEquals(b, 123)
218 218
219 219
220 220 def test_scatter_gather(self):
221 221 view = self.client[:]
222 222 seq1 = range(16)
223 223 view.scatter('a', seq1)
224 224 seq2 = view.gather('a', block=True)
225 225 self.assertEquals(seq2, seq1)
226 226 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
227 227
228 228 @skip_without('numpy')
229 229 def test_scatter_gather_numpy(self):
230 230 import numpy
231 231 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
232 232 view = self.client[:]
233 233 a = numpy.arange(64)
234 234 view.scatter('a', a)
235 235 b = view.gather('a', block=True)
236 236 assert_array_equal(b, a)
237 237
238 238 def test_scatter_gather_lazy(self):
239 239 """scatter/gather with targets='all'"""
240 240 view = self.client.direct_view(targets='all')
241 241 x = range(64)
242 242 view.scatter('x', x)
243 243 gathered = view.gather('x', block=True)
244 244 self.assertEquals(gathered, x)
245 245
246 246
247 247 @dec.known_failure_py3
248 248 @skip_without('numpy')
249 249 def test_push_numpy_nocopy(self):
250 250 import numpy
251 251 view = self.client[:]
252 252 a = numpy.arange(64)
253 253 view['A'] = a
254 254 @interactive
255 255 def check_writeable(x):
256 256 return x.flags.writeable
257 257
258 258 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
259 259 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
260 260
261 261 view.push(dict(B=a))
262 262 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
263 263 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
264 264
265 265 @skip_without('numpy')
266 266 def test_apply_numpy(self):
267 267 """view.apply(f, ndarray)"""
268 268 import numpy
269 269 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
270 270
271 271 A = numpy.random.random((100,100))
272 272 view = self.client[-1]
273 273 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
274 274 B = A.astype(dt)
275 275 C = view.apply_sync(lambda x:x, B)
276 276 assert_array_equal(B,C)
277 277
278 278 @skip_without('numpy')
279 279 def test_push_pull_recarray(self):
280 280 """push/pull recarrays"""
281 281 import numpy
282 282 from numpy.testing.utils import assert_array_equal
283 283
284 284 view = self.client[-1]
285 285
286 286 R = numpy.array([
287 287 (1, 'hi', 0.),
288 288 (2**30, 'there', 2.5),
289 289 (-99999, 'world', -12345.6789),
290 ], [('n', int), ('s', str), ('f', float)])
290 ], [('n', int), ('s', '|S10'), ('f', float)])
291 291
292 292 view['RR'] = R
293 293 R2 = view['RR']
294 294
295 295 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
296 296 self.assertEquals(r_dtype, R.dtype)
297 297 self.assertEquals(r_shape, R.shape)
298 298 self.assertEquals(R2.dtype, R.dtype)
299 299 self.assertEquals(R2.shape, R.shape)
300 300 assert_array_equal(R2, R)
301 301
302 302 def test_map(self):
303 303 view = self.client[:]
304 304 def f(x):
305 305 return x**2
306 306 data = range(16)
307 307 r = view.map_sync(f, data)
308 308 self.assertEquals(r, map(f, data))
309 309
310 310 def test_map_iterable(self):
311 311 """test map on iterables (direct)"""
312 312 view = self.client[:]
313 313 # 101 is prime, so it won't be evenly distributed
314 314 arr = range(101)
315 315 # ensure it will be an iterator, even in Python 3
316 316 it = iter(arr)
317 317 r = view.map_sync(lambda x:x, arr)
318 318 self.assertEquals(r, list(arr))
319 319
320 320 def test_scatterGatherNonblocking(self):
321 321 data = range(16)
322 322 view = self.client[:]
323 323 view.scatter('a', data, block=False)
324 324 ar = view.gather('a', block=False)
325 325 self.assertEquals(ar.get(), data)
326 326
327 327 @skip_without('numpy')
328 328 def test_scatter_gather_numpy_nonblocking(self):
329 329 import numpy
330 330 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
331 331 a = numpy.arange(64)
332 332 view = self.client[:]
333 333 ar = view.scatter('a', a, block=False)
334 334 self.assertTrue(isinstance(ar, AsyncResult))
335 335 amr = view.gather('a', block=False)
336 336 self.assertTrue(isinstance(amr, AsyncMapResult))
337 337 assert_array_equal(amr.get(), a)
338 338
339 339 def test_execute(self):
340 340 view = self.client[:]
341 341 # self.client.debug=True
342 342 execute = view.execute
343 343 ar = execute('c=30', block=False)
344 344 self.assertTrue(isinstance(ar, AsyncResult))
345 345 ar = execute('d=[0,1,2]', block=False)
346 346 self.client.wait(ar, 1)
347 347 self.assertEquals(len(ar.get()), len(self.client))
348 348 for c in view['c']:
349 349 self.assertEquals(c, 30)
350 350
351 351 def test_abort(self):
352 352 view = self.client[-1]
353 353 ar = view.execute('import time; time.sleep(1)', block=False)
354 354 ar2 = view.apply_async(lambda : 2)
355 355 ar3 = view.apply_async(lambda : 3)
356 356 view.abort(ar2)
357 357 view.abort(ar3.msg_ids)
358 358 self.assertRaises(error.TaskAborted, ar2.get)
359 359 self.assertRaises(error.TaskAborted, ar3.get)
360 360
361 361 def test_abort_all(self):
362 362 """view.abort() aborts all outstanding tasks"""
363 363 view = self.client[-1]
364 364 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
365 365 view.abort()
366 366 view.wait(timeout=5)
367 367 for ar in ars[5:]:
368 368 self.assertRaises(error.TaskAborted, ar.get)
369 369
370 370 def test_temp_flags(self):
371 371 view = self.client[-1]
372 372 view.block=True
373 373 with view.temp_flags(block=False):
374 374 self.assertFalse(view.block)
375 375 self.assertTrue(view.block)
376 376
377 377 @dec.known_failure_py3
378 378 def test_importer(self):
379 379 view = self.client[-1]
380 380 view.clear(block=True)
381 381 with view.importer:
382 382 import re
383 383
384 384 @interactive
385 385 def findall(pat, s):
386 386 # this globals() step isn't necessary in real code
387 387 # only to prevent a closure in the test
388 388 re = globals()['re']
389 389 return re.findall(pat, s)
390 390
391 391 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
392 392
393 393 def test_unicode_execute(self):
394 394 """test executing unicode strings"""
395 395 v = self.client[-1]
396 396 v.block=True
397 397 if sys.version_info[0] >= 3:
398 398 code="a='é'"
399 399 else:
400 400 code=u"a=u'é'"
401 401 v.execute(code)
402 402 self.assertEquals(v['a'], u'é')
403 403
404 404 def test_unicode_apply_result(self):
405 405 """test unicode apply results"""
406 406 v = self.client[-1]
407 407 r = v.apply_sync(lambda : u'é')
408 408 self.assertEquals(r, u'é')
409 409
410 410 def test_unicode_apply_arg(self):
411 411 """test passing unicode arguments to apply"""
412 412 v = self.client[-1]
413 413
414 414 @interactive
415 415 def check_unicode(a, check):
416 416 assert isinstance(a, unicode), "%r is not unicode"%a
417 417 assert isinstance(check, bytes), "%r is not bytes"%check
418 418 assert a.encode('utf8') == check, "%s != %s"%(a,check)
419 419
420 420 for s in [ u'é', u'ßø®∫',u'asdf' ]:
421 421 try:
422 422 v.apply_sync(check_unicode, s, s.encode('utf8'))
423 423 except error.RemoteError as e:
424 424 if e.ename == 'AssertionError':
425 425 self.fail(e.evalue)
426 426 else:
427 427 raise e
428 428
429 429 def test_map_reference(self):
430 430 """view.map(<Reference>, *seqs) should work"""
431 431 v = self.client[:]
432 432 v.scatter('n', self.client.ids, flatten=True)
433 433 v.execute("f = lambda x,y: x*y")
434 434 rf = pmod.Reference('f')
435 435 nlist = list(range(10))
436 436 mlist = nlist[::-1]
437 437 expected = [ m*n for m,n in zip(mlist, nlist) ]
438 438 result = v.map_sync(rf, mlist, nlist)
439 439 self.assertEquals(result, expected)
440 440
441 441 def test_apply_reference(self):
442 442 """view.apply(<Reference>, *args) should work"""
443 443 v = self.client[:]
444 444 v.scatter('n', self.client.ids, flatten=True)
445 445 v.execute("f = lambda x: n*x")
446 446 rf = pmod.Reference('f')
447 447 result = v.apply_sync(rf, 5)
448 448 expected = [ 5*id for id in self.client.ids ]
449 449 self.assertEquals(result, expected)
450 450
451 451 def test_eval_reference(self):
452 452 v = self.client[self.client.ids[0]]
453 453 v['g'] = range(5)
454 454 rg = pmod.Reference('g[0]')
455 455 echo = lambda x:x
456 456 self.assertEquals(v.apply_sync(echo, rg), 0)
457 457
458 458 def test_reference_nameerror(self):
459 459 v = self.client[self.client.ids[0]]
460 460 r = pmod.Reference('elvis_has_left')
461 461 echo = lambda x:x
462 462 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
463 463
464 464 def test_single_engine_map(self):
465 465 e0 = self.client[self.client.ids[0]]
466 466 r = range(5)
467 467 check = [ -1*i for i in r ]
468 468 result = e0.map_sync(lambda x: -1*x, r)
469 469 self.assertEquals(result, check)
470 470
471 471 def test_len(self):
472 472 """len(view) makes sense"""
473 473 e0 = self.client[self.client.ids[0]]
474 474 yield self.assertEquals(len(e0), 1)
475 475 v = self.client[:]
476 476 yield self.assertEquals(len(v), len(self.client.ids))
477 477 v = self.client.direct_view('all')
478 478 yield self.assertEquals(len(v), len(self.client.ids))
479 479 v = self.client[:2]
480 480 yield self.assertEquals(len(v), 2)
481 481 v = self.client[:1]
482 482 yield self.assertEquals(len(v), 1)
483 483 v = self.client.load_balanced_view()
484 484 yield self.assertEquals(len(v), len(self.client.ids))
485 485 # parametric tests seem to require manual closing?
486 486 self.client.close()
487 487
488 488
489 489 # begin execute tests
490 490
491 491 def test_execute_reply(self):
492 492 e0 = self.client[self.client.ids[0]]
493 493 e0.block = True
494 494 ar = e0.execute("5", silent=False)
495 495 er = ar.get()
496 496 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
497 497 self.assertEquals(er.pyout['data']['text/plain'], '5')
498 498
499 499 def test_execute_reply_stdout(self):
500 500 e0 = self.client[self.client.ids[0]]
501 501 e0.block = True
502 502 ar = e0.execute("print (5)", silent=False)
503 503 er = ar.get()
504 504 self.assertEquals(er.stdout.strip(), '5')
505 505
506 506 def test_execute_pyout(self):
507 507 """execute triggers pyout with silent=False"""
508 508 view = self.client[:]
509 509 ar = view.execute("5", silent=False, block=True)
510 510
511 511 expected = [{'text/plain' : '5'}] * len(view)
512 512 mimes = [ out['data'] for out in ar.pyout ]
513 513 self.assertEquals(mimes, expected)
514 514
515 515 def test_execute_silent(self):
516 516 """execute does not trigger pyout with silent=True"""
517 517 view = self.client[:]
518 518 ar = view.execute("5", block=True)
519 519 expected = [None] * len(view)
520 520 self.assertEquals(ar.pyout, expected)
521 521
522 522 def test_execute_magic(self):
523 523 """execute accepts IPython commands"""
524 524 view = self.client[:]
525 525 view.execute("a = 5")
526 526 ar = view.execute("%whos", block=True)
527 527 # this will raise, if that failed
528 528 ar.get(5)
529 529 for stdout in ar.stdout:
530 530 lines = stdout.splitlines()
531 531 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
532 532 found = False
533 533 for line in lines[2:]:
534 534 split = line.split()
535 535 if split == ['a', 'int', '5']:
536 536 found = True
537 537 break
538 538 self.assertTrue(found, "whos output wrong: %s" % stdout)
539 539
540 540 def test_execute_displaypub(self):
541 541 """execute tracks display_pub output"""
542 542 view = self.client[:]
543 543 view.execute("from IPython.core.display import *")
544 544 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
545 545
546 546 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
547 547 for outputs in ar.outputs:
548 548 mimes = [ out['data'] for out in outputs ]
549 549 self.assertEquals(mimes, expected)
550 550
551 551 def test_apply_displaypub(self):
552 552 """apply tracks display_pub output"""
553 553 view = self.client[:]
554 554 view.execute("from IPython.core.display import *")
555 555
556 556 @interactive
557 557 def publish():
558 558 [ display(i) for i in range(5) ]
559 559
560 560 ar = view.apply_async(publish)
561 561 ar.get(5)
562 562 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
563 563 for outputs in ar.outputs:
564 564 mimes = [ out['data'] for out in outputs ]
565 565 self.assertEquals(mimes, expected)
566 566
567 567 def test_execute_raises(self):
568 568 """exceptions in execute requests raise appropriately"""
569 569 view = self.client[-1]
570 570 ar = view.execute("1/0")
571 571 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
572 572
573 573 @dec.skipif_not_matplotlib
574 574 def test_magic_pylab(self):
575 575 """%pylab works on engines"""
576 576 view = self.client[-1]
577 577 ar = view.execute("%pylab inline")
578 578 # at least check if this raised:
579 579 reply = ar.get(5)
580 580 # include imports, in case user config
581 581 ar = view.execute("plot(rand(100))", silent=False)
582 582 reply = ar.get(5)
583 583 self.assertEquals(len(reply.outputs), 1)
584 584 output = reply.outputs[0]
585 585 self.assertTrue("data" in output)
586 586 data = output['data']
587 587 self.assertTrue("image/png" in data)
588 588
589 589
General Comments 0
You need to be logged in to leave comments. Login now