##// END OF EJS Templates
pickle arrays with dtype=object...
MinRK -
Show More
@@ -1,835 +1,850 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import base64
20 20 import sys
21 21 import platform
22 22 import time
23 23 from collections import namedtuple
24 24 from tempfile import mktemp
25 25
26 26 import zmq
27 27 from nose.plugins.attrib import attr
28 28
29 29 from IPython.testing import decorators as dec
30 30 from IPython.utils.io import capture_output
31 31 from IPython.utils.py3compat import unicode_type
32 32
33 33 from IPython import parallel as pmod
34 34 from IPython.parallel import error
35 35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 36 from IPython.parallel.util import interactive
37 37
38 38 from IPython.parallel.tests import add_engines
39 39
40 40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41 41
42 42 def setup():
43 43 add_engines(3, total=True)
44 44
45 45 point = namedtuple("point", "x y")
46 46
47 47 class TestView(ClusterTestCase):
48 48
49 49 def setUp(self):
50 50 # On Win XP, wait for resource cleanup, else parallel test group fails
51 51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 53 time.sleep(2)
54 54 super(TestView, self).setUp()
55 55
56 56 @attr('crash')
57 57 def test_z_crash_mux(self):
58 58 """test graceful handling of engine death (direct)"""
59 59 # self.add_engines(1)
60 60 eid = self.client.ids[-1]
61 61 ar = self.client[eid].apply_async(crash)
62 62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 63 eid = ar.engine_id
64 64 tic = time.time()
65 65 while eid in self.client.ids and time.time()-tic < 5:
66 66 time.sleep(.01)
67 67 self.client.spin()
68 68 self.assertFalse(eid in self.client.ids, "Engine should have died")
69 69
70 70 def test_push_pull(self):
71 71 """test pushing and pulling"""
72 72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 73 t = self.client.ids[-1]
74 74 v = self.client[t]
75 75 push = v.push
76 76 pull = v.pull
77 77 v.block=True
78 78 nengines = len(self.client)
79 79 push({'data':data})
80 80 d = pull('data')
81 81 self.assertEqual(d, data)
82 82 self.client[:].push({'data':data})
83 83 d = self.client[:].pull('data', block=True)
84 84 self.assertEqual(d, nengines*[data])
85 85 ar = push({'data':data}, block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 88 ar = self.client[:].pull('data', block=False)
89 89 self.assertTrue(isinstance(ar, AsyncResult))
90 90 r = ar.get()
91 91 self.assertEqual(r, nengines*[data])
92 92 self.client[:].push(dict(a=10,b=20))
93 93 r = self.client[:].pull(('a','b'), block=True)
94 94 self.assertEqual(r, nengines*[[10,20]])
95 95
96 96 def test_push_pull_function(self):
97 97 "test pushing and pulling functions"
98 98 def testf(x):
99 99 return 2.0*x
100 100
101 101 t = self.client.ids[-1]
102 102 v = self.client[t]
103 103 v.block=True
104 104 push = v.push
105 105 pull = v.pull
106 106 execute = v.execute
107 107 push({'testf':testf})
108 108 r = pull('testf')
109 109 self.assertEqual(r(1.0), testf(1.0))
110 110 execute('r = testf(10)')
111 111 r = pull('r')
112 112 self.assertEqual(r, testf(10))
113 113 ar = self.client[:].push({'testf':testf}, block=False)
114 114 ar.get()
115 115 ar = self.client[:].pull('testf', block=False)
116 116 rlist = ar.get()
117 117 for r in rlist:
118 118 self.assertEqual(r(1.0), testf(1.0))
119 119 execute("def g(x): return x*x")
120 120 r = pull(('testf','g'))
121 121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122 122
123 123 def test_push_function_globals(self):
124 124 """test that pushed functions have access to globals"""
125 125 @interactive
126 126 def geta():
127 127 return a
128 128 # self.add_engines(1)
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = geta
132 132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 133 v.execute('a=5')
134 134 v.execute('b=f()')
135 135 self.assertEqual(v['b'], 5)
136 136
137 137 def test_push_function_defaults(self):
138 138 """test that pushed functions preserve default args"""
139 139 def echo(a=10):
140 140 return a
141 141 v = self.client[-1]
142 142 v.block=True
143 143 v['f'] = echo
144 144 v.execute('b=f()')
145 145 self.assertEqual(v['b'], 10)
146 146
147 147 def test_get_result(self):
148 148 """test getting results from the Hub."""
149 149 c = pmod.Client(profile='iptest')
150 150 # self.add_engines(1)
151 151 t = c.ids[-1]
152 152 v = c[t]
153 153 v2 = self.client[t]
154 154 ar = v.apply_async(wait, 1)
155 155 # give the monitor time to notice the message
156 156 time.sleep(.25)
157 157 ahr = v2.get_result(ar.msg_ids[0])
158 158 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 159 self.assertEqual(ahr.get(), ar.get())
160 160 ar2 = v2.get_result(ar.msg_ids[0])
161 161 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 162 c.spin()
163 163 c.close()
164 164
165 165 def test_run_newline(self):
166 166 """test that run appends newline to files"""
167 167 tmpfile = mktemp()
168 168 with open(tmpfile, 'w') as f:
169 169 f.write("""def g():
170 170 return 5
171 171 """)
172 172 v = self.client[-1]
173 173 v.run(tmpfile, block=True)
174 174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175 175
176 176 def test_apply_tracked(self):
177 177 """test tracking for apply"""
178 178 # self.add_engines(1)
179 179 t = self.client.ids[-1]
180 180 v = self.client[t]
181 181 v.block=False
182 182 def echo(n=1024*1024, **kwargs):
183 183 with v.temp_flags(**kwargs):
184 184 return v.apply(lambda x: x, 'x'*n)
185 185 ar = echo(1, track=False)
186 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 187 self.assertTrue(ar.sent)
188 188 ar = echo(track=True)
189 189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 190 self.assertEqual(ar.sent, ar._tracker.done)
191 191 ar._tracker.wait()
192 192 self.assertTrue(ar.sent)
193 193
194 194 def test_push_tracked(self):
195 195 t = self.client.ids[-1]
196 196 ns = dict(x='x'*1024*1024)
197 197 v = self.client[t]
198 198 ar = v.push(ns, block=False, track=False)
199 199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 200 self.assertTrue(ar.sent)
201 201
202 202 ar = v.push(ns, block=False, track=True)
203 203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 204 ar._tracker.wait()
205 205 self.assertEqual(ar.sent, ar._tracker.done)
206 206 self.assertTrue(ar.sent)
207 207 ar.get()
208 208
209 209 def test_scatter_tracked(self):
210 210 t = self.client.ids
211 211 x='x'*1024*1024
212 212 ar = self.client[t].scatter('x', x, block=False, track=False)
213 213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 214 self.assertTrue(ar.sent)
215 215
216 216 ar = self.client[t].scatter('x', x, block=False, track=True)
217 217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 218 self.assertEqual(ar.sent, ar._tracker.done)
219 219 ar._tracker.wait()
220 220 self.assertTrue(ar.sent)
221 221 ar.get()
222 222
223 223 def test_remote_reference(self):
224 224 v = self.client[-1]
225 225 v['a'] = 123
226 226 ra = pmod.Reference('a')
227 227 b = v.apply_sync(lambda x: x, ra)
228 228 self.assertEqual(b, 123)
229 229
230 230
231 231 def test_scatter_gather(self):
232 232 view = self.client[:]
233 233 seq1 = list(range(16))
234 234 view.scatter('a', seq1)
235 235 seq2 = view.gather('a', block=True)
236 236 self.assertEqual(seq2, seq1)
237 237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238 238
239 239 @skip_without('numpy')
240 240 def test_scatter_gather_numpy(self):
241 241 import numpy
242 242 from numpy.testing.utils import assert_array_equal
243 243 view = self.client[:]
244 244 a = numpy.arange(64)
245 245 view.scatter('a', a, block=True)
246 246 b = view.gather('a', block=True)
247 247 assert_array_equal(b, a)
248 248
249 249 def test_scatter_gather_lazy(self):
250 250 """scatter/gather with targets='all'"""
251 251 view = self.client.direct_view(targets='all')
252 252 x = list(range(64))
253 253 view.scatter('x', x)
254 254 gathered = view.gather('x', block=True)
255 255 self.assertEqual(gathered, x)
256 256
257 257
258 258 @dec.known_failure_py3
259 259 @skip_without('numpy')
260 260 def test_push_numpy_nocopy(self):
261 261 import numpy
262 262 view = self.client[:]
263 263 a = numpy.arange(64)
264 264 view['A'] = a
265 265 @interactive
266 266 def check_writeable(x):
267 267 return x.flags.writeable
268 268
269 269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 271
272 272 view.push(dict(B=a))
273 273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275 275
276 276 @skip_without('numpy')
277 277 def test_apply_numpy(self):
278 278 """view.apply(f, ndarray)"""
279 279 import numpy
280 280 from numpy.testing.utils import assert_array_equal
281 281
282 282 A = numpy.random.random((100,100))
283 283 view = self.client[-1]
284 284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 285 B = A.astype(dt)
286 286 C = view.apply_sync(lambda x:x, B)
287 287 assert_array_equal(B,C)
288 288
289 289 @skip_without('numpy')
290 def test_apply_numpy_object_dtype(self):
291 """view.apply(f, ndarray) with dtype=object"""
292 import numpy
293 from numpy.testing.utils import assert_array_equal
294 view = self.client[-1]
295
296 A = numpy.array([dict(a=5)])
297 B = view.apply_sync(lambda x:x, A)
298 assert_array_equal(A,B)
299
300 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
301 B = view.apply_sync(lambda x:x, A)
302 assert_array_equal(A,B)
303
304 @skip_without('numpy')
290 305 def test_push_pull_recarray(self):
291 306 """push/pull recarrays"""
292 307 import numpy
293 308 from numpy.testing.utils import assert_array_equal
294 309
295 310 view = self.client[-1]
296 311
297 312 R = numpy.array([
298 313 (1, 'hi', 0.),
299 314 (2**30, 'there', 2.5),
300 315 (-99999, 'world', -12345.6789),
301 316 ], [('n', int), ('s', '|S10'), ('f', float)])
302 317
303 318 view['RR'] = R
304 319 R2 = view['RR']
305 320
306 321 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
307 322 self.assertEqual(r_dtype, R.dtype)
308 323 self.assertEqual(r_shape, R.shape)
309 324 self.assertEqual(R2.dtype, R.dtype)
310 325 self.assertEqual(R2.shape, R.shape)
311 326 assert_array_equal(R2, R)
312 327
313 328 @skip_without('pandas')
314 329 def test_push_pull_timeseries(self):
315 330 """push/pull pandas.TimeSeries"""
316 331 import pandas
317 332
318 333 ts = pandas.TimeSeries(list(range(10)))
319 334
320 335 view = self.client[-1]
321 336
322 337 view.push(dict(ts=ts), block=True)
323 338 rts = view['ts']
324 339
325 340 self.assertEqual(type(rts), type(ts))
326 341 self.assertTrue((ts == rts).all())
327 342
328 343 def test_map(self):
329 344 view = self.client[:]
330 345 def f(x):
331 346 return x**2
332 347 data = list(range(16))
333 348 r = view.map_sync(f, data)
334 349 self.assertEqual(r, list(map(f, data)))
335 350
336 351 def test_map_iterable(self):
337 352 """test map on iterables (direct)"""
338 353 view = self.client[:]
339 354 # 101 is prime, so it won't be evenly distributed
340 355 arr = range(101)
341 356 # ensure it will be an iterator, even in Python 3
342 357 it = iter(arr)
343 358 r = view.map_sync(lambda x: x, it)
344 359 self.assertEqual(r, list(arr))
345 360
346 361 @skip_without('numpy')
347 362 def test_map_numpy(self):
348 363 """test map on numpy arrays (direct)"""
349 364 import numpy
350 365 from numpy.testing.utils import assert_array_equal
351 366
352 367 view = self.client[:]
353 368 # 101 is prime, so it won't be evenly distributed
354 369 arr = numpy.arange(101)
355 370 r = view.map_sync(lambda x: x, arr)
356 371 assert_array_equal(r, arr)
357 372
358 373 def test_scatter_gather_nonblocking(self):
359 374 data = list(range(16))
360 375 view = self.client[:]
361 376 view.scatter('a', data, block=False)
362 377 ar = view.gather('a', block=False)
363 378 self.assertEqual(ar.get(), data)
364 379
365 380 @skip_without('numpy')
366 381 def test_scatter_gather_numpy_nonblocking(self):
367 382 import numpy
368 383 from numpy.testing.utils import assert_array_equal
369 384 a = numpy.arange(64)
370 385 view = self.client[:]
371 386 ar = view.scatter('a', a, block=False)
372 387 self.assertTrue(isinstance(ar, AsyncResult))
373 388 amr = view.gather('a', block=False)
374 389 self.assertTrue(isinstance(amr, AsyncMapResult))
375 390 assert_array_equal(amr.get(), a)
376 391
377 392 def test_execute(self):
378 393 view = self.client[:]
379 394 # self.client.debug=True
380 395 execute = view.execute
381 396 ar = execute('c=30', block=False)
382 397 self.assertTrue(isinstance(ar, AsyncResult))
383 398 ar = execute('d=[0,1,2]', block=False)
384 399 self.client.wait(ar, 1)
385 400 self.assertEqual(len(ar.get()), len(self.client))
386 401 for c in view['c']:
387 402 self.assertEqual(c, 30)
388 403
389 404 def test_abort(self):
390 405 view = self.client[-1]
391 406 ar = view.execute('import time; time.sleep(1)', block=False)
392 407 ar2 = view.apply_async(lambda : 2)
393 408 ar3 = view.apply_async(lambda : 3)
394 409 view.abort(ar2)
395 410 view.abort(ar3.msg_ids)
396 411 self.assertRaises(error.TaskAborted, ar2.get)
397 412 self.assertRaises(error.TaskAborted, ar3.get)
398 413
399 414 def test_abort_all(self):
400 415 """view.abort() aborts all outstanding tasks"""
401 416 view = self.client[-1]
402 417 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
403 418 view.abort()
404 419 view.wait(timeout=5)
405 420 for ar in ars[5:]:
406 421 self.assertRaises(error.TaskAborted, ar.get)
407 422
408 423 def test_temp_flags(self):
409 424 view = self.client[-1]
410 425 view.block=True
411 426 with view.temp_flags(block=False):
412 427 self.assertFalse(view.block)
413 428 self.assertTrue(view.block)
414 429
415 430 @dec.known_failure_py3
416 431 def test_importer(self):
417 432 view = self.client[-1]
418 433 view.clear(block=True)
419 434 with view.importer:
420 435 import re
421 436
422 437 @interactive
423 438 def findall(pat, s):
424 439 # this globals() step isn't necessary in real code
425 440 # only to prevent a closure in the test
426 441 re = globals()['re']
427 442 return re.findall(pat, s)
428 443
429 444 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
430 445
431 446 def test_unicode_execute(self):
432 447 """test executing unicode strings"""
433 448 v = self.client[-1]
434 449 v.block=True
435 450 if sys.version_info[0] >= 3:
436 451 code="a='é'"
437 452 else:
438 453 code=u"a=u'é'"
439 454 v.execute(code)
440 455 self.assertEqual(v['a'], u'é')
441 456
442 457 def test_unicode_apply_result(self):
443 458 """test unicode apply results"""
444 459 v = self.client[-1]
445 460 r = v.apply_sync(lambda : u'é')
446 461 self.assertEqual(r, u'é')
447 462
448 463 def test_unicode_apply_arg(self):
449 464 """test passing unicode arguments to apply"""
450 465 v = self.client[-1]
451 466
452 467 @interactive
453 468 def check_unicode(a, check):
454 469 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
455 470 assert isinstance(check, bytes), "%r is not bytes"%check
456 471 assert a.encode('utf8') == check, "%s != %s"%(a,check)
457 472
458 473 for s in [ u'é', u'ßø®∫',u'asdf' ]:
459 474 try:
460 475 v.apply_sync(check_unicode, s, s.encode('utf8'))
461 476 except error.RemoteError as e:
462 477 if e.ename == 'AssertionError':
463 478 self.fail(e.evalue)
464 479 else:
465 480 raise e
466 481
467 482 def test_map_reference(self):
468 483 """view.map(<Reference>, *seqs) should work"""
469 484 v = self.client[:]
470 485 v.scatter('n', self.client.ids, flatten=True)
471 486 v.execute("f = lambda x,y: x*y")
472 487 rf = pmod.Reference('f')
473 488 nlist = list(range(10))
474 489 mlist = nlist[::-1]
475 490 expected = [ m*n for m,n in zip(mlist, nlist) ]
476 491 result = v.map_sync(rf, mlist, nlist)
477 492 self.assertEqual(result, expected)
478 493
479 494 def test_apply_reference(self):
480 495 """view.apply(<Reference>, *args) should work"""
481 496 v = self.client[:]
482 497 v.scatter('n', self.client.ids, flatten=True)
483 498 v.execute("f = lambda x: n*x")
484 499 rf = pmod.Reference('f')
485 500 result = v.apply_sync(rf, 5)
486 501 expected = [ 5*id for id in self.client.ids ]
487 502 self.assertEqual(result, expected)
488 503
489 504 def test_eval_reference(self):
490 505 v = self.client[self.client.ids[0]]
491 506 v['g'] = list(range(5))
492 507 rg = pmod.Reference('g[0]')
493 508 echo = lambda x:x
494 509 self.assertEqual(v.apply_sync(echo, rg), 0)
495 510
496 511 def test_reference_nameerror(self):
497 512 v = self.client[self.client.ids[0]]
498 513 r = pmod.Reference('elvis_has_left')
499 514 echo = lambda x:x
500 515 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
501 516
502 517 def test_single_engine_map(self):
503 518 e0 = self.client[self.client.ids[0]]
504 519 r = list(range(5))
505 520 check = [ -1*i for i in r ]
506 521 result = e0.map_sync(lambda x: -1*x, r)
507 522 self.assertEqual(result, check)
508 523
509 524 def test_len(self):
510 525 """len(view) makes sense"""
511 526 e0 = self.client[self.client.ids[0]]
512 527 self.assertEqual(len(e0), 1)
513 528 v = self.client[:]
514 529 self.assertEqual(len(v), len(self.client.ids))
515 530 v = self.client.direct_view('all')
516 531 self.assertEqual(len(v), len(self.client.ids))
517 532 v = self.client[:2]
518 533 self.assertEqual(len(v), 2)
519 534 v = self.client[:1]
520 535 self.assertEqual(len(v), 1)
521 536 v = self.client.load_balanced_view()
522 537 self.assertEqual(len(v), len(self.client.ids))
523 538
524 539
525 540 # begin execute tests
526 541
527 542 def test_execute_reply(self):
528 543 e0 = self.client[self.client.ids[0]]
529 544 e0.block = True
530 545 ar = e0.execute("5", silent=False)
531 546 er = ar.get()
532 547 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
533 548 self.assertEqual(er.pyout['data']['text/plain'], '5')
534 549
535 550 def test_execute_reply_rich(self):
536 551 e0 = self.client[self.client.ids[0]]
537 552 e0.block = True
538 553 e0.execute("from IPython.display import Image, HTML")
539 554 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
540 555 er = ar.get()
541 556 b64data = base64.encodestring(b'garbage').decode('ascii')
542 557 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
543 558 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
544 559 er = ar.get()
545 560 self.assertEqual(er._repr_html_(), "<b>bold</b>")
546 561
547 562 def test_execute_reply_stdout(self):
548 563 e0 = self.client[self.client.ids[0]]
549 564 e0.block = True
550 565 ar = e0.execute("print (5)", silent=False)
551 566 er = ar.get()
552 567 self.assertEqual(er.stdout.strip(), '5')
553 568
554 569 def test_execute_pyout(self):
555 570 """execute triggers pyout with silent=False"""
556 571 view = self.client[:]
557 572 ar = view.execute("5", silent=False, block=True)
558 573
559 574 expected = [{'text/plain' : '5'}] * len(view)
560 575 mimes = [ out['data'] for out in ar.pyout ]
561 576 self.assertEqual(mimes, expected)
562 577
563 578 def test_execute_silent(self):
564 579 """execute does not trigger pyout with silent=True"""
565 580 view = self.client[:]
566 581 ar = view.execute("5", block=True)
567 582 expected = [None] * len(view)
568 583 self.assertEqual(ar.pyout, expected)
569 584
570 585 def test_execute_magic(self):
571 586 """execute accepts IPython commands"""
572 587 view = self.client[:]
573 588 view.execute("a = 5")
574 589 ar = view.execute("%whos", block=True)
575 590 # this will raise, if that failed
576 591 ar.get(5)
577 592 for stdout in ar.stdout:
578 593 lines = stdout.splitlines()
579 594 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
580 595 found = False
581 596 for line in lines[2:]:
582 597 split = line.split()
583 598 if split == ['a', 'int', '5']:
584 599 found = True
585 600 break
586 601 self.assertTrue(found, "whos output wrong: %s" % stdout)
587 602
588 603 def test_execute_displaypub(self):
589 604 """execute tracks display_pub output"""
590 605 view = self.client[:]
591 606 view.execute("from IPython.core.display import *")
592 607 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
593 608
594 609 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
595 610 for outputs in ar.outputs:
596 611 mimes = [ out['data'] for out in outputs ]
597 612 self.assertEqual(mimes, expected)
598 613
599 614 def test_apply_displaypub(self):
600 615 """apply tracks display_pub output"""
601 616 view = self.client[:]
602 617 view.execute("from IPython.core.display import *")
603 618
604 619 @interactive
605 620 def publish():
606 621 [ display(i) for i in range(5) ]
607 622
608 623 ar = view.apply_async(publish)
609 624 ar.get(5)
610 625 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
611 626 for outputs in ar.outputs:
612 627 mimes = [ out['data'] for out in outputs ]
613 628 self.assertEqual(mimes, expected)
614 629
615 630 def test_execute_raises(self):
616 631 """exceptions in execute requests raise appropriately"""
617 632 view = self.client[-1]
618 633 ar = view.execute("1/0")
619 634 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
620 635
621 636 def test_remoteerror_render_exception(self):
622 637 """RemoteErrors get nice tracebacks"""
623 638 view = self.client[-1]
624 639 ar = view.execute("1/0")
625 640 ip = get_ipython()
626 641 ip.user_ns['ar'] = ar
627 642 with capture_output() as io:
628 643 ip.run_cell("ar.get(2)")
629 644
630 645 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
631 646
632 647 def test_compositeerror_render_exception(self):
633 648 """CompositeErrors get nice tracebacks"""
634 649 view = self.client[:]
635 650 ar = view.execute("1/0")
636 651 ip = get_ipython()
637 652 ip.user_ns['ar'] = ar
638 653
639 654 with capture_output() as io:
640 655 ip.run_cell("ar.get(2)")
641 656
642 657 count = min(error.CompositeError.tb_limit, len(view))
643 658
644 659 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
645 660 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
646 661 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
647 662
648 663 def test_compositeerror_truncate(self):
649 664 """Truncate CompositeErrors with many exceptions"""
650 665 view = self.client[:]
651 666 msg_ids = []
652 667 for i in range(10):
653 668 ar = view.execute("1/0")
654 669 msg_ids.extend(ar.msg_ids)
655 670
656 671 ar = self.client.get_result(msg_ids)
657 672 try:
658 673 ar.get()
659 674 except error.CompositeError as _e:
660 675 e = _e
661 676 else:
662 677 self.fail("Should have raised CompositeError")
663 678
664 679 lines = e.render_traceback()
665 680 with capture_output() as io:
666 681 e.print_traceback()
667 682
668 683 self.assertTrue("more exceptions" in lines[-1])
669 684 count = e.tb_limit
670 685
671 686 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
672 687 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
673 688 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
674 689
675 690 @dec.skipif_not_matplotlib
676 691 def test_magic_pylab(self):
677 692 """%pylab works on engines"""
678 693 view = self.client[-1]
679 694 ar = view.execute("%pylab inline")
680 695 # at least check if this raised:
681 696 reply = ar.get(5)
682 697 # include imports, in case user config
683 698 ar = view.execute("plot(rand(100))", silent=False)
684 699 reply = ar.get(5)
685 700 self.assertEqual(len(reply.outputs), 1)
686 701 output = reply.outputs[0]
687 702 self.assertTrue("data" in output)
688 703 data = output['data']
689 704 self.assertTrue("image/png" in data)
690 705
691 706 def test_func_default_func(self):
692 707 """interactively defined function as apply func default"""
693 708 def foo():
694 709 return 'foo'
695 710
696 711 def bar(f=foo):
697 712 return f()
698 713
699 714 view = self.client[-1]
700 715 ar = view.apply_async(bar)
701 716 r = ar.get(10)
702 717 self.assertEqual(r, 'foo')
703 718 def test_data_pub_single(self):
704 719 view = self.client[-1]
705 720 ar = view.execute('\n'.join([
706 721 'from IPython.kernel.zmq.datapub import publish_data',
707 722 'for i in range(5):',
708 723 ' publish_data(dict(i=i))'
709 724 ]), block=False)
710 725 self.assertTrue(isinstance(ar.data, dict))
711 726 ar.get(5)
712 727 self.assertEqual(ar.data, dict(i=4))
713 728
714 729 def test_data_pub(self):
715 730 view = self.client[:]
716 731 ar = view.execute('\n'.join([
717 732 'from IPython.kernel.zmq.datapub import publish_data',
718 733 'for i in range(5):',
719 734 ' publish_data(dict(i=i))'
720 735 ]), block=False)
721 736 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
722 737 ar.get(5)
723 738 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
724 739
725 740 def test_can_list_arg(self):
726 741 """args in lists are canned"""
727 742 view = self.client[-1]
728 743 view['a'] = 128
729 744 rA = pmod.Reference('a')
730 745 ar = view.apply_async(lambda x: x, [rA])
731 746 r = ar.get(5)
732 747 self.assertEqual(r, [128])
733 748
734 749 def test_can_dict_arg(self):
735 750 """args in dicts are canned"""
736 751 view = self.client[-1]
737 752 view['a'] = 128
738 753 rA = pmod.Reference('a')
739 754 ar = view.apply_async(lambda x: x, dict(foo=rA))
740 755 r = ar.get(5)
741 756 self.assertEqual(r, dict(foo=128))
742 757
743 758 def test_can_list_kwarg(self):
744 759 """kwargs in lists are canned"""
745 760 view = self.client[-1]
746 761 view['a'] = 128
747 762 rA = pmod.Reference('a')
748 763 ar = view.apply_async(lambda x=5: x, x=[rA])
749 764 r = ar.get(5)
750 765 self.assertEqual(r, [128])
751 766
752 767 def test_can_dict_kwarg(self):
753 768 """kwargs in dicts are canned"""
754 769 view = self.client[-1]
755 770 view['a'] = 128
756 771 rA = pmod.Reference('a')
757 772 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
758 773 r = ar.get(5)
759 774 self.assertEqual(r, dict(foo=128))
760 775
761 776 def test_map_ref(self):
762 777 """view.map works with references"""
763 778 view = self.client[:]
764 779 ranks = sorted(self.client.ids)
765 780 view.scatter('rank', ranks, flatten=True)
766 781 rrank = pmod.Reference('rank')
767 782
768 783 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
769 784 drank = amr.get(5)
770 785 self.assertEqual(drank, [ r*2 for r in ranks ])
771 786
772 787 def test_nested_getitem_setitem(self):
773 788 """get and set with view['a.b']"""
774 789 view = self.client[-1]
775 790 view.execute('\n'.join([
776 791 'class A(object): pass',
777 792 'a = A()',
778 793 'a.b = 128',
779 794 ]), block=True)
780 795 ra = pmod.Reference('a')
781 796
782 797 r = view.apply_sync(lambda x: x.b, ra)
783 798 self.assertEqual(r, 128)
784 799 self.assertEqual(view['a.b'], 128)
785 800
786 801 view['a.b'] = 0
787 802
788 803 r = view.apply_sync(lambda x: x.b, ra)
789 804 self.assertEqual(r, 0)
790 805 self.assertEqual(view['a.b'], 0)
791 806
792 807 def test_return_namedtuple(self):
793 808 def namedtuplify(x, y):
794 809 from IPython.parallel.tests.test_view import point
795 810 return point(x, y)
796 811
797 812 view = self.client[-1]
798 813 p = view.apply_sync(namedtuplify, 1, 2)
799 814 self.assertEqual(p.x, 1)
800 815 self.assertEqual(p.y, 2)
801 816
802 817 def test_apply_namedtuple(self):
803 818 def echoxy(p):
804 819 return p.y, p.x
805 820
806 821 view = self.client[-1]
807 822 tup = view.apply_sync(echoxy, point(1, 2))
808 823 self.assertEqual(tup, (2,1))
809 824
810 825 def test_sync_imports(self):
811 826 view = self.client[-1]
812 827 with capture_output() as io:
813 828 with view.sync_imports():
814 829 import IPython
815 830 self.assertIn("IPython", io.stdout)
816 831
817 832 @interactive
818 833 def find_ipython():
819 834 return 'IPython' in globals()
820 835
821 836 assert view.apply_sync(find_ipython)
822 837
823 838 def test_sync_imports_quiet(self):
824 839 view = self.client[-1]
825 840 with capture_output() as io:
826 841 with view.sync_imports(quiet=True):
827 842 import IPython
828 843 self.assertEqual(io.stdout, '')
829 844
830 845 @interactive
831 846 def find_ipython():
832 847 return 'IPython' in globals()
833 848
834 849 assert view.apply_sync(find_ipython)
835 850
@@ -1,382 +1,390 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import sys
21 21 from types import FunctionType
22 22
23 23 try:
24 24 import cPickle as pickle
25 25 except ImportError:
26 26 import pickle
27 27
28 28 from . import codeutil # This registers a hook when it's imported
29 29 from . import py3compat
30 30 from .importstring import import_item
31 31 from .py3compat import string_types, iteritems
32 32
33 33 from IPython.config import Application
34 34
35 35 if py3compat.PY3:
36 36 buffer = memoryview
37 37 class_type = type
38 38 else:
39 39 from types import ClassType
40 40 class_type = (type, ClassType)
41 41
42 42 #-------------------------------------------------------------------------------
43 43 # Functions
44 44 #-------------------------------------------------------------------------------
45 45
46 46
47 47 def use_dill():
48 48 """use dill to expand serialization support
49 49
50 50 adds support for object methods and closures to serialization.
51 51 """
52 52 # import dill causes most of the magic
53 53 import dill
54 54
55 55 # dill doesn't work with cPickle,
56 56 # tell the two relevant modules to use plain pickle
57 57
58 58 global pickle
59 59 pickle = dill
60 60
61 61 try:
62 62 from IPython.kernel.zmq import serialize
63 63 except ImportError:
64 64 pass
65 65 else:
66 66 serialize.pickle = dill
67 67
68 68 # disable special function handling, let dill take care of it
69 69 can_map.pop(FunctionType, None)
70 70
71 71
72 72 #-------------------------------------------------------------------------------
73 73 # Classes
74 74 #-------------------------------------------------------------------------------
75 75
76 76
77 77 class CannedObject(object):
78 78 def __init__(self, obj, keys=[], hook=None):
79 79 """can an object for safe pickling
80 80
81 81 Parameters
82 82 ==========
83 83
84 84 obj:
85 85 The object to be canned
86 86 keys: list (optional)
87 87 list of attribute names that will be explicitly canned / uncanned
88 88 hook: callable (optional)
89 89 An optional extra callable,
90 90 which can do additional processing of the uncanned object.
91 91
92 92 large data may be offloaded into the buffers list,
93 93 used for zero-copy transfers.
94 94 """
95 95 self.keys = keys
96 96 self.obj = copy.copy(obj)
97 97 self.hook = can(hook)
98 98 for key in keys:
99 99 setattr(self.obj, key, can(getattr(obj, key)))
100 100
101 101 self.buffers = []
102 102
103 103 def get_object(self, g=None):
104 104 if g is None:
105 105 g = {}
106 106 obj = self.obj
107 107 for key in self.keys:
108 108 setattr(obj, key, uncan(getattr(obj, key), g))
109 109
110 110 if self.hook:
111 111 self.hook = uncan(self.hook, g)
112 112 self.hook(obj, g)
113 113 return self.obj
114 114
115 115
116 116 class Reference(CannedObject):
117 117 """object for wrapping a remote reference by name."""
118 118 def __init__(self, name):
119 119 if not isinstance(name, string_types):
120 120 raise TypeError("illegal name: %r"%name)
121 121 self.name = name
122 122 self.buffers = []
123 123
124 124 def __repr__(self):
125 125 return "<Reference: %r>"%self.name
126 126
127 127 def get_object(self, g=None):
128 128 if g is None:
129 129 g = {}
130 130
131 131 return eval(self.name, g)
132 132
133 133
134 134 class CannedFunction(CannedObject):
135 135
136 136 def __init__(self, f):
137 137 self._check_type(f)
138 138 self.code = f.__code__
139 139 if f.__defaults__:
140 140 self.defaults = [ can(fd) for fd in f.__defaults__ ]
141 141 else:
142 142 self.defaults = None
143 143 self.module = f.__module__ or '__main__'
144 144 self.__name__ = f.__name__
145 145 self.buffers = []
146 146
147 147 def _check_type(self, obj):
148 148 assert isinstance(obj, FunctionType), "Not a function type"
149 149
150 150 def get_object(self, g=None):
151 151 # try to load function back into its module:
152 152 if not self.module.startswith('__'):
153 153 __import__(self.module)
154 154 g = sys.modules[self.module].__dict__
155 155
156 156 if g is None:
157 157 g = {}
158 158 if self.defaults:
159 159 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
160 160 else:
161 161 defaults = None
162 162 newFunc = FunctionType(self.code, g, self.__name__, defaults)
163 163 return newFunc
164 164
165 165 class CannedClass(CannedObject):
166 166
167 167 def __init__(self, cls):
168 168 self._check_type(cls)
169 169 self.name = cls.__name__
170 170 self.old_style = not isinstance(cls, type)
171 171 self._canned_dict = {}
172 172 for k,v in cls.__dict__.items():
173 173 if k not in ('__weakref__', '__dict__'):
174 174 self._canned_dict[k] = can(v)
175 175 if self.old_style:
176 176 mro = []
177 177 else:
178 178 mro = cls.mro()
179 179
180 180 self.parents = [ can(c) for c in mro[1:] ]
181 181 self.buffers = []
182 182
183 183 def _check_type(self, obj):
184 184 assert isinstance(obj, class_type), "Not a class type"
185 185
186 186 def get_object(self, g=None):
187 187 parents = tuple(uncan(p, g) for p in self.parents)
188 188 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
189 189
190 190 class CannedArray(CannedObject):
191 191 def __init__(self, obj):
192 192 from numpy import ascontiguousarray
193 193 self.shape = obj.shape
194 194 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
195 self.pickled = False
195 196 if sum(obj.shape) == 0:
197 self.pickled = True
198 elif obj.dtype == 'O':
199 # can't handle object dtype with buffer approach
200 self.pickled = True
201 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
202 self.pickled = True
203 if self.pickled:
196 204 # just pickle it
197 205 self.buffers = [pickle.dumps(obj, -1)]
198 206 else:
199 207 # ensure contiguous
200 208 obj = ascontiguousarray(obj, dtype=None)
201 209 self.buffers = [buffer(obj)]
202 210
203 211 def get_object(self, g=None):
204 212 from numpy import frombuffer
205 213 data = self.buffers[0]
206 if sum(self.shape) == 0:
214 if self.pickled:
207 215 # no shape, we just pickled it
208 216 return pickle.loads(data)
209 217 else:
210 218 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
211 219
212 220
213 221 class CannedBytes(CannedObject):
214 222 wrap = bytes
215 223 def __init__(self, obj):
216 224 self.buffers = [obj]
217 225
218 226 def get_object(self, g=None):
219 227 data = self.buffers[0]
220 228 return self.wrap(data)
221 229
222 230 def CannedBuffer(CannedBytes):
223 231 wrap = buffer
224 232
225 233 #-------------------------------------------------------------------------------
226 234 # Functions
227 235 #-------------------------------------------------------------------------------
228 236
229 237 def _logger():
230 238 """get the logger for the current Application
231 239
232 240 the root logger will be used if no Application is running
233 241 """
234 242 if Application.initialized():
235 243 logger = Application.instance().log
236 244 else:
237 245 logger = logging.getLogger()
238 246 if not logger.handlers:
239 247 logging.basicConfig()
240 248
241 249 return logger
242 250
243 251 def _import_mapping(mapping, original=None):
244 252 """import any string-keys in a type mapping
245 253
246 254 """
247 255 log = _logger()
248 256 log.debug("Importing canning map")
249 257 for key,value in list(mapping.items()):
250 258 if isinstance(key, string_types):
251 259 try:
252 260 cls = import_item(key)
253 261 except Exception:
254 262 if original and key not in original:
255 263 # only message on user-added classes
256 264 log.error("canning class not importable: %r", key, exc_info=True)
257 265 mapping.pop(key)
258 266 else:
259 267 mapping[cls] = mapping.pop(key)
260 268
261 269 def istype(obj, check):
262 270 """like isinstance(obj, check), but strict
263 271
264 272 This won't catch subclasses.
265 273 """
266 274 if isinstance(check, tuple):
267 275 for cls in check:
268 276 if type(obj) is cls:
269 277 return True
270 278 return False
271 279 else:
272 280 return type(obj) is check
273 281
274 282 def can(obj):
275 283 """prepare an object for pickling"""
276 284
277 285 import_needed = False
278 286
279 287 for cls,canner in iteritems(can_map):
280 288 if isinstance(cls, string_types):
281 289 import_needed = True
282 290 break
283 291 elif istype(obj, cls):
284 292 return canner(obj)
285 293
286 294 if import_needed:
287 295 # perform can_map imports, then try again
288 296 # this will usually only happen once
289 297 _import_mapping(can_map, _original_can_map)
290 298 return can(obj)
291 299
292 300 return obj
293 301
294 302 def can_class(obj):
295 303 if isinstance(obj, class_type) and obj.__module__ == '__main__':
296 304 return CannedClass(obj)
297 305 else:
298 306 return obj
299 307
300 308 def can_dict(obj):
301 309 """can the *values* of a dict"""
302 310 if istype(obj, dict):
303 311 newobj = {}
304 312 for k, v in iteritems(obj):
305 313 newobj[k] = can(v)
306 314 return newobj
307 315 else:
308 316 return obj
309 317
310 318 sequence_types = (list, tuple, set)
311 319
312 320 def can_sequence(obj):
313 321 """can the elements of a sequence"""
314 322 if istype(obj, sequence_types):
315 323 t = type(obj)
316 324 return t([can(i) for i in obj])
317 325 else:
318 326 return obj
319 327
320 328 def uncan(obj, g=None):
321 329 """invert canning"""
322 330
323 331 import_needed = False
324 332 for cls,uncanner in iteritems(uncan_map):
325 333 if isinstance(cls, string_types):
326 334 import_needed = True
327 335 break
328 336 elif isinstance(obj, cls):
329 337 return uncanner(obj, g)
330 338
331 339 if import_needed:
332 340 # perform uncan_map imports, then try again
333 341 # this will usually only happen once
334 342 _import_mapping(uncan_map, _original_uncan_map)
335 343 return uncan(obj, g)
336 344
337 345 return obj
338 346
339 347 def uncan_dict(obj, g=None):
340 348 if istype(obj, dict):
341 349 newobj = {}
342 350 for k, v in iteritems(obj):
343 351 newobj[k] = uncan(v,g)
344 352 return newobj
345 353 else:
346 354 return obj
347 355
348 356 def uncan_sequence(obj, g=None):
349 357 if istype(obj, sequence_types):
350 358 t = type(obj)
351 359 return t([uncan(i,g) for i in obj])
352 360 else:
353 361 return obj
354 362
355 363 def _uncan_dependent_hook(dep, g=None):
356 364 dep.check_dependency()
357 365
358 366 def can_dependent(obj):
359 367 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
360 368
361 369 #-------------------------------------------------------------------------------
362 370 # API dictionaries
363 371 #-------------------------------------------------------------------------------
364 372
365 373 # These dicts can be extended for custom serialization of new objects
366 374
367 375 can_map = {
368 376 'IPython.parallel.dependent' : can_dependent,
369 377 'numpy.ndarray' : CannedArray,
370 378 FunctionType : CannedFunction,
371 379 bytes : CannedBytes,
372 380 buffer : CannedBuffer,
373 381 class_type : can_class,
374 382 }
375 383
376 384 uncan_map = {
377 385 CannedObject : lambda obj, g: obj.get_object(g),
378 386 }
379 387
380 388 # for use in _import_mapping:
381 389 _original_can_map = can_map.copy()
382 390 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now