##// END OF EJS Templates
Merge pull request #2327 from bfroehle/remote_push_pull_nested...
Min RK -
r8363:0af1e9d0 merge
parent child Browse files
Show More
@@ -1,703 +1,721 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30 from IPython.utils.io import capture_output
30 from IPython.utils.io import capture_output
31
31
32 from IPython import parallel as pmod
32 from IPython import parallel as pmod
33 from IPython.parallel import error
33 from IPython.parallel import error
34 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
35 from IPython.parallel import DirectView
35 from IPython.parallel import DirectView
36 from IPython.parallel.util import interactive
36 from IPython.parallel.util import interactive
37
37
38 from IPython.parallel.tests import add_engines
38 from IPython.parallel.tests import add_engines
39
39
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41
41
42 def setup():
42 def setup():
43 add_engines(3, total=True)
43 add_engines(3, total=True)
44
44
45 class TestView(ClusterTestCase, ParametricTestCase):
45 class TestView(ClusterTestCase, ParametricTestCase):
46
46
47 def setUp(self):
47 def setUp(self):
48 # On Win XP, wait for resource cleanup, else parallel test group fails
48 # On Win XP, wait for resource cleanup, else parallel test group fails
49 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
50 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
51 time.sleep(2)
51 time.sleep(2)
52 super(TestView, self).setUp()
52 super(TestView, self).setUp()
53
53
54 def test_z_crash_mux(self):
54 def test_z_crash_mux(self):
55 """test graceful handling of engine death (direct)"""
55 """test graceful handling of engine death (direct)"""
56 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 raise SkipTest("crash tests disabled, due to undesirable crash reports")
57 # self.add_engines(1)
57 # self.add_engines(1)
58 eid = self.client.ids[-1]
58 eid = self.client.ids[-1]
59 ar = self.client[eid].apply_async(crash)
59 ar = self.client[eid].apply_async(crash)
60 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 self.assertRaisesRemote(error.EngineError, ar.get, 10)
61 eid = ar.engine_id
61 eid = ar.engine_id
62 tic = time.time()
62 tic = time.time()
63 while eid in self.client.ids and time.time()-tic < 5:
63 while eid in self.client.ids and time.time()-tic < 5:
64 time.sleep(.01)
64 time.sleep(.01)
65 self.client.spin()
65 self.client.spin()
66 self.assertFalse(eid in self.client.ids, "Engine should have died")
66 self.assertFalse(eid in self.client.ids, "Engine should have died")
67
67
68 def test_push_pull(self):
68 def test_push_pull(self):
69 """test pushing and pulling"""
69 """test pushing and pulling"""
70 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
71 t = self.client.ids[-1]
71 t = self.client.ids[-1]
72 v = self.client[t]
72 v = self.client[t]
73 push = v.push
73 push = v.push
74 pull = v.pull
74 pull = v.pull
75 v.block=True
75 v.block=True
76 nengines = len(self.client)
76 nengines = len(self.client)
77 push({'data':data})
77 push({'data':data})
78 d = pull('data')
78 d = pull('data')
79 self.assertEqual(d, data)
79 self.assertEqual(d, data)
80 self.client[:].push({'data':data})
80 self.client[:].push({'data':data})
81 d = self.client[:].pull('data', block=True)
81 d = self.client[:].pull('data', block=True)
82 self.assertEqual(d, nengines*[data])
82 self.assertEqual(d, nengines*[data])
83 ar = push({'data':data}, block=False)
83 ar = push({'data':data}, block=False)
84 self.assertTrue(isinstance(ar, AsyncResult))
84 self.assertTrue(isinstance(ar, AsyncResult))
85 r = ar.get()
85 r = ar.get()
86 ar = self.client[:].pull('data', block=False)
86 ar = self.client[:].pull('data', block=False)
87 self.assertTrue(isinstance(ar, AsyncResult))
87 self.assertTrue(isinstance(ar, AsyncResult))
88 r = ar.get()
88 r = ar.get()
89 self.assertEqual(r, nengines*[data])
89 self.assertEqual(r, nengines*[data])
90 self.client[:].push(dict(a=10,b=20))
90 self.client[:].push(dict(a=10,b=20))
91 r = self.client[:].pull(('a','b'), block=True)
91 r = self.client[:].pull(('a','b'), block=True)
92 self.assertEqual(r, nengines*[[10,20]])
92 self.assertEqual(r, nengines*[[10,20]])
93
93
94 def test_push_pull_function(self):
94 def test_push_pull_function(self):
95 "test pushing and pulling functions"
95 "test pushing and pulling functions"
96 def testf(x):
96 def testf(x):
97 return 2.0*x
97 return 2.0*x
98
98
99 t = self.client.ids[-1]
99 t = self.client.ids[-1]
100 v = self.client[t]
100 v = self.client[t]
101 v.block=True
101 v.block=True
102 push = v.push
102 push = v.push
103 pull = v.pull
103 pull = v.pull
104 execute = v.execute
104 execute = v.execute
105 push({'testf':testf})
105 push({'testf':testf})
106 r = pull('testf')
106 r = pull('testf')
107 self.assertEqual(r(1.0), testf(1.0))
107 self.assertEqual(r(1.0), testf(1.0))
108 execute('r = testf(10)')
108 execute('r = testf(10)')
109 r = pull('r')
109 r = pull('r')
110 self.assertEqual(r, testf(10))
110 self.assertEqual(r, testf(10))
111 ar = self.client[:].push({'testf':testf}, block=False)
111 ar = self.client[:].push({'testf':testf}, block=False)
112 ar.get()
112 ar.get()
113 ar = self.client[:].pull('testf', block=False)
113 ar = self.client[:].pull('testf', block=False)
114 rlist = ar.get()
114 rlist = ar.get()
115 for r in rlist:
115 for r in rlist:
116 self.assertEqual(r(1.0), testf(1.0))
116 self.assertEqual(r(1.0), testf(1.0))
117 execute("def g(x): return x*x")
117 execute("def g(x): return x*x")
118 r = pull(('testf','g'))
118 r = pull(('testf','g'))
119 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
120
120
121 def test_push_function_globals(self):
121 def test_push_function_globals(self):
122 """test that pushed functions have access to globals"""
122 """test that pushed functions have access to globals"""
123 @interactive
123 @interactive
124 def geta():
124 def geta():
125 return a
125 return a
126 # self.add_engines(1)
126 # self.add_engines(1)
127 v = self.client[-1]
127 v = self.client[-1]
128 v.block=True
128 v.block=True
129 v['f'] = geta
129 v['f'] = geta
130 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
131 v.execute('a=5')
131 v.execute('a=5')
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEqual(v['b'], 5)
133 self.assertEqual(v['b'], 5)
134
134
135 def test_push_function_defaults(self):
135 def test_push_function_defaults(self):
136 """test that pushed functions preserve default args"""
136 """test that pushed functions preserve default args"""
137 def echo(a=10):
137 def echo(a=10):
138 return a
138 return a
139 v = self.client[-1]
139 v = self.client[-1]
140 v.block=True
140 v.block=True
141 v['f'] = echo
141 v['f'] = echo
142 v.execute('b=f()')
142 v.execute('b=f()')
143 self.assertEqual(v['b'], 10)
143 self.assertEqual(v['b'], 10)
144
144
145 def test_get_result(self):
145 def test_get_result(self):
146 """test getting results from the Hub."""
146 """test getting results from the Hub."""
147 c = pmod.Client(profile='iptest')
147 c = pmod.Client(profile='iptest')
148 # self.add_engines(1)
148 # self.add_engines(1)
149 t = c.ids[-1]
149 t = c.ids[-1]
150 v = c[t]
150 v = c[t]
151 v2 = self.client[t]
151 v2 = self.client[t]
152 ar = v.apply_async(wait, 1)
152 ar = v.apply_async(wait, 1)
153 # give the monitor time to notice the message
153 # give the monitor time to notice the message
154 time.sleep(.25)
154 time.sleep(.25)
155 ahr = v2.get_result(ar.msg_ids)
155 ahr = v2.get_result(ar.msg_ids)
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = v2.get_result(ar.msg_ids)
158 ar2 = v2.get_result(ar.msg_ids)
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.spin()
160 c.spin()
161 c.close()
161 c.close()
162
162
163 def test_run_newline(self):
163 def test_run_newline(self):
164 """test that run appends newline to files"""
164 """test that run appends newline to files"""
165 tmpfile = mktemp()
165 tmpfile = mktemp()
166 with open(tmpfile, 'w') as f:
166 with open(tmpfile, 'w') as f:
167 f.write("""def g():
167 f.write("""def g():
168 return 5
168 return 5
169 """)
169 """)
170 v = self.client[-1]
170 v = self.client[-1]
171 v.run(tmpfile, block=True)
171 v.run(tmpfile, block=True)
172 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
173
173
174 def test_apply_tracked(self):
174 def test_apply_tracked(self):
175 """test tracking for apply"""
175 """test tracking for apply"""
176 # self.add_engines(1)
176 # self.add_engines(1)
177 t = self.client.ids[-1]
177 t = self.client.ids[-1]
178 v = self.client[t]
178 v = self.client[t]
179 v.block=False
179 v.block=False
180 def echo(n=1024*1024, **kwargs):
180 def echo(n=1024*1024, **kwargs):
181 with v.temp_flags(**kwargs):
181 with v.temp_flags(**kwargs):
182 return v.apply(lambda x: x, 'x'*n)
182 return v.apply(lambda x: x, 'x'*n)
183 ar = echo(1, track=False)
183 ar = echo(1, track=False)
184 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
185 self.assertTrue(ar.sent)
185 self.assertTrue(ar.sent)
186 ar = echo(track=True)
186 ar = echo(track=True)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertEqual(ar.sent, ar._tracker.done)
188 self.assertEqual(ar.sent, ar._tracker.done)
189 ar._tracker.wait()
189 ar._tracker.wait()
190 self.assertTrue(ar.sent)
190 self.assertTrue(ar.sent)
191
191
192 def test_push_tracked(self):
192 def test_push_tracked(self):
193 t = self.client.ids[-1]
193 t = self.client.ids[-1]
194 ns = dict(x='x'*1024*1024)
194 ns = dict(x='x'*1024*1024)
195 v = self.client[t]
195 v = self.client[t]
196 ar = v.push(ns, block=False, track=False)
196 ar = v.push(ns, block=False, track=False)
197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
198 self.assertTrue(ar.sent)
198 self.assertTrue(ar.sent)
199
199
200 ar = v.push(ns, block=False, track=True)
200 ar = v.push(ns, block=False, track=True)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 ar._tracker.wait()
202 ar._tracker.wait()
203 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertEqual(ar.sent, ar._tracker.done)
204 self.assertTrue(ar.sent)
204 self.assertTrue(ar.sent)
205 ar.get()
205 ar.get()
206
206
207 def test_scatter_tracked(self):
207 def test_scatter_tracked(self):
208 t = self.client.ids
208 t = self.client.ids
209 x='x'*1024*1024
209 x='x'*1024*1024
210 ar = self.client[t].scatter('x', x, block=False, track=False)
210 ar = self.client[t].scatter('x', x, block=False, track=False)
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 self.assertTrue(ar.sent)
212 self.assertTrue(ar.sent)
213
213
214 ar = self.client[t].scatter('x', x, block=False, track=True)
214 ar = self.client[t].scatter('x', x, block=False, track=True)
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertEqual(ar.sent, ar._tracker.done)
216 self.assertEqual(ar.sent, ar._tracker.done)
217 ar._tracker.wait()
217 ar._tracker.wait()
218 self.assertTrue(ar.sent)
218 self.assertTrue(ar.sent)
219 ar.get()
219 ar.get()
220
220
221 def test_remote_reference(self):
221 def test_remote_reference(self):
222 v = self.client[-1]
222 v = self.client[-1]
223 v['a'] = 123
223 v['a'] = 123
224 ra = pmod.Reference('a')
224 ra = pmod.Reference('a')
225 b = v.apply_sync(lambda x: x, ra)
225 b = v.apply_sync(lambda x: x, ra)
226 self.assertEqual(b, 123)
226 self.assertEqual(b, 123)
227
227
228
228
229 def test_scatter_gather(self):
229 def test_scatter_gather(self):
230 view = self.client[:]
230 view = self.client[:]
231 seq1 = range(16)
231 seq1 = range(16)
232 view.scatter('a', seq1)
232 view.scatter('a', seq1)
233 seq2 = view.gather('a', block=True)
233 seq2 = view.gather('a', block=True)
234 self.assertEqual(seq2, seq1)
234 self.assertEqual(seq2, seq1)
235 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
236
236
237 @skip_without('numpy')
237 @skip_without('numpy')
238 def test_scatter_gather_numpy(self):
238 def test_scatter_gather_numpy(self):
239 import numpy
239 import numpy
240 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
241 view = self.client[:]
241 view = self.client[:]
242 a = numpy.arange(64)
242 a = numpy.arange(64)
243 view.scatter('a', a, block=True)
243 view.scatter('a', a, block=True)
244 b = view.gather('a', block=True)
244 b = view.gather('a', block=True)
245 assert_array_equal(b, a)
245 assert_array_equal(b, a)
246
246
247 def test_scatter_gather_lazy(self):
247 def test_scatter_gather_lazy(self):
248 """scatter/gather with targets='all'"""
248 """scatter/gather with targets='all'"""
249 view = self.client.direct_view(targets='all')
249 view = self.client.direct_view(targets='all')
250 x = range(64)
250 x = range(64)
251 view.scatter('x', x)
251 view.scatter('x', x)
252 gathered = view.gather('x', block=True)
252 gathered = view.gather('x', block=True)
253 self.assertEqual(gathered, x)
253 self.assertEqual(gathered, x)
254
254
255
255
256 @dec.known_failure_py3
256 @dec.known_failure_py3
257 @skip_without('numpy')
257 @skip_without('numpy')
258 def test_push_numpy_nocopy(self):
258 def test_push_numpy_nocopy(self):
259 import numpy
259 import numpy
260 view = self.client[:]
260 view = self.client[:]
261 a = numpy.arange(64)
261 a = numpy.arange(64)
262 view['A'] = a
262 view['A'] = a
263 @interactive
263 @interactive
264 def check_writeable(x):
264 def check_writeable(x):
265 return x.flags.writeable
265 return x.flags.writeable
266
266
267 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
268 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
269
269
270 view.push(dict(B=a))
270 view.push(dict(B=a))
271 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273
273
274 @skip_without('numpy')
274 @skip_without('numpy')
275 def test_apply_numpy(self):
275 def test_apply_numpy(self):
276 """view.apply(f, ndarray)"""
276 """view.apply(f, ndarray)"""
277 import numpy
277 import numpy
278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
279
279
280 A = numpy.random.random((100,100))
280 A = numpy.random.random((100,100))
281 view = self.client[-1]
281 view = self.client[-1]
282 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
283 B = A.astype(dt)
283 B = A.astype(dt)
284 C = view.apply_sync(lambda x:x, B)
284 C = view.apply_sync(lambda x:x, B)
285 assert_array_equal(B,C)
285 assert_array_equal(B,C)
286
286
287 @skip_without('numpy')
287 @skip_without('numpy')
288 def test_push_pull_recarray(self):
288 def test_push_pull_recarray(self):
289 """push/pull recarrays"""
289 """push/pull recarrays"""
290 import numpy
290 import numpy
291 from numpy.testing.utils import assert_array_equal
291 from numpy.testing.utils import assert_array_equal
292
292
293 view = self.client[-1]
293 view = self.client[-1]
294
294
295 R = numpy.array([
295 R = numpy.array([
296 (1, 'hi', 0.),
296 (1, 'hi', 0.),
297 (2**30, 'there', 2.5),
297 (2**30, 'there', 2.5),
298 (-99999, 'world', -12345.6789),
298 (-99999, 'world', -12345.6789),
299 ], [('n', int), ('s', '|S10'), ('f', float)])
299 ], [('n', int), ('s', '|S10'), ('f', float)])
300
300
301 view['RR'] = R
301 view['RR'] = R
302 R2 = view['RR']
302 R2 = view['RR']
303
303
304 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
305 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_dtype, R.dtype)
306 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(r_shape, R.shape)
307 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.dtype, R.dtype)
308 self.assertEqual(R2.shape, R.shape)
308 self.assertEqual(R2.shape, R.shape)
309 assert_array_equal(R2, R)
309 assert_array_equal(R2, R)
310
310
311 def test_map(self):
311 def test_map(self):
312 view = self.client[:]
312 view = self.client[:]
313 def f(x):
313 def f(x):
314 return x**2
314 return x**2
315 data = range(16)
315 data = range(16)
316 r = view.map_sync(f, data)
316 r = view.map_sync(f, data)
317 self.assertEqual(r, map(f, data))
317 self.assertEqual(r, map(f, data))
318
318
319 def test_map_iterable(self):
319 def test_map_iterable(self):
320 """test map on iterables (direct)"""
320 """test map on iterables (direct)"""
321 view = self.client[:]
321 view = self.client[:]
322 # 101 is prime, so it won't be evenly distributed
322 # 101 is prime, so it won't be evenly distributed
323 arr = range(101)
323 arr = range(101)
324 # ensure it will be an iterator, even in Python 3
324 # ensure it will be an iterator, even in Python 3
325 it = iter(arr)
325 it = iter(arr)
326 r = view.map_sync(lambda x:x, arr)
326 r = view.map_sync(lambda x:x, arr)
327 self.assertEqual(r, list(arr))
327 self.assertEqual(r, list(arr))
328
328
329 def test_scatter_gather_nonblocking(self):
329 def test_scatter_gather_nonblocking(self):
330 data = range(16)
330 data = range(16)
331 view = self.client[:]
331 view = self.client[:]
332 view.scatter('a', data, block=False)
332 view.scatter('a', data, block=False)
333 ar = view.gather('a', block=False)
333 ar = view.gather('a', block=False)
334 self.assertEqual(ar.get(), data)
334 self.assertEqual(ar.get(), data)
335
335
336 @skip_without('numpy')
336 @skip_without('numpy')
337 def test_scatter_gather_numpy_nonblocking(self):
337 def test_scatter_gather_numpy_nonblocking(self):
338 import numpy
338 import numpy
339 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
340 a = numpy.arange(64)
340 a = numpy.arange(64)
341 view = self.client[:]
341 view = self.client[:]
342 ar = view.scatter('a', a, block=False)
342 ar = view.scatter('a', a, block=False)
343 self.assertTrue(isinstance(ar, AsyncResult))
343 self.assertTrue(isinstance(ar, AsyncResult))
344 amr = view.gather('a', block=False)
344 amr = view.gather('a', block=False)
345 self.assertTrue(isinstance(amr, AsyncMapResult))
345 self.assertTrue(isinstance(amr, AsyncMapResult))
346 assert_array_equal(amr.get(), a)
346 assert_array_equal(amr.get(), a)
347
347
348 def test_execute(self):
348 def test_execute(self):
349 view = self.client[:]
349 view = self.client[:]
350 # self.client.debug=True
350 # self.client.debug=True
351 execute = view.execute
351 execute = view.execute
352 ar = execute('c=30', block=False)
352 ar = execute('c=30', block=False)
353 self.assertTrue(isinstance(ar, AsyncResult))
353 self.assertTrue(isinstance(ar, AsyncResult))
354 ar = execute('d=[0,1,2]', block=False)
354 ar = execute('d=[0,1,2]', block=False)
355 self.client.wait(ar, 1)
355 self.client.wait(ar, 1)
356 self.assertEqual(len(ar.get()), len(self.client))
356 self.assertEqual(len(ar.get()), len(self.client))
357 for c in view['c']:
357 for c in view['c']:
358 self.assertEqual(c, 30)
358 self.assertEqual(c, 30)
359
359
360 def test_abort(self):
360 def test_abort(self):
361 view = self.client[-1]
361 view = self.client[-1]
362 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar = view.execute('import time; time.sleep(1)', block=False)
363 ar2 = view.apply_async(lambda : 2)
363 ar2 = view.apply_async(lambda : 2)
364 ar3 = view.apply_async(lambda : 3)
364 ar3 = view.apply_async(lambda : 3)
365 view.abort(ar2)
365 view.abort(ar2)
366 view.abort(ar3.msg_ids)
366 view.abort(ar3.msg_ids)
367 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar2.get)
368 self.assertRaises(error.TaskAborted, ar3.get)
368 self.assertRaises(error.TaskAborted, ar3.get)
369
369
370 def test_abort_all(self):
370 def test_abort_all(self):
371 """view.abort() aborts all outstanding tasks"""
371 """view.abort() aborts all outstanding tasks"""
372 view = self.client[-1]
372 view = self.client[-1]
373 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
374 view.abort()
374 view.abort()
375 view.wait(timeout=5)
375 view.wait(timeout=5)
376 for ar in ars[5:]:
376 for ar in ars[5:]:
377 self.assertRaises(error.TaskAborted, ar.get)
377 self.assertRaises(error.TaskAborted, ar.get)
378
378
379 def test_temp_flags(self):
379 def test_temp_flags(self):
380 view = self.client[-1]
380 view = self.client[-1]
381 view.block=True
381 view.block=True
382 with view.temp_flags(block=False):
382 with view.temp_flags(block=False):
383 self.assertFalse(view.block)
383 self.assertFalse(view.block)
384 self.assertTrue(view.block)
384 self.assertTrue(view.block)
385
385
386 @dec.known_failure_py3
386 @dec.known_failure_py3
387 def test_importer(self):
387 def test_importer(self):
388 view = self.client[-1]
388 view = self.client[-1]
389 view.clear(block=True)
389 view.clear(block=True)
390 with view.importer:
390 with view.importer:
391 import re
391 import re
392
392
393 @interactive
393 @interactive
394 def findall(pat, s):
394 def findall(pat, s):
395 # this globals() step isn't necessary in real code
395 # this globals() step isn't necessary in real code
396 # only to prevent a closure in the test
396 # only to prevent a closure in the test
397 re = globals()['re']
397 re = globals()['re']
398 return re.findall(pat, s)
398 return re.findall(pat, s)
399
399
400 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
401
401
402 def test_unicode_execute(self):
402 def test_unicode_execute(self):
403 """test executing unicode strings"""
403 """test executing unicode strings"""
404 v = self.client[-1]
404 v = self.client[-1]
405 v.block=True
405 v.block=True
406 if sys.version_info[0] >= 3:
406 if sys.version_info[0] >= 3:
407 code="a='é'"
407 code="a='é'"
408 else:
408 else:
409 code=u"a=u'é'"
409 code=u"a=u'é'"
410 v.execute(code)
410 v.execute(code)
411 self.assertEqual(v['a'], u'é')
411 self.assertEqual(v['a'], u'é')
412
412
413 def test_unicode_apply_result(self):
413 def test_unicode_apply_result(self):
414 """test unicode apply results"""
414 """test unicode apply results"""
415 v = self.client[-1]
415 v = self.client[-1]
416 r = v.apply_sync(lambda : u'é')
416 r = v.apply_sync(lambda : u'é')
417 self.assertEqual(r, u'é')
417 self.assertEqual(r, u'é')
418
418
419 def test_unicode_apply_arg(self):
419 def test_unicode_apply_arg(self):
420 """test passing unicode arguments to apply"""
420 """test passing unicode arguments to apply"""
421 v = self.client[-1]
421 v = self.client[-1]
422
422
423 @interactive
423 @interactive
424 def check_unicode(a, check):
424 def check_unicode(a, check):
425 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(a, unicode), "%r is not unicode"%a
426 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert isinstance(check, bytes), "%r is not bytes"%check
427 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427 assert a.encode('utf8') == check, "%s != %s"%(a,check)
428
428
429 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 for s in [ u'é', u'ßø®∫',u'asdf' ]:
430 try:
430 try:
431 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 v.apply_sync(check_unicode, s, s.encode('utf8'))
432 except error.RemoteError as e:
432 except error.RemoteError as e:
433 if e.ename == 'AssertionError':
433 if e.ename == 'AssertionError':
434 self.fail(e.evalue)
434 self.fail(e.evalue)
435 else:
435 else:
436 raise e
436 raise e
437
437
438 def test_map_reference(self):
438 def test_map_reference(self):
439 """view.map(<Reference>, *seqs) should work"""
439 """view.map(<Reference>, *seqs) should work"""
440 v = self.client[:]
440 v = self.client[:]
441 v.scatter('n', self.client.ids, flatten=True)
441 v.scatter('n', self.client.ids, flatten=True)
442 v.execute("f = lambda x,y: x*y")
442 v.execute("f = lambda x,y: x*y")
443 rf = pmod.Reference('f')
443 rf = pmod.Reference('f')
444 nlist = list(range(10))
444 nlist = list(range(10))
445 mlist = nlist[::-1]
445 mlist = nlist[::-1]
446 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 expected = [ m*n for m,n in zip(mlist, nlist) ]
447 result = v.map_sync(rf, mlist, nlist)
447 result = v.map_sync(rf, mlist, nlist)
448 self.assertEqual(result, expected)
448 self.assertEqual(result, expected)
449
449
450 def test_apply_reference(self):
450 def test_apply_reference(self):
451 """view.apply(<Reference>, *args) should work"""
451 """view.apply(<Reference>, *args) should work"""
452 v = self.client[:]
452 v = self.client[:]
453 v.scatter('n', self.client.ids, flatten=True)
453 v.scatter('n', self.client.ids, flatten=True)
454 v.execute("f = lambda x: n*x")
454 v.execute("f = lambda x: n*x")
455 rf = pmod.Reference('f')
455 rf = pmod.Reference('f')
456 result = v.apply_sync(rf, 5)
456 result = v.apply_sync(rf, 5)
457 expected = [ 5*id for id in self.client.ids ]
457 expected = [ 5*id for id in self.client.ids ]
458 self.assertEqual(result, expected)
458 self.assertEqual(result, expected)
459
459
460 def test_eval_reference(self):
460 def test_eval_reference(self):
461 v = self.client[self.client.ids[0]]
461 v = self.client[self.client.ids[0]]
462 v['g'] = range(5)
462 v['g'] = range(5)
463 rg = pmod.Reference('g[0]')
463 rg = pmod.Reference('g[0]')
464 echo = lambda x:x
464 echo = lambda x:x
465 self.assertEqual(v.apply_sync(echo, rg), 0)
465 self.assertEqual(v.apply_sync(echo, rg), 0)
466
466
467 def test_reference_nameerror(self):
467 def test_reference_nameerror(self):
468 v = self.client[self.client.ids[0]]
468 v = self.client[self.client.ids[0]]
469 r = pmod.Reference('elvis_has_left')
469 r = pmod.Reference('elvis_has_left')
470 echo = lambda x:x
470 echo = lambda x:x
471 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
472
472
473 def test_single_engine_map(self):
473 def test_single_engine_map(self):
474 e0 = self.client[self.client.ids[0]]
474 e0 = self.client[self.client.ids[0]]
475 r = range(5)
475 r = range(5)
476 check = [ -1*i for i in r ]
476 check = [ -1*i for i in r ]
477 result = e0.map_sync(lambda x: -1*x, r)
477 result = e0.map_sync(lambda x: -1*x, r)
478 self.assertEqual(result, check)
478 self.assertEqual(result, check)
479
479
480 def test_len(self):
480 def test_len(self):
481 """len(view) makes sense"""
481 """len(view) makes sense"""
482 e0 = self.client[self.client.ids[0]]
482 e0 = self.client[self.client.ids[0]]
483 yield self.assertEqual(len(e0), 1)
483 yield self.assertEqual(len(e0), 1)
484 v = self.client[:]
484 v = self.client[:]
485 yield self.assertEqual(len(v), len(self.client.ids))
485 yield self.assertEqual(len(v), len(self.client.ids))
486 v = self.client.direct_view('all')
486 v = self.client.direct_view('all')
487 yield self.assertEqual(len(v), len(self.client.ids))
487 yield self.assertEqual(len(v), len(self.client.ids))
488 v = self.client[:2]
488 v = self.client[:2]
489 yield self.assertEqual(len(v), 2)
489 yield self.assertEqual(len(v), 2)
490 v = self.client[:1]
490 v = self.client[:1]
491 yield self.assertEqual(len(v), 1)
491 yield self.assertEqual(len(v), 1)
492 v = self.client.load_balanced_view()
492 v = self.client.load_balanced_view()
493 yield self.assertEqual(len(v), len(self.client.ids))
493 yield self.assertEqual(len(v), len(self.client.ids))
494 # parametric tests seem to require manual closing?
494 # parametric tests seem to require manual closing?
495 self.client.close()
495 self.client.close()
496
496
497
497
498 # begin execute tests
498 # begin execute tests
499
499
500 def test_execute_reply(self):
500 def test_execute_reply(self):
501 e0 = self.client[self.client.ids[0]]
501 e0 = self.client[self.client.ids[0]]
502 e0.block = True
502 e0.block = True
503 ar = e0.execute("5", silent=False)
503 ar = e0.execute("5", silent=False)
504 er = ar.get()
504 er = ar.get()
505 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
506 self.assertEqual(er.pyout['data']['text/plain'], '5')
506 self.assertEqual(er.pyout['data']['text/plain'], '5')
507
507
508 def test_execute_reply_stdout(self):
508 def test_execute_reply_stdout(self):
509 e0 = self.client[self.client.ids[0]]
509 e0 = self.client[self.client.ids[0]]
510 e0.block = True
510 e0.block = True
511 ar = e0.execute("print (5)", silent=False)
511 ar = e0.execute("print (5)", silent=False)
512 er = ar.get()
512 er = ar.get()
513 self.assertEqual(er.stdout.strip(), '5')
513 self.assertEqual(er.stdout.strip(), '5')
514
514
515 def test_execute_pyout(self):
515 def test_execute_pyout(self):
516 """execute triggers pyout with silent=False"""
516 """execute triggers pyout with silent=False"""
517 view = self.client[:]
517 view = self.client[:]
518 ar = view.execute("5", silent=False, block=True)
518 ar = view.execute("5", silent=False, block=True)
519
519
520 expected = [{'text/plain' : '5'}] * len(view)
520 expected = [{'text/plain' : '5'}] * len(view)
521 mimes = [ out['data'] for out in ar.pyout ]
521 mimes = [ out['data'] for out in ar.pyout ]
522 self.assertEqual(mimes, expected)
522 self.assertEqual(mimes, expected)
523
523
524 def test_execute_silent(self):
524 def test_execute_silent(self):
525 """execute does not trigger pyout with silent=True"""
525 """execute does not trigger pyout with silent=True"""
526 view = self.client[:]
526 view = self.client[:]
527 ar = view.execute("5", block=True)
527 ar = view.execute("5", block=True)
528 expected = [None] * len(view)
528 expected = [None] * len(view)
529 self.assertEqual(ar.pyout, expected)
529 self.assertEqual(ar.pyout, expected)
530
530
531 def test_execute_magic(self):
531 def test_execute_magic(self):
532 """execute accepts IPython commands"""
532 """execute accepts IPython commands"""
533 view = self.client[:]
533 view = self.client[:]
534 view.execute("a = 5")
534 view.execute("a = 5")
535 ar = view.execute("%whos", block=True)
535 ar = view.execute("%whos", block=True)
536 # this will raise, if that failed
536 # this will raise, if that failed
537 ar.get(5)
537 ar.get(5)
538 for stdout in ar.stdout:
538 for stdout in ar.stdout:
539 lines = stdout.splitlines()
539 lines = stdout.splitlines()
540 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
541 found = False
541 found = False
542 for line in lines[2:]:
542 for line in lines[2:]:
543 split = line.split()
543 split = line.split()
544 if split == ['a', 'int', '5']:
544 if split == ['a', 'int', '5']:
545 found = True
545 found = True
546 break
546 break
547 self.assertTrue(found, "whos output wrong: %s" % stdout)
547 self.assertTrue(found, "whos output wrong: %s" % stdout)
548
548
549 def test_execute_displaypub(self):
549 def test_execute_displaypub(self):
550 """execute tracks display_pub output"""
550 """execute tracks display_pub output"""
551 view = self.client[:]
551 view = self.client[:]
552 view.execute("from IPython.core.display import *")
552 view.execute("from IPython.core.display import *")
553 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
554
554
555 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
556 for outputs in ar.outputs:
556 for outputs in ar.outputs:
557 mimes = [ out['data'] for out in outputs ]
557 mimes = [ out['data'] for out in outputs ]
558 self.assertEqual(mimes, expected)
558 self.assertEqual(mimes, expected)
559
559
560 def test_apply_displaypub(self):
560 def test_apply_displaypub(self):
561 """apply tracks display_pub output"""
561 """apply tracks display_pub output"""
562 view = self.client[:]
562 view = self.client[:]
563 view.execute("from IPython.core.display import *")
563 view.execute("from IPython.core.display import *")
564
564
565 @interactive
565 @interactive
566 def publish():
566 def publish():
567 [ display(i) for i in range(5) ]
567 [ display(i) for i in range(5) ]
568
568
569 ar = view.apply_async(publish)
569 ar = view.apply_async(publish)
570 ar.get(5)
570 ar.get(5)
571 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
572 for outputs in ar.outputs:
572 for outputs in ar.outputs:
573 mimes = [ out['data'] for out in outputs ]
573 mimes = [ out['data'] for out in outputs ]
574 self.assertEqual(mimes, expected)
574 self.assertEqual(mimes, expected)
575
575
576 def test_execute_raises(self):
576 def test_execute_raises(self):
577 """exceptions in execute requests raise appropriately"""
577 """exceptions in execute requests raise appropriately"""
578 view = self.client[-1]
578 view = self.client[-1]
579 ar = view.execute("1/0")
579 ar = view.execute("1/0")
580 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
581
581
582 def test_remoteerror_render_exception(self):
582 def test_remoteerror_render_exception(self):
583 """RemoteErrors get nice tracebacks"""
583 """RemoteErrors get nice tracebacks"""
584 view = self.client[-1]
584 view = self.client[-1]
585 ar = view.execute("1/0")
585 ar = view.execute("1/0")
586 ip = get_ipython()
586 ip = get_ipython()
587 ip.user_ns['ar'] = ar
587 ip.user_ns['ar'] = ar
588 with capture_output() as io:
588 with capture_output() as io:
589 ip.run_cell("ar.get(2)")
589 ip.run_cell("ar.get(2)")
590
590
591 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
591 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
592
592
593 def test_compositeerror_render_exception(self):
593 def test_compositeerror_render_exception(self):
594 """CompositeErrors get nice tracebacks"""
594 """CompositeErrors get nice tracebacks"""
595 view = self.client[:]
595 view = self.client[:]
596 ar = view.execute("1/0")
596 ar = view.execute("1/0")
597 ip = get_ipython()
597 ip = get_ipython()
598 ip.user_ns['ar'] = ar
598 ip.user_ns['ar'] = ar
599 with capture_output() as io:
599 with capture_output() as io:
600 ip.run_cell("ar.get(2)")
600 ip.run_cell("ar.get(2)")
601
601
602 self.assertEqual(io.stdout.count('ZeroDivisionError'), len(view) * 2, io.stdout)
602 self.assertEqual(io.stdout.count('ZeroDivisionError'), len(view) * 2, io.stdout)
603 self.assertEqual(io.stdout.count('by zero'), len(view), io.stdout)
603 self.assertEqual(io.stdout.count('by zero'), len(view), io.stdout)
604 self.assertEqual(io.stdout.count(':execute'), len(view), io.stdout)
604 self.assertEqual(io.stdout.count(':execute'), len(view), io.stdout)
605
605
606 @dec.skipif_not_matplotlib
606 @dec.skipif_not_matplotlib
607 def test_magic_pylab(self):
607 def test_magic_pylab(self):
608 """%pylab works on engines"""
608 """%pylab works on engines"""
609 view = self.client[-1]
609 view = self.client[-1]
610 ar = view.execute("%pylab inline")
610 ar = view.execute("%pylab inline")
611 # at least check if this raised:
611 # at least check if this raised:
612 reply = ar.get(5)
612 reply = ar.get(5)
613 # include imports, in case user config
613 # include imports, in case user config
614 ar = view.execute("plot(rand(100))", silent=False)
614 ar = view.execute("plot(rand(100))", silent=False)
615 reply = ar.get(5)
615 reply = ar.get(5)
616 self.assertEqual(len(reply.outputs), 1)
616 self.assertEqual(len(reply.outputs), 1)
617 output = reply.outputs[0]
617 output = reply.outputs[0]
618 self.assertTrue("data" in output)
618 self.assertTrue("data" in output)
619 data = output['data']
619 data = output['data']
620 self.assertTrue("image/png" in data)
620 self.assertTrue("image/png" in data)
621
621
622 def test_func_default_func(self):
622 def test_func_default_func(self):
623 """interactively defined function as apply func default"""
623 """interactively defined function as apply func default"""
624 def foo():
624 def foo():
625 return 'foo'
625 return 'foo'
626
626
627 def bar(f=foo):
627 def bar(f=foo):
628 return f()
628 return f()
629
629
630 view = self.client[-1]
630 view = self.client[-1]
631 ar = view.apply_async(bar)
631 ar = view.apply_async(bar)
632 r = ar.get(10)
632 r = ar.get(10)
633 self.assertEqual(r, 'foo')
633 self.assertEqual(r, 'foo')
634 def test_data_pub_single(self):
634 def test_data_pub_single(self):
635 view = self.client[-1]
635 view = self.client[-1]
636 ar = view.execute('\n'.join([
636 ar = view.execute('\n'.join([
637 'from IPython.zmq.datapub import publish_data',
637 'from IPython.zmq.datapub import publish_data',
638 'for i in range(5):',
638 'for i in range(5):',
639 ' publish_data(dict(i=i))'
639 ' publish_data(dict(i=i))'
640 ]), block=False)
640 ]), block=False)
641 self.assertTrue(isinstance(ar.data, dict))
641 self.assertTrue(isinstance(ar.data, dict))
642 ar.get(5)
642 ar.get(5)
643 self.assertEqual(ar.data, dict(i=4))
643 self.assertEqual(ar.data, dict(i=4))
644
644
645 def test_data_pub(self):
645 def test_data_pub(self):
646 view = self.client[:]
646 view = self.client[:]
647 ar = view.execute('\n'.join([
647 ar = view.execute('\n'.join([
648 'from IPython.zmq.datapub import publish_data',
648 'from IPython.zmq.datapub import publish_data',
649 'for i in range(5):',
649 'for i in range(5):',
650 ' publish_data(dict(i=i))'
650 ' publish_data(dict(i=i))'
651 ]), block=False)
651 ]), block=False)
652 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
652 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
653 ar.get(5)
653 ar.get(5)
654 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
654 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
655
655
656 def test_can_list_arg(self):
656 def test_can_list_arg(self):
657 """args in lists are canned"""
657 """args in lists are canned"""
658 view = self.client[-1]
658 view = self.client[-1]
659 view['a'] = 128
659 view['a'] = 128
660 rA = pmod.Reference('a')
660 rA = pmod.Reference('a')
661 ar = view.apply_async(lambda x: x, [rA])
661 ar = view.apply_async(lambda x: x, [rA])
662 r = ar.get(5)
662 r = ar.get(5)
663 self.assertEqual(r, [128])
663 self.assertEqual(r, [128])
664
664
665 def test_can_dict_arg(self):
665 def test_can_dict_arg(self):
666 """args in dicts are canned"""
666 """args in dicts are canned"""
667 view = self.client[-1]
667 view = self.client[-1]
668 view['a'] = 128
668 view['a'] = 128
669 rA = pmod.Reference('a')
669 rA = pmod.Reference('a')
670 ar = view.apply_async(lambda x: x, dict(foo=rA))
670 ar = view.apply_async(lambda x: x, dict(foo=rA))
671 r = ar.get(5)
671 r = ar.get(5)
672 self.assertEqual(r, dict(foo=128))
672 self.assertEqual(r, dict(foo=128))
673
673
674 def test_can_list_kwarg(self):
674 def test_can_list_kwarg(self):
675 """kwargs in lists are canned"""
675 """kwargs in lists are canned"""
676 view = self.client[-1]
676 view = self.client[-1]
677 view['a'] = 128
677 view['a'] = 128
678 rA = pmod.Reference('a')
678 rA = pmod.Reference('a')
679 ar = view.apply_async(lambda x=5: x, x=[rA])
679 ar = view.apply_async(lambda x=5: x, x=[rA])
680 r = ar.get(5)
680 r = ar.get(5)
681 self.assertEqual(r, [128])
681 self.assertEqual(r, [128])
682
682
683 def test_can_dict_kwarg(self):
683 def test_can_dict_kwarg(self):
684 """kwargs in dicts are canned"""
684 """kwargs in dicts are canned"""
685 view = self.client[-1]
685 view = self.client[-1]
686 view['a'] = 128
686 view['a'] = 128
687 rA = pmod.Reference('a')
687 rA = pmod.Reference('a')
688 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
688 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
689 r = ar.get(5)
689 r = ar.get(5)
690 self.assertEqual(r, dict(foo=128))
690 self.assertEqual(r, dict(foo=128))
691
691
692 def test_map_ref(self):
692 def test_map_ref(self):
693 """view.map works with references"""
693 """view.map works with references"""
694 view = self.client[:]
694 view = self.client[:]
695 ranks = sorted(self.client.ids)
695 ranks = sorted(self.client.ids)
696 view.scatter('rank', ranks, flatten=True)
696 view.scatter('rank', ranks, flatten=True)
697 rrank = pmod.Reference('rank')
697 rrank = pmod.Reference('rank')
698
698
699 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
699 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
700 drank = amr.get(5)
700 drank = amr.get(5)
701 self.assertEqual(drank, [ r*2 for r in ranks ])
701 self.assertEqual(drank, [ r*2 for r in ranks ])
702
702
703 def test_nested_getitem_setitem(self):
704 """get and set with view['a.b']"""
705 view = self.client[-1]
706 view.execute('\n'.join([
707 'class A(object): pass',
708 'a = A()',
709 'a.b = 128',
710 ]), block=True)
711 ra = pmod.Reference('a')
712
713 r = view.apply_sync(lambda x: x.b, ra)
714 self.assertEqual(r, 128)
715 self.assertEqual(view['a.b'], 128)
716
717 view['a.b'] = 0
703
718
719 r = view.apply_sync(lambda x: x.b, ra)
720 self.assertEqual(r, 0)
721 self.assertEqual(view['a.b'], 0)
@@ -1,352 +1,355 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.zmq.log import EnginePUBHandler
46 from IPython.zmq.log import EnginePUBHandler
47 from IPython.zmq.serialize import (
47 from IPython.zmq.serialize import (
48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
49 )
49 )
50
50
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52 # Classes
52 # Classes
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54
54
55 class Namespace(dict):
55 class Namespace(dict):
56 """Subclass of dict for attribute access to keys."""
56 """Subclass of dict for attribute access to keys."""
57
57
58 def __getattr__(self, key):
58 def __getattr__(self, key):
59 """getattr aliased to getitem"""
59 """getattr aliased to getitem"""
60 if key in self.iterkeys():
60 if key in self.iterkeys():
61 return self[key]
61 return self[key]
62 else:
62 else:
63 raise NameError(key)
63 raise NameError(key)
64
64
65 def __setattr__(self, key, value):
65 def __setattr__(self, key, value):
66 """setattr aliased to setitem, with strict"""
66 """setattr aliased to setitem, with strict"""
67 if hasattr(dict, key):
67 if hasattr(dict, key):
68 raise KeyError("Cannot override dict keys %r"%key)
68 raise KeyError("Cannot override dict keys %r"%key)
69 self[key] = value
69 self[key] = value
70
70
71
71
72 class ReverseDict(dict):
72 class ReverseDict(dict):
73 """simple double-keyed subset of dict methods."""
73 """simple double-keyed subset of dict methods."""
74
74
75 def __init__(self, *args, **kwargs):
75 def __init__(self, *args, **kwargs):
76 dict.__init__(self, *args, **kwargs)
76 dict.__init__(self, *args, **kwargs)
77 self._reverse = dict()
77 self._reverse = dict()
78 for key, value in self.iteritems():
78 for key, value in self.iteritems():
79 self._reverse[value] = key
79 self._reverse[value] = key
80
80
81 def __getitem__(self, key):
81 def __getitem__(self, key):
82 try:
82 try:
83 return dict.__getitem__(self, key)
83 return dict.__getitem__(self, key)
84 except KeyError:
84 except KeyError:
85 return self._reverse[key]
85 return self._reverse[key]
86
86
87 def __setitem__(self, key, value):
87 def __setitem__(self, key, value):
88 if key in self._reverse:
88 if key in self._reverse:
89 raise KeyError("Can't have key %r on both sides!"%key)
89 raise KeyError("Can't have key %r on both sides!"%key)
90 dict.__setitem__(self, key, value)
90 dict.__setitem__(self, key, value)
91 self._reverse[value] = key
91 self._reverse[value] = key
92
92
93 def pop(self, key):
93 def pop(self, key):
94 value = dict.pop(self, key)
94 value = dict.pop(self, key)
95 self._reverse.pop(value)
95 self._reverse.pop(value)
96 return value
96 return value
97
97
98 def get(self, key, default=None):
98 def get(self, key, default=None):
99 try:
99 try:
100 return self[key]
100 return self[key]
101 except KeyError:
101 except KeyError:
102 return default
102 return default
103
103
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 # Functions
105 # Functions
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107
107
108 @decorator
108 @decorator
109 def log_errors(f, self, *args, **kwargs):
109 def log_errors(f, self, *args, **kwargs):
110 """decorator to log unhandled exceptions raised in a method.
110 """decorator to log unhandled exceptions raised in a method.
111
111
112 For use wrapping on_recv callbacks, so that exceptions
112 For use wrapping on_recv callbacks, so that exceptions
113 do not cause the stream to be closed.
113 do not cause the stream to be closed.
114 """
114 """
115 try:
115 try:
116 return f(self, *args, **kwargs)
116 return f(self, *args, **kwargs)
117 except Exception:
117 except Exception:
118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
119
119
120
120
121 def is_url(url):
121 def is_url(url):
122 """boolean check for whether a string is a zmq url"""
122 """boolean check for whether a string is a zmq url"""
123 if '://' not in url:
123 if '://' not in url:
124 return False
124 return False
125 proto, addr = url.split('://', 1)
125 proto, addr = url.split('://', 1)
126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
127 return False
127 return False
128 return True
128 return True
129
129
130 def validate_url(url):
130 def validate_url(url):
131 """validate a url for zeromq"""
131 """validate a url for zeromq"""
132 if not isinstance(url, basestring):
132 if not isinstance(url, basestring):
133 raise TypeError("url must be a string, not %r"%type(url))
133 raise TypeError("url must be a string, not %r"%type(url))
134 url = url.lower()
134 url = url.lower()
135
135
136 proto_addr = url.split('://')
136 proto_addr = url.split('://')
137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
138 proto, addr = proto_addr
138 proto, addr = proto_addr
139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
140
140
141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
142 # author: Remi Sabourin
142 # author: Remi Sabourin
143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
144
144
145 if proto == 'tcp':
145 if proto == 'tcp':
146 lis = addr.split(':')
146 lis = addr.split(':')
147 assert len(lis) == 2, 'Invalid url: %r'%url
147 assert len(lis) == 2, 'Invalid url: %r'%url
148 addr,s_port = lis
148 addr,s_port = lis
149 try:
149 try:
150 port = int(s_port)
150 port = int(s_port)
151 except ValueError:
151 except ValueError:
152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
153
153
154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
155
155
156 else:
156 else:
157 # only validate tcp urls currently
157 # only validate tcp urls currently
158 pass
158 pass
159
159
160 return True
160 return True
161
161
162
162
163 def validate_url_container(container):
163 def validate_url_container(container):
164 """validate a potentially nested collection of urls."""
164 """validate a potentially nested collection of urls."""
165 if isinstance(container, basestring):
165 if isinstance(container, basestring):
166 url = container
166 url = container
167 return validate_url(url)
167 return validate_url(url)
168 elif isinstance(container, dict):
168 elif isinstance(container, dict):
169 container = container.itervalues()
169 container = container.itervalues()
170
170
171 for element in container:
171 for element in container:
172 validate_url_container(element)
172 validate_url_container(element)
173
173
174
174
175 def split_url(url):
175 def split_url(url):
176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
177 proto_addr = url.split('://')
177 proto_addr = url.split('://')
178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
179 proto, addr = proto_addr
179 proto, addr = proto_addr
180 lis = addr.split(':')
180 lis = addr.split(':')
181 assert len(lis) == 2, 'Invalid url: %r'%url
181 assert len(lis) == 2, 'Invalid url: %r'%url
182 addr,s_port = lis
182 addr,s_port = lis
183 return proto,addr,s_port
183 return proto,addr,s_port
184
184
185 def disambiguate_ip_address(ip, location=None):
185 def disambiguate_ip_address(ip, location=None):
186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
187 ones, based on the location (default interpretation of location is localhost)."""
187 ones, based on the location (default interpretation of location is localhost)."""
188 if ip in ('0.0.0.0', '*'):
188 if ip in ('0.0.0.0', '*'):
189 try:
189 try:
190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
191 except (socket.gaierror, IndexError):
191 except (socket.gaierror, IndexError):
192 # couldn't identify this machine, assume localhost
192 # couldn't identify this machine, assume localhost
193 external_ips = []
193 external_ips = []
194 if location is None or location in external_ips or not external_ips:
194 if location is None or location in external_ips or not external_ips:
195 # If location is unspecified or cannot be determined, assume local
195 # If location is unspecified or cannot be determined, assume local
196 ip='127.0.0.1'
196 ip='127.0.0.1'
197 elif location:
197 elif location:
198 return location
198 return location
199 return ip
199 return ip
200
200
201 def disambiguate_url(url, location=None):
201 def disambiguate_url(url, location=None):
202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
203 ones, based on the location (default interpretation is localhost).
203 ones, based on the location (default interpretation is localhost).
204
204
205 This is for zeromq urls, such as tcp://*:10101."""
205 This is for zeromq urls, such as tcp://*:10101."""
206 try:
206 try:
207 proto,ip,port = split_url(url)
207 proto,ip,port = split_url(url)
208 except AssertionError:
208 except AssertionError:
209 # probably not tcp url; could be ipc, etc.
209 # probably not tcp url; could be ipc, etc.
210 return url
210 return url
211
211
212 ip = disambiguate_ip_address(ip,location)
212 ip = disambiguate_ip_address(ip,location)
213
213
214 return "%s://%s:%s"%(proto,ip,port)
214 return "%s://%s:%s"%(proto,ip,port)
215
215
216
216
217 #--------------------------------------------------------------------------
217 #--------------------------------------------------------------------------
218 # helpers for implementing old MEC API via view.apply
218 # helpers for implementing old MEC API via view.apply
219 #--------------------------------------------------------------------------
219 #--------------------------------------------------------------------------
220
220
221 def interactive(f):
221 def interactive(f):
222 """decorator for making functions appear as interactively defined.
222 """decorator for making functions appear as interactively defined.
223 This results in the function being linked to the user_ns as globals()
223 This results in the function being linked to the user_ns as globals()
224 instead of the module globals().
224 instead of the module globals().
225 """
225 """
226 f.__module__ = '__main__'
226 f.__module__ = '__main__'
227 return f
227 return f
228
228
229 @interactive
229 @interactive
230 def _push(**ns):
230 def _push(**ns):
231 """helper method for implementing `client.push` via `client.apply`"""
231 """helper method for implementing `client.push` via `client.apply`"""
232 globals().update(ns)
232 user_ns = globals()
233 tmp = '_IP_PUSH_TMP_'
234 while tmp in user_ns:
235 tmp = tmp + '_'
236 try:
237 for name, value in ns.iteritems():
238 user_ns[tmp] = value
239 exec "%s = %s" % (name, tmp) in user_ns
240 finally:
241 user_ns.pop(tmp, None)
233
242
234 @interactive
243 @interactive
235 def _pull(keys):
244 def _pull(keys):
236 """helper method for implementing `client.pull` via `client.apply`"""
245 """helper method for implementing `client.pull` via `client.apply`"""
237 user_ns = globals()
238 if isinstance(keys, (list,tuple, set)):
246 if isinstance(keys, (list,tuple, set)):
239 for key in keys:
247 return map(lambda key: eval(key, globals()), keys)
240 if key not in user_ns:
241 raise NameError("name '%s' is not defined"%key)
242 return map(user_ns.get, keys)
243 else:
248 else:
244 if keys not in user_ns:
249 return eval(keys, globals())
245 raise NameError("name '%s' is not defined"%keys)
246 return user_ns.get(keys)
247
250
248 @interactive
251 @interactive
249 def _execute(code):
252 def _execute(code):
250 """helper method for implementing `client.execute` via `client.apply`"""
253 """helper method for implementing `client.execute` via `client.apply`"""
251 exec code in globals()
254 exec code in globals()
252
255
253 #--------------------------------------------------------------------------
256 #--------------------------------------------------------------------------
254 # extra process management utilities
257 # extra process management utilities
255 #--------------------------------------------------------------------------
258 #--------------------------------------------------------------------------
256
259
257 _random_ports = set()
260 _random_ports = set()
258
261
259 def select_random_ports(n):
262 def select_random_ports(n):
260 """Selects and return n random ports that are available."""
263 """Selects and return n random ports that are available."""
261 ports = []
264 ports = []
262 for i in xrange(n):
265 for i in xrange(n):
263 sock = socket.socket()
266 sock = socket.socket()
264 sock.bind(('', 0))
267 sock.bind(('', 0))
265 while sock.getsockname()[1] in _random_ports:
268 while sock.getsockname()[1] in _random_ports:
266 sock.close()
269 sock.close()
267 sock = socket.socket()
270 sock = socket.socket()
268 sock.bind(('', 0))
271 sock.bind(('', 0))
269 ports.append(sock)
272 ports.append(sock)
270 for i, sock in enumerate(ports):
273 for i, sock in enumerate(ports):
271 port = sock.getsockname()[1]
274 port = sock.getsockname()[1]
272 sock.close()
275 sock.close()
273 ports[i] = port
276 ports[i] = port
274 _random_ports.add(port)
277 _random_ports.add(port)
275 return ports
278 return ports
276
279
277 def signal_children(children):
280 def signal_children(children):
278 """Relay interupt/term signals to children, for more solid process cleanup."""
281 """Relay interupt/term signals to children, for more solid process cleanup."""
279 def terminate_children(sig, frame):
282 def terminate_children(sig, frame):
280 log = Application.instance().log
283 log = Application.instance().log
281 log.critical("Got signal %i, terminating children..."%sig)
284 log.critical("Got signal %i, terminating children..."%sig)
282 for child in children:
285 for child in children:
283 child.terminate()
286 child.terminate()
284
287
285 sys.exit(sig != SIGINT)
288 sys.exit(sig != SIGINT)
286 # sys.exit(sig)
289 # sys.exit(sig)
287 for sig in (SIGINT, SIGABRT, SIGTERM):
290 for sig in (SIGINT, SIGABRT, SIGTERM):
288 signal(sig, terminate_children)
291 signal(sig, terminate_children)
289
292
290 def generate_exec_key(keyfile):
293 def generate_exec_key(keyfile):
291 import uuid
294 import uuid
292 newkey = str(uuid.uuid4())
295 newkey = str(uuid.uuid4())
293 with open(keyfile, 'w') as f:
296 with open(keyfile, 'w') as f:
294 # f.write('ipython-key ')
297 # f.write('ipython-key ')
295 f.write(newkey+'\n')
298 f.write(newkey+'\n')
296 # set user-only RW permissions (0600)
299 # set user-only RW permissions (0600)
297 # this will have no effect on Windows
300 # this will have no effect on Windows
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
301 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299
302
300
303
301 def integer_loglevel(loglevel):
304 def integer_loglevel(loglevel):
302 try:
305 try:
303 loglevel = int(loglevel)
306 loglevel = int(loglevel)
304 except ValueError:
307 except ValueError:
305 if isinstance(loglevel, str):
308 if isinstance(loglevel, str):
306 loglevel = getattr(logging, loglevel)
309 loglevel = getattr(logging, loglevel)
307 return loglevel
310 return loglevel
308
311
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
312 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 logger = logging.getLogger(logname)
313 logger = logging.getLogger(logname)
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
314 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 # don't add a second PUBHandler
315 # don't add a second PUBHandler
313 return
316 return
314 loglevel = integer_loglevel(loglevel)
317 loglevel = integer_loglevel(loglevel)
315 lsock = context.socket(zmq.PUB)
318 lsock = context.socket(zmq.PUB)
316 lsock.connect(iface)
319 lsock.connect(iface)
317 handler = handlers.PUBHandler(lsock)
320 handler = handlers.PUBHandler(lsock)
318 handler.setLevel(loglevel)
321 handler.setLevel(loglevel)
319 handler.root_topic = root
322 handler.root_topic = root
320 logger.addHandler(handler)
323 logger.addHandler(handler)
321 logger.setLevel(loglevel)
324 logger.setLevel(loglevel)
322
325
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
326 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 logger = logging.getLogger()
327 logger = logging.getLogger()
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
328 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 # don't add a second PUBHandler
329 # don't add a second PUBHandler
327 return
330 return
328 loglevel = integer_loglevel(loglevel)
331 loglevel = integer_loglevel(loglevel)
329 lsock = context.socket(zmq.PUB)
332 lsock = context.socket(zmq.PUB)
330 lsock.connect(iface)
333 lsock.connect(iface)
331 handler = EnginePUBHandler(engine, lsock)
334 handler = EnginePUBHandler(engine, lsock)
332 handler.setLevel(loglevel)
335 handler.setLevel(loglevel)
333 logger.addHandler(handler)
336 logger.addHandler(handler)
334 logger.setLevel(loglevel)
337 logger.setLevel(loglevel)
335 return logger
338 return logger
336
339
337 def local_logger(logname, loglevel=logging.DEBUG):
340 def local_logger(logname, loglevel=logging.DEBUG):
338 loglevel = integer_loglevel(loglevel)
341 loglevel = integer_loglevel(loglevel)
339 logger = logging.getLogger(logname)
342 logger = logging.getLogger(logname)
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
343 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 # don't add a second StreamHandler
344 # don't add a second StreamHandler
342 return
345 return
343 handler = logging.StreamHandler()
346 handler = logging.StreamHandler()
344 handler.setLevel(loglevel)
347 handler.setLevel(loglevel)
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
348 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 datefmt="%Y-%m-%d %H:%M:%S")
349 datefmt="%Y-%m-%d %H:%M:%S")
347 handler.setFormatter(formatter)
350 handler.setFormatter(formatter)
348
351
349 logger.addHandler(handler)
352 logger.addHandler(handler)
350 logger.setLevel(loglevel)
353 logger.setLevel(loglevel)
351 return logger
354 return logger
352
355
General Comments 0
You need to be logged in to leave comments. Login now