##// END OF EJS Templates
allow map / parallel function for single-engine views
MinRK -
Show More
@@ -1,241 +1,244 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import sys
22 22 import warnings
23 23
24 24 from IPython.testing.skipdoctest import skip_doctest
25 25
26 26 from . import map as Map
27 27 from .asyncresult import AsyncMapResult
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Functions and Decorators
31 31 #-----------------------------------------------------------------------------
32 32
33 33 @skip_doctest
34 34 def remote(view, block=None, **flags):
35 35 """Turn a function into a remote function.
36 36
37 37 This method can be used for map:
38 38
39 39 In [1]: @remote(view,block=True)
40 40 ...: def func(a):
41 41 ...: pass
42 42 """
43 43
44 44 def remote_function(f):
45 45 return RemoteFunction(view, f, block=block, **flags)
46 46 return remote_function
47 47
48 48 @skip_doctest
49 49 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 50 """Turn a function into a parallel remote function.
51 51
52 52 This method can be used for map:
53 53
54 54 In [1]: @parallel(view, block=True)
55 55 ...: def func(a):
56 56 ...: pass
57 57 """
58 58
59 59 def parallel_function(f):
60 60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 61 return parallel_function
62 62
63 63 def getname(f):
64 64 """Get the name of an object.
65 65
66 66 For use in case of callables that are not functions, and
67 67 thus may not have __name__ defined.
68 68
69 69 Order: f.__name__ > f.name > str(f)
70 70 """
71 71 try:
72 72 return f.__name__
73 73 except:
74 74 pass
75 75 try:
76 76 return f.name
77 77 except:
78 78 pass
79 79
80 80 return str(f)
81 81
82 82 #--------------------------------------------------------------------------
83 83 # Classes
84 84 #--------------------------------------------------------------------------
85 85
86 86 class RemoteFunction(object):
87 87 """Turn an existing function into a remote function.
88 88
89 89 Parameters
90 90 ----------
91 91
92 92 view : View instance
93 93 The view to be used for execution
94 94 f : callable
95 95 The function to be wrapped into a remote function
96 96 block : bool [default: None]
97 97 Whether to wait for results or not. The default behavior is
98 98 to use the current `block` attribute of `view`
99 99
100 100 **flags : remaining kwargs are passed to View.temp_flags
101 101 """
102 102
103 103 view = None # the remote connection
104 104 func = None # the wrapped function
105 105 block = None # whether to block
106 106 flags = None # dict of extra kwargs for temp_flags
107 107
108 108 def __init__(self, view, f, block=None, **flags):
109 109 self.view = view
110 110 self.func = f
111 111 self.block=block
112 112 self.flags=flags
113 113
114 114 def __call__(self, *args, **kwargs):
115 115 block = self.view.block if self.block is None else self.block
116 116 with self.view.temp_flags(block=block, **self.flags):
117 117 return self.view.apply(self.func, *args, **kwargs)
118 118
119 119
120 120 class ParallelFunction(RemoteFunction):
121 121 """Class for mapping a function to sequences.
122 122
123 123 This will distribute the sequences according the a mapper, and call
124 124 the function on each sub-sequence. If called via map, then the function
125 125 will be called once on each element, rather that each sub-sequence.
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 view : View instance
131 131 The view to be used for execution
132 132 f : callable
133 133 The function to be wrapped into a remote function
134 134 dist : str [default: 'b']
135 135 The key for which mapObject to use to distribute sequences
136 136 options are:
137 137 * 'b' : use contiguous chunks in order
138 138 * 'r' : use round-robin striping
139 139 block : bool [default: None]
140 140 Whether to wait for results or not. The default behavior is
141 141 to use the current `block` attribute of `view`
142 142 chunksize : int or None
143 143 The size of chunk to use when breaking up sequences in a load-balanced manner
144 144 ordered : bool [default: True]
145 145 Whether
146 146 **flags : remaining kwargs are passed to View.temp_flags
147 147 """
148 148
149 149 chunksize=None
150 150 ordered=None
151 151 mapObject=None
152 152
153 153 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
154 154 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
155 155 self.chunksize = chunksize
156 156 self.ordered = ordered
157 157
158 158 mapClass = Map.dists[dist]
159 159 self.mapObject = mapClass()
160 160
161 161 def __call__(self, *sequences):
162 162 client = self.view.client
163 163
164 164 # check that the length of sequences match
165 165 len_0 = len(sequences[0])
166 166 for s in sequences:
167 167 if len(s)!=len_0:
168 168 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
169 169 raise ValueError(msg)
170 170 balanced = 'Balanced' in self.view.__class__.__name__
171 171 if balanced:
172 172 if self.chunksize:
173 173 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
174 174 else:
175 175 nparts = len_0
176 176 targets = [None]*nparts
177 177 else:
178 178 if self.chunksize:
179 179 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
180 180 # multiplexed:
181 181 targets = self.view.targets
182 182 # 'all' is lazily evaluated at execution time, which is now:
183 183 if targets == 'all':
184 184 targets = client._build_targets(targets)[1]
185 elif isinstance(targets, int):
186 # single-engine view, targets must be iterable
187 targets = [targets]
185 188 nparts = len(targets)
186 189
187 190 msg_ids = []
188 191 for index, t in enumerate(targets):
189 192 args = []
190 193 for seq in sequences:
191 194 part = self.mapObject.getPartition(seq, index, nparts)
192 195 if len(part) == 0:
193 196 continue
194 197 else:
195 198 args.append(part)
196 199 if not args:
197 200 continue
198 201
199 202 # print (args)
200 203 if hasattr(self, '_map'):
201 204 if sys.version_info[0] >= 3:
202 205 f = lambda f, *sequences: list(map(f, *sequences))
203 206 else:
204 207 f = map
205 208 args = [self.func]+args
206 209 else:
207 210 f=self.func
208 211
209 212 view = self.view if balanced else client[t]
210 213 with view.temp_flags(block=False, **self.flags):
211 214 ar = view.apply(f, *args)
212 215
213 216 msg_ids.append(ar.msg_ids[0])
214 217
215 218 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
216 219 fname=getname(self.func),
217 220 ordered=self.ordered
218 221 )
219 222
220 223 if self.block:
221 224 try:
222 225 return r.get()
223 226 except KeyboardInterrupt:
224 227 return r
225 228 else:
226 229 return r
227 230
228 231 def map(self, *sequences):
229 232 """call a function on each element of a sequence remotely.
230 233 This should behave very much like the builtin map, but return an AsyncMapResult
231 234 if self.block is False.
232 235 """
233 236 # set _map as a flag for use inside self.__call__
234 237 self._map = True
235 238 try:
236 239 ret = self.__call__(*sequences)
237 240 finally:
238 241 del self._map
239 242 return ret
240 243
241 244 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,537 +1,543 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython.testing import decorators as dec
28 28
29 29 from IPython import parallel as pmod
30 30 from IPython.parallel import error
31 31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 32 from IPython.parallel import DirectView
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 38
39 39 def setup():
40 40 add_engines(3, total=True)
41 41
42 42 class TestView(ClusterTestCase):
43 43
44 44 def test_z_crash_mux(self):
45 45 """test graceful handling of engine death (direct)"""
46 46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 47 # self.add_engines(1)
48 48 eid = self.client.ids[-1]
49 49 ar = self.client[eid].apply_async(crash)
50 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 51 eid = ar.engine_id
52 52 tic = time.time()
53 53 while eid in self.client.ids and time.time()-tic < 5:
54 54 time.sleep(.01)
55 55 self.client.spin()
56 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 57
58 58 def test_push_pull(self):
59 59 """test pushing and pulling"""
60 60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 61 t = self.client.ids[-1]
62 62 v = self.client[t]
63 63 push = v.push
64 64 pull = v.pull
65 65 v.block=True
66 66 nengines = len(self.client)
67 67 push({'data':data})
68 68 d = pull('data')
69 69 self.assertEquals(d, data)
70 70 self.client[:].push({'data':data})
71 71 d = self.client[:].pull('data', block=True)
72 72 self.assertEquals(d, nengines*[data])
73 73 ar = push({'data':data}, block=False)
74 74 self.assertTrue(isinstance(ar, AsyncResult))
75 75 r = ar.get()
76 76 ar = self.client[:].pull('data', block=False)
77 77 self.assertTrue(isinstance(ar, AsyncResult))
78 78 r = ar.get()
79 79 self.assertEquals(r, nengines*[data])
80 80 self.client[:].push(dict(a=10,b=20))
81 81 r = self.client[:].pull(('a','b'), block=True)
82 82 self.assertEquals(r, nengines*[[10,20]])
83 83
84 84 def test_push_pull_function(self):
85 85 "test pushing and pulling functions"
86 86 def testf(x):
87 87 return 2.0*x
88 88
89 89 t = self.client.ids[-1]
90 90 v = self.client[t]
91 91 v.block=True
92 92 push = v.push
93 93 pull = v.pull
94 94 execute = v.execute
95 95 push({'testf':testf})
96 96 r = pull('testf')
97 97 self.assertEqual(r(1.0), testf(1.0))
98 98 execute('r = testf(10)')
99 99 r = pull('r')
100 100 self.assertEquals(r, testf(10))
101 101 ar = self.client[:].push({'testf':testf}, block=False)
102 102 ar.get()
103 103 ar = self.client[:].pull('testf', block=False)
104 104 rlist = ar.get()
105 105 for r in rlist:
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute("def g(x): return x*x")
108 108 r = pull(('testf','g'))
109 109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 110
111 111 def test_push_function_globals(self):
112 112 """test that pushed functions have access to globals"""
113 113 @interactive
114 114 def geta():
115 115 return a
116 116 # self.add_engines(1)
117 117 v = self.client[-1]
118 118 v.block=True
119 119 v['f'] = geta
120 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 121 v.execute('a=5')
122 122 v.execute('b=f()')
123 123 self.assertEquals(v['b'], 5)
124 124
125 125 def test_push_function_defaults(self):
126 126 """test that pushed functions preserve default args"""
127 127 def echo(a=10):
128 128 return a
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = echo
132 132 v.execute('b=f()')
133 133 self.assertEquals(v['b'], 10)
134 134
135 135 def test_get_result(self):
136 136 """test getting results from the Hub."""
137 137 c = pmod.Client(profile='iptest')
138 138 # self.add_engines(1)
139 139 t = c.ids[-1]
140 140 v = c[t]
141 141 v2 = self.client[t]
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 145 ahr = v2.get_result(ar.msg_ids)
146 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 147 self.assertEquals(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids)
149 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 150 c.spin()
151 151 c.close()
152 152
153 153 def test_run_newline(self):
154 154 """test that run appends newline to files"""
155 155 tmpfile = mktemp()
156 156 with open(tmpfile, 'w') as f:
157 157 f.write("""def g():
158 158 return 5
159 159 """)
160 160 v = self.client[-1]
161 161 v.run(tmpfile, block=True)
162 162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 163
164 164 def test_apply_tracked(self):
165 165 """test tracking for apply"""
166 166 # self.add_engines(1)
167 167 t = self.client.ids[-1]
168 168 v = self.client[t]
169 169 v.block=False
170 170 def echo(n=1024*1024, **kwargs):
171 171 with v.temp_flags(**kwargs):
172 172 return v.apply(lambda x: x, 'x'*n)
173 173 ar = echo(1, track=False)
174 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 175 self.assertTrue(ar.sent)
176 176 ar = echo(track=True)
177 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 178 self.assertEquals(ar.sent, ar._tracker.done)
179 179 ar._tracker.wait()
180 180 self.assertTrue(ar.sent)
181 181
182 182 def test_push_tracked(self):
183 183 t = self.client.ids[-1]
184 184 ns = dict(x='x'*1024*1024)
185 185 v = self.client[t]
186 186 ar = v.push(ns, block=False, track=False)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertTrue(ar.sent)
189 189
190 190 ar = v.push(ns, block=False, track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 ar._tracker.wait()
193 193 self.assertEquals(ar.sent, ar._tracker.done)
194 194 self.assertTrue(ar.sent)
195 195 ar.get()
196 196
197 197 def test_scatter_tracked(self):
198 198 t = self.client.ids
199 199 x='x'*1024*1024
200 200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 self.assertEquals(ar.sent, ar._tracker.done)
207 207 ar._tracker.wait()
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_remote_reference(self):
212 212 v = self.client[-1]
213 213 v['a'] = 123
214 214 ra = pmod.Reference('a')
215 215 b = v.apply_sync(lambda x: x, ra)
216 216 self.assertEquals(b, 123)
217 217
218 218
219 219 def test_scatter_gather(self):
220 220 view = self.client[:]
221 221 seq1 = range(16)
222 222 view.scatter('a', seq1)
223 223 seq2 = view.gather('a', block=True)
224 224 self.assertEquals(seq2, seq1)
225 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 226
227 227 @skip_without('numpy')
228 228 def test_scatter_gather_numpy(self):
229 229 import numpy
230 230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 231 view = self.client[:]
232 232 a = numpy.arange(64)
233 233 view.scatter('a', a)
234 234 b = view.gather('a', block=True)
235 235 assert_array_equal(b, a)
236 236
237 237 @skip_without('numpy')
238 238 def test_push_numpy_nocopy(self):
239 239 import numpy
240 240 view = self.client[:]
241 241 a = numpy.arange(64)
242 242 view['A'] = a
243 243 @interactive
244 244 def check_writeable(x):
245 245 return x.flags.writeable
246 246
247 247 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
248 248 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
249 249
250 250 view.push(dict(B=a))
251 251 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
252 252 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
253 253
254 254 @skip_without('numpy')
255 255 def test_apply_numpy(self):
256 256 """view.apply(f, ndarray)"""
257 257 import numpy
258 258 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
259 259
260 260 A = numpy.random.random((100,100))
261 261 view = self.client[-1]
262 262 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
263 263 B = A.astype(dt)
264 264 C = view.apply_sync(lambda x:x, B)
265 265 assert_array_equal(B,C)
266 266
267 267 def test_map(self):
268 268 view = self.client[:]
269 269 def f(x):
270 270 return x**2
271 271 data = range(16)
272 272 r = view.map_sync(f, data)
273 273 self.assertEquals(r, map(f, data))
274 274
275 275 def test_map_iterable(self):
276 276 """test map on iterables (direct)"""
277 277 view = self.client[:]
278 278 # 101 is prime, so it won't be evenly distributed
279 279 arr = range(101)
280 280 # ensure it will be an iterator, even in Python 3
281 281 it = iter(arr)
282 282 r = view.map_sync(lambda x:x, arr)
283 283 self.assertEquals(r, list(arr))
284 284
285 285 def test_scatterGatherNonblocking(self):
286 286 data = range(16)
287 287 view = self.client[:]
288 288 view.scatter('a', data, block=False)
289 289 ar = view.gather('a', block=False)
290 290 self.assertEquals(ar.get(), data)
291 291
292 292 @skip_without('numpy')
293 293 def test_scatter_gather_numpy_nonblocking(self):
294 294 import numpy
295 295 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
296 296 a = numpy.arange(64)
297 297 view = self.client[:]
298 298 ar = view.scatter('a', a, block=False)
299 299 self.assertTrue(isinstance(ar, AsyncResult))
300 300 amr = view.gather('a', block=False)
301 301 self.assertTrue(isinstance(amr, AsyncMapResult))
302 302 assert_array_equal(amr.get(), a)
303 303
304 304 def test_execute(self):
305 305 view = self.client[:]
306 306 # self.client.debug=True
307 307 execute = view.execute
308 308 ar = execute('c=30', block=False)
309 309 self.assertTrue(isinstance(ar, AsyncResult))
310 310 ar = execute('d=[0,1,2]', block=False)
311 311 self.client.wait(ar, 1)
312 312 self.assertEquals(len(ar.get()), len(self.client))
313 313 for c in view['c']:
314 314 self.assertEquals(c, 30)
315 315
316 316 def test_abort(self):
317 317 view = self.client[-1]
318 318 ar = view.execute('import time; time.sleep(1)', block=False)
319 319 ar2 = view.apply_async(lambda : 2)
320 320 ar3 = view.apply_async(lambda : 3)
321 321 view.abort(ar2)
322 322 view.abort(ar3.msg_ids)
323 323 self.assertRaises(error.TaskAborted, ar2.get)
324 324 self.assertRaises(error.TaskAborted, ar3.get)
325 325
326 326 def test_abort_all(self):
327 327 """view.abort() aborts all outstanding tasks"""
328 328 view = self.client[-1]
329 329 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
330 330 view.abort()
331 331 view.wait(timeout=5)
332 332 for ar in ars[5:]:
333 333 self.assertRaises(error.TaskAborted, ar.get)
334 334
335 335 def test_temp_flags(self):
336 336 view = self.client[-1]
337 337 view.block=True
338 338 with view.temp_flags(block=False):
339 339 self.assertFalse(view.block)
340 340 self.assertTrue(view.block)
341 341
342 342 @dec.known_failure_py3
343 343 def test_importer(self):
344 344 view = self.client[-1]
345 345 view.clear(block=True)
346 346 with view.importer:
347 347 import re
348 348
349 349 @interactive
350 350 def findall(pat, s):
351 351 # this globals() step isn't necessary in real code
352 352 # only to prevent a closure in the test
353 353 re = globals()['re']
354 354 return re.findall(pat, s)
355 355
356 356 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
357 357
358 358 # parallel magic tests
359 359
360 360 def test_magic_px_blocking(self):
361 361 ip = get_ipython()
362 362 v = self.client[-1]
363 363 v.activate()
364 364 v.block=True
365 365
366 366 ip.magic_px('a=5')
367 367 self.assertEquals(v['a'], 5)
368 368 ip.magic_px('a=10')
369 369 self.assertEquals(v['a'], 10)
370 370 sio = StringIO()
371 371 savestdout = sys.stdout
372 372 sys.stdout = sio
373 373 # just 'print a' worst ~99% of the time, but this ensures that
374 374 # the stdout message has arrived when the result is finished:
375 375 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
376 376 sys.stdout = savestdout
377 377 buf = sio.getvalue()
378 378 self.assertTrue('[stdout:' in buf, buf)
379 379 self.assertTrue(buf.rstrip().endswith('10'))
380 380 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
381 381
382 382 def test_magic_px_nonblocking(self):
383 383 ip = get_ipython()
384 384 v = self.client[-1]
385 385 v.activate()
386 386 v.block=False
387 387
388 388 ip.magic_px('a=5')
389 389 self.assertEquals(v['a'], 5)
390 390 ip.magic_px('a=10')
391 391 self.assertEquals(v['a'], 10)
392 392 sio = StringIO()
393 393 savestdout = sys.stdout
394 394 sys.stdout = sio
395 395 ip.magic_px('print a')
396 396 sys.stdout = savestdout
397 397 buf = sio.getvalue()
398 398 self.assertFalse('[stdout:%i]'%v.targets in buf)
399 399 ip.magic_px('1/0')
400 400 ar = v.get_result(-1)
401 401 self.assertRaisesRemote(ZeroDivisionError, ar.get)
402 402
403 403 def test_magic_autopx_blocking(self):
404 404 ip = get_ipython()
405 405 v = self.client[-1]
406 406 v.activate()
407 407 v.block=True
408 408
409 409 sio = StringIO()
410 410 savestdout = sys.stdout
411 411 sys.stdout = sio
412 412 ip.magic_autopx()
413 413 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
414 414 ip.run_cell('print b')
415 415 ip.run_cell("b/c")
416 416 ip.run_code(compile('b*=2', '', 'single'))
417 417 ip.magic_autopx()
418 418 sys.stdout = savestdout
419 419 output = sio.getvalue().strip()
420 420 self.assertTrue(output.startswith('%autopx enabled'))
421 421 self.assertTrue(output.endswith('%autopx disabled'))
422 422 self.assertTrue('RemoteError: ZeroDivisionError' in output)
423 423 ar = v.get_result(-2)
424 424 self.assertEquals(v['a'], 5)
425 425 self.assertEquals(v['b'], 20)
426 426 self.assertRaisesRemote(ZeroDivisionError, ar.get)
427 427
428 428 def test_magic_autopx_nonblocking(self):
429 429 ip = get_ipython()
430 430 v = self.client[-1]
431 431 v.activate()
432 432 v.block=False
433 433
434 434 sio = StringIO()
435 435 savestdout = sys.stdout
436 436 sys.stdout = sio
437 437 ip.magic_autopx()
438 438 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
439 439 ip.run_cell('print b')
440 440 ip.run_cell("b/c")
441 441 ip.run_code(compile('b*=2', '', 'single'))
442 442 ip.magic_autopx()
443 443 sys.stdout = savestdout
444 444 output = sio.getvalue().strip()
445 445 self.assertTrue(output.startswith('%autopx enabled'))
446 446 self.assertTrue(output.endswith('%autopx disabled'))
447 447 self.assertFalse('ZeroDivisionError' in output)
448 448 ar = v.get_result(-2)
449 449 self.assertEquals(v['a'], 5)
450 450 self.assertEquals(v['b'], 20)
451 451 self.assertRaisesRemote(ZeroDivisionError, ar.get)
452 452
453 453 def test_magic_result(self):
454 454 ip = get_ipython()
455 455 v = self.client[-1]
456 456 v.activate()
457 457 v['a'] = 111
458 458 ra = v['a']
459 459
460 460 ar = ip.magic_result()
461 461 self.assertEquals(ar.msg_ids, [v.history[-1]])
462 462 self.assertEquals(ar.get(), 111)
463 463 ar = ip.magic_result('-2')
464 464 self.assertEquals(ar.msg_ids, [v.history[-2]])
465 465
466 466 def test_unicode_execute(self):
467 467 """test executing unicode strings"""
468 468 v = self.client[-1]
469 469 v.block=True
470 470 if sys.version_info[0] >= 3:
471 471 code="a='é'"
472 472 else:
473 473 code=u"a=u'é'"
474 474 v.execute(code)
475 475 self.assertEquals(v['a'], u'é')
476 476
477 477 def test_unicode_apply_result(self):
478 478 """test unicode apply results"""
479 479 v = self.client[-1]
480 480 r = v.apply_sync(lambda : u'é')
481 481 self.assertEquals(r, u'é')
482 482
483 483 def test_unicode_apply_arg(self):
484 484 """test passing unicode arguments to apply"""
485 485 v = self.client[-1]
486 486
487 487 @interactive
488 488 def check_unicode(a, check):
489 489 assert isinstance(a, unicode), "%r is not unicode"%a
490 490 assert isinstance(check, bytes), "%r is not bytes"%check
491 491 assert a.encode('utf8') == check, "%s != %s"%(a,check)
492 492
493 493 for s in [ u'é', u'ßø®∫',u'asdf' ]:
494 494 try:
495 495 v.apply_sync(check_unicode, s, s.encode('utf8'))
496 496 except error.RemoteError as e:
497 497 if e.ename == 'AssertionError':
498 498 self.fail(e.evalue)
499 499 else:
500 500 raise e
501 501
502 502 def test_map_reference(self):
503 503 """view.map(<Reference>, *seqs) should work"""
504 504 v = self.client[:]
505 505 v.scatter('n', self.client.ids, flatten=True)
506 506 v.execute("f = lambda x,y: x*y")
507 507 rf = pmod.Reference('f')
508 508 nlist = list(range(10))
509 509 mlist = nlist[::-1]
510 510 expected = [ m*n for m,n in zip(mlist, nlist) ]
511 511 result = v.map_sync(rf, mlist, nlist)
512 512 self.assertEquals(result, expected)
513 513
514 514 def test_apply_reference(self):
515 515 """view.apply(<Reference>, *args) should work"""
516 516 v = self.client[:]
517 517 v.scatter('n', self.client.ids, flatten=True)
518 518 v.execute("f = lambda x: n*x")
519 519 rf = pmod.Reference('f')
520 520 result = v.apply_sync(rf, 5)
521 521 expected = [ 5*id for id in self.client.ids ]
522 522 self.assertEquals(result, expected)
523 523
524 524 def test_eval_reference(self):
525 525 v = self.client[self.client.ids[0]]
526 526 v['g'] = range(5)
527 527 rg = pmod.Reference('g[0]')
528 528 echo = lambda x:x
529 529 self.assertEquals(v.apply_sync(echo, rg), 0)
530 530
531 531 def test_reference_nameerror(self):
532 532 v = self.client[self.client.ids[0]]
533 533 r = pmod.Reference('elvis_has_left')
534 534 echo = lambda x:x
535 535 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
536 536
537 def test_single_engine_map(self):
538 e0 = self.client[self.client.ids[0]]
539 r = range(5)
540 check = [ -1*i for i in r ]
541 result = e0.map_sync(lambda x: -1*x, r)
542 self.assertEquals(result, check)
537 543
General Comments 0
You need to be logged in to leave comments. Login now