##// END OF EJS Templates
Fix parallel.client.View map() on numpy arrays
Samuel Ainsworth -
Show More
@@ -1,282 +1,283 b''
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 * Min RK
6 * Min RK
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import sys
21 import sys
22 import warnings
22 import warnings
23
23
24 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
25 from IPython.testing.skipdoctest import skip_doctest
25 from IPython.testing.skipdoctest import skip_doctest
26
26
27 from . import map as Map
27 from . import map as Map
28 from .asyncresult import AsyncMapResult
28 from .asyncresult import AsyncMapResult
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Functions and Decorators
31 # Functions and Decorators
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 @skip_doctest
34 @skip_doctest
35 def remote(view, block=None, **flags):
35 def remote(view, block=None, **flags):
36 """Turn a function into a remote function.
36 """Turn a function into a remote function.
37
37
38 This method can be used for map:
38 This method can be used for map:
39
39
40 In [1]: @remote(view,block=True)
40 In [1]: @remote(view,block=True)
41 ...: def func(a):
41 ...: def func(a):
42 ...: pass
42 ...: pass
43 """
43 """
44
44
45 def remote_function(f):
45 def remote_function(f):
46 return RemoteFunction(view, f, block=block, **flags)
46 return RemoteFunction(view, f, block=block, **flags)
47 return remote_function
47 return remote_function
48
48
49 @skip_doctest
49 @skip_doctest
50 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 """Turn a function into a parallel remote function.
51 """Turn a function into a parallel remote function.
52
52
53 This method can be used for map:
53 This method can be used for map:
54
54
55 In [1]: @parallel(view, block=True)
55 In [1]: @parallel(view, block=True)
56 ...: def func(a):
56 ...: def func(a):
57 ...: pass
57 ...: pass
58 """
58 """
59
59
60 def parallel_function(f):
60 def parallel_function(f):
61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 return parallel_function
62 return parallel_function
63
63
64 def getname(f):
64 def getname(f):
65 """Get the name of an object.
65 """Get the name of an object.
66
66
67 For use in case of callables that are not functions, and
67 For use in case of callables that are not functions, and
68 thus may not have __name__ defined.
68 thus may not have __name__ defined.
69
69
70 Order: f.__name__ > f.name > str(f)
70 Order: f.__name__ > f.name > str(f)
71 """
71 """
72 try:
72 try:
73 return f.__name__
73 return f.__name__
74 except:
74 except:
75 pass
75 pass
76 try:
76 try:
77 return f.name
77 return f.name
78 except:
78 except:
79 pass
79 pass
80
80
81 return str(f)
81 return str(f)
82
82
83 @decorator
83 @decorator
84 def sync_view_results(f, self, *args, **kwargs):
84 def sync_view_results(f, self, *args, **kwargs):
85 """sync relevant results from self.client to our results attribute.
85 """sync relevant results from self.client to our results attribute.
86
86
87 This is a clone of view.sync_results, but for remote functions
87 This is a clone of view.sync_results, but for remote functions
88 """
88 """
89 view = self.view
89 view = self.view
90 if view._in_sync_results:
90 if view._in_sync_results:
91 return f(self, *args, **kwargs)
91 return f(self, *args, **kwargs)
92 print 'in sync results', f
92 print 'in sync results', f
93 view._in_sync_results = True
93 view._in_sync_results = True
94 try:
94 try:
95 ret = f(self, *args, **kwargs)
95 ret = f(self, *args, **kwargs)
96 finally:
96 finally:
97 view._in_sync_results = False
97 view._in_sync_results = False
98 view._sync_results()
98 view._sync_results()
99 return ret
99 return ret
100
100
101 #--------------------------------------------------------------------------
101 #--------------------------------------------------------------------------
102 # Classes
102 # Classes
103 #--------------------------------------------------------------------------
103 #--------------------------------------------------------------------------
104
104
105 class RemoteFunction(object):
105 class RemoteFunction(object):
106 """Turn an existing function into a remote function.
106 """Turn an existing function into a remote function.
107
107
108 Parameters
108 Parameters
109 ----------
109 ----------
110
110
111 view : View instance
111 view : View instance
112 The view to be used for execution
112 The view to be used for execution
113 f : callable
113 f : callable
114 The function to be wrapped into a remote function
114 The function to be wrapped into a remote function
115 block : bool [default: None]
115 block : bool [default: None]
116 Whether to wait for results or not. The default behavior is
116 Whether to wait for results or not. The default behavior is
117 to use the current `block` attribute of `view`
117 to use the current `block` attribute of `view`
118
118
119 **flags : remaining kwargs are passed to View.temp_flags
119 **flags : remaining kwargs are passed to View.temp_flags
120 """
120 """
121
121
122 view = None # the remote connection
122 view = None # the remote connection
123 func = None # the wrapped function
123 func = None # the wrapped function
124 block = None # whether to block
124 block = None # whether to block
125 flags = None # dict of extra kwargs for temp_flags
125 flags = None # dict of extra kwargs for temp_flags
126
126
127 def __init__(self, view, f, block=None, **flags):
127 def __init__(self, view, f, block=None, **flags):
128 self.view = view
128 self.view = view
129 self.func = f
129 self.func = f
130 self.block=block
130 self.block=block
131 self.flags=flags
131 self.flags=flags
132
132
133 def __call__(self, *args, **kwargs):
133 def __call__(self, *args, **kwargs):
134 block = self.view.block if self.block is None else self.block
134 block = self.view.block if self.block is None else self.block
135 with self.view.temp_flags(block=block, **self.flags):
135 with self.view.temp_flags(block=block, **self.flags):
136 return self.view.apply(self.func, *args, **kwargs)
136 return self.view.apply(self.func, *args, **kwargs)
137
137
138
138
139 class ParallelFunction(RemoteFunction):
139 class ParallelFunction(RemoteFunction):
140 """Class for mapping a function to sequences.
140 """Class for mapping a function to sequences.
141
141
142 This will distribute the sequences according the a mapper, and call
142 This will distribute the sequences according the a mapper, and call
143 the function on each sub-sequence. If called via map, then the function
143 the function on each sub-sequence. If called via map, then the function
144 will be called once on each element, rather that each sub-sequence.
144 will be called once on each element, rather that each sub-sequence.
145
145
146 Parameters
146 Parameters
147 ----------
147 ----------
148
148
149 view : View instance
149 view : View instance
150 The view to be used for execution
150 The view to be used for execution
151 f : callable
151 f : callable
152 The function to be wrapped into a remote function
152 The function to be wrapped into a remote function
153 dist : str [default: 'b']
153 dist : str [default: 'b']
154 The key for which mapObject to use to distribute sequences
154 The key for which mapObject to use to distribute sequences
155 options are:
155 options are:
156 * 'b' : use contiguous chunks in order
156 * 'b' : use contiguous chunks in order
157 * 'r' : use round-robin striping
157 * 'r' : use round-robin striping
158 block : bool [default: None]
158 block : bool [default: None]
159 Whether to wait for results or not. The default behavior is
159 Whether to wait for results or not. The default behavior is
160 to use the current `block` attribute of `view`
160 to use the current `block` attribute of `view`
161 chunksize : int or None
161 chunksize : int or None
162 The size of chunk to use when breaking up sequences in a load-balanced manner
162 The size of chunk to use when breaking up sequences in a load-balanced manner
163 ordered : bool [default: True]
163 ordered : bool [default: True]
164 Whether the result should be kept in order. If False,
164 Whether the result should be kept in order. If False,
165 results become available as they arrive, regardless of submission order.
165 results become available as they arrive, regardless of submission order.
166 **flags : remaining kwargs are passed to View.temp_flags
166 **flags : remaining kwargs are passed to View.temp_flags
167 """
167 """
168
168
169 chunksize = None
169 chunksize = None
170 ordered = None
170 ordered = None
171 mapObject = None
171 mapObject = None
172 _mapping = False
172 _mapping = False
173
173
174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
176 self.chunksize = chunksize
176 self.chunksize = chunksize
177 self.ordered = ordered
177 self.ordered = ordered
178
178
179 mapClass = Map.dists[dist]
179 mapClass = Map.dists[dist]
180 self.mapObject = mapClass()
180 self.mapObject = mapClass()
181
181
182 @sync_view_results
182 @sync_view_results
183 def __call__(self, *sequences):
183 def __call__(self, *sequences):
184 client = self.view.client
184 client = self.view.client
185
185
186 lens = []
186 lens = []
187 maxlen = minlen = -1
187 maxlen = minlen = -1
188 for i, seq in enumerate(sequences):
188 for i, seq in enumerate(sequences):
189 try:
189 try:
190 n = len(seq)
190 n = len(seq)
191 except Exception:
191 except Exception:
192 seq = list(seq)
192 seq = list(seq)
193 if isinstance(sequences, tuple):
193 if isinstance(sequences, tuple):
194 # can't alter a tuple
194 # can't alter a tuple
195 sequences = list(sequences)
195 sequences = list(sequences)
196 sequences[i] = seq
196 sequences[i] = seq
197 n = len(seq)
197 n = len(seq)
198 if n > maxlen:
198 if n > maxlen:
199 maxlen = n
199 maxlen = n
200 if minlen == -1 or n < minlen:
200 if minlen == -1 or n < minlen:
201 minlen = n
201 minlen = n
202 lens.append(n)
202 lens.append(n)
203
203
204 # check that the length of sequences match
204 # check that the length of sequences match
205 if not self._mapping and minlen != maxlen:
205 if not self._mapping and minlen != maxlen:
206 msg = 'all sequences must have equal length, but have %s' % lens
206 msg = 'all sequences must have equal length, but have %s' % lens
207 raise ValueError(msg)
207 raise ValueError(msg)
208
208
209 balanced = 'Balanced' in self.view.__class__.__name__
209 balanced = 'Balanced' in self.view.__class__.__name__
210 if balanced:
210 if balanced:
211 if self.chunksize:
211 if self.chunksize:
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
213 else:
213 else:
214 nparts = maxlen
214 nparts = maxlen
215 targets = [None]*nparts
215 targets = [None]*nparts
216 else:
216 else:
217 if self.chunksize:
217 if self.chunksize:
218 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
218 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
219 # multiplexed:
219 # multiplexed:
220 targets = self.view.targets
220 targets = self.view.targets
221 # 'all' is lazily evaluated at execution time, which is now:
221 # 'all' is lazily evaluated at execution time, which is now:
222 if targets == 'all':
222 if targets == 'all':
223 targets = client._build_targets(targets)[1]
223 targets = client._build_targets(targets)[1]
224 elif isinstance(targets, int):
224 elif isinstance(targets, int):
225 # single-engine view, targets must be iterable
225 # single-engine view, targets must be iterable
226 targets = [targets]
226 targets = [targets]
227 nparts = len(targets)
227 nparts = len(targets)
228
228
229 msg_ids = []
229 msg_ids = []
230 for index, t in enumerate(targets):
230 for index, t in enumerate(targets):
231 args = []
231 args = []
232 for seq in sequences:
232 for seq in sequences:
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
234 args.append(part)
234 args.append(part)
235 if not any(args):
235
236 if sum([len(arg) for arg in args]) == 0:
236 continue
237 continue
237
238
238 if self._mapping:
239 if self._mapping:
239 if sys.version_info[0] >= 3:
240 if sys.version_info[0] >= 3:
240 f = lambda f, *sequences: list(map(f, *sequences))
241 f = lambda f, *sequences: list(map(f, *sequences))
241 else:
242 else:
242 f = map
243 f = map
243 args = [self.func] + args
244 args = [self.func] + args
244 else:
245 else:
245 f=self.func
246 f=self.func
246
247
247 view = self.view if balanced else client[t]
248 view = self.view if balanced else client[t]
248 with view.temp_flags(block=False, **self.flags):
249 with view.temp_flags(block=False, **self.flags):
249 ar = view.apply(f, *args)
250 ar = view.apply(f, *args)
250
251
251 msg_ids.extend(ar.msg_ids)
252 msg_ids.extend(ar.msg_ids)
252
253
253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
254 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
254 fname=getname(self.func),
255 fname=getname(self.func),
255 ordered=self.ordered
256 ordered=self.ordered
256 )
257 )
257
258
258 if self.block:
259 if self.block:
259 try:
260 try:
260 return r.get()
261 return r.get()
261 except KeyboardInterrupt:
262 except KeyboardInterrupt:
262 return r
263 return r
263 else:
264 else:
264 return r
265 return r
265
266
266 def map(self, *sequences):
267 def map(self, *sequences):
267 """call a function on each element of one or more sequence(s) remotely.
268 """call a function on each element of one or more sequence(s) remotely.
268 This should behave very much like the builtin map, but return an AsyncMapResult
269 This should behave very much like the builtin map, but return an AsyncMapResult
269 if self.block is False.
270 if self.block is False.
270
271
271 That means it can take generators (will be cast to lists locally),
272 That means it can take generators (will be cast to lists locally),
272 and mismatched sequence lengths will be padded with None.
273 and mismatched sequence lengths will be padded with None.
273 """
274 """
274 # set _mapping as a flag for use inside self.__call__
275 # set _mapping as a flag for use inside self.__call__
275 self._mapping = True
276 self._mapping = True
276 try:
277 try:
277 ret = self(*sequences)
278 ret = self(*sequences)
278 finally:
279 finally:
279 self._mapping = False
280 self._mapping = False
280 return ret
281 return ret
281
282
282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
283 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,789 +1,801 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from collections import namedtuple
22 from collections import namedtuple
23 from tempfile import mktemp
23 from tempfile import mktemp
24 from StringIO import StringIO
24 from StringIO import StringIO
25
25
26 import zmq
26 import zmq
27 from nose import SkipTest
27 from nose import SkipTest
28 from nose.plugins.attrib import attr
28 from nose.plugins.attrib import attr
29
29
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.ipunittest import ParametricTestCase
31 from IPython.testing.ipunittest import ParametricTestCase
32 from IPython.utils.io import capture_output
32 from IPython.utils.io import capture_output
33
33
34 from IPython import parallel as pmod
34 from IPython import parallel as pmod
35 from IPython.parallel import error
35 from IPython.parallel import error
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 from IPython.parallel import DirectView
37 from IPython.parallel import DirectView
38 from IPython.parallel.util import interactive
38 from IPython.parallel.util import interactive
39
39
40 from IPython.parallel.tests import add_engines
40 from IPython.parallel.tests import add_engines
41
41
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43
43
44 def setup():
44 def setup():
45 add_engines(3, total=True)
45 add_engines(3, total=True)
46
46
47 point = namedtuple("point", "x y")
47 point = namedtuple("point", "x y")
48
48
49 class TestView(ClusterTestCase, ParametricTestCase):
49 class TestView(ClusterTestCase, ParametricTestCase):
50
50
51 def setUp(self):
51 def setUp(self):
52 # On Win XP, wait for resource cleanup, else parallel test group fails
52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 time.sleep(2)
55 time.sleep(2)
56 super(TestView, self).setUp()
56 super(TestView, self).setUp()
57
57
58 @attr('crash')
58 @attr('crash')
59 def test_z_crash_mux(self):
59 def test_z_crash_mux(self):
60 """test graceful handling of engine death (direct)"""
60 """test graceful handling of engine death (direct)"""
61 # self.add_engines(1)
61 # self.add_engines(1)
62 eid = self.client.ids[-1]
62 eid = self.client.ids[-1]
63 ar = self.client[eid].apply_async(crash)
63 ar = self.client[eid].apply_async(crash)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 eid = ar.engine_id
65 eid = ar.engine_id
66 tic = time.time()
66 tic = time.time()
67 while eid in self.client.ids and time.time()-tic < 5:
67 while eid in self.client.ids and time.time()-tic < 5:
68 time.sleep(.01)
68 time.sleep(.01)
69 self.client.spin()
69 self.client.spin()
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71
71
72 def test_push_pull(self):
72 def test_push_pull(self):
73 """test pushing and pulling"""
73 """test pushing and pulling"""
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 t = self.client.ids[-1]
75 t = self.client.ids[-1]
76 v = self.client[t]
76 v = self.client[t]
77 push = v.push
77 push = v.push
78 pull = v.pull
78 pull = v.pull
79 v.block=True
79 v.block=True
80 nengines = len(self.client)
80 nengines = len(self.client)
81 push({'data':data})
81 push({'data':data})
82 d = pull('data')
82 d = pull('data')
83 self.assertEqual(d, data)
83 self.assertEqual(d, data)
84 self.client[:].push({'data':data})
84 self.client[:].push({'data':data})
85 d = self.client[:].pull('data', block=True)
85 d = self.client[:].pull('data', block=True)
86 self.assertEqual(d, nengines*[data])
86 self.assertEqual(d, nengines*[data])
87 ar = push({'data':data}, block=False)
87 ar = push({'data':data}, block=False)
88 self.assertTrue(isinstance(ar, AsyncResult))
88 self.assertTrue(isinstance(ar, AsyncResult))
89 r = ar.get()
89 r = ar.get()
90 ar = self.client[:].pull('data', block=False)
90 ar = self.client[:].pull('data', block=False)
91 self.assertTrue(isinstance(ar, AsyncResult))
91 self.assertTrue(isinstance(ar, AsyncResult))
92 r = ar.get()
92 r = ar.get()
93 self.assertEqual(r, nengines*[data])
93 self.assertEqual(r, nengines*[data])
94 self.client[:].push(dict(a=10,b=20))
94 self.client[:].push(dict(a=10,b=20))
95 r = self.client[:].pull(('a','b'), block=True)
95 r = self.client[:].pull(('a','b'), block=True)
96 self.assertEqual(r, nengines*[[10,20]])
96 self.assertEqual(r, nengines*[[10,20]])
97
97
98 def test_push_pull_function(self):
98 def test_push_pull_function(self):
99 "test pushing and pulling functions"
99 "test pushing and pulling functions"
100 def testf(x):
100 def testf(x):
101 return 2.0*x
101 return 2.0*x
102
102
103 t = self.client.ids[-1]
103 t = self.client.ids[-1]
104 v = self.client[t]
104 v = self.client[t]
105 v.block=True
105 v.block=True
106 push = v.push
106 push = v.push
107 pull = v.pull
107 pull = v.pull
108 execute = v.execute
108 execute = v.execute
109 push({'testf':testf})
109 push({'testf':testf})
110 r = pull('testf')
110 r = pull('testf')
111 self.assertEqual(r(1.0), testf(1.0))
111 self.assertEqual(r(1.0), testf(1.0))
112 execute('r = testf(10)')
112 execute('r = testf(10)')
113 r = pull('r')
113 r = pull('r')
114 self.assertEqual(r, testf(10))
114 self.assertEqual(r, testf(10))
115 ar = self.client[:].push({'testf':testf}, block=False)
115 ar = self.client[:].push({'testf':testf}, block=False)
116 ar.get()
116 ar.get()
117 ar = self.client[:].pull('testf', block=False)
117 ar = self.client[:].pull('testf', block=False)
118 rlist = ar.get()
118 rlist = ar.get()
119 for r in rlist:
119 for r in rlist:
120 self.assertEqual(r(1.0), testf(1.0))
120 self.assertEqual(r(1.0), testf(1.0))
121 execute("def g(x): return x*x")
121 execute("def g(x): return x*x")
122 r = pull(('testf','g'))
122 r = pull(('testf','g'))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124
124
125 def test_push_function_globals(self):
125 def test_push_function_globals(self):
126 """test that pushed functions have access to globals"""
126 """test that pushed functions have access to globals"""
127 @interactive
127 @interactive
128 def geta():
128 def geta():
129 return a
129 return a
130 # self.add_engines(1)
130 # self.add_engines(1)
131 v = self.client[-1]
131 v = self.client[-1]
132 v.block=True
132 v.block=True
133 v['f'] = geta
133 v['f'] = geta
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 v.execute('a=5')
135 v.execute('a=5')
136 v.execute('b=f()')
136 v.execute('b=f()')
137 self.assertEqual(v['b'], 5)
137 self.assertEqual(v['b'], 5)
138
138
139 def test_push_function_defaults(self):
139 def test_push_function_defaults(self):
140 """test that pushed functions preserve default args"""
140 """test that pushed functions preserve default args"""
141 def echo(a=10):
141 def echo(a=10):
142 return a
142 return a
143 v = self.client[-1]
143 v = self.client[-1]
144 v.block=True
144 v.block=True
145 v['f'] = echo
145 v['f'] = echo
146 v.execute('b=f()')
146 v.execute('b=f()')
147 self.assertEqual(v['b'], 10)
147 self.assertEqual(v['b'], 10)
148
148
149 def test_get_result(self):
149 def test_get_result(self):
150 """test getting results from the Hub."""
150 """test getting results from the Hub."""
151 c = pmod.Client(profile='iptest')
151 c = pmod.Client(profile='iptest')
152 # self.add_engines(1)
152 # self.add_engines(1)
153 t = c.ids[-1]
153 t = c.ids[-1]
154 v = c[t]
154 v = c[t]
155 v2 = self.client[t]
155 v2 = self.client[t]
156 ar = v.apply_async(wait, 1)
156 ar = v.apply_async(wait, 1)
157 # give the monitor time to notice the message
157 # give the monitor time to notice the message
158 time.sleep(.25)
158 time.sleep(.25)
159 ahr = v2.get_result(ar.msg_ids[0])
159 ahr = v2.get_result(ar.msg_ids[0])
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 self.assertEqual(ahr.get(), ar.get())
161 self.assertEqual(ahr.get(), ar.get())
162 ar2 = v2.get_result(ar.msg_ids[0])
162 ar2 = v2.get_result(ar.msg_ids[0])
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 c.spin()
164 c.spin()
165 c.close()
165 c.close()
166
166
167 def test_run_newline(self):
167 def test_run_newline(self):
168 """test that run appends newline to files"""
168 """test that run appends newline to files"""
169 tmpfile = mktemp()
169 tmpfile = mktemp()
170 with open(tmpfile, 'w') as f:
170 with open(tmpfile, 'w') as f:
171 f.write("""def g():
171 f.write("""def g():
172 return 5
172 return 5
173 """)
173 """)
174 v = self.client[-1]
174 v = self.client[-1]
175 v.run(tmpfile, block=True)
175 v.run(tmpfile, block=True)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177
177
178 def test_apply_tracked(self):
178 def test_apply_tracked(self):
179 """test tracking for apply"""
179 """test tracking for apply"""
180 # self.add_engines(1)
180 # self.add_engines(1)
181 t = self.client.ids[-1]
181 t = self.client.ids[-1]
182 v = self.client[t]
182 v = self.client[t]
183 v.block=False
183 v.block=False
184 def echo(n=1024*1024, **kwargs):
184 def echo(n=1024*1024, **kwargs):
185 with v.temp_flags(**kwargs):
185 with v.temp_flags(**kwargs):
186 return v.apply(lambda x: x, 'x'*n)
186 return v.apply(lambda x: x, 'x'*n)
187 ar = echo(1, track=False)
187 ar = echo(1, track=False)
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190 ar = echo(track=True)
190 ar = echo(track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertEqual(ar.sent, ar._tracker.done)
192 self.assertEqual(ar.sent, ar._tracker.done)
193 ar._tracker.wait()
193 ar._tracker.wait()
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195
195
196 def test_push_tracked(self):
196 def test_push_tracked(self):
197 t = self.client.ids[-1]
197 t = self.client.ids[-1]
198 ns = dict(x='x'*1024*1024)
198 ns = dict(x='x'*1024*1024)
199 v = self.client[t]
199 v = self.client[t]
200 ar = v.push(ns, block=False, track=False)
200 ar = v.push(ns, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = v.push(ns, block=False, track=True)
204 ar = v.push(ns, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 ar._tracker.wait()
206 ar._tracker.wait()
207 self.assertEqual(ar.sent, ar._tracker.done)
207 self.assertEqual(ar.sent, ar._tracker.done)
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_scatter_tracked(self):
211 def test_scatter_tracked(self):
212 t = self.client.ids
212 t = self.client.ids
213 x='x'*1024*1024
213 x='x'*1024*1024
214 ar = self.client[t].scatter('x', x, block=False, track=False)
214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(ar.sent)
216 self.assertTrue(ar.sent)
217
217
218 ar = self.client[t].scatter('x', x, block=False, track=True)
218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 self.assertEqual(ar.sent, ar._tracker.done)
220 self.assertEqual(ar.sent, ar._tracker.done)
221 ar._tracker.wait()
221 ar._tracker.wait()
222 self.assertTrue(ar.sent)
222 self.assertTrue(ar.sent)
223 ar.get()
223 ar.get()
224
224
225 def test_remote_reference(self):
225 def test_remote_reference(self):
226 v = self.client[-1]
226 v = self.client[-1]
227 v['a'] = 123
227 v['a'] = 123
228 ra = pmod.Reference('a')
228 ra = pmod.Reference('a')
229 b = v.apply_sync(lambda x: x, ra)
229 b = v.apply_sync(lambda x: x, ra)
230 self.assertEqual(b, 123)
230 self.assertEqual(b, 123)
231
231
232
232
233 def test_scatter_gather(self):
233 def test_scatter_gather(self):
234 view = self.client[:]
234 view = self.client[:]
235 seq1 = range(16)
235 seq1 = range(16)
236 view.scatter('a', seq1)
236 view.scatter('a', seq1)
237 seq2 = view.gather('a', block=True)
237 seq2 = view.gather('a', block=True)
238 self.assertEqual(seq2, seq1)
238 self.assertEqual(seq2, seq1)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240
240
241 @skip_without('numpy')
241 @skip_without('numpy')
242 def test_scatter_gather_numpy(self):
242 def test_scatter_gather_numpy(self):
243 import numpy
243 import numpy
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 view = self.client[:]
245 view = self.client[:]
246 a = numpy.arange(64)
246 a = numpy.arange(64)
247 view.scatter('a', a, block=True)
247 view.scatter('a', a, block=True)
248 b = view.gather('a', block=True)
248 b = view.gather('a', block=True)
249 assert_array_equal(b, a)
249 assert_array_equal(b, a)
250
250
251 def test_scatter_gather_lazy(self):
251 def test_scatter_gather_lazy(self):
252 """scatter/gather with targets='all'"""
252 """scatter/gather with targets='all'"""
253 view = self.client.direct_view(targets='all')
253 view = self.client.direct_view(targets='all')
254 x = range(64)
254 x = range(64)
255 view.scatter('x', x)
255 view.scatter('x', x)
256 gathered = view.gather('x', block=True)
256 gathered = view.gather('x', block=True)
257 self.assertEqual(gathered, x)
257 self.assertEqual(gathered, x)
258
258
259
259
260 @dec.known_failure_py3
260 @dec.known_failure_py3
261 @skip_without('numpy')
261 @skip_without('numpy')
262 def test_push_numpy_nocopy(self):
262 def test_push_numpy_nocopy(self):
263 import numpy
263 import numpy
264 view = self.client[:]
264 view = self.client[:]
265 a = numpy.arange(64)
265 a = numpy.arange(64)
266 view['A'] = a
266 view['A'] = a
267 @interactive
267 @interactive
268 def check_writeable(x):
268 def check_writeable(x):
269 return x.flags.writeable
269 return x.flags.writeable
270
270
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273
273
274 view.push(dict(B=a))
274 view.push(dict(B=a))
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277
277
278 @skip_without('numpy')
278 @skip_without('numpy')
279 def test_apply_numpy(self):
279 def test_apply_numpy(self):
280 """view.apply(f, ndarray)"""
280 """view.apply(f, ndarray)"""
281 import numpy
281 import numpy
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283
283
284 A = numpy.random.random((100,100))
284 A = numpy.random.random((100,100))
285 view = self.client[-1]
285 view = self.client[-1]
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 B = A.astype(dt)
287 B = A.astype(dt)
288 C = view.apply_sync(lambda x:x, B)
288 C = view.apply_sync(lambda x:x, B)
289 assert_array_equal(B,C)
289 assert_array_equal(B,C)
290
290
291 @skip_without('numpy')
291 @skip_without('numpy')
292 def test_push_pull_recarray(self):
292 def test_push_pull_recarray(self):
293 """push/pull recarrays"""
293 """push/pull recarrays"""
294 import numpy
294 import numpy
295 from numpy.testing.utils import assert_array_equal
295 from numpy.testing.utils import assert_array_equal
296
296
297 view = self.client[-1]
297 view = self.client[-1]
298
298
299 R = numpy.array([
299 R = numpy.array([
300 (1, 'hi', 0.),
300 (1, 'hi', 0.),
301 (2**30, 'there', 2.5),
301 (2**30, 'there', 2.5),
302 (-99999, 'world', -12345.6789),
302 (-99999, 'world', -12345.6789),
303 ], [('n', int), ('s', '|S10'), ('f', float)])
303 ], [('n', int), ('s', '|S10'), ('f', float)])
304
304
305 view['RR'] = R
305 view['RR'] = R
306 R2 = view['RR']
306 R2 = view['RR']
307
307
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 self.assertEqual(r_dtype, R.dtype)
309 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_shape, R.shape)
310 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(R2.dtype, R.dtype)
311 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.shape, R.shape)
312 self.assertEqual(R2.shape, R.shape)
313 assert_array_equal(R2, R)
313 assert_array_equal(R2, R)
314
314
315 @skip_without('pandas')
315 @skip_without('pandas')
316 def test_push_pull_timeseries(self):
316 def test_push_pull_timeseries(self):
317 """push/pull pandas.TimeSeries"""
317 """push/pull pandas.TimeSeries"""
318 import pandas
318 import pandas
319
319
320 ts = pandas.TimeSeries(range(10))
320 ts = pandas.TimeSeries(range(10))
321
321
322 view = self.client[-1]
322 view = self.client[-1]
323
323
324 view.push(dict(ts=ts), block=True)
324 view.push(dict(ts=ts), block=True)
325 rts = view['ts']
325 rts = view['ts']
326
326
327 self.assertEqual(type(rts), type(ts))
327 self.assertEqual(type(rts), type(ts))
328 self.assertTrue((ts == rts).all())
328 self.assertTrue((ts == rts).all())
329
329
330 def test_map(self):
330 def test_map(self):
331 view = self.client[:]
331 view = self.client[:]
332 def f(x):
332 def f(x):
333 return x**2
333 return x**2
334 data = range(16)
334 data = range(16)
335 r = view.map_sync(f, data)
335 r = view.map_sync(f, data)
336 self.assertEqual(r, map(f, data))
336 self.assertEqual(r, map(f, data))
337
337
338 def test_map_iterable(self):
338 def test_map_iterable(self):
339 """test map on iterables (direct)"""
339 """test map on iterables (direct)"""
340 view = self.client[:]
340 view = self.client[:]
341 # 101 is prime, so it won't be evenly distributed
341 # 101 is prime, so it won't be evenly distributed
342 arr = range(101)
342 arr = range(101)
343 # ensure it will be an iterator, even in Python 3
343 # ensure it will be an iterator, even in Python 3
344 it = iter(arr)
344 it = iter(arr)
345 r = view.map_sync(lambda x:x, arr)
345 r = view.map_sync(lambda x:x, arr)
346 self.assertEqual(r, list(arr))
346 self.assertEqual(r, list(arr))
347
348 @skip_without('numpy')
349 def test_map_numpy(self):
350 """test map on numpy arrays (direct)"""
351 import numpy
352 from numpy.testing.utils import assert_array_equal
353
354 view = self.client[:]
355 # 101 is prime, so it won't be evenly distributed
356 arr = numpy.arange(101)
357 r = view.map_sync(lambda x: x, arr)
358 assert_array_equal(r, arr)
347
359
348 def test_scatter_gather_nonblocking(self):
360 def test_scatter_gather_nonblocking(self):
349 data = range(16)
361 data = range(16)
350 view = self.client[:]
362 view = self.client[:]
351 view.scatter('a', data, block=False)
363 view.scatter('a', data, block=False)
352 ar = view.gather('a', block=False)
364 ar = view.gather('a', block=False)
353 self.assertEqual(ar.get(), data)
365 self.assertEqual(ar.get(), data)
354
366
355 @skip_without('numpy')
367 @skip_without('numpy')
356 def test_scatter_gather_numpy_nonblocking(self):
368 def test_scatter_gather_numpy_nonblocking(self):
357 import numpy
369 import numpy
358 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
359 a = numpy.arange(64)
371 a = numpy.arange(64)
360 view = self.client[:]
372 view = self.client[:]
361 ar = view.scatter('a', a, block=False)
373 ar = view.scatter('a', a, block=False)
362 self.assertTrue(isinstance(ar, AsyncResult))
374 self.assertTrue(isinstance(ar, AsyncResult))
363 amr = view.gather('a', block=False)
375 amr = view.gather('a', block=False)
364 self.assertTrue(isinstance(amr, AsyncMapResult))
376 self.assertTrue(isinstance(amr, AsyncMapResult))
365 assert_array_equal(amr.get(), a)
377 assert_array_equal(amr.get(), a)
366
378
367 def test_execute(self):
379 def test_execute(self):
368 view = self.client[:]
380 view = self.client[:]
369 # self.client.debug=True
381 # self.client.debug=True
370 execute = view.execute
382 execute = view.execute
371 ar = execute('c=30', block=False)
383 ar = execute('c=30', block=False)
372 self.assertTrue(isinstance(ar, AsyncResult))
384 self.assertTrue(isinstance(ar, AsyncResult))
373 ar = execute('d=[0,1,2]', block=False)
385 ar = execute('d=[0,1,2]', block=False)
374 self.client.wait(ar, 1)
386 self.client.wait(ar, 1)
375 self.assertEqual(len(ar.get()), len(self.client))
387 self.assertEqual(len(ar.get()), len(self.client))
376 for c in view['c']:
388 for c in view['c']:
377 self.assertEqual(c, 30)
389 self.assertEqual(c, 30)
378
390
379 def test_abort(self):
391 def test_abort(self):
380 view = self.client[-1]
392 view = self.client[-1]
381 ar = view.execute('import time; time.sleep(1)', block=False)
393 ar = view.execute('import time; time.sleep(1)', block=False)
382 ar2 = view.apply_async(lambda : 2)
394 ar2 = view.apply_async(lambda : 2)
383 ar3 = view.apply_async(lambda : 3)
395 ar3 = view.apply_async(lambda : 3)
384 view.abort(ar2)
396 view.abort(ar2)
385 view.abort(ar3.msg_ids)
397 view.abort(ar3.msg_ids)
386 self.assertRaises(error.TaskAborted, ar2.get)
398 self.assertRaises(error.TaskAborted, ar2.get)
387 self.assertRaises(error.TaskAborted, ar3.get)
399 self.assertRaises(error.TaskAborted, ar3.get)
388
400
389 def test_abort_all(self):
401 def test_abort_all(self):
390 """view.abort() aborts all outstanding tasks"""
402 """view.abort() aborts all outstanding tasks"""
391 view = self.client[-1]
403 view = self.client[-1]
392 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
393 view.abort()
405 view.abort()
394 view.wait(timeout=5)
406 view.wait(timeout=5)
395 for ar in ars[5:]:
407 for ar in ars[5:]:
396 self.assertRaises(error.TaskAborted, ar.get)
408 self.assertRaises(error.TaskAborted, ar.get)
397
409
398 def test_temp_flags(self):
410 def test_temp_flags(self):
399 view = self.client[-1]
411 view = self.client[-1]
400 view.block=True
412 view.block=True
401 with view.temp_flags(block=False):
413 with view.temp_flags(block=False):
402 self.assertFalse(view.block)
414 self.assertFalse(view.block)
403 self.assertTrue(view.block)
415 self.assertTrue(view.block)
404
416
405 @dec.known_failure_py3
417 @dec.known_failure_py3
406 def test_importer(self):
418 def test_importer(self):
407 view = self.client[-1]
419 view = self.client[-1]
408 view.clear(block=True)
420 view.clear(block=True)
409 with view.importer:
421 with view.importer:
410 import re
422 import re
411
423
412 @interactive
424 @interactive
413 def findall(pat, s):
425 def findall(pat, s):
414 # this globals() step isn't necessary in real code
426 # this globals() step isn't necessary in real code
415 # only to prevent a closure in the test
427 # only to prevent a closure in the test
416 re = globals()['re']
428 re = globals()['re']
417 return re.findall(pat, s)
429 return re.findall(pat, s)
418
430
419 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
420
432
421 def test_unicode_execute(self):
433 def test_unicode_execute(self):
422 """test executing unicode strings"""
434 """test executing unicode strings"""
423 v = self.client[-1]
435 v = self.client[-1]
424 v.block=True
436 v.block=True
425 if sys.version_info[0] >= 3:
437 if sys.version_info[0] >= 3:
426 code="a='é'"
438 code="a='é'"
427 else:
439 else:
428 code=u"a=u'é'"
440 code=u"a=u'é'"
429 v.execute(code)
441 v.execute(code)
430 self.assertEqual(v['a'], u'é')
442 self.assertEqual(v['a'], u'é')
431
443
432 def test_unicode_apply_result(self):
444 def test_unicode_apply_result(self):
433 """test unicode apply results"""
445 """test unicode apply results"""
434 v = self.client[-1]
446 v = self.client[-1]
435 r = v.apply_sync(lambda : u'é')
447 r = v.apply_sync(lambda : u'é')
436 self.assertEqual(r, u'é')
448 self.assertEqual(r, u'é')
437
449
438 def test_unicode_apply_arg(self):
450 def test_unicode_apply_arg(self):
439 """test passing unicode arguments to apply"""
451 """test passing unicode arguments to apply"""
440 v = self.client[-1]
452 v = self.client[-1]
441
453
442 @interactive
454 @interactive
443 def check_unicode(a, check):
455 def check_unicode(a, check):
444 assert isinstance(a, unicode), "%r is not unicode"%a
456 assert isinstance(a, unicode), "%r is not unicode"%a
445 assert isinstance(check, bytes), "%r is not bytes"%check
457 assert isinstance(check, bytes), "%r is not bytes"%check
446 assert a.encode('utf8') == check, "%s != %s"%(a,check)
458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
447
459
448 for s in [ u'é', u'ßø®∫',u'asdf' ]:
460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
449 try:
461 try:
450 v.apply_sync(check_unicode, s, s.encode('utf8'))
462 v.apply_sync(check_unicode, s, s.encode('utf8'))
451 except error.RemoteError as e:
463 except error.RemoteError as e:
452 if e.ename == 'AssertionError':
464 if e.ename == 'AssertionError':
453 self.fail(e.evalue)
465 self.fail(e.evalue)
454 else:
466 else:
455 raise e
467 raise e
456
468
457 def test_map_reference(self):
469 def test_map_reference(self):
458 """view.map(<Reference>, *seqs) should work"""
470 """view.map(<Reference>, *seqs) should work"""
459 v = self.client[:]
471 v = self.client[:]
460 v.scatter('n', self.client.ids, flatten=True)
472 v.scatter('n', self.client.ids, flatten=True)
461 v.execute("f = lambda x,y: x*y")
473 v.execute("f = lambda x,y: x*y")
462 rf = pmod.Reference('f')
474 rf = pmod.Reference('f')
463 nlist = list(range(10))
475 nlist = list(range(10))
464 mlist = nlist[::-1]
476 mlist = nlist[::-1]
465 expected = [ m*n for m,n in zip(mlist, nlist) ]
477 expected = [ m*n for m,n in zip(mlist, nlist) ]
466 result = v.map_sync(rf, mlist, nlist)
478 result = v.map_sync(rf, mlist, nlist)
467 self.assertEqual(result, expected)
479 self.assertEqual(result, expected)
468
480
469 def test_apply_reference(self):
481 def test_apply_reference(self):
470 """view.apply(<Reference>, *args) should work"""
482 """view.apply(<Reference>, *args) should work"""
471 v = self.client[:]
483 v = self.client[:]
472 v.scatter('n', self.client.ids, flatten=True)
484 v.scatter('n', self.client.ids, flatten=True)
473 v.execute("f = lambda x: n*x")
485 v.execute("f = lambda x: n*x")
474 rf = pmod.Reference('f')
486 rf = pmod.Reference('f')
475 result = v.apply_sync(rf, 5)
487 result = v.apply_sync(rf, 5)
476 expected = [ 5*id for id in self.client.ids ]
488 expected = [ 5*id for id in self.client.ids ]
477 self.assertEqual(result, expected)
489 self.assertEqual(result, expected)
478
490
479 def test_eval_reference(self):
491 def test_eval_reference(self):
480 v = self.client[self.client.ids[0]]
492 v = self.client[self.client.ids[0]]
481 v['g'] = range(5)
493 v['g'] = range(5)
482 rg = pmod.Reference('g[0]')
494 rg = pmod.Reference('g[0]')
483 echo = lambda x:x
495 echo = lambda x:x
484 self.assertEqual(v.apply_sync(echo, rg), 0)
496 self.assertEqual(v.apply_sync(echo, rg), 0)
485
497
486 def test_reference_nameerror(self):
498 def test_reference_nameerror(self):
487 v = self.client[self.client.ids[0]]
499 v = self.client[self.client.ids[0]]
488 r = pmod.Reference('elvis_has_left')
500 r = pmod.Reference('elvis_has_left')
489 echo = lambda x:x
501 echo = lambda x:x
490 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
491
503
492 def test_single_engine_map(self):
504 def test_single_engine_map(self):
493 e0 = self.client[self.client.ids[0]]
505 e0 = self.client[self.client.ids[0]]
494 r = range(5)
506 r = range(5)
495 check = [ -1*i for i in r ]
507 check = [ -1*i for i in r ]
496 result = e0.map_sync(lambda x: -1*x, r)
508 result = e0.map_sync(lambda x: -1*x, r)
497 self.assertEqual(result, check)
509 self.assertEqual(result, check)
498
510
499 def test_len(self):
511 def test_len(self):
500 """len(view) makes sense"""
512 """len(view) makes sense"""
501 e0 = self.client[self.client.ids[0]]
513 e0 = self.client[self.client.ids[0]]
502 yield self.assertEqual(len(e0), 1)
514 yield self.assertEqual(len(e0), 1)
503 v = self.client[:]
515 v = self.client[:]
504 yield self.assertEqual(len(v), len(self.client.ids))
516 yield self.assertEqual(len(v), len(self.client.ids))
505 v = self.client.direct_view('all')
517 v = self.client.direct_view('all')
506 yield self.assertEqual(len(v), len(self.client.ids))
518 yield self.assertEqual(len(v), len(self.client.ids))
507 v = self.client[:2]
519 v = self.client[:2]
508 yield self.assertEqual(len(v), 2)
520 yield self.assertEqual(len(v), 2)
509 v = self.client[:1]
521 v = self.client[:1]
510 yield self.assertEqual(len(v), 1)
522 yield self.assertEqual(len(v), 1)
511 v = self.client.load_balanced_view()
523 v = self.client.load_balanced_view()
512 yield self.assertEqual(len(v), len(self.client.ids))
524 yield self.assertEqual(len(v), len(self.client.ids))
513 # parametric tests seem to require manual closing?
525 # parametric tests seem to require manual closing?
514 self.client.close()
526 self.client.close()
515
527
516
528
517 # begin execute tests
529 # begin execute tests
518
530
519 def test_execute_reply(self):
531 def test_execute_reply(self):
520 e0 = self.client[self.client.ids[0]]
532 e0 = self.client[self.client.ids[0]]
521 e0.block = True
533 e0.block = True
522 ar = e0.execute("5", silent=False)
534 ar = e0.execute("5", silent=False)
523 er = ar.get()
535 er = ar.get()
524 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
525 self.assertEqual(er.pyout['data']['text/plain'], '5')
537 self.assertEqual(er.pyout['data']['text/plain'], '5')
526
538
527 def test_execute_reply_stdout(self):
539 def test_execute_reply_stdout(self):
528 e0 = self.client[self.client.ids[0]]
540 e0 = self.client[self.client.ids[0]]
529 e0.block = True
541 e0.block = True
530 ar = e0.execute("print (5)", silent=False)
542 ar = e0.execute("print (5)", silent=False)
531 er = ar.get()
543 er = ar.get()
532 self.assertEqual(er.stdout.strip(), '5')
544 self.assertEqual(er.stdout.strip(), '5')
533
545
534 def test_execute_pyout(self):
546 def test_execute_pyout(self):
535 """execute triggers pyout with silent=False"""
547 """execute triggers pyout with silent=False"""
536 view = self.client[:]
548 view = self.client[:]
537 ar = view.execute("5", silent=False, block=True)
549 ar = view.execute("5", silent=False, block=True)
538
550
539 expected = [{'text/plain' : '5'}] * len(view)
551 expected = [{'text/plain' : '5'}] * len(view)
540 mimes = [ out['data'] for out in ar.pyout ]
552 mimes = [ out['data'] for out in ar.pyout ]
541 self.assertEqual(mimes, expected)
553 self.assertEqual(mimes, expected)
542
554
543 def test_execute_silent(self):
555 def test_execute_silent(self):
544 """execute does not trigger pyout with silent=True"""
556 """execute does not trigger pyout with silent=True"""
545 view = self.client[:]
557 view = self.client[:]
546 ar = view.execute("5", block=True)
558 ar = view.execute("5", block=True)
547 expected = [None] * len(view)
559 expected = [None] * len(view)
548 self.assertEqual(ar.pyout, expected)
560 self.assertEqual(ar.pyout, expected)
549
561
550 def test_execute_magic(self):
562 def test_execute_magic(self):
551 """execute accepts IPython commands"""
563 """execute accepts IPython commands"""
552 view = self.client[:]
564 view = self.client[:]
553 view.execute("a = 5")
565 view.execute("a = 5")
554 ar = view.execute("%whos", block=True)
566 ar = view.execute("%whos", block=True)
555 # this will raise, if that failed
567 # this will raise, if that failed
556 ar.get(5)
568 ar.get(5)
557 for stdout in ar.stdout:
569 for stdout in ar.stdout:
558 lines = stdout.splitlines()
570 lines = stdout.splitlines()
559 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
560 found = False
572 found = False
561 for line in lines[2:]:
573 for line in lines[2:]:
562 split = line.split()
574 split = line.split()
563 if split == ['a', 'int', '5']:
575 if split == ['a', 'int', '5']:
564 found = True
576 found = True
565 break
577 break
566 self.assertTrue(found, "whos output wrong: %s" % stdout)
578 self.assertTrue(found, "whos output wrong: %s" % stdout)
567
579
568 def test_execute_displaypub(self):
580 def test_execute_displaypub(self):
569 """execute tracks display_pub output"""
581 """execute tracks display_pub output"""
570 view = self.client[:]
582 view = self.client[:]
571 view.execute("from IPython.core.display import *")
583 view.execute("from IPython.core.display import *")
572 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
573
585
574 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
575 for outputs in ar.outputs:
587 for outputs in ar.outputs:
576 mimes = [ out['data'] for out in outputs ]
588 mimes = [ out['data'] for out in outputs ]
577 self.assertEqual(mimes, expected)
589 self.assertEqual(mimes, expected)
578
590
579 def test_apply_displaypub(self):
591 def test_apply_displaypub(self):
580 """apply tracks display_pub output"""
592 """apply tracks display_pub output"""
581 view = self.client[:]
593 view = self.client[:]
582 view.execute("from IPython.core.display import *")
594 view.execute("from IPython.core.display import *")
583
595
584 @interactive
596 @interactive
585 def publish():
597 def publish():
586 [ display(i) for i in range(5) ]
598 [ display(i) for i in range(5) ]
587
599
588 ar = view.apply_async(publish)
600 ar = view.apply_async(publish)
589 ar.get(5)
601 ar.get(5)
590 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
591 for outputs in ar.outputs:
603 for outputs in ar.outputs:
592 mimes = [ out['data'] for out in outputs ]
604 mimes = [ out['data'] for out in outputs ]
593 self.assertEqual(mimes, expected)
605 self.assertEqual(mimes, expected)
594
606
595 def test_execute_raises(self):
607 def test_execute_raises(self):
596 """exceptions in execute requests raise appropriately"""
608 """exceptions in execute requests raise appropriately"""
597 view = self.client[-1]
609 view = self.client[-1]
598 ar = view.execute("1/0")
610 ar = view.execute("1/0")
599 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
600
612
601 def test_remoteerror_render_exception(self):
613 def test_remoteerror_render_exception(self):
602 """RemoteErrors get nice tracebacks"""
614 """RemoteErrors get nice tracebacks"""
603 view = self.client[-1]
615 view = self.client[-1]
604 ar = view.execute("1/0")
616 ar = view.execute("1/0")
605 ip = get_ipython()
617 ip = get_ipython()
606 ip.user_ns['ar'] = ar
618 ip.user_ns['ar'] = ar
607 with capture_output() as io:
619 with capture_output() as io:
608 ip.run_cell("ar.get(2)")
620 ip.run_cell("ar.get(2)")
609
621
610 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
611
623
612 def test_compositeerror_render_exception(self):
624 def test_compositeerror_render_exception(self):
613 """CompositeErrors get nice tracebacks"""
625 """CompositeErrors get nice tracebacks"""
614 view = self.client[:]
626 view = self.client[:]
615 ar = view.execute("1/0")
627 ar = view.execute("1/0")
616 ip = get_ipython()
628 ip = get_ipython()
617 ip.user_ns['ar'] = ar
629 ip.user_ns['ar'] = ar
618
630
619 with capture_output() as io:
631 with capture_output() as io:
620 ip.run_cell("ar.get(2)")
632 ip.run_cell("ar.get(2)")
621
633
622 count = min(error.CompositeError.tb_limit, len(view))
634 count = min(error.CompositeError.tb_limit, len(view))
623
635
624 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
625 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
626 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
627
639
628 def test_compositeerror_truncate(self):
640 def test_compositeerror_truncate(self):
629 """Truncate CompositeErrors with many exceptions"""
641 """Truncate CompositeErrors with many exceptions"""
630 view = self.client[:]
642 view = self.client[:]
631 msg_ids = []
643 msg_ids = []
632 for i in range(10):
644 for i in range(10):
633 ar = view.execute("1/0")
645 ar = view.execute("1/0")
634 msg_ids.extend(ar.msg_ids)
646 msg_ids.extend(ar.msg_ids)
635
647
636 ar = self.client.get_result(msg_ids)
648 ar = self.client.get_result(msg_ids)
637 try:
649 try:
638 ar.get()
650 ar.get()
639 except error.CompositeError as _e:
651 except error.CompositeError as _e:
640 e = _e
652 e = _e
641 else:
653 else:
642 self.fail("Should have raised CompositeError")
654 self.fail("Should have raised CompositeError")
643
655
644 lines = e.render_traceback()
656 lines = e.render_traceback()
645 with capture_output() as io:
657 with capture_output() as io:
646 e.print_traceback()
658 e.print_traceback()
647
659
648 self.assertTrue("more exceptions" in lines[-1])
660 self.assertTrue("more exceptions" in lines[-1])
649 count = e.tb_limit
661 count = e.tb_limit
650
662
651 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
652 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654
666
655 @dec.skipif_not_matplotlib
667 @dec.skipif_not_matplotlib
656 def test_magic_pylab(self):
668 def test_magic_pylab(self):
657 """%pylab works on engines"""
669 """%pylab works on engines"""
658 view = self.client[-1]
670 view = self.client[-1]
659 ar = view.execute("%pylab inline")
671 ar = view.execute("%pylab inline")
660 # at least check if this raised:
672 # at least check if this raised:
661 reply = ar.get(5)
673 reply = ar.get(5)
662 # include imports, in case user config
674 # include imports, in case user config
663 ar = view.execute("plot(rand(100))", silent=False)
675 ar = view.execute("plot(rand(100))", silent=False)
664 reply = ar.get(5)
676 reply = ar.get(5)
665 self.assertEqual(len(reply.outputs), 1)
677 self.assertEqual(len(reply.outputs), 1)
666 output = reply.outputs[0]
678 output = reply.outputs[0]
667 self.assertTrue("data" in output)
679 self.assertTrue("data" in output)
668 data = output['data']
680 data = output['data']
669 self.assertTrue("image/png" in data)
681 self.assertTrue("image/png" in data)
670
682
671 def test_func_default_func(self):
683 def test_func_default_func(self):
672 """interactively defined function as apply func default"""
684 """interactively defined function as apply func default"""
673 def foo():
685 def foo():
674 return 'foo'
686 return 'foo'
675
687
676 def bar(f=foo):
688 def bar(f=foo):
677 return f()
689 return f()
678
690
679 view = self.client[-1]
691 view = self.client[-1]
680 ar = view.apply_async(bar)
692 ar = view.apply_async(bar)
681 r = ar.get(10)
693 r = ar.get(10)
682 self.assertEqual(r, 'foo')
694 self.assertEqual(r, 'foo')
683 def test_data_pub_single(self):
695 def test_data_pub_single(self):
684 view = self.client[-1]
696 view = self.client[-1]
685 ar = view.execute('\n'.join([
697 ar = view.execute('\n'.join([
686 'from IPython.kernel.zmq.datapub import publish_data',
698 'from IPython.kernel.zmq.datapub import publish_data',
687 'for i in range(5):',
699 'for i in range(5):',
688 ' publish_data(dict(i=i))'
700 ' publish_data(dict(i=i))'
689 ]), block=False)
701 ]), block=False)
690 self.assertTrue(isinstance(ar.data, dict))
702 self.assertTrue(isinstance(ar.data, dict))
691 ar.get(5)
703 ar.get(5)
692 self.assertEqual(ar.data, dict(i=4))
704 self.assertEqual(ar.data, dict(i=4))
693
705
694 def test_data_pub(self):
706 def test_data_pub(self):
695 view = self.client[:]
707 view = self.client[:]
696 ar = view.execute('\n'.join([
708 ar = view.execute('\n'.join([
697 'from IPython.kernel.zmq.datapub import publish_data',
709 'from IPython.kernel.zmq.datapub import publish_data',
698 'for i in range(5):',
710 'for i in range(5):',
699 ' publish_data(dict(i=i))'
711 ' publish_data(dict(i=i))'
700 ]), block=False)
712 ]), block=False)
701 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
702 ar.get(5)
714 ar.get(5)
703 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
704
716
705 def test_can_list_arg(self):
717 def test_can_list_arg(self):
706 """args in lists are canned"""
718 """args in lists are canned"""
707 view = self.client[-1]
719 view = self.client[-1]
708 view['a'] = 128
720 view['a'] = 128
709 rA = pmod.Reference('a')
721 rA = pmod.Reference('a')
710 ar = view.apply_async(lambda x: x, [rA])
722 ar = view.apply_async(lambda x: x, [rA])
711 r = ar.get(5)
723 r = ar.get(5)
712 self.assertEqual(r, [128])
724 self.assertEqual(r, [128])
713
725
714 def test_can_dict_arg(self):
726 def test_can_dict_arg(self):
715 """args in dicts are canned"""
727 """args in dicts are canned"""
716 view = self.client[-1]
728 view = self.client[-1]
717 view['a'] = 128
729 view['a'] = 128
718 rA = pmod.Reference('a')
730 rA = pmod.Reference('a')
719 ar = view.apply_async(lambda x: x, dict(foo=rA))
731 ar = view.apply_async(lambda x: x, dict(foo=rA))
720 r = ar.get(5)
732 r = ar.get(5)
721 self.assertEqual(r, dict(foo=128))
733 self.assertEqual(r, dict(foo=128))
722
734
723 def test_can_list_kwarg(self):
735 def test_can_list_kwarg(self):
724 """kwargs in lists are canned"""
736 """kwargs in lists are canned"""
725 view = self.client[-1]
737 view = self.client[-1]
726 view['a'] = 128
738 view['a'] = 128
727 rA = pmod.Reference('a')
739 rA = pmod.Reference('a')
728 ar = view.apply_async(lambda x=5: x, x=[rA])
740 ar = view.apply_async(lambda x=5: x, x=[rA])
729 r = ar.get(5)
741 r = ar.get(5)
730 self.assertEqual(r, [128])
742 self.assertEqual(r, [128])
731
743
732 def test_can_dict_kwarg(self):
744 def test_can_dict_kwarg(self):
733 """kwargs in dicts are canned"""
745 """kwargs in dicts are canned"""
734 view = self.client[-1]
746 view = self.client[-1]
735 view['a'] = 128
747 view['a'] = 128
736 rA = pmod.Reference('a')
748 rA = pmod.Reference('a')
737 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
738 r = ar.get(5)
750 r = ar.get(5)
739 self.assertEqual(r, dict(foo=128))
751 self.assertEqual(r, dict(foo=128))
740
752
741 def test_map_ref(self):
753 def test_map_ref(self):
742 """view.map works with references"""
754 """view.map works with references"""
743 view = self.client[:]
755 view = self.client[:]
744 ranks = sorted(self.client.ids)
756 ranks = sorted(self.client.ids)
745 view.scatter('rank', ranks, flatten=True)
757 view.scatter('rank', ranks, flatten=True)
746 rrank = pmod.Reference('rank')
758 rrank = pmod.Reference('rank')
747
759
748 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
749 drank = amr.get(5)
761 drank = amr.get(5)
750 self.assertEqual(drank, [ r*2 for r in ranks ])
762 self.assertEqual(drank, [ r*2 for r in ranks ])
751
763
752 def test_nested_getitem_setitem(self):
764 def test_nested_getitem_setitem(self):
753 """get and set with view['a.b']"""
765 """get and set with view['a.b']"""
754 view = self.client[-1]
766 view = self.client[-1]
755 view.execute('\n'.join([
767 view.execute('\n'.join([
756 'class A(object): pass',
768 'class A(object): pass',
757 'a = A()',
769 'a = A()',
758 'a.b = 128',
770 'a.b = 128',
759 ]), block=True)
771 ]), block=True)
760 ra = pmod.Reference('a')
772 ra = pmod.Reference('a')
761
773
762 r = view.apply_sync(lambda x: x.b, ra)
774 r = view.apply_sync(lambda x: x.b, ra)
763 self.assertEqual(r, 128)
775 self.assertEqual(r, 128)
764 self.assertEqual(view['a.b'], 128)
776 self.assertEqual(view['a.b'], 128)
765
777
766 view['a.b'] = 0
778 view['a.b'] = 0
767
779
768 r = view.apply_sync(lambda x: x.b, ra)
780 r = view.apply_sync(lambda x: x.b, ra)
769 self.assertEqual(r, 0)
781 self.assertEqual(r, 0)
770 self.assertEqual(view['a.b'], 0)
782 self.assertEqual(view['a.b'], 0)
771
783
772 def test_return_namedtuple(self):
784 def test_return_namedtuple(self):
773 def namedtuplify(x, y):
785 def namedtuplify(x, y):
774 from IPython.parallel.tests.test_view import point
786 from IPython.parallel.tests.test_view import point
775 return point(x, y)
787 return point(x, y)
776
788
777 view = self.client[-1]
789 view = self.client[-1]
778 p = view.apply_sync(namedtuplify, 1, 2)
790 p = view.apply_sync(namedtuplify, 1, 2)
779 self.assertEqual(p.x, 1)
791 self.assertEqual(p.x, 1)
780 self.assertEqual(p.y, 2)
792 self.assertEqual(p.y, 2)
781
793
782 def test_apply_namedtuple(self):
794 def test_apply_namedtuple(self):
783 def echoxy(p):
795 def echoxy(p):
784 return p.y, p.x
796 return p.y, p.x
785
797
786 view = self.client[-1]
798 view = self.client[-1]
787 tup = view.apply_sync(echoxy, point(1, 2))
799 tup = view.apply_sync(echoxy, point(1, 2))
788 self.assertEqual(tup, (2,1))
800 self.assertEqual(tup, (2,1))
789
801
General Comments 0
You need to be logged in to leave comments. Login now