##// END OF EJS Templates
pickle arrays with dtype=object...
MinRK -
Show More
@@ -1,835 +1,850 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import base64
19 import base64
20 import sys
20 import sys
21 import platform
21 import platform
22 import time
22 import time
23 from collections import namedtuple
23 from collections import namedtuple
24 from tempfile import mktemp
24 from tempfile import mktemp
25
25
26 import zmq
26 import zmq
27 from nose.plugins.attrib import attr
27 from nose.plugins.attrib import attr
28
28
29 from IPython.testing import decorators as dec
29 from IPython.testing import decorators as dec
30 from IPython.utils.io import capture_output
30 from IPython.utils.io import capture_output
31 from IPython.utils.py3compat import unicode_type
31 from IPython.utils.py3compat import unicode_type
32
32
33 from IPython import parallel as pmod
33 from IPython import parallel as pmod
34 from IPython.parallel import error
34 from IPython.parallel import error
35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel.util import interactive
36 from IPython.parallel.util import interactive
37
37
38 from IPython.parallel.tests import add_engines
38 from IPython.parallel.tests import add_engines
39
39
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41
41
42 def setup():
42 def setup():
43 add_engines(3, total=True)
43 add_engines(3, total=True)
44
44
45 point = namedtuple("point", "x y")
45 point = namedtuple("point", "x y")
46
46
47 class TestView(ClusterTestCase):
47 class TestView(ClusterTestCase):
48
48
49 def setUp(self):
49 def setUp(self):
50 # On Win XP, wait for resource cleanup, else parallel test group fails
50 # On Win XP, wait for resource cleanup, else parallel test group fails
51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 time.sleep(2)
53 time.sleep(2)
54 super(TestView, self).setUp()
54 super(TestView, self).setUp()
55
55
56 @attr('crash')
56 @attr('crash')
57 def test_z_crash_mux(self):
57 def test_z_crash_mux(self):
58 """test graceful handling of engine death (direct)"""
58 """test graceful handling of engine death (direct)"""
59 # self.add_engines(1)
59 # self.add_engines(1)
60 eid = self.client.ids[-1]
60 eid = self.client.ids[-1]
61 ar = self.client[eid].apply_async(crash)
61 ar = self.client[eid].apply_async(crash)
62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 eid = ar.engine_id
63 eid = ar.engine_id
64 tic = time.time()
64 tic = time.time()
65 while eid in self.client.ids and time.time()-tic < 5:
65 while eid in self.client.ids and time.time()-tic < 5:
66 time.sleep(.01)
66 time.sleep(.01)
67 self.client.spin()
67 self.client.spin()
68 self.assertFalse(eid in self.client.ids, "Engine should have died")
68 self.assertFalse(eid in self.client.ids, "Engine should have died")
69
69
70 def test_push_pull(self):
70 def test_push_pull(self):
71 """test pushing and pulling"""
71 """test pushing and pulling"""
72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 t = self.client.ids[-1]
73 t = self.client.ids[-1]
74 v = self.client[t]
74 v = self.client[t]
75 push = v.push
75 push = v.push
76 pull = v.pull
76 pull = v.pull
77 v.block=True
77 v.block=True
78 nengines = len(self.client)
78 nengines = len(self.client)
79 push({'data':data})
79 push({'data':data})
80 d = pull('data')
80 d = pull('data')
81 self.assertEqual(d, data)
81 self.assertEqual(d, data)
82 self.client[:].push({'data':data})
82 self.client[:].push({'data':data})
83 d = self.client[:].pull('data', block=True)
83 d = self.client[:].pull('data', block=True)
84 self.assertEqual(d, nengines*[data])
84 self.assertEqual(d, nengines*[data])
85 ar = push({'data':data}, block=False)
85 ar = push({'data':data}, block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 ar = self.client[:].pull('data', block=False)
88 ar = self.client[:].pull('data', block=False)
89 self.assertTrue(isinstance(ar, AsyncResult))
89 self.assertTrue(isinstance(ar, AsyncResult))
90 r = ar.get()
90 r = ar.get()
91 self.assertEqual(r, nengines*[data])
91 self.assertEqual(r, nengines*[data])
92 self.client[:].push(dict(a=10,b=20))
92 self.client[:].push(dict(a=10,b=20))
93 r = self.client[:].pull(('a','b'), block=True)
93 r = self.client[:].pull(('a','b'), block=True)
94 self.assertEqual(r, nengines*[[10,20]])
94 self.assertEqual(r, nengines*[[10,20]])
95
95
96 def test_push_pull_function(self):
96 def test_push_pull_function(self):
97 "test pushing and pulling functions"
97 "test pushing and pulling functions"
98 def testf(x):
98 def testf(x):
99 return 2.0*x
99 return 2.0*x
100
100
101 t = self.client.ids[-1]
101 t = self.client.ids[-1]
102 v = self.client[t]
102 v = self.client[t]
103 v.block=True
103 v.block=True
104 push = v.push
104 push = v.push
105 pull = v.pull
105 pull = v.pull
106 execute = v.execute
106 execute = v.execute
107 push({'testf':testf})
107 push({'testf':testf})
108 r = pull('testf')
108 r = pull('testf')
109 self.assertEqual(r(1.0), testf(1.0))
109 self.assertEqual(r(1.0), testf(1.0))
110 execute('r = testf(10)')
110 execute('r = testf(10)')
111 r = pull('r')
111 r = pull('r')
112 self.assertEqual(r, testf(10))
112 self.assertEqual(r, testf(10))
113 ar = self.client[:].push({'testf':testf}, block=False)
113 ar = self.client[:].push({'testf':testf}, block=False)
114 ar.get()
114 ar.get()
115 ar = self.client[:].pull('testf', block=False)
115 ar = self.client[:].pull('testf', block=False)
116 rlist = ar.get()
116 rlist = ar.get()
117 for r in rlist:
117 for r in rlist:
118 self.assertEqual(r(1.0), testf(1.0))
118 self.assertEqual(r(1.0), testf(1.0))
119 execute("def g(x): return x*x")
119 execute("def g(x): return x*x")
120 r = pull(('testf','g'))
120 r = pull(('testf','g'))
121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122
122
123 def test_push_function_globals(self):
123 def test_push_function_globals(self):
124 """test that pushed functions have access to globals"""
124 """test that pushed functions have access to globals"""
125 @interactive
125 @interactive
126 def geta():
126 def geta():
127 return a
127 return a
128 # self.add_engines(1)
128 # self.add_engines(1)
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = geta
131 v['f'] = geta
132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 v.execute('a=5')
133 v.execute('a=5')
134 v.execute('b=f()')
134 v.execute('b=f()')
135 self.assertEqual(v['b'], 5)
135 self.assertEqual(v['b'], 5)
136
136
137 def test_push_function_defaults(self):
137 def test_push_function_defaults(self):
138 """test that pushed functions preserve default args"""
138 """test that pushed functions preserve default args"""
139 def echo(a=10):
139 def echo(a=10):
140 return a
140 return a
141 v = self.client[-1]
141 v = self.client[-1]
142 v.block=True
142 v.block=True
143 v['f'] = echo
143 v['f'] = echo
144 v.execute('b=f()')
144 v.execute('b=f()')
145 self.assertEqual(v['b'], 10)
145 self.assertEqual(v['b'], 10)
146
146
147 def test_get_result(self):
147 def test_get_result(self):
148 """test getting results from the Hub."""
148 """test getting results from the Hub."""
149 c = pmod.Client(profile='iptest')
149 c = pmod.Client(profile='iptest')
150 # self.add_engines(1)
150 # self.add_engines(1)
151 t = c.ids[-1]
151 t = c.ids[-1]
152 v = c[t]
152 v = c[t]
153 v2 = self.client[t]
153 v2 = self.client[t]
154 ar = v.apply_async(wait, 1)
154 ar = v.apply_async(wait, 1)
155 # give the monitor time to notice the message
155 # give the monitor time to notice the message
156 time.sleep(.25)
156 time.sleep(.25)
157 ahr = v2.get_result(ar.msg_ids[0])
157 ahr = v2.get_result(ar.msg_ids[0])
158 self.assertTrue(isinstance(ahr, AsyncHubResult))
158 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 self.assertEqual(ahr.get(), ar.get())
159 self.assertEqual(ahr.get(), ar.get())
160 ar2 = v2.get_result(ar.msg_ids[0])
160 ar2 = v2.get_result(ar.msg_ids[0])
161 self.assertFalse(isinstance(ar2, AsyncHubResult))
161 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 c.spin()
162 c.spin()
163 c.close()
163 c.close()
164
164
165 def test_run_newline(self):
165 def test_run_newline(self):
166 """test that run appends newline to files"""
166 """test that run appends newline to files"""
167 tmpfile = mktemp()
167 tmpfile = mktemp()
168 with open(tmpfile, 'w') as f:
168 with open(tmpfile, 'w') as f:
169 f.write("""def g():
169 f.write("""def g():
170 return 5
170 return 5
171 """)
171 """)
172 v = self.client[-1]
172 v = self.client[-1]
173 v.run(tmpfile, block=True)
173 v.run(tmpfile, block=True)
174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175
175
176 def test_apply_tracked(self):
176 def test_apply_tracked(self):
177 """test tracking for apply"""
177 """test tracking for apply"""
178 # self.add_engines(1)
178 # self.add_engines(1)
179 t = self.client.ids[-1]
179 t = self.client.ids[-1]
180 v = self.client[t]
180 v = self.client[t]
181 v.block=False
181 v.block=False
182 def echo(n=1024*1024, **kwargs):
182 def echo(n=1024*1024, **kwargs):
183 with v.temp_flags(**kwargs):
183 with v.temp_flags(**kwargs):
184 return v.apply(lambda x: x, 'x'*n)
184 return v.apply(lambda x: x, 'x'*n)
185 ar = echo(1, track=False)
185 ar = echo(1, track=False)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(ar.sent)
187 self.assertTrue(ar.sent)
188 ar = echo(track=True)
188 ar = echo(track=True)
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 self.assertEqual(ar.sent, ar._tracker.done)
190 self.assertEqual(ar.sent, ar._tracker.done)
191 ar._tracker.wait()
191 ar._tracker.wait()
192 self.assertTrue(ar.sent)
192 self.assertTrue(ar.sent)
193
193
194 def test_push_tracked(self):
194 def test_push_tracked(self):
195 t = self.client.ids[-1]
195 t = self.client.ids[-1]
196 ns = dict(x='x'*1024*1024)
196 ns = dict(x='x'*1024*1024)
197 v = self.client[t]
197 v = self.client[t]
198 ar = v.push(ns, block=False, track=False)
198 ar = v.push(ns, block=False, track=False)
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(ar.sent)
200 self.assertTrue(ar.sent)
201
201
202 ar = v.push(ns, block=False, track=True)
202 ar = v.push(ns, block=False, track=True)
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 ar._tracker.wait()
204 ar._tracker.wait()
205 self.assertEqual(ar.sent, ar._tracker.done)
205 self.assertEqual(ar.sent, ar._tracker.done)
206 self.assertTrue(ar.sent)
206 self.assertTrue(ar.sent)
207 ar.get()
207 ar.get()
208
208
209 def test_scatter_tracked(self):
209 def test_scatter_tracked(self):
210 t = self.client.ids
210 t = self.client.ids
211 x='x'*1024*1024
211 x='x'*1024*1024
212 ar = self.client[t].scatter('x', x, block=False, track=False)
212 ar = self.client[t].scatter('x', x, block=False, track=False)
213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(ar.sent)
214 self.assertTrue(ar.sent)
215
215
216 ar = self.client[t].scatter('x', x, block=False, track=True)
216 ar = self.client[t].scatter('x', x, block=False, track=True)
217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 self.assertEqual(ar.sent, ar._tracker.done)
218 self.assertEqual(ar.sent, ar._tracker.done)
219 ar._tracker.wait()
219 ar._tracker.wait()
220 self.assertTrue(ar.sent)
220 self.assertTrue(ar.sent)
221 ar.get()
221 ar.get()
222
222
223 def test_remote_reference(self):
223 def test_remote_reference(self):
224 v = self.client[-1]
224 v = self.client[-1]
225 v['a'] = 123
225 v['a'] = 123
226 ra = pmod.Reference('a')
226 ra = pmod.Reference('a')
227 b = v.apply_sync(lambda x: x, ra)
227 b = v.apply_sync(lambda x: x, ra)
228 self.assertEqual(b, 123)
228 self.assertEqual(b, 123)
229
229
230
230
231 def test_scatter_gather(self):
231 def test_scatter_gather(self):
232 view = self.client[:]
232 view = self.client[:]
233 seq1 = list(range(16))
233 seq1 = list(range(16))
234 view.scatter('a', seq1)
234 view.scatter('a', seq1)
235 seq2 = view.gather('a', block=True)
235 seq2 = view.gather('a', block=True)
236 self.assertEqual(seq2, seq1)
236 self.assertEqual(seq2, seq1)
237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238
238
239 @skip_without('numpy')
239 @skip_without('numpy')
240 def test_scatter_gather_numpy(self):
240 def test_scatter_gather_numpy(self):
241 import numpy
241 import numpy
242 from numpy.testing.utils import assert_array_equal
242 from numpy.testing.utils import assert_array_equal
243 view = self.client[:]
243 view = self.client[:]
244 a = numpy.arange(64)
244 a = numpy.arange(64)
245 view.scatter('a', a, block=True)
245 view.scatter('a', a, block=True)
246 b = view.gather('a', block=True)
246 b = view.gather('a', block=True)
247 assert_array_equal(b, a)
247 assert_array_equal(b, a)
248
248
249 def test_scatter_gather_lazy(self):
249 def test_scatter_gather_lazy(self):
250 """scatter/gather with targets='all'"""
250 """scatter/gather with targets='all'"""
251 view = self.client.direct_view(targets='all')
251 view = self.client.direct_view(targets='all')
252 x = list(range(64))
252 x = list(range(64))
253 view.scatter('x', x)
253 view.scatter('x', x)
254 gathered = view.gather('x', block=True)
254 gathered = view.gather('x', block=True)
255 self.assertEqual(gathered, x)
255 self.assertEqual(gathered, x)
256
256
257
257
258 @dec.known_failure_py3
258 @dec.known_failure_py3
259 @skip_without('numpy')
259 @skip_without('numpy')
260 def test_push_numpy_nocopy(self):
260 def test_push_numpy_nocopy(self):
261 import numpy
261 import numpy
262 view = self.client[:]
262 view = self.client[:]
263 a = numpy.arange(64)
263 a = numpy.arange(64)
264 view['A'] = a
264 view['A'] = a
265 @interactive
265 @interactive
266 def check_writeable(x):
266 def check_writeable(x):
267 return x.flags.writeable
267 return x.flags.writeable
268
268
269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271
271
272 view.push(dict(B=a))
272 view.push(dict(B=a))
273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275
275
276 @skip_without('numpy')
276 @skip_without('numpy')
277 def test_apply_numpy(self):
277 def test_apply_numpy(self):
278 """view.apply(f, ndarray)"""
278 """view.apply(f, ndarray)"""
279 import numpy
279 import numpy
280 from numpy.testing.utils import assert_array_equal
280 from numpy.testing.utils import assert_array_equal
281
281
282 A = numpy.random.random((100,100))
282 A = numpy.random.random((100,100))
283 view = self.client[-1]
283 view = self.client[-1]
284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 B = A.astype(dt)
285 B = A.astype(dt)
286 C = view.apply_sync(lambda x:x, B)
286 C = view.apply_sync(lambda x:x, B)
287 assert_array_equal(B,C)
287 assert_array_equal(B,C)
288
288
289 @skip_without('numpy')
289 @skip_without('numpy')
290 def test_apply_numpy_object_dtype(self):
291 """view.apply(f, ndarray) with dtype=object"""
292 import numpy
293 from numpy.testing.utils import assert_array_equal
294 view = self.client[-1]
295
296 A = numpy.array([dict(a=5)])
297 B = view.apply_sync(lambda x:x, A)
298 assert_array_equal(A,B)
299
300 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
301 B = view.apply_sync(lambda x:x, A)
302 assert_array_equal(A,B)
303
304 @skip_without('numpy')
290 def test_push_pull_recarray(self):
305 def test_push_pull_recarray(self):
291 """push/pull recarrays"""
306 """push/pull recarrays"""
292 import numpy
307 import numpy
293 from numpy.testing.utils import assert_array_equal
308 from numpy.testing.utils import assert_array_equal
294
309
295 view = self.client[-1]
310 view = self.client[-1]
296
311
297 R = numpy.array([
312 R = numpy.array([
298 (1, 'hi', 0.),
313 (1, 'hi', 0.),
299 (2**30, 'there', 2.5),
314 (2**30, 'there', 2.5),
300 (-99999, 'world', -12345.6789),
315 (-99999, 'world', -12345.6789),
301 ], [('n', int), ('s', '|S10'), ('f', float)])
316 ], [('n', int), ('s', '|S10'), ('f', float)])
302
317
303 view['RR'] = R
318 view['RR'] = R
304 R2 = view['RR']
319 R2 = view['RR']
305
320
306 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
321 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
307 self.assertEqual(r_dtype, R.dtype)
322 self.assertEqual(r_dtype, R.dtype)
308 self.assertEqual(r_shape, R.shape)
323 self.assertEqual(r_shape, R.shape)
309 self.assertEqual(R2.dtype, R.dtype)
324 self.assertEqual(R2.dtype, R.dtype)
310 self.assertEqual(R2.shape, R.shape)
325 self.assertEqual(R2.shape, R.shape)
311 assert_array_equal(R2, R)
326 assert_array_equal(R2, R)
312
327
313 @skip_without('pandas')
328 @skip_without('pandas')
314 def test_push_pull_timeseries(self):
329 def test_push_pull_timeseries(self):
315 """push/pull pandas.TimeSeries"""
330 """push/pull pandas.TimeSeries"""
316 import pandas
331 import pandas
317
332
318 ts = pandas.TimeSeries(list(range(10)))
333 ts = pandas.TimeSeries(list(range(10)))
319
334
320 view = self.client[-1]
335 view = self.client[-1]
321
336
322 view.push(dict(ts=ts), block=True)
337 view.push(dict(ts=ts), block=True)
323 rts = view['ts']
338 rts = view['ts']
324
339
325 self.assertEqual(type(rts), type(ts))
340 self.assertEqual(type(rts), type(ts))
326 self.assertTrue((ts == rts).all())
341 self.assertTrue((ts == rts).all())
327
342
328 def test_map(self):
343 def test_map(self):
329 view = self.client[:]
344 view = self.client[:]
330 def f(x):
345 def f(x):
331 return x**2
346 return x**2
332 data = list(range(16))
347 data = list(range(16))
333 r = view.map_sync(f, data)
348 r = view.map_sync(f, data)
334 self.assertEqual(r, list(map(f, data)))
349 self.assertEqual(r, list(map(f, data)))
335
350
336 def test_map_iterable(self):
351 def test_map_iterable(self):
337 """test map on iterables (direct)"""
352 """test map on iterables (direct)"""
338 view = self.client[:]
353 view = self.client[:]
339 # 101 is prime, so it won't be evenly distributed
354 # 101 is prime, so it won't be evenly distributed
340 arr = range(101)
355 arr = range(101)
341 # ensure it will be an iterator, even in Python 3
356 # ensure it will be an iterator, even in Python 3
342 it = iter(arr)
357 it = iter(arr)
343 r = view.map_sync(lambda x: x, it)
358 r = view.map_sync(lambda x: x, it)
344 self.assertEqual(r, list(arr))
359 self.assertEqual(r, list(arr))
345
360
346 @skip_without('numpy')
361 @skip_without('numpy')
347 def test_map_numpy(self):
362 def test_map_numpy(self):
348 """test map on numpy arrays (direct)"""
363 """test map on numpy arrays (direct)"""
349 import numpy
364 import numpy
350 from numpy.testing.utils import assert_array_equal
365 from numpy.testing.utils import assert_array_equal
351
366
352 view = self.client[:]
367 view = self.client[:]
353 # 101 is prime, so it won't be evenly distributed
368 # 101 is prime, so it won't be evenly distributed
354 arr = numpy.arange(101)
369 arr = numpy.arange(101)
355 r = view.map_sync(lambda x: x, arr)
370 r = view.map_sync(lambda x: x, arr)
356 assert_array_equal(r, arr)
371 assert_array_equal(r, arr)
357
372
358 def test_scatter_gather_nonblocking(self):
373 def test_scatter_gather_nonblocking(self):
359 data = list(range(16))
374 data = list(range(16))
360 view = self.client[:]
375 view = self.client[:]
361 view.scatter('a', data, block=False)
376 view.scatter('a', data, block=False)
362 ar = view.gather('a', block=False)
377 ar = view.gather('a', block=False)
363 self.assertEqual(ar.get(), data)
378 self.assertEqual(ar.get(), data)
364
379
365 @skip_without('numpy')
380 @skip_without('numpy')
366 def test_scatter_gather_numpy_nonblocking(self):
381 def test_scatter_gather_numpy_nonblocking(self):
367 import numpy
382 import numpy
368 from numpy.testing.utils import assert_array_equal
383 from numpy.testing.utils import assert_array_equal
369 a = numpy.arange(64)
384 a = numpy.arange(64)
370 view = self.client[:]
385 view = self.client[:]
371 ar = view.scatter('a', a, block=False)
386 ar = view.scatter('a', a, block=False)
372 self.assertTrue(isinstance(ar, AsyncResult))
387 self.assertTrue(isinstance(ar, AsyncResult))
373 amr = view.gather('a', block=False)
388 amr = view.gather('a', block=False)
374 self.assertTrue(isinstance(amr, AsyncMapResult))
389 self.assertTrue(isinstance(amr, AsyncMapResult))
375 assert_array_equal(amr.get(), a)
390 assert_array_equal(amr.get(), a)
376
391
377 def test_execute(self):
392 def test_execute(self):
378 view = self.client[:]
393 view = self.client[:]
379 # self.client.debug=True
394 # self.client.debug=True
380 execute = view.execute
395 execute = view.execute
381 ar = execute('c=30', block=False)
396 ar = execute('c=30', block=False)
382 self.assertTrue(isinstance(ar, AsyncResult))
397 self.assertTrue(isinstance(ar, AsyncResult))
383 ar = execute('d=[0,1,2]', block=False)
398 ar = execute('d=[0,1,2]', block=False)
384 self.client.wait(ar, 1)
399 self.client.wait(ar, 1)
385 self.assertEqual(len(ar.get()), len(self.client))
400 self.assertEqual(len(ar.get()), len(self.client))
386 for c in view['c']:
401 for c in view['c']:
387 self.assertEqual(c, 30)
402 self.assertEqual(c, 30)
388
403
389 def test_abort(self):
404 def test_abort(self):
390 view = self.client[-1]
405 view = self.client[-1]
391 ar = view.execute('import time; time.sleep(1)', block=False)
406 ar = view.execute('import time; time.sleep(1)', block=False)
392 ar2 = view.apply_async(lambda : 2)
407 ar2 = view.apply_async(lambda : 2)
393 ar3 = view.apply_async(lambda : 3)
408 ar3 = view.apply_async(lambda : 3)
394 view.abort(ar2)
409 view.abort(ar2)
395 view.abort(ar3.msg_ids)
410 view.abort(ar3.msg_ids)
396 self.assertRaises(error.TaskAborted, ar2.get)
411 self.assertRaises(error.TaskAborted, ar2.get)
397 self.assertRaises(error.TaskAborted, ar3.get)
412 self.assertRaises(error.TaskAborted, ar3.get)
398
413
399 def test_abort_all(self):
414 def test_abort_all(self):
400 """view.abort() aborts all outstanding tasks"""
415 """view.abort() aborts all outstanding tasks"""
401 view = self.client[-1]
416 view = self.client[-1]
402 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
417 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
403 view.abort()
418 view.abort()
404 view.wait(timeout=5)
419 view.wait(timeout=5)
405 for ar in ars[5:]:
420 for ar in ars[5:]:
406 self.assertRaises(error.TaskAborted, ar.get)
421 self.assertRaises(error.TaskAborted, ar.get)
407
422
408 def test_temp_flags(self):
423 def test_temp_flags(self):
409 view = self.client[-1]
424 view = self.client[-1]
410 view.block=True
425 view.block=True
411 with view.temp_flags(block=False):
426 with view.temp_flags(block=False):
412 self.assertFalse(view.block)
427 self.assertFalse(view.block)
413 self.assertTrue(view.block)
428 self.assertTrue(view.block)
414
429
415 @dec.known_failure_py3
430 @dec.known_failure_py3
416 def test_importer(self):
431 def test_importer(self):
417 view = self.client[-1]
432 view = self.client[-1]
418 view.clear(block=True)
433 view.clear(block=True)
419 with view.importer:
434 with view.importer:
420 import re
435 import re
421
436
422 @interactive
437 @interactive
423 def findall(pat, s):
438 def findall(pat, s):
424 # this globals() step isn't necessary in real code
439 # this globals() step isn't necessary in real code
425 # only to prevent a closure in the test
440 # only to prevent a closure in the test
426 re = globals()['re']
441 re = globals()['re']
427 return re.findall(pat, s)
442 return re.findall(pat, s)
428
443
429 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
444 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
430
445
431 def test_unicode_execute(self):
446 def test_unicode_execute(self):
432 """test executing unicode strings"""
447 """test executing unicode strings"""
433 v = self.client[-1]
448 v = self.client[-1]
434 v.block=True
449 v.block=True
435 if sys.version_info[0] >= 3:
450 if sys.version_info[0] >= 3:
436 code="a='é'"
451 code="a='é'"
437 else:
452 else:
438 code=u"a=u'é'"
453 code=u"a=u'é'"
439 v.execute(code)
454 v.execute(code)
440 self.assertEqual(v['a'], u'é')
455 self.assertEqual(v['a'], u'é')
441
456
442 def test_unicode_apply_result(self):
457 def test_unicode_apply_result(self):
443 """test unicode apply results"""
458 """test unicode apply results"""
444 v = self.client[-1]
459 v = self.client[-1]
445 r = v.apply_sync(lambda : u'é')
460 r = v.apply_sync(lambda : u'é')
446 self.assertEqual(r, u'é')
461 self.assertEqual(r, u'é')
447
462
448 def test_unicode_apply_arg(self):
463 def test_unicode_apply_arg(self):
449 """test passing unicode arguments to apply"""
464 """test passing unicode arguments to apply"""
450 v = self.client[-1]
465 v = self.client[-1]
451
466
452 @interactive
467 @interactive
453 def check_unicode(a, check):
468 def check_unicode(a, check):
454 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
469 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
455 assert isinstance(check, bytes), "%r is not bytes"%check
470 assert isinstance(check, bytes), "%r is not bytes"%check
456 assert a.encode('utf8') == check, "%s != %s"%(a,check)
471 assert a.encode('utf8') == check, "%s != %s"%(a,check)
457
472
458 for s in [ u'é', u'ßø®∫',u'asdf' ]:
473 for s in [ u'é', u'ßø®∫',u'asdf' ]:
459 try:
474 try:
460 v.apply_sync(check_unicode, s, s.encode('utf8'))
475 v.apply_sync(check_unicode, s, s.encode('utf8'))
461 except error.RemoteError as e:
476 except error.RemoteError as e:
462 if e.ename == 'AssertionError':
477 if e.ename == 'AssertionError':
463 self.fail(e.evalue)
478 self.fail(e.evalue)
464 else:
479 else:
465 raise e
480 raise e
466
481
467 def test_map_reference(self):
482 def test_map_reference(self):
468 """view.map(<Reference>, *seqs) should work"""
483 """view.map(<Reference>, *seqs) should work"""
469 v = self.client[:]
484 v = self.client[:]
470 v.scatter('n', self.client.ids, flatten=True)
485 v.scatter('n', self.client.ids, flatten=True)
471 v.execute("f = lambda x,y: x*y")
486 v.execute("f = lambda x,y: x*y")
472 rf = pmod.Reference('f')
487 rf = pmod.Reference('f')
473 nlist = list(range(10))
488 nlist = list(range(10))
474 mlist = nlist[::-1]
489 mlist = nlist[::-1]
475 expected = [ m*n for m,n in zip(mlist, nlist) ]
490 expected = [ m*n for m,n in zip(mlist, nlist) ]
476 result = v.map_sync(rf, mlist, nlist)
491 result = v.map_sync(rf, mlist, nlist)
477 self.assertEqual(result, expected)
492 self.assertEqual(result, expected)
478
493
479 def test_apply_reference(self):
494 def test_apply_reference(self):
480 """view.apply(<Reference>, *args) should work"""
495 """view.apply(<Reference>, *args) should work"""
481 v = self.client[:]
496 v = self.client[:]
482 v.scatter('n', self.client.ids, flatten=True)
497 v.scatter('n', self.client.ids, flatten=True)
483 v.execute("f = lambda x: n*x")
498 v.execute("f = lambda x: n*x")
484 rf = pmod.Reference('f')
499 rf = pmod.Reference('f')
485 result = v.apply_sync(rf, 5)
500 result = v.apply_sync(rf, 5)
486 expected = [ 5*id for id in self.client.ids ]
501 expected = [ 5*id for id in self.client.ids ]
487 self.assertEqual(result, expected)
502 self.assertEqual(result, expected)
488
503
489 def test_eval_reference(self):
504 def test_eval_reference(self):
490 v = self.client[self.client.ids[0]]
505 v = self.client[self.client.ids[0]]
491 v['g'] = list(range(5))
506 v['g'] = list(range(5))
492 rg = pmod.Reference('g[0]')
507 rg = pmod.Reference('g[0]')
493 echo = lambda x:x
508 echo = lambda x:x
494 self.assertEqual(v.apply_sync(echo, rg), 0)
509 self.assertEqual(v.apply_sync(echo, rg), 0)
495
510
496 def test_reference_nameerror(self):
511 def test_reference_nameerror(self):
497 v = self.client[self.client.ids[0]]
512 v = self.client[self.client.ids[0]]
498 r = pmod.Reference('elvis_has_left')
513 r = pmod.Reference('elvis_has_left')
499 echo = lambda x:x
514 echo = lambda x:x
500 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
515 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
501
516
502 def test_single_engine_map(self):
517 def test_single_engine_map(self):
503 e0 = self.client[self.client.ids[0]]
518 e0 = self.client[self.client.ids[0]]
504 r = list(range(5))
519 r = list(range(5))
505 check = [ -1*i for i in r ]
520 check = [ -1*i for i in r ]
506 result = e0.map_sync(lambda x: -1*x, r)
521 result = e0.map_sync(lambda x: -1*x, r)
507 self.assertEqual(result, check)
522 self.assertEqual(result, check)
508
523
509 def test_len(self):
524 def test_len(self):
510 """len(view) makes sense"""
525 """len(view) makes sense"""
511 e0 = self.client[self.client.ids[0]]
526 e0 = self.client[self.client.ids[0]]
512 self.assertEqual(len(e0), 1)
527 self.assertEqual(len(e0), 1)
513 v = self.client[:]
528 v = self.client[:]
514 self.assertEqual(len(v), len(self.client.ids))
529 self.assertEqual(len(v), len(self.client.ids))
515 v = self.client.direct_view('all')
530 v = self.client.direct_view('all')
516 self.assertEqual(len(v), len(self.client.ids))
531 self.assertEqual(len(v), len(self.client.ids))
517 v = self.client[:2]
532 v = self.client[:2]
518 self.assertEqual(len(v), 2)
533 self.assertEqual(len(v), 2)
519 v = self.client[:1]
534 v = self.client[:1]
520 self.assertEqual(len(v), 1)
535 self.assertEqual(len(v), 1)
521 v = self.client.load_balanced_view()
536 v = self.client.load_balanced_view()
522 self.assertEqual(len(v), len(self.client.ids))
537 self.assertEqual(len(v), len(self.client.ids))
523
538
524
539
525 # begin execute tests
540 # begin execute tests
526
541
527 def test_execute_reply(self):
542 def test_execute_reply(self):
528 e0 = self.client[self.client.ids[0]]
543 e0 = self.client[self.client.ids[0]]
529 e0.block = True
544 e0.block = True
530 ar = e0.execute("5", silent=False)
545 ar = e0.execute("5", silent=False)
531 er = ar.get()
546 er = ar.get()
532 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
547 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
533 self.assertEqual(er.pyout['data']['text/plain'], '5')
548 self.assertEqual(er.pyout['data']['text/plain'], '5')
534
549
535 def test_execute_reply_rich(self):
550 def test_execute_reply_rich(self):
536 e0 = self.client[self.client.ids[0]]
551 e0 = self.client[self.client.ids[0]]
537 e0.block = True
552 e0.block = True
538 e0.execute("from IPython.display import Image, HTML")
553 e0.execute("from IPython.display import Image, HTML")
539 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
554 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
540 er = ar.get()
555 er = ar.get()
541 b64data = base64.encodestring(b'garbage').decode('ascii')
556 b64data = base64.encodestring(b'garbage').decode('ascii')
542 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
557 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
543 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
558 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
544 er = ar.get()
559 er = ar.get()
545 self.assertEqual(er._repr_html_(), "<b>bold</b>")
560 self.assertEqual(er._repr_html_(), "<b>bold</b>")
546
561
547 def test_execute_reply_stdout(self):
562 def test_execute_reply_stdout(self):
548 e0 = self.client[self.client.ids[0]]
563 e0 = self.client[self.client.ids[0]]
549 e0.block = True
564 e0.block = True
550 ar = e0.execute("print (5)", silent=False)
565 ar = e0.execute("print (5)", silent=False)
551 er = ar.get()
566 er = ar.get()
552 self.assertEqual(er.stdout.strip(), '5')
567 self.assertEqual(er.stdout.strip(), '5')
553
568
554 def test_execute_pyout(self):
569 def test_execute_pyout(self):
555 """execute triggers pyout with silent=False"""
570 """execute triggers pyout with silent=False"""
556 view = self.client[:]
571 view = self.client[:]
557 ar = view.execute("5", silent=False, block=True)
572 ar = view.execute("5", silent=False, block=True)
558
573
559 expected = [{'text/plain' : '5'}] * len(view)
574 expected = [{'text/plain' : '5'}] * len(view)
560 mimes = [ out['data'] for out in ar.pyout ]
575 mimes = [ out['data'] for out in ar.pyout ]
561 self.assertEqual(mimes, expected)
576 self.assertEqual(mimes, expected)
562
577
563 def test_execute_silent(self):
578 def test_execute_silent(self):
564 """execute does not trigger pyout with silent=True"""
579 """execute does not trigger pyout with silent=True"""
565 view = self.client[:]
580 view = self.client[:]
566 ar = view.execute("5", block=True)
581 ar = view.execute("5", block=True)
567 expected = [None] * len(view)
582 expected = [None] * len(view)
568 self.assertEqual(ar.pyout, expected)
583 self.assertEqual(ar.pyout, expected)
569
584
570 def test_execute_magic(self):
585 def test_execute_magic(self):
571 """execute accepts IPython commands"""
586 """execute accepts IPython commands"""
572 view = self.client[:]
587 view = self.client[:]
573 view.execute("a = 5")
588 view.execute("a = 5")
574 ar = view.execute("%whos", block=True)
589 ar = view.execute("%whos", block=True)
575 # this will raise, if that failed
590 # this will raise, if that failed
576 ar.get(5)
591 ar.get(5)
577 for stdout in ar.stdout:
592 for stdout in ar.stdout:
578 lines = stdout.splitlines()
593 lines = stdout.splitlines()
579 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
594 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
580 found = False
595 found = False
581 for line in lines[2:]:
596 for line in lines[2:]:
582 split = line.split()
597 split = line.split()
583 if split == ['a', 'int', '5']:
598 if split == ['a', 'int', '5']:
584 found = True
599 found = True
585 break
600 break
586 self.assertTrue(found, "whos output wrong: %s" % stdout)
601 self.assertTrue(found, "whos output wrong: %s" % stdout)
587
602
588 def test_execute_displaypub(self):
603 def test_execute_displaypub(self):
589 """execute tracks display_pub output"""
604 """execute tracks display_pub output"""
590 view = self.client[:]
605 view = self.client[:]
591 view.execute("from IPython.core.display import *")
606 view.execute("from IPython.core.display import *")
592 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
607 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
593
608
594 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
609 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
595 for outputs in ar.outputs:
610 for outputs in ar.outputs:
596 mimes = [ out['data'] for out in outputs ]
611 mimes = [ out['data'] for out in outputs ]
597 self.assertEqual(mimes, expected)
612 self.assertEqual(mimes, expected)
598
613
599 def test_apply_displaypub(self):
614 def test_apply_displaypub(self):
600 """apply tracks display_pub output"""
615 """apply tracks display_pub output"""
601 view = self.client[:]
616 view = self.client[:]
602 view.execute("from IPython.core.display import *")
617 view.execute("from IPython.core.display import *")
603
618
604 @interactive
619 @interactive
605 def publish():
620 def publish():
606 [ display(i) for i in range(5) ]
621 [ display(i) for i in range(5) ]
607
622
608 ar = view.apply_async(publish)
623 ar = view.apply_async(publish)
609 ar.get(5)
624 ar.get(5)
610 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
625 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
611 for outputs in ar.outputs:
626 for outputs in ar.outputs:
612 mimes = [ out['data'] for out in outputs ]
627 mimes = [ out['data'] for out in outputs ]
613 self.assertEqual(mimes, expected)
628 self.assertEqual(mimes, expected)
614
629
615 def test_execute_raises(self):
630 def test_execute_raises(self):
616 """exceptions in execute requests raise appropriately"""
631 """exceptions in execute requests raise appropriately"""
617 view = self.client[-1]
632 view = self.client[-1]
618 ar = view.execute("1/0")
633 ar = view.execute("1/0")
619 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
634 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
620
635
621 def test_remoteerror_render_exception(self):
636 def test_remoteerror_render_exception(self):
622 """RemoteErrors get nice tracebacks"""
637 """RemoteErrors get nice tracebacks"""
623 view = self.client[-1]
638 view = self.client[-1]
624 ar = view.execute("1/0")
639 ar = view.execute("1/0")
625 ip = get_ipython()
640 ip = get_ipython()
626 ip.user_ns['ar'] = ar
641 ip.user_ns['ar'] = ar
627 with capture_output() as io:
642 with capture_output() as io:
628 ip.run_cell("ar.get(2)")
643 ip.run_cell("ar.get(2)")
629
644
630 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
645 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
631
646
632 def test_compositeerror_render_exception(self):
647 def test_compositeerror_render_exception(self):
633 """CompositeErrors get nice tracebacks"""
648 """CompositeErrors get nice tracebacks"""
634 view = self.client[:]
649 view = self.client[:]
635 ar = view.execute("1/0")
650 ar = view.execute("1/0")
636 ip = get_ipython()
651 ip = get_ipython()
637 ip.user_ns['ar'] = ar
652 ip.user_ns['ar'] = ar
638
653
639 with capture_output() as io:
654 with capture_output() as io:
640 ip.run_cell("ar.get(2)")
655 ip.run_cell("ar.get(2)")
641
656
642 count = min(error.CompositeError.tb_limit, len(view))
657 count = min(error.CompositeError.tb_limit, len(view))
643
658
644 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
659 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
645 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
660 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
646 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
661 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
647
662
648 def test_compositeerror_truncate(self):
663 def test_compositeerror_truncate(self):
649 """Truncate CompositeErrors with many exceptions"""
664 """Truncate CompositeErrors with many exceptions"""
650 view = self.client[:]
665 view = self.client[:]
651 msg_ids = []
666 msg_ids = []
652 for i in range(10):
667 for i in range(10):
653 ar = view.execute("1/0")
668 ar = view.execute("1/0")
654 msg_ids.extend(ar.msg_ids)
669 msg_ids.extend(ar.msg_ids)
655
670
656 ar = self.client.get_result(msg_ids)
671 ar = self.client.get_result(msg_ids)
657 try:
672 try:
658 ar.get()
673 ar.get()
659 except error.CompositeError as _e:
674 except error.CompositeError as _e:
660 e = _e
675 e = _e
661 else:
676 else:
662 self.fail("Should have raised CompositeError")
677 self.fail("Should have raised CompositeError")
663
678
664 lines = e.render_traceback()
679 lines = e.render_traceback()
665 with capture_output() as io:
680 with capture_output() as io:
666 e.print_traceback()
681 e.print_traceback()
667
682
668 self.assertTrue("more exceptions" in lines[-1])
683 self.assertTrue("more exceptions" in lines[-1])
669 count = e.tb_limit
684 count = e.tb_limit
670
685
671 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
686 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
672 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
687 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
673 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
688 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
674
689
675 @dec.skipif_not_matplotlib
690 @dec.skipif_not_matplotlib
676 def test_magic_pylab(self):
691 def test_magic_pylab(self):
677 """%pylab works on engines"""
692 """%pylab works on engines"""
678 view = self.client[-1]
693 view = self.client[-1]
679 ar = view.execute("%pylab inline")
694 ar = view.execute("%pylab inline")
680 # at least check if this raised:
695 # at least check if this raised:
681 reply = ar.get(5)
696 reply = ar.get(5)
682 # include imports, in case user config
697 # include imports, in case user config
683 ar = view.execute("plot(rand(100))", silent=False)
698 ar = view.execute("plot(rand(100))", silent=False)
684 reply = ar.get(5)
699 reply = ar.get(5)
685 self.assertEqual(len(reply.outputs), 1)
700 self.assertEqual(len(reply.outputs), 1)
686 output = reply.outputs[0]
701 output = reply.outputs[0]
687 self.assertTrue("data" in output)
702 self.assertTrue("data" in output)
688 data = output['data']
703 data = output['data']
689 self.assertTrue("image/png" in data)
704 self.assertTrue("image/png" in data)
690
705
691 def test_func_default_func(self):
706 def test_func_default_func(self):
692 """interactively defined function as apply func default"""
707 """interactively defined function as apply func default"""
693 def foo():
708 def foo():
694 return 'foo'
709 return 'foo'
695
710
696 def bar(f=foo):
711 def bar(f=foo):
697 return f()
712 return f()
698
713
699 view = self.client[-1]
714 view = self.client[-1]
700 ar = view.apply_async(bar)
715 ar = view.apply_async(bar)
701 r = ar.get(10)
716 r = ar.get(10)
702 self.assertEqual(r, 'foo')
717 self.assertEqual(r, 'foo')
703 def test_data_pub_single(self):
718 def test_data_pub_single(self):
704 view = self.client[-1]
719 view = self.client[-1]
705 ar = view.execute('\n'.join([
720 ar = view.execute('\n'.join([
706 'from IPython.kernel.zmq.datapub import publish_data',
721 'from IPython.kernel.zmq.datapub import publish_data',
707 'for i in range(5):',
722 'for i in range(5):',
708 ' publish_data(dict(i=i))'
723 ' publish_data(dict(i=i))'
709 ]), block=False)
724 ]), block=False)
710 self.assertTrue(isinstance(ar.data, dict))
725 self.assertTrue(isinstance(ar.data, dict))
711 ar.get(5)
726 ar.get(5)
712 self.assertEqual(ar.data, dict(i=4))
727 self.assertEqual(ar.data, dict(i=4))
713
728
714 def test_data_pub(self):
729 def test_data_pub(self):
715 view = self.client[:]
730 view = self.client[:]
716 ar = view.execute('\n'.join([
731 ar = view.execute('\n'.join([
717 'from IPython.kernel.zmq.datapub import publish_data',
732 'from IPython.kernel.zmq.datapub import publish_data',
718 'for i in range(5):',
733 'for i in range(5):',
719 ' publish_data(dict(i=i))'
734 ' publish_data(dict(i=i))'
720 ]), block=False)
735 ]), block=False)
721 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
736 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
722 ar.get(5)
737 ar.get(5)
723 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
738 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
724
739
725 def test_can_list_arg(self):
740 def test_can_list_arg(self):
726 """args in lists are canned"""
741 """args in lists are canned"""
727 view = self.client[-1]
742 view = self.client[-1]
728 view['a'] = 128
743 view['a'] = 128
729 rA = pmod.Reference('a')
744 rA = pmod.Reference('a')
730 ar = view.apply_async(lambda x: x, [rA])
745 ar = view.apply_async(lambda x: x, [rA])
731 r = ar.get(5)
746 r = ar.get(5)
732 self.assertEqual(r, [128])
747 self.assertEqual(r, [128])
733
748
734 def test_can_dict_arg(self):
749 def test_can_dict_arg(self):
735 """args in dicts are canned"""
750 """args in dicts are canned"""
736 view = self.client[-1]
751 view = self.client[-1]
737 view['a'] = 128
752 view['a'] = 128
738 rA = pmod.Reference('a')
753 rA = pmod.Reference('a')
739 ar = view.apply_async(lambda x: x, dict(foo=rA))
754 ar = view.apply_async(lambda x: x, dict(foo=rA))
740 r = ar.get(5)
755 r = ar.get(5)
741 self.assertEqual(r, dict(foo=128))
756 self.assertEqual(r, dict(foo=128))
742
757
743 def test_can_list_kwarg(self):
758 def test_can_list_kwarg(self):
744 """kwargs in lists are canned"""
759 """kwargs in lists are canned"""
745 view = self.client[-1]
760 view = self.client[-1]
746 view['a'] = 128
761 view['a'] = 128
747 rA = pmod.Reference('a')
762 rA = pmod.Reference('a')
748 ar = view.apply_async(lambda x=5: x, x=[rA])
763 ar = view.apply_async(lambda x=5: x, x=[rA])
749 r = ar.get(5)
764 r = ar.get(5)
750 self.assertEqual(r, [128])
765 self.assertEqual(r, [128])
751
766
752 def test_can_dict_kwarg(self):
767 def test_can_dict_kwarg(self):
753 """kwargs in dicts are canned"""
768 """kwargs in dicts are canned"""
754 view = self.client[-1]
769 view = self.client[-1]
755 view['a'] = 128
770 view['a'] = 128
756 rA = pmod.Reference('a')
771 rA = pmod.Reference('a')
757 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
772 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
758 r = ar.get(5)
773 r = ar.get(5)
759 self.assertEqual(r, dict(foo=128))
774 self.assertEqual(r, dict(foo=128))
760
775
761 def test_map_ref(self):
776 def test_map_ref(self):
762 """view.map works with references"""
777 """view.map works with references"""
763 view = self.client[:]
778 view = self.client[:]
764 ranks = sorted(self.client.ids)
779 ranks = sorted(self.client.ids)
765 view.scatter('rank', ranks, flatten=True)
780 view.scatter('rank', ranks, flatten=True)
766 rrank = pmod.Reference('rank')
781 rrank = pmod.Reference('rank')
767
782
768 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
783 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
769 drank = amr.get(5)
784 drank = amr.get(5)
770 self.assertEqual(drank, [ r*2 for r in ranks ])
785 self.assertEqual(drank, [ r*2 for r in ranks ])
771
786
772 def test_nested_getitem_setitem(self):
787 def test_nested_getitem_setitem(self):
773 """get and set with view['a.b']"""
788 """get and set with view['a.b']"""
774 view = self.client[-1]
789 view = self.client[-1]
775 view.execute('\n'.join([
790 view.execute('\n'.join([
776 'class A(object): pass',
791 'class A(object): pass',
777 'a = A()',
792 'a = A()',
778 'a.b = 128',
793 'a.b = 128',
779 ]), block=True)
794 ]), block=True)
780 ra = pmod.Reference('a')
795 ra = pmod.Reference('a')
781
796
782 r = view.apply_sync(lambda x: x.b, ra)
797 r = view.apply_sync(lambda x: x.b, ra)
783 self.assertEqual(r, 128)
798 self.assertEqual(r, 128)
784 self.assertEqual(view['a.b'], 128)
799 self.assertEqual(view['a.b'], 128)
785
800
786 view['a.b'] = 0
801 view['a.b'] = 0
787
802
788 r = view.apply_sync(lambda x: x.b, ra)
803 r = view.apply_sync(lambda x: x.b, ra)
789 self.assertEqual(r, 0)
804 self.assertEqual(r, 0)
790 self.assertEqual(view['a.b'], 0)
805 self.assertEqual(view['a.b'], 0)
791
806
792 def test_return_namedtuple(self):
807 def test_return_namedtuple(self):
793 def namedtuplify(x, y):
808 def namedtuplify(x, y):
794 from IPython.parallel.tests.test_view import point
809 from IPython.parallel.tests.test_view import point
795 return point(x, y)
810 return point(x, y)
796
811
797 view = self.client[-1]
812 view = self.client[-1]
798 p = view.apply_sync(namedtuplify, 1, 2)
813 p = view.apply_sync(namedtuplify, 1, 2)
799 self.assertEqual(p.x, 1)
814 self.assertEqual(p.x, 1)
800 self.assertEqual(p.y, 2)
815 self.assertEqual(p.y, 2)
801
816
802 def test_apply_namedtuple(self):
817 def test_apply_namedtuple(self):
803 def echoxy(p):
818 def echoxy(p):
804 return p.y, p.x
819 return p.y, p.x
805
820
806 view = self.client[-1]
821 view = self.client[-1]
807 tup = view.apply_sync(echoxy, point(1, 2))
822 tup = view.apply_sync(echoxy, point(1, 2))
808 self.assertEqual(tup, (2,1))
823 self.assertEqual(tup, (2,1))
809
824
810 def test_sync_imports(self):
825 def test_sync_imports(self):
811 view = self.client[-1]
826 view = self.client[-1]
812 with capture_output() as io:
827 with capture_output() as io:
813 with view.sync_imports():
828 with view.sync_imports():
814 import IPython
829 import IPython
815 self.assertIn("IPython", io.stdout)
830 self.assertIn("IPython", io.stdout)
816
831
817 @interactive
832 @interactive
818 def find_ipython():
833 def find_ipython():
819 return 'IPython' in globals()
834 return 'IPython' in globals()
820
835
821 assert view.apply_sync(find_ipython)
836 assert view.apply_sync(find_ipython)
822
837
823 def test_sync_imports_quiet(self):
838 def test_sync_imports_quiet(self):
824 view = self.client[-1]
839 view = self.client[-1]
825 with capture_output() as io:
840 with capture_output() as io:
826 with view.sync_imports(quiet=True):
841 with view.sync_imports(quiet=True):
827 import IPython
842 import IPython
828 self.assertEqual(io.stdout, '')
843 self.assertEqual(io.stdout, '')
829
844
830 @interactive
845 @interactive
831 def find_ipython():
846 def find_ipython():
832 return 'IPython' in globals()
847 return 'IPython' in globals()
833
848
834 assert view.apply_sync(find_ipython)
849 assert view.apply_sync(find_ipython)
835
850
@@ -1,382 +1,390 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Pickle related utilities. Perhaps this should be called 'can'."""
3 """Pickle related utilities. Perhaps this should be called 'can'."""
4
4
5 __docformat__ = "restructuredtext en"
5 __docformat__ = "restructuredtext en"
6
6
7 #-------------------------------------------------------------------------------
7 #-------------------------------------------------------------------------------
8 # Copyright (C) 2008-2011 The IPython Development Team
8 # Copyright (C) 2008-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 #-------------------------------------------------------------------------------
14 #-------------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 import copy
18 import copy
19 import logging
19 import logging
20 import sys
20 import sys
21 from types import FunctionType
21 from types import FunctionType
22
22
23 try:
23 try:
24 import cPickle as pickle
24 import cPickle as pickle
25 except ImportError:
25 except ImportError:
26 import pickle
26 import pickle
27
27
28 from . import codeutil # This registers a hook when it's imported
28 from . import codeutil # This registers a hook when it's imported
29 from . import py3compat
29 from . import py3compat
30 from .importstring import import_item
30 from .importstring import import_item
31 from .py3compat import string_types, iteritems
31 from .py3compat import string_types, iteritems
32
32
33 from IPython.config import Application
33 from IPython.config import Application
34
34
35 if py3compat.PY3:
35 if py3compat.PY3:
36 buffer = memoryview
36 buffer = memoryview
37 class_type = type
37 class_type = type
38 else:
38 else:
39 from types import ClassType
39 from types import ClassType
40 class_type = (type, ClassType)
40 class_type = (type, ClassType)
41
41
42 #-------------------------------------------------------------------------------
42 #-------------------------------------------------------------------------------
43 # Functions
43 # Functions
44 #-------------------------------------------------------------------------------
44 #-------------------------------------------------------------------------------
45
45
46
46
47 def use_dill():
47 def use_dill():
48 """use dill to expand serialization support
48 """use dill to expand serialization support
49
49
50 adds support for object methods and closures to serialization.
50 adds support for object methods and closures to serialization.
51 """
51 """
52 # import dill causes most of the magic
52 # import dill causes most of the magic
53 import dill
53 import dill
54
54
55 # dill doesn't work with cPickle,
55 # dill doesn't work with cPickle,
56 # tell the two relevant modules to use plain pickle
56 # tell the two relevant modules to use plain pickle
57
57
58 global pickle
58 global pickle
59 pickle = dill
59 pickle = dill
60
60
61 try:
61 try:
62 from IPython.kernel.zmq import serialize
62 from IPython.kernel.zmq import serialize
63 except ImportError:
63 except ImportError:
64 pass
64 pass
65 else:
65 else:
66 serialize.pickle = dill
66 serialize.pickle = dill
67
67
68 # disable special function handling, let dill take care of it
68 # disable special function handling, let dill take care of it
69 can_map.pop(FunctionType, None)
69 can_map.pop(FunctionType, None)
70
70
71
71
72 #-------------------------------------------------------------------------------
72 #-------------------------------------------------------------------------------
73 # Classes
73 # Classes
74 #-------------------------------------------------------------------------------
74 #-------------------------------------------------------------------------------
75
75
76
76
77 class CannedObject(object):
77 class CannedObject(object):
78 def __init__(self, obj, keys=[], hook=None):
78 def __init__(self, obj, keys=[], hook=None):
79 """can an object for safe pickling
79 """can an object for safe pickling
80
80
81 Parameters
81 Parameters
82 ==========
82 ==========
83
83
84 obj:
84 obj:
85 The object to be canned
85 The object to be canned
86 keys: list (optional)
86 keys: list (optional)
87 list of attribute names that will be explicitly canned / uncanned
87 list of attribute names that will be explicitly canned / uncanned
88 hook: callable (optional)
88 hook: callable (optional)
89 An optional extra callable,
89 An optional extra callable,
90 which can do additional processing of the uncanned object.
90 which can do additional processing of the uncanned object.
91
91
92 large data may be offloaded into the buffers list,
92 large data may be offloaded into the buffers list,
93 used for zero-copy transfers.
93 used for zero-copy transfers.
94 """
94 """
95 self.keys = keys
95 self.keys = keys
96 self.obj = copy.copy(obj)
96 self.obj = copy.copy(obj)
97 self.hook = can(hook)
97 self.hook = can(hook)
98 for key in keys:
98 for key in keys:
99 setattr(self.obj, key, can(getattr(obj, key)))
99 setattr(self.obj, key, can(getattr(obj, key)))
100
100
101 self.buffers = []
101 self.buffers = []
102
102
103 def get_object(self, g=None):
103 def get_object(self, g=None):
104 if g is None:
104 if g is None:
105 g = {}
105 g = {}
106 obj = self.obj
106 obj = self.obj
107 for key in self.keys:
107 for key in self.keys:
108 setattr(obj, key, uncan(getattr(obj, key), g))
108 setattr(obj, key, uncan(getattr(obj, key), g))
109
109
110 if self.hook:
110 if self.hook:
111 self.hook = uncan(self.hook, g)
111 self.hook = uncan(self.hook, g)
112 self.hook(obj, g)
112 self.hook(obj, g)
113 return self.obj
113 return self.obj
114
114
115
115
116 class Reference(CannedObject):
116 class Reference(CannedObject):
117 """object for wrapping a remote reference by name."""
117 """object for wrapping a remote reference by name."""
118 def __init__(self, name):
118 def __init__(self, name):
119 if not isinstance(name, string_types):
119 if not isinstance(name, string_types):
120 raise TypeError("illegal name: %r"%name)
120 raise TypeError("illegal name: %r"%name)
121 self.name = name
121 self.name = name
122 self.buffers = []
122 self.buffers = []
123
123
124 def __repr__(self):
124 def __repr__(self):
125 return "<Reference: %r>"%self.name
125 return "<Reference: %r>"%self.name
126
126
127 def get_object(self, g=None):
127 def get_object(self, g=None):
128 if g is None:
128 if g is None:
129 g = {}
129 g = {}
130
130
131 return eval(self.name, g)
131 return eval(self.name, g)
132
132
133
133
134 class CannedFunction(CannedObject):
134 class CannedFunction(CannedObject):
135
135
136 def __init__(self, f):
136 def __init__(self, f):
137 self._check_type(f)
137 self._check_type(f)
138 self.code = f.__code__
138 self.code = f.__code__
139 if f.__defaults__:
139 if f.__defaults__:
140 self.defaults = [ can(fd) for fd in f.__defaults__ ]
140 self.defaults = [ can(fd) for fd in f.__defaults__ ]
141 else:
141 else:
142 self.defaults = None
142 self.defaults = None
143 self.module = f.__module__ or '__main__'
143 self.module = f.__module__ or '__main__'
144 self.__name__ = f.__name__
144 self.__name__ = f.__name__
145 self.buffers = []
145 self.buffers = []
146
146
147 def _check_type(self, obj):
147 def _check_type(self, obj):
148 assert isinstance(obj, FunctionType), "Not a function type"
148 assert isinstance(obj, FunctionType), "Not a function type"
149
149
150 def get_object(self, g=None):
150 def get_object(self, g=None):
151 # try to load function back into its module:
151 # try to load function back into its module:
152 if not self.module.startswith('__'):
152 if not self.module.startswith('__'):
153 __import__(self.module)
153 __import__(self.module)
154 g = sys.modules[self.module].__dict__
154 g = sys.modules[self.module].__dict__
155
155
156 if g is None:
156 if g is None:
157 g = {}
157 g = {}
158 if self.defaults:
158 if self.defaults:
159 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
159 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
160 else:
160 else:
161 defaults = None
161 defaults = None
162 newFunc = FunctionType(self.code, g, self.__name__, defaults)
162 newFunc = FunctionType(self.code, g, self.__name__, defaults)
163 return newFunc
163 return newFunc
164
164
165 class CannedClass(CannedObject):
165 class CannedClass(CannedObject):
166
166
167 def __init__(self, cls):
167 def __init__(self, cls):
168 self._check_type(cls)
168 self._check_type(cls)
169 self.name = cls.__name__
169 self.name = cls.__name__
170 self.old_style = not isinstance(cls, type)
170 self.old_style = not isinstance(cls, type)
171 self._canned_dict = {}
171 self._canned_dict = {}
172 for k,v in cls.__dict__.items():
172 for k,v in cls.__dict__.items():
173 if k not in ('__weakref__', '__dict__'):
173 if k not in ('__weakref__', '__dict__'):
174 self._canned_dict[k] = can(v)
174 self._canned_dict[k] = can(v)
175 if self.old_style:
175 if self.old_style:
176 mro = []
176 mro = []
177 else:
177 else:
178 mro = cls.mro()
178 mro = cls.mro()
179
179
180 self.parents = [ can(c) for c in mro[1:] ]
180 self.parents = [ can(c) for c in mro[1:] ]
181 self.buffers = []
181 self.buffers = []
182
182
183 def _check_type(self, obj):
183 def _check_type(self, obj):
184 assert isinstance(obj, class_type), "Not a class type"
184 assert isinstance(obj, class_type), "Not a class type"
185
185
186 def get_object(self, g=None):
186 def get_object(self, g=None):
187 parents = tuple(uncan(p, g) for p in self.parents)
187 parents = tuple(uncan(p, g) for p in self.parents)
188 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
188 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
189
189
190 class CannedArray(CannedObject):
190 class CannedArray(CannedObject):
191 def __init__(self, obj):
191 def __init__(self, obj):
192 from numpy import ascontiguousarray
192 from numpy import ascontiguousarray
193 self.shape = obj.shape
193 self.shape = obj.shape
194 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
194 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
195 self.pickled = False
195 if sum(obj.shape) == 0:
196 if sum(obj.shape) == 0:
197 self.pickled = True
198 elif obj.dtype == 'O':
199 # can't handle object dtype with buffer approach
200 self.pickled = True
201 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
202 self.pickled = True
203 if self.pickled:
196 # just pickle it
204 # just pickle it
197 self.buffers = [pickle.dumps(obj, -1)]
205 self.buffers = [pickle.dumps(obj, -1)]
198 else:
206 else:
199 # ensure contiguous
207 # ensure contiguous
200 obj = ascontiguousarray(obj, dtype=None)
208 obj = ascontiguousarray(obj, dtype=None)
201 self.buffers = [buffer(obj)]
209 self.buffers = [buffer(obj)]
202
210
203 def get_object(self, g=None):
211 def get_object(self, g=None):
204 from numpy import frombuffer
212 from numpy import frombuffer
205 data = self.buffers[0]
213 data = self.buffers[0]
206 if sum(self.shape) == 0:
214 if self.pickled:
207 # no shape, we just pickled it
215 # no shape, we just pickled it
208 return pickle.loads(data)
216 return pickle.loads(data)
209 else:
217 else:
210 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
218 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
211
219
212
220
213 class CannedBytes(CannedObject):
221 class CannedBytes(CannedObject):
214 wrap = bytes
222 wrap = bytes
215 def __init__(self, obj):
223 def __init__(self, obj):
216 self.buffers = [obj]
224 self.buffers = [obj]
217
225
218 def get_object(self, g=None):
226 def get_object(self, g=None):
219 data = self.buffers[0]
227 data = self.buffers[0]
220 return self.wrap(data)
228 return self.wrap(data)
221
229
222 def CannedBuffer(CannedBytes):
230 def CannedBuffer(CannedBytes):
223 wrap = buffer
231 wrap = buffer
224
232
225 #-------------------------------------------------------------------------------
233 #-------------------------------------------------------------------------------
226 # Functions
234 # Functions
227 #-------------------------------------------------------------------------------
235 #-------------------------------------------------------------------------------
228
236
229 def _logger():
237 def _logger():
230 """get the logger for the current Application
238 """get the logger for the current Application
231
239
232 the root logger will be used if no Application is running
240 the root logger will be used if no Application is running
233 """
241 """
234 if Application.initialized():
242 if Application.initialized():
235 logger = Application.instance().log
243 logger = Application.instance().log
236 else:
244 else:
237 logger = logging.getLogger()
245 logger = logging.getLogger()
238 if not logger.handlers:
246 if not logger.handlers:
239 logging.basicConfig()
247 logging.basicConfig()
240
248
241 return logger
249 return logger
242
250
243 def _import_mapping(mapping, original=None):
251 def _import_mapping(mapping, original=None):
244 """import any string-keys in a type mapping
252 """import any string-keys in a type mapping
245
253
246 """
254 """
247 log = _logger()
255 log = _logger()
248 log.debug("Importing canning map")
256 log.debug("Importing canning map")
249 for key,value in list(mapping.items()):
257 for key,value in list(mapping.items()):
250 if isinstance(key, string_types):
258 if isinstance(key, string_types):
251 try:
259 try:
252 cls = import_item(key)
260 cls = import_item(key)
253 except Exception:
261 except Exception:
254 if original and key not in original:
262 if original and key not in original:
255 # only message on user-added classes
263 # only message on user-added classes
256 log.error("canning class not importable: %r", key, exc_info=True)
264 log.error("canning class not importable: %r", key, exc_info=True)
257 mapping.pop(key)
265 mapping.pop(key)
258 else:
266 else:
259 mapping[cls] = mapping.pop(key)
267 mapping[cls] = mapping.pop(key)
260
268
261 def istype(obj, check):
269 def istype(obj, check):
262 """like isinstance(obj, check), but strict
270 """like isinstance(obj, check), but strict
263
271
264 This won't catch subclasses.
272 This won't catch subclasses.
265 """
273 """
266 if isinstance(check, tuple):
274 if isinstance(check, tuple):
267 for cls in check:
275 for cls in check:
268 if type(obj) is cls:
276 if type(obj) is cls:
269 return True
277 return True
270 return False
278 return False
271 else:
279 else:
272 return type(obj) is check
280 return type(obj) is check
273
281
274 def can(obj):
282 def can(obj):
275 """prepare an object for pickling"""
283 """prepare an object for pickling"""
276
284
277 import_needed = False
285 import_needed = False
278
286
279 for cls,canner in iteritems(can_map):
287 for cls,canner in iteritems(can_map):
280 if isinstance(cls, string_types):
288 if isinstance(cls, string_types):
281 import_needed = True
289 import_needed = True
282 break
290 break
283 elif istype(obj, cls):
291 elif istype(obj, cls):
284 return canner(obj)
292 return canner(obj)
285
293
286 if import_needed:
294 if import_needed:
287 # perform can_map imports, then try again
295 # perform can_map imports, then try again
288 # this will usually only happen once
296 # this will usually only happen once
289 _import_mapping(can_map, _original_can_map)
297 _import_mapping(can_map, _original_can_map)
290 return can(obj)
298 return can(obj)
291
299
292 return obj
300 return obj
293
301
294 def can_class(obj):
302 def can_class(obj):
295 if isinstance(obj, class_type) and obj.__module__ == '__main__':
303 if isinstance(obj, class_type) and obj.__module__ == '__main__':
296 return CannedClass(obj)
304 return CannedClass(obj)
297 else:
305 else:
298 return obj
306 return obj
299
307
300 def can_dict(obj):
308 def can_dict(obj):
301 """can the *values* of a dict"""
309 """can the *values* of a dict"""
302 if istype(obj, dict):
310 if istype(obj, dict):
303 newobj = {}
311 newobj = {}
304 for k, v in iteritems(obj):
312 for k, v in iteritems(obj):
305 newobj[k] = can(v)
313 newobj[k] = can(v)
306 return newobj
314 return newobj
307 else:
315 else:
308 return obj
316 return obj
309
317
310 sequence_types = (list, tuple, set)
318 sequence_types = (list, tuple, set)
311
319
312 def can_sequence(obj):
320 def can_sequence(obj):
313 """can the elements of a sequence"""
321 """can the elements of a sequence"""
314 if istype(obj, sequence_types):
322 if istype(obj, sequence_types):
315 t = type(obj)
323 t = type(obj)
316 return t([can(i) for i in obj])
324 return t([can(i) for i in obj])
317 else:
325 else:
318 return obj
326 return obj
319
327
320 def uncan(obj, g=None):
328 def uncan(obj, g=None):
321 """invert canning"""
329 """invert canning"""
322
330
323 import_needed = False
331 import_needed = False
324 for cls,uncanner in iteritems(uncan_map):
332 for cls,uncanner in iteritems(uncan_map):
325 if isinstance(cls, string_types):
333 if isinstance(cls, string_types):
326 import_needed = True
334 import_needed = True
327 break
335 break
328 elif isinstance(obj, cls):
336 elif isinstance(obj, cls):
329 return uncanner(obj, g)
337 return uncanner(obj, g)
330
338
331 if import_needed:
339 if import_needed:
332 # perform uncan_map imports, then try again
340 # perform uncan_map imports, then try again
333 # this will usually only happen once
341 # this will usually only happen once
334 _import_mapping(uncan_map, _original_uncan_map)
342 _import_mapping(uncan_map, _original_uncan_map)
335 return uncan(obj, g)
343 return uncan(obj, g)
336
344
337 return obj
345 return obj
338
346
339 def uncan_dict(obj, g=None):
347 def uncan_dict(obj, g=None):
340 if istype(obj, dict):
348 if istype(obj, dict):
341 newobj = {}
349 newobj = {}
342 for k, v in iteritems(obj):
350 for k, v in iteritems(obj):
343 newobj[k] = uncan(v,g)
351 newobj[k] = uncan(v,g)
344 return newobj
352 return newobj
345 else:
353 else:
346 return obj
354 return obj
347
355
348 def uncan_sequence(obj, g=None):
356 def uncan_sequence(obj, g=None):
349 if istype(obj, sequence_types):
357 if istype(obj, sequence_types):
350 t = type(obj)
358 t = type(obj)
351 return t([uncan(i,g) for i in obj])
359 return t([uncan(i,g) for i in obj])
352 else:
360 else:
353 return obj
361 return obj
354
362
355 def _uncan_dependent_hook(dep, g=None):
363 def _uncan_dependent_hook(dep, g=None):
356 dep.check_dependency()
364 dep.check_dependency()
357
365
358 def can_dependent(obj):
366 def can_dependent(obj):
359 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
367 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
360
368
361 #-------------------------------------------------------------------------------
369 #-------------------------------------------------------------------------------
362 # API dictionaries
370 # API dictionaries
363 #-------------------------------------------------------------------------------
371 #-------------------------------------------------------------------------------
364
372
365 # These dicts can be extended for custom serialization of new objects
373 # These dicts can be extended for custom serialization of new objects
366
374
367 can_map = {
375 can_map = {
368 'IPython.parallel.dependent' : can_dependent,
376 'IPython.parallel.dependent' : can_dependent,
369 'numpy.ndarray' : CannedArray,
377 'numpy.ndarray' : CannedArray,
370 FunctionType : CannedFunction,
378 FunctionType : CannedFunction,
371 bytes : CannedBytes,
379 bytes : CannedBytes,
372 buffer : CannedBuffer,
380 buffer : CannedBuffer,
373 class_type : can_class,
381 class_type : can_class,
374 }
382 }
375
383
376 uncan_map = {
384 uncan_map = {
377 CannedObject : lambda obj, g: obj.get_object(g),
385 CannedObject : lambda obj, g: obj.get_object(g),
378 }
386 }
379
387
380 # for use in _import_mapping:
388 # for use in _import_mapping:
381 _original_can_map = can_map.copy()
389 _original_can_map = can_map.copy()
382 _original_uncan_map = uncan_map.copy()
390 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now