##// END OF EJS Templates
allow map / parallel function for single-engine views
MinRK -
Show More
@@ -1,241 +1,244 b''
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 * Min RK
6 * Min RK
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import sys
21 import sys
22 import warnings
22 import warnings
23
23
24 from IPython.testing.skipdoctest import skip_doctest
24 from IPython.testing.skipdoctest import skip_doctest
25
25
26 from . import map as Map
26 from . import map as Map
27 from .asyncresult import AsyncMapResult
27 from .asyncresult import AsyncMapResult
28
28
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30 # Functions and Decorators
30 # Functions and Decorators
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 @skip_doctest
33 @skip_doctest
34 def remote(view, block=None, **flags):
34 def remote(view, block=None, **flags):
35 """Turn a function into a remote function.
35 """Turn a function into a remote function.
36
36
37 This method can be used for map:
37 This method can be used for map:
38
38
39 In [1]: @remote(view,block=True)
39 In [1]: @remote(view,block=True)
40 ...: def func(a):
40 ...: def func(a):
41 ...: pass
41 ...: pass
42 """
42 """
43
43
44 def remote_function(f):
44 def remote_function(f):
45 return RemoteFunction(view, f, block=block, **flags)
45 return RemoteFunction(view, f, block=block, **flags)
46 return remote_function
46 return remote_function
47
47
48 @skip_doctest
48 @skip_doctest
49 def parallel(view, dist='b', block=None, ordered=True, **flags):
49 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 """Turn a function into a parallel remote function.
50 """Turn a function into a parallel remote function.
51
51
52 This method can be used for map:
52 This method can be used for map:
53
53
54 In [1]: @parallel(view, block=True)
54 In [1]: @parallel(view, block=True)
55 ...: def func(a):
55 ...: def func(a):
56 ...: pass
56 ...: pass
57 """
57 """
58
58
59 def parallel_function(f):
59 def parallel_function(f):
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return parallel_function
61 return parallel_function
62
62
63 def getname(f):
63 def getname(f):
64 """Get the name of an object.
64 """Get the name of an object.
65
65
66 For use in case of callables that are not functions, and
66 For use in case of callables that are not functions, and
67 thus may not have __name__ defined.
67 thus may not have __name__ defined.
68
68
69 Order: f.__name__ > f.name > str(f)
69 Order: f.__name__ > f.name > str(f)
70 """
70 """
71 try:
71 try:
72 return f.__name__
72 return f.__name__
73 except:
73 except:
74 pass
74 pass
75 try:
75 try:
76 return f.name
76 return f.name
77 except:
77 except:
78 pass
78 pass
79
79
80 return str(f)
80 return str(f)
81
81
82 #--------------------------------------------------------------------------
82 #--------------------------------------------------------------------------
83 # Classes
83 # Classes
84 #--------------------------------------------------------------------------
84 #--------------------------------------------------------------------------
85
85
86 class RemoteFunction(object):
86 class RemoteFunction(object):
87 """Turn an existing function into a remote function.
87 """Turn an existing function into a remote function.
88
88
89 Parameters
89 Parameters
90 ----------
90 ----------
91
91
92 view : View instance
92 view : View instance
93 The view to be used for execution
93 The view to be used for execution
94 f : callable
94 f : callable
95 The function to be wrapped into a remote function
95 The function to be wrapped into a remote function
96 block : bool [default: None]
96 block : bool [default: None]
97 Whether to wait for results or not. The default behavior is
97 Whether to wait for results or not. The default behavior is
98 to use the current `block` attribute of `view`
98 to use the current `block` attribute of `view`
99
99
100 **flags : remaining kwargs are passed to View.temp_flags
100 **flags : remaining kwargs are passed to View.temp_flags
101 """
101 """
102
102
103 view = None # the remote connection
103 view = None # the remote connection
104 func = None # the wrapped function
104 func = None # the wrapped function
105 block = None # whether to block
105 block = None # whether to block
106 flags = None # dict of extra kwargs for temp_flags
106 flags = None # dict of extra kwargs for temp_flags
107
107
108 def __init__(self, view, f, block=None, **flags):
108 def __init__(self, view, f, block=None, **flags):
109 self.view = view
109 self.view = view
110 self.func = f
110 self.func = f
111 self.block=block
111 self.block=block
112 self.flags=flags
112 self.flags=flags
113
113
114 def __call__(self, *args, **kwargs):
114 def __call__(self, *args, **kwargs):
115 block = self.view.block if self.block is None else self.block
115 block = self.view.block if self.block is None else self.block
116 with self.view.temp_flags(block=block, **self.flags):
116 with self.view.temp_flags(block=block, **self.flags):
117 return self.view.apply(self.func, *args, **kwargs)
117 return self.view.apply(self.func, *args, **kwargs)
118
118
119
119
120 class ParallelFunction(RemoteFunction):
120 class ParallelFunction(RemoteFunction):
121 """Class for mapping a function to sequences.
121 """Class for mapping a function to sequences.
122
122
123 This will distribute the sequences according the a mapper, and call
123 This will distribute the sequences according the a mapper, and call
124 the function on each sub-sequence. If called via map, then the function
124 the function on each sub-sequence. If called via map, then the function
125 will be called once on each element, rather that each sub-sequence.
125 will be called once on each element, rather that each sub-sequence.
126
126
127 Parameters
127 Parameters
128 ----------
128 ----------
129
129
130 view : View instance
130 view : View instance
131 The view to be used for execution
131 The view to be used for execution
132 f : callable
132 f : callable
133 The function to be wrapped into a remote function
133 The function to be wrapped into a remote function
134 dist : str [default: 'b']
134 dist : str [default: 'b']
135 The key for which mapObject to use to distribute sequences
135 The key for which mapObject to use to distribute sequences
136 options are:
136 options are:
137 * 'b' : use contiguous chunks in order
137 * 'b' : use contiguous chunks in order
138 * 'r' : use round-robin striping
138 * 'r' : use round-robin striping
139 block : bool [default: None]
139 block : bool [default: None]
140 Whether to wait for results or not. The default behavior is
140 Whether to wait for results or not. The default behavior is
141 to use the current `block` attribute of `view`
141 to use the current `block` attribute of `view`
142 chunksize : int or None
142 chunksize : int or None
143 The size of chunk to use when breaking up sequences in a load-balanced manner
143 The size of chunk to use when breaking up sequences in a load-balanced manner
144 ordered : bool [default: True]
144 ordered : bool [default: True]
145 Whether
145 Whether
146 **flags : remaining kwargs are passed to View.temp_flags
146 **flags : remaining kwargs are passed to View.temp_flags
147 """
147 """
148
148
149 chunksize=None
149 chunksize=None
150 ordered=None
150 ordered=None
151 mapObject=None
151 mapObject=None
152
152
153 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
153 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
154 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
154 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
155 self.chunksize = chunksize
155 self.chunksize = chunksize
156 self.ordered = ordered
156 self.ordered = ordered
157
157
158 mapClass = Map.dists[dist]
158 mapClass = Map.dists[dist]
159 self.mapObject = mapClass()
159 self.mapObject = mapClass()
160
160
161 def __call__(self, *sequences):
161 def __call__(self, *sequences):
162 client = self.view.client
162 client = self.view.client
163
163
164 # check that the length of sequences match
164 # check that the length of sequences match
165 len_0 = len(sequences[0])
165 len_0 = len(sequences[0])
166 for s in sequences:
166 for s in sequences:
167 if len(s)!=len_0:
167 if len(s)!=len_0:
168 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
168 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
169 raise ValueError(msg)
169 raise ValueError(msg)
170 balanced = 'Balanced' in self.view.__class__.__name__
170 balanced = 'Balanced' in self.view.__class__.__name__
171 if balanced:
171 if balanced:
172 if self.chunksize:
172 if self.chunksize:
173 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
173 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
174 else:
174 else:
175 nparts = len_0
175 nparts = len_0
176 targets = [None]*nparts
176 targets = [None]*nparts
177 else:
177 else:
178 if self.chunksize:
178 if self.chunksize:
179 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
179 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
180 # multiplexed:
180 # multiplexed:
181 targets = self.view.targets
181 targets = self.view.targets
182 # 'all' is lazily evaluated at execution time, which is now:
182 # 'all' is lazily evaluated at execution time, which is now:
183 if targets == 'all':
183 if targets == 'all':
184 targets = client._build_targets(targets)[1]
184 targets = client._build_targets(targets)[1]
185 elif isinstance(targets, int):
186 # single-engine view, targets must be iterable
187 targets = [targets]
185 nparts = len(targets)
188 nparts = len(targets)
186
189
187 msg_ids = []
190 msg_ids = []
188 for index, t in enumerate(targets):
191 for index, t in enumerate(targets):
189 args = []
192 args = []
190 for seq in sequences:
193 for seq in sequences:
191 part = self.mapObject.getPartition(seq, index, nparts)
194 part = self.mapObject.getPartition(seq, index, nparts)
192 if len(part) == 0:
195 if len(part) == 0:
193 continue
196 continue
194 else:
197 else:
195 args.append(part)
198 args.append(part)
196 if not args:
199 if not args:
197 continue
200 continue
198
201
199 # print (args)
202 # print (args)
200 if hasattr(self, '_map'):
203 if hasattr(self, '_map'):
201 if sys.version_info[0] >= 3:
204 if sys.version_info[0] >= 3:
202 f = lambda f, *sequences: list(map(f, *sequences))
205 f = lambda f, *sequences: list(map(f, *sequences))
203 else:
206 else:
204 f = map
207 f = map
205 args = [self.func]+args
208 args = [self.func]+args
206 else:
209 else:
207 f=self.func
210 f=self.func
208
211
209 view = self.view if balanced else client[t]
212 view = self.view if balanced else client[t]
210 with view.temp_flags(block=False, **self.flags):
213 with view.temp_flags(block=False, **self.flags):
211 ar = view.apply(f, *args)
214 ar = view.apply(f, *args)
212
215
213 msg_ids.append(ar.msg_ids[0])
216 msg_ids.append(ar.msg_ids[0])
214
217
215 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
218 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
216 fname=getname(self.func),
219 fname=getname(self.func),
217 ordered=self.ordered
220 ordered=self.ordered
218 )
221 )
219
222
220 if self.block:
223 if self.block:
221 try:
224 try:
222 return r.get()
225 return r.get()
223 except KeyboardInterrupt:
226 except KeyboardInterrupt:
224 return r
227 return r
225 else:
228 else:
226 return r
229 return r
227
230
228 def map(self, *sequences):
231 def map(self, *sequences):
229 """call a function on each element of a sequence remotely.
232 """call a function on each element of a sequence remotely.
230 This should behave very much like the builtin map, but return an AsyncMapResult
233 This should behave very much like the builtin map, but return an AsyncMapResult
231 if self.block is False.
234 if self.block is False.
232 """
235 """
233 # set _map as a flag for use inside self.__call__
236 # set _map as a flag for use inside self.__call__
234 self._map = True
237 self._map = True
235 try:
238 try:
236 ret = self.__call__(*sequences)
239 ret = self.__call__(*sequences)
237 finally:
240 finally:
238 del self._map
241 del self._map
239 return ret
242 return ret
240
243
241 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
244 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,537 +1,543 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28
28
29 from IPython import parallel as pmod
29 from IPython import parallel as pmod
30 from IPython.parallel import error
30 from IPython.parallel import error
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import DirectView
32 from IPython.parallel import DirectView
33 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
34
34
35 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
36
36
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38
38
39 def setup():
39 def setup():
40 add_engines(3, total=True)
40 add_engines(3, total=True)
41
41
42 class TestView(ClusterTestCase):
42 class TestView(ClusterTestCase):
43
43
44 def test_z_crash_mux(self):
44 def test_z_crash_mux(self):
45 """test graceful handling of engine death (direct)"""
45 """test graceful handling of engine death (direct)"""
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEquals(d, data)
69 self.assertEquals(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEquals(d, nengines*[data])
72 self.assertEquals(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEquals(r, nengines*[data])
79 self.assertEquals(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEquals(r, nengines*[[10,20]])
82 self.assertEquals(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEquals(r, testf(10))
100 self.assertEquals(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEquals(v['b'], 5)
123 self.assertEquals(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEquals(v['b'], 10)
133 self.assertEquals(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids)
145 ahr = v2.get_result(ar.msg_ids)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertEquals(ahr.get(), ar.get())
147 self.assertEquals(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids)
148 ar2 = v2.get_result(ar.msg_ids)
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 c.spin()
150 c.spin()
151 c.close()
151 c.close()
152
152
153 def test_run_newline(self):
153 def test_run_newline(self):
154 """test that run appends newline to files"""
154 """test that run appends newline to files"""
155 tmpfile = mktemp()
155 tmpfile = mktemp()
156 with open(tmpfile, 'w') as f:
156 with open(tmpfile, 'w') as f:
157 f.write("""def g():
157 f.write("""def g():
158 return 5
158 return 5
159 """)
159 """)
160 v = self.client[-1]
160 v = self.client[-1]
161 v.run(tmpfile, block=True)
161 v.run(tmpfile, block=True)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
163
164 def test_apply_tracked(self):
164 def test_apply_tracked(self):
165 """test tracking for apply"""
165 """test tracking for apply"""
166 # self.add_engines(1)
166 # self.add_engines(1)
167 t = self.client.ids[-1]
167 t = self.client.ids[-1]
168 v = self.client[t]
168 v = self.client[t]
169 v.block=False
169 v.block=False
170 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
173 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176 ar = echo(track=True)
176 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEquals(ar.sent, ar._tracker.done)
178 self.assertEquals(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
179 ar._tracker.wait()
180 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
181
181
182 def test_push_tracked(self):
182 def test_push_tracked(self):
183 t = self.client.ids[-1]
183 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
185 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
189
189
190 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
192 ar._tracker.wait()
193 self.assertEquals(ar.sent, ar._tracker.done)
193 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195 ar.get()
195 ar.get()
196
196
197 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
198 t = self.client.ids
198 t = self.client.ids
199 x='x'*1024*1024
199 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEquals(ar.sent, ar._tracker.done)
206 self.assertEquals(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
207 ar._tracker.wait()
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_remote_reference(self):
211 def test_remote_reference(self):
212 v = self.client[-1]
212 v = self.client[-1]
213 v['a'] = 123
213 v['a'] = 123
214 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
216 self.assertEquals(b, 123)
216 self.assertEquals(b, 123)
217
217
218
218
219 def test_scatter_gather(self):
219 def test_scatter_gather(self):
220 view = self.client[:]
220 view = self.client[:]
221 seq1 = range(16)
221 seq1 = range(16)
222 view.scatter('a', seq1)
222 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
224 self.assertEquals(seq2, seq1)
224 self.assertEquals(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
226
227 @skip_without('numpy')
227 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
229 import numpy
229 import numpy
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 view = self.client[:]
231 view = self.client[:]
232 a = numpy.arange(64)
232 a = numpy.arange(64)
233 view.scatter('a', a)
233 view.scatter('a', a)
234 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
235 assert_array_equal(b, a)
236
236
237 @skip_without('numpy')
237 @skip_without('numpy')
238 def test_push_numpy_nocopy(self):
238 def test_push_numpy_nocopy(self):
239 import numpy
239 import numpy
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view['A'] = a
242 view['A'] = a
243 @interactive
243 @interactive
244 def check_writeable(x):
244 def check_writeable(x):
245 return x.flags.writeable
245 return x.flags.writeable
246
246
247 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
247 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
248 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
248 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
249
249
250 view.push(dict(B=a))
250 view.push(dict(B=a))
251 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
251 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
252 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
252 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
253
253
254 @skip_without('numpy')
254 @skip_without('numpy')
255 def test_apply_numpy(self):
255 def test_apply_numpy(self):
256 """view.apply(f, ndarray)"""
256 """view.apply(f, ndarray)"""
257 import numpy
257 import numpy
258 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
258 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
259
259
260 A = numpy.random.random((100,100))
260 A = numpy.random.random((100,100))
261 view = self.client[-1]
261 view = self.client[-1]
262 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
262 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
263 B = A.astype(dt)
263 B = A.astype(dt)
264 C = view.apply_sync(lambda x:x, B)
264 C = view.apply_sync(lambda x:x, B)
265 assert_array_equal(B,C)
265 assert_array_equal(B,C)
266
266
267 def test_map(self):
267 def test_map(self):
268 view = self.client[:]
268 view = self.client[:]
269 def f(x):
269 def f(x):
270 return x**2
270 return x**2
271 data = range(16)
271 data = range(16)
272 r = view.map_sync(f, data)
272 r = view.map_sync(f, data)
273 self.assertEquals(r, map(f, data))
273 self.assertEquals(r, map(f, data))
274
274
275 def test_map_iterable(self):
275 def test_map_iterable(self):
276 """test map on iterables (direct)"""
276 """test map on iterables (direct)"""
277 view = self.client[:]
277 view = self.client[:]
278 # 101 is prime, so it won't be evenly distributed
278 # 101 is prime, so it won't be evenly distributed
279 arr = range(101)
279 arr = range(101)
280 # ensure it will be an iterator, even in Python 3
280 # ensure it will be an iterator, even in Python 3
281 it = iter(arr)
281 it = iter(arr)
282 r = view.map_sync(lambda x:x, arr)
282 r = view.map_sync(lambda x:x, arr)
283 self.assertEquals(r, list(arr))
283 self.assertEquals(r, list(arr))
284
284
285 def test_scatterGatherNonblocking(self):
285 def test_scatterGatherNonblocking(self):
286 data = range(16)
286 data = range(16)
287 view = self.client[:]
287 view = self.client[:]
288 view.scatter('a', data, block=False)
288 view.scatter('a', data, block=False)
289 ar = view.gather('a', block=False)
289 ar = view.gather('a', block=False)
290 self.assertEquals(ar.get(), data)
290 self.assertEquals(ar.get(), data)
291
291
292 @skip_without('numpy')
292 @skip_without('numpy')
293 def test_scatter_gather_numpy_nonblocking(self):
293 def test_scatter_gather_numpy_nonblocking(self):
294 import numpy
294 import numpy
295 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
295 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
296 a = numpy.arange(64)
296 a = numpy.arange(64)
297 view = self.client[:]
297 view = self.client[:]
298 ar = view.scatter('a', a, block=False)
298 ar = view.scatter('a', a, block=False)
299 self.assertTrue(isinstance(ar, AsyncResult))
299 self.assertTrue(isinstance(ar, AsyncResult))
300 amr = view.gather('a', block=False)
300 amr = view.gather('a', block=False)
301 self.assertTrue(isinstance(amr, AsyncMapResult))
301 self.assertTrue(isinstance(amr, AsyncMapResult))
302 assert_array_equal(amr.get(), a)
302 assert_array_equal(amr.get(), a)
303
303
304 def test_execute(self):
304 def test_execute(self):
305 view = self.client[:]
305 view = self.client[:]
306 # self.client.debug=True
306 # self.client.debug=True
307 execute = view.execute
307 execute = view.execute
308 ar = execute('c=30', block=False)
308 ar = execute('c=30', block=False)
309 self.assertTrue(isinstance(ar, AsyncResult))
309 self.assertTrue(isinstance(ar, AsyncResult))
310 ar = execute('d=[0,1,2]', block=False)
310 ar = execute('d=[0,1,2]', block=False)
311 self.client.wait(ar, 1)
311 self.client.wait(ar, 1)
312 self.assertEquals(len(ar.get()), len(self.client))
312 self.assertEquals(len(ar.get()), len(self.client))
313 for c in view['c']:
313 for c in view['c']:
314 self.assertEquals(c, 30)
314 self.assertEquals(c, 30)
315
315
316 def test_abort(self):
316 def test_abort(self):
317 view = self.client[-1]
317 view = self.client[-1]
318 ar = view.execute('import time; time.sleep(1)', block=False)
318 ar = view.execute('import time; time.sleep(1)', block=False)
319 ar2 = view.apply_async(lambda : 2)
319 ar2 = view.apply_async(lambda : 2)
320 ar3 = view.apply_async(lambda : 3)
320 ar3 = view.apply_async(lambda : 3)
321 view.abort(ar2)
321 view.abort(ar2)
322 view.abort(ar3.msg_ids)
322 view.abort(ar3.msg_ids)
323 self.assertRaises(error.TaskAborted, ar2.get)
323 self.assertRaises(error.TaskAborted, ar2.get)
324 self.assertRaises(error.TaskAborted, ar3.get)
324 self.assertRaises(error.TaskAborted, ar3.get)
325
325
326 def test_abort_all(self):
326 def test_abort_all(self):
327 """view.abort() aborts all outstanding tasks"""
327 """view.abort() aborts all outstanding tasks"""
328 view = self.client[-1]
328 view = self.client[-1]
329 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
329 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
330 view.abort()
330 view.abort()
331 view.wait(timeout=5)
331 view.wait(timeout=5)
332 for ar in ars[5:]:
332 for ar in ars[5:]:
333 self.assertRaises(error.TaskAborted, ar.get)
333 self.assertRaises(error.TaskAborted, ar.get)
334
334
335 def test_temp_flags(self):
335 def test_temp_flags(self):
336 view = self.client[-1]
336 view = self.client[-1]
337 view.block=True
337 view.block=True
338 with view.temp_flags(block=False):
338 with view.temp_flags(block=False):
339 self.assertFalse(view.block)
339 self.assertFalse(view.block)
340 self.assertTrue(view.block)
340 self.assertTrue(view.block)
341
341
342 @dec.known_failure_py3
342 @dec.known_failure_py3
343 def test_importer(self):
343 def test_importer(self):
344 view = self.client[-1]
344 view = self.client[-1]
345 view.clear(block=True)
345 view.clear(block=True)
346 with view.importer:
346 with view.importer:
347 import re
347 import re
348
348
349 @interactive
349 @interactive
350 def findall(pat, s):
350 def findall(pat, s):
351 # this globals() step isn't necessary in real code
351 # this globals() step isn't necessary in real code
352 # only to prevent a closure in the test
352 # only to prevent a closure in the test
353 re = globals()['re']
353 re = globals()['re']
354 return re.findall(pat, s)
354 return re.findall(pat, s)
355
355
356 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
356 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
357
357
358 # parallel magic tests
358 # parallel magic tests
359
359
360 def test_magic_px_blocking(self):
360 def test_magic_px_blocking(self):
361 ip = get_ipython()
361 ip = get_ipython()
362 v = self.client[-1]
362 v = self.client[-1]
363 v.activate()
363 v.activate()
364 v.block=True
364 v.block=True
365
365
366 ip.magic_px('a=5')
366 ip.magic_px('a=5')
367 self.assertEquals(v['a'], 5)
367 self.assertEquals(v['a'], 5)
368 ip.magic_px('a=10')
368 ip.magic_px('a=10')
369 self.assertEquals(v['a'], 10)
369 self.assertEquals(v['a'], 10)
370 sio = StringIO()
370 sio = StringIO()
371 savestdout = sys.stdout
371 savestdout = sys.stdout
372 sys.stdout = sio
372 sys.stdout = sio
373 # just 'print a' worst ~99% of the time, but this ensures that
373 # just 'print a' worst ~99% of the time, but this ensures that
374 # the stdout message has arrived when the result is finished:
374 # the stdout message has arrived when the result is finished:
375 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
375 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
376 sys.stdout = savestdout
376 sys.stdout = savestdout
377 buf = sio.getvalue()
377 buf = sio.getvalue()
378 self.assertTrue('[stdout:' in buf, buf)
378 self.assertTrue('[stdout:' in buf, buf)
379 self.assertTrue(buf.rstrip().endswith('10'))
379 self.assertTrue(buf.rstrip().endswith('10'))
380 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
380 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
381
381
382 def test_magic_px_nonblocking(self):
382 def test_magic_px_nonblocking(self):
383 ip = get_ipython()
383 ip = get_ipython()
384 v = self.client[-1]
384 v = self.client[-1]
385 v.activate()
385 v.activate()
386 v.block=False
386 v.block=False
387
387
388 ip.magic_px('a=5')
388 ip.magic_px('a=5')
389 self.assertEquals(v['a'], 5)
389 self.assertEquals(v['a'], 5)
390 ip.magic_px('a=10')
390 ip.magic_px('a=10')
391 self.assertEquals(v['a'], 10)
391 self.assertEquals(v['a'], 10)
392 sio = StringIO()
392 sio = StringIO()
393 savestdout = sys.stdout
393 savestdout = sys.stdout
394 sys.stdout = sio
394 sys.stdout = sio
395 ip.magic_px('print a')
395 ip.magic_px('print a')
396 sys.stdout = savestdout
396 sys.stdout = savestdout
397 buf = sio.getvalue()
397 buf = sio.getvalue()
398 self.assertFalse('[stdout:%i]'%v.targets in buf)
398 self.assertFalse('[stdout:%i]'%v.targets in buf)
399 ip.magic_px('1/0')
399 ip.magic_px('1/0')
400 ar = v.get_result(-1)
400 ar = v.get_result(-1)
401 self.assertRaisesRemote(ZeroDivisionError, ar.get)
401 self.assertRaisesRemote(ZeroDivisionError, ar.get)
402
402
403 def test_magic_autopx_blocking(self):
403 def test_magic_autopx_blocking(self):
404 ip = get_ipython()
404 ip = get_ipython()
405 v = self.client[-1]
405 v = self.client[-1]
406 v.activate()
406 v.activate()
407 v.block=True
407 v.block=True
408
408
409 sio = StringIO()
409 sio = StringIO()
410 savestdout = sys.stdout
410 savestdout = sys.stdout
411 sys.stdout = sio
411 sys.stdout = sio
412 ip.magic_autopx()
412 ip.magic_autopx()
413 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
413 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
414 ip.run_cell('print b')
414 ip.run_cell('print b')
415 ip.run_cell("b/c")
415 ip.run_cell("b/c")
416 ip.run_code(compile('b*=2', '', 'single'))
416 ip.run_code(compile('b*=2', '', 'single'))
417 ip.magic_autopx()
417 ip.magic_autopx()
418 sys.stdout = savestdout
418 sys.stdout = savestdout
419 output = sio.getvalue().strip()
419 output = sio.getvalue().strip()
420 self.assertTrue(output.startswith('%autopx enabled'))
420 self.assertTrue(output.startswith('%autopx enabled'))
421 self.assertTrue(output.endswith('%autopx disabled'))
421 self.assertTrue(output.endswith('%autopx disabled'))
422 self.assertTrue('RemoteError: ZeroDivisionError' in output)
422 self.assertTrue('RemoteError: ZeroDivisionError' in output)
423 ar = v.get_result(-2)
423 ar = v.get_result(-2)
424 self.assertEquals(v['a'], 5)
424 self.assertEquals(v['a'], 5)
425 self.assertEquals(v['b'], 20)
425 self.assertEquals(v['b'], 20)
426 self.assertRaisesRemote(ZeroDivisionError, ar.get)
426 self.assertRaisesRemote(ZeroDivisionError, ar.get)
427
427
428 def test_magic_autopx_nonblocking(self):
428 def test_magic_autopx_nonblocking(self):
429 ip = get_ipython()
429 ip = get_ipython()
430 v = self.client[-1]
430 v = self.client[-1]
431 v.activate()
431 v.activate()
432 v.block=False
432 v.block=False
433
433
434 sio = StringIO()
434 sio = StringIO()
435 savestdout = sys.stdout
435 savestdout = sys.stdout
436 sys.stdout = sio
436 sys.stdout = sio
437 ip.magic_autopx()
437 ip.magic_autopx()
438 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
438 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
439 ip.run_cell('print b')
439 ip.run_cell('print b')
440 ip.run_cell("b/c")
440 ip.run_cell("b/c")
441 ip.run_code(compile('b*=2', '', 'single'))
441 ip.run_code(compile('b*=2', '', 'single'))
442 ip.magic_autopx()
442 ip.magic_autopx()
443 sys.stdout = savestdout
443 sys.stdout = savestdout
444 output = sio.getvalue().strip()
444 output = sio.getvalue().strip()
445 self.assertTrue(output.startswith('%autopx enabled'))
445 self.assertTrue(output.startswith('%autopx enabled'))
446 self.assertTrue(output.endswith('%autopx disabled'))
446 self.assertTrue(output.endswith('%autopx disabled'))
447 self.assertFalse('ZeroDivisionError' in output)
447 self.assertFalse('ZeroDivisionError' in output)
448 ar = v.get_result(-2)
448 ar = v.get_result(-2)
449 self.assertEquals(v['a'], 5)
449 self.assertEquals(v['a'], 5)
450 self.assertEquals(v['b'], 20)
450 self.assertEquals(v['b'], 20)
451 self.assertRaisesRemote(ZeroDivisionError, ar.get)
451 self.assertRaisesRemote(ZeroDivisionError, ar.get)
452
452
453 def test_magic_result(self):
453 def test_magic_result(self):
454 ip = get_ipython()
454 ip = get_ipython()
455 v = self.client[-1]
455 v = self.client[-1]
456 v.activate()
456 v.activate()
457 v['a'] = 111
457 v['a'] = 111
458 ra = v['a']
458 ra = v['a']
459
459
460 ar = ip.magic_result()
460 ar = ip.magic_result()
461 self.assertEquals(ar.msg_ids, [v.history[-1]])
461 self.assertEquals(ar.msg_ids, [v.history[-1]])
462 self.assertEquals(ar.get(), 111)
462 self.assertEquals(ar.get(), 111)
463 ar = ip.magic_result('-2')
463 ar = ip.magic_result('-2')
464 self.assertEquals(ar.msg_ids, [v.history[-2]])
464 self.assertEquals(ar.msg_ids, [v.history[-2]])
465
465
466 def test_unicode_execute(self):
466 def test_unicode_execute(self):
467 """test executing unicode strings"""
467 """test executing unicode strings"""
468 v = self.client[-1]
468 v = self.client[-1]
469 v.block=True
469 v.block=True
470 if sys.version_info[0] >= 3:
470 if sys.version_info[0] >= 3:
471 code="a='é'"
471 code="a='é'"
472 else:
472 else:
473 code=u"a=u'é'"
473 code=u"a=u'é'"
474 v.execute(code)
474 v.execute(code)
475 self.assertEquals(v['a'], u'é')
475 self.assertEquals(v['a'], u'é')
476
476
477 def test_unicode_apply_result(self):
477 def test_unicode_apply_result(self):
478 """test unicode apply results"""
478 """test unicode apply results"""
479 v = self.client[-1]
479 v = self.client[-1]
480 r = v.apply_sync(lambda : u'é')
480 r = v.apply_sync(lambda : u'é')
481 self.assertEquals(r, u'é')
481 self.assertEquals(r, u'é')
482
482
483 def test_unicode_apply_arg(self):
483 def test_unicode_apply_arg(self):
484 """test passing unicode arguments to apply"""
484 """test passing unicode arguments to apply"""
485 v = self.client[-1]
485 v = self.client[-1]
486
486
487 @interactive
487 @interactive
488 def check_unicode(a, check):
488 def check_unicode(a, check):
489 assert isinstance(a, unicode), "%r is not unicode"%a
489 assert isinstance(a, unicode), "%r is not unicode"%a
490 assert isinstance(check, bytes), "%r is not bytes"%check
490 assert isinstance(check, bytes), "%r is not bytes"%check
491 assert a.encode('utf8') == check, "%s != %s"%(a,check)
491 assert a.encode('utf8') == check, "%s != %s"%(a,check)
492
492
493 for s in [ u'é', u'ßø®∫',u'asdf' ]:
493 for s in [ u'é', u'ßø®∫',u'asdf' ]:
494 try:
494 try:
495 v.apply_sync(check_unicode, s, s.encode('utf8'))
495 v.apply_sync(check_unicode, s, s.encode('utf8'))
496 except error.RemoteError as e:
496 except error.RemoteError as e:
497 if e.ename == 'AssertionError':
497 if e.ename == 'AssertionError':
498 self.fail(e.evalue)
498 self.fail(e.evalue)
499 else:
499 else:
500 raise e
500 raise e
501
501
502 def test_map_reference(self):
502 def test_map_reference(self):
503 """view.map(<Reference>, *seqs) should work"""
503 """view.map(<Reference>, *seqs) should work"""
504 v = self.client[:]
504 v = self.client[:]
505 v.scatter('n', self.client.ids, flatten=True)
505 v.scatter('n', self.client.ids, flatten=True)
506 v.execute("f = lambda x,y: x*y")
506 v.execute("f = lambda x,y: x*y")
507 rf = pmod.Reference('f')
507 rf = pmod.Reference('f')
508 nlist = list(range(10))
508 nlist = list(range(10))
509 mlist = nlist[::-1]
509 mlist = nlist[::-1]
510 expected = [ m*n for m,n in zip(mlist, nlist) ]
510 expected = [ m*n for m,n in zip(mlist, nlist) ]
511 result = v.map_sync(rf, mlist, nlist)
511 result = v.map_sync(rf, mlist, nlist)
512 self.assertEquals(result, expected)
512 self.assertEquals(result, expected)
513
513
514 def test_apply_reference(self):
514 def test_apply_reference(self):
515 """view.apply(<Reference>, *args) should work"""
515 """view.apply(<Reference>, *args) should work"""
516 v = self.client[:]
516 v = self.client[:]
517 v.scatter('n', self.client.ids, flatten=True)
517 v.scatter('n', self.client.ids, flatten=True)
518 v.execute("f = lambda x: n*x")
518 v.execute("f = lambda x: n*x")
519 rf = pmod.Reference('f')
519 rf = pmod.Reference('f')
520 result = v.apply_sync(rf, 5)
520 result = v.apply_sync(rf, 5)
521 expected = [ 5*id for id in self.client.ids ]
521 expected = [ 5*id for id in self.client.ids ]
522 self.assertEquals(result, expected)
522 self.assertEquals(result, expected)
523
523
524 def test_eval_reference(self):
524 def test_eval_reference(self):
525 v = self.client[self.client.ids[0]]
525 v = self.client[self.client.ids[0]]
526 v['g'] = range(5)
526 v['g'] = range(5)
527 rg = pmod.Reference('g[0]')
527 rg = pmod.Reference('g[0]')
528 echo = lambda x:x
528 echo = lambda x:x
529 self.assertEquals(v.apply_sync(echo, rg), 0)
529 self.assertEquals(v.apply_sync(echo, rg), 0)
530
530
531 def test_reference_nameerror(self):
531 def test_reference_nameerror(self):
532 v = self.client[self.client.ids[0]]
532 v = self.client[self.client.ids[0]]
533 r = pmod.Reference('elvis_has_left')
533 r = pmod.Reference('elvis_has_left')
534 echo = lambda x:x
534 echo = lambda x:x
535 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
535 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
536
536
537 def test_single_engine_map(self):
538 e0 = self.client[self.client.ids[0]]
539 r = range(5)
540 check = [ -1*i for i in r ]
541 result = e0.map_sync(lambda x: -1*x, r)
542 self.assertEquals(result, check)
537
543
General Comments 0
You need to be logged in to leave comments. Login now