##// END OF EJS Templates
Merge pull request #2327 from bfroehle/remote_push_pull_nested...
Min RK -
r8363:0af1e9d0 merge
parent child Browse files
Show More
@@ -1,703 +1,721
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from tempfile import mktemp
23 23 from StringIO import StringIO
24 24
25 25 import zmq
26 26 from nose import SkipTest
27 27
28 28 from IPython.testing import decorators as dec
29 29 from IPython.testing.ipunittest import ParametricTestCase
30 30 from IPython.utils.io import capture_output
31 31
32 32 from IPython import parallel as pmod
33 33 from IPython.parallel import error
34 34 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
35 35 from IPython.parallel import DirectView
36 36 from IPython.parallel.util import interactive
37 37
38 38 from IPython.parallel.tests import add_engines
39 39
40 40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41 41
42 42 def setup():
43 43 add_engines(3, total=True)
44 44
45 45 class TestView(ClusterTestCase, ParametricTestCase):
46 46
47 47 def setUp(self):
48 48 # On Win XP, wait for resource cleanup, else parallel test group fails
49 49 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
50 50 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
51 51 time.sleep(2)
52 52 super(TestView, self).setUp()
53 53
54 54 def test_z_crash_mux(self):
55 55 """test graceful handling of engine death (direct)"""
56 56 raise SkipTest("crash tests disabled, due to undesirable crash reports")
57 57 # self.add_engines(1)
58 58 eid = self.client.ids[-1]
59 59 ar = self.client[eid].apply_async(crash)
60 60 self.assertRaisesRemote(error.EngineError, ar.get, 10)
61 61 eid = ar.engine_id
62 62 tic = time.time()
63 63 while eid in self.client.ids and time.time()-tic < 5:
64 64 time.sleep(.01)
65 65 self.client.spin()
66 66 self.assertFalse(eid in self.client.ids, "Engine should have died")
67 67
68 68 def test_push_pull(self):
69 69 """test pushing and pulling"""
70 70 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
71 71 t = self.client.ids[-1]
72 72 v = self.client[t]
73 73 push = v.push
74 74 pull = v.pull
75 75 v.block=True
76 76 nengines = len(self.client)
77 77 push({'data':data})
78 78 d = pull('data')
79 79 self.assertEqual(d, data)
80 80 self.client[:].push({'data':data})
81 81 d = self.client[:].pull('data', block=True)
82 82 self.assertEqual(d, nengines*[data])
83 83 ar = push({'data':data}, block=False)
84 84 self.assertTrue(isinstance(ar, AsyncResult))
85 85 r = ar.get()
86 86 ar = self.client[:].pull('data', block=False)
87 87 self.assertTrue(isinstance(ar, AsyncResult))
88 88 r = ar.get()
89 89 self.assertEqual(r, nengines*[data])
90 90 self.client[:].push(dict(a=10,b=20))
91 91 r = self.client[:].pull(('a','b'), block=True)
92 92 self.assertEqual(r, nengines*[[10,20]])
93 93
94 94 def test_push_pull_function(self):
95 95 "test pushing and pulling functions"
96 96 def testf(x):
97 97 return 2.0*x
98 98
99 99 t = self.client.ids[-1]
100 100 v = self.client[t]
101 101 v.block=True
102 102 push = v.push
103 103 pull = v.pull
104 104 execute = v.execute
105 105 push({'testf':testf})
106 106 r = pull('testf')
107 107 self.assertEqual(r(1.0), testf(1.0))
108 108 execute('r = testf(10)')
109 109 r = pull('r')
110 110 self.assertEqual(r, testf(10))
111 111 ar = self.client[:].push({'testf':testf}, block=False)
112 112 ar.get()
113 113 ar = self.client[:].pull('testf', block=False)
114 114 rlist = ar.get()
115 115 for r in rlist:
116 116 self.assertEqual(r(1.0), testf(1.0))
117 117 execute("def g(x): return x*x")
118 118 r = pull(('testf','g'))
119 119 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
120 120
121 121 def test_push_function_globals(self):
122 122 """test that pushed functions have access to globals"""
123 123 @interactive
124 124 def geta():
125 125 return a
126 126 # self.add_engines(1)
127 127 v = self.client[-1]
128 128 v.block=True
129 129 v['f'] = geta
130 130 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
131 131 v.execute('a=5')
132 132 v.execute('b=f()')
133 133 self.assertEqual(v['b'], 5)
134 134
135 135 def test_push_function_defaults(self):
136 136 """test that pushed functions preserve default args"""
137 137 def echo(a=10):
138 138 return a
139 139 v = self.client[-1]
140 140 v.block=True
141 141 v['f'] = echo
142 142 v.execute('b=f()')
143 143 self.assertEqual(v['b'], 10)
144 144
145 145 def test_get_result(self):
146 146 """test getting results from the Hub."""
147 147 c = pmod.Client(profile='iptest')
148 148 # self.add_engines(1)
149 149 t = c.ids[-1]
150 150 v = c[t]
151 151 v2 = self.client[t]
152 152 ar = v.apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 155 ahr = v2.get_result(ar.msg_ids)
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 157 self.assertEqual(ahr.get(), ar.get())
158 158 ar2 = v2.get_result(ar.msg_ids)
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.spin()
161 161 c.close()
162 162
163 163 def test_run_newline(self):
164 164 """test that run appends newline to files"""
165 165 tmpfile = mktemp()
166 166 with open(tmpfile, 'w') as f:
167 167 f.write("""def g():
168 168 return 5
169 169 """)
170 170 v = self.client[-1]
171 171 v.run(tmpfile, block=True)
172 172 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
173 173
174 174 def test_apply_tracked(self):
175 175 """test tracking for apply"""
176 176 # self.add_engines(1)
177 177 t = self.client.ids[-1]
178 178 v = self.client[t]
179 179 v.block=False
180 180 def echo(n=1024*1024, **kwargs):
181 181 with v.temp_flags(**kwargs):
182 182 return v.apply(lambda x: x, 'x'*n)
183 183 ar = echo(1, track=False)
184 184 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
185 185 self.assertTrue(ar.sent)
186 186 ar = echo(track=True)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertEqual(ar.sent, ar._tracker.done)
189 189 ar._tracker.wait()
190 190 self.assertTrue(ar.sent)
191 191
192 192 def test_push_tracked(self):
193 193 t = self.client.ids[-1]
194 194 ns = dict(x='x'*1024*1024)
195 195 v = self.client[t]
196 196 ar = v.push(ns, block=False, track=False)
197 197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
198 198 self.assertTrue(ar.sent)
199 199
200 200 ar = v.push(ns, block=False, track=True)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 ar._tracker.wait()
203 203 self.assertEqual(ar.sent, ar._tracker.done)
204 204 self.assertTrue(ar.sent)
205 205 ar.get()
206 206
207 207 def test_scatter_tracked(self):
208 208 t = self.client.ids
209 209 x='x'*1024*1024
210 210 ar = self.client[t].scatter('x', x, block=False, track=False)
211 211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 212 self.assertTrue(ar.sent)
213 213
214 214 ar = self.client[t].scatter('x', x, block=False, track=True)
215 215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 216 self.assertEqual(ar.sent, ar._tracker.done)
217 217 ar._tracker.wait()
218 218 self.assertTrue(ar.sent)
219 219 ar.get()
220 220
221 221 def test_remote_reference(self):
222 222 v = self.client[-1]
223 223 v['a'] = 123
224 224 ra = pmod.Reference('a')
225 225 b = v.apply_sync(lambda x: x, ra)
226 226 self.assertEqual(b, 123)
227 227
228 228
229 229 def test_scatter_gather(self):
230 230 view = self.client[:]
231 231 seq1 = range(16)
232 232 view.scatter('a', seq1)
233 233 seq2 = view.gather('a', block=True)
234 234 self.assertEqual(seq2, seq1)
235 235 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
236 236
237 237 @skip_without('numpy')
238 238 def test_scatter_gather_numpy(self):
239 239 import numpy
240 240 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
241 241 view = self.client[:]
242 242 a = numpy.arange(64)
243 243 view.scatter('a', a, block=True)
244 244 b = view.gather('a', block=True)
245 245 assert_array_equal(b, a)
246 246
247 247 def test_scatter_gather_lazy(self):
248 248 """scatter/gather with targets='all'"""
249 249 view = self.client.direct_view(targets='all')
250 250 x = range(64)
251 251 view.scatter('x', x)
252 252 gathered = view.gather('x', block=True)
253 253 self.assertEqual(gathered, x)
254 254
255 255
256 256 @dec.known_failure_py3
257 257 @skip_without('numpy')
258 258 def test_push_numpy_nocopy(self):
259 259 import numpy
260 260 view = self.client[:]
261 261 a = numpy.arange(64)
262 262 view['A'] = a
263 263 @interactive
264 264 def check_writeable(x):
265 265 return x.flags.writeable
266 266
267 267 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
268 268 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
269 269
270 270 view.push(dict(B=a))
271 271 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
272 272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 273
274 274 @skip_without('numpy')
275 275 def test_apply_numpy(self):
276 276 """view.apply(f, ndarray)"""
277 277 import numpy
278 278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
279 279
280 280 A = numpy.random.random((100,100))
281 281 view = self.client[-1]
282 282 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
283 283 B = A.astype(dt)
284 284 C = view.apply_sync(lambda x:x, B)
285 285 assert_array_equal(B,C)
286 286
287 287 @skip_without('numpy')
288 288 def test_push_pull_recarray(self):
289 289 """push/pull recarrays"""
290 290 import numpy
291 291 from numpy.testing.utils import assert_array_equal
292 292
293 293 view = self.client[-1]
294 294
295 295 R = numpy.array([
296 296 (1, 'hi', 0.),
297 297 (2**30, 'there', 2.5),
298 298 (-99999, 'world', -12345.6789),
299 299 ], [('n', int), ('s', '|S10'), ('f', float)])
300 300
301 301 view['RR'] = R
302 302 R2 = view['RR']
303 303
304 304 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
305 305 self.assertEqual(r_dtype, R.dtype)
306 306 self.assertEqual(r_shape, R.shape)
307 307 self.assertEqual(R2.dtype, R.dtype)
308 308 self.assertEqual(R2.shape, R.shape)
309 309 assert_array_equal(R2, R)
310 310
311 311 def test_map(self):
312 312 view = self.client[:]
313 313 def f(x):
314 314 return x**2
315 315 data = range(16)
316 316 r = view.map_sync(f, data)
317 317 self.assertEqual(r, map(f, data))
318 318
319 319 def test_map_iterable(self):
320 320 """test map on iterables (direct)"""
321 321 view = self.client[:]
322 322 # 101 is prime, so it won't be evenly distributed
323 323 arr = range(101)
324 324 # ensure it will be an iterator, even in Python 3
325 325 it = iter(arr)
326 326 r = view.map_sync(lambda x:x, arr)
327 327 self.assertEqual(r, list(arr))
328 328
329 329 def test_scatter_gather_nonblocking(self):
330 330 data = range(16)
331 331 view = self.client[:]
332 332 view.scatter('a', data, block=False)
333 333 ar = view.gather('a', block=False)
334 334 self.assertEqual(ar.get(), data)
335 335
336 336 @skip_without('numpy')
337 337 def test_scatter_gather_numpy_nonblocking(self):
338 338 import numpy
339 339 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
340 340 a = numpy.arange(64)
341 341 view = self.client[:]
342 342 ar = view.scatter('a', a, block=False)
343 343 self.assertTrue(isinstance(ar, AsyncResult))
344 344 amr = view.gather('a', block=False)
345 345 self.assertTrue(isinstance(amr, AsyncMapResult))
346 346 assert_array_equal(amr.get(), a)
347 347
348 348 def test_execute(self):
349 349 view = self.client[:]
350 350 # self.client.debug=True
351 351 execute = view.execute
352 352 ar = execute('c=30', block=False)
353 353 self.assertTrue(isinstance(ar, AsyncResult))
354 354 ar = execute('d=[0,1,2]', block=False)
355 355 self.client.wait(ar, 1)
356 356 self.assertEqual(len(ar.get()), len(self.client))
357 357 for c in view['c']:
358 358 self.assertEqual(c, 30)
359 359
360 360 def test_abort(self):
361 361 view = self.client[-1]
362 362 ar = view.execute('import time; time.sleep(1)', block=False)
363 363 ar2 = view.apply_async(lambda : 2)
364 364 ar3 = view.apply_async(lambda : 3)
365 365 view.abort(ar2)
366 366 view.abort(ar3.msg_ids)
367 367 self.assertRaises(error.TaskAborted, ar2.get)
368 368 self.assertRaises(error.TaskAborted, ar3.get)
369 369
370 370 def test_abort_all(self):
371 371 """view.abort() aborts all outstanding tasks"""
372 372 view = self.client[-1]
373 373 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
374 374 view.abort()
375 375 view.wait(timeout=5)
376 376 for ar in ars[5:]:
377 377 self.assertRaises(error.TaskAborted, ar.get)
378 378
379 379 def test_temp_flags(self):
380 380 view = self.client[-1]
381 381 view.block=True
382 382 with view.temp_flags(block=False):
383 383 self.assertFalse(view.block)
384 384 self.assertTrue(view.block)
385 385
386 386 @dec.known_failure_py3
387 387 def test_importer(self):
388 388 view = self.client[-1]
389 389 view.clear(block=True)
390 390 with view.importer:
391 391 import re
392 392
393 393 @interactive
394 394 def findall(pat, s):
395 395 # this globals() step isn't necessary in real code
396 396 # only to prevent a closure in the test
397 397 re = globals()['re']
398 398 return re.findall(pat, s)
399 399
400 400 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
401 401
402 402 def test_unicode_execute(self):
403 403 """test executing unicode strings"""
404 404 v = self.client[-1]
405 405 v.block=True
406 406 if sys.version_info[0] >= 3:
407 407 code="a='é'"
408 408 else:
409 409 code=u"a=u'é'"
410 410 v.execute(code)
411 411 self.assertEqual(v['a'], u'é')
412 412
413 413 def test_unicode_apply_result(self):
414 414 """test unicode apply results"""
415 415 v = self.client[-1]
416 416 r = v.apply_sync(lambda : u'é')
417 417 self.assertEqual(r, u'é')
418 418
419 419 def test_unicode_apply_arg(self):
420 420 """test passing unicode arguments to apply"""
421 421 v = self.client[-1]
422 422
423 423 @interactive
424 424 def check_unicode(a, check):
425 425 assert isinstance(a, unicode), "%r is not unicode"%a
426 426 assert isinstance(check, bytes), "%r is not bytes"%check
427 427 assert a.encode('utf8') == check, "%s != %s"%(a,check)
428 428
429 429 for s in [ u'é', u'ßø®∫',u'asdf' ]:
430 430 try:
431 431 v.apply_sync(check_unicode, s, s.encode('utf8'))
432 432 except error.RemoteError as e:
433 433 if e.ename == 'AssertionError':
434 434 self.fail(e.evalue)
435 435 else:
436 436 raise e
437 437
438 438 def test_map_reference(self):
439 439 """view.map(<Reference>, *seqs) should work"""
440 440 v = self.client[:]
441 441 v.scatter('n', self.client.ids, flatten=True)
442 442 v.execute("f = lambda x,y: x*y")
443 443 rf = pmod.Reference('f')
444 444 nlist = list(range(10))
445 445 mlist = nlist[::-1]
446 446 expected = [ m*n for m,n in zip(mlist, nlist) ]
447 447 result = v.map_sync(rf, mlist, nlist)
448 448 self.assertEqual(result, expected)
449 449
450 450 def test_apply_reference(self):
451 451 """view.apply(<Reference>, *args) should work"""
452 452 v = self.client[:]
453 453 v.scatter('n', self.client.ids, flatten=True)
454 454 v.execute("f = lambda x: n*x")
455 455 rf = pmod.Reference('f')
456 456 result = v.apply_sync(rf, 5)
457 457 expected = [ 5*id for id in self.client.ids ]
458 458 self.assertEqual(result, expected)
459 459
460 460 def test_eval_reference(self):
461 461 v = self.client[self.client.ids[0]]
462 462 v['g'] = range(5)
463 463 rg = pmod.Reference('g[0]')
464 464 echo = lambda x:x
465 465 self.assertEqual(v.apply_sync(echo, rg), 0)
466 466
467 467 def test_reference_nameerror(self):
468 468 v = self.client[self.client.ids[0]]
469 469 r = pmod.Reference('elvis_has_left')
470 470 echo = lambda x:x
471 471 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
472 472
473 473 def test_single_engine_map(self):
474 474 e0 = self.client[self.client.ids[0]]
475 475 r = range(5)
476 476 check = [ -1*i for i in r ]
477 477 result = e0.map_sync(lambda x: -1*x, r)
478 478 self.assertEqual(result, check)
479 479
480 480 def test_len(self):
481 481 """len(view) makes sense"""
482 482 e0 = self.client[self.client.ids[0]]
483 483 yield self.assertEqual(len(e0), 1)
484 484 v = self.client[:]
485 485 yield self.assertEqual(len(v), len(self.client.ids))
486 486 v = self.client.direct_view('all')
487 487 yield self.assertEqual(len(v), len(self.client.ids))
488 488 v = self.client[:2]
489 489 yield self.assertEqual(len(v), 2)
490 490 v = self.client[:1]
491 491 yield self.assertEqual(len(v), 1)
492 492 v = self.client.load_balanced_view()
493 493 yield self.assertEqual(len(v), len(self.client.ids))
494 494 # parametric tests seem to require manual closing?
495 495 self.client.close()
496 496
497 497
498 498 # begin execute tests
499 499
500 500 def test_execute_reply(self):
501 501 e0 = self.client[self.client.ids[0]]
502 502 e0.block = True
503 503 ar = e0.execute("5", silent=False)
504 504 er = ar.get()
505 505 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
506 506 self.assertEqual(er.pyout['data']['text/plain'], '5')
507 507
508 508 def test_execute_reply_stdout(self):
509 509 e0 = self.client[self.client.ids[0]]
510 510 e0.block = True
511 511 ar = e0.execute("print (5)", silent=False)
512 512 er = ar.get()
513 513 self.assertEqual(er.stdout.strip(), '5')
514 514
515 515 def test_execute_pyout(self):
516 516 """execute triggers pyout with silent=False"""
517 517 view = self.client[:]
518 518 ar = view.execute("5", silent=False, block=True)
519 519
520 520 expected = [{'text/plain' : '5'}] * len(view)
521 521 mimes = [ out['data'] for out in ar.pyout ]
522 522 self.assertEqual(mimes, expected)
523 523
524 524 def test_execute_silent(self):
525 525 """execute does not trigger pyout with silent=True"""
526 526 view = self.client[:]
527 527 ar = view.execute("5", block=True)
528 528 expected = [None] * len(view)
529 529 self.assertEqual(ar.pyout, expected)
530 530
531 531 def test_execute_magic(self):
532 532 """execute accepts IPython commands"""
533 533 view = self.client[:]
534 534 view.execute("a = 5")
535 535 ar = view.execute("%whos", block=True)
536 536 # this will raise, if that failed
537 537 ar.get(5)
538 538 for stdout in ar.stdout:
539 539 lines = stdout.splitlines()
540 540 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
541 541 found = False
542 542 for line in lines[2:]:
543 543 split = line.split()
544 544 if split == ['a', 'int', '5']:
545 545 found = True
546 546 break
547 547 self.assertTrue(found, "whos output wrong: %s" % stdout)
548 548
549 549 def test_execute_displaypub(self):
550 550 """execute tracks display_pub output"""
551 551 view = self.client[:]
552 552 view.execute("from IPython.core.display import *")
553 553 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
554 554
555 555 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
556 556 for outputs in ar.outputs:
557 557 mimes = [ out['data'] for out in outputs ]
558 558 self.assertEqual(mimes, expected)
559 559
560 560 def test_apply_displaypub(self):
561 561 """apply tracks display_pub output"""
562 562 view = self.client[:]
563 563 view.execute("from IPython.core.display import *")
564 564
565 565 @interactive
566 566 def publish():
567 567 [ display(i) for i in range(5) ]
568 568
569 569 ar = view.apply_async(publish)
570 570 ar.get(5)
571 571 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
572 572 for outputs in ar.outputs:
573 573 mimes = [ out['data'] for out in outputs ]
574 574 self.assertEqual(mimes, expected)
575 575
576 576 def test_execute_raises(self):
577 577 """exceptions in execute requests raise appropriately"""
578 578 view = self.client[-1]
579 579 ar = view.execute("1/0")
580 580 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
581 581
582 582 def test_remoteerror_render_exception(self):
583 583 """RemoteErrors get nice tracebacks"""
584 584 view = self.client[-1]
585 585 ar = view.execute("1/0")
586 586 ip = get_ipython()
587 587 ip.user_ns['ar'] = ar
588 588 with capture_output() as io:
589 589 ip.run_cell("ar.get(2)")
590 590
591 591 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
592 592
593 593 def test_compositeerror_render_exception(self):
594 594 """CompositeErrors get nice tracebacks"""
595 595 view = self.client[:]
596 596 ar = view.execute("1/0")
597 597 ip = get_ipython()
598 598 ip.user_ns['ar'] = ar
599 599 with capture_output() as io:
600 600 ip.run_cell("ar.get(2)")
601 601
602 602 self.assertEqual(io.stdout.count('ZeroDivisionError'), len(view) * 2, io.stdout)
603 603 self.assertEqual(io.stdout.count('by zero'), len(view), io.stdout)
604 604 self.assertEqual(io.stdout.count(':execute'), len(view), io.stdout)
605 605
606 606 @dec.skipif_not_matplotlib
607 607 def test_magic_pylab(self):
608 608 """%pylab works on engines"""
609 609 view = self.client[-1]
610 610 ar = view.execute("%pylab inline")
611 611 # at least check if this raised:
612 612 reply = ar.get(5)
613 613 # include imports, in case user config
614 614 ar = view.execute("plot(rand(100))", silent=False)
615 615 reply = ar.get(5)
616 616 self.assertEqual(len(reply.outputs), 1)
617 617 output = reply.outputs[0]
618 618 self.assertTrue("data" in output)
619 619 data = output['data']
620 620 self.assertTrue("image/png" in data)
621 621
622 622 def test_func_default_func(self):
623 623 """interactively defined function as apply func default"""
624 624 def foo():
625 625 return 'foo'
626 626
627 627 def bar(f=foo):
628 628 return f()
629 629
630 630 view = self.client[-1]
631 631 ar = view.apply_async(bar)
632 632 r = ar.get(10)
633 633 self.assertEqual(r, 'foo')
634 634 def test_data_pub_single(self):
635 635 view = self.client[-1]
636 636 ar = view.execute('\n'.join([
637 637 'from IPython.zmq.datapub import publish_data',
638 638 'for i in range(5):',
639 639 ' publish_data(dict(i=i))'
640 640 ]), block=False)
641 641 self.assertTrue(isinstance(ar.data, dict))
642 642 ar.get(5)
643 643 self.assertEqual(ar.data, dict(i=4))
644 644
645 645 def test_data_pub(self):
646 646 view = self.client[:]
647 647 ar = view.execute('\n'.join([
648 648 'from IPython.zmq.datapub import publish_data',
649 649 'for i in range(5):',
650 650 ' publish_data(dict(i=i))'
651 651 ]), block=False)
652 652 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
653 653 ar.get(5)
654 654 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
655 655
656 656 def test_can_list_arg(self):
657 657 """args in lists are canned"""
658 658 view = self.client[-1]
659 659 view['a'] = 128
660 660 rA = pmod.Reference('a')
661 661 ar = view.apply_async(lambda x: x, [rA])
662 662 r = ar.get(5)
663 663 self.assertEqual(r, [128])
664 664
665 665 def test_can_dict_arg(self):
666 666 """args in dicts are canned"""
667 667 view = self.client[-1]
668 668 view['a'] = 128
669 669 rA = pmod.Reference('a')
670 670 ar = view.apply_async(lambda x: x, dict(foo=rA))
671 671 r = ar.get(5)
672 672 self.assertEqual(r, dict(foo=128))
673 673
674 674 def test_can_list_kwarg(self):
675 675 """kwargs in lists are canned"""
676 676 view = self.client[-1]
677 677 view['a'] = 128
678 678 rA = pmod.Reference('a')
679 679 ar = view.apply_async(lambda x=5: x, x=[rA])
680 680 r = ar.get(5)
681 681 self.assertEqual(r, [128])
682 682
683 683 def test_can_dict_kwarg(self):
684 684 """kwargs in dicts are canned"""
685 685 view = self.client[-1]
686 686 view['a'] = 128
687 687 rA = pmod.Reference('a')
688 688 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
689 689 r = ar.get(5)
690 690 self.assertEqual(r, dict(foo=128))
691 691
692 692 def test_map_ref(self):
693 693 """view.map works with references"""
694 694 view = self.client[:]
695 695 ranks = sorted(self.client.ids)
696 696 view.scatter('rank', ranks, flatten=True)
697 697 rrank = pmod.Reference('rank')
698 698
699 699 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
700 700 drank = amr.get(5)
701 701 self.assertEqual(drank, [ r*2 for r in ranks ])
702 702
703 def test_nested_getitem_setitem(self):
704 """get and set with view['a.b']"""
705 view = self.client[-1]
706 view.execute('\n'.join([
707 'class A(object): pass',
708 'a = A()',
709 'a.b = 128',
710 ]), block=True)
711 ra = pmod.Reference('a')
712
713 r = view.apply_sync(lambda x: x.b, ra)
714 self.assertEqual(r, 128)
715 self.assertEqual(view['a.b'], 128)
716
717 view['a.b'] = 0
703 718
719 r = view.apply_sync(lambda x: x.b, ra)
720 self.assertEqual(r, 0)
721 self.assertEqual(view['a.b'], 0)
@@ -1,352 +1,355
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 from IPython.external.decorator import decorator
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 46 from IPython.zmq.log import EnginePUBHandler
47 47 from IPython.zmq.serialize import (
48 48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
49 49 )
50 50
51 51 #-----------------------------------------------------------------------------
52 52 # Classes
53 53 #-----------------------------------------------------------------------------
54 54
55 55 class Namespace(dict):
56 56 """Subclass of dict for attribute access to keys."""
57 57
58 58 def __getattr__(self, key):
59 59 """getattr aliased to getitem"""
60 60 if key in self.iterkeys():
61 61 return self[key]
62 62 else:
63 63 raise NameError(key)
64 64
65 65 def __setattr__(self, key, value):
66 66 """setattr aliased to setitem, with strict"""
67 67 if hasattr(dict, key):
68 68 raise KeyError("Cannot override dict keys %r"%key)
69 69 self[key] = value
70 70
71 71
72 72 class ReverseDict(dict):
73 73 """simple double-keyed subset of dict methods."""
74 74
75 75 def __init__(self, *args, **kwargs):
76 76 dict.__init__(self, *args, **kwargs)
77 77 self._reverse = dict()
78 78 for key, value in self.iteritems():
79 79 self._reverse[value] = key
80 80
81 81 def __getitem__(self, key):
82 82 try:
83 83 return dict.__getitem__(self, key)
84 84 except KeyError:
85 85 return self._reverse[key]
86 86
87 87 def __setitem__(self, key, value):
88 88 if key in self._reverse:
89 89 raise KeyError("Can't have key %r on both sides!"%key)
90 90 dict.__setitem__(self, key, value)
91 91 self._reverse[value] = key
92 92
93 93 def pop(self, key):
94 94 value = dict.pop(self, key)
95 95 self._reverse.pop(value)
96 96 return value
97 97
98 98 def get(self, key, default=None):
99 99 try:
100 100 return self[key]
101 101 except KeyError:
102 102 return default
103 103
104 104 #-----------------------------------------------------------------------------
105 105 # Functions
106 106 #-----------------------------------------------------------------------------
107 107
108 108 @decorator
109 109 def log_errors(f, self, *args, **kwargs):
110 110 """decorator to log unhandled exceptions raised in a method.
111 111
112 112 For use wrapping on_recv callbacks, so that exceptions
113 113 do not cause the stream to be closed.
114 114 """
115 115 try:
116 116 return f(self, *args, **kwargs)
117 117 except Exception:
118 118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
119 119
120 120
121 121 def is_url(url):
122 122 """boolean check for whether a string is a zmq url"""
123 123 if '://' not in url:
124 124 return False
125 125 proto, addr = url.split('://', 1)
126 126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
127 127 return False
128 128 return True
129 129
130 130 def validate_url(url):
131 131 """validate a url for zeromq"""
132 132 if not isinstance(url, basestring):
133 133 raise TypeError("url must be a string, not %r"%type(url))
134 134 url = url.lower()
135 135
136 136 proto_addr = url.split('://')
137 137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
138 138 proto, addr = proto_addr
139 139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
140 140
141 141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
142 142 # author: Remi Sabourin
143 143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
144 144
145 145 if proto == 'tcp':
146 146 lis = addr.split(':')
147 147 assert len(lis) == 2, 'Invalid url: %r'%url
148 148 addr,s_port = lis
149 149 try:
150 150 port = int(s_port)
151 151 except ValueError:
152 152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
153 153
154 154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
155 155
156 156 else:
157 157 # only validate tcp urls currently
158 158 pass
159 159
160 160 return True
161 161
162 162
163 163 def validate_url_container(container):
164 164 """validate a potentially nested collection of urls."""
165 165 if isinstance(container, basestring):
166 166 url = container
167 167 return validate_url(url)
168 168 elif isinstance(container, dict):
169 169 container = container.itervalues()
170 170
171 171 for element in container:
172 172 validate_url_container(element)
173 173
174 174
175 175 def split_url(url):
176 176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
177 177 proto_addr = url.split('://')
178 178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
179 179 proto, addr = proto_addr
180 180 lis = addr.split(':')
181 181 assert len(lis) == 2, 'Invalid url: %r'%url
182 182 addr,s_port = lis
183 183 return proto,addr,s_port
184 184
185 185 def disambiguate_ip_address(ip, location=None):
186 186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
187 187 ones, based on the location (default interpretation of location is localhost)."""
188 188 if ip in ('0.0.0.0', '*'):
189 189 try:
190 190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
191 191 except (socket.gaierror, IndexError):
192 192 # couldn't identify this machine, assume localhost
193 193 external_ips = []
194 194 if location is None or location in external_ips or not external_ips:
195 195 # If location is unspecified or cannot be determined, assume local
196 196 ip='127.0.0.1'
197 197 elif location:
198 198 return location
199 199 return ip
200 200
201 201 def disambiguate_url(url, location=None):
202 202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
203 203 ones, based on the location (default interpretation is localhost).
204 204
205 205 This is for zeromq urls, such as tcp://*:10101."""
206 206 try:
207 207 proto,ip,port = split_url(url)
208 208 except AssertionError:
209 209 # probably not tcp url; could be ipc, etc.
210 210 return url
211 211
212 212 ip = disambiguate_ip_address(ip,location)
213 213
214 214 return "%s://%s:%s"%(proto,ip,port)
215 215
216 216
217 217 #--------------------------------------------------------------------------
218 218 # helpers for implementing old MEC API via view.apply
219 219 #--------------------------------------------------------------------------
220 220
221 221 def interactive(f):
222 222 """decorator for making functions appear as interactively defined.
223 223 This results in the function being linked to the user_ns as globals()
224 224 instead of the module globals().
225 225 """
226 226 f.__module__ = '__main__'
227 227 return f
228 228
229 229 @interactive
230 230 def _push(**ns):
231 231 """helper method for implementing `client.push` via `client.apply`"""
232 globals().update(ns)
232 user_ns = globals()
233 tmp = '_IP_PUSH_TMP_'
234 while tmp in user_ns:
235 tmp = tmp + '_'
236 try:
237 for name, value in ns.iteritems():
238 user_ns[tmp] = value
239 exec "%s = %s" % (name, tmp) in user_ns
240 finally:
241 user_ns.pop(tmp, None)
233 242
234 243 @interactive
235 244 def _pull(keys):
236 245 """helper method for implementing `client.pull` via `client.apply`"""
237 user_ns = globals()
238 246 if isinstance(keys, (list,tuple, set)):
239 for key in keys:
240 if key not in user_ns:
241 raise NameError("name '%s' is not defined"%key)
242 return map(user_ns.get, keys)
247 return map(lambda key: eval(key, globals()), keys)
243 248 else:
244 if keys not in user_ns:
245 raise NameError("name '%s' is not defined"%keys)
246 return user_ns.get(keys)
249 return eval(keys, globals())
247 250
248 251 @interactive
249 252 def _execute(code):
250 253 """helper method for implementing `client.execute` via `client.apply`"""
251 254 exec code in globals()
252 255
253 256 #--------------------------------------------------------------------------
254 257 # extra process management utilities
255 258 #--------------------------------------------------------------------------
256 259
257 260 _random_ports = set()
258 261
259 262 def select_random_ports(n):
260 263 """Selects and return n random ports that are available."""
261 264 ports = []
262 265 for i in xrange(n):
263 266 sock = socket.socket()
264 267 sock.bind(('', 0))
265 268 while sock.getsockname()[1] in _random_ports:
266 269 sock.close()
267 270 sock = socket.socket()
268 271 sock.bind(('', 0))
269 272 ports.append(sock)
270 273 for i, sock in enumerate(ports):
271 274 port = sock.getsockname()[1]
272 275 sock.close()
273 276 ports[i] = port
274 277 _random_ports.add(port)
275 278 return ports
276 279
277 280 def signal_children(children):
278 281 """Relay interupt/term signals to children, for more solid process cleanup."""
279 282 def terminate_children(sig, frame):
280 283 log = Application.instance().log
281 284 log.critical("Got signal %i, terminating children..."%sig)
282 285 for child in children:
283 286 child.terminate()
284 287
285 288 sys.exit(sig != SIGINT)
286 289 # sys.exit(sig)
287 290 for sig in (SIGINT, SIGABRT, SIGTERM):
288 291 signal(sig, terminate_children)
289 292
290 293 def generate_exec_key(keyfile):
291 294 import uuid
292 295 newkey = str(uuid.uuid4())
293 296 with open(keyfile, 'w') as f:
294 297 # f.write('ipython-key ')
295 298 f.write(newkey+'\n')
296 299 # set user-only RW permissions (0600)
297 300 # this will have no effect on Windows
298 301 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299 302
300 303
301 304 def integer_loglevel(loglevel):
302 305 try:
303 306 loglevel = int(loglevel)
304 307 except ValueError:
305 308 if isinstance(loglevel, str):
306 309 loglevel = getattr(logging, loglevel)
307 310 return loglevel
308 311
309 312 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 313 logger = logging.getLogger(logname)
311 314 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 315 # don't add a second PUBHandler
313 316 return
314 317 loglevel = integer_loglevel(loglevel)
315 318 lsock = context.socket(zmq.PUB)
316 319 lsock.connect(iface)
317 320 handler = handlers.PUBHandler(lsock)
318 321 handler.setLevel(loglevel)
319 322 handler.root_topic = root
320 323 logger.addHandler(handler)
321 324 logger.setLevel(loglevel)
322 325
323 326 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 327 logger = logging.getLogger()
325 328 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 329 # don't add a second PUBHandler
327 330 return
328 331 loglevel = integer_loglevel(loglevel)
329 332 lsock = context.socket(zmq.PUB)
330 333 lsock.connect(iface)
331 334 handler = EnginePUBHandler(engine, lsock)
332 335 handler.setLevel(loglevel)
333 336 logger.addHandler(handler)
334 337 logger.setLevel(loglevel)
335 338 return logger
336 339
337 340 def local_logger(logname, loglevel=logging.DEBUG):
338 341 loglevel = integer_loglevel(loglevel)
339 342 logger = logging.getLogger(logname)
340 343 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 344 # don't add a second StreamHandler
342 345 return
343 346 handler = logging.StreamHandler()
344 347 handler.setLevel(loglevel)
345 348 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 349 datefmt="%Y-%m-%d %H:%M:%S")
347 350 handler.setFormatter(formatter)
348 351
349 352 logger.addHandler(handler)
350 353 logger.setLevel(loglevel)
351 354 return logger
352 355
General Comments 0
You need to be logged in to leave comments. Login now