##// END OF EJS Templates
Parallel: Support get/set of nested objects in view (e.g. dv['a.b'])
Bradley M. Froehle -
Show More
@@ -1,678 +1,696
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from tempfile import mktemp
23 23 from StringIO import StringIO
24 24
25 25 import zmq
26 26 from nose import SkipTest
27 27
28 28 from IPython.testing import decorators as dec
29 29 from IPython.testing.ipunittest import ParametricTestCase
30 30
31 31 from IPython import parallel as pmod
32 32 from IPython.parallel import error
33 33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 34 from IPython.parallel import DirectView
35 35 from IPython.parallel.util import interactive
36 36
37 37 from IPython.parallel.tests import add_engines
38 38
39 39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 40
41 41 def setup():
42 42 add_engines(3, total=True)
43 43
44 44 class TestView(ClusterTestCase, ParametricTestCase):
45 45
46 46 def setUp(self):
47 47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 50 time.sleep(2)
51 51 super(TestView, self).setUp()
52 52
53 53 def test_z_crash_mux(self):
54 54 """test graceful handling of engine death (direct)"""
55 55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 56 # self.add_engines(1)
57 57 eid = self.client.ids[-1]
58 58 ar = self.client[eid].apply_async(crash)
59 59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 60 eid = ar.engine_id
61 61 tic = time.time()
62 62 while eid in self.client.ids and time.time()-tic < 5:
63 63 time.sleep(.01)
64 64 self.client.spin()
65 65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66 66
67 67 def test_push_pull(self):
68 68 """test pushing and pulling"""
69 69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 70 t = self.client.ids[-1]
71 71 v = self.client[t]
72 72 push = v.push
73 73 pull = v.pull
74 74 v.block=True
75 75 nengines = len(self.client)
76 76 push({'data':data})
77 77 d = pull('data')
78 78 self.assertEqual(d, data)
79 79 self.client[:].push({'data':data})
80 80 d = self.client[:].pull('data', block=True)
81 81 self.assertEqual(d, nengines*[data])
82 82 ar = push({'data':data}, block=False)
83 83 self.assertTrue(isinstance(ar, AsyncResult))
84 84 r = ar.get()
85 85 ar = self.client[:].pull('data', block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 88 self.assertEqual(r, nengines*[data])
89 89 self.client[:].push(dict(a=10,b=20))
90 90 r = self.client[:].pull(('a','b'), block=True)
91 91 self.assertEqual(r, nengines*[[10,20]])
92 92
93 93 def test_push_pull_function(self):
94 94 "test pushing and pulling functions"
95 95 def testf(x):
96 96 return 2.0*x
97 97
98 98 t = self.client.ids[-1]
99 99 v = self.client[t]
100 100 v.block=True
101 101 push = v.push
102 102 pull = v.pull
103 103 execute = v.execute
104 104 push({'testf':testf})
105 105 r = pull('testf')
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute('r = testf(10)')
108 108 r = pull('r')
109 109 self.assertEqual(r, testf(10))
110 110 ar = self.client[:].push({'testf':testf}, block=False)
111 111 ar.get()
112 112 ar = self.client[:].pull('testf', block=False)
113 113 rlist = ar.get()
114 114 for r in rlist:
115 115 self.assertEqual(r(1.0), testf(1.0))
116 116 execute("def g(x): return x*x")
117 117 r = pull(('testf','g'))
118 118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119 119
120 120 def test_push_function_globals(self):
121 121 """test that pushed functions have access to globals"""
122 122 @interactive
123 123 def geta():
124 124 return a
125 125 # self.add_engines(1)
126 126 v = self.client[-1]
127 127 v.block=True
128 128 v['f'] = geta
129 129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 130 v.execute('a=5')
131 131 v.execute('b=f()')
132 132 self.assertEqual(v['b'], 5)
133 133
134 134 def test_push_function_defaults(self):
135 135 """test that pushed functions preserve default args"""
136 136 def echo(a=10):
137 137 return a
138 138 v = self.client[-1]
139 139 v.block=True
140 140 v['f'] = echo
141 141 v.execute('b=f()')
142 142 self.assertEqual(v['b'], 10)
143 143
144 144 def test_get_result(self):
145 145 """test getting results from the Hub."""
146 146 c = pmod.Client(profile='iptest')
147 147 # self.add_engines(1)
148 148 t = c.ids[-1]
149 149 v = c[t]
150 150 v2 = self.client[t]
151 151 ar = v.apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = v2.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 156 self.assertEqual(ahr.get(), ar.get())
157 157 ar2 = v2.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.spin()
160 160 c.close()
161 161
162 162 def test_run_newline(self):
163 163 """test that run appends newline to files"""
164 164 tmpfile = mktemp()
165 165 with open(tmpfile, 'w') as f:
166 166 f.write("""def g():
167 167 return 5
168 168 """)
169 169 v = self.client[-1]
170 170 v.run(tmpfile, block=True)
171 171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172 172
173 173 def test_apply_tracked(self):
174 174 """test tracking for apply"""
175 175 # self.add_engines(1)
176 176 t = self.client.ids[-1]
177 177 v = self.client[t]
178 178 v.block=False
179 179 def echo(n=1024*1024, **kwargs):
180 180 with v.temp_flags(**kwargs):
181 181 return v.apply(lambda x: x, 'x'*n)
182 182 ar = echo(1, track=False)
183 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 184 self.assertTrue(ar.sent)
185 185 ar = echo(track=True)
186 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 187 self.assertEqual(ar.sent, ar._tracker.done)
188 188 ar._tracker.wait()
189 189 self.assertTrue(ar.sent)
190 190
191 191 def test_push_tracked(self):
192 192 t = self.client.ids[-1]
193 193 ns = dict(x='x'*1024*1024)
194 194 v = self.client[t]
195 195 ar = v.push(ns, block=False, track=False)
196 196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 197 self.assertTrue(ar.sent)
198 198
199 199 ar = v.push(ns, block=False, track=True)
200 200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 201 ar._tracker.wait()
202 202 self.assertEqual(ar.sent, ar._tracker.done)
203 203 self.assertTrue(ar.sent)
204 204 ar.get()
205 205
206 206 def test_scatter_tracked(self):
207 207 t = self.client.ids
208 208 x='x'*1024*1024
209 209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 211 self.assertTrue(ar.sent)
212 212
213 213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 215 self.assertEqual(ar.sent, ar._tracker.done)
216 216 ar._tracker.wait()
217 217 self.assertTrue(ar.sent)
218 218 ar.get()
219 219
220 220 def test_remote_reference(self):
221 221 v = self.client[-1]
222 222 v['a'] = 123
223 223 ra = pmod.Reference('a')
224 224 b = v.apply_sync(lambda x: x, ra)
225 225 self.assertEqual(b, 123)
226 226
227 227
228 228 def test_scatter_gather(self):
229 229 view = self.client[:]
230 230 seq1 = range(16)
231 231 view.scatter('a', seq1)
232 232 seq2 = view.gather('a', block=True)
233 233 self.assertEqual(seq2, seq1)
234 234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235 235
236 236 @skip_without('numpy')
237 237 def test_scatter_gather_numpy(self):
238 238 import numpy
239 239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 240 view = self.client[:]
241 241 a = numpy.arange(64)
242 242 view.scatter('a', a, block=True)
243 243 b = view.gather('a', block=True)
244 244 assert_array_equal(b, a)
245 245
246 246 def test_scatter_gather_lazy(self):
247 247 """scatter/gather with targets='all'"""
248 248 view = self.client.direct_view(targets='all')
249 249 x = range(64)
250 250 view.scatter('x', x)
251 251 gathered = view.gather('x', block=True)
252 252 self.assertEqual(gathered, x)
253 253
254 254
255 255 @dec.known_failure_py3
256 256 @skip_without('numpy')
257 257 def test_push_numpy_nocopy(self):
258 258 import numpy
259 259 view = self.client[:]
260 260 a = numpy.arange(64)
261 261 view['A'] = a
262 262 @interactive
263 263 def check_writeable(x):
264 264 return x.flags.writeable
265 265
266 266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268 268
269 269 view.push(dict(B=a))
270 270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 272
273 273 @skip_without('numpy')
274 274 def test_apply_numpy(self):
275 275 """view.apply(f, ndarray)"""
276 276 import numpy
277 277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 278
279 279 A = numpy.random.random((100,100))
280 280 view = self.client[-1]
281 281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 282 B = A.astype(dt)
283 283 C = view.apply_sync(lambda x:x, B)
284 284 assert_array_equal(B,C)
285 285
286 286 @skip_without('numpy')
287 287 def test_push_pull_recarray(self):
288 288 """push/pull recarrays"""
289 289 import numpy
290 290 from numpy.testing.utils import assert_array_equal
291 291
292 292 view = self.client[-1]
293 293
294 294 R = numpy.array([
295 295 (1, 'hi', 0.),
296 296 (2**30, 'there', 2.5),
297 297 (-99999, 'world', -12345.6789),
298 298 ], [('n', int), ('s', '|S10'), ('f', float)])
299 299
300 300 view['RR'] = R
301 301 R2 = view['RR']
302 302
303 303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 304 self.assertEqual(r_dtype, R.dtype)
305 305 self.assertEqual(r_shape, R.shape)
306 306 self.assertEqual(R2.dtype, R.dtype)
307 307 self.assertEqual(R2.shape, R.shape)
308 308 assert_array_equal(R2, R)
309 309
310 310 def test_map(self):
311 311 view = self.client[:]
312 312 def f(x):
313 313 return x**2
314 314 data = range(16)
315 315 r = view.map_sync(f, data)
316 316 self.assertEqual(r, map(f, data))
317 317
318 318 def test_map_iterable(self):
319 319 """test map on iterables (direct)"""
320 320 view = self.client[:]
321 321 # 101 is prime, so it won't be evenly distributed
322 322 arr = range(101)
323 323 # ensure it will be an iterator, even in Python 3
324 324 it = iter(arr)
325 325 r = view.map_sync(lambda x:x, arr)
326 326 self.assertEqual(r, list(arr))
327 327
328 328 def test_scatter_gather_nonblocking(self):
329 329 data = range(16)
330 330 view = self.client[:]
331 331 view.scatter('a', data, block=False)
332 332 ar = view.gather('a', block=False)
333 333 self.assertEqual(ar.get(), data)
334 334
335 335 @skip_without('numpy')
336 336 def test_scatter_gather_numpy_nonblocking(self):
337 337 import numpy
338 338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 339 a = numpy.arange(64)
340 340 view = self.client[:]
341 341 ar = view.scatter('a', a, block=False)
342 342 self.assertTrue(isinstance(ar, AsyncResult))
343 343 amr = view.gather('a', block=False)
344 344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 345 assert_array_equal(amr.get(), a)
346 346
347 347 def test_execute(self):
348 348 view = self.client[:]
349 349 # self.client.debug=True
350 350 execute = view.execute
351 351 ar = execute('c=30', block=False)
352 352 self.assertTrue(isinstance(ar, AsyncResult))
353 353 ar = execute('d=[0,1,2]', block=False)
354 354 self.client.wait(ar, 1)
355 355 self.assertEqual(len(ar.get()), len(self.client))
356 356 for c in view['c']:
357 357 self.assertEqual(c, 30)
358 358
359 359 def test_abort(self):
360 360 view = self.client[-1]
361 361 ar = view.execute('import time; time.sleep(1)', block=False)
362 362 ar2 = view.apply_async(lambda : 2)
363 363 ar3 = view.apply_async(lambda : 3)
364 364 view.abort(ar2)
365 365 view.abort(ar3.msg_ids)
366 366 self.assertRaises(error.TaskAborted, ar2.get)
367 367 self.assertRaises(error.TaskAborted, ar3.get)
368 368
369 369 def test_abort_all(self):
370 370 """view.abort() aborts all outstanding tasks"""
371 371 view = self.client[-1]
372 372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 373 view.abort()
374 374 view.wait(timeout=5)
375 375 for ar in ars[5:]:
376 376 self.assertRaises(error.TaskAborted, ar.get)
377 377
378 378 def test_temp_flags(self):
379 379 view = self.client[-1]
380 380 view.block=True
381 381 with view.temp_flags(block=False):
382 382 self.assertFalse(view.block)
383 383 self.assertTrue(view.block)
384 384
385 385 @dec.known_failure_py3
386 386 def test_importer(self):
387 387 view = self.client[-1]
388 388 view.clear(block=True)
389 389 with view.importer:
390 390 import re
391 391
392 392 @interactive
393 393 def findall(pat, s):
394 394 # this globals() step isn't necessary in real code
395 395 # only to prevent a closure in the test
396 396 re = globals()['re']
397 397 return re.findall(pat, s)
398 398
399 399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400 400
401 401 def test_unicode_execute(self):
402 402 """test executing unicode strings"""
403 403 v = self.client[-1]
404 404 v.block=True
405 405 if sys.version_info[0] >= 3:
406 406 code="a='é'"
407 407 else:
408 408 code=u"a=u'é'"
409 409 v.execute(code)
410 410 self.assertEqual(v['a'], u'é')
411 411
412 412 def test_unicode_apply_result(self):
413 413 """test unicode apply results"""
414 414 v = self.client[-1]
415 415 r = v.apply_sync(lambda : u'é')
416 416 self.assertEqual(r, u'é')
417 417
418 418 def test_unicode_apply_arg(self):
419 419 """test passing unicode arguments to apply"""
420 420 v = self.client[-1]
421 421
422 422 @interactive
423 423 def check_unicode(a, check):
424 424 assert isinstance(a, unicode), "%r is not unicode"%a
425 425 assert isinstance(check, bytes), "%r is not bytes"%check
426 426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427 427
428 428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 429 try:
430 430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 431 except error.RemoteError as e:
432 432 if e.ename == 'AssertionError':
433 433 self.fail(e.evalue)
434 434 else:
435 435 raise e
436 436
437 437 def test_map_reference(self):
438 438 """view.map(<Reference>, *seqs) should work"""
439 439 v = self.client[:]
440 440 v.scatter('n', self.client.ids, flatten=True)
441 441 v.execute("f = lambda x,y: x*y")
442 442 rf = pmod.Reference('f')
443 443 nlist = list(range(10))
444 444 mlist = nlist[::-1]
445 445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 446 result = v.map_sync(rf, mlist, nlist)
447 447 self.assertEqual(result, expected)
448 448
449 449 def test_apply_reference(self):
450 450 """view.apply(<Reference>, *args) should work"""
451 451 v = self.client[:]
452 452 v.scatter('n', self.client.ids, flatten=True)
453 453 v.execute("f = lambda x: n*x")
454 454 rf = pmod.Reference('f')
455 455 result = v.apply_sync(rf, 5)
456 456 expected = [ 5*id for id in self.client.ids ]
457 457 self.assertEqual(result, expected)
458 458
459 459 def test_eval_reference(self):
460 460 v = self.client[self.client.ids[0]]
461 461 v['g'] = range(5)
462 462 rg = pmod.Reference('g[0]')
463 463 echo = lambda x:x
464 464 self.assertEqual(v.apply_sync(echo, rg), 0)
465 465
466 466 def test_reference_nameerror(self):
467 467 v = self.client[self.client.ids[0]]
468 468 r = pmod.Reference('elvis_has_left')
469 469 echo = lambda x:x
470 470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471 471
472 472 def test_single_engine_map(self):
473 473 e0 = self.client[self.client.ids[0]]
474 474 r = range(5)
475 475 check = [ -1*i for i in r ]
476 476 result = e0.map_sync(lambda x: -1*x, r)
477 477 self.assertEqual(result, check)
478 478
479 479 def test_len(self):
480 480 """len(view) makes sense"""
481 481 e0 = self.client[self.client.ids[0]]
482 482 yield self.assertEqual(len(e0), 1)
483 483 v = self.client[:]
484 484 yield self.assertEqual(len(v), len(self.client.ids))
485 485 v = self.client.direct_view('all')
486 486 yield self.assertEqual(len(v), len(self.client.ids))
487 487 v = self.client[:2]
488 488 yield self.assertEqual(len(v), 2)
489 489 v = self.client[:1]
490 490 yield self.assertEqual(len(v), 1)
491 491 v = self.client.load_balanced_view()
492 492 yield self.assertEqual(len(v), len(self.client.ids))
493 493 # parametric tests seem to require manual closing?
494 494 self.client.close()
495 495
496 496
497 497 # begin execute tests
498 498
499 499 def test_execute_reply(self):
500 500 e0 = self.client[self.client.ids[0]]
501 501 e0.block = True
502 502 ar = e0.execute("5", silent=False)
503 503 er = ar.get()
504 504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506 506
507 507 def test_execute_reply_stdout(self):
508 508 e0 = self.client[self.client.ids[0]]
509 509 e0.block = True
510 510 ar = e0.execute("print (5)", silent=False)
511 511 er = ar.get()
512 512 self.assertEqual(er.stdout.strip(), '5')
513 513
514 514 def test_execute_pyout(self):
515 515 """execute triggers pyout with silent=False"""
516 516 view = self.client[:]
517 517 ar = view.execute("5", silent=False, block=True)
518 518
519 519 expected = [{'text/plain' : '5'}] * len(view)
520 520 mimes = [ out['data'] for out in ar.pyout ]
521 521 self.assertEqual(mimes, expected)
522 522
523 523 def test_execute_silent(self):
524 524 """execute does not trigger pyout with silent=True"""
525 525 view = self.client[:]
526 526 ar = view.execute("5", block=True)
527 527 expected = [None] * len(view)
528 528 self.assertEqual(ar.pyout, expected)
529 529
530 530 def test_execute_magic(self):
531 531 """execute accepts IPython commands"""
532 532 view = self.client[:]
533 533 view.execute("a = 5")
534 534 ar = view.execute("%whos", block=True)
535 535 # this will raise, if that failed
536 536 ar.get(5)
537 537 for stdout in ar.stdout:
538 538 lines = stdout.splitlines()
539 539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 540 found = False
541 541 for line in lines[2:]:
542 542 split = line.split()
543 543 if split == ['a', 'int', '5']:
544 544 found = True
545 545 break
546 546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547 547
548 548 def test_execute_displaypub(self):
549 549 """execute tracks display_pub output"""
550 550 view = self.client[:]
551 551 view.execute("from IPython.core.display import *")
552 552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553 553
554 554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 555 for outputs in ar.outputs:
556 556 mimes = [ out['data'] for out in outputs ]
557 557 self.assertEqual(mimes, expected)
558 558
559 559 def test_apply_displaypub(self):
560 560 """apply tracks display_pub output"""
561 561 view = self.client[:]
562 562 view.execute("from IPython.core.display import *")
563 563
564 564 @interactive
565 565 def publish():
566 566 [ display(i) for i in range(5) ]
567 567
568 568 ar = view.apply_async(publish)
569 569 ar.get(5)
570 570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 571 for outputs in ar.outputs:
572 572 mimes = [ out['data'] for out in outputs ]
573 573 self.assertEqual(mimes, expected)
574 574
575 575 def test_execute_raises(self):
576 576 """exceptions in execute requests raise appropriately"""
577 577 view = self.client[-1]
578 578 ar = view.execute("1/0")
579 579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580 580
581 581 @dec.skipif_not_matplotlib
582 582 def test_magic_pylab(self):
583 583 """%pylab works on engines"""
584 584 view = self.client[-1]
585 585 ar = view.execute("%pylab inline")
586 586 # at least check if this raised:
587 587 reply = ar.get(5)
588 588 # include imports, in case user config
589 589 ar = view.execute("plot(rand(100))", silent=False)
590 590 reply = ar.get(5)
591 591 self.assertEqual(len(reply.outputs), 1)
592 592 output = reply.outputs[0]
593 593 self.assertTrue("data" in output)
594 594 data = output['data']
595 595 self.assertTrue("image/png" in data)
596 596
597 597 def test_func_default_func(self):
598 598 """interactively defined function as apply func default"""
599 599 def foo():
600 600 return 'foo'
601 601
602 602 def bar(f=foo):
603 603 return f()
604 604
605 605 view = self.client[-1]
606 606 ar = view.apply_async(bar)
607 607 r = ar.get(10)
608 608 self.assertEqual(r, 'foo')
609 609 def test_data_pub_single(self):
610 610 view = self.client[-1]
611 611 ar = view.execute('\n'.join([
612 612 'from IPython.zmq.datapub import publish_data',
613 613 'for i in range(5):',
614 614 ' publish_data(dict(i=i))'
615 615 ]), block=False)
616 616 self.assertTrue(isinstance(ar.data, dict))
617 617 ar.get(5)
618 618 self.assertEqual(ar.data, dict(i=4))
619 619
620 620 def test_data_pub(self):
621 621 view = self.client[:]
622 622 ar = view.execute('\n'.join([
623 623 'from IPython.zmq.datapub import publish_data',
624 624 'for i in range(5):',
625 625 ' publish_data(dict(i=i))'
626 626 ]), block=False)
627 627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
628 628 ar.get(5)
629 629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
630 630
631 631 def test_can_list_arg(self):
632 632 """args in lists are canned"""
633 633 view = self.client[-1]
634 634 view['a'] = 128
635 635 rA = pmod.Reference('a')
636 636 ar = view.apply_async(lambda x: x, [rA])
637 637 r = ar.get(5)
638 638 self.assertEqual(r, [128])
639 639
640 640 def test_can_dict_arg(self):
641 641 """args in dicts are canned"""
642 642 view = self.client[-1]
643 643 view['a'] = 128
644 644 rA = pmod.Reference('a')
645 645 ar = view.apply_async(lambda x: x, dict(foo=rA))
646 646 r = ar.get(5)
647 647 self.assertEqual(r, dict(foo=128))
648 648
649 649 def test_can_list_kwarg(self):
650 650 """kwargs in lists are canned"""
651 651 view = self.client[-1]
652 652 view['a'] = 128
653 653 rA = pmod.Reference('a')
654 654 ar = view.apply_async(lambda x=5: x, x=[rA])
655 655 r = ar.get(5)
656 656 self.assertEqual(r, [128])
657 657
658 658 def test_can_dict_kwarg(self):
659 659 """kwargs in dicts are canned"""
660 660 view = self.client[-1]
661 661 view['a'] = 128
662 662 rA = pmod.Reference('a')
663 663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
664 664 r = ar.get(5)
665 665 self.assertEqual(r, dict(foo=128))
666 666
667 667 def test_map_ref(self):
668 668 """view.map works with references"""
669 669 view = self.client[:]
670 670 ranks = sorted(self.client.ids)
671 671 view.scatter('rank', ranks, flatten=True)
672 672 rrank = pmod.Reference('rank')
673 673
674 674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
675 675 drank = amr.get(5)
676 676 self.assertEqual(drank, [ r*2 for r in ranks ])
677 677
678 def test_nested_getitem_setitem(self):
679 """get and set with view['a.b']"""
680 view = self.client[-1]
681 view.execute('\n'.join([
682 'class A(object): pass',
683 'a = A()',
684 'a.b = 128',
685 ]), block=True)
686 ra = pmod.Reference('a')
687
688 r = view.apply_sync(lambda x: x.b, ra)
689 self.assertEqual(r, 128)
690 self.assertEqual(view['a.b'], 128)
691
692 view['a.b'] = 0
678 693
694 r = view.apply_sync(lambda x: x.b, ra)
695 self.assertEqual(r, 0)
696 self.assertEqual(view['a.b'], 0)
@@ -1,352 +1,355
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 from IPython.external.decorator import decorator
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 46 from IPython.zmq.log import EnginePUBHandler
47 47 from IPython.zmq.serialize import (
48 48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
49 49 )
50 50
51 51 #-----------------------------------------------------------------------------
52 52 # Classes
53 53 #-----------------------------------------------------------------------------
54 54
55 55 class Namespace(dict):
56 56 """Subclass of dict for attribute access to keys."""
57 57
58 58 def __getattr__(self, key):
59 59 """getattr aliased to getitem"""
60 60 if key in self.iterkeys():
61 61 return self[key]
62 62 else:
63 63 raise NameError(key)
64 64
65 65 def __setattr__(self, key, value):
66 66 """setattr aliased to setitem, with strict"""
67 67 if hasattr(dict, key):
68 68 raise KeyError("Cannot override dict keys %r"%key)
69 69 self[key] = value
70 70
71 71
72 72 class ReverseDict(dict):
73 73 """simple double-keyed subset of dict methods."""
74 74
75 75 def __init__(self, *args, **kwargs):
76 76 dict.__init__(self, *args, **kwargs)
77 77 self._reverse = dict()
78 78 for key, value in self.iteritems():
79 79 self._reverse[value] = key
80 80
81 81 def __getitem__(self, key):
82 82 try:
83 83 return dict.__getitem__(self, key)
84 84 except KeyError:
85 85 return self._reverse[key]
86 86
87 87 def __setitem__(self, key, value):
88 88 if key in self._reverse:
89 89 raise KeyError("Can't have key %r on both sides!"%key)
90 90 dict.__setitem__(self, key, value)
91 91 self._reverse[value] = key
92 92
93 93 def pop(self, key):
94 94 value = dict.pop(self, key)
95 95 self._reverse.pop(value)
96 96 return value
97 97
98 98 def get(self, key, default=None):
99 99 try:
100 100 return self[key]
101 101 except KeyError:
102 102 return default
103 103
104 104 #-----------------------------------------------------------------------------
105 105 # Functions
106 106 #-----------------------------------------------------------------------------
107 107
108 108 @decorator
109 109 def log_errors(f, self, *args, **kwargs):
110 110 """decorator to log unhandled exceptions raised in a method.
111 111
112 112 For use wrapping on_recv callbacks, so that exceptions
113 113 do not cause the stream to be closed.
114 114 """
115 115 try:
116 116 return f(self, *args, **kwargs)
117 117 except Exception:
118 118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
119 119
120 120
121 121 def is_url(url):
122 122 """boolean check for whether a string is a zmq url"""
123 123 if '://' not in url:
124 124 return False
125 125 proto, addr = url.split('://', 1)
126 126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
127 127 return False
128 128 return True
129 129
130 130 def validate_url(url):
131 131 """validate a url for zeromq"""
132 132 if not isinstance(url, basestring):
133 133 raise TypeError("url must be a string, not %r"%type(url))
134 134 url = url.lower()
135 135
136 136 proto_addr = url.split('://')
137 137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
138 138 proto, addr = proto_addr
139 139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
140 140
141 141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
142 142 # author: Remi Sabourin
143 143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
144 144
145 145 if proto == 'tcp':
146 146 lis = addr.split(':')
147 147 assert len(lis) == 2, 'Invalid url: %r'%url
148 148 addr,s_port = lis
149 149 try:
150 150 port = int(s_port)
151 151 except ValueError:
152 152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
153 153
154 154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
155 155
156 156 else:
157 157 # only validate tcp urls currently
158 158 pass
159 159
160 160 return True
161 161
162 162
163 163 def validate_url_container(container):
164 164 """validate a potentially nested collection of urls."""
165 165 if isinstance(container, basestring):
166 166 url = container
167 167 return validate_url(url)
168 168 elif isinstance(container, dict):
169 169 container = container.itervalues()
170 170
171 171 for element in container:
172 172 validate_url_container(element)
173 173
174 174
175 175 def split_url(url):
176 176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
177 177 proto_addr = url.split('://')
178 178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
179 179 proto, addr = proto_addr
180 180 lis = addr.split(':')
181 181 assert len(lis) == 2, 'Invalid url: %r'%url
182 182 addr,s_port = lis
183 183 return proto,addr,s_port
184 184
185 185 def disambiguate_ip_address(ip, location=None):
186 186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
187 187 ones, based on the location (default interpretation of location is localhost)."""
188 188 if ip in ('0.0.0.0', '*'):
189 189 try:
190 190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
191 191 except (socket.gaierror, IndexError):
192 192 # couldn't identify this machine, assume localhost
193 193 external_ips = []
194 194 if location is None or location in external_ips or not external_ips:
195 195 # If location is unspecified or cannot be determined, assume local
196 196 ip='127.0.0.1'
197 197 elif location:
198 198 return location
199 199 return ip
200 200
201 201 def disambiguate_url(url, location=None):
202 202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
203 203 ones, based on the location (default interpretation is localhost).
204 204
205 205 This is for zeromq urls, such as tcp://*:10101."""
206 206 try:
207 207 proto,ip,port = split_url(url)
208 208 except AssertionError:
209 209 # probably not tcp url; could be ipc, etc.
210 210 return url
211 211
212 212 ip = disambiguate_ip_address(ip,location)
213 213
214 214 return "%s://%s:%s"%(proto,ip,port)
215 215
216 216
217 217 #--------------------------------------------------------------------------
218 218 # helpers for implementing old MEC API via view.apply
219 219 #--------------------------------------------------------------------------
220 220
221 221 def interactive(f):
222 222 """decorator for making functions appear as interactively defined.
223 223 This results in the function being linked to the user_ns as globals()
224 224 instead of the module globals().
225 225 """
226 226 f.__module__ = '__main__'
227 227 return f
228 228
229 229 @interactive
230 230 def _push(**ns):
231 231 """helper method for implementing `client.push` via `client.apply`"""
232 globals().update(ns)
232 user_ns = globals()
233 tmp = '_IP_PUSH_TMP_'
234 while tmp in user_ns:
235 tmp = tmp + '_'
236 try:
237 for name, value in ns.iteritems():
238 user_ns[tmp] = value
239 exec "%s = %s" % (name, tmp) in user_ns
240 finally:
241 user_ns.pop(tmp, None)
233 242
234 243 @interactive
235 244 def _pull(keys):
236 245 """helper method for implementing `client.pull` via `client.apply`"""
237 user_ns = globals()
238 246 if isinstance(keys, (list,tuple, set)):
239 for key in keys:
240 if key not in user_ns:
241 raise NameError("name '%s' is not defined"%key)
242 return map(user_ns.get, keys)
247 return map(lambda key: eval(key, globals()), keys)
243 248 else:
244 if keys not in user_ns:
245 raise NameError("name '%s' is not defined"%keys)
246 return user_ns.get(keys)
249 return eval(keys, globals())
247 250
248 251 @interactive
249 252 def _execute(code):
250 253 """helper method for implementing `client.execute` via `client.apply`"""
251 254 exec code in globals()
252 255
253 256 #--------------------------------------------------------------------------
254 257 # extra process management utilities
255 258 #--------------------------------------------------------------------------
256 259
257 260 _random_ports = set()
258 261
259 262 def select_random_ports(n):
260 263 """Selects and return n random ports that are available."""
261 264 ports = []
262 265 for i in xrange(n):
263 266 sock = socket.socket()
264 267 sock.bind(('', 0))
265 268 while sock.getsockname()[1] in _random_ports:
266 269 sock.close()
267 270 sock = socket.socket()
268 271 sock.bind(('', 0))
269 272 ports.append(sock)
270 273 for i, sock in enumerate(ports):
271 274 port = sock.getsockname()[1]
272 275 sock.close()
273 276 ports[i] = port
274 277 _random_ports.add(port)
275 278 return ports
276 279
277 280 def signal_children(children):
278 281 """Relay interupt/term signals to children, for more solid process cleanup."""
279 282 def terminate_children(sig, frame):
280 283 log = Application.instance().log
281 284 log.critical("Got signal %i, terminating children..."%sig)
282 285 for child in children:
283 286 child.terminate()
284 287
285 288 sys.exit(sig != SIGINT)
286 289 # sys.exit(sig)
287 290 for sig in (SIGINT, SIGABRT, SIGTERM):
288 291 signal(sig, terminate_children)
289 292
290 293 def generate_exec_key(keyfile):
291 294 import uuid
292 295 newkey = str(uuid.uuid4())
293 296 with open(keyfile, 'w') as f:
294 297 # f.write('ipython-key ')
295 298 f.write(newkey+'\n')
296 299 # set user-only RW permissions (0600)
297 300 # this will have no effect on Windows
298 301 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299 302
300 303
301 304 def integer_loglevel(loglevel):
302 305 try:
303 306 loglevel = int(loglevel)
304 307 except ValueError:
305 308 if isinstance(loglevel, str):
306 309 loglevel = getattr(logging, loglevel)
307 310 return loglevel
308 311
309 312 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 313 logger = logging.getLogger(logname)
311 314 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 315 # don't add a second PUBHandler
313 316 return
314 317 loglevel = integer_loglevel(loglevel)
315 318 lsock = context.socket(zmq.PUB)
316 319 lsock.connect(iface)
317 320 handler = handlers.PUBHandler(lsock)
318 321 handler.setLevel(loglevel)
319 322 handler.root_topic = root
320 323 logger.addHandler(handler)
321 324 logger.setLevel(loglevel)
322 325
323 326 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 327 logger = logging.getLogger()
325 328 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 329 # don't add a second PUBHandler
327 330 return
328 331 loglevel = integer_loglevel(loglevel)
329 332 lsock = context.socket(zmq.PUB)
330 333 lsock.connect(iface)
331 334 handler = EnginePUBHandler(engine, lsock)
332 335 handler.setLevel(loglevel)
333 336 logger.addHandler(handler)
334 337 logger.setLevel(loglevel)
335 338 return logger
336 339
337 340 def local_logger(logname, loglevel=logging.DEBUG):
338 341 loglevel = integer_loglevel(loglevel)
339 342 logger = logging.getLogger(logname)
340 343 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 344 # don't add a second StreamHandler
342 345 return
343 346 handler = logging.StreamHandler()
344 347 handler.setLevel(loglevel)
345 348 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 349 datefmt="%Y-%m-%d %H:%M:%S")
347 350 handler.setFormatter(formatter)
348 351
349 352 logger.addHandler(handler)
350 353 logger.setLevel(loglevel)
351 354 return logger
352 355
General Comments 0
You need to be logged in to leave comments. Login now