##// END OF EJS Templates
fix dangling `buffer` in IPython.parallel.util...
MinRK -
Show More
@@ -1,493 +1,506
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython.testing import decorators as dec
28 28
29 29 from IPython import parallel as pmod
30 30 from IPython.parallel import error
31 31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 32 from IPython.parallel import DirectView
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 38
39 39 def setup():
40 40 add_engines(3)
41 41
42 42 class TestView(ClusterTestCase):
43 43
44 44 def test_z_crash_mux(self):
45 45 """test graceful handling of engine death (direct)"""
46 46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 47 # self.add_engines(1)
48 48 eid = self.client.ids[-1]
49 49 ar = self.client[eid].apply_async(crash)
50 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 51 eid = ar.engine_id
52 52 tic = time.time()
53 53 while eid in self.client.ids and time.time()-tic < 5:
54 54 time.sleep(.01)
55 55 self.client.spin()
56 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 57
58 58 def test_push_pull(self):
59 59 """test pushing and pulling"""
60 60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 61 t = self.client.ids[-1]
62 62 v = self.client[t]
63 63 push = v.push
64 64 pull = v.pull
65 65 v.block=True
66 66 nengines = len(self.client)
67 67 push({'data':data})
68 68 d = pull('data')
69 69 self.assertEquals(d, data)
70 70 self.client[:].push({'data':data})
71 71 d = self.client[:].pull('data', block=True)
72 72 self.assertEquals(d, nengines*[data])
73 73 ar = push({'data':data}, block=False)
74 74 self.assertTrue(isinstance(ar, AsyncResult))
75 75 r = ar.get()
76 76 ar = self.client[:].pull('data', block=False)
77 77 self.assertTrue(isinstance(ar, AsyncResult))
78 78 r = ar.get()
79 79 self.assertEquals(r, nengines*[data])
80 80 self.client[:].push(dict(a=10,b=20))
81 81 r = self.client[:].pull(('a','b'), block=True)
82 82 self.assertEquals(r, nengines*[[10,20]])
83 83
84 84 def test_push_pull_function(self):
85 85 "test pushing and pulling functions"
86 86 def testf(x):
87 87 return 2.0*x
88 88
89 89 t = self.client.ids[-1]
90 90 v = self.client[t]
91 91 v.block=True
92 92 push = v.push
93 93 pull = v.pull
94 94 execute = v.execute
95 95 push({'testf':testf})
96 96 r = pull('testf')
97 97 self.assertEqual(r(1.0), testf(1.0))
98 98 execute('r = testf(10)')
99 99 r = pull('r')
100 100 self.assertEquals(r, testf(10))
101 101 ar = self.client[:].push({'testf':testf}, block=False)
102 102 ar.get()
103 103 ar = self.client[:].pull('testf', block=False)
104 104 rlist = ar.get()
105 105 for r in rlist:
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute("def g(x): return x*x")
108 108 r = pull(('testf','g'))
109 109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 110
111 111 def test_push_function_globals(self):
112 112 """test that pushed functions have access to globals"""
113 113 @interactive
114 114 def geta():
115 115 return a
116 116 # self.add_engines(1)
117 117 v = self.client[-1]
118 118 v.block=True
119 119 v['f'] = geta
120 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 121 v.execute('a=5')
122 122 v.execute('b=f()')
123 123 self.assertEquals(v['b'], 5)
124 124
125 125 def test_push_function_defaults(self):
126 126 """test that pushed functions preserve default args"""
127 127 def echo(a=10):
128 128 return a
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = echo
132 132 v.execute('b=f()')
133 133 self.assertEquals(v['b'], 10)
134 134
135 135 def test_get_result(self):
136 136 """test getting results from the Hub."""
137 137 c = pmod.Client(profile='iptest')
138 138 # self.add_engines(1)
139 139 t = c.ids[-1]
140 140 v = c[t]
141 141 v2 = self.client[t]
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 145 ahr = v2.get_result(ar.msg_ids)
146 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 147 self.assertEquals(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids)
149 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 150 c.spin()
151 151 c.close()
152 152
153 153 def test_run_newline(self):
154 154 """test that run appends newline to files"""
155 155 tmpfile = mktemp()
156 156 with open(tmpfile, 'w') as f:
157 157 f.write("""def g():
158 158 return 5
159 159 """)
160 160 v = self.client[-1]
161 161 v.run(tmpfile, block=True)
162 162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 163
164 164 def test_apply_tracked(self):
165 165 """test tracking for apply"""
166 166 # self.add_engines(1)
167 167 t = self.client.ids[-1]
168 168 v = self.client[t]
169 169 v.block=False
170 170 def echo(n=1024*1024, **kwargs):
171 171 with v.temp_flags(**kwargs):
172 172 return v.apply(lambda x: x, 'x'*n)
173 173 ar = echo(1, track=False)
174 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 175 self.assertTrue(ar.sent)
176 176 ar = echo(track=True)
177 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 178 self.assertEquals(ar.sent, ar._tracker.done)
179 179 ar._tracker.wait()
180 180 self.assertTrue(ar.sent)
181 181
182 182 def test_push_tracked(self):
183 183 t = self.client.ids[-1]
184 184 ns = dict(x='x'*1024*1024)
185 185 v = self.client[t]
186 186 ar = v.push(ns, block=False, track=False)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertTrue(ar.sent)
189 189
190 190 ar = v.push(ns, block=False, track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 ar._tracker.wait()
193 193 self.assertEquals(ar.sent, ar._tracker.done)
194 194 self.assertTrue(ar.sent)
195 195 ar.get()
196 196
197 197 def test_scatter_tracked(self):
198 198 t = self.client.ids
199 199 x='x'*1024*1024
200 200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 self.assertEquals(ar.sent, ar._tracker.done)
207 207 ar._tracker.wait()
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_remote_reference(self):
212 212 v = self.client[-1]
213 213 v['a'] = 123
214 214 ra = pmod.Reference('a')
215 215 b = v.apply_sync(lambda x: x, ra)
216 216 self.assertEquals(b, 123)
217 217
218 218
219 219 def test_scatter_gather(self):
220 220 view = self.client[:]
221 221 seq1 = range(16)
222 222 view.scatter('a', seq1)
223 223 seq2 = view.gather('a', block=True)
224 224 self.assertEquals(seq2, seq1)
225 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 226
227 227 @skip_without('numpy')
228 228 def test_scatter_gather_numpy(self):
229 229 import numpy
230 230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 231 view = self.client[:]
232 232 a = numpy.arange(64)
233 233 view.scatter('a', a)
234 234 b = view.gather('a', block=True)
235 235 assert_array_equal(b, a)
236 236
237 @skip_without('numpy')
238 def test_apply_numpy(self):
239 """view.apply(f, ndarray)"""
240 import numpy
241 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
242
243 A = numpy.random.random((100,100))
244 view = self.client[-1]
245 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
246 B = A.astype(dt)
247 C = view.apply_sync(lambda x:x, B)
248 assert_array_equal(B,C)
249
237 250 def test_map(self):
238 251 view = self.client[:]
239 252 def f(x):
240 253 return x**2
241 254 data = range(16)
242 255 r = view.map_sync(f, data)
243 256 self.assertEquals(r, map(f, data))
244 257
245 258 def test_map_iterable(self):
246 259 """test map on iterables (direct)"""
247 260 view = self.client[:]
248 261 # 101 is prime, so it won't be evenly distributed
249 262 arr = range(101)
250 263 # ensure it will be an iterator, even in Python 3
251 264 it = iter(arr)
252 265 r = view.map_sync(lambda x:x, arr)
253 266 self.assertEquals(r, list(arr))
254 267
255 268 def test_scatterGatherNonblocking(self):
256 269 data = range(16)
257 270 view = self.client[:]
258 271 view.scatter('a', data, block=False)
259 272 ar = view.gather('a', block=False)
260 273 self.assertEquals(ar.get(), data)
261 274
262 275 @skip_without('numpy')
263 276 def test_scatter_gather_numpy_nonblocking(self):
264 277 import numpy
265 278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
266 279 a = numpy.arange(64)
267 280 view = self.client[:]
268 281 ar = view.scatter('a', a, block=False)
269 282 self.assertTrue(isinstance(ar, AsyncResult))
270 283 amr = view.gather('a', block=False)
271 284 self.assertTrue(isinstance(amr, AsyncMapResult))
272 285 assert_array_equal(amr.get(), a)
273 286
274 287 def test_execute(self):
275 288 view = self.client[:]
276 289 # self.client.debug=True
277 290 execute = view.execute
278 291 ar = execute('c=30', block=False)
279 292 self.assertTrue(isinstance(ar, AsyncResult))
280 293 ar = execute('d=[0,1,2]', block=False)
281 294 self.client.wait(ar, 1)
282 295 self.assertEquals(len(ar.get()), len(self.client))
283 296 for c in view['c']:
284 297 self.assertEquals(c, 30)
285 298
286 299 def test_abort(self):
287 300 view = self.client[-1]
288 301 ar = view.execute('import time; time.sleep(1)', block=False)
289 302 ar2 = view.apply_async(lambda : 2)
290 303 ar3 = view.apply_async(lambda : 3)
291 304 view.abort(ar2)
292 305 view.abort(ar3.msg_ids)
293 306 self.assertRaises(error.TaskAborted, ar2.get)
294 307 self.assertRaises(error.TaskAborted, ar3.get)
295 308
296 309 def test_abort_all(self):
297 310 """view.abort() aborts all outstanding tasks"""
298 311 view = self.client[-1]
299 312 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
300 313 view.abort()
301 314 view.wait(timeout=5)
302 315 for ar in ars[5:]:
303 316 self.assertRaises(error.TaskAborted, ar.get)
304 317
305 318 def test_temp_flags(self):
306 319 view = self.client[-1]
307 320 view.block=True
308 321 with view.temp_flags(block=False):
309 322 self.assertFalse(view.block)
310 323 self.assertTrue(view.block)
311 324
312 325 @dec.known_failure_py3
313 326 def test_importer(self):
314 327 view = self.client[-1]
315 328 view.clear(block=True)
316 329 with view.importer:
317 330 import re
318 331
319 332 @interactive
320 333 def findall(pat, s):
321 334 # this globals() step isn't necessary in real code
322 335 # only to prevent a closure in the test
323 336 re = globals()['re']
324 337 return re.findall(pat, s)
325 338
326 339 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
327 340
328 341 # parallel magic tests
329 342
330 343 def test_magic_px_blocking(self):
331 344 ip = get_ipython()
332 345 v = self.client[-1]
333 346 v.activate()
334 347 v.block=True
335 348
336 349 ip.magic_px('a=5')
337 350 self.assertEquals(v['a'], 5)
338 351 ip.magic_px('a=10')
339 352 self.assertEquals(v['a'], 10)
340 353 sio = StringIO()
341 354 savestdout = sys.stdout
342 355 sys.stdout = sio
343 356 # just 'print a' worst ~99% of the time, but this ensures that
344 357 # the stdout message has arrived when the result is finished:
345 358 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
346 359 sys.stdout = savestdout
347 360 buf = sio.getvalue()
348 361 self.assertTrue('[stdout:' in buf, buf)
349 362 self.assertTrue(buf.rstrip().endswith('10'))
350 363 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
351 364
352 365 def test_magic_px_nonblocking(self):
353 366 ip = get_ipython()
354 367 v = self.client[-1]
355 368 v.activate()
356 369 v.block=False
357 370
358 371 ip.magic_px('a=5')
359 372 self.assertEquals(v['a'], 5)
360 373 ip.magic_px('a=10')
361 374 self.assertEquals(v['a'], 10)
362 375 sio = StringIO()
363 376 savestdout = sys.stdout
364 377 sys.stdout = sio
365 378 ip.magic_px('print a')
366 379 sys.stdout = savestdout
367 380 buf = sio.getvalue()
368 381 self.assertFalse('[stdout:%i]'%v.targets in buf)
369 382 ip.magic_px('1/0')
370 383 ar = v.get_result(-1)
371 384 self.assertRaisesRemote(ZeroDivisionError, ar.get)
372 385
373 386 def test_magic_autopx_blocking(self):
374 387 ip = get_ipython()
375 388 v = self.client[-1]
376 389 v.activate()
377 390 v.block=True
378 391
379 392 sio = StringIO()
380 393 savestdout = sys.stdout
381 394 sys.stdout = sio
382 395 ip.magic_autopx()
383 396 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
384 397 ip.run_cell('print b')
385 398 ip.run_cell("b/c")
386 399 ip.run_code(compile('b*=2', '', 'single'))
387 400 ip.magic_autopx()
388 401 sys.stdout = savestdout
389 402 output = sio.getvalue().strip()
390 403 self.assertTrue(output.startswith('%autopx enabled'))
391 404 self.assertTrue(output.endswith('%autopx disabled'))
392 405 self.assertTrue('RemoteError: ZeroDivisionError' in output)
393 406 ar = v.get_result(-2)
394 407 self.assertEquals(v['a'], 5)
395 408 self.assertEquals(v['b'], 20)
396 409 self.assertRaisesRemote(ZeroDivisionError, ar.get)
397 410
398 411 def test_magic_autopx_nonblocking(self):
399 412 ip = get_ipython()
400 413 v = self.client[-1]
401 414 v.activate()
402 415 v.block=False
403 416
404 417 sio = StringIO()
405 418 savestdout = sys.stdout
406 419 sys.stdout = sio
407 420 ip.magic_autopx()
408 421 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
409 422 ip.run_cell('print b')
410 423 ip.run_cell("b/c")
411 424 ip.run_code(compile('b*=2', '', 'single'))
412 425 ip.magic_autopx()
413 426 sys.stdout = savestdout
414 427 output = sio.getvalue().strip()
415 428 self.assertTrue(output.startswith('%autopx enabled'))
416 429 self.assertTrue(output.endswith('%autopx disabled'))
417 430 self.assertFalse('ZeroDivisionError' in output)
418 431 ar = v.get_result(-2)
419 432 self.assertEquals(v['a'], 5)
420 433 self.assertEquals(v['b'], 20)
421 434 self.assertRaisesRemote(ZeroDivisionError, ar.get)
422 435
423 436 def test_magic_result(self):
424 437 ip = get_ipython()
425 438 v = self.client[-1]
426 439 v.activate()
427 440 v['a'] = 111
428 441 ra = v['a']
429 442
430 443 ar = ip.magic_result()
431 444 self.assertEquals(ar.msg_ids, [v.history[-1]])
432 445 self.assertEquals(ar.get(), 111)
433 446 ar = ip.magic_result('-2')
434 447 self.assertEquals(ar.msg_ids, [v.history[-2]])
435 448
436 449 def test_unicode_execute(self):
437 450 """test executing unicode strings"""
438 451 v = self.client[-1]
439 452 v.block=True
440 453 if sys.version_info[0] >= 3:
441 454 code="a='é'"
442 455 else:
443 456 code=u"a=u'é'"
444 457 v.execute(code)
445 458 self.assertEquals(v['a'], u'é')
446 459
447 460 def test_unicode_apply_result(self):
448 461 """test unicode apply results"""
449 462 v = self.client[-1]
450 463 r = v.apply_sync(lambda : u'é')
451 464 self.assertEquals(r, u'é')
452 465
453 466 def test_unicode_apply_arg(self):
454 467 """test passing unicode arguments to apply"""
455 468 v = self.client[-1]
456 469
457 470 @interactive
458 471 def check_unicode(a, check):
459 472 assert isinstance(a, unicode), "%r is not unicode"%a
460 473 assert isinstance(check, bytes), "%r is not bytes"%check
461 474 assert a.encode('utf8') == check, "%s != %s"%(a,check)
462 475
463 476 for s in [ u'é', u'ßø®∫',u'asdf' ]:
464 477 try:
465 478 v.apply_sync(check_unicode, s, s.encode('utf8'))
466 479 except error.RemoteError as e:
467 480 if e.ename == 'AssertionError':
468 481 self.fail(e.evalue)
469 482 else:
470 483 raise e
471 484
472 485 def test_map_reference(self):
473 486 """view.map(<Reference>, *seqs) should work"""
474 487 v = self.client[:]
475 488 v.scatter('n', self.client.ids, flatten=True)
476 489 v.execute("f = lambda x,y: x*y")
477 490 rf = pmod.Reference('f')
478 491 nlist = list(range(10))
479 492 mlist = nlist[::-1]
480 493 expected = [ m*n for m,n in zip(mlist, nlist) ]
481 494 result = v.map_sync(rf, mlist, nlist)
482 495 self.assertEquals(result, expected)
483 496
484 497 def test_apply_reference(self):
485 498 """view.apply(<Reference>, *args) should work"""
486 499 v = self.client[:]
487 500 v.scatter('n', self.client.ids, flatten=True)
488 501 v.execute("f = lambda x: n*x")
489 502 rf = pmod.Reference('f')
490 503 result = v.apply_sync(rf, 5)
491 504 expected = [ 5*id for id in self.client.ids ]
492 505 self.assertEquals(result, expected)
493 506
@@ -1,476 +1,480
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 # IPython imports
43 43 from IPython.config.application import Application
44 from IPython.utils import py3compat
44 45 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
45 46 from IPython.utils.newserialized import serialize, unserialize
46 47 from IPython.zmq.log import EnginePUBHandler
47 48
49 if py3compat.PY3:
50 buffer = memoryview
51
48 52 #-----------------------------------------------------------------------------
49 53 # Classes
50 54 #-----------------------------------------------------------------------------
51 55
52 56 class Namespace(dict):
53 57 """Subclass of dict for attribute access to keys."""
54 58
55 59 def __getattr__(self, key):
56 60 """getattr aliased to getitem"""
57 61 if key in self.iterkeys():
58 62 return self[key]
59 63 else:
60 64 raise NameError(key)
61 65
62 66 def __setattr__(self, key, value):
63 67 """setattr aliased to setitem, with strict"""
64 68 if hasattr(dict, key):
65 69 raise KeyError("Cannot override dict keys %r"%key)
66 70 self[key] = value
67 71
68 72
69 73 class ReverseDict(dict):
70 74 """simple double-keyed subset of dict methods."""
71 75
72 76 def __init__(self, *args, **kwargs):
73 77 dict.__init__(self, *args, **kwargs)
74 78 self._reverse = dict()
75 79 for key, value in self.iteritems():
76 80 self._reverse[value] = key
77 81
78 82 def __getitem__(self, key):
79 83 try:
80 84 return dict.__getitem__(self, key)
81 85 except KeyError:
82 86 return self._reverse[key]
83 87
84 88 def __setitem__(self, key, value):
85 89 if key in self._reverse:
86 90 raise KeyError("Can't have key %r on both sides!"%key)
87 91 dict.__setitem__(self, key, value)
88 92 self._reverse[value] = key
89 93
90 94 def pop(self, key):
91 95 value = dict.pop(self, key)
92 96 self._reverse.pop(value)
93 97 return value
94 98
95 99 def get(self, key, default=None):
96 100 try:
97 101 return self[key]
98 102 except KeyError:
99 103 return default
100 104
101 105 #-----------------------------------------------------------------------------
102 106 # Functions
103 107 #-----------------------------------------------------------------------------
104 108
105 109 def asbytes(s):
106 110 """ensure that an object is ascii bytes"""
107 111 if isinstance(s, unicode):
108 112 s = s.encode('ascii')
109 113 return s
110 114
111 115 def is_url(url):
112 116 """boolean check for whether a string is a zmq url"""
113 117 if '://' not in url:
114 118 return False
115 119 proto, addr = url.split('://', 1)
116 120 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
117 121 return False
118 122 return True
119 123
120 124 def validate_url(url):
121 125 """validate a url for zeromq"""
122 126 if not isinstance(url, basestring):
123 127 raise TypeError("url must be a string, not %r"%type(url))
124 128 url = url.lower()
125 129
126 130 proto_addr = url.split('://')
127 131 assert len(proto_addr) == 2, 'Invalid url: %r'%url
128 132 proto, addr = proto_addr
129 133 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
130 134
131 135 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
132 136 # author: Remi Sabourin
133 137 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
134 138
135 139 if proto == 'tcp':
136 140 lis = addr.split(':')
137 141 assert len(lis) == 2, 'Invalid url: %r'%url
138 142 addr,s_port = lis
139 143 try:
140 144 port = int(s_port)
141 145 except ValueError:
142 146 raise AssertionError("Invalid port %r in url: %r"%(port, url))
143 147
144 148 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
145 149
146 150 else:
147 151 # only validate tcp urls currently
148 152 pass
149 153
150 154 return True
151 155
152 156
153 157 def validate_url_container(container):
154 158 """validate a potentially nested collection of urls."""
155 159 if isinstance(container, basestring):
156 160 url = container
157 161 return validate_url(url)
158 162 elif isinstance(container, dict):
159 163 container = container.itervalues()
160 164
161 165 for element in container:
162 166 validate_url_container(element)
163 167
164 168
165 169 def split_url(url):
166 170 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
167 171 proto_addr = url.split('://')
168 172 assert len(proto_addr) == 2, 'Invalid url: %r'%url
169 173 proto, addr = proto_addr
170 174 lis = addr.split(':')
171 175 assert len(lis) == 2, 'Invalid url: %r'%url
172 176 addr,s_port = lis
173 177 return proto,addr,s_port
174 178
175 179 def disambiguate_ip_address(ip, location=None):
176 180 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
177 181 ones, based on the location (default interpretation of location is localhost)."""
178 182 if ip in ('0.0.0.0', '*'):
179 183 try:
180 184 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
181 185 except (socket.gaierror, IndexError):
182 186 # couldn't identify this machine, assume localhost
183 187 external_ips = []
184 188 if location is None or location in external_ips or not external_ips:
185 189 # If location is unspecified or cannot be determined, assume local
186 190 ip='127.0.0.1'
187 191 elif location:
188 192 return location
189 193 return ip
190 194
191 195 def disambiguate_url(url, location=None):
192 196 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
193 197 ones, based on the location (default interpretation is localhost).
194 198
195 199 This is for zeromq urls, such as tcp://*:10101."""
196 200 try:
197 201 proto,ip,port = split_url(url)
198 202 except AssertionError:
199 203 # probably not tcp url; could be ipc, etc.
200 204 return url
201 205
202 206 ip = disambiguate_ip_address(ip,location)
203 207
204 208 return "%s://%s:%s"%(proto,ip,port)
205 209
206 210 def serialize_object(obj, threshold=64e-6):
207 211 """Serialize an object into a list of sendable buffers.
208 212
209 213 Parameters
210 214 ----------
211 215
212 216 obj : object
213 217 The object to be serialized
214 218 threshold : float
215 219 The threshold for not double-pickling the content.
216 220
217 221
218 222 Returns
219 223 -------
220 224 ('pmd', [bufs]) :
221 225 where pmd is the pickled metadata wrapper,
222 226 bufs is a list of data buffers
223 227 """
224 228 databuffers = []
225 229 if isinstance(obj, (list, tuple)):
226 230 clist = canSequence(obj)
227 231 slist = map(serialize, clist)
228 232 for s in slist:
229 233 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 234 databuffers.append(s.getData())
231 235 s.data = None
232 236 return pickle.dumps(slist,-1), databuffers
233 237 elif isinstance(obj, dict):
234 238 sobj = {}
235 239 for k in sorted(obj.iterkeys()):
236 240 s = serialize(can(obj[k]))
237 241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
238 242 databuffers.append(s.getData())
239 243 s.data = None
240 244 sobj[k] = s
241 245 return pickle.dumps(sobj,-1),databuffers
242 246 else:
243 247 s = serialize(can(obj))
244 248 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
245 249 databuffers.append(s.getData())
246 250 s.data = None
247 251 return pickle.dumps(s,-1),databuffers
248 252
249 253
250 254 def unserialize_object(bufs):
251 255 """reconstruct an object serialized by serialize_object from data buffers."""
252 256 bufs = list(bufs)
253 257 sobj = pickle.loads(bufs.pop(0))
254 258 if isinstance(sobj, (list, tuple)):
255 259 for s in sobj:
256 260 if s.data is None:
257 261 s.data = bufs.pop(0)
258 262 return uncanSequence(map(unserialize, sobj)), bufs
259 263 elif isinstance(sobj, dict):
260 264 newobj = {}
261 265 for k in sorted(sobj.iterkeys()):
262 266 s = sobj[k]
263 267 if s.data is None:
264 268 s.data = bufs.pop(0)
265 269 newobj[k] = uncan(unserialize(s))
266 270 return newobj, bufs
267 271 else:
268 272 if sobj.data is None:
269 273 sobj.data = bufs.pop(0)
270 274 return uncan(unserialize(sobj)), bufs
271 275
272 276 def pack_apply_message(f, args, kwargs, threshold=64e-6):
273 277 """pack up a function, args, and kwargs to be sent over the wire
274 278 as a series of buffers. Any object whose data is larger than `threshold`
275 279 will not have their data copied (currently only numpy arrays support zero-copy)"""
276 280 msg = [pickle.dumps(can(f),-1)]
277 281 databuffers = [] # for large objects
278 282 sargs, bufs = serialize_object(args,threshold)
279 283 msg.append(sargs)
280 284 databuffers.extend(bufs)
281 285 skwargs, bufs = serialize_object(kwargs,threshold)
282 286 msg.append(skwargs)
283 287 databuffers.extend(bufs)
284 288 msg.extend(databuffers)
285 289 return msg
286 290
287 291 def unpack_apply_message(bufs, g=None, copy=True):
288 292 """unpack f,args,kwargs from buffers packed by pack_apply_message()
289 293 Returns: original f,args,kwargs"""
290 294 bufs = list(bufs) # allow us to pop
291 295 assert len(bufs) >= 3, "not enough buffers!"
292 296 if not copy:
293 297 for i in range(3):
294 298 bufs[i] = bufs[i].bytes
295 299 cf = pickle.loads(bufs.pop(0))
296 300 sargs = list(pickle.loads(bufs.pop(0)))
297 301 skwargs = dict(pickle.loads(bufs.pop(0)))
298 302 # print sargs, skwargs
299 303 f = uncan(cf, g)
300 304 for sa in sargs:
301 305 if sa.data is None:
302 306 m = bufs.pop(0)
303 307 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
304 308 # always use a buffer, until memoryviews get sorted out
305 309 sa.data = buffer(m)
306 310 # disable memoryview support
307 311 # if copy:
308 312 # sa.data = buffer(m)
309 313 # else:
310 314 # sa.data = m.buffer
311 315 else:
312 316 if copy:
313 317 sa.data = m
314 318 else:
315 319 sa.data = m.bytes
316 320
317 321 args = uncanSequence(map(unserialize, sargs), g)
318 322 kwargs = {}
319 323 for k in sorted(skwargs.iterkeys()):
320 324 sa = skwargs[k]
321 325 if sa.data is None:
322 326 m = bufs.pop(0)
323 327 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
324 328 # always use a buffer, until memoryviews get sorted out
325 329 sa.data = buffer(m)
326 330 # disable memoryview support
327 331 # if copy:
328 332 # sa.data = buffer(m)
329 333 # else:
330 334 # sa.data = m.buffer
331 335 else:
332 336 if copy:
333 337 sa.data = m
334 338 else:
335 339 sa.data = m.bytes
336 340
337 341 kwargs[k] = uncan(unserialize(sa), g)
338 342
339 343 return f,args,kwargs
340 344
341 345 #--------------------------------------------------------------------------
342 346 # helpers for implementing old MEC API via view.apply
343 347 #--------------------------------------------------------------------------
344 348
345 349 def interactive(f):
346 350 """decorator for making functions appear as interactively defined.
347 351 This results in the function being linked to the user_ns as globals()
348 352 instead of the module globals().
349 353 """
350 354 f.__module__ = '__main__'
351 355 return f
352 356
353 357 @interactive
354 358 def _push(ns):
355 359 """helper method for implementing `client.push` via `client.apply`"""
356 360 globals().update(ns)
357 361
358 362 @interactive
359 363 def _pull(keys):
360 364 """helper method for implementing `client.pull` via `client.apply`"""
361 365 user_ns = globals()
362 366 if isinstance(keys, (list,tuple, set)):
363 367 for key in keys:
364 368 if not user_ns.has_key(key):
365 369 raise NameError("name '%s' is not defined"%key)
366 370 return map(user_ns.get, keys)
367 371 else:
368 372 if not user_ns.has_key(keys):
369 373 raise NameError("name '%s' is not defined"%keys)
370 374 return user_ns.get(keys)
371 375
372 376 @interactive
373 377 def _execute(code):
374 378 """helper method for implementing `client.execute` via `client.apply`"""
375 379 exec code in globals()
376 380
377 381 #--------------------------------------------------------------------------
378 382 # extra process management utilities
379 383 #--------------------------------------------------------------------------
380 384
381 385 _random_ports = set()
382 386
383 387 def select_random_ports(n):
384 388 """Selects and return n random ports that are available."""
385 389 ports = []
386 390 for i in xrange(n):
387 391 sock = socket.socket()
388 392 sock.bind(('', 0))
389 393 while sock.getsockname()[1] in _random_ports:
390 394 sock.close()
391 395 sock = socket.socket()
392 396 sock.bind(('', 0))
393 397 ports.append(sock)
394 398 for i, sock in enumerate(ports):
395 399 port = sock.getsockname()[1]
396 400 sock.close()
397 401 ports[i] = port
398 402 _random_ports.add(port)
399 403 return ports
400 404
401 405 def signal_children(children):
402 406 """Relay interupt/term signals to children, for more solid process cleanup."""
403 407 def terminate_children(sig, frame):
404 408 log = Application.instance().log
405 409 log.critical("Got signal %i, terminating children..."%sig)
406 410 for child in children:
407 411 child.terminate()
408 412
409 413 sys.exit(sig != SIGINT)
410 414 # sys.exit(sig)
411 415 for sig in (SIGINT, SIGABRT, SIGTERM):
412 416 signal(sig, terminate_children)
413 417
414 418 def generate_exec_key(keyfile):
415 419 import uuid
416 420 newkey = str(uuid.uuid4())
417 421 with open(keyfile, 'w') as f:
418 422 # f.write('ipython-key ')
419 423 f.write(newkey+'\n')
420 424 # set user-only RW permissions (0600)
421 425 # this will have no effect on Windows
422 426 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
423 427
424 428
425 429 def integer_loglevel(loglevel):
426 430 try:
427 431 loglevel = int(loglevel)
428 432 except ValueError:
429 433 if isinstance(loglevel, str):
430 434 loglevel = getattr(logging, loglevel)
431 435 return loglevel
432 436
433 437 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
434 438 logger = logging.getLogger(logname)
435 439 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
436 440 # don't add a second PUBHandler
437 441 return
438 442 loglevel = integer_loglevel(loglevel)
439 443 lsock = context.socket(zmq.PUB)
440 444 lsock.connect(iface)
441 445 handler = handlers.PUBHandler(lsock)
442 446 handler.setLevel(loglevel)
443 447 handler.root_topic = root
444 448 logger.addHandler(handler)
445 449 logger.setLevel(loglevel)
446 450
447 451 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
448 452 logger = logging.getLogger()
449 453 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
450 454 # don't add a second PUBHandler
451 455 return
452 456 loglevel = integer_loglevel(loglevel)
453 457 lsock = context.socket(zmq.PUB)
454 458 lsock.connect(iface)
455 459 handler = EnginePUBHandler(engine, lsock)
456 460 handler.setLevel(loglevel)
457 461 logger.addHandler(handler)
458 462 logger.setLevel(loglevel)
459 463 return logger
460 464
461 465 def local_logger(logname, loglevel=logging.DEBUG):
462 466 loglevel = integer_loglevel(loglevel)
463 467 logger = logging.getLogger(logname)
464 468 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
465 469 # don't add a second StreamHandler
466 470 return
467 471 handler = logging.StreamHandler()
468 472 handler.setLevel(loglevel)
469 473 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
470 474 datefmt="%Y-%m-%d %H:%M:%S")
471 475 handler.setFormatter(formatter)
472 476
473 477 logger.addHandler(handler)
474 478 logger.setLevel(loglevel)
475 479 return logger
476 480
General Comments 0
You need to be logged in to leave comments. Login now