##// END OF EJS Templates
Fix test_map_iterable() again
Samuel Ainsworth -
Show More
@@ -1,801 +1,801 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from collections import namedtuple
22 from collections import namedtuple
23 from tempfile import mktemp
23 from tempfile import mktemp
24 from StringIO import StringIO
24 from StringIO import StringIO
25
25
26 import zmq
26 import zmq
27 from nose import SkipTest
27 from nose import SkipTest
28 from nose.plugins.attrib import attr
28 from nose.plugins.attrib import attr
29
29
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.ipunittest import ParametricTestCase
31 from IPython.testing.ipunittest import ParametricTestCase
32 from IPython.utils.io import capture_output
32 from IPython.utils.io import capture_output
33
33
34 from IPython import parallel as pmod
34 from IPython import parallel as pmod
35 from IPython.parallel import error
35 from IPython.parallel import error
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 from IPython.parallel import DirectView
37 from IPython.parallel import DirectView
38 from IPython.parallel.util import interactive
38 from IPython.parallel.util import interactive
39
39
40 from IPython.parallel.tests import add_engines
40 from IPython.parallel.tests import add_engines
41
41
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43
43
44 def setup():
44 def setup():
45 add_engines(3, total=True)
45 add_engines(3, total=True)
46
46
47 point = namedtuple("point", "x y")
47 point = namedtuple("point", "x y")
48
48
49 class TestView(ClusterTestCase, ParametricTestCase):
49 class TestView(ClusterTestCase, ParametricTestCase):
50
50
51 def setUp(self):
51 def setUp(self):
52 # On Win XP, wait for resource cleanup, else parallel test group fails
52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 time.sleep(2)
55 time.sleep(2)
56 super(TestView, self).setUp()
56 super(TestView, self).setUp()
57
57
58 @attr('crash')
58 @attr('crash')
59 def test_z_crash_mux(self):
59 def test_z_crash_mux(self):
60 """test graceful handling of engine death (direct)"""
60 """test graceful handling of engine death (direct)"""
61 # self.add_engines(1)
61 # self.add_engines(1)
62 eid = self.client.ids[-1]
62 eid = self.client.ids[-1]
63 ar = self.client[eid].apply_async(crash)
63 ar = self.client[eid].apply_async(crash)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 eid = ar.engine_id
65 eid = ar.engine_id
66 tic = time.time()
66 tic = time.time()
67 while eid in self.client.ids and time.time()-tic < 5:
67 while eid in self.client.ids and time.time()-tic < 5:
68 time.sleep(.01)
68 time.sleep(.01)
69 self.client.spin()
69 self.client.spin()
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71
71
72 def test_push_pull(self):
72 def test_push_pull(self):
73 """test pushing and pulling"""
73 """test pushing and pulling"""
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 t = self.client.ids[-1]
75 t = self.client.ids[-1]
76 v = self.client[t]
76 v = self.client[t]
77 push = v.push
77 push = v.push
78 pull = v.pull
78 pull = v.pull
79 v.block=True
79 v.block=True
80 nengines = len(self.client)
80 nengines = len(self.client)
81 push({'data':data})
81 push({'data':data})
82 d = pull('data')
82 d = pull('data')
83 self.assertEqual(d, data)
83 self.assertEqual(d, data)
84 self.client[:].push({'data':data})
84 self.client[:].push({'data':data})
85 d = self.client[:].pull('data', block=True)
85 d = self.client[:].pull('data', block=True)
86 self.assertEqual(d, nengines*[data])
86 self.assertEqual(d, nengines*[data])
87 ar = push({'data':data}, block=False)
87 ar = push({'data':data}, block=False)
88 self.assertTrue(isinstance(ar, AsyncResult))
88 self.assertTrue(isinstance(ar, AsyncResult))
89 r = ar.get()
89 r = ar.get()
90 ar = self.client[:].pull('data', block=False)
90 ar = self.client[:].pull('data', block=False)
91 self.assertTrue(isinstance(ar, AsyncResult))
91 self.assertTrue(isinstance(ar, AsyncResult))
92 r = ar.get()
92 r = ar.get()
93 self.assertEqual(r, nengines*[data])
93 self.assertEqual(r, nengines*[data])
94 self.client[:].push(dict(a=10,b=20))
94 self.client[:].push(dict(a=10,b=20))
95 r = self.client[:].pull(('a','b'), block=True)
95 r = self.client[:].pull(('a','b'), block=True)
96 self.assertEqual(r, nengines*[[10,20]])
96 self.assertEqual(r, nengines*[[10,20]])
97
97
98 def test_push_pull_function(self):
98 def test_push_pull_function(self):
99 "test pushing and pulling functions"
99 "test pushing and pulling functions"
100 def testf(x):
100 def testf(x):
101 return 2.0*x
101 return 2.0*x
102
102
103 t = self.client.ids[-1]
103 t = self.client.ids[-1]
104 v = self.client[t]
104 v = self.client[t]
105 v.block=True
105 v.block=True
106 push = v.push
106 push = v.push
107 pull = v.pull
107 pull = v.pull
108 execute = v.execute
108 execute = v.execute
109 push({'testf':testf})
109 push({'testf':testf})
110 r = pull('testf')
110 r = pull('testf')
111 self.assertEqual(r(1.0), testf(1.0))
111 self.assertEqual(r(1.0), testf(1.0))
112 execute('r = testf(10)')
112 execute('r = testf(10)')
113 r = pull('r')
113 r = pull('r')
114 self.assertEqual(r, testf(10))
114 self.assertEqual(r, testf(10))
115 ar = self.client[:].push({'testf':testf}, block=False)
115 ar = self.client[:].push({'testf':testf}, block=False)
116 ar.get()
116 ar.get()
117 ar = self.client[:].pull('testf', block=False)
117 ar = self.client[:].pull('testf', block=False)
118 rlist = ar.get()
118 rlist = ar.get()
119 for r in rlist:
119 for r in rlist:
120 self.assertEqual(r(1.0), testf(1.0))
120 self.assertEqual(r(1.0), testf(1.0))
121 execute("def g(x): return x*x")
121 execute("def g(x): return x*x")
122 r = pull(('testf','g'))
122 r = pull(('testf','g'))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124
124
125 def test_push_function_globals(self):
125 def test_push_function_globals(self):
126 """test that pushed functions have access to globals"""
126 """test that pushed functions have access to globals"""
127 @interactive
127 @interactive
128 def geta():
128 def geta():
129 return a
129 return a
130 # self.add_engines(1)
130 # self.add_engines(1)
131 v = self.client[-1]
131 v = self.client[-1]
132 v.block=True
132 v.block=True
133 v['f'] = geta
133 v['f'] = geta
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 v.execute('a=5')
135 v.execute('a=5')
136 v.execute('b=f()')
136 v.execute('b=f()')
137 self.assertEqual(v['b'], 5)
137 self.assertEqual(v['b'], 5)
138
138
139 def test_push_function_defaults(self):
139 def test_push_function_defaults(self):
140 """test that pushed functions preserve default args"""
140 """test that pushed functions preserve default args"""
141 def echo(a=10):
141 def echo(a=10):
142 return a
142 return a
143 v = self.client[-1]
143 v = self.client[-1]
144 v.block=True
144 v.block=True
145 v['f'] = echo
145 v['f'] = echo
146 v.execute('b=f()')
146 v.execute('b=f()')
147 self.assertEqual(v['b'], 10)
147 self.assertEqual(v['b'], 10)
148
148
149 def test_get_result(self):
149 def test_get_result(self):
150 """test getting results from the Hub."""
150 """test getting results from the Hub."""
151 c = pmod.Client(profile='iptest')
151 c = pmod.Client(profile='iptest')
152 # self.add_engines(1)
152 # self.add_engines(1)
153 t = c.ids[-1]
153 t = c.ids[-1]
154 v = c[t]
154 v = c[t]
155 v2 = self.client[t]
155 v2 = self.client[t]
156 ar = v.apply_async(wait, 1)
156 ar = v.apply_async(wait, 1)
157 # give the monitor time to notice the message
157 # give the monitor time to notice the message
158 time.sleep(.25)
158 time.sleep(.25)
159 ahr = v2.get_result(ar.msg_ids[0])
159 ahr = v2.get_result(ar.msg_ids[0])
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 self.assertEqual(ahr.get(), ar.get())
161 self.assertEqual(ahr.get(), ar.get())
162 ar2 = v2.get_result(ar.msg_ids[0])
162 ar2 = v2.get_result(ar.msg_ids[0])
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 c.spin()
164 c.spin()
165 c.close()
165 c.close()
166
166
167 def test_run_newline(self):
167 def test_run_newline(self):
168 """test that run appends newline to files"""
168 """test that run appends newline to files"""
169 tmpfile = mktemp()
169 tmpfile = mktemp()
170 with open(tmpfile, 'w') as f:
170 with open(tmpfile, 'w') as f:
171 f.write("""def g():
171 f.write("""def g():
172 return 5
172 return 5
173 """)
173 """)
174 v = self.client[-1]
174 v = self.client[-1]
175 v.run(tmpfile, block=True)
175 v.run(tmpfile, block=True)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177
177
178 def test_apply_tracked(self):
178 def test_apply_tracked(self):
179 """test tracking for apply"""
179 """test tracking for apply"""
180 # self.add_engines(1)
180 # self.add_engines(1)
181 t = self.client.ids[-1]
181 t = self.client.ids[-1]
182 v = self.client[t]
182 v = self.client[t]
183 v.block=False
183 v.block=False
184 def echo(n=1024*1024, **kwargs):
184 def echo(n=1024*1024, **kwargs):
185 with v.temp_flags(**kwargs):
185 with v.temp_flags(**kwargs):
186 return v.apply(lambda x: x, 'x'*n)
186 return v.apply(lambda x: x, 'x'*n)
187 ar = echo(1, track=False)
187 ar = echo(1, track=False)
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190 ar = echo(track=True)
190 ar = echo(track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertEqual(ar.sent, ar._tracker.done)
192 self.assertEqual(ar.sent, ar._tracker.done)
193 ar._tracker.wait()
193 ar._tracker.wait()
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195
195
196 def test_push_tracked(self):
196 def test_push_tracked(self):
197 t = self.client.ids[-1]
197 t = self.client.ids[-1]
198 ns = dict(x='x'*1024*1024)
198 ns = dict(x='x'*1024*1024)
199 v = self.client[t]
199 v = self.client[t]
200 ar = v.push(ns, block=False, track=False)
200 ar = v.push(ns, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = v.push(ns, block=False, track=True)
204 ar = v.push(ns, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 ar._tracker.wait()
206 ar._tracker.wait()
207 self.assertEqual(ar.sent, ar._tracker.done)
207 self.assertEqual(ar.sent, ar._tracker.done)
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_scatter_tracked(self):
211 def test_scatter_tracked(self):
212 t = self.client.ids
212 t = self.client.ids
213 x='x'*1024*1024
213 x='x'*1024*1024
214 ar = self.client[t].scatter('x', x, block=False, track=False)
214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(ar.sent)
216 self.assertTrue(ar.sent)
217
217
218 ar = self.client[t].scatter('x', x, block=False, track=True)
218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 self.assertEqual(ar.sent, ar._tracker.done)
220 self.assertEqual(ar.sent, ar._tracker.done)
221 ar._tracker.wait()
221 ar._tracker.wait()
222 self.assertTrue(ar.sent)
222 self.assertTrue(ar.sent)
223 ar.get()
223 ar.get()
224
224
225 def test_remote_reference(self):
225 def test_remote_reference(self):
226 v = self.client[-1]
226 v = self.client[-1]
227 v['a'] = 123
227 v['a'] = 123
228 ra = pmod.Reference('a')
228 ra = pmod.Reference('a')
229 b = v.apply_sync(lambda x: x, ra)
229 b = v.apply_sync(lambda x: x, ra)
230 self.assertEqual(b, 123)
230 self.assertEqual(b, 123)
231
231
232
232
233 def test_scatter_gather(self):
233 def test_scatter_gather(self):
234 view = self.client[:]
234 view = self.client[:]
235 seq1 = range(16)
235 seq1 = range(16)
236 view.scatter('a', seq1)
236 view.scatter('a', seq1)
237 seq2 = view.gather('a', block=True)
237 seq2 = view.gather('a', block=True)
238 self.assertEqual(seq2, seq1)
238 self.assertEqual(seq2, seq1)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240
240
241 @skip_without('numpy')
241 @skip_without('numpy')
242 def test_scatter_gather_numpy(self):
242 def test_scatter_gather_numpy(self):
243 import numpy
243 import numpy
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 view = self.client[:]
245 view = self.client[:]
246 a = numpy.arange(64)
246 a = numpy.arange(64)
247 view.scatter('a', a, block=True)
247 view.scatter('a', a, block=True)
248 b = view.gather('a', block=True)
248 b = view.gather('a', block=True)
249 assert_array_equal(b, a)
249 assert_array_equal(b, a)
250
250
251 def test_scatter_gather_lazy(self):
251 def test_scatter_gather_lazy(self):
252 """scatter/gather with targets='all'"""
252 """scatter/gather with targets='all'"""
253 view = self.client.direct_view(targets='all')
253 view = self.client.direct_view(targets='all')
254 x = range(64)
254 x = range(64)
255 view.scatter('x', x)
255 view.scatter('x', x)
256 gathered = view.gather('x', block=True)
256 gathered = view.gather('x', block=True)
257 self.assertEqual(gathered, x)
257 self.assertEqual(gathered, x)
258
258
259
259
260 @dec.known_failure_py3
260 @dec.known_failure_py3
261 @skip_without('numpy')
261 @skip_without('numpy')
262 def test_push_numpy_nocopy(self):
262 def test_push_numpy_nocopy(self):
263 import numpy
263 import numpy
264 view = self.client[:]
264 view = self.client[:]
265 a = numpy.arange(64)
265 a = numpy.arange(64)
266 view['A'] = a
266 view['A'] = a
267 @interactive
267 @interactive
268 def check_writeable(x):
268 def check_writeable(x):
269 return x.flags.writeable
269 return x.flags.writeable
270
270
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273
273
274 view.push(dict(B=a))
274 view.push(dict(B=a))
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277
277
278 @skip_without('numpy')
278 @skip_without('numpy')
279 def test_apply_numpy(self):
279 def test_apply_numpy(self):
280 """view.apply(f, ndarray)"""
280 """view.apply(f, ndarray)"""
281 import numpy
281 import numpy
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283
283
284 A = numpy.random.random((100,100))
284 A = numpy.random.random((100,100))
285 view = self.client[-1]
285 view = self.client[-1]
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 B = A.astype(dt)
287 B = A.astype(dt)
288 C = view.apply_sync(lambda x:x, B)
288 C = view.apply_sync(lambda x:x, B)
289 assert_array_equal(B,C)
289 assert_array_equal(B,C)
290
290
291 @skip_without('numpy')
291 @skip_without('numpy')
292 def test_push_pull_recarray(self):
292 def test_push_pull_recarray(self):
293 """push/pull recarrays"""
293 """push/pull recarrays"""
294 import numpy
294 import numpy
295 from numpy.testing.utils import assert_array_equal
295 from numpy.testing.utils import assert_array_equal
296
296
297 view = self.client[-1]
297 view = self.client[-1]
298
298
299 R = numpy.array([
299 R = numpy.array([
300 (1, 'hi', 0.),
300 (1, 'hi', 0.),
301 (2**30, 'there', 2.5),
301 (2**30, 'there', 2.5),
302 (-99999, 'world', -12345.6789),
302 (-99999, 'world', -12345.6789),
303 ], [('n', int), ('s', '|S10'), ('f', float)])
303 ], [('n', int), ('s', '|S10'), ('f', float)])
304
304
305 view['RR'] = R
305 view['RR'] = R
306 R2 = view['RR']
306 R2 = view['RR']
307
307
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 self.assertEqual(r_dtype, R.dtype)
309 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_shape, R.shape)
310 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(R2.dtype, R.dtype)
311 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.shape, R.shape)
312 self.assertEqual(R2.shape, R.shape)
313 assert_array_equal(R2, R)
313 assert_array_equal(R2, R)
314
314
315 @skip_without('pandas')
315 @skip_without('pandas')
316 def test_push_pull_timeseries(self):
316 def test_push_pull_timeseries(self):
317 """push/pull pandas.TimeSeries"""
317 """push/pull pandas.TimeSeries"""
318 import pandas
318 import pandas
319
319
320 ts = pandas.TimeSeries(range(10))
320 ts = pandas.TimeSeries(range(10))
321
321
322 view = self.client[-1]
322 view = self.client[-1]
323
323
324 view.push(dict(ts=ts), block=True)
324 view.push(dict(ts=ts), block=True)
325 rts = view['ts']
325 rts = view['ts']
326
326
327 self.assertEqual(type(rts), type(ts))
327 self.assertEqual(type(rts), type(ts))
328 self.assertTrue((ts == rts).all())
328 self.assertTrue((ts == rts).all())
329
329
330 def test_map(self):
330 def test_map(self):
331 view = self.client[:]
331 view = self.client[:]
332 def f(x):
332 def f(x):
333 return x**2
333 return x**2
334 data = range(16)
334 data = range(16)
335 r = view.map_sync(f, data)
335 r = view.map_sync(f, data)
336 self.assertEqual(r, map(f, data))
336 self.assertEqual(r, map(f, data))
337
337
338 def test_map_iterable(self):
338 def test_map_iterable(self):
339 """test map on iterables (direct)"""
339 """test map on iterables (direct)"""
340 view = self.client[:]
340 view = self.client[:]
341 # 101 is prime, so it won't be evenly distributed
341 # 101 is prime, so it won't be evenly distributed
342 arr = range(101)
342 arr = range(101)
343 # ensure it will be an iterator, even in Python 3
343 # ensure it will be an iterator, even in Python 3
344 it = iter(arr)
344 it = iter(arr)
345 r = view.map_sync(lambda x: x, it)
345 r = view.map_sync(lambda x: x, it)
346 self.assertEqual(r, list(it))
346 self.assertEqual(r, list(arr))
347
347
348 @skip_without('numpy')
348 @skip_without('numpy')
349 def test_map_numpy(self):
349 def test_map_numpy(self):
350 """test map on numpy arrays (direct)"""
350 """test map on numpy arrays (direct)"""
351 import numpy
351 import numpy
352 from numpy.testing.utils import assert_array_equal
352 from numpy.testing.utils import assert_array_equal
353
353
354 view = self.client[:]
354 view = self.client[:]
355 # 101 is prime, so it won't be evenly distributed
355 # 101 is prime, so it won't be evenly distributed
356 arr = numpy.arange(101)
356 arr = numpy.arange(101)
357 r = view.map_sync(lambda x: x, arr)
357 r = view.map_sync(lambda x: x, arr)
358 assert_array_equal(r, arr)
358 assert_array_equal(r, arr)
359
359
360 def test_scatter_gather_nonblocking(self):
360 def test_scatter_gather_nonblocking(self):
361 data = range(16)
361 data = range(16)
362 view = self.client[:]
362 view = self.client[:]
363 view.scatter('a', data, block=False)
363 view.scatter('a', data, block=False)
364 ar = view.gather('a', block=False)
364 ar = view.gather('a', block=False)
365 self.assertEqual(ar.get(), data)
365 self.assertEqual(ar.get(), data)
366
366
367 @skip_without('numpy')
367 @skip_without('numpy')
368 def test_scatter_gather_numpy_nonblocking(self):
368 def test_scatter_gather_numpy_nonblocking(self):
369 import numpy
369 import numpy
370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 a = numpy.arange(64)
371 a = numpy.arange(64)
372 view = self.client[:]
372 view = self.client[:]
373 ar = view.scatter('a', a, block=False)
373 ar = view.scatter('a', a, block=False)
374 self.assertTrue(isinstance(ar, AsyncResult))
374 self.assertTrue(isinstance(ar, AsyncResult))
375 amr = view.gather('a', block=False)
375 amr = view.gather('a', block=False)
376 self.assertTrue(isinstance(amr, AsyncMapResult))
376 self.assertTrue(isinstance(amr, AsyncMapResult))
377 assert_array_equal(amr.get(), a)
377 assert_array_equal(amr.get(), a)
378
378
379 def test_execute(self):
379 def test_execute(self):
380 view = self.client[:]
380 view = self.client[:]
381 # self.client.debug=True
381 # self.client.debug=True
382 execute = view.execute
382 execute = view.execute
383 ar = execute('c=30', block=False)
383 ar = execute('c=30', block=False)
384 self.assertTrue(isinstance(ar, AsyncResult))
384 self.assertTrue(isinstance(ar, AsyncResult))
385 ar = execute('d=[0,1,2]', block=False)
385 ar = execute('d=[0,1,2]', block=False)
386 self.client.wait(ar, 1)
386 self.client.wait(ar, 1)
387 self.assertEqual(len(ar.get()), len(self.client))
387 self.assertEqual(len(ar.get()), len(self.client))
388 for c in view['c']:
388 for c in view['c']:
389 self.assertEqual(c, 30)
389 self.assertEqual(c, 30)
390
390
391 def test_abort(self):
391 def test_abort(self):
392 view = self.client[-1]
392 view = self.client[-1]
393 ar = view.execute('import time; time.sleep(1)', block=False)
393 ar = view.execute('import time; time.sleep(1)', block=False)
394 ar2 = view.apply_async(lambda : 2)
394 ar2 = view.apply_async(lambda : 2)
395 ar3 = view.apply_async(lambda : 3)
395 ar3 = view.apply_async(lambda : 3)
396 view.abort(ar2)
396 view.abort(ar2)
397 view.abort(ar3.msg_ids)
397 view.abort(ar3.msg_ids)
398 self.assertRaises(error.TaskAborted, ar2.get)
398 self.assertRaises(error.TaskAborted, ar2.get)
399 self.assertRaises(error.TaskAborted, ar3.get)
399 self.assertRaises(error.TaskAborted, ar3.get)
400
400
401 def test_abort_all(self):
401 def test_abort_all(self):
402 """view.abort() aborts all outstanding tasks"""
402 """view.abort() aborts all outstanding tasks"""
403 view = self.client[-1]
403 view = self.client[-1]
404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 view.abort()
405 view.abort()
406 view.wait(timeout=5)
406 view.wait(timeout=5)
407 for ar in ars[5:]:
407 for ar in ars[5:]:
408 self.assertRaises(error.TaskAborted, ar.get)
408 self.assertRaises(error.TaskAborted, ar.get)
409
409
410 def test_temp_flags(self):
410 def test_temp_flags(self):
411 view = self.client[-1]
411 view = self.client[-1]
412 view.block=True
412 view.block=True
413 with view.temp_flags(block=False):
413 with view.temp_flags(block=False):
414 self.assertFalse(view.block)
414 self.assertFalse(view.block)
415 self.assertTrue(view.block)
415 self.assertTrue(view.block)
416
416
417 @dec.known_failure_py3
417 @dec.known_failure_py3
418 def test_importer(self):
418 def test_importer(self):
419 view = self.client[-1]
419 view = self.client[-1]
420 view.clear(block=True)
420 view.clear(block=True)
421 with view.importer:
421 with view.importer:
422 import re
422 import re
423
423
424 @interactive
424 @interactive
425 def findall(pat, s):
425 def findall(pat, s):
426 # this globals() step isn't necessary in real code
426 # this globals() step isn't necessary in real code
427 # only to prevent a closure in the test
427 # only to prevent a closure in the test
428 re = globals()['re']
428 re = globals()['re']
429 return re.findall(pat, s)
429 return re.findall(pat, s)
430
430
431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432
432
433 def test_unicode_execute(self):
433 def test_unicode_execute(self):
434 """test executing unicode strings"""
434 """test executing unicode strings"""
435 v = self.client[-1]
435 v = self.client[-1]
436 v.block=True
436 v.block=True
437 if sys.version_info[0] >= 3:
437 if sys.version_info[0] >= 3:
438 code="a='é'"
438 code="a='é'"
439 else:
439 else:
440 code=u"a=u'é'"
440 code=u"a=u'é'"
441 v.execute(code)
441 v.execute(code)
442 self.assertEqual(v['a'], u'é')
442 self.assertEqual(v['a'], u'é')
443
443
444 def test_unicode_apply_result(self):
444 def test_unicode_apply_result(self):
445 """test unicode apply results"""
445 """test unicode apply results"""
446 v = self.client[-1]
446 v = self.client[-1]
447 r = v.apply_sync(lambda : u'é')
447 r = v.apply_sync(lambda : u'é')
448 self.assertEqual(r, u'é')
448 self.assertEqual(r, u'é')
449
449
450 def test_unicode_apply_arg(self):
450 def test_unicode_apply_arg(self):
451 """test passing unicode arguments to apply"""
451 """test passing unicode arguments to apply"""
452 v = self.client[-1]
452 v = self.client[-1]
453
453
454 @interactive
454 @interactive
455 def check_unicode(a, check):
455 def check_unicode(a, check):
456 assert isinstance(a, unicode), "%r is not unicode"%a
456 assert isinstance(a, unicode), "%r is not unicode"%a
457 assert isinstance(check, bytes), "%r is not bytes"%check
457 assert isinstance(check, bytes), "%r is not bytes"%check
458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459
459
460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 try:
461 try:
462 v.apply_sync(check_unicode, s, s.encode('utf8'))
462 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 except error.RemoteError as e:
463 except error.RemoteError as e:
464 if e.ename == 'AssertionError':
464 if e.ename == 'AssertionError':
465 self.fail(e.evalue)
465 self.fail(e.evalue)
466 else:
466 else:
467 raise e
467 raise e
468
468
469 def test_map_reference(self):
469 def test_map_reference(self):
470 """view.map(<Reference>, *seqs) should work"""
470 """view.map(<Reference>, *seqs) should work"""
471 v = self.client[:]
471 v = self.client[:]
472 v.scatter('n', self.client.ids, flatten=True)
472 v.scatter('n', self.client.ids, flatten=True)
473 v.execute("f = lambda x,y: x*y")
473 v.execute("f = lambda x,y: x*y")
474 rf = pmod.Reference('f')
474 rf = pmod.Reference('f')
475 nlist = list(range(10))
475 nlist = list(range(10))
476 mlist = nlist[::-1]
476 mlist = nlist[::-1]
477 expected = [ m*n for m,n in zip(mlist, nlist) ]
477 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 result = v.map_sync(rf, mlist, nlist)
478 result = v.map_sync(rf, mlist, nlist)
479 self.assertEqual(result, expected)
479 self.assertEqual(result, expected)
480
480
481 def test_apply_reference(self):
481 def test_apply_reference(self):
482 """view.apply(<Reference>, *args) should work"""
482 """view.apply(<Reference>, *args) should work"""
483 v = self.client[:]
483 v = self.client[:]
484 v.scatter('n', self.client.ids, flatten=True)
484 v.scatter('n', self.client.ids, flatten=True)
485 v.execute("f = lambda x: n*x")
485 v.execute("f = lambda x: n*x")
486 rf = pmod.Reference('f')
486 rf = pmod.Reference('f')
487 result = v.apply_sync(rf, 5)
487 result = v.apply_sync(rf, 5)
488 expected = [ 5*id for id in self.client.ids ]
488 expected = [ 5*id for id in self.client.ids ]
489 self.assertEqual(result, expected)
489 self.assertEqual(result, expected)
490
490
491 def test_eval_reference(self):
491 def test_eval_reference(self):
492 v = self.client[self.client.ids[0]]
492 v = self.client[self.client.ids[0]]
493 v['g'] = range(5)
493 v['g'] = range(5)
494 rg = pmod.Reference('g[0]')
494 rg = pmod.Reference('g[0]')
495 echo = lambda x:x
495 echo = lambda x:x
496 self.assertEqual(v.apply_sync(echo, rg), 0)
496 self.assertEqual(v.apply_sync(echo, rg), 0)
497
497
498 def test_reference_nameerror(self):
498 def test_reference_nameerror(self):
499 v = self.client[self.client.ids[0]]
499 v = self.client[self.client.ids[0]]
500 r = pmod.Reference('elvis_has_left')
500 r = pmod.Reference('elvis_has_left')
501 echo = lambda x:x
501 echo = lambda x:x
502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503
503
504 def test_single_engine_map(self):
504 def test_single_engine_map(self):
505 e0 = self.client[self.client.ids[0]]
505 e0 = self.client[self.client.ids[0]]
506 r = range(5)
506 r = range(5)
507 check = [ -1*i for i in r ]
507 check = [ -1*i for i in r ]
508 result = e0.map_sync(lambda x: -1*x, r)
508 result = e0.map_sync(lambda x: -1*x, r)
509 self.assertEqual(result, check)
509 self.assertEqual(result, check)
510
510
511 def test_len(self):
511 def test_len(self):
512 """len(view) makes sense"""
512 """len(view) makes sense"""
513 e0 = self.client[self.client.ids[0]]
513 e0 = self.client[self.client.ids[0]]
514 yield self.assertEqual(len(e0), 1)
514 yield self.assertEqual(len(e0), 1)
515 v = self.client[:]
515 v = self.client[:]
516 yield self.assertEqual(len(v), len(self.client.ids))
516 yield self.assertEqual(len(v), len(self.client.ids))
517 v = self.client.direct_view('all')
517 v = self.client.direct_view('all')
518 yield self.assertEqual(len(v), len(self.client.ids))
518 yield self.assertEqual(len(v), len(self.client.ids))
519 v = self.client[:2]
519 v = self.client[:2]
520 yield self.assertEqual(len(v), 2)
520 yield self.assertEqual(len(v), 2)
521 v = self.client[:1]
521 v = self.client[:1]
522 yield self.assertEqual(len(v), 1)
522 yield self.assertEqual(len(v), 1)
523 v = self.client.load_balanced_view()
523 v = self.client.load_balanced_view()
524 yield self.assertEqual(len(v), len(self.client.ids))
524 yield self.assertEqual(len(v), len(self.client.ids))
525 # parametric tests seem to require manual closing?
525 # parametric tests seem to require manual closing?
526 self.client.close()
526 self.client.close()
527
527
528
528
529 # begin execute tests
529 # begin execute tests
530
530
531 def test_execute_reply(self):
531 def test_execute_reply(self):
532 e0 = self.client[self.client.ids[0]]
532 e0 = self.client[self.client.ids[0]]
533 e0.block = True
533 e0.block = True
534 ar = e0.execute("5", silent=False)
534 ar = e0.execute("5", silent=False)
535 er = ar.get()
535 er = ar.get()
536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 self.assertEqual(er.pyout['data']['text/plain'], '5')
537 self.assertEqual(er.pyout['data']['text/plain'], '5')
538
538
539 def test_execute_reply_stdout(self):
539 def test_execute_reply_stdout(self):
540 e0 = self.client[self.client.ids[0]]
540 e0 = self.client[self.client.ids[0]]
541 e0.block = True
541 e0.block = True
542 ar = e0.execute("print (5)", silent=False)
542 ar = e0.execute("print (5)", silent=False)
543 er = ar.get()
543 er = ar.get()
544 self.assertEqual(er.stdout.strip(), '5')
544 self.assertEqual(er.stdout.strip(), '5')
545
545
546 def test_execute_pyout(self):
546 def test_execute_pyout(self):
547 """execute triggers pyout with silent=False"""
547 """execute triggers pyout with silent=False"""
548 view = self.client[:]
548 view = self.client[:]
549 ar = view.execute("5", silent=False, block=True)
549 ar = view.execute("5", silent=False, block=True)
550
550
551 expected = [{'text/plain' : '5'}] * len(view)
551 expected = [{'text/plain' : '5'}] * len(view)
552 mimes = [ out['data'] for out in ar.pyout ]
552 mimes = [ out['data'] for out in ar.pyout ]
553 self.assertEqual(mimes, expected)
553 self.assertEqual(mimes, expected)
554
554
555 def test_execute_silent(self):
555 def test_execute_silent(self):
556 """execute does not trigger pyout with silent=True"""
556 """execute does not trigger pyout with silent=True"""
557 view = self.client[:]
557 view = self.client[:]
558 ar = view.execute("5", block=True)
558 ar = view.execute("5", block=True)
559 expected = [None] * len(view)
559 expected = [None] * len(view)
560 self.assertEqual(ar.pyout, expected)
560 self.assertEqual(ar.pyout, expected)
561
561
562 def test_execute_magic(self):
562 def test_execute_magic(self):
563 """execute accepts IPython commands"""
563 """execute accepts IPython commands"""
564 view = self.client[:]
564 view = self.client[:]
565 view.execute("a = 5")
565 view.execute("a = 5")
566 ar = view.execute("%whos", block=True)
566 ar = view.execute("%whos", block=True)
567 # this will raise, if that failed
567 # this will raise, if that failed
568 ar.get(5)
568 ar.get(5)
569 for stdout in ar.stdout:
569 for stdout in ar.stdout:
570 lines = stdout.splitlines()
570 lines = stdout.splitlines()
571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
572 found = False
572 found = False
573 for line in lines[2:]:
573 for line in lines[2:]:
574 split = line.split()
574 split = line.split()
575 if split == ['a', 'int', '5']:
575 if split == ['a', 'int', '5']:
576 found = True
576 found = True
577 break
577 break
578 self.assertTrue(found, "whos output wrong: %s" % stdout)
578 self.assertTrue(found, "whos output wrong: %s" % stdout)
579
579
580 def test_execute_displaypub(self):
580 def test_execute_displaypub(self):
581 """execute tracks display_pub output"""
581 """execute tracks display_pub output"""
582 view = self.client[:]
582 view = self.client[:]
583 view.execute("from IPython.core.display import *")
583 view.execute("from IPython.core.display import *")
584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
585
585
586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
587 for outputs in ar.outputs:
587 for outputs in ar.outputs:
588 mimes = [ out['data'] for out in outputs ]
588 mimes = [ out['data'] for out in outputs ]
589 self.assertEqual(mimes, expected)
589 self.assertEqual(mimes, expected)
590
590
591 def test_apply_displaypub(self):
591 def test_apply_displaypub(self):
592 """apply tracks display_pub output"""
592 """apply tracks display_pub output"""
593 view = self.client[:]
593 view = self.client[:]
594 view.execute("from IPython.core.display import *")
594 view.execute("from IPython.core.display import *")
595
595
596 @interactive
596 @interactive
597 def publish():
597 def publish():
598 [ display(i) for i in range(5) ]
598 [ display(i) for i in range(5) ]
599
599
600 ar = view.apply_async(publish)
600 ar = view.apply_async(publish)
601 ar.get(5)
601 ar.get(5)
602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
603 for outputs in ar.outputs:
603 for outputs in ar.outputs:
604 mimes = [ out['data'] for out in outputs ]
604 mimes = [ out['data'] for out in outputs ]
605 self.assertEqual(mimes, expected)
605 self.assertEqual(mimes, expected)
606
606
607 def test_execute_raises(self):
607 def test_execute_raises(self):
608 """exceptions in execute requests raise appropriately"""
608 """exceptions in execute requests raise appropriately"""
609 view = self.client[-1]
609 view = self.client[-1]
610 ar = view.execute("1/0")
610 ar = view.execute("1/0")
611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
612
612
613 def test_remoteerror_render_exception(self):
613 def test_remoteerror_render_exception(self):
614 """RemoteErrors get nice tracebacks"""
614 """RemoteErrors get nice tracebacks"""
615 view = self.client[-1]
615 view = self.client[-1]
616 ar = view.execute("1/0")
616 ar = view.execute("1/0")
617 ip = get_ipython()
617 ip = get_ipython()
618 ip.user_ns['ar'] = ar
618 ip.user_ns['ar'] = ar
619 with capture_output() as io:
619 with capture_output() as io:
620 ip.run_cell("ar.get(2)")
620 ip.run_cell("ar.get(2)")
621
621
622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
623
623
624 def test_compositeerror_render_exception(self):
624 def test_compositeerror_render_exception(self):
625 """CompositeErrors get nice tracebacks"""
625 """CompositeErrors get nice tracebacks"""
626 view = self.client[:]
626 view = self.client[:]
627 ar = view.execute("1/0")
627 ar = view.execute("1/0")
628 ip = get_ipython()
628 ip = get_ipython()
629 ip.user_ns['ar'] = ar
629 ip.user_ns['ar'] = ar
630
630
631 with capture_output() as io:
631 with capture_output() as io:
632 ip.run_cell("ar.get(2)")
632 ip.run_cell("ar.get(2)")
633
633
634 count = min(error.CompositeError.tb_limit, len(view))
634 count = min(error.CompositeError.tb_limit, len(view))
635
635
636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
639
639
640 def test_compositeerror_truncate(self):
640 def test_compositeerror_truncate(self):
641 """Truncate CompositeErrors with many exceptions"""
641 """Truncate CompositeErrors with many exceptions"""
642 view = self.client[:]
642 view = self.client[:]
643 msg_ids = []
643 msg_ids = []
644 for i in range(10):
644 for i in range(10):
645 ar = view.execute("1/0")
645 ar = view.execute("1/0")
646 msg_ids.extend(ar.msg_ids)
646 msg_ids.extend(ar.msg_ids)
647
647
648 ar = self.client.get_result(msg_ids)
648 ar = self.client.get_result(msg_ids)
649 try:
649 try:
650 ar.get()
650 ar.get()
651 except error.CompositeError as _e:
651 except error.CompositeError as _e:
652 e = _e
652 e = _e
653 else:
653 else:
654 self.fail("Should have raised CompositeError")
654 self.fail("Should have raised CompositeError")
655
655
656 lines = e.render_traceback()
656 lines = e.render_traceback()
657 with capture_output() as io:
657 with capture_output() as io:
658 e.print_traceback()
658 e.print_traceback()
659
659
660 self.assertTrue("more exceptions" in lines[-1])
660 self.assertTrue("more exceptions" in lines[-1])
661 count = e.tb_limit
661 count = e.tb_limit
662
662
663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
666
666
667 @dec.skipif_not_matplotlib
667 @dec.skipif_not_matplotlib
668 def test_magic_pylab(self):
668 def test_magic_pylab(self):
669 """%pylab works on engines"""
669 """%pylab works on engines"""
670 view = self.client[-1]
670 view = self.client[-1]
671 ar = view.execute("%pylab inline")
671 ar = view.execute("%pylab inline")
672 # at least check if this raised:
672 # at least check if this raised:
673 reply = ar.get(5)
673 reply = ar.get(5)
674 # include imports, in case user config
674 # include imports, in case user config
675 ar = view.execute("plot(rand(100))", silent=False)
675 ar = view.execute("plot(rand(100))", silent=False)
676 reply = ar.get(5)
676 reply = ar.get(5)
677 self.assertEqual(len(reply.outputs), 1)
677 self.assertEqual(len(reply.outputs), 1)
678 output = reply.outputs[0]
678 output = reply.outputs[0]
679 self.assertTrue("data" in output)
679 self.assertTrue("data" in output)
680 data = output['data']
680 data = output['data']
681 self.assertTrue("image/png" in data)
681 self.assertTrue("image/png" in data)
682
682
683 def test_func_default_func(self):
683 def test_func_default_func(self):
684 """interactively defined function as apply func default"""
684 """interactively defined function as apply func default"""
685 def foo():
685 def foo():
686 return 'foo'
686 return 'foo'
687
687
688 def bar(f=foo):
688 def bar(f=foo):
689 return f()
689 return f()
690
690
691 view = self.client[-1]
691 view = self.client[-1]
692 ar = view.apply_async(bar)
692 ar = view.apply_async(bar)
693 r = ar.get(10)
693 r = ar.get(10)
694 self.assertEqual(r, 'foo')
694 self.assertEqual(r, 'foo')
695 def test_data_pub_single(self):
695 def test_data_pub_single(self):
696 view = self.client[-1]
696 view = self.client[-1]
697 ar = view.execute('\n'.join([
697 ar = view.execute('\n'.join([
698 'from IPython.kernel.zmq.datapub import publish_data',
698 'from IPython.kernel.zmq.datapub import publish_data',
699 'for i in range(5):',
699 'for i in range(5):',
700 ' publish_data(dict(i=i))'
700 ' publish_data(dict(i=i))'
701 ]), block=False)
701 ]), block=False)
702 self.assertTrue(isinstance(ar.data, dict))
702 self.assertTrue(isinstance(ar.data, dict))
703 ar.get(5)
703 ar.get(5)
704 self.assertEqual(ar.data, dict(i=4))
704 self.assertEqual(ar.data, dict(i=4))
705
705
706 def test_data_pub(self):
706 def test_data_pub(self):
707 view = self.client[:]
707 view = self.client[:]
708 ar = view.execute('\n'.join([
708 ar = view.execute('\n'.join([
709 'from IPython.kernel.zmq.datapub import publish_data',
709 'from IPython.kernel.zmq.datapub import publish_data',
710 'for i in range(5):',
710 'for i in range(5):',
711 ' publish_data(dict(i=i))'
711 ' publish_data(dict(i=i))'
712 ]), block=False)
712 ]), block=False)
713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
714 ar.get(5)
714 ar.get(5)
715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
716
716
717 def test_can_list_arg(self):
717 def test_can_list_arg(self):
718 """args in lists are canned"""
718 """args in lists are canned"""
719 view = self.client[-1]
719 view = self.client[-1]
720 view['a'] = 128
720 view['a'] = 128
721 rA = pmod.Reference('a')
721 rA = pmod.Reference('a')
722 ar = view.apply_async(lambda x: x, [rA])
722 ar = view.apply_async(lambda x: x, [rA])
723 r = ar.get(5)
723 r = ar.get(5)
724 self.assertEqual(r, [128])
724 self.assertEqual(r, [128])
725
725
726 def test_can_dict_arg(self):
726 def test_can_dict_arg(self):
727 """args in dicts are canned"""
727 """args in dicts are canned"""
728 view = self.client[-1]
728 view = self.client[-1]
729 view['a'] = 128
729 view['a'] = 128
730 rA = pmod.Reference('a')
730 rA = pmod.Reference('a')
731 ar = view.apply_async(lambda x: x, dict(foo=rA))
731 ar = view.apply_async(lambda x: x, dict(foo=rA))
732 r = ar.get(5)
732 r = ar.get(5)
733 self.assertEqual(r, dict(foo=128))
733 self.assertEqual(r, dict(foo=128))
734
734
735 def test_can_list_kwarg(self):
735 def test_can_list_kwarg(self):
736 """kwargs in lists are canned"""
736 """kwargs in lists are canned"""
737 view = self.client[-1]
737 view = self.client[-1]
738 view['a'] = 128
738 view['a'] = 128
739 rA = pmod.Reference('a')
739 rA = pmod.Reference('a')
740 ar = view.apply_async(lambda x=5: x, x=[rA])
740 ar = view.apply_async(lambda x=5: x, x=[rA])
741 r = ar.get(5)
741 r = ar.get(5)
742 self.assertEqual(r, [128])
742 self.assertEqual(r, [128])
743
743
744 def test_can_dict_kwarg(self):
744 def test_can_dict_kwarg(self):
745 """kwargs in dicts are canned"""
745 """kwargs in dicts are canned"""
746 view = self.client[-1]
746 view = self.client[-1]
747 view['a'] = 128
747 view['a'] = 128
748 rA = pmod.Reference('a')
748 rA = pmod.Reference('a')
749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
750 r = ar.get(5)
750 r = ar.get(5)
751 self.assertEqual(r, dict(foo=128))
751 self.assertEqual(r, dict(foo=128))
752
752
753 def test_map_ref(self):
753 def test_map_ref(self):
754 """view.map works with references"""
754 """view.map works with references"""
755 view = self.client[:]
755 view = self.client[:]
756 ranks = sorted(self.client.ids)
756 ranks = sorted(self.client.ids)
757 view.scatter('rank', ranks, flatten=True)
757 view.scatter('rank', ranks, flatten=True)
758 rrank = pmod.Reference('rank')
758 rrank = pmod.Reference('rank')
759
759
760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
761 drank = amr.get(5)
761 drank = amr.get(5)
762 self.assertEqual(drank, [ r*2 for r in ranks ])
762 self.assertEqual(drank, [ r*2 for r in ranks ])
763
763
764 def test_nested_getitem_setitem(self):
764 def test_nested_getitem_setitem(self):
765 """get and set with view['a.b']"""
765 """get and set with view['a.b']"""
766 view = self.client[-1]
766 view = self.client[-1]
767 view.execute('\n'.join([
767 view.execute('\n'.join([
768 'class A(object): pass',
768 'class A(object): pass',
769 'a = A()',
769 'a = A()',
770 'a.b = 128',
770 'a.b = 128',
771 ]), block=True)
771 ]), block=True)
772 ra = pmod.Reference('a')
772 ra = pmod.Reference('a')
773
773
774 r = view.apply_sync(lambda x: x.b, ra)
774 r = view.apply_sync(lambda x: x.b, ra)
775 self.assertEqual(r, 128)
775 self.assertEqual(r, 128)
776 self.assertEqual(view['a.b'], 128)
776 self.assertEqual(view['a.b'], 128)
777
777
778 view['a.b'] = 0
778 view['a.b'] = 0
779
779
780 r = view.apply_sync(lambda x: x.b, ra)
780 r = view.apply_sync(lambda x: x.b, ra)
781 self.assertEqual(r, 0)
781 self.assertEqual(r, 0)
782 self.assertEqual(view['a.b'], 0)
782 self.assertEqual(view['a.b'], 0)
783
783
784 def test_return_namedtuple(self):
784 def test_return_namedtuple(self):
785 def namedtuplify(x, y):
785 def namedtuplify(x, y):
786 from IPython.parallel.tests.test_view import point
786 from IPython.parallel.tests.test_view import point
787 return point(x, y)
787 return point(x, y)
788
788
789 view = self.client[-1]
789 view = self.client[-1]
790 p = view.apply_sync(namedtuplify, 1, 2)
790 p = view.apply_sync(namedtuplify, 1, 2)
791 self.assertEqual(p.x, 1)
791 self.assertEqual(p.x, 1)
792 self.assertEqual(p.y, 2)
792 self.assertEqual(p.y, 2)
793
793
794 def test_apply_namedtuple(self):
794 def test_apply_namedtuple(self):
795 def echoxy(p):
795 def echoxy(p):
796 return p.y, p.x
796 return p.y, p.x
797
797
798 view = self.client[-1]
798 view = self.client[-1]
799 tup = view.apply_sync(echoxy, point(1, 2))
799 tup = view.apply_sync(echoxy, point(1, 2))
800 self.assertEqual(tup, (2,1))
800 self.assertEqual(tup, (2,1))
801
801
General Comments 0
You need to be logged in to leave comments. Login now