##// END OF EJS Templates
Merge pull request #5766 from minrk/map-length-0...
Thomas Kluyver -
r16513:7aeb0546 merge
parent child Browse files
Show More
@@ -1,285 +1,276 b''
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views."""
2 2
3 Authors:
4
5 * Brian Granger
6 * Min RK
7 """
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
18 5
19 6 from __future__ import division
20 7
21 8 import sys
22 9 import warnings
23 10
24 11 from IPython.external.decorator import decorator
25 12 from IPython.testing.skipdoctest import skip_doctest
26 13
27 14 from . import map as Map
28 15 from .asyncresult import AsyncMapResult
29 16
30 17 #-----------------------------------------------------------------------------
31 18 # Functions and Decorators
32 19 #-----------------------------------------------------------------------------
33 20
34 21 @skip_doctest
35 22 def remote(view, block=None, **flags):
36 23 """Turn a function into a remote function.
37 24
38 25 This method can be used for map:
39 26
40 27 In [1]: @remote(view,block=True)
41 28 ...: def func(a):
42 29 ...: pass
43 30 """
44 31
45 32 def remote_function(f):
46 33 return RemoteFunction(view, f, block=block, **flags)
47 34 return remote_function
48 35
49 36 @skip_doctest
50 37 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 38 """Turn a function into a parallel remote function.
52 39
53 40 This method can be used for map:
54 41
55 42 In [1]: @parallel(view, block=True)
56 43 ...: def func(a):
57 44 ...: pass
58 45 """
59 46
60 47 def parallel_function(f):
61 48 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 49 return parallel_function
63 50
64 51 def getname(f):
65 52 """Get the name of an object.
66 53
67 54 For use in case of callables that are not functions, and
68 55 thus may not have __name__ defined.
69 56
70 57 Order: f.__name__ > f.name > str(f)
71 58 """
72 59 try:
73 60 return f.__name__
74 61 except:
75 62 pass
76 63 try:
77 64 return f.name
78 65 except:
79 66 pass
80 67
81 68 return str(f)
82 69
83 70 @decorator
84 71 def sync_view_results(f, self, *args, **kwargs):
85 72 """sync relevant results from self.client to our results attribute.
86 73
87 74 This is a clone of view.sync_results, but for remote functions
88 75 """
89 76 view = self.view
90 77 if view._in_sync_results:
91 78 return f(self, *args, **kwargs)
92 79 view._in_sync_results = True
93 80 try:
94 81 ret = f(self, *args, **kwargs)
95 82 finally:
96 83 view._in_sync_results = False
97 84 view._sync_results()
98 85 return ret
99 86
100 87 #--------------------------------------------------------------------------
101 88 # Classes
102 89 #--------------------------------------------------------------------------
103 90
104 91 class RemoteFunction(object):
105 92 """Turn an existing function into a remote function.
106 93
107 94 Parameters
108 95 ----------
109 96
110 97 view : View instance
111 98 The view to be used for execution
112 99 f : callable
113 100 The function to be wrapped into a remote function
114 101 block : bool [default: None]
115 102 Whether to wait for results or not. The default behavior is
116 103 to use the current `block` attribute of `view`
117 104
118 105 **flags : remaining kwargs are passed to View.temp_flags
119 106 """
120 107
121 108 view = None # the remote connection
122 109 func = None # the wrapped function
123 110 block = None # whether to block
124 111 flags = None # dict of extra kwargs for temp_flags
125 112
126 113 def __init__(self, view, f, block=None, **flags):
127 114 self.view = view
128 115 self.func = f
129 116 self.block=block
130 117 self.flags=flags
131 118
132 119 def __call__(self, *args, **kwargs):
133 120 block = self.view.block if self.block is None else self.block
134 121 with self.view.temp_flags(block=block, **self.flags):
135 122 return self.view.apply(self.func, *args, **kwargs)
136 123
137 124
138 125 class ParallelFunction(RemoteFunction):
139 126 """Class for mapping a function to sequences.
140 127
141 128 This will distribute the sequences according the a mapper, and call
142 129 the function on each sub-sequence. If called via map, then the function
143 130 will be called once on each element, rather that each sub-sequence.
144 131
145 132 Parameters
146 133 ----------
147 134
148 135 view : View instance
149 136 The view to be used for execution
150 137 f : callable
151 138 The function to be wrapped into a remote function
152 139 dist : str [default: 'b']
153 140 The key for which mapObject to use to distribute sequences
154 141 options are:
155 142
156 143 * 'b' : use contiguous chunks in order
157 144 * 'r' : use round-robin striping
158 145
159 146 block : bool [default: None]
160 147 Whether to wait for results or not. The default behavior is
161 148 to use the current `block` attribute of `view`
162 149 chunksize : int or None
163 150 The size of chunk to use when breaking up sequences in a load-balanced manner
164 151 ordered : bool [default: True]
165 152 Whether the result should be kept in order. If False,
166 153 results become available as they arrive, regardless of submission order.
167 154 **flags
168 155 remaining kwargs are passed to View.temp_flags
169 156 """
170 157
171 158 chunksize = None
172 159 ordered = None
173 160 mapObject = None
174 161 _mapping = False
175 162
176 163 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
177 164 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
178 165 self.chunksize = chunksize
179 166 self.ordered = ordered
180 167
181 168 mapClass = Map.dists[dist]
182 169 self.mapObject = mapClass()
183 170
184 171 @sync_view_results
185 172 def __call__(self, *sequences):
186 173 client = self.view.client
187 174
188 175 lens = []
189 176 maxlen = minlen = -1
190 177 for i, seq in enumerate(sequences):
191 178 try:
192 179 n = len(seq)
193 180 except Exception:
194 181 seq = list(seq)
195 182 if isinstance(sequences, tuple):
196 183 # can't alter a tuple
197 184 sequences = list(sequences)
198 185 sequences[i] = seq
199 186 n = len(seq)
200 187 if n > maxlen:
201 188 maxlen = n
202 189 if minlen == -1 or n < minlen:
203 190 minlen = n
204 191 lens.append(n)
205 192
193 if maxlen == 0:
194 # nothing to iterate over
195 return []
196
206 197 # check that the length of sequences match
207 198 if not self._mapping and minlen != maxlen:
208 199 msg = 'all sequences must have equal length, but have %s' % lens
209 200 raise ValueError(msg)
210 201
211 202 balanced = 'Balanced' in self.view.__class__.__name__
212 203 if balanced:
213 204 if self.chunksize:
214 205 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
215 206 else:
216 207 nparts = maxlen
217 208 targets = [None]*nparts
218 209 else:
219 210 if self.chunksize:
220 211 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
221 212 # multiplexed:
222 213 targets = self.view.targets
223 214 # 'all' is lazily evaluated at execution time, which is now:
224 215 if targets == 'all':
225 216 targets = client._build_targets(targets)[1]
226 217 elif isinstance(targets, int):
227 218 # single-engine view, targets must be iterable
228 219 targets = [targets]
229 220 nparts = len(targets)
230 221
231 222 msg_ids = []
232 223 for index, t in enumerate(targets):
233 224 args = []
234 225 for seq in sequences:
235 226 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
236 227 args.append(part)
237 228
238 229 if sum([len(arg) for arg in args]) == 0:
239 230 continue
240 231
241 232 if self._mapping:
242 233 if sys.version_info[0] >= 3:
243 234 f = lambda f, *sequences: list(map(f, *sequences))
244 235 else:
245 236 f = map
246 237 args = [self.func] + args
247 238 else:
248 239 f=self.func
249 240
250 241 view = self.view if balanced else client[t]
251 242 with view.temp_flags(block=False, **self.flags):
252 243 ar = view.apply(f, *args)
253 244
254 245 msg_ids.extend(ar.msg_ids)
255 246
256 247 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
257 248 fname=getname(self.func),
258 249 ordered=self.ordered
259 250 )
260 251
261 252 if self.block:
262 253 try:
263 254 return r.get()
264 255 except KeyboardInterrupt:
265 256 return r
266 257 else:
267 258 return r
268 259
269 260 def map(self, *sequences):
270 261 """call a function on each element of one or more sequence(s) remotely.
271 262 This should behave very much like the builtin map, but return an AsyncMapResult
272 263 if self.block is False.
273 264
274 265 That means it can take generators (will be cast to lists locally),
275 266 and mismatched sequence lengths will be padded with None.
276 267 """
277 268 # set _mapping as a flag for use inside self.__call__
278 269 self._mapping = True
279 270 try:
280 271 ret = self(*sequences)
281 272 finally:
282 273 self._mapping = False
283 274 return ret
284 275
285 276 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,849 +1,842 b''
1 1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects"""
3 3
4 Authors:
5
6 * Min RK
7 """
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
14
15 #-------------------------------------------------------------------------------
16 # Imports
17 #-------------------------------------------------------------------------------
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
18 6
19 7 import base64
20 8 import sys
21 9 import platform
22 10 import time
23 11 from collections import namedtuple
24 12 from tempfile import NamedTemporaryFile
25 13
26 14 import zmq
27 15 from nose.plugins.attrib import attr
28 16
29 17 from IPython.testing import decorators as dec
30 18 from IPython.utils.io import capture_output
31 19 from IPython.utils.py3compat import unicode_type
32 20
33 21 from IPython import parallel as pmod
34 22 from IPython.parallel import error
35 23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 24 from IPython.parallel.util import interactive
37 25
38 26 from IPython.parallel.tests import add_engines
39 27
40 28 from .clienttest import ClusterTestCase, crash, wait, skip_without
41 29
42 30 def setup():
43 31 add_engines(3, total=True)
44 32
45 33 point = namedtuple("point", "x y")
46 34
47 35 class TestView(ClusterTestCase):
48 36
49 37 def setUp(self):
50 38 # On Win XP, wait for resource cleanup, else parallel test group fails
51 39 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 40 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 41 time.sleep(2)
54 42 super(TestView, self).setUp()
55 43
56 44 @attr('crash')
57 45 def test_z_crash_mux(self):
58 46 """test graceful handling of engine death (direct)"""
59 47 # self.add_engines(1)
60 48 eid = self.client.ids[-1]
61 49 ar = self.client[eid].apply_async(crash)
62 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 51 eid = ar.engine_id
64 52 tic = time.time()
65 53 while eid in self.client.ids and time.time()-tic < 5:
66 54 time.sleep(.01)
67 55 self.client.spin()
68 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
69 57
70 58 def test_push_pull(self):
71 59 """test pushing and pulling"""
72 60 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 61 t = self.client.ids[-1]
74 62 v = self.client[t]
75 63 push = v.push
76 64 pull = v.pull
77 65 v.block=True
78 66 nengines = len(self.client)
79 67 push({'data':data})
80 68 d = pull('data')
81 69 self.assertEqual(d, data)
82 70 self.client[:].push({'data':data})
83 71 d = self.client[:].pull('data', block=True)
84 72 self.assertEqual(d, nengines*[data])
85 73 ar = push({'data':data}, block=False)
86 74 self.assertTrue(isinstance(ar, AsyncResult))
87 75 r = ar.get()
88 76 ar = self.client[:].pull('data', block=False)
89 77 self.assertTrue(isinstance(ar, AsyncResult))
90 78 r = ar.get()
91 79 self.assertEqual(r, nengines*[data])
92 80 self.client[:].push(dict(a=10,b=20))
93 81 r = self.client[:].pull(('a','b'), block=True)
94 82 self.assertEqual(r, nengines*[[10,20]])
95 83
96 84 def test_push_pull_function(self):
97 85 "test pushing and pulling functions"
98 86 def testf(x):
99 87 return 2.0*x
100 88
101 89 t = self.client.ids[-1]
102 90 v = self.client[t]
103 91 v.block=True
104 92 push = v.push
105 93 pull = v.pull
106 94 execute = v.execute
107 95 push({'testf':testf})
108 96 r = pull('testf')
109 97 self.assertEqual(r(1.0), testf(1.0))
110 98 execute('r = testf(10)')
111 99 r = pull('r')
112 100 self.assertEqual(r, testf(10))
113 101 ar = self.client[:].push({'testf':testf}, block=False)
114 102 ar.get()
115 103 ar = self.client[:].pull('testf', block=False)
116 104 rlist = ar.get()
117 105 for r in rlist:
118 106 self.assertEqual(r(1.0), testf(1.0))
119 107 execute("def g(x): return x*x")
120 108 r = pull(('testf','g'))
121 109 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122 110
123 111 def test_push_function_globals(self):
124 112 """test that pushed functions have access to globals"""
125 113 @interactive
126 114 def geta():
127 115 return a
128 116 # self.add_engines(1)
129 117 v = self.client[-1]
130 118 v.block=True
131 119 v['f'] = geta
132 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 121 v.execute('a=5')
134 122 v.execute('b=f()')
135 123 self.assertEqual(v['b'], 5)
136 124
137 125 def test_push_function_defaults(self):
138 126 """test that pushed functions preserve default args"""
139 127 def echo(a=10):
140 128 return a
141 129 v = self.client[-1]
142 130 v.block=True
143 131 v['f'] = echo
144 132 v.execute('b=f()')
145 133 self.assertEqual(v['b'], 10)
146 134
147 135 def test_get_result(self):
148 136 """test getting results from the Hub."""
149 137 c = pmod.Client(profile='iptest')
150 138 # self.add_engines(1)
151 139 t = c.ids[-1]
152 140 v = c[t]
153 141 v2 = self.client[t]
154 142 ar = v.apply_async(wait, 1)
155 143 # give the monitor time to notice the message
156 144 time.sleep(.25)
157 145 ahr = v2.get_result(ar.msg_ids[0])
158 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 147 self.assertEqual(ahr.get(), ar.get())
160 148 ar2 = v2.get_result(ar.msg_ids[0])
161 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 150 c.spin()
163 151 c.close()
164 152
165 153 def test_run_newline(self):
166 154 """test that run appends newline to files"""
167 155 with NamedTemporaryFile('w', delete=False) as f:
168 156 f.write("""def g():
169 157 return 5
170 158 """)
171 159 v = self.client[-1]
172 160 v.run(f.name, block=True)
173 161 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
174 162
175 163 def test_apply_tracked(self):
176 164 """test tracking for apply"""
177 165 # self.add_engines(1)
178 166 t = self.client.ids[-1]
179 167 v = self.client[t]
180 168 v.block=False
181 169 def echo(n=1024*1024, **kwargs):
182 170 with v.temp_flags(**kwargs):
183 171 return v.apply(lambda x: x, 'x'*n)
184 172 ar = echo(1, track=False)
185 173 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 174 self.assertTrue(ar.sent)
187 175 ar = echo(track=True)
188 176 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 177 self.assertEqual(ar.sent, ar._tracker.done)
190 178 ar._tracker.wait()
191 179 self.assertTrue(ar.sent)
192 180
193 181 def test_push_tracked(self):
194 182 t = self.client.ids[-1]
195 183 ns = dict(x='x'*1024*1024)
196 184 v = self.client[t]
197 185 ar = v.push(ns, block=False, track=False)
198 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
199 187 self.assertTrue(ar.sent)
200 188
201 189 ar = v.push(ns, block=False, track=True)
202 190 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 191 ar._tracker.wait()
204 192 self.assertEqual(ar.sent, ar._tracker.done)
205 193 self.assertTrue(ar.sent)
206 194 ar.get()
207 195
208 196 def test_scatter_tracked(self):
209 197 t = self.client.ids
210 198 x='x'*1024*1024
211 199 ar = self.client[t].scatter('x', x, block=False, track=False)
212 200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
213 201 self.assertTrue(ar.sent)
214 202
215 203 ar = self.client[t].scatter('x', x, block=False, track=True)
216 204 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
217 205 self.assertEqual(ar.sent, ar._tracker.done)
218 206 ar._tracker.wait()
219 207 self.assertTrue(ar.sent)
220 208 ar.get()
221 209
222 210 def test_remote_reference(self):
223 211 v = self.client[-1]
224 212 v['a'] = 123
225 213 ra = pmod.Reference('a')
226 214 b = v.apply_sync(lambda x: x, ra)
227 215 self.assertEqual(b, 123)
228 216
229 217
230 218 def test_scatter_gather(self):
231 219 view = self.client[:]
232 220 seq1 = list(range(16))
233 221 view.scatter('a', seq1)
234 222 seq2 = view.gather('a', block=True)
235 223 self.assertEqual(seq2, seq1)
236 224 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
237 225
238 226 @skip_without('numpy')
239 227 def test_scatter_gather_numpy(self):
240 228 import numpy
241 229 from numpy.testing.utils import assert_array_equal
242 230 view = self.client[:]
243 231 a = numpy.arange(64)
244 232 view.scatter('a', a, block=True)
245 233 b = view.gather('a', block=True)
246 234 assert_array_equal(b, a)
247 235
248 236 def test_scatter_gather_lazy(self):
249 237 """scatter/gather with targets='all'"""
250 238 view = self.client.direct_view(targets='all')
251 239 x = list(range(64))
252 240 view.scatter('x', x)
253 241 gathered = view.gather('x', block=True)
254 242 self.assertEqual(gathered, x)
255 243
256 244
257 245 @dec.known_failure_py3
258 246 @skip_without('numpy')
259 247 def test_push_numpy_nocopy(self):
260 248 import numpy
261 249 view = self.client[:]
262 250 a = numpy.arange(64)
263 251 view['A'] = a
264 252 @interactive
265 253 def check_writeable(x):
266 254 return x.flags.writeable
267 255
268 256 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
269 257 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
270 258
271 259 view.push(dict(B=a))
272 260 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
273 261 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
274 262
275 263 @skip_without('numpy')
276 264 def test_apply_numpy(self):
277 265 """view.apply(f, ndarray)"""
278 266 import numpy
279 267 from numpy.testing.utils import assert_array_equal
280 268
281 269 A = numpy.random.random((100,100))
282 270 view = self.client[-1]
283 271 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
284 272 B = A.astype(dt)
285 273 C = view.apply_sync(lambda x:x, B)
286 274 assert_array_equal(B,C)
287 275
288 276 @skip_without('numpy')
289 277 def test_apply_numpy_object_dtype(self):
290 278 """view.apply(f, ndarray) with dtype=object"""
291 279 import numpy
292 280 from numpy.testing.utils import assert_array_equal
293 281 view = self.client[-1]
294 282
295 283 A = numpy.array([dict(a=5)])
296 284 B = view.apply_sync(lambda x:x, A)
297 285 assert_array_equal(A,B)
298 286
299 287 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
300 288 B = view.apply_sync(lambda x:x, A)
301 289 assert_array_equal(A,B)
302 290
303 291 @skip_without('numpy')
304 292 def test_push_pull_recarray(self):
305 293 """push/pull recarrays"""
306 294 import numpy
307 295 from numpy.testing.utils import assert_array_equal
308 296
309 297 view = self.client[-1]
310 298
311 299 R = numpy.array([
312 300 (1, 'hi', 0.),
313 301 (2**30, 'there', 2.5),
314 302 (-99999, 'world', -12345.6789),
315 303 ], [('n', int), ('s', '|S10'), ('f', float)])
316 304
317 305 view['RR'] = R
318 306 R2 = view['RR']
319 307
320 308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
321 309 self.assertEqual(r_dtype, R.dtype)
322 310 self.assertEqual(r_shape, R.shape)
323 311 self.assertEqual(R2.dtype, R.dtype)
324 312 self.assertEqual(R2.shape, R.shape)
325 313 assert_array_equal(R2, R)
326 314
327 315 @skip_without('pandas')
328 316 def test_push_pull_timeseries(self):
329 317 """push/pull pandas.TimeSeries"""
330 318 import pandas
331 319
332 320 ts = pandas.TimeSeries(list(range(10)))
333 321
334 322 view = self.client[-1]
335 323
336 324 view.push(dict(ts=ts), block=True)
337 325 rts = view['ts']
338 326
339 327 self.assertEqual(type(rts), type(ts))
340 328 self.assertTrue((ts == rts).all())
341 329
342 330 def test_map(self):
343 331 view = self.client[:]
344 332 def f(x):
345 333 return x**2
346 334 data = list(range(16))
347 335 r = view.map_sync(f, data)
348 336 self.assertEqual(r, list(map(f, data)))
349 337
338 def test_map_empty_sequence(self):
339 view = self.client[:]
340 r = view.map_sync(lambda x: x, [])
341 self.assertEqual(r, [])
342
350 343 def test_map_iterable(self):
351 344 """test map on iterables (direct)"""
352 345 view = self.client[:]
353 346 # 101 is prime, so it won't be evenly distributed
354 347 arr = range(101)
355 348 # ensure it will be an iterator, even in Python 3
356 349 it = iter(arr)
357 350 r = view.map_sync(lambda x: x, it)
358 351 self.assertEqual(r, list(arr))
359 352
360 353 @skip_without('numpy')
361 354 def test_map_numpy(self):
362 355 """test map on numpy arrays (direct)"""
363 356 import numpy
364 357 from numpy.testing.utils import assert_array_equal
365 358
366 359 view = self.client[:]
367 360 # 101 is prime, so it won't be evenly distributed
368 361 arr = numpy.arange(101)
369 362 r = view.map_sync(lambda x: x, arr)
370 363 assert_array_equal(r, arr)
371 364
372 365 def test_scatter_gather_nonblocking(self):
373 366 data = list(range(16))
374 367 view = self.client[:]
375 368 view.scatter('a', data, block=False)
376 369 ar = view.gather('a', block=False)
377 370 self.assertEqual(ar.get(), data)
378 371
379 372 @skip_without('numpy')
380 373 def test_scatter_gather_numpy_nonblocking(self):
381 374 import numpy
382 375 from numpy.testing.utils import assert_array_equal
383 376 a = numpy.arange(64)
384 377 view = self.client[:]
385 378 ar = view.scatter('a', a, block=False)
386 379 self.assertTrue(isinstance(ar, AsyncResult))
387 380 amr = view.gather('a', block=False)
388 381 self.assertTrue(isinstance(amr, AsyncMapResult))
389 382 assert_array_equal(amr.get(), a)
390 383
391 384 def test_execute(self):
392 385 view = self.client[:]
393 386 # self.client.debug=True
394 387 execute = view.execute
395 388 ar = execute('c=30', block=False)
396 389 self.assertTrue(isinstance(ar, AsyncResult))
397 390 ar = execute('d=[0,1,2]', block=False)
398 391 self.client.wait(ar, 1)
399 392 self.assertEqual(len(ar.get()), len(self.client))
400 393 for c in view['c']:
401 394 self.assertEqual(c, 30)
402 395
403 396 def test_abort(self):
404 397 view = self.client[-1]
405 398 ar = view.execute('import time; time.sleep(1)', block=False)
406 399 ar2 = view.apply_async(lambda : 2)
407 400 ar3 = view.apply_async(lambda : 3)
408 401 view.abort(ar2)
409 402 view.abort(ar3.msg_ids)
410 403 self.assertRaises(error.TaskAborted, ar2.get)
411 404 self.assertRaises(error.TaskAborted, ar3.get)
412 405
413 406 def test_abort_all(self):
414 407 """view.abort() aborts all outstanding tasks"""
415 408 view = self.client[-1]
416 409 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
417 410 view.abort()
418 411 view.wait(timeout=5)
419 412 for ar in ars[5:]:
420 413 self.assertRaises(error.TaskAborted, ar.get)
421 414
422 415 def test_temp_flags(self):
423 416 view = self.client[-1]
424 417 view.block=True
425 418 with view.temp_flags(block=False):
426 419 self.assertFalse(view.block)
427 420 self.assertTrue(view.block)
428 421
429 422 @dec.known_failure_py3
430 423 def test_importer(self):
431 424 view = self.client[-1]
432 425 view.clear(block=True)
433 426 with view.importer:
434 427 import re
435 428
436 429 @interactive
437 430 def findall(pat, s):
438 431 # this globals() step isn't necessary in real code
439 432 # only to prevent a closure in the test
440 433 re = globals()['re']
441 434 return re.findall(pat, s)
442 435
443 436 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
444 437
445 438 def test_unicode_execute(self):
446 439 """test executing unicode strings"""
447 440 v = self.client[-1]
448 441 v.block=True
449 442 if sys.version_info[0] >= 3:
450 443 code="a='é'"
451 444 else:
452 445 code=u"a=u'é'"
453 446 v.execute(code)
454 447 self.assertEqual(v['a'], u'é')
455 448
456 449 def test_unicode_apply_result(self):
457 450 """test unicode apply results"""
458 451 v = self.client[-1]
459 452 r = v.apply_sync(lambda : u'é')
460 453 self.assertEqual(r, u'é')
461 454
462 455 def test_unicode_apply_arg(self):
463 456 """test passing unicode arguments to apply"""
464 457 v = self.client[-1]
465 458
466 459 @interactive
467 460 def check_unicode(a, check):
468 461 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
469 462 assert isinstance(check, bytes), "%r is not bytes"%check
470 463 assert a.encode('utf8') == check, "%s != %s"%(a,check)
471 464
472 465 for s in [ u'é', u'ßø®∫',u'asdf' ]:
473 466 try:
474 467 v.apply_sync(check_unicode, s, s.encode('utf8'))
475 468 except error.RemoteError as e:
476 469 if e.ename == 'AssertionError':
477 470 self.fail(e.evalue)
478 471 else:
479 472 raise e
480 473
481 474 def test_map_reference(self):
482 475 """view.map(<Reference>, *seqs) should work"""
483 476 v = self.client[:]
484 477 v.scatter('n', self.client.ids, flatten=True)
485 478 v.execute("f = lambda x,y: x*y")
486 479 rf = pmod.Reference('f')
487 480 nlist = list(range(10))
488 481 mlist = nlist[::-1]
489 482 expected = [ m*n for m,n in zip(mlist, nlist) ]
490 483 result = v.map_sync(rf, mlist, nlist)
491 484 self.assertEqual(result, expected)
492 485
493 486 def test_apply_reference(self):
494 487 """view.apply(<Reference>, *args) should work"""
495 488 v = self.client[:]
496 489 v.scatter('n', self.client.ids, flatten=True)
497 490 v.execute("f = lambda x: n*x")
498 491 rf = pmod.Reference('f')
499 492 result = v.apply_sync(rf, 5)
500 493 expected = [ 5*id for id in self.client.ids ]
501 494 self.assertEqual(result, expected)
502 495
503 496 def test_eval_reference(self):
504 497 v = self.client[self.client.ids[0]]
505 498 v['g'] = list(range(5))
506 499 rg = pmod.Reference('g[0]')
507 500 echo = lambda x:x
508 501 self.assertEqual(v.apply_sync(echo, rg), 0)
509 502
510 503 def test_reference_nameerror(self):
511 504 v = self.client[self.client.ids[0]]
512 505 r = pmod.Reference('elvis_has_left')
513 506 echo = lambda x:x
514 507 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
515 508
516 509 def test_single_engine_map(self):
517 510 e0 = self.client[self.client.ids[0]]
518 511 r = list(range(5))
519 512 check = [ -1*i for i in r ]
520 513 result = e0.map_sync(lambda x: -1*x, r)
521 514 self.assertEqual(result, check)
522 515
523 516 def test_len(self):
524 517 """len(view) makes sense"""
525 518 e0 = self.client[self.client.ids[0]]
526 519 self.assertEqual(len(e0), 1)
527 520 v = self.client[:]
528 521 self.assertEqual(len(v), len(self.client.ids))
529 522 v = self.client.direct_view('all')
530 523 self.assertEqual(len(v), len(self.client.ids))
531 524 v = self.client[:2]
532 525 self.assertEqual(len(v), 2)
533 526 v = self.client[:1]
534 527 self.assertEqual(len(v), 1)
535 528 v = self.client.load_balanced_view()
536 529 self.assertEqual(len(v), len(self.client.ids))
537 530
538 531
539 532 # begin execute tests
540 533
541 534 def test_execute_reply(self):
542 535 e0 = self.client[self.client.ids[0]]
543 536 e0.block = True
544 537 ar = e0.execute("5", silent=False)
545 538 er = ar.get()
546 539 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
547 540 self.assertEqual(er.pyout['data']['text/plain'], '5')
548 541
549 542 def test_execute_reply_rich(self):
550 543 e0 = self.client[self.client.ids[0]]
551 544 e0.block = True
552 545 e0.execute("from IPython.display import Image, HTML")
553 546 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
554 547 er = ar.get()
555 548 b64data = base64.encodestring(b'garbage').decode('ascii')
556 549 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
557 550 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
558 551 er = ar.get()
559 552 self.assertEqual(er._repr_html_(), "<b>bold</b>")
560 553
561 554 def test_execute_reply_stdout(self):
562 555 e0 = self.client[self.client.ids[0]]
563 556 e0.block = True
564 557 ar = e0.execute("print (5)", silent=False)
565 558 er = ar.get()
566 559 self.assertEqual(er.stdout.strip(), '5')
567 560
568 561 def test_execute_pyout(self):
569 562 """execute triggers pyout with silent=False"""
570 563 view = self.client[:]
571 564 ar = view.execute("5", silent=False, block=True)
572 565
573 566 expected = [{'text/plain' : '5'}] * len(view)
574 567 mimes = [ out['data'] for out in ar.pyout ]
575 568 self.assertEqual(mimes, expected)
576 569
577 570 def test_execute_silent(self):
578 571 """execute does not trigger pyout with silent=True"""
579 572 view = self.client[:]
580 573 ar = view.execute("5", block=True)
581 574 expected = [None] * len(view)
582 575 self.assertEqual(ar.pyout, expected)
583 576
584 577 def test_execute_magic(self):
585 578 """execute accepts IPython commands"""
586 579 view = self.client[:]
587 580 view.execute("a = 5")
588 581 ar = view.execute("%whos", block=True)
589 582 # this will raise, if that failed
590 583 ar.get(5)
591 584 for stdout in ar.stdout:
592 585 lines = stdout.splitlines()
593 586 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
594 587 found = False
595 588 for line in lines[2:]:
596 589 split = line.split()
597 590 if split == ['a', 'int', '5']:
598 591 found = True
599 592 break
600 593 self.assertTrue(found, "whos output wrong: %s" % stdout)
601 594
602 595 def test_execute_displaypub(self):
603 596 """execute tracks display_pub output"""
604 597 view = self.client[:]
605 598 view.execute("from IPython.core.display import *")
606 599 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
607 600
608 601 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
609 602 for outputs in ar.outputs:
610 603 mimes = [ out['data'] for out in outputs ]
611 604 self.assertEqual(mimes, expected)
612 605
613 606 def test_apply_displaypub(self):
614 607 """apply tracks display_pub output"""
615 608 view = self.client[:]
616 609 view.execute("from IPython.core.display import *")
617 610
618 611 @interactive
619 612 def publish():
620 613 [ display(i) for i in range(5) ]
621 614
622 615 ar = view.apply_async(publish)
623 616 ar.get(5)
624 617 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
625 618 for outputs in ar.outputs:
626 619 mimes = [ out['data'] for out in outputs ]
627 620 self.assertEqual(mimes, expected)
628 621
629 622 def test_execute_raises(self):
630 623 """exceptions in execute requests raise appropriately"""
631 624 view = self.client[-1]
632 625 ar = view.execute("1/0")
633 626 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
634 627
635 628 def test_remoteerror_render_exception(self):
636 629 """RemoteErrors get nice tracebacks"""
637 630 view = self.client[-1]
638 631 ar = view.execute("1/0")
639 632 ip = get_ipython()
640 633 ip.user_ns['ar'] = ar
641 634 with capture_output() as io:
642 635 ip.run_cell("ar.get(2)")
643 636
644 637 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
645 638
646 639 def test_compositeerror_render_exception(self):
647 640 """CompositeErrors get nice tracebacks"""
648 641 view = self.client[:]
649 642 ar = view.execute("1/0")
650 643 ip = get_ipython()
651 644 ip.user_ns['ar'] = ar
652 645
653 646 with capture_output() as io:
654 647 ip.run_cell("ar.get(2)")
655 648
656 649 count = min(error.CompositeError.tb_limit, len(view))
657 650
658 651 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
659 652 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
660 653 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
661 654
662 655 def test_compositeerror_truncate(self):
663 656 """Truncate CompositeErrors with many exceptions"""
664 657 view = self.client[:]
665 658 msg_ids = []
666 659 for i in range(10):
667 660 ar = view.execute("1/0")
668 661 msg_ids.extend(ar.msg_ids)
669 662
670 663 ar = self.client.get_result(msg_ids)
671 664 try:
672 665 ar.get()
673 666 except error.CompositeError as _e:
674 667 e = _e
675 668 else:
676 669 self.fail("Should have raised CompositeError")
677 670
678 671 lines = e.render_traceback()
679 672 with capture_output() as io:
680 673 e.print_traceback()
681 674
682 675 self.assertTrue("more exceptions" in lines[-1])
683 676 count = e.tb_limit
684 677
685 678 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
686 679 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
687 680 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
688 681
689 682 @dec.skipif_not_matplotlib
690 683 def test_magic_pylab(self):
691 684 """%pylab works on engines"""
692 685 view = self.client[-1]
693 686 ar = view.execute("%pylab inline")
694 687 # at least check if this raised:
695 688 reply = ar.get(5)
696 689 # include imports, in case user config
697 690 ar = view.execute("plot(rand(100))", silent=False)
698 691 reply = ar.get(5)
699 692 self.assertEqual(len(reply.outputs), 1)
700 693 output = reply.outputs[0]
701 694 self.assertTrue("data" in output)
702 695 data = output['data']
703 696 self.assertTrue("image/png" in data)
704 697
705 698 def test_func_default_func(self):
706 699 """interactively defined function as apply func default"""
707 700 def foo():
708 701 return 'foo'
709 702
710 703 def bar(f=foo):
711 704 return f()
712 705
713 706 view = self.client[-1]
714 707 ar = view.apply_async(bar)
715 708 r = ar.get(10)
716 709 self.assertEqual(r, 'foo')
717 710 def test_data_pub_single(self):
718 711 view = self.client[-1]
719 712 ar = view.execute('\n'.join([
720 713 'from IPython.kernel.zmq.datapub import publish_data',
721 714 'for i in range(5):',
722 715 ' publish_data(dict(i=i))'
723 716 ]), block=False)
724 717 self.assertTrue(isinstance(ar.data, dict))
725 718 ar.get(5)
726 719 self.assertEqual(ar.data, dict(i=4))
727 720
728 721 def test_data_pub(self):
729 722 view = self.client[:]
730 723 ar = view.execute('\n'.join([
731 724 'from IPython.kernel.zmq.datapub import publish_data',
732 725 'for i in range(5):',
733 726 ' publish_data(dict(i=i))'
734 727 ]), block=False)
735 728 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
736 729 ar.get(5)
737 730 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
738 731
739 732 def test_can_list_arg(self):
740 733 """args in lists are canned"""
741 734 view = self.client[-1]
742 735 view['a'] = 128
743 736 rA = pmod.Reference('a')
744 737 ar = view.apply_async(lambda x: x, [rA])
745 738 r = ar.get(5)
746 739 self.assertEqual(r, [128])
747 740
748 741 def test_can_dict_arg(self):
749 742 """args in dicts are canned"""
750 743 view = self.client[-1]
751 744 view['a'] = 128
752 745 rA = pmod.Reference('a')
753 746 ar = view.apply_async(lambda x: x, dict(foo=rA))
754 747 r = ar.get(5)
755 748 self.assertEqual(r, dict(foo=128))
756 749
757 750 def test_can_list_kwarg(self):
758 751 """kwargs in lists are canned"""
759 752 view = self.client[-1]
760 753 view['a'] = 128
761 754 rA = pmod.Reference('a')
762 755 ar = view.apply_async(lambda x=5: x, x=[rA])
763 756 r = ar.get(5)
764 757 self.assertEqual(r, [128])
765 758
766 759 def test_can_dict_kwarg(self):
767 760 """kwargs in dicts are canned"""
768 761 view = self.client[-1]
769 762 view['a'] = 128
770 763 rA = pmod.Reference('a')
771 764 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
772 765 r = ar.get(5)
773 766 self.assertEqual(r, dict(foo=128))
774 767
775 768 def test_map_ref(self):
776 769 """view.map works with references"""
777 770 view = self.client[:]
778 771 ranks = sorted(self.client.ids)
779 772 view.scatter('rank', ranks, flatten=True)
780 773 rrank = pmod.Reference('rank')
781 774
782 775 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
783 776 drank = amr.get(5)
784 777 self.assertEqual(drank, [ r*2 for r in ranks ])
785 778
786 779 def test_nested_getitem_setitem(self):
787 780 """get and set with view['a.b']"""
788 781 view = self.client[-1]
789 782 view.execute('\n'.join([
790 783 'class A(object): pass',
791 784 'a = A()',
792 785 'a.b = 128',
793 786 ]), block=True)
794 787 ra = pmod.Reference('a')
795 788
796 789 r = view.apply_sync(lambda x: x.b, ra)
797 790 self.assertEqual(r, 128)
798 791 self.assertEqual(view['a.b'], 128)
799 792
800 793 view['a.b'] = 0
801 794
802 795 r = view.apply_sync(lambda x: x.b, ra)
803 796 self.assertEqual(r, 0)
804 797 self.assertEqual(view['a.b'], 0)
805 798
806 799 def test_return_namedtuple(self):
807 800 def namedtuplify(x, y):
808 801 from IPython.parallel.tests.test_view import point
809 802 return point(x, y)
810 803
811 804 view = self.client[-1]
812 805 p = view.apply_sync(namedtuplify, 1, 2)
813 806 self.assertEqual(p.x, 1)
814 807 self.assertEqual(p.y, 2)
815 808
816 809 def test_apply_namedtuple(self):
817 810 def echoxy(p):
818 811 return p.y, p.x
819 812
820 813 view = self.client[-1]
821 814 tup = view.apply_sync(echoxy, point(1, 2))
822 815 self.assertEqual(tup, (2,1))
823 816
824 817 def test_sync_imports(self):
825 818 view = self.client[-1]
826 819 with capture_output() as io:
827 820 with view.sync_imports():
828 821 import IPython
829 822 self.assertIn("IPython", io.stdout)
830 823
831 824 @interactive
832 825 def find_ipython():
833 826 return 'IPython' in globals()
834 827
835 828 assert view.apply_sync(find_ipython)
836 829
837 830 def test_sync_imports_quiet(self):
838 831 view = self.client[-1]
839 832 with capture_output() as io:
840 833 with view.sync_imports(quiet=True):
841 834 import IPython
842 835 self.assertEqual(io.stdout, '')
843 836
844 837 @interactive
845 838 def find_ipython():
846 839 return 'IPython' in globals()
847 840
848 841 assert view.apply_sync(find_ipython)
849 842
General Comments 0
You need to be logged in to leave comments. Login now