##// END OF EJS Templates
Fix parallel.client.View map() on numpy arrays
Samuel Ainsworth -
Show More
@@ -1,282 +1,283 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import sys
22 22 import warnings
23 23
24 24 from IPython.external.decorator import decorator
25 25 from IPython.testing.skipdoctest import skip_doctest
26 26
27 27 from . import map as Map
28 28 from .asyncresult import AsyncMapResult
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Functions and Decorators
32 32 #-----------------------------------------------------------------------------
33 33
34 34 @skip_doctest
35 35 def remote(view, block=None, **flags):
36 36 """Turn a function into a remote function.
37 37
38 38 This method can be used for map:
39 39
40 40 In [1]: @remote(view,block=True)
41 41 ...: def func(a):
42 42 ...: pass
43 43 """
44 44
45 45 def remote_function(f):
46 46 return RemoteFunction(view, f, block=block, **flags)
47 47 return remote_function
48 48
49 49 @skip_doctest
50 50 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 51 """Turn a function into a parallel remote function.
52 52
53 53 This method can be used for map:
54 54
55 55 In [1]: @parallel(view, block=True)
56 56 ...: def func(a):
57 57 ...: pass
58 58 """
59 59
60 60 def parallel_function(f):
61 61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 62 return parallel_function
63 63
64 64 def getname(f):
65 65 """Get the name of an object.
66 66
67 67 For use in case of callables that are not functions, and
68 68 thus may not have __name__ defined.
69 69
70 70 Order: f.__name__ > f.name > str(f)
71 71 """
72 72 try:
73 73 return f.__name__
74 74 except:
75 75 pass
76 76 try:
77 77 return f.name
78 78 except:
79 79 pass
80 80
81 81 return str(f)
82 82
83 83 @decorator
84 84 def sync_view_results(f, self, *args, **kwargs):
85 85 """sync relevant results from self.client to our results attribute.
86 86
87 87 This is a clone of view.sync_results, but for remote functions
88 88 """
89 89 view = self.view
90 90 if view._in_sync_results:
91 91 return f(self, *args, **kwargs)
92 92 print 'in sync results', f
93 93 view._in_sync_results = True
94 94 try:
95 95 ret = f(self, *args, **kwargs)
96 96 finally:
97 97 view._in_sync_results = False
98 98 view._sync_results()
99 99 return ret
100 100
101 101 #--------------------------------------------------------------------------
102 102 # Classes
103 103 #--------------------------------------------------------------------------
104 104
105 105 class RemoteFunction(object):
106 106 """Turn an existing function into a remote function.
107 107
108 108 Parameters
109 109 ----------
110 110
111 111 view : View instance
112 112 The view to be used for execution
113 113 f : callable
114 114 The function to be wrapped into a remote function
115 115 block : bool [default: None]
116 116 Whether to wait for results or not. The default behavior is
117 117 to use the current `block` attribute of `view`
118 118
119 119 **flags : remaining kwargs are passed to View.temp_flags
120 120 """
121 121
122 122 view = None # the remote connection
123 123 func = None # the wrapped function
124 124 block = None # whether to block
125 125 flags = None # dict of extra kwargs for temp_flags
126 126
127 127 def __init__(self, view, f, block=None, **flags):
128 128 self.view = view
129 129 self.func = f
130 130 self.block=block
131 131 self.flags=flags
132 132
133 133 def __call__(self, *args, **kwargs):
134 134 block = self.view.block if self.block is None else self.block
135 135 with self.view.temp_flags(block=block, **self.flags):
136 136 return self.view.apply(self.func, *args, **kwargs)
137 137
138 138
139 139 class ParallelFunction(RemoteFunction):
140 140 """Class for mapping a function to sequences.
141 141
142 142 This will distribute the sequences according the a mapper, and call
143 143 the function on each sub-sequence. If called via map, then the function
144 144 will be called once on each element, rather that each sub-sequence.
145 145
146 146 Parameters
147 147 ----------
148 148
149 149 view : View instance
150 150 The view to be used for execution
151 151 f : callable
152 152 The function to be wrapped into a remote function
153 153 dist : str [default: 'b']
154 154 The key for which mapObject to use to distribute sequences
155 155 options are:
156 156 * 'b' : use contiguous chunks in order
157 157 * 'r' : use round-robin striping
158 158 block : bool [default: None]
159 159 Whether to wait for results or not. The default behavior is
160 160 to use the current `block` attribute of `view`
161 161 chunksize : int or None
162 162 The size of chunk to use when breaking up sequences in a load-balanced manner
163 163 ordered : bool [default: True]
164 164 Whether the result should be kept in order. If False,
165 165 results become available as they arrive, regardless of submission order.
166 166 **flags : remaining kwargs are passed to View.temp_flags
167 167 """
168 168
169 169 chunksize = None
170 170 ordered = None
171 171 mapObject = None
172 172 _mapping = False
173 173
174 174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
175 175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
176 176 self.chunksize = chunksize
177 177 self.ordered = ordered
178 178
179 179 mapClass = Map.dists[dist]
180 180 self.mapObject = mapClass()
181 181
182 182 @sync_view_results
183 183 def __call__(self, *sequences):
184 184 client = self.view.client
185 185
186 186 lens = []
187 187 maxlen = minlen = -1
188 188 for i, seq in enumerate(sequences):
189 189 try:
190 190 n = len(seq)
191 191 except Exception:
192 192 seq = list(seq)
193 193 if isinstance(sequences, tuple):
194 194 # can't alter a tuple
195 195 sequences = list(sequences)
196 196 sequences[i] = seq
197 197 n = len(seq)
198 198 if n > maxlen:
199 199 maxlen = n
200 200 if minlen == -1 or n < minlen:
201 201 minlen = n
202 202 lens.append(n)
203 203
204 204 # check that the length of sequences match
205 205 if not self._mapping and minlen != maxlen:
206 206 msg = 'all sequences must have equal length, but have %s' % lens
207 207 raise ValueError(msg)
208 208
209 209 balanced = 'Balanced' in self.view.__class__.__name__
210 210 if balanced:
211 211 if self.chunksize:
212 212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
213 213 else:
214 214 nparts = maxlen
215 215 targets = [None]*nparts
216 216 else:
217 217 if self.chunksize:
218 218 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
219 219 # multiplexed:
220 220 targets = self.view.targets
221 221 # 'all' is lazily evaluated at execution time, which is now:
222 222 if targets == 'all':
223 223 targets = client._build_targets(targets)[1]
224 224 elif isinstance(targets, int):
225 225 # single-engine view, targets must be iterable
226 226 targets = [targets]
227 227 nparts = len(targets)
228 228
229 229 msg_ids = []
230 230 for index, t in enumerate(targets):
231 231 args = []
232 232 for seq in sequences:
233 233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
234 234 args.append(part)
235 if not any(args):
235
236 if sum([len(arg) for arg in args]) == 0:
236 237 continue
237 238
238 239 if self._mapping:
239 240 if sys.version_info[0] >= 3:
240 241 f = lambda f, *sequences: list(map(f, *sequences))
241 242 else:
242 243 f = map
243 244 args = [self.func] + args
244 245 else:
245 246 f=self.func
246 247
247 248 view = self.view if balanced else client[t]
248 249 with view.temp_flags(block=False, **self.flags):
249 250 ar = view.apply(f, *args)
250 251
251 252 msg_ids.extend(ar.msg_ids)
252 253
253 254 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
254 255 fname=getname(self.func),
255 256 ordered=self.ordered
256 257 )
257 258
258 259 if self.block:
259 260 try:
260 261 return r.get()
261 262 except KeyboardInterrupt:
262 263 return r
263 264 else:
264 265 return r
265 266
266 267 def map(self, *sequences):
267 268 """call a function on each element of one or more sequence(s) remotely.
268 269 This should behave very much like the builtin map, but return an AsyncMapResult
269 270 if self.block is False.
270 271
271 272 That means it can take generators (will be cast to lists locally),
272 273 and mismatched sequence lengths will be padded with None.
273 274 """
274 275 # set _mapping as a flag for use inside self.__call__
275 276 self._mapping = True
276 277 try:
277 278 ret = self(*sequences)
278 279 finally:
279 280 self._mapping = False
280 281 return ret
281 282
282 283 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,789 +1,801 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from collections import namedtuple
23 23 from tempfile import mktemp
24 24 from StringIO import StringIO
25 25
26 26 import zmq
27 27 from nose import SkipTest
28 28 from nose.plugins.attrib import attr
29 29
30 30 from IPython.testing import decorators as dec
31 31 from IPython.testing.ipunittest import ParametricTestCase
32 32 from IPython.utils.io import capture_output
33 33
34 34 from IPython import parallel as pmod
35 35 from IPython.parallel import error
36 36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 37 from IPython.parallel import DirectView
38 38 from IPython.parallel.util import interactive
39 39
40 40 from IPython.parallel.tests import add_engines
41 41
42 42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43 43
44 44 def setup():
45 45 add_engines(3, total=True)
46 46
47 47 point = namedtuple("point", "x y")
48 48
49 49 class TestView(ClusterTestCase, ParametricTestCase):
50 50
51 51 def setUp(self):
52 52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 55 time.sleep(2)
56 56 super(TestView, self).setUp()
57 57
58 58 @attr('crash')
59 59 def test_z_crash_mux(self):
60 60 """test graceful handling of engine death (direct)"""
61 61 # self.add_engines(1)
62 62 eid = self.client.ids[-1]
63 63 ar = self.client[eid].apply_async(crash)
64 64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 65 eid = ar.engine_id
66 66 tic = time.time()
67 67 while eid in self.client.ids and time.time()-tic < 5:
68 68 time.sleep(.01)
69 69 self.client.spin()
70 70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71 71
72 72 def test_push_pull(self):
73 73 """test pushing and pulling"""
74 74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 75 t = self.client.ids[-1]
76 76 v = self.client[t]
77 77 push = v.push
78 78 pull = v.pull
79 79 v.block=True
80 80 nengines = len(self.client)
81 81 push({'data':data})
82 82 d = pull('data')
83 83 self.assertEqual(d, data)
84 84 self.client[:].push({'data':data})
85 85 d = self.client[:].pull('data', block=True)
86 86 self.assertEqual(d, nengines*[data])
87 87 ar = push({'data':data}, block=False)
88 88 self.assertTrue(isinstance(ar, AsyncResult))
89 89 r = ar.get()
90 90 ar = self.client[:].pull('data', block=False)
91 91 self.assertTrue(isinstance(ar, AsyncResult))
92 92 r = ar.get()
93 93 self.assertEqual(r, nengines*[data])
94 94 self.client[:].push(dict(a=10,b=20))
95 95 r = self.client[:].pull(('a','b'), block=True)
96 96 self.assertEqual(r, nengines*[[10,20]])
97 97
98 98 def test_push_pull_function(self):
99 99 "test pushing and pulling functions"
100 100 def testf(x):
101 101 return 2.0*x
102 102
103 103 t = self.client.ids[-1]
104 104 v = self.client[t]
105 105 v.block=True
106 106 push = v.push
107 107 pull = v.pull
108 108 execute = v.execute
109 109 push({'testf':testf})
110 110 r = pull('testf')
111 111 self.assertEqual(r(1.0), testf(1.0))
112 112 execute('r = testf(10)')
113 113 r = pull('r')
114 114 self.assertEqual(r, testf(10))
115 115 ar = self.client[:].push({'testf':testf}, block=False)
116 116 ar.get()
117 117 ar = self.client[:].pull('testf', block=False)
118 118 rlist = ar.get()
119 119 for r in rlist:
120 120 self.assertEqual(r(1.0), testf(1.0))
121 121 execute("def g(x): return x*x")
122 122 r = pull(('testf','g'))
123 123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124 124
125 125 def test_push_function_globals(self):
126 126 """test that pushed functions have access to globals"""
127 127 @interactive
128 128 def geta():
129 129 return a
130 130 # self.add_engines(1)
131 131 v = self.client[-1]
132 132 v.block=True
133 133 v['f'] = geta
134 134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 135 v.execute('a=5')
136 136 v.execute('b=f()')
137 137 self.assertEqual(v['b'], 5)
138 138
139 139 def test_push_function_defaults(self):
140 140 """test that pushed functions preserve default args"""
141 141 def echo(a=10):
142 142 return a
143 143 v = self.client[-1]
144 144 v.block=True
145 145 v['f'] = echo
146 146 v.execute('b=f()')
147 147 self.assertEqual(v['b'], 10)
148 148
149 149 def test_get_result(self):
150 150 """test getting results from the Hub."""
151 151 c = pmod.Client(profile='iptest')
152 152 # self.add_engines(1)
153 153 t = c.ids[-1]
154 154 v = c[t]
155 155 v2 = self.client[t]
156 156 ar = v.apply_async(wait, 1)
157 157 # give the monitor time to notice the message
158 158 time.sleep(.25)
159 159 ahr = v2.get_result(ar.msg_ids[0])
160 160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 161 self.assertEqual(ahr.get(), ar.get())
162 162 ar2 = v2.get_result(ar.msg_ids[0])
163 163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 164 c.spin()
165 165 c.close()
166 166
167 167 def test_run_newline(self):
168 168 """test that run appends newline to files"""
169 169 tmpfile = mktemp()
170 170 with open(tmpfile, 'w') as f:
171 171 f.write("""def g():
172 172 return 5
173 173 """)
174 174 v = self.client[-1]
175 175 v.run(tmpfile, block=True)
176 176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177 177
178 178 def test_apply_tracked(self):
179 179 """test tracking for apply"""
180 180 # self.add_engines(1)
181 181 t = self.client.ids[-1]
182 182 v = self.client[t]
183 183 v.block=False
184 184 def echo(n=1024*1024, **kwargs):
185 185 with v.temp_flags(**kwargs):
186 186 return v.apply(lambda x: x, 'x'*n)
187 187 ar = echo(1, track=False)
188 188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 189 self.assertTrue(ar.sent)
190 190 ar = echo(track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 self.assertEqual(ar.sent, ar._tracker.done)
193 193 ar._tracker.wait()
194 194 self.assertTrue(ar.sent)
195 195
196 196 def test_push_tracked(self):
197 197 t = self.client.ids[-1]
198 198 ns = dict(x='x'*1024*1024)
199 199 v = self.client[t]
200 200 ar = v.push(ns, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = v.push(ns, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 ar._tracker.wait()
207 207 self.assertEqual(ar.sent, ar._tracker.done)
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_scatter_tracked(self):
212 212 t = self.client.ids
213 213 x='x'*1024*1024
214 214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 216 self.assertTrue(ar.sent)
217 217
218 218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 220 self.assertEqual(ar.sent, ar._tracker.done)
221 221 ar._tracker.wait()
222 222 self.assertTrue(ar.sent)
223 223 ar.get()
224 224
225 225 def test_remote_reference(self):
226 226 v = self.client[-1]
227 227 v['a'] = 123
228 228 ra = pmod.Reference('a')
229 229 b = v.apply_sync(lambda x: x, ra)
230 230 self.assertEqual(b, 123)
231 231
232 232
233 233 def test_scatter_gather(self):
234 234 view = self.client[:]
235 235 seq1 = range(16)
236 236 view.scatter('a', seq1)
237 237 seq2 = view.gather('a', block=True)
238 238 self.assertEqual(seq2, seq1)
239 239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240 240
241 241 @skip_without('numpy')
242 242 def test_scatter_gather_numpy(self):
243 243 import numpy
244 244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 245 view = self.client[:]
246 246 a = numpy.arange(64)
247 247 view.scatter('a', a, block=True)
248 248 b = view.gather('a', block=True)
249 249 assert_array_equal(b, a)
250 250
251 251 def test_scatter_gather_lazy(self):
252 252 """scatter/gather with targets='all'"""
253 253 view = self.client.direct_view(targets='all')
254 254 x = range(64)
255 255 view.scatter('x', x)
256 256 gathered = view.gather('x', block=True)
257 257 self.assertEqual(gathered, x)
258 258
259 259
260 260 @dec.known_failure_py3
261 261 @skip_without('numpy')
262 262 def test_push_numpy_nocopy(self):
263 263 import numpy
264 264 view = self.client[:]
265 265 a = numpy.arange(64)
266 266 view['A'] = a
267 267 @interactive
268 268 def check_writeable(x):
269 269 return x.flags.writeable
270 270
271 271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 273
274 274 view.push(dict(B=a))
275 275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277 277
278 278 @skip_without('numpy')
279 279 def test_apply_numpy(self):
280 280 """view.apply(f, ndarray)"""
281 281 import numpy
282 282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283 283
284 284 A = numpy.random.random((100,100))
285 285 view = self.client[-1]
286 286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 287 B = A.astype(dt)
288 288 C = view.apply_sync(lambda x:x, B)
289 289 assert_array_equal(B,C)
290 290
291 291 @skip_without('numpy')
292 292 def test_push_pull_recarray(self):
293 293 """push/pull recarrays"""
294 294 import numpy
295 295 from numpy.testing.utils import assert_array_equal
296 296
297 297 view = self.client[-1]
298 298
299 299 R = numpy.array([
300 300 (1, 'hi', 0.),
301 301 (2**30, 'there', 2.5),
302 302 (-99999, 'world', -12345.6789),
303 303 ], [('n', int), ('s', '|S10'), ('f', float)])
304 304
305 305 view['RR'] = R
306 306 R2 = view['RR']
307 307
308 308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 309 self.assertEqual(r_dtype, R.dtype)
310 310 self.assertEqual(r_shape, R.shape)
311 311 self.assertEqual(R2.dtype, R.dtype)
312 312 self.assertEqual(R2.shape, R.shape)
313 313 assert_array_equal(R2, R)
314 314
315 315 @skip_without('pandas')
316 316 def test_push_pull_timeseries(self):
317 317 """push/pull pandas.TimeSeries"""
318 318 import pandas
319 319
320 320 ts = pandas.TimeSeries(range(10))
321 321
322 322 view = self.client[-1]
323 323
324 324 view.push(dict(ts=ts), block=True)
325 325 rts = view['ts']
326 326
327 327 self.assertEqual(type(rts), type(ts))
328 328 self.assertTrue((ts == rts).all())
329 329
330 330 def test_map(self):
331 331 view = self.client[:]
332 332 def f(x):
333 333 return x**2
334 334 data = range(16)
335 335 r = view.map_sync(f, data)
336 336 self.assertEqual(r, map(f, data))
337 337
338 338 def test_map_iterable(self):
339 339 """test map on iterables (direct)"""
340 340 view = self.client[:]
341 341 # 101 is prime, so it won't be evenly distributed
342 342 arr = range(101)
343 343 # ensure it will be an iterator, even in Python 3
344 344 it = iter(arr)
345 345 r = view.map_sync(lambda x:x, arr)
346 346 self.assertEqual(r, list(arr))
347
348 @skip_without('numpy')
349 def test_map_numpy(self):
350 """test map on numpy arrays (direct)"""
351 import numpy
352 from numpy.testing.utils import assert_array_equal
353
354 view = self.client[:]
355 # 101 is prime, so it won't be evenly distributed
356 arr = numpy.arange(101)
357 r = view.map_sync(lambda x: x, arr)
358 assert_array_equal(r, arr)
347 359
348 360 def test_scatter_gather_nonblocking(self):
349 361 data = range(16)
350 362 view = self.client[:]
351 363 view.scatter('a', data, block=False)
352 364 ar = view.gather('a', block=False)
353 365 self.assertEqual(ar.get(), data)
354 366
355 367 @skip_without('numpy')
356 368 def test_scatter_gather_numpy_nonblocking(self):
357 369 import numpy
358 370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
359 371 a = numpy.arange(64)
360 372 view = self.client[:]
361 373 ar = view.scatter('a', a, block=False)
362 374 self.assertTrue(isinstance(ar, AsyncResult))
363 375 amr = view.gather('a', block=False)
364 376 self.assertTrue(isinstance(amr, AsyncMapResult))
365 377 assert_array_equal(amr.get(), a)
366 378
367 379 def test_execute(self):
368 380 view = self.client[:]
369 381 # self.client.debug=True
370 382 execute = view.execute
371 383 ar = execute('c=30', block=False)
372 384 self.assertTrue(isinstance(ar, AsyncResult))
373 385 ar = execute('d=[0,1,2]', block=False)
374 386 self.client.wait(ar, 1)
375 387 self.assertEqual(len(ar.get()), len(self.client))
376 388 for c in view['c']:
377 389 self.assertEqual(c, 30)
378 390
379 391 def test_abort(self):
380 392 view = self.client[-1]
381 393 ar = view.execute('import time; time.sleep(1)', block=False)
382 394 ar2 = view.apply_async(lambda : 2)
383 395 ar3 = view.apply_async(lambda : 3)
384 396 view.abort(ar2)
385 397 view.abort(ar3.msg_ids)
386 398 self.assertRaises(error.TaskAborted, ar2.get)
387 399 self.assertRaises(error.TaskAborted, ar3.get)
388 400
389 401 def test_abort_all(self):
390 402 """view.abort() aborts all outstanding tasks"""
391 403 view = self.client[-1]
392 404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
393 405 view.abort()
394 406 view.wait(timeout=5)
395 407 for ar in ars[5:]:
396 408 self.assertRaises(error.TaskAborted, ar.get)
397 409
398 410 def test_temp_flags(self):
399 411 view = self.client[-1]
400 412 view.block=True
401 413 with view.temp_flags(block=False):
402 414 self.assertFalse(view.block)
403 415 self.assertTrue(view.block)
404 416
405 417 @dec.known_failure_py3
406 418 def test_importer(self):
407 419 view = self.client[-1]
408 420 view.clear(block=True)
409 421 with view.importer:
410 422 import re
411 423
412 424 @interactive
413 425 def findall(pat, s):
414 426 # this globals() step isn't necessary in real code
415 427 # only to prevent a closure in the test
416 428 re = globals()['re']
417 429 return re.findall(pat, s)
418 430
419 431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
420 432
421 433 def test_unicode_execute(self):
422 434 """test executing unicode strings"""
423 435 v = self.client[-1]
424 436 v.block=True
425 437 if sys.version_info[0] >= 3:
426 438 code="a='é'"
427 439 else:
428 440 code=u"a=u'é'"
429 441 v.execute(code)
430 442 self.assertEqual(v['a'], u'é')
431 443
432 444 def test_unicode_apply_result(self):
433 445 """test unicode apply results"""
434 446 v = self.client[-1]
435 447 r = v.apply_sync(lambda : u'é')
436 448 self.assertEqual(r, u'é')
437 449
438 450 def test_unicode_apply_arg(self):
439 451 """test passing unicode arguments to apply"""
440 452 v = self.client[-1]
441 453
442 454 @interactive
443 455 def check_unicode(a, check):
444 456 assert isinstance(a, unicode), "%r is not unicode"%a
445 457 assert isinstance(check, bytes), "%r is not bytes"%check
446 458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
447 459
448 460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
449 461 try:
450 462 v.apply_sync(check_unicode, s, s.encode('utf8'))
451 463 except error.RemoteError as e:
452 464 if e.ename == 'AssertionError':
453 465 self.fail(e.evalue)
454 466 else:
455 467 raise e
456 468
457 469 def test_map_reference(self):
458 470 """view.map(<Reference>, *seqs) should work"""
459 471 v = self.client[:]
460 472 v.scatter('n', self.client.ids, flatten=True)
461 473 v.execute("f = lambda x,y: x*y")
462 474 rf = pmod.Reference('f')
463 475 nlist = list(range(10))
464 476 mlist = nlist[::-1]
465 477 expected = [ m*n for m,n in zip(mlist, nlist) ]
466 478 result = v.map_sync(rf, mlist, nlist)
467 479 self.assertEqual(result, expected)
468 480
469 481 def test_apply_reference(self):
470 482 """view.apply(<Reference>, *args) should work"""
471 483 v = self.client[:]
472 484 v.scatter('n', self.client.ids, flatten=True)
473 485 v.execute("f = lambda x: n*x")
474 486 rf = pmod.Reference('f')
475 487 result = v.apply_sync(rf, 5)
476 488 expected = [ 5*id for id in self.client.ids ]
477 489 self.assertEqual(result, expected)
478 490
479 491 def test_eval_reference(self):
480 492 v = self.client[self.client.ids[0]]
481 493 v['g'] = range(5)
482 494 rg = pmod.Reference('g[0]')
483 495 echo = lambda x:x
484 496 self.assertEqual(v.apply_sync(echo, rg), 0)
485 497
486 498 def test_reference_nameerror(self):
487 499 v = self.client[self.client.ids[0]]
488 500 r = pmod.Reference('elvis_has_left')
489 501 echo = lambda x:x
490 502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
491 503
492 504 def test_single_engine_map(self):
493 505 e0 = self.client[self.client.ids[0]]
494 506 r = range(5)
495 507 check = [ -1*i for i in r ]
496 508 result = e0.map_sync(lambda x: -1*x, r)
497 509 self.assertEqual(result, check)
498 510
499 511 def test_len(self):
500 512 """len(view) makes sense"""
501 513 e0 = self.client[self.client.ids[0]]
502 514 yield self.assertEqual(len(e0), 1)
503 515 v = self.client[:]
504 516 yield self.assertEqual(len(v), len(self.client.ids))
505 517 v = self.client.direct_view('all')
506 518 yield self.assertEqual(len(v), len(self.client.ids))
507 519 v = self.client[:2]
508 520 yield self.assertEqual(len(v), 2)
509 521 v = self.client[:1]
510 522 yield self.assertEqual(len(v), 1)
511 523 v = self.client.load_balanced_view()
512 524 yield self.assertEqual(len(v), len(self.client.ids))
513 525 # parametric tests seem to require manual closing?
514 526 self.client.close()
515 527
516 528
517 529 # begin execute tests
518 530
519 531 def test_execute_reply(self):
520 532 e0 = self.client[self.client.ids[0]]
521 533 e0.block = True
522 534 ar = e0.execute("5", silent=False)
523 535 er = ar.get()
524 536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
525 537 self.assertEqual(er.pyout['data']['text/plain'], '5')
526 538
527 539 def test_execute_reply_stdout(self):
528 540 e0 = self.client[self.client.ids[0]]
529 541 e0.block = True
530 542 ar = e0.execute("print (5)", silent=False)
531 543 er = ar.get()
532 544 self.assertEqual(er.stdout.strip(), '5')
533 545
534 546 def test_execute_pyout(self):
535 547 """execute triggers pyout with silent=False"""
536 548 view = self.client[:]
537 549 ar = view.execute("5", silent=False, block=True)
538 550
539 551 expected = [{'text/plain' : '5'}] * len(view)
540 552 mimes = [ out['data'] for out in ar.pyout ]
541 553 self.assertEqual(mimes, expected)
542 554
543 555 def test_execute_silent(self):
544 556 """execute does not trigger pyout with silent=True"""
545 557 view = self.client[:]
546 558 ar = view.execute("5", block=True)
547 559 expected = [None] * len(view)
548 560 self.assertEqual(ar.pyout, expected)
549 561
550 562 def test_execute_magic(self):
551 563 """execute accepts IPython commands"""
552 564 view = self.client[:]
553 565 view.execute("a = 5")
554 566 ar = view.execute("%whos", block=True)
555 567 # this will raise, if that failed
556 568 ar.get(5)
557 569 for stdout in ar.stdout:
558 570 lines = stdout.splitlines()
559 571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
560 572 found = False
561 573 for line in lines[2:]:
562 574 split = line.split()
563 575 if split == ['a', 'int', '5']:
564 576 found = True
565 577 break
566 578 self.assertTrue(found, "whos output wrong: %s" % stdout)
567 579
568 580 def test_execute_displaypub(self):
569 581 """execute tracks display_pub output"""
570 582 view = self.client[:]
571 583 view.execute("from IPython.core.display import *")
572 584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
573 585
574 586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
575 587 for outputs in ar.outputs:
576 588 mimes = [ out['data'] for out in outputs ]
577 589 self.assertEqual(mimes, expected)
578 590
579 591 def test_apply_displaypub(self):
580 592 """apply tracks display_pub output"""
581 593 view = self.client[:]
582 594 view.execute("from IPython.core.display import *")
583 595
584 596 @interactive
585 597 def publish():
586 598 [ display(i) for i in range(5) ]
587 599
588 600 ar = view.apply_async(publish)
589 601 ar.get(5)
590 602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
591 603 for outputs in ar.outputs:
592 604 mimes = [ out['data'] for out in outputs ]
593 605 self.assertEqual(mimes, expected)
594 606
595 607 def test_execute_raises(self):
596 608 """exceptions in execute requests raise appropriately"""
597 609 view = self.client[-1]
598 610 ar = view.execute("1/0")
599 611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
600 612
601 613 def test_remoteerror_render_exception(self):
602 614 """RemoteErrors get nice tracebacks"""
603 615 view = self.client[-1]
604 616 ar = view.execute("1/0")
605 617 ip = get_ipython()
606 618 ip.user_ns['ar'] = ar
607 619 with capture_output() as io:
608 620 ip.run_cell("ar.get(2)")
609 621
610 622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
611 623
612 624 def test_compositeerror_render_exception(self):
613 625 """CompositeErrors get nice tracebacks"""
614 626 view = self.client[:]
615 627 ar = view.execute("1/0")
616 628 ip = get_ipython()
617 629 ip.user_ns['ar'] = ar
618 630
619 631 with capture_output() as io:
620 632 ip.run_cell("ar.get(2)")
621 633
622 634 count = min(error.CompositeError.tb_limit, len(view))
623 635
624 636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
625 637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
626 638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
627 639
628 640 def test_compositeerror_truncate(self):
629 641 """Truncate CompositeErrors with many exceptions"""
630 642 view = self.client[:]
631 643 msg_ids = []
632 644 for i in range(10):
633 645 ar = view.execute("1/0")
634 646 msg_ids.extend(ar.msg_ids)
635 647
636 648 ar = self.client.get_result(msg_ids)
637 649 try:
638 650 ar.get()
639 651 except error.CompositeError as _e:
640 652 e = _e
641 653 else:
642 654 self.fail("Should have raised CompositeError")
643 655
644 656 lines = e.render_traceback()
645 657 with capture_output() as io:
646 658 e.print_traceback()
647 659
648 660 self.assertTrue("more exceptions" in lines[-1])
649 661 count = e.tb_limit
650 662
651 663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
652 664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654 666
655 667 @dec.skipif_not_matplotlib
656 668 def test_magic_pylab(self):
657 669 """%pylab works on engines"""
658 670 view = self.client[-1]
659 671 ar = view.execute("%pylab inline")
660 672 # at least check if this raised:
661 673 reply = ar.get(5)
662 674 # include imports, in case user config
663 675 ar = view.execute("plot(rand(100))", silent=False)
664 676 reply = ar.get(5)
665 677 self.assertEqual(len(reply.outputs), 1)
666 678 output = reply.outputs[0]
667 679 self.assertTrue("data" in output)
668 680 data = output['data']
669 681 self.assertTrue("image/png" in data)
670 682
671 683 def test_func_default_func(self):
672 684 """interactively defined function as apply func default"""
673 685 def foo():
674 686 return 'foo'
675 687
676 688 def bar(f=foo):
677 689 return f()
678 690
679 691 view = self.client[-1]
680 692 ar = view.apply_async(bar)
681 693 r = ar.get(10)
682 694 self.assertEqual(r, 'foo')
683 695 def test_data_pub_single(self):
684 696 view = self.client[-1]
685 697 ar = view.execute('\n'.join([
686 698 'from IPython.kernel.zmq.datapub import publish_data',
687 699 'for i in range(5):',
688 700 ' publish_data(dict(i=i))'
689 701 ]), block=False)
690 702 self.assertTrue(isinstance(ar.data, dict))
691 703 ar.get(5)
692 704 self.assertEqual(ar.data, dict(i=4))
693 705
694 706 def test_data_pub(self):
695 707 view = self.client[:]
696 708 ar = view.execute('\n'.join([
697 709 'from IPython.kernel.zmq.datapub import publish_data',
698 710 'for i in range(5):',
699 711 ' publish_data(dict(i=i))'
700 712 ]), block=False)
701 713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
702 714 ar.get(5)
703 715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
704 716
705 717 def test_can_list_arg(self):
706 718 """args in lists are canned"""
707 719 view = self.client[-1]
708 720 view['a'] = 128
709 721 rA = pmod.Reference('a')
710 722 ar = view.apply_async(lambda x: x, [rA])
711 723 r = ar.get(5)
712 724 self.assertEqual(r, [128])
713 725
714 726 def test_can_dict_arg(self):
715 727 """args in dicts are canned"""
716 728 view = self.client[-1]
717 729 view['a'] = 128
718 730 rA = pmod.Reference('a')
719 731 ar = view.apply_async(lambda x: x, dict(foo=rA))
720 732 r = ar.get(5)
721 733 self.assertEqual(r, dict(foo=128))
722 734
723 735 def test_can_list_kwarg(self):
724 736 """kwargs in lists are canned"""
725 737 view = self.client[-1]
726 738 view['a'] = 128
727 739 rA = pmod.Reference('a')
728 740 ar = view.apply_async(lambda x=5: x, x=[rA])
729 741 r = ar.get(5)
730 742 self.assertEqual(r, [128])
731 743
732 744 def test_can_dict_kwarg(self):
733 745 """kwargs in dicts are canned"""
734 746 view = self.client[-1]
735 747 view['a'] = 128
736 748 rA = pmod.Reference('a')
737 749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
738 750 r = ar.get(5)
739 751 self.assertEqual(r, dict(foo=128))
740 752
741 753 def test_map_ref(self):
742 754 """view.map works with references"""
743 755 view = self.client[:]
744 756 ranks = sorted(self.client.ids)
745 757 view.scatter('rank', ranks, flatten=True)
746 758 rrank = pmod.Reference('rank')
747 759
748 760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
749 761 drank = amr.get(5)
750 762 self.assertEqual(drank, [ r*2 for r in ranks ])
751 763
752 764 def test_nested_getitem_setitem(self):
753 765 """get and set with view['a.b']"""
754 766 view = self.client[-1]
755 767 view.execute('\n'.join([
756 768 'class A(object): pass',
757 769 'a = A()',
758 770 'a.b = 128',
759 771 ]), block=True)
760 772 ra = pmod.Reference('a')
761 773
762 774 r = view.apply_sync(lambda x: x.b, ra)
763 775 self.assertEqual(r, 128)
764 776 self.assertEqual(view['a.b'], 128)
765 777
766 778 view['a.b'] = 0
767 779
768 780 r = view.apply_sync(lambda x: x.b, ra)
769 781 self.assertEqual(r, 0)
770 782 self.assertEqual(view['a.b'], 0)
771 783
772 784 def test_return_namedtuple(self):
773 785 def namedtuplify(x, y):
774 786 from IPython.parallel.tests.test_view import point
775 787 return point(x, y)
776 788
777 789 view = self.client[-1]
778 790 p = view.apply_sync(namedtuplify, 1, 2)
779 791 self.assertEqual(p.x, 1)
780 792 self.assertEqual(p.y, 2)
781 793
782 794 def test_apply_namedtuple(self):
783 795 def echoxy(p):
784 796 return p.y, p.x
785 797
786 798 view = self.client[-1]
787 799 tup = view.apply_sync(echoxy, point(1, 2))
788 800 self.assertEqual(tup, (2,1))
789 801
General Comments 0
You need to be logged in to leave comments. Login now