##// END OF EJS Templates
Merge pull request #2305 from minrk/render_traceback...
Min RK -
r8296:58335ad7 merge
parent child Browse files
Show More
@@ -1,335 +1,341 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes and functions for kernel related errors and exceptions.
3 """Classes and functions for kernel related errors and exceptions.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Min RK
8 * Min RK
9 """
9 """
10 from __future__ import print_function
10 from __future__ import print_function
11
11
12 import sys
12 import sys
13 import traceback
13 import traceback
14
14
15 __docformat__ = "restructuredtext en"
15 __docformat__ = "restructuredtext en"
16
16
17 # Tell nose to skip this module
17 # Tell nose to skip this module
18 __test__ = {}
18 __test__ = {}
19
19
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21 # Copyright (C) 2008-2011 The IPython Development Team
21 # Copyright (C) 2008-2011 The IPython Development Team
22 #
22 #
23 # Distributed under the terms of the BSD License. The full license is in
23 # Distributed under the terms of the BSD License. The full license is in
24 # the file COPYING, distributed as part of this software.
24 # the file COPYING, distributed as part of this software.
25 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
26
26
27 #-------------------------------------------------------------------------------
27 #-------------------------------------------------------------------------------
28 # Error classes
28 # Error classes
29 #-------------------------------------------------------------------------------
29 #-------------------------------------------------------------------------------
30 class IPythonError(Exception):
30 class IPythonError(Exception):
31 """Base exception that all of our exceptions inherit from.
31 """Base exception that all of our exceptions inherit from.
32
32
33 This can be raised by code that doesn't have any more specific
33 This can be raised by code that doesn't have any more specific
34 information."""
34 information."""
35
35
36 pass
36 pass
37
37
38 # Exceptions associated with the controller objects
38 # Exceptions associated with the controller objects
39 class ControllerError(IPythonError): pass
39 class ControllerError(IPythonError): pass
40
40
41 class ControllerCreationError(ControllerError): pass
41 class ControllerCreationError(ControllerError): pass
42
42
43
43
44 # Exceptions associated with the Engines
44 # Exceptions associated with the Engines
45 class EngineError(IPythonError): pass
45 class EngineError(IPythonError): pass
46
46
47 class EngineCreationError(EngineError): pass
47 class EngineCreationError(EngineError): pass
48
48
49 class KernelError(IPythonError):
49 class KernelError(IPythonError):
50 pass
50 pass
51
51
52 class NotDefined(KernelError):
52 class NotDefined(KernelError):
53 def __init__(self, name):
53 def __init__(self, name):
54 self.name = name
54 self.name = name
55 self.args = (name,)
55 self.args = (name,)
56
56
57 def __repr__(self):
57 def __repr__(self):
58 return '<NotDefined: %s>' % self.name
58 return '<NotDefined: %s>' % self.name
59
59
60 __str__ = __repr__
60 __str__ = __repr__
61
61
62
62
63 class QueueCleared(KernelError):
63 class QueueCleared(KernelError):
64 pass
64 pass
65
65
66
66
67 class IdInUse(KernelError):
67 class IdInUse(KernelError):
68 pass
68 pass
69
69
70
70
71 class ProtocolError(KernelError):
71 class ProtocolError(KernelError):
72 pass
72 pass
73
73
74
74
75 class ConnectionError(KernelError):
75 class ConnectionError(KernelError):
76 pass
76 pass
77
77
78
78
79 class InvalidEngineID(KernelError):
79 class InvalidEngineID(KernelError):
80 pass
80 pass
81
81
82
82
83 class NoEnginesRegistered(KernelError):
83 class NoEnginesRegistered(KernelError):
84 pass
84 pass
85
85
86
86
87 class InvalidClientID(KernelError):
87 class InvalidClientID(KernelError):
88 pass
88 pass
89
89
90
90
91 class InvalidDeferredID(KernelError):
91 class InvalidDeferredID(KernelError):
92 pass
92 pass
93
93
94
94
95 class SerializationError(KernelError):
95 class SerializationError(KernelError):
96 pass
96 pass
97
97
98
98
99 class MessageSizeError(KernelError):
99 class MessageSizeError(KernelError):
100 pass
100 pass
101
101
102
102
103 class PBMessageSizeError(MessageSizeError):
103 class PBMessageSizeError(MessageSizeError):
104 pass
104 pass
105
105
106
106
107 class ResultNotCompleted(KernelError):
107 class ResultNotCompleted(KernelError):
108 pass
108 pass
109
109
110
110
111 class ResultAlreadyRetrieved(KernelError):
111 class ResultAlreadyRetrieved(KernelError):
112 pass
112 pass
113
113
114 class ClientError(KernelError):
114 class ClientError(KernelError):
115 pass
115 pass
116
116
117
117
118 class TaskAborted(KernelError):
118 class TaskAborted(KernelError):
119 pass
119 pass
120
120
121
121
122 class TaskTimeout(KernelError):
122 class TaskTimeout(KernelError):
123 pass
123 pass
124
124
125
125
126 class NotAPendingResult(KernelError):
126 class NotAPendingResult(KernelError):
127 pass
127 pass
128
128
129
129
130 class UnpickleableException(KernelError):
130 class UnpickleableException(KernelError):
131 pass
131 pass
132
132
133
133
134 class AbortedPendingDeferredError(KernelError):
134 class AbortedPendingDeferredError(KernelError):
135 pass
135 pass
136
136
137
137
138 class InvalidProperty(KernelError):
138 class InvalidProperty(KernelError):
139 pass
139 pass
140
140
141
141
142 class MissingBlockArgument(KernelError):
142 class MissingBlockArgument(KernelError):
143 pass
143 pass
144
144
145
145
146 class StopLocalExecution(KernelError):
146 class StopLocalExecution(KernelError):
147 pass
147 pass
148
148
149
149
150 class SecurityError(KernelError):
150 class SecurityError(KernelError):
151 pass
151 pass
152
152
153
153
154 class FileTimeoutError(KernelError):
154 class FileTimeoutError(KernelError):
155 pass
155 pass
156
156
157 class TimeoutError(KernelError):
157 class TimeoutError(KernelError):
158 pass
158 pass
159
159
160 class UnmetDependency(KernelError):
160 class UnmetDependency(KernelError):
161 pass
161 pass
162
162
163 class ImpossibleDependency(UnmetDependency):
163 class ImpossibleDependency(UnmetDependency):
164 pass
164 pass
165
165
166 class DependencyTimeout(ImpossibleDependency):
166 class DependencyTimeout(ImpossibleDependency):
167 pass
167 pass
168
168
169 class InvalidDependency(ImpossibleDependency):
169 class InvalidDependency(ImpossibleDependency):
170 pass
170 pass
171
171
172 class RemoteError(KernelError):
172 class RemoteError(KernelError):
173 """Error raised elsewhere"""
173 """Error raised elsewhere"""
174 ename=None
174 ename=None
175 evalue=None
175 evalue=None
176 traceback=None
176 traceback=None
177 engine_info=None
177 engine_info=None
178
178
179 def __init__(self, ename, evalue, traceback, engine_info=None):
179 def __init__(self, ename, evalue, traceback, engine_info=None):
180 self.ename=ename
180 self.ename=ename
181 self.evalue=evalue
181 self.evalue=evalue
182 self.traceback=traceback
182 self.traceback=traceback
183 self.engine_info=engine_info or {}
183 self.engine_info=engine_info or {}
184 self.args=(ename, evalue)
184 self.args=(ename, evalue)
185
185
186 def __repr__(self):
186 def __repr__(self):
187 engineid = self.engine_info.get('engine_id', ' ')
187 engineid = self.engine_info.get('engine_id', ' ')
188 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
188 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
189
189
190 def __str__(self):
190 def __str__(self):
191 return "%s(%s)" % (self.ename, self.evalue)
191 return "%s(%s)" % (self.ename, self.evalue)
192
192
193 def render_traceback(self):
193 def render_traceback(self):
194 """render traceback to a list of lines"""
194 """render traceback to a list of lines"""
195 return (self.traceback or "No traceback available").splitlines()
195 return (self.traceback or "No traceback available").splitlines()
196
196
197 # Special method for custom tracebacks within IPython
197 def _render_traceback_(self):
198 _render_traceback_ = render_traceback
198 """Special method for custom tracebacks within IPython.
199
199
200 This will be called by IPython instead of displaying the local traceback.
201
202 It should return a traceback rendered as a list of lines.
203 """
204 return self.render_traceback()
205
200 def print_traceback(self, excid=None):
206 def print_traceback(self, excid=None):
201 """print my traceback"""
207 """print my traceback"""
202 print('\n'.join(self.render_traceback()))
208 print('\n'.join(self.render_traceback()))
203
209
204
210
205
211
206
212
207 class TaskRejectError(KernelError):
213 class TaskRejectError(KernelError):
208 """Exception to raise when a task should be rejected by an engine.
214 """Exception to raise when a task should be rejected by an engine.
209
215
210 This exception can be used to allow a task running on an engine to test
216 This exception can be used to allow a task running on an engine to test
211 if the engine (or the user's namespace on the engine) has the needed
217 if the engine (or the user's namespace on the engine) has the needed
212 task dependencies. If not, the task should raise this exception. For
218 task dependencies. If not, the task should raise this exception. For
213 the task to be retried on another engine, the task should be created
219 the task to be retried on another engine, the task should be created
214 with the `retries` argument > 1.
220 with the `retries` argument > 1.
215
221
216 The advantage of this approach over our older properties system is that
222 The advantage of this approach over our older properties system is that
217 tasks have full access to the user's namespace on the engines and the
223 tasks have full access to the user's namespace on the engines and the
218 properties don't have to be managed or tested by the controller.
224 properties don't have to be managed or tested by the controller.
219 """
225 """
220
226
221
227
222 class CompositeError(RemoteError):
228 class CompositeError(RemoteError):
223 """Error for representing possibly multiple errors on engines"""
229 """Error for representing possibly multiple errors on engines"""
224 def __init__(self, message, elist):
230 def __init__(self, message, elist):
225 Exception.__init__(self, *(message, elist))
231 Exception.__init__(self, *(message, elist))
226 # Don't use pack_exception because it will conflict with the .message
232 # Don't use pack_exception because it will conflict with the .message
227 # attribute that is being deprecated in 2.6 and beyond.
233 # attribute that is being deprecated in 2.6 and beyond.
228 self.msg = message
234 self.msg = message
229 self.elist = elist
235 self.elist = elist
230 self.args = [ e[0] for e in elist ]
236 self.args = [ e[0] for e in elist ]
231
237
232 def _get_engine_str(self, ei):
238 def _get_engine_str(self, ei):
233 if not ei:
239 if not ei:
234 return '[Engine Exception]'
240 return '[Engine Exception]'
235 else:
241 else:
236 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
242 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
237
243
238 def _get_traceback(self, ev):
244 def _get_traceback(self, ev):
239 try:
245 try:
240 tb = ev._ipython_traceback_text
246 tb = ev._ipython_traceback_text
241 except AttributeError:
247 except AttributeError:
242 return 'No traceback available'
248 return 'No traceback available'
243 else:
249 else:
244 return tb
250 return tb
245
251
246 def __str__(self):
252 def __str__(self):
247 s = str(self.msg)
253 s = str(self.msg)
248 for en, ev, etb, ei in self.elist:
254 for en, ev, etb, ei in self.elist:
249 engine_str = self._get_engine_str(ei)
255 engine_str = self._get_engine_str(ei)
250 s = s + '\n' + engine_str + en + ': ' + str(ev)
256 s = s + '\n' + engine_str + en + ': ' + str(ev)
251 return s
257 return s
252
258
253 def __repr__(self):
259 def __repr__(self):
254 return "CompositeError(%i)"%len(self.elist)
260 return "CompositeError(%i)"%len(self.elist)
255
261
256 def render_traceback(self, excid=None):
262 def render_traceback(self, excid=None):
257 """render one or all of my tracebacks to a list of lines"""
263 """render one or all of my tracebacks to a list of lines"""
258 lines = []
264 lines = []
259 if excid is None:
265 if excid is None:
260 for (en,ev,etb,ei) in self.elist:
266 for (en,ev,etb,ei) in self.elist:
261 lines.append(self._get_engine_str(ei))
267 lines.append(self._get_engine_str(ei))
262 lines.extend((etb or 'No traceback available').splitlines())
268 lines.extend((etb or 'No traceback available').splitlines())
263 lines.append('')
269 lines.append('')
264 else:
270 else:
265 try:
271 try:
266 en,ev,etb,ei = self.elist[excid]
272 en,ev,etb,ei = self.elist[excid]
267 except:
273 except:
268 raise IndexError("an exception with index %i does not exist"%excid)
274 raise IndexError("an exception with index %i does not exist"%excid)
269 else:
275 else:
270 lines.append(self._get_engine_str(ei))
276 lines.append(self._get_engine_str(ei))
271 lines.extend((etb or 'No traceback available').splitlines())
277 lines.extend((etb or 'No traceback available').splitlines())
272
278
273 return lines
279 return lines
274
280
275 def print_traceback(self, excid=None):
281 def print_traceback(self, excid=None):
276 print('\n'.join(self.render_traceback(excid)))
282 print('\n'.join(self.render_traceback(excid)))
277
283
278 def raise_exception(self, excid=0):
284 def raise_exception(self, excid=0):
279 try:
285 try:
280 en,ev,etb,ei = self.elist[excid]
286 en,ev,etb,ei = self.elist[excid]
281 except:
287 except:
282 raise IndexError("an exception with index %i does not exist"%excid)
288 raise IndexError("an exception with index %i does not exist"%excid)
283 else:
289 else:
284 raise RemoteError(en, ev, etb, ei)
290 raise RemoteError(en, ev, etb, ei)
285
291
286
292
287 def collect_exceptions(rdict_or_list, method='unspecified'):
293 def collect_exceptions(rdict_or_list, method='unspecified'):
288 """check a result dict for errors, and raise CompositeError if any exist.
294 """check a result dict for errors, and raise CompositeError if any exist.
289 Passthrough otherwise."""
295 Passthrough otherwise."""
290 elist = []
296 elist = []
291 if isinstance(rdict_or_list, dict):
297 if isinstance(rdict_or_list, dict):
292 rlist = rdict_or_list.values()
298 rlist = rdict_or_list.values()
293 else:
299 else:
294 rlist = rdict_or_list
300 rlist = rdict_or_list
295 for r in rlist:
301 for r in rlist:
296 if isinstance(r, RemoteError):
302 if isinstance(r, RemoteError):
297 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
303 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
298 # Sometimes we could have CompositeError in our list. Just take
304 # Sometimes we could have CompositeError in our list. Just take
299 # the errors out of them and put them in our new list. This
305 # the errors out of them and put them in our new list. This
300 # has the effect of flattening lists of CompositeErrors into one
306 # has the effect of flattening lists of CompositeErrors into one
301 # CompositeError
307 # CompositeError
302 if en=='CompositeError':
308 if en=='CompositeError':
303 for e in ev.elist:
309 for e in ev.elist:
304 elist.append(e)
310 elist.append(e)
305 else:
311 else:
306 elist.append((en, ev, etb, ei))
312 elist.append((en, ev, etb, ei))
307 if len(elist)==0:
313 if len(elist)==0:
308 return rdict_or_list
314 return rdict_or_list
309 else:
315 else:
310 msg = "one or more exceptions from call to method: %s" % (method)
316 msg = "one or more exceptions from call to method: %s" % (method)
311 # This silliness is needed so the debugger has access to the exception
317 # This silliness is needed so the debugger has access to the exception
312 # instance (e in this case)
318 # instance (e in this case)
313 try:
319 try:
314 raise CompositeError(msg, elist)
320 raise CompositeError(msg, elist)
315 except CompositeError as e:
321 except CompositeError as e:
316 raise e
322 raise e
317
323
318 def wrap_exception(engine_info={}):
324 def wrap_exception(engine_info={}):
319 etype, evalue, tb = sys.exc_info()
325 etype, evalue, tb = sys.exc_info()
320 stb = traceback.format_exception(etype, evalue, tb)
326 stb = traceback.format_exception(etype, evalue, tb)
321 exc_content = {
327 exc_content = {
322 'status' : 'error',
328 'status' : 'error',
323 'traceback' : stb,
329 'traceback' : stb,
324 'ename' : unicode(etype.__name__),
330 'ename' : unicode(etype.__name__),
325 'evalue' : unicode(evalue),
331 'evalue' : unicode(evalue),
326 'engine_info' : engine_info
332 'engine_info' : engine_info
327 }
333 }
328 return exc_content
334 return exc_content
329
335
330 def unwrap_exception(content):
336 def unwrap_exception(content):
331 err = RemoteError(content['ename'], content['evalue'],
337 err = RemoteError(content['ename'], content['evalue'],
332 ''.join(content['traceback']),
338 ''.join(content['traceback']),
333 content.get('engine_info', {}))
339 content.get('engine_info', {}))
334 return err
340 return err
335
341
@@ -1,678 +1,703 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30 from IPython.utils.io import capture_output
30
31
31 from IPython import parallel as pmod
32 from IPython import parallel as pmod
32 from IPython.parallel import error
33 from IPython.parallel import error
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import DirectView
35 from IPython.parallel import DirectView
35 from IPython.parallel.util import interactive
36 from IPython.parallel.util import interactive
36
37
37 from IPython.parallel.tests import add_engines
38 from IPython.parallel.tests import add_engines
38
39
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
40
41
41 def setup():
42 def setup():
42 add_engines(3, total=True)
43 add_engines(3, total=True)
43
44
44 class TestView(ClusterTestCase, ParametricTestCase):
45 class TestView(ClusterTestCase, ParametricTestCase):
45
46
46 def setUp(self):
47 def setUp(self):
47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 # On Win XP, wait for resource cleanup, else parallel test group fails
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 time.sleep(2)
51 time.sleep(2)
51 super(TestView, self).setUp()
52 super(TestView, self).setUp()
52
53
53 def test_z_crash_mux(self):
54 def test_z_crash_mux(self):
54 """test graceful handling of engine death (direct)"""
55 """test graceful handling of engine death (direct)"""
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 # self.add_engines(1)
57 # self.add_engines(1)
57 eid = self.client.ids[-1]
58 eid = self.client.ids[-1]
58 ar = self.client[eid].apply_async(crash)
59 ar = self.client[eid].apply_async(crash)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 eid = ar.engine_id
61 eid = ar.engine_id
61 tic = time.time()
62 tic = time.time()
62 while eid in self.client.ids and time.time()-tic < 5:
63 while eid in self.client.ids and time.time()-tic < 5:
63 time.sleep(.01)
64 time.sleep(.01)
64 self.client.spin()
65 self.client.spin()
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66 self.assertFalse(eid in self.client.ids, "Engine should have died")
66
67
67 def test_push_pull(self):
68 def test_push_pull(self):
68 """test pushing and pulling"""
69 """test pushing and pulling"""
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 t = self.client.ids[-1]
71 t = self.client.ids[-1]
71 v = self.client[t]
72 v = self.client[t]
72 push = v.push
73 push = v.push
73 pull = v.pull
74 pull = v.pull
74 v.block=True
75 v.block=True
75 nengines = len(self.client)
76 nengines = len(self.client)
76 push({'data':data})
77 push({'data':data})
77 d = pull('data')
78 d = pull('data')
78 self.assertEqual(d, data)
79 self.assertEqual(d, data)
79 self.client[:].push({'data':data})
80 self.client[:].push({'data':data})
80 d = self.client[:].pull('data', block=True)
81 d = self.client[:].pull('data', block=True)
81 self.assertEqual(d, nengines*[data])
82 self.assertEqual(d, nengines*[data])
82 ar = push({'data':data}, block=False)
83 ar = push({'data':data}, block=False)
83 self.assertTrue(isinstance(ar, AsyncResult))
84 self.assertTrue(isinstance(ar, AsyncResult))
84 r = ar.get()
85 r = ar.get()
85 ar = self.client[:].pull('data', block=False)
86 ar = self.client[:].pull('data', block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
87 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
88 r = ar.get()
88 self.assertEqual(r, nengines*[data])
89 self.assertEqual(r, nengines*[data])
89 self.client[:].push(dict(a=10,b=20))
90 self.client[:].push(dict(a=10,b=20))
90 r = self.client[:].pull(('a','b'), block=True)
91 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEqual(r, nengines*[[10,20]])
92 self.assertEqual(r, nengines*[[10,20]])
92
93
93 def test_push_pull_function(self):
94 def test_push_pull_function(self):
94 "test pushing and pulling functions"
95 "test pushing and pulling functions"
95 def testf(x):
96 def testf(x):
96 return 2.0*x
97 return 2.0*x
97
98
98 t = self.client.ids[-1]
99 t = self.client.ids[-1]
99 v = self.client[t]
100 v = self.client[t]
100 v.block=True
101 v.block=True
101 push = v.push
102 push = v.push
102 pull = v.pull
103 pull = v.pull
103 execute = v.execute
104 execute = v.execute
104 push({'testf':testf})
105 push({'testf':testf})
105 r = pull('testf')
106 r = pull('testf')
106 self.assertEqual(r(1.0), testf(1.0))
107 self.assertEqual(r(1.0), testf(1.0))
107 execute('r = testf(10)')
108 execute('r = testf(10)')
108 r = pull('r')
109 r = pull('r')
109 self.assertEqual(r, testf(10))
110 self.assertEqual(r, testf(10))
110 ar = self.client[:].push({'testf':testf}, block=False)
111 ar = self.client[:].push({'testf':testf}, block=False)
111 ar.get()
112 ar.get()
112 ar = self.client[:].pull('testf', block=False)
113 ar = self.client[:].pull('testf', block=False)
113 rlist = ar.get()
114 rlist = ar.get()
114 for r in rlist:
115 for r in rlist:
115 self.assertEqual(r(1.0), testf(1.0))
116 self.assertEqual(r(1.0), testf(1.0))
116 execute("def g(x): return x*x")
117 execute("def g(x): return x*x")
117 r = pull(('testf','g'))
118 r = pull(('testf','g'))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119
120
120 def test_push_function_globals(self):
121 def test_push_function_globals(self):
121 """test that pushed functions have access to globals"""
122 """test that pushed functions have access to globals"""
122 @interactive
123 @interactive
123 def geta():
124 def geta():
124 return a
125 return a
125 # self.add_engines(1)
126 # self.add_engines(1)
126 v = self.client[-1]
127 v = self.client[-1]
127 v.block=True
128 v.block=True
128 v['f'] = geta
129 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
131 v.execute('a=5')
131 v.execute('b=f()')
132 v.execute('b=f()')
132 self.assertEqual(v['b'], 5)
133 self.assertEqual(v['b'], 5)
133
134
134 def test_push_function_defaults(self):
135 def test_push_function_defaults(self):
135 """test that pushed functions preserve default args"""
136 """test that pushed functions preserve default args"""
136 def echo(a=10):
137 def echo(a=10):
137 return a
138 return a
138 v = self.client[-1]
139 v = self.client[-1]
139 v.block=True
140 v.block=True
140 v['f'] = echo
141 v['f'] = echo
141 v.execute('b=f()')
142 v.execute('b=f()')
142 self.assertEqual(v['b'], 10)
143 self.assertEqual(v['b'], 10)
143
144
144 def test_get_result(self):
145 def test_get_result(self):
145 """test getting results from the Hub."""
146 """test getting results from the Hub."""
146 c = pmod.Client(profile='iptest')
147 c = pmod.Client(profile='iptest')
147 # self.add_engines(1)
148 # self.add_engines(1)
148 t = c.ids[-1]
149 t = c.ids[-1]
149 v = c[t]
150 v = c[t]
150 v2 = self.client[t]
151 v2 = self.client[t]
151 ar = v.apply_async(wait, 1)
152 ar = v.apply_async(wait, 1)
152 # give the monitor time to notice the message
153 # give the monitor time to notice the message
153 time.sleep(.25)
154 time.sleep(.25)
154 ahr = v2.get_result(ar.msg_ids)
155 ahr = v2.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEqual(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
157 ar2 = v2.get_result(ar.msg_ids)
158 ar2 = v2.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.spin()
160 c.spin()
160 c.close()
161 c.close()
161
162
162 def test_run_newline(self):
163 def test_run_newline(self):
163 """test that run appends newline to files"""
164 """test that run appends newline to files"""
164 tmpfile = mktemp()
165 tmpfile = mktemp()
165 with open(tmpfile, 'w') as f:
166 with open(tmpfile, 'w') as f:
166 f.write("""def g():
167 f.write("""def g():
167 return 5
168 return 5
168 """)
169 """)
169 v = self.client[-1]
170 v = self.client[-1]
170 v.run(tmpfile, block=True)
171 v.run(tmpfile, block=True)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172
173
173 def test_apply_tracked(self):
174 def test_apply_tracked(self):
174 """test tracking for apply"""
175 """test tracking for apply"""
175 # self.add_engines(1)
176 # self.add_engines(1)
176 t = self.client.ids[-1]
177 t = self.client.ids[-1]
177 v = self.client[t]
178 v = self.client[t]
178 v.block=False
179 v.block=False
179 def echo(n=1024*1024, **kwargs):
180 def echo(n=1024*1024, **kwargs):
180 with v.temp_flags(**kwargs):
181 with v.temp_flags(**kwargs):
181 return v.apply(lambda x: x, 'x'*n)
182 return v.apply(lambda x: x, 'x'*n)
182 ar = echo(1, track=False)
183 ar = echo(1, track=False)
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(ar.sent)
185 self.assertTrue(ar.sent)
185 ar = echo(track=True)
186 ar = echo(track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEqual(ar.sent, ar._tracker.done)
188 self.assertEqual(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
189 ar._tracker.wait()
189 self.assertTrue(ar.sent)
190 self.assertTrue(ar.sent)
190
191
191 def test_push_tracked(self):
192 def test_push_tracked(self):
192 t = self.client.ids[-1]
193 t = self.client.ids[-1]
193 ns = dict(x='x'*1024*1024)
194 ns = dict(x='x'*1024*1024)
194 v = self.client[t]
195 v = self.client[t]
195 ar = v.push(ns, block=False, track=False)
196 ar = v.push(ns, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
198 self.assertTrue(ar.sent)
198
199
199 ar = v.push(ns, block=False, track=True)
200 ar = v.push(ns, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 ar._tracker.wait()
202 ar._tracker.wait()
202 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertTrue(ar.sent)
204 self.assertTrue(ar.sent)
204 ar.get()
205 ar.get()
205
206
206 def test_scatter_tracked(self):
207 def test_scatter_tracked(self):
207 t = self.client.ids
208 t = self.client.ids
208 x='x'*1024*1024
209 x='x'*1024*1024
209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 ar = self.client[t].scatter('x', x, block=False, track=False)
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(ar.sent)
212 self.assertTrue(ar.sent)
212
213
213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 ar = self.client[t].scatter('x', x, block=False, track=True)
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEqual(ar.sent, ar._tracker.done)
216 self.assertEqual(ar.sent, ar._tracker.done)
216 ar._tracker.wait()
217 ar._tracker.wait()
217 self.assertTrue(ar.sent)
218 self.assertTrue(ar.sent)
218 ar.get()
219 ar.get()
219
220
220 def test_remote_reference(self):
221 def test_remote_reference(self):
221 v = self.client[-1]
222 v = self.client[-1]
222 v['a'] = 123
223 v['a'] = 123
223 ra = pmod.Reference('a')
224 ra = pmod.Reference('a')
224 b = v.apply_sync(lambda x: x, ra)
225 b = v.apply_sync(lambda x: x, ra)
225 self.assertEqual(b, 123)
226 self.assertEqual(b, 123)
226
227
227
228
228 def test_scatter_gather(self):
229 def test_scatter_gather(self):
229 view = self.client[:]
230 view = self.client[:]
230 seq1 = range(16)
231 seq1 = range(16)
231 view.scatter('a', seq1)
232 view.scatter('a', seq1)
232 seq2 = view.gather('a', block=True)
233 seq2 = view.gather('a', block=True)
233 self.assertEqual(seq2, seq1)
234 self.assertEqual(seq2, seq1)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235
236
236 @skip_without('numpy')
237 @skip_without('numpy')
237 def test_scatter_gather_numpy(self):
238 def test_scatter_gather_numpy(self):
238 import numpy
239 import numpy
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
241 view = self.client[:]
241 a = numpy.arange(64)
242 a = numpy.arange(64)
242 view.scatter('a', a, block=True)
243 view.scatter('a', a, block=True)
243 b = view.gather('a', block=True)
244 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
245 assert_array_equal(b, a)
245
246
246 def test_scatter_gather_lazy(self):
247 def test_scatter_gather_lazy(self):
247 """scatter/gather with targets='all'"""
248 """scatter/gather with targets='all'"""
248 view = self.client.direct_view(targets='all')
249 view = self.client.direct_view(targets='all')
249 x = range(64)
250 x = range(64)
250 view.scatter('x', x)
251 view.scatter('x', x)
251 gathered = view.gather('x', block=True)
252 gathered = view.gather('x', block=True)
252 self.assertEqual(gathered, x)
253 self.assertEqual(gathered, x)
253
254
254
255
255 @dec.known_failure_py3
256 @dec.known_failure_py3
256 @skip_without('numpy')
257 @skip_without('numpy')
257 def test_push_numpy_nocopy(self):
258 def test_push_numpy_nocopy(self):
258 import numpy
259 import numpy
259 view = self.client[:]
260 view = self.client[:]
260 a = numpy.arange(64)
261 a = numpy.arange(64)
261 view['A'] = a
262 view['A'] = a
262 @interactive
263 @interactive
263 def check_writeable(x):
264 def check_writeable(x):
264 return x.flags.writeable
265 return x.flags.writeable
265
266
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268
269
269 view.push(dict(B=a))
270 view.push(dict(B=a))
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272
273
273 @skip_without('numpy')
274 @skip_without('numpy')
274 def test_apply_numpy(self):
275 def test_apply_numpy(self):
275 """view.apply(f, ndarray)"""
276 """view.apply(f, ndarray)"""
276 import numpy
277 import numpy
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278
279
279 A = numpy.random.random((100,100))
280 A = numpy.random.random((100,100))
280 view = self.client[-1]
281 view = self.client[-1]
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 B = A.astype(dt)
283 B = A.astype(dt)
283 C = view.apply_sync(lambda x:x, B)
284 C = view.apply_sync(lambda x:x, B)
284 assert_array_equal(B,C)
285 assert_array_equal(B,C)
285
286
286 @skip_without('numpy')
287 @skip_without('numpy')
287 def test_push_pull_recarray(self):
288 def test_push_pull_recarray(self):
288 """push/pull recarrays"""
289 """push/pull recarrays"""
289 import numpy
290 import numpy
290 from numpy.testing.utils import assert_array_equal
291 from numpy.testing.utils import assert_array_equal
291
292
292 view = self.client[-1]
293 view = self.client[-1]
293
294
294 R = numpy.array([
295 R = numpy.array([
295 (1, 'hi', 0.),
296 (1, 'hi', 0.),
296 (2**30, 'there', 2.5),
297 (2**30, 'there', 2.5),
297 (-99999, 'world', -12345.6789),
298 (-99999, 'world', -12345.6789),
298 ], [('n', int), ('s', '|S10'), ('f', float)])
299 ], [('n', int), ('s', '|S10'), ('f', float)])
299
300
300 view['RR'] = R
301 view['RR'] = R
301 R2 = view['RR']
302 R2 = view['RR']
302
303
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.shape, R.shape)
308 self.assertEqual(R2.shape, R.shape)
308 assert_array_equal(R2, R)
309 assert_array_equal(R2, R)
309
310
310 def test_map(self):
311 def test_map(self):
311 view = self.client[:]
312 view = self.client[:]
312 def f(x):
313 def f(x):
313 return x**2
314 return x**2
314 data = range(16)
315 data = range(16)
315 r = view.map_sync(f, data)
316 r = view.map_sync(f, data)
316 self.assertEqual(r, map(f, data))
317 self.assertEqual(r, map(f, data))
317
318
318 def test_map_iterable(self):
319 def test_map_iterable(self):
319 """test map on iterables (direct)"""
320 """test map on iterables (direct)"""
320 view = self.client[:]
321 view = self.client[:]
321 # 101 is prime, so it won't be evenly distributed
322 # 101 is prime, so it won't be evenly distributed
322 arr = range(101)
323 arr = range(101)
323 # ensure it will be an iterator, even in Python 3
324 # ensure it will be an iterator, even in Python 3
324 it = iter(arr)
325 it = iter(arr)
325 r = view.map_sync(lambda x:x, arr)
326 r = view.map_sync(lambda x:x, arr)
326 self.assertEqual(r, list(arr))
327 self.assertEqual(r, list(arr))
327
328
328 def test_scatter_gather_nonblocking(self):
329 def test_scatter_gather_nonblocking(self):
329 data = range(16)
330 data = range(16)
330 view = self.client[:]
331 view = self.client[:]
331 view.scatter('a', data, block=False)
332 view.scatter('a', data, block=False)
332 ar = view.gather('a', block=False)
333 ar = view.gather('a', block=False)
333 self.assertEqual(ar.get(), data)
334 self.assertEqual(ar.get(), data)
334
335
335 @skip_without('numpy')
336 @skip_without('numpy')
336 def test_scatter_gather_numpy_nonblocking(self):
337 def test_scatter_gather_numpy_nonblocking(self):
337 import numpy
338 import numpy
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 a = numpy.arange(64)
340 a = numpy.arange(64)
340 view = self.client[:]
341 view = self.client[:]
341 ar = view.scatter('a', a, block=False)
342 ar = view.scatter('a', a, block=False)
342 self.assertTrue(isinstance(ar, AsyncResult))
343 self.assertTrue(isinstance(ar, AsyncResult))
343 amr = view.gather('a', block=False)
344 amr = view.gather('a', block=False)
344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 self.assertTrue(isinstance(amr, AsyncMapResult))
345 assert_array_equal(amr.get(), a)
346 assert_array_equal(amr.get(), a)
346
347
347 def test_execute(self):
348 def test_execute(self):
348 view = self.client[:]
349 view = self.client[:]
349 # self.client.debug=True
350 # self.client.debug=True
350 execute = view.execute
351 execute = view.execute
351 ar = execute('c=30', block=False)
352 ar = execute('c=30', block=False)
352 self.assertTrue(isinstance(ar, AsyncResult))
353 self.assertTrue(isinstance(ar, AsyncResult))
353 ar = execute('d=[0,1,2]', block=False)
354 ar = execute('d=[0,1,2]', block=False)
354 self.client.wait(ar, 1)
355 self.client.wait(ar, 1)
355 self.assertEqual(len(ar.get()), len(self.client))
356 self.assertEqual(len(ar.get()), len(self.client))
356 for c in view['c']:
357 for c in view['c']:
357 self.assertEqual(c, 30)
358 self.assertEqual(c, 30)
358
359
359 def test_abort(self):
360 def test_abort(self):
360 view = self.client[-1]
361 view = self.client[-1]
361 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar2 = view.apply_async(lambda : 2)
363 ar2 = view.apply_async(lambda : 2)
363 ar3 = view.apply_async(lambda : 3)
364 ar3 = view.apply_async(lambda : 3)
364 view.abort(ar2)
365 view.abort(ar2)
365 view.abort(ar3.msg_ids)
366 view.abort(ar3.msg_ids)
366 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
368 self.assertRaises(error.TaskAborted, ar3.get)
368
369
369 def test_abort_all(self):
370 def test_abort_all(self):
370 """view.abort() aborts all outstanding tasks"""
371 """view.abort() aborts all outstanding tasks"""
371 view = self.client[-1]
372 view = self.client[-1]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 view.abort()
374 view.abort()
374 view.wait(timeout=5)
375 view.wait(timeout=5)
375 for ar in ars[5:]:
376 for ar in ars[5:]:
376 self.assertRaises(error.TaskAborted, ar.get)
377 self.assertRaises(error.TaskAborted, ar.get)
377
378
378 def test_temp_flags(self):
379 def test_temp_flags(self):
379 view = self.client[-1]
380 view = self.client[-1]
380 view.block=True
381 view.block=True
381 with view.temp_flags(block=False):
382 with view.temp_flags(block=False):
382 self.assertFalse(view.block)
383 self.assertFalse(view.block)
383 self.assertTrue(view.block)
384 self.assertTrue(view.block)
384
385
385 @dec.known_failure_py3
386 @dec.known_failure_py3
386 def test_importer(self):
387 def test_importer(self):
387 view = self.client[-1]
388 view = self.client[-1]
388 view.clear(block=True)
389 view.clear(block=True)
389 with view.importer:
390 with view.importer:
390 import re
391 import re
391
392
392 @interactive
393 @interactive
393 def findall(pat, s):
394 def findall(pat, s):
394 # this globals() step isn't necessary in real code
395 # this globals() step isn't necessary in real code
395 # only to prevent a closure in the test
396 # only to prevent a closure in the test
396 re = globals()['re']
397 re = globals()['re']
397 return re.findall(pat, s)
398 return re.findall(pat, s)
398
399
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400
401
401 def test_unicode_execute(self):
402 def test_unicode_execute(self):
402 """test executing unicode strings"""
403 """test executing unicode strings"""
403 v = self.client[-1]
404 v = self.client[-1]
404 v.block=True
405 v.block=True
405 if sys.version_info[0] >= 3:
406 if sys.version_info[0] >= 3:
406 code="a='é'"
407 code="a='é'"
407 else:
408 else:
408 code=u"a=u'é'"
409 code=u"a=u'é'"
409 v.execute(code)
410 v.execute(code)
410 self.assertEqual(v['a'], u'é')
411 self.assertEqual(v['a'], u'é')
411
412
412 def test_unicode_apply_result(self):
413 def test_unicode_apply_result(self):
413 """test unicode apply results"""
414 """test unicode apply results"""
414 v = self.client[-1]
415 v = self.client[-1]
415 r = v.apply_sync(lambda : u'é')
416 r = v.apply_sync(lambda : u'é')
416 self.assertEqual(r, u'é')
417 self.assertEqual(r, u'é')
417
418
418 def test_unicode_apply_arg(self):
419 def test_unicode_apply_arg(self):
419 """test passing unicode arguments to apply"""
420 """test passing unicode arguments to apply"""
420 v = self.client[-1]
421 v = self.client[-1]
421
422
422 @interactive
423 @interactive
423 def check_unicode(a, check):
424 def check_unicode(a, check):
424 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427
428
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 try:
430 try:
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 except error.RemoteError as e:
432 except error.RemoteError as e:
432 if e.ename == 'AssertionError':
433 if e.ename == 'AssertionError':
433 self.fail(e.evalue)
434 self.fail(e.evalue)
434 else:
435 else:
435 raise e
436 raise e
436
437
437 def test_map_reference(self):
438 def test_map_reference(self):
438 """view.map(<Reference>, *seqs) should work"""
439 """view.map(<Reference>, *seqs) should work"""
439 v = self.client[:]
440 v = self.client[:]
440 v.scatter('n', self.client.ids, flatten=True)
441 v.scatter('n', self.client.ids, flatten=True)
441 v.execute("f = lambda x,y: x*y")
442 v.execute("f = lambda x,y: x*y")
442 rf = pmod.Reference('f')
443 rf = pmod.Reference('f')
443 nlist = list(range(10))
444 nlist = list(range(10))
444 mlist = nlist[::-1]
445 mlist = nlist[::-1]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 result = v.map_sync(rf, mlist, nlist)
447 result = v.map_sync(rf, mlist, nlist)
447 self.assertEqual(result, expected)
448 self.assertEqual(result, expected)
448
449
449 def test_apply_reference(self):
450 def test_apply_reference(self):
450 """view.apply(<Reference>, *args) should work"""
451 """view.apply(<Reference>, *args) should work"""
451 v = self.client[:]
452 v = self.client[:]
452 v.scatter('n', self.client.ids, flatten=True)
453 v.scatter('n', self.client.ids, flatten=True)
453 v.execute("f = lambda x: n*x")
454 v.execute("f = lambda x: n*x")
454 rf = pmod.Reference('f')
455 rf = pmod.Reference('f')
455 result = v.apply_sync(rf, 5)
456 result = v.apply_sync(rf, 5)
456 expected = [ 5*id for id in self.client.ids ]
457 expected = [ 5*id for id in self.client.ids ]
457 self.assertEqual(result, expected)
458 self.assertEqual(result, expected)
458
459
459 def test_eval_reference(self):
460 def test_eval_reference(self):
460 v = self.client[self.client.ids[0]]
461 v = self.client[self.client.ids[0]]
461 v['g'] = range(5)
462 v['g'] = range(5)
462 rg = pmod.Reference('g[0]')
463 rg = pmod.Reference('g[0]')
463 echo = lambda x:x
464 echo = lambda x:x
464 self.assertEqual(v.apply_sync(echo, rg), 0)
465 self.assertEqual(v.apply_sync(echo, rg), 0)
465
466
466 def test_reference_nameerror(self):
467 def test_reference_nameerror(self):
467 v = self.client[self.client.ids[0]]
468 v = self.client[self.client.ids[0]]
468 r = pmod.Reference('elvis_has_left')
469 r = pmod.Reference('elvis_has_left')
469 echo = lambda x:x
470 echo = lambda x:x
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471
472
472 def test_single_engine_map(self):
473 def test_single_engine_map(self):
473 e0 = self.client[self.client.ids[0]]
474 e0 = self.client[self.client.ids[0]]
474 r = range(5)
475 r = range(5)
475 check = [ -1*i for i in r ]
476 check = [ -1*i for i in r ]
476 result = e0.map_sync(lambda x: -1*x, r)
477 result = e0.map_sync(lambda x: -1*x, r)
477 self.assertEqual(result, check)
478 self.assertEqual(result, check)
478
479
479 def test_len(self):
480 def test_len(self):
480 """len(view) makes sense"""
481 """len(view) makes sense"""
481 e0 = self.client[self.client.ids[0]]
482 e0 = self.client[self.client.ids[0]]
482 yield self.assertEqual(len(e0), 1)
483 yield self.assertEqual(len(e0), 1)
483 v = self.client[:]
484 v = self.client[:]
484 yield self.assertEqual(len(v), len(self.client.ids))
485 yield self.assertEqual(len(v), len(self.client.ids))
485 v = self.client.direct_view('all')
486 v = self.client.direct_view('all')
486 yield self.assertEqual(len(v), len(self.client.ids))
487 yield self.assertEqual(len(v), len(self.client.ids))
487 v = self.client[:2]
488 v = self.client[:2]
488 yield self.assertEqual(len(v), 2)
489 yield self.assertEqual(len(v), 2)
489 v = self.client[:1]
490 v = self.client[:1]
490 yield self.assertEqual(len(v), 1)
491 yield self.assertEqual(len(v), 1)
491 v = self.client.load_balanced_view()
492 v = self.client.load_balanced_view()
492 yield self.assertEqual(len(v), len(self.client.ids))
493 yield self.assertEqual(len(v), len(self.client.ids))
493 # parametric tests seem to require manual closing?
494 # parametric tests seem to require manual closing?
494 self.client.close()
495 self.client.close()
495
496
496
497
497 # begin execute tests
498 # begin execute tests
498
499
499 def test_execute_reply(self):
500 def test_execute_reply(self):
500 e0 = self.client[self.client.ids[0]]
501 e0 = self.client[self.client.ids[0]]
501 e0.block = True
502 e0.block = True
502 ar = e0.execute("5", silent=False)
503 ar = e0.execute("5", silent=False)
503 er = ar.get()
504 er = ar.get()
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506 self.assertEqual(er.pyout['data']['text/plain'], '5')
506
507
507 def test_execute_reply_stdout(self):
508 def test_execute_reply_stdout(self):
508 e0 = self.client[self.client.ids[0]]
509 e0 = self.client[self.client.ids[0]]
509 e0.block = True
510 e0.block = True
510 ar = e0.execute("print (5)", silent=False)
511 ar = e0.execute("print (5)", silent=False)
511 er = ar.get()
512 er = ar.get()
512 self.assertEqual(er.stdout.strip(), '5')
513 self.assertEqual(er.stdout.strip(), '5')
513
514
514 def test_execute_pyout(self):
515 def test_execute_pyout(self):
515 """execute triggers pyout with silent=False"""
516 """execute triggers pyout with silent=False"""
516 view = self.client[:]
517 view = self.client[:]
517 ar = view.execute("5", silent=False, block=True)
518 ar = view.execute("5", silent=False, block=True)
518
519
519 expected = [{'text/plain' : '5'}] * len(view)
520 expected = [{'text/plain' : '5'}] * len(view)
520 mimes = [ out['data'] for out in ar.pyout ]
521 mimes = [ out['data'] for out in ar.pyout ]
521 self.assertEqual(mimes, expected)
522 self.assertEqual(mimes, expected)
522
523
523 def test_execute_silent(self):
524 def test_execute_silent(self):
524 """execute does not trigger pyout with silent=True"""
525 """execute does not trigger pyout with silent=True"""
525 view = self.client[:]
526 view = self.client[:]
526 ar = view.execute("5", block=True)
527 ar = view.execute("5", block=True)
527 expected = [None] * len(view)
528 expected = [None] * len(view)
528 self.assertEqual(ar.pyout, expected)
529 self.assertEqual(ar.pyout, expected)
529
530
530 def test_execute_magic(self):
531 def test_execute_magic(self):
531 """execute accepts IPython commands"""
532 """execute accepts IPython commands"""
532 view = self.client[:]
533 view = self.client[:]
533 view.execute("a = 5")
534 view.execute("a = 5")
534 ar = view.execute("%whos", block=True)
535 ar = view.execute("%whos", block=True)
535 # this will raise, if that failed
536 # this will raise, if that failed
536 ar.get(5)
537 ar.get(5)
537 for stdout in ar.stdout:
538 for stdout in ar.stdout:
538 lines = stdout.splitlines()
539 lines = stdout.splitlines()
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 found = False
541 found = False
541 for line in lines[2:]:
542 for line in lines[2:]:
542 split = line.split()
543 split = line.split()
543 if split == ['a', 'int', '5']:
544 if split == ['a', 'int', '5']:
544 found = True
545 found = True
545 break
546 break
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547 self.assertTrue(found, "whos output wrong: %s" % stdout)
547
548
548 def test_execute_displaypub(self):
549 def test_execute_displaypub(self):
549 """execute tracks display_pub output"""
550 """execute tracks display_pub output"""
550 view = self.client[:]
551 view = self.client[:]
551 view.execute("from IPython.core.display import *")
552 view.execute("from IPython.core.display import *")
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553
554
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 for outputs in ar.outputs:
556 for outputs in ar.outputs:
556 mimes = [ out['data'] for out in outputs ]
557 mimes = [ out['data'] for out in outputs ]
557 self.assertEqual(mimes, expected)
558 self.assertEqual(mimes, expected)
558
559
559 def test_apply_displaypub(self):
560 def test_apply_displaypub(self):
560 """apply tracks display_pub output"""
561 """apply tracks display_pub output"""
561 view = self.client[:]
562 view = self.client[:]
562 view.execute("from IPython.core.display import *")
563 view.execute("from IPython.core.display import *")
563
564
564 @interactive
565 @interactive
565 def publish():
566 def publish():
566 [ display(i) for i in range(5) ]
567 [ display(i) for i in range(5) ]
567
568
568 ar = view.apply_async(publish)
569 ar = view.apply_async(publish)
569 ar.get(5)
570 ar.get(5)
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 for outputs in ar.outputs:
572 for outputs in ar.outputs:
572 mimes = [ out['data'] for out in outputs ]
573 mimes = [ out['data'] for out in outputs ]
573 self.assertEqual(mimes, expected)
574 self.assertEqual(mimes, expected)
574
575
575 def test_execute_raises(self):
576 def test_execute_raises(self):
576 """exceptions in execute requests raise appropriately"""
577 """exceptions in execute requests raise appropriately"""
577 view = self.client[-1]
578 view = self.client[-1]
578 ar = view.execute("1/0")
579 ar = view.execute("1/0")
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580
581
582 def test_remoteerror_render_exception(self):
583 """RemoteErrors get nice tracebacks"""
584 view = self.client[-1]
585 ar = view.execute("1/0")
586 ip = get_ipython()
587 ip.user_ns['ar'] = ar
588 with capture_output() as io:
589 ip.run_cell("ar.get(2)")
590
591 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
592
593 def test_compositeerror_render_exception(self):
594 """CompositeErrors get nice tracebacks"""
595 view = self.client[:]
596 ar = view.execute("1/0")
597 ip = get_ipython()
598 ip.user_ns['ar'] = ar
599 with capture_output() as io:
600 ip.run_cell("ar.get(2)")
601
602 self.assertEqual(io.stdout.count('ZeroDivisionError'), len(view) * 2, io.stdout)
603 self.assertEqual(io.stdout.count('integer division'), len(view), io.stdout)
604 self.assertEqual(io.stdout.count(':execute'), len(view), io.stdout)
605
581 @dec.skipif_not_matplotlib
606 @dec.skipif_not_matplotlib
582 def test_magic_pylab(self):
607 def test_magic_pylab(self):
583 """%pylab works on engines"""
608 """%pylab works on engines"""
584 view = self.client[-1]
609 view = self.client[-1]
585 ar = view.execute("%pylab inline")
610 ar = view.execute("%pylab inline")
586 # at least check if this raised:
611 # at least check if this raised:
587 reply = ar.get(5)
612 reply = ar.get(5)
588 # include imports, in case user config
613 # include imports, in case user config
589 ar = view.execute("plot(rand(100))", silent=False)
614 ar = view.execute("plot(rand(100))", silent=False)
590 reply = ar.get(5)
615 reply = ar.get(5)
591 self.assertEqual(len(reply.outputs), 1)
616 self.assertEqual(len(reply.outputs), 1)
592 output = reply.outputs[0]
617 output = reply.outputs[0]
593 self.assertTrue("data" in output)
618 self.assertTrue("data" in output)
594 data = output['data']
619 data = output['data']
595 self.assertTrue("image/png" in data)
620 self.assertTrue("image/png" in data)
596
621
597 def test_func_default_func(self):
622 def test_func_default_func(self):
598 """interactively defined function as apply func default"""
623 """interactively defined function as apply func default"""
599 def foo():
624 def foo():
600 return 'foo'
625 return 'foo'
601
626
602 def bar(f=foo):
627 def bar(f=foo):
603 return f()
628 return f()
604
629
605 view = self.client[-1]
630 view = self.client[-1]
606 ar = view.apply_async(bar)
631 ar = view.apply_async(bar)
607 r = ar.get(10)
632 r = ar.get(10)
608 self.assertEqual(r, 'foo')
633 self.assertEqual(r, 'foo')
609 def test_data_pub_single(self):
634 def test_data_pub_single(self):
610 view = self.client[-1]
635 view = self.client[-1]
611 ar = view.execute('\n'.join([
636 ar = view.execute('\n'.join([
612 'from IPython.zmq.datapub import publish_data',
637 'from IPython.zmq.datapub import publish_data',
613 'for i in range(5):',
638 'for i in range(5):',
614 ' publish_data(dict(i=i))'
639 ' publish_data(dict(i=i))'
615 ]), block=False)
640 ]), block=False)
616 self.assertTrue(isinstance(ar.data, dict))
641 self.assertTrue(isinstance(ar.data, dict))
617 ar.get(5)
642 ar.get(5)
618 self.assertEqual(ar.data, dict(i=4))
643 self.assertEqual(ar.data, dict(i=4))
619
644
620 def test_data_pub(self):
645 def test_data_pub(self):
621 view = self.client[:]
646 view = self.client[:]
622 ar = view.execute('\n'.join([
647 ar = view.execute('\n'.join([
623 'from IPython.zmq.datapub import publish_data',
648 'from IPython.zmq.datapub import publish_data',
624 'for i in range(5):',
649 'for i in range(5):',
625 ' publish_data(dict(i=i))'
650 ' publish_data(dict(i=i))'
626 ]), block=False)
651 ]), block=False)
627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
652 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
628 ar.get(5)
653 ar.get(5)
629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
654 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
630
655
631 def test_can_list_arg(self):
656 def test_can_list_arg(self):
632 """args in lists are canned"""
657 """args in lists are canned"""
633 view = self.client[-1]
658 view = self.client[-1]
634 view['a'] = 128
659 view['a'] = 128
635 rA = pmod.Reference('a')
660 rA = pmod.Reference('a')
636 ar = view.apply_async(lambda x: x, [rA])
661 ar = view.apply_async(lambda x: x, [rA])
637 r = ar.get(5)
662 r = ar.get(5)
638 self.assertEqual(r, [128])
663 self.assertEqual(r, [128])
639
664
640 def test_can_dict_arg(self):
665 def test_can_dict_arg(self):
641 """args in dicts are canned"""
666 """args in dicts are canned"""
642 view = self.client[-1]
667 view = self.client[-1]
643 view['a'] = 128
668 view['a'] = 128
644 rA = pmod.Reference('a')
669 rA = pmod.Reference('a')
645 ar = view.apply_async(lambda x: x, dict(foo=rA))
670 ar = view.apply_async(lambda x: x, dict(foo=rA))
646 r = ar.get(5)
671 r = ar.get(5)
647 self.assertEqual(r, dict(foo=128))
672 self.assertEqual(r, dict(foo=128))
648
673
649 def test_can_list_kwarg(self):
674 def test_can_list_kwarg(self):
650 """kwargs in lists are canned"""
675 """kwargs in lists are canned"""
651 view = self.client[-1]
676 view = self.client[-1]
652 view['a'] = 128
677 view['a'] = 128
653 rA = pmod.Reference('a')
678 rA = pmod.Reference('a')
654 ar = view.apply_async(lambda x=5: x, x=[rA])
679 ar = view.apply_async(lambda x=5: x, x=[rA])
655 r = ar.get(5)
680 r = ar.get(5)
656 self.assertEqual(r, [128])
681 self.assertEqual(r, [128])
657
682
658 def test_can_dict_kwarg(self):
683 def test_can_dict_kwarg(self):
659 """kwargs in dicts are canned"""
684 """kwargs in dicts are canned"""
660 view = self.client[-1]
685 view = self.client[-1]
661 view['a'] = 128
686 view['a'] = 128
662 rA = pmod.Reference('a')
687 rA = pmod.Reference('a')
663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
688 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
664 r = ar.get(5)
689 r = ar.get(5)
665 self.assertEqual(r, dict(foo=128))
690 self.assertEqual(r, dict(foo=128))
666
691
667 def test_map_ref(self):
692 def test_map_ref(self):
668 """view.map works with references"""
693 """view.map works with references"""
669 view = self.client[:]
694 view = self.client[:]
670 ranks = sorted(self.client.ids)
695 ranks = sorted(self.client.ids)
671 view.scatter('rank', ranks, flatten=True)
696 view.scatter('rank', ranks, flatten=True)
672 rrank = pmod.Reference('rank')
697 rrank = pmod.Reference('rank')
673
698
674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
699 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
675 drank = amr.get(5)
700 drank = amr.get(5)
676 self.assertEqual(drank, [ r*2 for r in ranks ])
701 self.assertEqual(drank, [ r*2 for r in ranks ])
677
702
678
703
General Comments 0
You need to be logged in to leave comments. Login now