##// END OF EJS Templates
Parallel: Support get/set of nested objects in view (e.g. dv['a.b'])
Bradley M. Froehle -
Show More
@@ -1,678 +1,696 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30
30
31 from IPython import parallel as pmod
31 from IPython import parallel as pmod
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import DirectView
34 from IPython.parallel import DirectView
35 from IPython.parallel.util import interactive
35 from IPython.parallel.util import interactive
36
36
37 from IPython.parallel.tests import add_engines
37 from IPython.parallel.tests import add_engines
38
38
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40
40
41 def setup():
41 def setup():
42 add_engines(3, total=True)
42 add_engines(3, total=True)
43
43
44 class TestView(ClusterTestCase, ParametricTestCase):
44 class TestView(ClusterTestCase, ParametricTestCase):
45
45
46 def setUp(self):
46 def setUp(self):
47 # On Win XP, wait for resource cleanup, else parallel test group fails
47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 time.sleep(2)
50 time.sleep(2)
51 super(TestView, self).setUp()
51 super(TestView, self).setUp()
52
52
53 def test_z_crash_mux(self):
53 def test_z_crash_mux(self):
54 """test graceful handling of engine death (direct)"""
54 """test graceful handling of engine death (direct)"""
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 # self.add_engines(1)
56 # self.add_engines(1)
57 eid = self.client.ids[-1]
57 eid = self.client.ids[-1]
58 ar = self.client[eid].apply_async(crash)
58 ar = self.client[eid].apply_async(crash)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 eid = ar.engine_id
60 eid = ar.engine_id
61 tic = time.time()
61 tic = time.time()
62 while eid in self.client.ids and time.time()-tic < 5:
62 while eid in self.client.ids and time.time()-tic < 5:
63 time.sleep(.01)
63 time.sleep(.01)
64 self.client.spin()
64 self.client.spin()
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66
66
67 def test_push_pull(self):
67 def test_push_pull(self):
68 """test pushing and pulling"""
68 """test pushing and pulling"""
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 t = self.client.ids[-1]
70 t = self.client.ids[-1]
71 v = self.client[t]
71 v = self.client[t]
72 push = v.push
72 push = v.push
73 pull = v.pull
73 pull = v.pull
74 v.block=True
74 v.block=True
75 nengines = len(self.client)
75 nengines = len(self.client)
76 push({'data':data})
76 push({'data':data})
77 d = pull('data')
77 d = pull('data')
78 self.assertEqual(d, data)
78 self.assertEqual(d, data)
79 self.client[:].push({'data':data})
79 self.client[:].push({'data':data})
80 d = self.client[:].pull('data', block=True)
80 d = self.client[:].pull('data', block=True)
81 self.assertEqual(d, nengines*[data])
81 self.assertEqual(d, nengines*[data])
82 ar = push({'data':data}, block=False)
82 ar = push({'data':data}, block=False)
83 self.assertTrue(isinstance(ar, AsyncResult))
83 self.assertTrue(isinstance(ar, AsyncResult))
84 r = ar.get()
84 r = ar.get()
85 ar = self.client[:].pull('data', block=False)
85 ar = self.client[:].pull('data', block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 self.assertEqual(r, nengines*[data])
88 self.assertEqual(r, nengines*[data])
89 self.client[:].push(dict(a=10,b=20))
89 self.client[:].push(dict(a=10,b=20))
90 r = self.client[:].pull(('a','b'), block=True)
90 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEqual(r, nengines*[[10,20]])
91 self.assertEqual(r, nengines*[[10,20]])
92
92
93 def test_push_pull_function(self):
93 def test_push_pull_function(self):
94 "test pushing and pulling functions"
94 "test pushing and pulling functions"
95 def testf(x):
95 def testf(x):
96 return 2.0*x
96 return 2.0*x
97
97
98 t = self.client.ids[-1]
98 t = self.client.ids[-1]
99 v = self.client[t]
99 v = self.client[t]
100 v.block=True
100 v.block=True
101 push = v.push
101 push = v.push
102 pull = v.pull
102 pull = v.pull
103 execute = v.execute
103 execute = v.execute
104 push({'testf':testf})
104 push({'testf':testf})
105 r = pull('testf')
105 r = pull('testf')
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute('r = testf(10)')
107 execute('r = testf(10)')
108 r = pull('r')
108 r = pull('r')
109 self.assertEqual(r, testf(10))
109 self.assertEqual(r, testf(10))
110 ar = self.client[:].push({'testf':testf}, block=False)
110 ar = self.client[:].push({'testf':testf}, block=False)
111 ar.get()
111 ar.get()
112 ar = self.client[:].pull('testf', block=False)
112 ar = self.client[:].pull('testf', block=False)
113 rlist = ar.get()
113 rlist = ar.get()
114 for r in rlist:
114 for r in rlist:
115 self.assertEqual(r(1.0), testf(1.0))
115 self.assertEqual(r(1.0), testf(1.0))
116 execute("def g(x): return x*x")
116 execute("def g(x): return x*x")
117 r = pull(('testf','g'))
117 r = pull(('testf','g'))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119
119
120 def test_push_function_globals(self):
120 def test_push_function_globals(self):
121 """test that pushed functions have access to globals"""
121 """test that pushed functions have access to globals"""
122 @interactive
122 @interactive
123 def geta():
123 def geta():
124 return a
124 return a
125 # self.add_engines(1)
125 # self.add_engines(1)
126 v = self.client[-1]
126 v = self.client[-1]
127 v.block=True
127 v.block=True
128 v['f'] = geta
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
130 v.execute('a=5')
131 v.execute('b=f()')
131 v.execute('b=f()')
132 self.assertEqual(v['b'], 5)
132 self.assertEqual(v['b'], 5)
133
133
134 def test_push_function_defaults(self):
134 def test_push_function_defaults(self):
135 """test that pushed functions preserve default args"""
135 """test that pushed functions preserve default args"""
136 def echo(a=10):
136 def echo(a=10):
137 return a
137 return a
138 v = self.client[-1]
138 v = self.client[-1]
139 v.block=True
139 v.block=True
140 v['f'] = echo
140 v['f'] = echo
141 v.execute('b=f()')
141 v.execute('b=f()')
142 self.assertEqual(v['b'], 10)
142 self.assertEqual(v['b'], 10)
143
143
144 def test_get_result(self):
144 def test_get_result(self):
145 """test getting results from the Hub."""
145 """test getting results from the Hub."""
146 c = pmod.Client(profile='iptest')
146 c = pmod.Client(profile='iptest')
147 # self.add_engines(1)
147 # self.add_engines(1)
148 t = c.ids[-1]
148 t = c.ids[-1]
149 v = c[t]
149 v = c[t]
150 v2 = self.client[t]
150 v2 = self.client[t]
151 ar = v.apply_async(wait, 1)
151 ar = v.apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = v2.get_result(ar.msg_ids)
154 ahr = v2.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEqual(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
157 ar2 = v2.get_result(ar.msg_ids)
157 ar2 = v2.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.spin()
159 c.spin()
160 c.close()
160 c.close()
161
161
162 def test_run_newline(self):
162 def test_run_newline(self):
163 """test that run appends newline to files"""
163 """test that run appends newline to files"""
164 tmpfile = mktemp()
164 tmpfile = mktemp()
165 with open(tmpfile, 'w') as f:
165 with open(tmpfile, 'w') as f:
166 f.write("""def g():
166 f.write("""def g():
167 return 5
167 return 5
168 """)
168 """)
169 v = self.client[-1]
169 v = self.client[-1]
170 v.run(tmpfile, block=True)
170 v.run(tmpfile, block=True)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172
172
173 def test_apply_tracked(self):
173 def test_apply_tracked(self):
174 """test tracking for apply"""
174 """test tracking for apply"""
175 # self.add_engines(1)
175 # self.add_engines(1)
176 t = self.client.ids[-1]
176 t = self.client.ids[-1]
177 v = self.client[t]
177 v = self.client[t]
178 v.block=False
178 v.block=False
179 def echo(n=1024*1024, **kwargs):
179 def echo(n=1024*1024, **kwargs):
180 with v.temp_flags(**kwargs):
180 with v.temp_flags(**kwargs):
181 return v.apply(lambda x: x, 'x'*n)
181 return v.apply(lambda x: x, 'x'*n)
182 ar = echo(1, track=False)
182 ar = echo(1, track=False)
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(ar.sent)
184 self.assertTrue(ar.sent)
185 ar = echo(track=True)
185 ar = echo(track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEqual(ar.sent, ar._tracker.done)
187 self.assertEqual(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
188 ar._tracker.wait()
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190
190
191 def test_push_tracked(self):
191 def test_push_tracked(self):
192 t = self.client.ids[-1]
192 t = self.client.ids[-1]
193 ns = dict(x='x'*1024*1024)
193 ns = dict(x='x'*1024*1024)
194 v = self.client[t]
194 v = self.client[t]
195 ar = v.push(ns, block=False, track=False)
195 ar = v.push(ns, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
197 self.assertTrue(ar.sent)
198
198
199 ar = v.push(ns, block=False, track=True)
199 ar = v.push(ns, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 ar._tracker.wait()
201 ar._tracker.wait()
202 self.assertEqual(ar.sent, ar._tracker.done)
202 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204 ar.get()
204 ar.get()
205
205
206 def test_scatter_tracked(self):
206 def test_scatter_tracked(self):
207 t = self.client.ids
207 t = self.client.ids
208 x='x'*1024*1024
208 x='x'*1024*1024
209 ar = self.client[t].scatter('x', x, block=False, track=False)
209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(ar.sent)
211 self.assertTrue(ar.sent)
212
212
213 ar = self.client[t].scatter('x', x, block=False, track=True)
213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEqual(ar.sent, ar._tracker.done)
215 self.assertEqual(ar.sent, ar._tracker.done)
216 ar._tracker.wait()
216 ar._tracker.wait()
217 self.assertTrue(ar.sent)
217 self.assertTrue(ar.sent)
218 ar.get()
218 ar.get()
219
219
220 def test_remote_reference(self):
220 def test_remote_reference(self):
221 v = self.client[-1]
221 v = self.client[-1]
222 v['a'] = 123
222 v['a'] = 123
223 ra = pmod.Reference('a')
223 ra = pmod.Reference('a')
224 b = v.apply_sync(lambda x: x, ra)
224 b = v.apply_sync(lambda x: x, ra)
225 self.assertEqual(b, 123)
225 self.assertEqual(b, 123)
226
226
227
227
228 def test_scatter_gather(self):
228 def test_scatter_gather(self):
229 view = self.client[:]
229 view = self.client[:]
230 seq1 = range(16)
230 seq1 = range(16)
231 view.scatter('a', seq1)
231 view.scatter('a', seq1)
232 seq2 = view.gather('a', block=True)
232 seq2 = view.gather('a', block=True)
233 self.assertEqual(seq2, seq1)
233 self.assertEqual(seq2, seq1)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235
235
236 @skip_without('numpy')
236 @skip_without('numpy')
237 def test_scatter_gather_numpy(self):
237 def test_scatter_gather_numpy(self):
238 import numpy
238 import numpy
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view.scatter('a', a, block=True)
242 view.scatter('a', a, block=True)
243 b = view.gather('a', block=True)
243 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
244 assert_array_equal(b, a)
245
245
246 def test_scatter_gather_lazy(self):
246 def test_scatter_gather_lazy(self):
247 """scatter/gather with targets='all'"""
247 """scatter/gather with targets='all'"""
248 view = self.client.direct_view(targets='all')
248 view = self.client.direct_view(targets='all')
249 x = range(64)
249 x = range(64)
250 view.scatter('x', x)
250 view.scatter('x', x)
251 gathered = view.gather('x', block=True)
251 gathered = view.gather('x', block=True)
252 self.assertEqual(gathered, x)
252 self.assertEqual(gathered, x)
253
253
254
254
255 @dec.known_failure_py3
255 @dec.known_failure_py3
256 @skip_without('numpy')
256 @skip_without('numpy')
257 def test_push_numpy_nocopy(self):
257 def test_push_numpy_nocopy(self):
258 import numpy
258 import numpy
259 view = self.client[:]
259 view = self.client[:]
260 a = numpy.arange(64)
260 a = numpy.arange(64)
261 view['A'] = a
261 view['A'] = a
262 @interactive
262 @interactive
263 def check_writeable(x):
263 def check_writeable(x):
264 return x.flags.writeable
264 return x.flags.writeable
265
265
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268
268
269 view.push(dict(B=a))
269 view.push(dict(B=a))
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272
272
273 @skip_without('numpy')
273 @skip_without('numpy')
274 def test_apply_numpy(self):
274 def test_apply_numpy(self):
275 """view.apply(f, ndarray)"""
275 """view.apply(f, ndarray)"""
276 import numpy
276 import numpy
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278
278
279 A = numpy.random.random((100,100))
279 A = numpy.random.random((100,100))
280 view = self.client[-1]
280 view = self.client[-1]
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 B = A.astype(dt)
282 B = A.astype(dt)
283 C = view.apply_sync(lambda x:x, B)
283 C = view.apply_sync(lambda x:x, B)
284 assert_array_equal(B,C)
284 assert_array_equal(B,C)
285
285
286 @skip_without('numpy')
286 @skip_without('numpy')
287 def test_push_pull_recarray(self):
287 def test_push_pull_recarray(self):
288 """push/pull recarrays"""
288 """push/pull recarrays"""
289 import numpy
289 import numpy
290 from numpy.testing.utils import assert_array_equal
290 from numpy.testing.utils import assert_array_equal
291
291
292 view = self.client[-1]
292 view = self.client[-1]
293
293
294 R = numpy.array([
294 R = numpy.array([
295 (1, 'hi', 0.),
295 (1, 'hi', 0.),
296 (2**30, 'there', 2.5),
296 (2**30, 'there', 2.5),
297 (-99999, 'world', -12345.6789),
297 (-99999, 'world', -12345.6789),
298 ], [('n', int), ('s', '|S10'), ('f', float)])
298 ], [('n', int), ('s', '|S10'), ('f', float)])
299
299
300 view['RR'] = R
300 view['RR'] = R
301 R2 = view['RR']
301 R2 = view['RR']
302
302
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEqual(r_dtype, R.dtype)
304 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_shape, R.shape)
305 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(R2.dtype, R.dtype)
306 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.shape, R.shape)
307 self.assertEqual(R2.shape, R.shape)
308 assert_array_equal(R2, R)
308 assert_array_equal(R2, R)
309
309
310 def test_map(self):
310 def test_map(self):
311 view = self.client[:]
311 view = self.client[:]
312 def f(x):
312 def f(x):
313 return x**2
313 return x**2
314 data = range(16)
314 data = range(16)
315 r = view.map_sync(f, data)
315 r = view.map_sync(f, data)
316 self.assertEqual(r, map(f, data))
316 self.assertEqual(r, map(f, data))
317
317
318 def test_map_iterable(self):
318 def test_map_iterable(self):
319 """test map on iterables (direct)"""
319 """test map on iterables (direct)"""
320 view = self.client[:]
320 view = self.client[:]
321 # 101 is prime, so it won't be evenly distributed
321 # 101 is prime, so it won't be evenly distributed
322 arr = range(101)
322 arr = range(101)
323 # ensure it will be an iterator, even in Python 3
323 # ensure it will be an iterator, even in Python 3
324 it = iter(arr)
324 it = iter(arr)
325 r = view.map_sync(lambda x:x, arr)
325 r = view.map_sync(lambda x:x, arr)
326 self.assertEqual(r, list(arr))
326 self.assertEqual(r, list(arr))
327
327
328 def test_scatter_gather_nonblocking(self):
328 def test_scatter_gather_nonblocking(self):
329 data = range(16)
329 data = range(16)
330 view = self.client[:]
330 view = self.client[:]
331 view.scatter('a', data, block=False)
331 view.scatter('a', data, block=False)
332 ar = view.gather('a', block=False)
332 ar = view.gather('a', block=False)
333 self.assertEqual(ar.get(), data)
333 self.assertEqual(ar.get(), data)
334
334
335 @skip_without('numpy')
335 @skip_without('numpy')
336 def test_scatter_gather_numpy_nonblocking(self):
336 def test_scatter_gather_numpy_nonblocking(self):
337 import numpy
337 import numpy
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 a = numpy.arange(64)
339 a = numpy.arange(64)
340 view = self.client[:]
340 view = self.client[:]
341 ar = view.scatter('a', a, block=False)
341 ar = view.scatter('a', a, block=False)
342 self.assertTrue(isinstance(ar, AsyncResult))
342 self.assertTrue(isinstance(ar, AsyncResult))
343 amr = view.gather('a', block=False)
343 amr = view.gather('a', block=False)
344 self.assertTrue(isinstance(amr, AsyncMapResult))
344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 assert_array_equal(amr.get(), a)
345 assert_array_equal(amr.get(), a)
346
346
347 def test_execute(self):
347 def test_execute(self):
348 view = self.client[:]
348 view = self.client[:]
349 # self.client.debug=True
349 # self.client.debug=True
350 execute = view.execute
350 execute = view.execute
351 ar = execute('c=30', block=False)
351 ar = execute('c=30', block=False)
352 self.assertTrue(isinstance(ar, AsyncResult))
352 self.assertTrue(isinstance(ar, AsyncResult))
353 ar = execute('d=[0,1,2]', block=False)
353 ar = execute('d=[0,1,2]', block=False)
354 self.client.wait(ar, 1)
354 self.client.wait(ar, 1)
355 self.assertEqual(len(ar.get()), len(self.client))
355 self.assertEqual(len(ar.get()), len(self.client))
356 for c in view['c']:
356 for c in view['c']:
357 self.assertEqual(c, 30)
357 self.assertEqual(c, 30)
358
358
359 def test_abort(self):
359 def test_abort(self):
360 view = self.client[-1]
360 view = self.client[-1]
361 ar = view.execute('import time; time.sleep(1)', block=False)
361 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar2 = view.apply_async(lambda : 2)
362 ar2 = view.apply_async(lambda : 2)
363 ar3 = view.apply_async(lambda : 3)
363 ar3 = view.apply_async(lambda : 3)
364 view.abort(ar2)
364 view.abort(ar2)
365 view.abort(ar3.msg_ids)
365 view.abort(ar3.msg_ids)
366 self.assertRaises(error.TaskAborted, ar2.get)
366 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
368
368
369 def test_abort_all(self):
369 def test_abort_all(self):
370 """view.abort() aborts all outstanding tasks"""
370 """view.abort() aborts all outstanding tasks"""
371 view = self.client[-1]
371 view = self.client[-1]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 view.abort()
373 view.abort()
374 view.wait(timeout=5)
374 view.wait(timeout=5)
375 for ar in ars[5:]:
375 for ar in ars[5:]:
376 self.assertRaises(error.TaskAborted, ar.get)
376 self.assertRaises(error.TaskAborted, ar.get)
377
377
378 def test_temp_flags(self):
378 def test_temp_flags(self):
379 view = self.client[-1]
379 view = self.client[-1]
380 view.block=True
380 view.block=True
381 with view.temp_flags(block=False):
381 with view.temp_flags(block=False):
382 self.assertFalse(view.block)
382 self.assertFalse(view.block)
383 self.assertTrue(view.block)
383 self.assertTrue(view.block)
384
384
385 @dec.known_failure_py3
385 @dec.known_failure_py3
386 def test_importer(self):
386 def test_importer(self):
387 view = self.client[-1]
387 view = self.client[-1]
388 view.clear(block=True)
388 view.clear(block=True)
389 with view.importer:
389 with view.importer:
390 import re
390 import re
391
391
392 @interactive
392 @interactive
393 def findall(pat, s):
393 def findall(pat, s):
394 # this globals() step isn't necessary in real code
394 # this globals() step isn't necessary in real code
395 # only to prevent a closure in the test
395 # only to prevent a closure in the test
396 re = globals()['re']
396 re = globals()['re']
397 return re.findall(pat, s)
397 return re.findall(pat, s)
398
398
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400
400
401 def test_unicode_execute(self):
401 def test_unicode_execute(self):
402 """test executing unicode strings"""
402 """test executing unicode strings"""
403 v = self.client[-1]
403 v = self.client[-1]
404 v.block=True
404 v.block=True
405 if sys.version_info[0] >= 3:
405 if sys.version_info[0] >= 3:
406 code="a='é'"
406 code="a='é'"
407 else:
407 else:
408 code=u"a=u'é'"
408 code=u"a=u'é'"
409 v.execute(code)
409 v.execute(code)
410 self.assertEqual(v['a'], u'é')
410 self.assertEqual(v['a'], u'é')
411
411
412 def test_unicode_apply_result(self):
412 def test_unicode_apply_result(self):
413 """test unicode apply results"""
413 """test unicode apply results"""
414 v = self.client[-1]
414 v = self.client[-1]
415 r = v.apply_sync(lambda : u'é')
415 r = v.apply_sync(lambda : u'é')
416 self.assertEqual(r, u'é')
416 self.assertEqual(r, u'é')
417
417
418 def test_unicode_apply_arg(self):
418 def test_unicode_apply_arg(self):
419 """test passing unicode arguments to apply"""
419 """test passing unicode arguments to apply"""
420 v = self.client[-1]
420 v = self.client[-1]
421
421
422 @interactive
422 @interactive
423 def check_unicode(a, check):
423 def check_unicode(a, check):
424 assert isinstance(a, unicode), "%r is not unicode"%a
424 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(check, bytes), "%r is not bytes"%check
425 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427
427
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 try:
429 try:
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 except error.RemoteError as e:
431 except error.RemoteError as e:
432 if e.ename == 'AssertionError':
432 if e.ename == 'AssertionError':
433 self.fail(e.evalue)
433 self.fail(e.evalue)
434 else:
434 else:
435 raise e
435 raise e
436
436
437 def test_map_reference(self):
437 def test_map_reference(self):
438 """view.map(<Reference>, *seqs) should work"""
438 """view.map(<Reference>, *seqs) should work"""
439 v = self.client[:]
439 v = self.client[:]
440 v.scatter('n', self.client.ids, flatten=True)
440 v.scatter('n', self.client.ids, flatten=True)
441 v.execute("f = lambda x,y: x*y")
441 v.execute("f = lambda x,y: x*y")
442 rf = pmod.Reference('f')
442 rf = pmod.Reference('f')
443 nlist = list(range(10))
443 nlist = list(range(10))
444 mlist = nlist[::-1]
444 mlist = nlist[::-1]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 result = v.map_sync(rf, mlist, nlist)
446 result = v.map_sync(rf, mlist, nlist)
447 self.assertEqual(result, expected)
447 self.assertEqual(result, expected)
448
448
449 def test_apply_reference(self):
449 def test_apply_reference(self):
450 """view.apply(<Reference>, *args) should work"""
450 """view.apply(<Reference>, *args) should work"""
451 v = self.client[:]
451 v = self.client[:]
452 v.scatter('n', self.client.ids, flatten=True)
452 v.scatter('n', self.client.ids, flatten=True)
453 v.execute("f = lambda x: n*x")
453 v.execute("f = lambda x: n*x")
454 rf = pmod.Reference('f')
454 rf = pmod.Reference('f')
455 result = v.apply_sync(rf, 5)
455 result = v.apply_sync(rf, 5)
456 expected = [ 5*id for id in self.client.ids ]
456 expected = [ 5*id for id in self.client.ids ]
457 self.assertEqual(result, expected)
457 self.assertEqual(result, expected)
458
458
459 def test_eval_reference(self):
459 def test_eval_reference(self):
460 v = self.client[self.client.ids[0]]
460 v = self.client[self.client.ids[0]]
461 v['g'] = range(5)
461 v['g'] = range(5)
462 rg = pmod.Reference('g[0]')
462 rg = pmod.Reference('g[0]')
463 echo = lambda x:x
463 echo = lambda x:x
464 self.assertEqual(v.apply_sync(echo, rg), 0)
464 self.assertEqual(v.apply_sync(echo, rg), 0)
465
465
466 def test_reference_nameerror(self):
466 def test_reference_nameerror(self):
467 v = self.client[self.client.ids[0]]
467 v = self.client[self.client.ids[0]]
468 r = pmod.Reference('elvis_has_left')
468 r = pmod.Reference('elvis_has_left')
469 echo = lambda x:x
469 echo = lambda x:x
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471
471
472 def test_single_engine_map(self):
472 def test_single_engine_map(self):
473 e0 = self.client[self.client.ids[0]]
473 e0 = self.client[self.client.ids[0]]
474 r = range(5)
474 r = range(5)
475 check = [ -1*i for i in r ]
475 check = [ -1*i for i in r ]
476 result = e0.map_sync(lambda x: -1*x, r)
476 result = e0.map_sync(lambda x: -1*x, r)
477 self.assertEqual(result, check)
477 self.assertEqual(result, check)
478
478
479 def test_len(self):
479 def test_len(self):
480 """len(view) makes sense"""
480 """len(view) makes sense"""
481 e0 = self.client[self.client.ids[0]]
481 e0 = self.client[self.client.ids[0]]
482 yield self.assertEqual(len(e0), 1)
482 yield self.assertEqual(len(e0), 1)
483 v = self.client[:]
483 v = self.client[:]
484 yield self.assertEqual(len(v), len(self.client.ids))
484 yield self.assertEqual(len(v), len(self.client.ids))
485 v = self.client.direct_view('all')
485 v = self.client.direct_view('all')
486 yield self.assertEqual(len(v), len(self.client.ids))
486 yield self.assertEqual(len(v), len(self.client.ids))
487 v = self.client[:2]
487 v = self.client[:2]
488 yield self.assertEqual(len(v), 2)
488 yield self.assertEqual(len(v), 2)
489 v = self.client[:1]
489 v = self.client[:1]
490 yield self.assertEqual(len(v), 1)
490 yield self.assertEqual(len(v), 1)
491 v = self.client.load_balanced_view()
491 v = self.client.load_balanced_view()
492 yield self.assertEqual(len(v), len(self.client.ids))
492 yield self.assertEqual(len(v), len(self.client.ids))
493 # parametric tests seem to require manual closing?
493 # parametric tests seem to require manual closing?
494 self.client.close()
494 self.client.close()
495
495
496
496
497 # begin execute tests
497 # begin execute tests
498
498
499 def test_execute_reply(self):
499 def test_execute_reply(self):
500 e0 = self.client[self.client.ids[0]]
500 e0 = self.client[self.client.ids[0]]
501 e0.block = True
501 e0.block = True
502 ar = e0.execute("5", silent=False)
502 ar = e0.execute("5", silent=False)
503 er = ar.get()
503 er = ar.get()
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506
506
507 def test_execute_reply_stdout(self):
507 def test_execute_reply_stdout(self):
508 e0 = self.client[self.client.ids[0]]
508 e0 = self.client[self.client.ids[0]]
509 e0.block = True
509 e0.block = True
510 ar = e0.execute("print (5)", silent=False)
510 ar = e0.execute("print (5)", silent=False)
511 er = ar.get()
511 er = ar.get()
512 self.assertEqual(er.stdout.strip(), '5')
512 self.assertEqual(er.stdout.strip(), '5')
513
513
514 def test_execute_pyout(self):
514 def test_execute_pyout(self):
515 """execute triggers pyout with silent=False"""
515 """execute triggers pyout with silent=False"""
516 view = self.client[:]
516 view = self.client[:]
517 ar = view.execute("5", silent=False, block=True)
517 ar = view.execute("5", silent=False, block=True)
518
518
519 expected = [{'text/plain' : '5'}] * len(view)
519 expected = [{'text/plain' : '5'}] * len(view)
520 mimes = [ out['data'] for out in ar.pyout ]
520 mimes = [ out['data'] for out in ar.pyout ]
521 self.assertEqual(mimes, expected)
521 self.assertEqual(mimes, expected)
522
522
523 def test_execute_silent(self):
523 def test_execute_silent(self):
524 """execute does not trigger pyout with silent=True"""
524 """execute does not trigger pyout with silent=True"""
525 view = self.client[:]
525 view = self.client[:]
526 ar = view.execute("5", block=True)
526 ar = view.execute("5", block=True)
527 expected = [None] * len(view)
527 expected = [None] * len(view)
528 self.assertEqual(ar.pyout, expected)
528 self.assertEqual(ar.pyout, expected)
529
529
530 def test_execute_magic(self):
530 def test_execute_magic(self):
531 """execute accepts IPython commands"""
531 """execute accepts IPython commands"""
532 view = self.client[:]
532 view = self.client[:]
533 view.execute("a = 5")
533 view.execute("a = 5")
534 ar = view.execute("%whos", block=True)
534 ar = view.execute("%whos", block=True)
535 # this will raise, if that failed
535 # this will raise, if that failed
536 ar.get(5)
536 ar.get(5)
537 for stdout in ar.stdout:
537 for stdout in ar.stdout:
538 lines = stdout.splitlines()
538 lines = stdout.splitlines()
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 found = False
540 found = False
541 for line in lines[2:]:
541 for line in lines[2:]:
542 split = line.split()
542 split = line.split()
543 if split == ['a', 'int', '5']:
543 if split == ['a', 'int', '5']:
544 found = True
544 found = True
545 break
545 break
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547
547
548 def test_execute_displaypub(self):
548 def test_execute_displaypub(self):
549 """execute tracks display_pub output"""
549 """execute tracks display_pub output"""
550 view = self.client[:]
550 view = self.client[:]
551 view.execute("from IPython.core.display import *")
551 view.execute("from IPython.core.display import *")
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553
553
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 for outputs in ar.outputs:
555 for outputs in ar.outputs:
556 mimes = [ out['data'] for out in outputs ]
556 mimes = [ out['data'] for out in outputs ]
557 self.assertEqual(mimes, expected)
557 self.assertEqual(mimes, expected)
558
558
559 def test_apply_displaypub(self):
559 def test_apply_displaypub(self):
560 """apply tracks display_pub output"""
560 """apply tracks display_pub output"""
561 view = self.client[:]
561 view = self.client[:]
562 view.execute("from IPython.core.display import *")
562 view.execute("from IPython.core.display import *")
563
563
564 @interactive
564 @interactive
565 def publish():
565 def publish():
566 [ display(i) for i in range(5) ]
566 [ display(i) for i in range(5) ]
567
567
568 ar = view.apply_async(publish)
568 ar = view.apply_async(publish)
569 ar.get(5)
569 ar.get(5)
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 for outputs in ar.outputs:
571 for outputs in ar.outputs:
572 mimes = [ out['data'] for out in outputs ]
572 mimes = [ out['data'] for out in outputs ]
573 self.assertEqual(mimes, expected)
573 self.assertEqual(mimes, expected)
574
574
575 def test_execute_raises(self):
575 def test_execute_raises(self):
576 """exceptions in execute requests raise appropriately"""
576 """exceptions in execute requests raise appropriately"""
577 view = self.client[-1]
577 view = self.client[-1]
578 ar = view.execute("1/0")
578 ar = view.execute("1/0")
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580
580
581 @dec.skipif_not_matplotlib
581 @dec.skipif_not_matplotlib
582 def test_magic_pylab(self):
582 def test_magic_pylab(self):
583 """%pylab works on engines"""
583 """%pylab works on engines"""
584 view = self.client[-1]
584 view = self.client[-1]
585 ar = view.execute("%pylab inline")
585 ar = view.execute("%pylab inline")
586 # at least check if this raised:
586 # at least check if this raised:
587 reply = ar.get(5)
587 reply = ar.get(5)
588 # include imports, in case user config
588 # include imports, in case user config
589 ar = view.execute("plot(rand(100))", silent=False)
589 ar = view.execute("plot(rand(100))", silent=False)
590 reply = ar.get(5)
590 reply = ar.get(5)
591 self.assertEqual(len(reply.outputs), 1)
591 self.assertEqual(len(reply.outputs), 1)
592 output = reply.outputs[0]
592 output = reply.outputs[0]
593 self.assertTrue("data" in output)
593 self.assertTrue("data" in output)
594 data = output['data']
594 data = output['data']
595 self.assertTrue("image/png" in data)
595 self.assertTrue("image/png" in data)
596
596
597 def test_func_default_func(self):
597 def test_func_default_func(self):
598 """interactively defined function as apply func default"""
598 """interactively defined function as apply func default"""
599 def foo():
599 def foo():
600 return 'foo'
600 return 'foo'
601
601
602 def bar(f=foo):
602 def bar(f=foo):
603 return f()
603 return f()
604
604
605 view = self.client[-1]
605 view = self.client[-1]
606 ar = view.apply_async(bar)
606 ar = view.apply_async(bar)
607 r = ar.get(10)
607 r = ar.get(10)
608 self.assertEqual(r, 'foo')
608 self.assertEqual(r, 'foo')
609 def test_data_pub_single(self):
609 def test_data_pub_single(self):
610 view = self.client[-1]
610 view = self.client[-1]
611 ar = view.execute('\n'.join([
611 ar = view.execute('\n'.join([
612 'from IPython.zmq.datapub import publish_data',
612 'from IPython.zmq.datapub import publish_data',
613 'for i in range(5):',
613 'for i in range(5):',
614 ' publish_data(dict(i=i))'
614 ' publish_data(dict(i=i))'
615 ]), block=False)
615 ]), block=False)
616 self.assertTrue(isinstance(ar.data, dict))
616 self.assertTrue(isinstance(ar.data, dict))
617 ar.get(5)
617 ar.get(5)
618 self.assertEqual(ar.data, dict(i=4))
618 self.assertEqual(ar.data, dict(i=4))
619
619
620 def test_data_pub(self):
620 def test_data_pub(self):
621 view = self.client[:]
621 view = self.client[:]
622 ar = view.execute('\n'.join([
622 ar = view.execute('\n'.join([
623 'from IPython.zmq.datapub import publish_data',
623 'from IPython.zmq.datapub import publish_data',
624 'for i in range(5):',
624 'for i in range(5):',
625 ' publish_data(dict(i=i))'
625 ' publish_data(dict(i=i))'
626 ]), block=False)
626 ]), block=False)
627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
628 ar.get(5)
628 ar.get(5)
629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
630
630
631 def test_can_list_arg(self):
631 def test_can_list_arg(self):
632 """args in lists are canned"""
632 """args in lists are canned"""
633 view = self.client[-1]
633 view = self.client[-1]
634 view['a'] = 128
634 view['a'] = 128
635 rA = pmod.Reference('a')
635 rA = pmod.Reference('a')
636 ar = view.apply_async(lambda x: x, [rA])
636 ar = view.apply_async(lambda x: x, [rA])
637 r = ar.get(5)
637 r = ar.get(5)
638 self.assertEqual(r, [128])
638 self.assertEqual(r, [128])
639
639
640 def test_can_dict_arg(self):
640 def test_can_dict_arg(self):
641 """args in dicts are canned"""
641 """args in dicts are canned"""
642 view = self.client[-1]
642 view = self.client[-1]
643 view['a'] = 128
643 view['a'] = 128
644 rA = pmod.Reference('a')
644 rA = pmod.Reference('a')
645 ar = view.apply_async(lambda x: x, dict(foo=rA))
645 ar = view.apply_async(lambda x: x, dict(foo=rA))
646 r = ar.get(5)
646 r = ar.get(5)
647 self.assertEqual(r, dict(foo=128))
647 self.assertEqual(r, dict(foo=128))
648
648
649 def test_can_list_kwarg(self):
649 def test_can_list_kwarg(self):
650 """kwargs in lists are canned"""
650 """kwargs in lists are canned"""
651 view = self.client[-1]
651 view = self.client[-1]
652 view['a'] = 128
652 view['a'] = 128
653 rA = pmod.Reference('a')
653 rA = pmod.Reference('a')
654 ar = view.apply_async(lambda x=5: x, x=[rA])
654 ar = view.apply_async(lambda x=5: x, x=[rA])
655 r = ar.get(5)
655 r = ar.get(5)
656 self.assertEqual(r, [128])
656 self.assertEqual(r, [128])
657
657
658 def test_can_dict_kwarg(self):
658 def test_can_dict_kwarg(self):
659 """kwargs in dicts are canned"""
659 """kwargs in dicts are canned"""
660 view = self.client[-1]
660 view = self.client[-1]
661 view['a'] = 128
661 view['a'] = 128
662 rA = pmod.Reference('a')
662 rA = pmod.Reference('a')
663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
664 r = ar.get(5)
664 r = ar.get(5)
665 self.assertEqual(r, dict(foo=128))
665 self.assertEqual(r, dict(foo=128))
666
666
667 def test_map_ref(self):
667 def test_map_ref(self):
668 """view.map works with references"""
668 """view.map works with references"""
669 view = self.client[:]
669 view = self.client[:]
670 ranks = sorted(self.client.ids)
670 ranks = sorted(self.client.ids)
671 view.scatter('rank', ranks, flatten=True)
671 view.scatter('rank', ranks, flatten=True)
672 rrank = pmod.Reference('rank')
672 rrank = pmod.Reference('rank')
673
673
674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
675 drank = amr.get(5)
675 drank = amr.get(5)
676 self.assertEqual(drank, [ r*2 for r in ranks ])
676 self.assertEqual(drank, [ r*2 for r in ranks ])
677
677
678 def test_nested_getitem_setitem(self):
679 """get and set with view['a.b']"""
680 view = self.client[-1]
681 view.execute('\n'.join([
682 'class A(object): pass',
683 'a = A()',
684 'a.b = 128',
685 ]), block=True)
686 ra = pmod.Reference('a')
687
688 r = view.apply_sync(lambda x: x.b, ra)
689 self.assertEqual(r, 128)
690 self.assertEqual(view['a.b'], 128)
691
692 view['a.b'] = 0
678
693
694 r = view.apply_sync(lambda x: x.b, ra)
695 self.assertEqual(r, 0)
696 self.assertEqual(view['a.b'], 0)
@@ -1,352 +1,355 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.zmq.log import EnginePUBHandler
46 from IPython.zmq.log import EnginePUBHandler
47 from IPython.zmq.serialize import (
47 from IPython.zmq.serialize import (
48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
49 )
49 )
50
50
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52 # Classes
52 # Classes
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54
54
55 class Namespace(dict):
55 class Namespace(dict):
56 """Subclass of dict for attribute access to keys."""
56 """Subclass of dict for attribute access to keys."""
57
57
58 def __getattr__(self, key):
58 def __getattr__(self, key):
59 """getattr aliased to getitem"""
59 """getattr aliased to getitem"""
60 if key in self.iterkeys():
60 if key in self.iterkeys():
61 return self[key]
61 return self[key]
62 else:
62 else:
63 raise NameError(key)
63 raise NameError(key)
64
64
65 def __setattr__(self, key, value):
65 def __setattr__(self, key, value):
66 """setattr aliased to setitem, with strict"""
66 """setattr aliased to setitem, with strict"""
67 if hasattr(dict, key):
67 if hasattr(dict, key):
68 raise KeyError("Cannot override dict keys %r"%key)
68 raise KeyError("Cannot override dict keys %r"%key)
69 self[key] = value
69 self[key] = value
70
70
71
71
72 class ReverseDict(dict):
72 class ReverseDict(dict):
73 """simple double-keyed subset of dict methods."""
73 """simple double-keyed subset of dict methods."""
74
74
75 def __init__(self, *args, **kwargs):
75 def __init__(self, *args, **kwargs):
76 dict.__init__(self, *args, **kwargs)
76 dict.__init__(self, *args, **kwargs)
77 self._reverse = dict()
77 self._reverse = dict()
78 for key, value in self.iteritems():
78 for key, value in self.iteritems():
79 self._reverse[value] = key
79 self._reverse[value] = key
80
80
81 def __getitem__(self, key):
81 def __getitem__(self, key):
82 try:
82 try:
83 return dict.__getitem__(self, key)
83 return dict.__getitem__(self, key)
84 except KeyError:
84 except KeyError:
85 return self._reverse[key]
85 return self._reverse[key]
86
86
87 def __setitem__(self, key, value):
87 def __setitem__(self, key, value):
88 if key in self._reverse:
88 if key in self._reverse:
89 raise KeyError("Can't have key %r on both sides!"%key)
89 raise KeyError("Can't have key %r on both sides!"%key)
90 dict.__setitem__(self, key, value)
90 dict.__setitem__(self, key, value)
91 self._reverse[value] = key
91 self._reverse[value] = key
92
92
93 def pop(self, key):
93 def pop(self, key):
94 value = dict.pop(self, key)
94 value = dict.pop(self, key)
95 self._reverse.pop(value)
95 self._reverse.pop(value)
96 return value
96 return value
97
97
98 def get(self, key, default=None):
98 def get(self, key, default=None):
99 try:
99 try:
100 return self[key]
100 return self[key]
101 except KeyError:
101 except KeyError:
102 return default
102 return default
103
103
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 # Functions
105 # Functions
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107
107
108 @decorator
108 @decorator
109 def log_errors(f, self, *args, **kwargs):
109 def log_errors(f, self, *args, **kwargs):
110 """decorator to log unhandled exceptions raised in a method.
110 """decorator to log unhandled exceptions raised in a method.
111
111
112 For use wrapping on_recv callbacks, so that exceptions
112 For use wrapping on_recv callbacks, so that exceptions
113 do not cause the stream to be closed.
113 do not cause the stream to be closed.
114 """
114 """
115 try:
115 try:
116 return f(self, *args, **kwargs)
116 return f(self, *args, **kwargs)
117 except Exception:
117 except Exception:
118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
119
119
120
120
121 def is_url(url):
121 def is_url(url):
122 """boolean check for whether a string is a zmq url"""
122 """boolean check for whether a string is a zmq url"""
123 if '://' not in url:
123 if '://' not in url:
124 return False
124 return False
125 proto, addr = url.split('://', 1)
125 proto, addr = url.split('://', 1)
126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
127 return False
127 return False
128 return True
128 return True
129
129
130 def validate_url(url):
130 def validate_url(url):
131 """validate a url for zeromq"""
131 """validate a url for zeromq"""
132 if not isinstance(url, basestring):
132 if not isinstance(url, basestring):
133 raise TypeError("url must be a string, not %r"%type(url))
133 raise TypeError("url must be a string, not %r"%type(url))
134 url = url.lower()
134 url = url.lower()
135
135
136 proto_addr = url.split('://')
136 proto_addr = url.split('://')
137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
138 proto, addr = proto_addr
138 proto, addr = proto_addr
139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
140
140
141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
142 # author: Remi Sabourin
142 # author: Remi Sabourin
143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
144
144
145 if proto == 'tcp':
145 if proto == 'tcp':
146 lis = addr.split(':')
146 lis = addr.split(':')
147 assert len(lis) == 2, 'Invalid url: %r'%url
147 assert len(lis) == 2, 'Invalid url: %r'%url
148 addr,s_port = lis
148 addr,s_port = lis
149 try:
149 try:
150 port = int(s_port)
150 port = int(s_port)
151 except ValueError:
151 except ValueError:
152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
153
153
154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
155
155
156 else:
156 else:
157 # only validate tcp urls currently
157 # only validate tcp urls currently
158 pass
158 pass
159
159
160 return True
160 return True
161
161
162
162
163 def validate_url_container(container):
163 def validate_url_container(container):
164 """validate a potentially nested collection of urls."""
164 """validate a potentially nested collection of urls."""
165 if isinstance(container, basestring):
165 if isinstance(container, basestring):
166 url = container
166 url = container
167 return validate_url(url)
167 return validate_url(url)
168 elif isinstance(container, dict):
168 elif isinstance(container, dict):
169 container = container.itervalues()
169 container = container.itervalues()
170
170
171 for element in container:
171 for element in container:
172 validate_url_container(element)
172 validate_url_container(element)
173
173
174
174
175 def split_url(url):
175 def split_url(url):
176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
177 proto_addr = url.split('://')
177 proto_addr = url.split('://')
178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
179 proto, addr = proto_addr
179 proto, addr = proto_addr
180 lis = addr.split(':')
180 lis = addr.split(':')
181 assert len(lis) == 2, 'Invalid url: %r'%url
181 assert len(lis) == 2, 'Invalid url: %r'%url
182 addr,s_port = lis
182 addr,s_port = lis
183 return proto,addr,s_port
183 return proto,addr,s_port
184
184
185 def disambiguate_ip_address(ip, location=None):
185 def disambiguate_ip_address(ip, location=None):
186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
187 ones, based on the location (default interpretation of location is localhost)."""
187 ones, based on the location (default interpretation of location is localhost)."""
188 if ip in ('0.0.0.0', '*'):
188 if ip in ('0.0.0.0', '*'):
189 try:
189 try:
190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
191 except (socket.gaierror, IndexError):
191 except (socket.gaierror, IndexError):
192 # couldn't identify this machine, assume localhost
192 # couldn't identify this machine, assume localhost
193 external_ips = []
193 external_ips = []
194 if location is None or location in external_ips or not external_ips:
194 if location is None or location in external_ips or not external_ips:
195 # If location is unspecified or cannot be determined, assume local
195 # If location is unspecified or cannot be determined, assume local
196 ip='127.0.0.1'
196 ip='127.0.0.1'
197 elif location:
197 elif location:
198 return location
198 return location
199 return ip
199 return ip
200
200
201 def disambiguate_url(url, location=None):
201 def disambiguate_url(url, location=None):
202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
203 ones, based on the location (default interpretation is localhost).
203 ones, based on the location (default interpretation is localhost).
204
204
205 This is for zeromq urls, such as tcp://*:10101."""
205 This is for zeromq urls, such as tcp://*:10101."""
206 try:
206 try:
207 proto,ip,port = split_url(url)
207 proto,ip,port = split_url(url)
208 except AssertionError:
208 except AssertionError:
209 # probably not tcp url; could be ipc, etc.
209 # probably not tcp url; could be ipc, etc.
210 return url
210 return url
211
211
212 ip = disambiguate_ip_address(ip,location)
212 ip = disambiguate_ip_address(ip,location)
213
213
214 return "%s://%s:%s"%(proto,ip,port)
214 return "%s://%s:%s"%(proto,ip,port)
215
215
216
216
217 #--------------------------------------------------------------------------
217 #--------------------------------------------------------------------------
218 # helpers for implementing old MEC API via view.apply
218 # helpers for implementing old MEC API via view.apply
219 #--------------------------------------------------------------------------
219 #--------------------------------------------------------------------------
220
220
221 def interactive(f):
221 def interactive(f):
222 """decorator for making functions appear as interactively defined.
222 """decorator for making functions appear as interactively defined.
223 This results in the function being linked to the user_ns as globals()
223 This results in the function being linked to the user_ns as globals()
224 instead of the module globals().
224 instead of the module globals().
225 """
225 """
226 f.__module__ = '__main__'
226 f.__module__ = '__main__'
227 return f
227 return f
228
228
229 @interactive
229 @interactive
230 def _push(**ns):
230 def _push(**ns):
231 """helper method for implementing `client.push` via `client.apply`"""
231 """helper method for implementing `client.push` via `client.apply`"""
232 globals().update(ns)
232 user_ns = globals()
233 tmp = '_IP_PUSH_TMP_'
234 while tmp in user_ns:
235 tmp = tmp + '_'
236 try:
237 for name, value in ns.iteritems():
238 user_ns[tmp] = value
239 exec "%s = %s" % (name, tmp) in user_ns
240 finally:
241 user_ns.pop(tmp, None)
233
242
234 @interactive
243 @interactive
235 def _pull(keys):
244 def _pull(keys):
236 """helper method for implementing `client.pull` via `client.apply`"""
245 """helper method for implementing `client.pull` via `client.apply`"""
237 user_ns = globals()
238 if isinstance(keys, (list,tuple, set)):
246 if isinstance(keys, (list,tuple, set)):
239 for key in keys:
247 return map(lambda key: eval(key, globals()), keys)
240 if key not in user_ns:
241 raise NameError("name '%s' is not defined"%key)
242 return map(user_ns.get, keys)
243 else:
248 else:
244 if keys not in user_ns:
249 return eval(keys, globals())
245 raise NameError("name '%s' is not defined"%keys)
246 return user_ns.get(keys)
247
250
248 @interactive
251 @interactive
249 def _execute(code):
252 def _execute(code):
250 """helper method for implementing `client.execute` via `client.apply`"""
253 """helper method for implementing `client.execute` via `client.apply`"""
251 exec code in globals()
254 exec code in globals()
252
255
253 #--------------------------------------------------------------------------
256 #--------------------------------------------------------------------------
254 # extra process management utilities
257 # extra process management utilities
255 #--------------------------------------------------------------------------
258 #--------------------------------------------------------------------------
256
259
257 _random_ports = set()
260 _random_ports = set()
258
261
259 def select_random_ports(n):
262 def select_random_ports(n):
260 """Selects and return n random ports that are available."""
263 """Selects and return n random ports that are available."""
261 ports = []
264 ports = []
262 for i in xrange(n):
265 for i in xrange(n):
263 sock = socket.socket()
266 sock = socket.socket()
264 sock.bind(('', 0))
267 sock.bind(('', 0))
265 while sock.getsockname()[1] in _random_ports:
268 while sock.getsockname()[1] in _random_ports:
266 sock.close()
269 sock.close()
267 sock = socket.socket()
270 sock = socket.socket()
268 sock.bind(('', 0))
271 sock.bind(('', 0))
269 ports.append(sock)
272 ports.append(sock)
270 for i, sock in enumerate(ports):
273 for i, sock in enumerate(ports):
271 port = sock.getsockname()[1]
274 port = sock.getsockname()[1]
272 sock.close()
275 sock.close()
273 ports[i] = port
276 ports[i] = port
274 _random_ports.add(port)
277 _random_ports.add(port)
275 return ports
278 return ports
276
279
277 def signal_children(children):
280 def signal_children(children):
278 """Relay interupt/term signals to children, for more solid process cleanup."""
281 """Relay interupt/term signals to children, for more solid process cleanup."""
279 def terminate_children(sig, frame):
282 def terminate_children(sig, frame):
280 log = Application.instance().log
283 log = Application.instance().log
281 log.critical("Got signal %i, terminating children..."%sig)
284 log.critical("Got signal %i, terminating children..."%sig)
282 for child in children:
285 for child in children:
283 child.terminate()
286 child.terminate()
284
287
285 sys.exit(sig != SIGINT)
288 sys.exit(sig != SIGINT)
286 # sys.exit(sig)
289 # sys.exit(sig)
287 for sig in (SIGINT, SIGABRT, SIGTERM):
290 for sig in (SIGINT, SIGABRT, SIGTERM):
288 signal(sig, terminate_children)
291 signal(sig, terminate_children)
289
292
290 def generate_exec_key(keyfile):
293 def generate_exec_key(keyfile):
291 import uuid
294 import uuid
292 newkey = str(uuid.uuid4())
295 newkey = str(uuid.uuid4())
293 with open(keyfile, 'w') as f:
296 with open(keyfile, 'w') as f:
294 # f.write('ipython-key ')
297 # f.write('ipython-key ')
295 f.write(newkey+'\n')
298 f.write(newkey+'\n')
296 # set user-only RW permissions (0600)
299 # set user-only RW permissions (0600)
297 # this will have no effect on Windows
300 # this will have no effect on Windows
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
301 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299
302
300
303
301 def integer_loglevel(loglevel):
304 def integer_loglevel(loglevel):
302 try:
305 try:
303 loglevel = int(loglevel)
306 loglevel = int(loglevel)
304 except ValueError:
307 except ValueError:
305 if isinstance(loglevel, str):
308 if isinstance(loglevel, str):
306 loglevel = getattr(logging, loglevel)
309 loglevel = getattr(logging, loglevel)
307 return loglevel
310 return loglevel
308
311
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
312 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 logger = logging.getLogger(logname)
313 logger = logging.getLogger(logname)
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
314 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 # don't add a second PUBHandler
315 # don't add a second PUBHandler
313 return
316 return
314 loglevel = integer_loglevel(loglevel)
317 loglevel = integer_loglevel(loglevel)
315 lsock = context.socket(zmq.PUB)
318 lsock = context.socket(zmq.PUB)
316 lsock.connect(iface)
319 lsock.connect(iface)
317 handler = handlers.PUBHandler(lsock)
320 handler = handlers.PUBHandler(lsock)
318 handler.setLevel(loglevel)
321 handler.setLevel(loglevel)
319 handler.root_topic = root
322 handler.root_topic = root
320 logger.addHandler(handler)
323 logger.addHandler(handler)
321 logger.setLevel(loglevel)
324 logger.setLevel(loglevel)
322
325
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
326 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 logger = logging.getLogger()
327 logger = logging.getLogger()
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
328 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 # don't add a second PUBHandler
329 # don't add a second PUBHandler
327 return
330 return
328 loglevel = integer_loglevel(loglevel)
331 loglevel = integer_loglevel(loglevel)
329 lsock = context.socket(zmq.PUB)
332 lsock = context.socket(zmq.PUB)
330 lsock.connect(iface)
333 lsock.connect(iface)
331 handler = EnginePUBHandler(engine, lsock)
334 handler = EnginePUBHandler(engine, lsock)
332 handler.setLevel(loglevel)
335 handler.setLevel(loglevel)
333 logger.addHandler(handler)
336 logger.addHandler(handler)
334 logger.setLevel(loglevel)
337 logger.setLevel(loglevel)
335 return logger
338 return logger
336
339
337 def local_logger(logname, loglevel=logging.DEBUG):
340 def local_logger(logname, loglevel=logging.DEBUG):
338 loglevel = integer_loglevel(loglevel)
341 loglevel = integer_loglevel(loglevel)
339 logger = logging.getLogger(logname)
342 logger = logging.getLogger(logname)
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
343 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 # don't add a second StreamHandler
344 # don't add a second StreamHandler
342 return
345 return
343 handler = logging.StreamHandler()
346 handler = logging.StreamHandler()
344 handler.setLevel(loglevel)
347 handler.setLevel(loglevel)
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
348 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 datefmt="%Y-%m-%d %H:%M:%S")
349 datefmt="%Y-%m-%d %H:%M:%S")
347 handler.setFormatter(formatter)
350 handler.setFormatter(formatter)
348
351
349 logger.addHandler(handler)
352 logger.addHandler(handler)
350 logger.setLevel(loglevel)
353 logger.setLevel(loglevel)
351 return logger
354 return logger
352
355
General Comments 0
You need to be logged in to leave comments. Login now