##// END OF EJS Templates
fix dangling `buffer` in IPython.parallel.util...
MinRK -
Show More
@@ -1,493 +1,506
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28
28
29 from IPython import parallel as pmod
29 from IPython import parallel as pmod
30 from IPython.parallel import error
30 from IPython.parallel import error
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import DirectView
32 from IPython.parallel import DirectView
33 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
34
34
35 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
36
36
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38
38
39 def setup():
39 def setup():
40 add_engines(3)
40 add_engines(3)
41
41
42 class TestView(ClusterTestCase):
42 class TestView(ClusterTestCase):
43
43
44 def test_z_crash_mux(self):
44 def test_z_crash_mux(self):
45 """test graceful handling of engine death (direct)"""
45 """test graceful handling of engine death (direct)"""
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEquals(d, data)
69 self.assertEquals(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEquals(d, nengines*[data])
72 self.assertEquals(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEquals(r, nengines*[data])
79 self.assertEquals(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEquals(r, nengines*[[10,20]])
82 self.assertEquals(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEquals(r, testf(10))
100 self.assertEquals(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEquals(v['b'], 5)
123 self.assertEquals(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEquals(v['b'], 10)
133 self.assertEquals(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids)
145 ahr = v2.get_result(ar.msg_ids)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertEquals(ahr.get(), ar.get())
147 self.assertEquals(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids)
148 ar2 = v2.get_result(ar.msg_ids)
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 c.spin()
150 c.spin()
151 c.close()
151 c.close()
152
152
153 def test_run_newline(self):
153 def test_run_newline(self):
154 """test that run appends newline to files"""
154 """test that run appends newline to files"""
155 tmpfile = mktemp()
155 tmpfile = mktemp()
156 with open(tmpfile, 'w') as f:
156 with open(tmpfile, 'w') as f:
157 f.write("""def g():
157 f.write("""def g():
158 return 5
158 return 5
159 """)
159 """)
160 v = self.client[-1]
160 v = self.client[-1]
161 v.run(tmpfile, block=True)
161 v.run(tmpfile, block=True)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
163
164 def test_apply_tracked(self):
164 def test_apply_tracked(self):
165 """test tracking for apply"""
165 """test tracking for apply"""
166 # self.add_engines(1)
166 # self.add_engines(1)
167 t = self.client.ids[-1]
167 t = self.client.ids[-1]
168 v = self.client[t]
168 v = self.client[t]
169 v.block=False
169 v.block=False
170 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
173 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176 ar = echo(track=True)
176 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEquals(ar.sent, ar._tracker.done)
178 self.assertEquals(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
179 ar._tracker.wait()
180 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
181
181
182 def test_push_tracked(self):
182 def test_push_tracked(self):
183 t = self.client.ids[-1]
183 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
185 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
189
189
190 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
192 ar._tracker.wait()
193 self.assertEquals(ar.sent, ar._tracker.done)
193 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195 ar.get()
195 ar.get()
196
196
197 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
198 t = self.client.ids
198 t = self.client.ids
199 x='x'*1024*1024
199 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEquals(ar.sent, ar._tracker.done)
206 self.assertEquals(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
207 ar._tracker.wait()
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_remote_reference(self):
211 def test_remote_reference(self):
212 v = self.client[-1]
212 v = self.client[-1]
213 v['a'] = 123
213 v['a'] = 123
214 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
216 self.assertEquals(b, 123)
216 self.assertEquals(b, 123)
217
217
218
218
219 def test_scatter_gather(self):
219 def test_scatter_gather(self):
220 view = self.client[:]
220 view = self.client[:]
221 seq1 = range(16)
221 seq1 = range(16)
222 view.scatter('a', seq1)
222 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
224 self.assertEquals(seq2, seq1)
224 self.assertEquals(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
226
227 @skip_without('numpy')
227 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
229 import numpy
229 import numpy
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 view = self.client[:]
231 view = self.client[:]
232 a = numpy.arange(64)
232 a = numpy.arange(64)
233 view.scatter('a', a)
233 view.scatter('a', a)
234 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
235 assert_array_equal(b, a)
236
236
237 @skip_without('numpy')
238 def test_apply_numpy(self):
239 """view.apply(f, ndarray)"""
240 import numpy
241 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
242
243 A = numpy.random.random((100,100))
244 view = self.client[-1]
245 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
246 B = A.astype(dt)
247 C = view.apply_sync(lambda x:x, B)
248 assert_array_equal(B,C)
249
237 def test_map(self):
250 def test_map(self):
238 view = self.client[:]
251 view = self.client[:]
239 def f(x):
252 def f(x):
240 return x**2
253 return x**2
241 data = range(16)
254 data = range(16)
242 r = view.map_sync(f, data)
255 r = view.map_sync(f, data)
243 self.assertEquals(r, map(f, data))
256 self.assertEquals(r, map(f, data))
244
257
245 def test_map_iterable(self):
258 def test_map_iterable(self):
246 """test map on iterables (direct)"""
259 """test map on iterables (direct)"""
247 view = self.client[:]
260 view = self.client[:]
248 # 101 is prime, so it won't be evenly distributed
261 # 101 is prime, so it won't be evenly distributed
249 arr = range(101)
262 arr = range(101)
250 # ensure it will be an iterator, even in Python 3
263 # ensure it will be an iterator, even in Python 3
251 it = iter(arr)
264 it = iter(arr)
252 r = view.map_sync(lambda x:x, arr)
265 r = view.map_sync(lambda x:x, arr)
253 self.assertEquals(r, list(arr))
266 self.assertEquals(r, list(arr))
254
267
255 def test_scatterGatherNonblocking(self):
268 def test_scatterGatherNonblocking(self):
256 data = range(16)
269 data = range(16)
257 view = self.client[:]
270 view = self.client[:]
258 view.scatter('a', data, block=False)
271 view.scatter('a', data, block=False)
259 ar = view.gather('a', block=False)
272 ar = view.gather('a', block=False)
260 self.assertEquals(ar.get(), data)
273 self.assertEquals(ar.get(), data)
261
274
262 @skip_without('numpy')
275 @skip_without('numpy')
263 def test_scatter_gather_numpy_nonblocking(self):
276 def test_scatter_gather_numpy_nonblocking(self):
264 import numpy
277 import numpy
265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
266 a = numpy.arange(64)
279 a = numpy.arange(64)
267 view = self.client[:]
280 view = self.client[:]
268 ar = view.scatter('a', a, block=False)
281 ar = view.scatter('a', a, block=False)
269 self.assertTrue(isinstance(ar, AsyncResult))
282 self.assertTrue(isinstance(ar, AsyncResult))
270 amr = view.gather('a', block=False)
283 amr = view.gather('a', block=False)
271 self.assertTrue(isinstance(amr, AsyncMapResult))
284 self.assertTrue(isinstance(amr, AsyncMapResult))
272 assert_array_equal(amr.get(), a)
285 assert_array_equal(amr.get(), a)
273
286
274 def test_execute(self):
287 def test_execute(self):
275 view = self.client[:]
288 view = self.client[:]
276 # self.client.debug=True
289 # self.client.debug=True
277 execute = view.execute
290 execute = view.execute
278 ar = execute('c=30', block=False)
291 ar = execute('c=30', block=False)
279 self.assertTrue(isinstance(ar, AsyncResult))
292 self.assertTrue(isinstance(ar, AsyncResult))
280 ar = execute('d=[0,1,2]', block=False)
293 ar = execute('d=[0,1,2]', block=False)
281 self.client.wait(ar, 1)
294 self.client.wait(ar, 1)
282 self.assertEquals(len(ar.get()), len(self.client))
295 self.assertEquals(len(ar.get()), len(self.client))
283 for c in view['c']:
296 for c in view['c']:
284 self.assertEquals(c, 30)
297 self.assertEquals(c, 30)
285
298
286 def test_abort(self):
299 def test_abort(self):
287 view = self.client[-1]
300 view = self.client[-1]
288 ar = view.execute('import time; time.sleep(1)', block=False)
301 ar = view.execute('import time; time.sleep(1)', block=False)
289 ar2 = view.apply_async(lambda : 2)
302 ar2 = view.apply_async(lambda : 2)
290 ar3 = view.apply_async(lambda : 3)
303 ar3 = view.apply_async(lambda : 3)
291 view.abort(ar2)
304 view.abort(ar2)
292 view.abort(ar3.msg_ids)
305 view.abort(ar3.msg_ids)
293 self.assertRaises(error.TaskAborted, ar2.get)
306 self.assertRaises(error.TaskAborted, ar2.get)
294 self.assertRaises(error.TaskAborted, ar3.get)
307 self.assertRaises(error.TaskAborted, ar3.get)
295
308
296 def test_abort_all(self):
309 def test_abort_all(self):
297 """view.abort() aborts all outstanding tasks"""
310 """view.abort() aborts all outstanding tasks"""
298 view = self.client[-1]
311 view = self.client[-1]
299 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
312 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
300 view.abort()
313 view.abort()
301 view.wait(timeout=5)
314 view.wait(timeout=5)
302 for ar in ars[5:]:
315 for ar in ars[5:]:
303 self.assertRaises(error.TaskAborted, ar.get)
316 self.assertRaises(error.TaskAborted, ar.get)
304
317
305 def test_temp_flags(self):
318 def test_temp_flags(self):
306 view = self.client[-1]
319 view = self.client[-1]
307 view.block=True
320 view.block=True
308 with view.temp_flags(block=False):
321 with view.temp_flags(block=False):
309 self.assertFalse(view.block)
322 self.assertFalse(view.block)
310 self.assertTrue(view.block)
323 self.assertTrue(view.block)
311
324
312 @dec.known_failure_py3
325 @dec.known_failure_py3
313 def test_importer(self):
326 def test_importer(self):
314 view = self.client[-1]
327 view = self.client[-1]
315 view.clear(block=True)
328 view.clear(block=True)
316 with view.importer:
329 with view.importer:
317 import re
330 import re
318
331
319 @interactive
332 @interactive
320 def findall(pat, s):
333 def findall(pat, s):
321 # this globals() step isn't necessary in real code
334 # this globals() step isn't necessary in real code
322 # only to prevent a closure in the test
335 # only to prevent a closure in the test
323 re = globals()['re']
336 re = globals()['re']
324 return re.findall(pat, s)
337 return re.findall(pat, s)
325
338
326 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
339 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
327
340
328 # parallel magic tests
341 # parallel magic tests
329
342
330 def test_magic_px_blocking(self):
343 def test_magic_px_blocking(self):
331 ip = get_ipython()
344 ip = get_ipython()
332 v = self.client[-1]
345 v = self.client[-1]
333 v.activate()
346 v.activate()
334 v.block=True
347 v.block=True
335
348
336 ip.magic_px('a=5')
349 ip.magic_px('a=5')
337 self.assertEquals(v['a'], 5)
350 self.assertEquals(v['a'], 5)
338 ip.magic_px('a=10')
351 ip.magic_px('a=10')
339 self.assertEquals(v['a'], 10)
352 self.assertEquals(v['a'], 10)
340 sio = StringIO()
353 sio = StringIO()
341 savestdout = sys.stdout
354 savestdout = sys.stdout
342 sys.stdout = sio
355 sys.stdout = sio
343 # just 'print a' worst ~99% of the time, but this ensures that
356 # just 'print a' worst ~99% of the time, but this ensures that
344 # the stdout message has arrived when the result is finished:
357 # the stdout message has arrived when the result is finished:
345 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
358 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
346 sys.stdout = savestdout
359 sys.stdout = savestdout
347 buf = sio.getvalue()
360 buf = sio.getvalue()
348 self.assertTrue('[stdout:' in buf, buf)
361 self.assertTrue('[stdout:' in buf, buf)
349 self.assertTrue(buf.rstrip().endswith('10'))
362 self.assertTrue(buf.rstrip().endswith('10'))
350 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
363 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
351
364
352 def test_magic_px_nonblocking(self):
365 def test_magic_px_nonblocking(self):
353 ip = get_ipython()
366 ip = get_ipython()
354 v = self.client[-1]
367 v = self.client[-1]
355 v.activate()
368 v.activate()
356 v.block=False
369 v.block=False
357
370
358 ip.magic_px('a=5')
371 ip.magic_px('a=5')
359 self.assertEquals(v['a'], 5)
372 self.assertEquals(v['a'], 5)
360 ip.magic_px('a=10')
373 ip.magic_px('a=10')
361 self.assertEquals(v['a'], 10)
374 self.assertEquals(v['a'], 10)
362 sio = StringIO()
375 sio = StringIO()
363 savestdout = sys.stdout
376 savestdout = sys.stdout
364 sys.stdout = sio
377 sys.stdout = sio
365 ip.magic_px('print a')
378 ip.magic_px('print a')
366 sys.stdout = savestdout
379 sys.stdout = savestdout
367 buf = sio.getvalue()
380 buf = sio.getvalue()
368 self.assertFalse('[stdout:%i]'%v.targets in buf)
381 self.assertFalse('[stdout:%i]'%v.targets in buf)
369 ip.magic_px('1/0')
382 ip.magic_px('1/0')
370 ar = v.get_result(-1)
383 ar = v.get_result(-1)
371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
384 self.assertRaisesRemote(ZeroDivisionError, ar.get)
372
385
373 def test_magic_autopx_blocking(self):
386 def test_magic_autopx_blocking(self):
374 ip = get_ipython()
387 ip = get_ipython()
375 v = self.client[-1]
388 v = self.client[-1]
376 v.activate()
389 v.activate()
377 v.block=True
390 v.block=True
378
391
379 sio = StringIO()
392 sio = StringIO()
380 savestdout = sys.stdout
393 savestdout = sys.stdout
381 sys.stdout = sio
394 sys.stdout = sio
382 ip.magic_autopx()
395 ip.magic_autopx()
383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
396 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
384 ip.run_cell('print b')
397 ip.run_cell('print b')
385 ip.run_cell("b/c")
398 ip.run_cell("b/c")
386 ip.run_code(compile('b*=2', '', 'single'))
399 ip.run_code(compile('b*=2', '', 'single'))
387 ip.magic_autopx()
400 ip.magic_autopx()
388 sys.stdout = savestdout
401 sys.stdout = savestdout
389 output = sio.getvalue().strip()
402 output = sio.getvalue().strip()
390 self.assertTrue(output.startswith('%autopx enabled'))
403 self.assertTrue(output.startswith('%autopx enabled'))
391 self.assertTrue(output.endswith('%autopx disabled'))
404 self.assertTrue(output.endswith('%autopx disabled'))
392 self.assertTrue('RemoteError: ZeroDivisionError' in output)
405 self.assertTrue('RemoteError: ZeroDivisionError' in output)
393 ar = v.get_result(-2)
406 ar = v.get_result(-2)
394 self.assertEquals(v['a'], 5)
407 self.assertEquals(v['a'], 5)
395 self.assertEquals(v['b'], 20)
408 self.assertEquals(v['b'], 20)
396 self.assertRaisesRemote(ZeroDivisionError, ar.get)
409 self.assertRaisesRemote(ZeroDivisionError, ar.get)
397
410
398 def test_magic_autopx_nonblocking(self):
411 def test_magic_autopx_nonblocking(self):
399 ip = get_ipython()
412 ip = get_ipython()
400 v = self.client[-1]
413 v = self.client[-1]
401 v.activate()
414 v.activate()
402 v.block=False
415 v.block=False
403
416
404 sio = StringIO()
417 sio = StringIO()
405 savestdout = sys.stdout
418 savestdout = sys.stdout
406 sys.stdout = sio
419 sys.stdout = sio
407 ip.magic_autopx()
420 ip.magic_autopx()
408 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
421 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
409 ip.run_cell('print b')
422 ip.run_cell('print b')
410 ip.run_cell("b/c")
423 ip.run_cell("b/c")
411 ip.run_code(compile('b*=2', '', 'single'))
424 ip.run_code(compile('b*=2', '', 'single'))
412 ip.magic_autopx()
425 ip.magic_autopx()
413 sys.stdout = savestdout
426 sys.stdout = savestdout
414 output = sio.getvalue().strip()
427 output = sio.getvalue().strip()
415 self.assertTrue(output.startswith('%autopx enabled'))
428 self.assertTrue(output.startswith('%autopx enabled'))
416 self.assertTrue(output.endswith('%autopx disabled'))
429 self.assertTrue(output.endswith('%autopx disabled'))
417 self.assertFalse('ZeroDivisionError' in output)
430 self.assertFalse('ZeroDivisionError' in output)
418 ar = v.get_result(-2)
431 ar = v.get_result(-2)
419 self.assertEquals(v['a'], 5)
432 self.assertEquals(v['a'], 5)
420 self.assertEquals(v['b'], 20)
433 self.assertEquals(v['b'], 20)
421 self.assertRaisesRemote(ZeroDivisionError, ar.get)
434 self.assertRaisesRemote(ZeroDivisionError, ar.get)
422
435
423 def test_magic_result(self):
436 def test_magic_result(self):
424 ip = get_ipython()
437 ip = get_ipython()
425 v = self.client[-1]
438 v = self.client[-1]
426 v.activate()
439 v.activate()
427 v['a'] = 111
440 v['a'] = 111
428 ra = v['a']
441 ra = v['a']
429
442
430 ar = ip.magic_result()
443 ar = ip.magic_result()
431 self.assertEquals(ar.msg_ids, [v.history[-1]])
444 self.assertEquals(ar.msg_ids, [v.history[-1]])
432 self.assertEquals(ar.get(), 111)
445 self.assertEquals(ar.get(), 111)
433 ar = ip.magic_result('-2')
446 ar = ip.magic_result('-2')
434 self.assertEquals(ar.msg_ids, [v.history[-2]])
447 self.assertEquals(ar.msg_ids, [v.history[-2]])
435
448
436 def test_unicode_execute(self):
449 def test_unicode_execute(self):
437 """test executing unicode strings"""
450 """test executing unicode strings"""
438 v = self.client[-1]
451 v = self.client[-1]
439 v.block=True
452 v.block=True
440 if sys.version_info[0] >= 3:
453 if sys.version_info[0] >= 3:
441 code="a='é'"
454 code="a='é'"
442 else:
455 else:
443 code=u"a=u'é'"
456 code=u"a=u'é'"
444 v.execute(code)
457 v.execute(code)
445 self.assertEquals(v['a'], u'é')
458 self.assertEquals(v['a'], u'é')
446
459
447 def test_unicode_apply_result(self):
460 def test_unicode_apply_result(self):
448 """test unicode apply results"""
461 """test unicode apply results"""
449 v = self.client[-1]
462 v = self.client[-1]
450 r = v.apply_sync(lambda : u'é')
463 r = v.apply_sync(lambda : u'é')
451 self.assertEquals(r, u'é')
464 self.assertEquals(r, u'é')
452
465
453 def test_unicode_apply_arg(self):
466 def test_unicode_apply_arg(self):
454 """test passing unicode arguments to apply"""
467 """test passing unicode arguments to apply"""
455 v = self.client[-1]
468 v = self.client[-1]
456
469
457 @interactive
470 @interactive
458 def check_unicode(a, check):
471 def check_unicode(a, check):
459 assert isinstance(a, unicode), "%r is not unicode"%a
472 assert isinstance(a, unicode), "%r is not unicode"%a
460 assert isinstance(check, bytes), "%r is not bytes"%check
473 assert isinstance(check, bytes), "%r is not bytes"%check
461 assert a.encode('utf8') == check, "%s != %s"%(a,check)
474 assert a.encode('utf8') == check, "%s != %s"%(a,check)
462
475
463 for s in [ u'é', u'ßø®∫',u'asdf' ]:
476 for s in [ u'é', u'ßø®∫',u'asdf' ]:
464 try:
477 try:
465 v.apply_sync(check_unicode, s, s.encode('utf8'))
478 v.apply_sync(check_unicode, s, s.encode('utf8'))
466 except error.RemoteError as e:
479 except error.RemoteError as e:
467 if e.ename == 'AssertionError':
480 if e.ename == 'AssertionError':
468 self.fail(e.evalue)
481 self.fail(e.evalue)
469 else:
482 else:
470 raise e
483 raise e
471
484
472 def test_map_reference(self):
485 def test_map_reference(self):
473 """view.map(<Reference>, *seqs) should work"""
486 """view.map(<Reference>, *seqs) should work"""
474 v = self.client[:]
487 v = self.client[:]
475 v.scatter('n', self.client.ids, flatten=True)
488 v.scatter('n', self.client.ids, flatten=True)
476 v.execute("f = lambda x,y: x*y")
489 v.execute("f = lambda x,y: x*y")
477 rf = pmod.Reference('f')
490 rf = pmod.Reference('f')
478 nlist = list(range(10))
491 nlist = list(range(10))
479 mlist = nlist[::-1]
492 mlist = nlist[::-1]
480 expected = [ m*n for m,n in zip(mlist, nlist) ]
493 expected = [ m*n for m,n in zip(mlist, nlist) ]
481 result = v.map_sync(rf, mlist, nlist)
494 result = v.map_sync(rf, mlist, nlist)
482 self.assertEquals(result, expected)
495 self.assertEquals(result, expected)
483
496
484 def test_apply_reference(self):
497 def test_apply_reference(self):
485 """view.apply(<Reference>, *args) should work"""
498 """view.apply(<Reference>, *args) should work"""
486 v = self.client[:]
499 v = self.client[:]
487 v.scatter('n', self.client.ids, flatten=True)
500 v.scatter('n', self.client.ids, flatten=True)
488 v.execute("f = lambda x: n*x")
501 v.execute("f = lambda x: n*x")
489 rf = pmod.Reference('f')
502 rf = pmod.Reference('f')
490 result = v.apply_sync(rf, 5)
503 result = v.apply_sync(rf, 5)
491 expected = [ 5*id for id in self.client.ids ]
504 expected = [ 5*id for id in self.client.ids ]
492 self.assertEquals(result, expected)
505 self.assertEquals(result, expected)
493
506
@@ -1,476 +1,480
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 # IPython imports
42 # IPython imports
43 from IPython.config.application import Application
43 from IPython.config.application import Application
44 from IPython.utils import py3compat
44 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
45 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
45 from IPython.utils.newserialized import serialize, unserialize
46 from IPython.utils.newserialized import serialize, unserialize
46 from IPython.zmq.log import EnginePUBHandler
47 from IPython.zmq.log import EnginePUBHandler
47
48
49 if py3compat.PY3:
50 buffer = memoryview
51
48 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
49 # Classes
53 # Classes
50 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
51
55
52 class Namespace(dict):
56 class Namespace(dict):
53 """Subclass of dict for attribute access to keys."""
57 """Subclass of dict for attribute access to keys."""
54
58
55 def __getattr__(self, key):
59 def __getattr__(self, key):
56 """getattr aliased to getitem"""
60 """getattr aliased to getitem"""
57 if key in self.iterkeys():
61 if key in self.iterkeys():
58 return self[key]
62 return self[key]
59 else:
63 else:
60 raise NameError(key)
64 raise NameError(key)
61
65
62 def __setattr__(self, key, value):
66 def __setattr__(self, key, value):
63 """setattr aliased to setitem, with strict"""
67 """setattr aliased to setitem, with strict"""
64 if hasattr(dict, key):
68 if hasattr(dict, key):
65 raise KeyError("Cannot override dict keys %r"%key)
69 raise KeyError("Cannot override dict keys %r"%key)
66 self[key] = value
70 self[key] = value
67
71
68
72
69 class ReverseDict(dict):
73 class ReverseDict(dict):
70 """simple double-keyed subset of dict methods."""
74 """simple double-keyed subset of dict methods."""
71
75
72 def __init__(self, *args, **kwargs):
76 def __init__(self, *args, **kwargs):
73 dict.__init__(self, *args, **kwargs)
77 dict.__init__(self, *args, **kwargs)
74 self._reverse = dict()
78 self._reverse = dict()
75 for key, value in self.iteritems():
79 for key, value in self.iteritems():
76 self._reverse[value] = key
80 self._reverse[value] = key
77
81
78 def __getitem__(self, key):
82 def __getitem__(self, key):
79 try:
83 try:
80 return dict.__getitem__(self, key)
84 return dict.__getitem__(self, key)
81 except KeyError:
85 except KeyError:
82 return self._reverse[key]
86 return self._reverse[key]
83
87
84 def __setitem__(self, key, value):
88 def __setitem__(self, key, value):
85 if key in self._reverse:
89 if key in self._reverse:
86 raise KeyError("Can't have key %r on both sides!"%key)
90 raise KeyError("Can't have key %r on both sides!"%key)
87 dict.__setitem__(self, key, value)
91 dict.__setitem__(self, key, value)
88 self._reverse[value] = key
92 self._reverse[value] = key
89
93
90 def pop(self, key):
94 def pop(self, key):
91 value = dict.pop(self, key)
95 value = dict.pop(self, key)
92 self._reverse.pop(value)
96 self._reverse.pop(value)
93 return value
97 return value
94
98
95 def get(self, key, default=None):
99 def get(self, key, default=None):
96 try:
100 try:
97 return self[key]
101 return self[key]
98 except KeyError:
102 except KeyError:
99 return default
103 return default
100
104
101 #-----------------------------------------------------------------------------
105 #-----------------------------------------------------------------------------
102 # Functions
106 # Functions
103 #-----------------------------------------------------------------------------
107 #-----------------------------------------------------------------------------
104
108
105 def asbytes(s):
109 def asbytes(s):
106 """ensure that an object is ascii bytes"""
110 """ensure that an object is ascii bytes"""
107 if isinstance(s, unicode):
111 if isinstance(s, unicode):
108 s = s.encode('ascii')
112 s = s.encode('ascii')
109 return s
113 return s
110
114
111 def is_url(url):
115 def is_url(url):
112 """boolean check for whether a string is a zmq url"""
116 """boolean check for whether a string is a zmq url"""
113 if '://' not in url:
117 if '://' not in url:
114 return False
118 return False
115 proto, addr = url.split('://', 1)
119 proto, addr = url.split('://', 1)
116 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
120 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
117 return False
121 return False
118 return True
122 return True
119
123
120 def validate_url(url):
124 def validate_url(url):
121 """validate a url for zeromq"""
125 """validate a url for zeromq"""
122 if not isinstance(url, basestring):
126 if not isinstance(url, basestring):
123 raise TypeError("url must be a string, not %r"%type(url))
127 raise TypeError("url must be a string, not %r"%type(url))
124 url = url.lower()
128 url = url.lower()
125
129
126 proto_addr = url.split('://')
130 proto_addr = url.split('://')
127 assert len(proto_addr) == 2, 'Invalid url: %r'%url
131 assert len(proto_addr) == 2, 'Invalid url: %r'%url
128 proto, addr = proto_addr
132 proto, addr = proto_addr
129 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
133 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
130
134
131 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
135 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
132 # author: Remi Sabourin
136 # author: Remi Sabourin
133 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
137 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
134
138
135 if proto == 'tcp':
139 if proto == 'tcp':
136 lis = addr.split(':')
140 lis = addr.split(':')
137 assert len(lis) == 2, 'Invalid url: %r'%url
141 assert len(lis) == 2, 'Invalid url: %r'%url
138 addr,s_port = lis
142 addr,s_port = lis
139 try:
143 try:
140 port = int(s_port)
144 port = int(s_port)
141 except ValueError:
145 except ValueError:
142 raise AssertionError("Invalid port %r in url: %r"%(port, url))
146 raise AssertionError("Invalid port %r in url: %r"%(port, url))
143
147
144 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
148 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
145
149
146 else:
150 else:
147 # only validate tcp urls currently
151 # only validate tcp urls currently
148 pass
152 pass
149
153
150 return True
154 return True
151
155
152
156
153 def validate_url_container(container):
157 def validate_url_container(container):
154 """validate a potentially nested collection of urls."""
158 """validate a potentially nested collection of urls."""
155 if isinstance(container, basestring):
159 if isinstance(container, basestring):
156 url = container
160 url = container
157 return validate_url(url)
161 return validate_url(url)
158 elif isinstance(container, dict):
162 elif isinstance(container, dict):
159 container = container.itervalues()
163 container = container.itervalues()
160
164
161 for element in container:
165 for element in container:
162 validate_url_container(element)
166 validate_url_container(element)
163
167
164
168
165 def split_url(url):
169 def split_url(url):
166 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
170 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
167 proto_addr = url.split('://')
171 proto_addr = url.split('://')
168 assert len(proto_addr) == 2, 'Invalid url: %r'%url
172 assert len(proto_addr) == 2, 'Invalid url: %r'%url
169 proto, addr = proto_addr
173 proto, addr = proto_addr
170 lis = addr.split(':')
174 lis = addr.split(':')
171 assert len(lis) == 2, 'Invalid url: %r'%url
175 assert len(lis) == 2, 'Invalid url: %r'%url
172 addr,s_port = lis
176 addr,s_port = lis
173 return proto,addr,s_port
177 return proto,addr,s_port
174
178
175 def disambiguate_ip_address(ip, location=None):
179 def disambiguate_ip_address(ip, location=None):
176 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
180 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
177 ones, based on the location (default interpretation of location is localhost)."""
181 ones, based on the location (default interpretation of location is localhost)."""
178 if ip in ('0.0.0.0', '*'):
182 if ip in ('0.0.0.0', '*'):
179 try:
183 try:
180 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
184 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
181 except (socket.gaierror, IndexError):
185 except (socket.gaierror, IndexError):
182 # couldn't identify this machine, assume localhost
186 # couldn't identify this machine, assume localhost
183 external_ips = []
187 external_ips = []
184 if location is None or location in external_ips or not external_ips:
188 if location is None or location in external_ips or not external_ips:
185 # If location is unspecified or cannot be determined, assume local
189 # If location is unspecified or cannot be determined, assume local
186 ip='127.0.0.1'
190 ip='127.0.0.1'
187 elif location:
191 elif location:
188 return location
192 return location
189 return ip
193 return ip
190
194
191 def disambiguate_url(url, location=None):
195 def disambiguate_url(url, location=None):
192 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
196 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
193 ones, based on the location (default interpretation is localhost).
197 ones, based on the location (default interpretation is localhost).
194
198
195 This is for zeromq urls, such as tcp://*:10101."""
199 This is for zeromq urls, such as tcp://*:10101."""
196 try:
200 try:
197 proto,ip,port = split_url(url)
201 proto,ip,port = split_url(url)
198 except AssertionError:
202 except AssertionError:
199 # probably not tcp url; could be ipc, etc.
203 # probably not tcp url; could be ipc, etc.
200 return url
204 return url
201
205
202 ip = disambiguate_ip_address(ip,location)
206 ip = disambiguate_ip_address(ip,location)
203
207
204 return "%s://%s:%s"%(proto,ip,port)
208 return "%s://%s:%s"%(proto,ip,port)
205
209
206 def serialize_object(obj, threshold=64e-6):
210 def serialize_object(obj, threshold=64e-6):
207 """Serialize an object into a list of sendable buffers.
211 """Serialize an object into a list of sendable buffers.
208
212
209 Parameters
213 Parameters
210 ----------
214 ----------
211
215
212 obj : object
216 obj : object
213 The object to be serialized
217 The object to be serialized
214 threshold : float
218 threshold : float
215 The threshold for not double-pickling the content.
219 The threshold for not double-pickling the content.
216
220
217
221
218 Returns
222 Returns
219 -------
223 -------
220 ('pmd', [bufs]) :
224 ('pmd', [bufs]) :
221 where pmd is the pickled metadata wrapper,
225 where pmd is the pickled metadata wrapper,
222 bufs is a list of data buffers
226 bufs is a list of data buffers
223 """
227 """
224 databuffers = []
228 databuffers = []
225 if isinstance(obj, (list, tuple)):
229 if isinstance(obj, (list, tuple)):
226 clist = canSequence(obj)
230 clist = canSequence(obj)
227 slist = map(serialize, clist)
231 slist = map(serialize, clist)
228 for s in slist:
232 for s in slist:
229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
233 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 databuffers.append(s.getData())
234 databuffers.append(s.getData())
231 s.data = None
235 s.data = None
232 return pickle.dumps(slist,-1), databuffers
236 return pickle.dumps(slist,-1), databuffers
233 elif isinstance(obj, dict):
237 elif isinstance(obj, dict):
234 sobj = {}
238 sobj = {}
235 for k in sorted(obj.iterkeys()):
239 for k in sorted(obj.iterkeys()):
236 s = serialize(can(obj[k]))
240 s = serialize(can(obj[k]))
237 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
238 databuffers.append(s.getData())
242 databuffers.append(s.getData())
239 s.data = None
243 s.data = None
240 sobj[k] = s
244 sobj[k] = s
241 return pickle.dumps(sobj,-1),databuffers
245 return pickle.dumps(sobj,-1),databuffers
242 else:
246 else:
243 s = serialize(can(obj))
247 s = serialize(can(obj))
244 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
248 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
245 databuffers.append(s.getData())
249 databuffers.append(s.getData())
246 s.data = None
250 s.data = None
247 return pickle.dumps(s,-1),databuffers
251 return pickle.dumps(s,-1),databuffers
248
252
249
253
250 def unserialize_object(bufs):
254 def unserialize_object(bufs):
251 """reconstruct an object serialized by serialize_object from data buffers."""
255 """reconstruct an object serialized by serialize_object from data buffers."""
252 bufs = list(bufs)
256 bufs = list(bufs)
253 sobj = pickle.loads(bufs.pop(0))
257 sobj = pickle.loads(bufs.pop(0))
254 if isinstance(sobj, (list, tuple)):
258 if isinstance(sobj, (list, tuple)):
255 for s in sobj:
259 for s in sobj:
256 if s.data is None:
260 if s.data is None:
257 s.data = bufs.pop(0)
261 s.data = bufs.pop(0)
258 return uncanSequence(map(unserialize, sobj)), bufs
262 return uncanSequence(map(unserialize, sobj)), bufs
259 elif isinstance(sobj, dict):
263 elif isinstance(sobj, dict):
260 newobj = {}
264 newobj = {}
261 for k in sorted(sobj.iterkeys()):
265 for k in sorted(sobj.iterkeys()):
262 s = sobj[k]
266 s = sobj[k]
263 if s.data is None:
267 if s.data is None:
264 s.data = bufs.pop(0)
268 s.data = bufs.pop(0)
265 newobj[k] = uncan(unserialize(s))
269 newobj[k] = uncan(unserialize(s))
266 return newobj, bufs
270 return newobj, bufs
267 else:
271 else:
268 if sobj.data is None:
272 if sobj.data is None:
269 sobj.data = bufs.pop(0)
273 sobj.data = bufs.pop(0)
270 return uncan(unserialize(sobj)), bufs
274 return uncan(unserialize(sobj)), bufs
271
275
272 def pack_apply_message(f, args, kwargs, threshold=64e-6):
276 def pack_apply_message(f, args, kwargs, threshold=64e-6):
273 """pack up a function, args, and kwargs to be sent over the wire
277 """pack up a function, args, and kwargs to be sent over the wire
274 as a series of buffers. Any object whose data is larger than `threshold`
278 as a series of buffers. Any object whose data is larger than `threshold`
275 will not have their data copied (currently only numpy arrays support zero-copy)"""
279 will not have their data copied (currently only numpy arrays support zero-copy)"""
276 msg = [pickle.dumps(can(f),-1)]
280 msg = [pickle.dumps(can(f),-1)]
277 databuffers = [] # for large objects
281 databuffers = [] # for large objects
278 sargs, bufs = serialize_object(args,threshold)
282 sargs, bufs = serialize_object(args,threshold)
279 msg.append(sargs)
283 msg.append(sargs)
280 databuffers.extend(bufs)
284 databuffers.extend(bufs)
281 skwargs, bufs = serialize_object(kwargs,threshold)
285 skwargs, bufs = serialize_object(kwargs,threshold)
282 msg.append(skwargs)
286 msg.append(skwargs)
283 databuffers.extend(bufs)
287 databuffers.extend(bufs)
284 msg.extend(databuffers)
288 msg.extend(databuffers)
285 return msg
289 return msg
286
290
287 def unpack_apply_message(bufs, g=None, copy=True):
291 def unpack_apply_message(bufs, g=None, copy=True):
288 """unpack f,args,kwargs from buffers packed by pack_apply_message()
292 """unpack f,args,kwargs from buffers packed by pack_apply_message()
289 Returns: original f,args,kwargs"""
293 Returns: original f,args,kwargs"""
290 bufs = list(bufs) # allow us to pop
294 bufs = list(bufs) # allow us to pop
291 assert len(bufs) >= 3, "not enough buffers!"
295 assert len(bufs) >= 3, "not enough buffers!"
292 if not copy:
296 if not copy:
293 for i in range(3):
297 for i in range(3):
294 bufs[i] = bufs[i].bytes
298 bufs[i] = bufs[i].bytes
295 cf = pickle.loads(bufs.pop(0))
299 cf = pickle.loads(bufs.pop(0))
296 sargs = list(pickle.loads(bufs.pop(0)))
300 sargs = list(pickle.loads(bufs.pop(0)))
297 skwargs = dict(pickle.loads(bufs.pop(0)))
301 skwargs = dict(pickle.loads(bufs.pop(0)))
298 # print sargs, skwargs
302 # print sargs, skwargs
299 f = uncan(cf, g)
303 f = uncan(cf, g)
300 for sa in sargs:
304 for sa in sargs:
301 if sa.data is None:
305 if sa.data is None:
302 m = bufs.pop(0)
306 m = bufs.pop(0)
303 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
307 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
304 # always use a buffer, until memoryviews get sorted out
308 # always use a buffer, until memoryviews get sorted out
305 sa.data = buffer(m)
309 sa.data = buffer(m)
306 # disable memoryview support
310 # disable memoryview support
307 # if copy:
311 # if copy:
308 # sa.data = buffer(m)
312 # sa.data = buffer(m)
309 # else:
313 # else:
310 # sa.data = m.buffer
314 # sa.data = m.buffer
311 else:
315 else:
312 if copy:
316 if copy:
313 sa.data = m
317 sa.data = m
314 else:
318 else:
315 sa.data = m.bytes
319 sa.data = m.bytes
316
320
317 args = uncanSequence(map(unserialize, sargs), g)
321 args = uncanSequence(map(unserialize, sargs), g)
318 kwargs = {}
322 kwargs = {}
319 for k in sorted(skwargs.iterkeys()):
323 for k in sorted(skwargs.iterkeys()):
320 sa = skwargs[k]
324 sa = skwargs[k]
321 if sa.data is None:
325 if sa.data is None:
322 m = bufs.pop(0)
326 m = bufs.pop(0)
323 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
327 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
324 # always use a buffer, until memoryviews get sorted out
328 # always use a buffer, until memoryviews get sorted out
325 sa.data = buffer(m)
329 sa.data = buffer(m)
326 # disable memoryview support
330 # disable memoryview support
327 # if copy:
331 # if copy:
328 # sa.data = buffer(m)
332 # sa.data = buffer(m)
329 # else:
333 # else:
330 # sa.data = m.buffer
334 # sa.data = m.buffer
331 else:
335 else:
332 if copy:
336 if copy:
333 sa.data = m
337 sa.data = m
334 else:
338 else:
335 sa.data = m.bytes
339 sa.data = m.bytes
336
340
337 kwargs[k] = uncan(unserialize(sa), g)
341 kwargs[k] = uncan(unserialize(sa), g)
338
342
339 return f,args,kwargs
343 return f,args,kwargs
340
344
341 #--------------------------------------------------------------------------
345 #--------------------------------------------------------------------------
342 # helpers for implementing old MEC API via view.apply
346 # helpers for implementing old MEC API via view.apply
343 #--------------------------------------------------------------------------
347 #--------------------------------------------------------------------------
344
348
345 def interactive(f):
349 def interactive(f):
346 """decorator for making functions appear as interactively defined.
350 """decorator for making functions appear as interactively defined.
347 This results in the function being linked to the user_ns as globals()
351 This results in the function being linked to the user_ns as globals()
348 instead of the module globals().
352 instead of the module globals().
349 """
353 """
350 f.__module__ = '__main__'
354 f.__module__ = '__main__'
351 return f
355 return f
352
356
353 @interactive
357 @interactive
354 def _push(ns):
358 def _push(ns):
355 """helper method for implementing `client.push` via `client.apply`"""
359 """helper method for implementing `client.push` via `client.apply`"""
356 globals().update(ns)
360 globals().update(ns)
357
361
358 @interactive
362 @interactive
359 def _pull(keys):
363 def _pull(keys):
360 """helper method for implementing `client.pull` via `client.apply`"""
364 """helper method for implementing `client.pull` via `client.apply`"""
361 user_ns = globals()
365 user_ns = globals()
362 if isinstance(keys, (list,tuple, set)):
366 if isinstance(keys, (list,tuple, set)):
363 for key in keys:
367 for key in keys:
364 if not user_ns.has_key(key):
368 if not user_ns.has_key(key):
365 raise NameError("name '%s' is not defined"%key)
369 raise NameError("name '%s' is not defined"%key)
366 return map(user_ns.get, keys)
370 return map(user_ns.get, keys)
367 else:
371 else:
368 if not user_ns.has_key(keys):
372 if not user_ns.has_key(keys):
369 raise NameError("name '%s' is not defined"%keys)
373 raise NameError("name '%s' is not defined"%keys)
370 return user_ns.get(keys)
374 return user_ns.get(keys)
371
375
372 @interactive
376 @interactive
373 def _execute(code):
377 def _execute(code):
374 """helper method for implementing `client.execute` via `client.apply`"""
378 """helper method for implementing `client.execute` via `client.apply`"""
375 exec code in globals()
379 exec code in globals()
376
380
377 #--------------------------------------------------------------------------
381 #--------------------------------------------------------------------------
378 # extra process management utilities
382 # extra process management utilities
379 #--------------------------------------------------------------------------
383 #--------------------------------------------------------------------------
380
384
381 _random_ports = set()
385 _random_ports = set()
382
386
383 def select_random_ports(n):
387 def select_random_ports(n):
384 """Selects and return n random ports that are available."""
388 """Selects and return n random ports that are available."""
385 ports = []
389 ports = []
386 for i in xrange(n):
390 for i in xrange(n):
387 sock = socket.socket()
391 sock = socket.socket()
388 sock.bind(('', 0))
392 sock.bind(('', 0))
389 while sock.getsockname()[1] in _random_ports:
393 while sock.getsockname()[1] in _random_ports:
390 sock.close()
394 sock.close()
391 sock = socket.socket()
395 sock = socket.socket()
392 sock.bind(('', 0))
396 sock.bind(('', 0))
393 ports.append(sock)
397 ports.append(sock)
394 for i, sock in enumerate(ports):
398 for i, sock in enumerate(ports):
395 port = sock.getsockname()[1]
399 port = sock.getsockname()[1]
396 sock.close()
400 sock.close()
397 ports[i] = port
401 ports[i] = port
398 _random_ports.add(port)
402 _random_ports.add(port)
399 return ports
403 return ports
400
404
401 def signal_children(children):
405 def signal_children(children):
402 """Relay interupt/term signals to children, for more solid process cleanup."""
406 """Relay interupt/term signals to children, for more solid process cleanup."""
403 def terminate_children(sig, frame):
407 def terminate_children(sig, frame):
404 log = Application.instance().log
408 log = Application.instance().log
405 log.critical("Got signal %i, terminating children..."%sig)
409 log.critical("Got signal %i, terminating children..."%sig)
406 for child in children:
410 for child in children:
407 child.terminate()
411 child.terminate()
408
412
409 sys.exit(sig != SIGINT)
413 sys.exit(sig != SIGINT)
410 # sys.exit(sig)
414 # sys.exit(sig)
411 for sig in (SIGINT, SIGABRT, SIGTERM):
415 for sig in (SIGINT, SIGABRT, SIGTERM):
412 signal(sig, terminate_children)
416 signal(sig, terminate_children)
413
417
414 def generate_exec_key(keyfile):
418 def generate_exec_key(keyfile):
415 import uuid
419 import uuid
416 newkey = str(uuid.uuid4())
420 newkey = str(uuid.uuid4())
417 with open(keyfile, 'w') as f:
421 with open(keyfile, 'w') as f:
418 # f.write('ipython-key ')
422 # f.write('ipython-key ')
419 f.write(newkey+'\n')
423 f.write(newkey+'\n')
420 # set user-only RW permissions (0600)
424 # set user-only RW permissions (0600)
421 # this will have no effect on Windows
425 # this will have no effect on Windows
422 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
426 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
423
427
424
428
425 def integer_loglevel(loglevel):
429 def integer_loglevel(loglevel):
426 try:
430 try:
427 loglevel = int(loglevel)
431 loglevel = int(loglevel)
428 except ValueError:
432 except ValueError:
429 if isinstance(loglevel, str):
433 if isinstance(loglevel, str):
430 loglevel = getattr(logging, loglevel)
434 loglevel = getattr(logging, loglevel)
431 return loglevel
435 return loglevel
432
436
433 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
437 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
434 logger = logging.getLogger(logname)
438 logger = logging.getLogger(logname)
435 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
439 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
436 # don't add a second PUBHandler
440 # don't add a second PUBHandler
437 return
441 return
438 loglevel = integer_loglevel(loglevel)
442 loglevel = integer_loglevel(loglevel)
439 lsock = context.socket(zmq.PUB)
443 lsock = context.socket(zmq.PUB)
440 lsock.connect(iface)
444 lsock.connect(iface)
441 handler = handlers.PUBHandler(lsock)
445 handler = handlers.PUBHandler(lsock)
442 handler.setLevel(loglevel)
446 handler.setLevel(loglevel)
443 handler.root_topic = root
447 handler.root_topic = root
444 logger.addHandler(handler)
448 logger.addHandler(handler)
445 logger.setLevel(loglevel)
449 logger.setLevel(loglevel)
446
450
447 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
451 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
448 logger = logging.getLogger()
452 logger = logging.getLogger()
449 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
453 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
450 # don't add a second PUBHandler
454 # don't add a second PUBHandler
451 return
455 return
452 loglevel = integer_loglevel(loglevel)
456 loglevel = integer_loglevel(loglevel)
453 lsock = context.socket(zmq.PUB)
457 lsock = context.socket(zmq.PUB)
454 lsock.connect(iface)
458 lsock.connect(iface)
455 handler = EnginePUBHandler(engine, lsock)
459 handler = EnginePUBHandler(engine, lsock)
456 handler.setLevel(loglevel)
460 handler.setLevel(loglevel)
457 logger.addHandler(handler)
461 logger.addHandler(handler)
458 logger.setLevel(loglevel)
462 logger.setLevel(loglevel)
459 return logger
463 return logger
460
464
461 def local_logger(logname, loglevel=logging.DEBUG):
465 def local_logger(logname, loglevel=logging.DEBUG):
462 loglevel = integer_loglevel(loglevel)
466 loglevel = integer_loglevel(loglevel)
463 logger = logging.getLogger(logname)
467 logger = logging.getLogger(logname)
464 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
468 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
465 # don't add a second StreamHandler
469 # don't add a second StreamHandler
466 return
470 return
467 handler = logging.StreamHandler()
471 handler = logging.StreamHandler()
468 handler.setLevel(loglevel)
472 handler.setLevel(loglevel)
469 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
473 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
470 datefmt="%Y-%m-%d %H:%M:%S")
474 datefmt="%Y-%m-%d %H:%M:%S")
471 handler.setFormatter(formatter)
475 handler.setFormatter(formatter)
472
476
473 logger.addHandler(handler)
477 logger.addHandler(handler)
474 logger.setLevel(loglevel)
478 logger.setLevel(loglevel)
475 return logger
479 return logger
476
480
General Comments 0
You need to be logged in to leave comments. Login now