##// END OF EJS Templates
fix/test pushed function globals
MinRK -
Show More
@@ -1,127 +1,127 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Pickle related utilities. Perhaps this should be called 'can'."""
3 """Pickle related utilities. Perhaps this should be called 'can'."""
4
4
5 __docformat__ = "restructuredtext en"
5 __docformat__ = "restructuredtext en"
6
6
7 #-------------------------------------------------------------------------------
7 #-------------------------------------------------------------------------------
8 # Copyright (C) 2008 The IPython Development Team
8 # Copyright (C) 2008 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 #-------------------------------------------------------------------------------
14 #-------------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 from types import FunctionType
18 from types import FunctionType
19 import copy
19 import copy
20
20
21 from IPython.zmq.parallel.dependency import dependent
21 from IPython.zmq.parallel.dependency import dependent
22
22
23 import codeutil
23 import codeutil
24
24
25 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
26 # Classes
26 # Classes
27 #-------------------------------------------------------------------------------
27 #-------------------------------------------------------------------------------
28
28
29
29
30 class CannedObject(object):
30 class CannedObject(object):
31 def __init__(self, obj, keys=[]):
31 def __init__(self, obj, keys=[]):
32 self.keys = keys
32 self.keys = keys
33 self.obj = copy.copy(obj)
33 self.obj = copy.copy(obj)
34 for key in keys:
34 for key in keys:
35 setattr(self.obj, key, can(getattr(obj, key)))
35 setattr(self.obj, key, can(getattr(obj, key)))
36
36
37
37
38 def getObject(self, g=None):
38 def getObject(self, g=None):
39 if g is None:
39 if g is None:
40 g = globals()
40 g = globals()
41 for key in self.keys:
41 for key in self.keys:
42 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
42 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
43 return self.obj
43 return self.obj
44
44
45
45
46
46
47 class CannedFunction(CannedObject):
47 class CannedFunction(CannedObject):
48
48
49 def __init__(self, f):
49 def __init__(self, f):
50 self._checkType(f)
50 self._checkType(f)
51 self.code = f.func_code
51 self.code = f.func_code
52 self.__name__ = f.__name__
52 self.__name__ = f.__name__
53
53
54 def _checkType(self, obj):
54 def _checkType(self, obj):
55 assert isinstance(obj, FunctionType), "Not a function type"
55 assert isinstance(obj, FunctionType), "Not a function type"
56
56
57 def getFunction(self, g=None):
57 def getFunction(self, g=None):
58 if g is None:
58 if g is None:
59 g = globals()
59 g = globals()
60 newFunc = FunctionType(self.code, g)
60 newFunc = FunctionType(self.code, g)
61 return newFunc
61 return newFunc
62
62
63 #-------------------------------------------------------------------------------
63 #-------------------------------------------------------------------------------
64 # Functions
64 # Functions
65 #-------------------------------------------------------------------------------
65 #-------------------------------------------------------------------------------
66
66
67
67
68 def can(obj):
68 def can(obj):
69 if isinstance(obj, FunctionType):
69 if isinstance(obj, FunctionType):
70 return CannedFunction(obj)
70 return CannedFunction(obj)
71 elif isinstance(obj, dependent):
71 elif isinstance(obj, dependent):
72 keys = ('f','df')
72 keys = ('f','df')
73 return CannedObject(obj, keys=keys)
73 return CannedObject(obj, keys=keys)
74 elif isinstance(obj,dict):
74 elif isinstance(obj,dict):
75 return canDict(obj)
75 return canDict(obj)
76 elif isinstance(obj, (list,tuple)):
76 elif isinstance(obj, (list,tuple)):
77 return canSequence(obj)
77 return canSequence(obj)
78 else:
78 else:
79 return obj
79 return obj
80
80
81 def canDict(obj):
81 def canDict(obj):
82 if isinstance(obj, dict):
82 if isinstance(obj, dict):
83 newobj = {}
83 newobj = {}
84 for k, v in obj.iteritems():
84 for k, v in obj.iteritems():
85 newobj[k] = can(v)
85 newobj[k] = can(v)
86 return newobj
86 return newobj
87 else:
87 else:
88 return obj
88 return obj
89
89
90 def canSequence(obj):
90 def canSequence(obj):
91 if isinstance(obj, (list, tuple)):
91 if isinstance(obj, (list, tuple)):
92 t = type(obj)
92 t = type(obj)
93 return t([can(i) for i in obj])
93 return t([can(i) for i in obj])
94 else:
94 else:
95 return obj
95 return obj
96
96
97 def uncan(obj, g=None):
97 def uncan(obj, g=None):
98 if isinstance(obj, CannedFunction):
98 if isinstance(obj, CannedFunction):
99 return obj.getFunction(g)
99 return obj.getFunction(g)
100 elif isinstance(obj, CannedObject):
100 elif isinstance(obj, CannedObject):
101 return obj.getObject(g)
101 return obj.getObject(g)
102 elif isinstance(obj,dict):
102 elif isinstance(obj,dict):
103 return uncanDict(obj)
103 return uncanDict(obj, g)
104 elif isinstance(obj, (list,tuple)):
104 elif isinstance(obj, (list,tuple)):
105 return uncanSequence(obj)
105 return uncanSequence(obj, g)
106 else:
106 else:
107 return obj
107 return obj
108
108
109 def uncanDict(obj, g=None):
109 def uncanDict(obj, g=None):
110 if isinstance(obj, dict):
110 if isinstance(obj, dict):
111 newobj = {}
111 newobj = {}
112 for k, v in obj.iteritems():
112 for k, v in obj.iteritems():
113 newobj[k] = uncan(v,g)
113 newobj[k] = uncan(v,g)
114 return newobj
114 return newobj
115 else:
115 else:
116 return obj
116 return obj
117
117
118 def uncanSequence(obj, g=None):
118 def uncanSequence(obj, g=None):
119 if isinstance(obj, (list, tuple)):
119 if isinstance(obj, (list, tuple)):
120 t = type(obj)
120 t = type(obj)
121 return t([uncan(i,g) for i in obj])
121 return t([uncan(i,g) for i in obj])
122 else:
122 else:
123 return obj
123 return obj
124
124
125
125
126 def rebindFunctionGlobals(f, glbls):
126 def rebindFunctionGlobals(f, glbls):
127 return FunctionType(f.func_code, glbls)
127 return FunctionType(f.func_code, glbls)
@@ -1,295 +1,292 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes and functions for kernel related errors and exceptions."""
3 """Classes and functions for kernel related errors and exceptions."""
4 from __future__ import print_function
4 from __future__ import print_function
5
5
6 __docformat__ = "restructuredtext en"
6 __docformat__ = "restructuredtext en"
7
7
8 # Tell nose to skip this module
8 # Tell nose to skip this module
9 __test__ = {}
9 __test__ = {}
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
12 # Copyright (C) 2008 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Error classes
19 # Error classes
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21 class IPythonError(Exception):
21 class IPythonError(Exception):
22 """Base exception that all of our exceptions inherit from.
22 """Base exception that all of our exceptions inherit from.
23
23
24 This can be raised by code that doesn't have any more specific
24 This can be raised by code that doesn't have any more specific
25 information."""
25 information."""
26
26
27 pass
27 pass
28
28
29 # Exceptions associated with the controller objects
29 # Exceptions associated with the controller objects
30 class ControllerError(IPythonError): pass
30 class ControllerError(IPythonError): pass
31
31
32 class ControllerCreationError(ControllerError): pass
32 class ControllerCreationError(ControllerError): pass
33
33
34
34
35 # Exceptions associated with the Engines
35 # Exceptions associated with the Engines
36 class EngineError(IPythonError): pass
36 class EngineError(IPythonError): pass
37
37
38 class EngineCreationError(EngineError): pass
38 class EngineCreationError(EngineError): pass
39
39
40 class KernelError(IPythonError):
40 class KernelError(IPythonError):
41 pass
41 pass
42
42
43 class NotDefined(KernelError):
43 class NotDefined(KernelError):
44 def __init__(self, name):
44 def __init__(self, name):
45 self.name = name
45 self.name = name
46 self.args = (name,)
46 self.args = (name,)
47
47
48 def __repr__(self):
48 def __repr__(self):
49 return '<NotDefined: %s>' % self.name
49 return '<NotDefined: %s>' % self.name
50
50
51 __str__ = __repr__
51 __str__ = __repr__
52
52
53
53
54 class QueueCleared(KernelError):
54 class QueueCleared(KernelError):
55 pass
55 pass
56
56
57
57
58 class IdInUse(KernelError):
58 class IdInUse(KernelError):
59 pass
59 pass
60
60
61
61
62 class ProtocolError(KernelError):
62 class ProtocolError(KernelError):
63 pass
63 pass
64
64
65
65
66 class ConnectionError(KernelError):
66 class ConnectionError(KernelError):
67 pass
67 pass
68
68
69
69
70 class InvalidEngineID(KernelError):
70 class InvalidEngineID(KernelError):
71 pass
71 pass
72
72
73
73
74 class NoEnginesRegistered(KernelError):
74 class NoEnginesRegistered(KernelError):
75 pass
75 pass
76
76
77
77
78 class InvalidClientID(KernelError):
78 class InvalidClientID(KernelError):
79 pass
79 pass
80
80
81
81
82 class InvalidDeferredID(KernelError):
82 class InvalidDeferredID(KernelError):
83 pass
83 pass
84
84
85
85
86 class SerializationError(KernelError):
86 class SerializationError(KernelError):
87 pass
87 pass
88
88
89
89
90 class MessageSizeError(KernelError):
90 class MessageSizeError(KernelError):
91 pass
91 pass
92
92
93
93
94 class PBMessageSizeError(MessageSizeError):
94 class PBMessageSizeError(MessageSizeError):
95 pass
95 pass
96
96
97
97
98 class ResultNotCompleted(KernelError):
98 class ResultNotCompleted(KernelError):
99 pass
99 pass
100
100
101
101
102 class ResultAlreadyRetrieved(KernelError):
102 class ResultAlreadyRetrieved(KernelError):
103 pass
103 pass
104
104
105 class ClientError(KernelError):
105 class ClientError(KernelError):
106 pass
106 pass
107
107
108
108
109 class TaskAborted(KernelError):
109 class TaskAborted(KernelError):
110 pass
110 pass
111
111
112
112
113 class TaskTimeout(KernelError):
113 class TaskTimeout(KernelError):
114 pass
114 pass
115
115
116
116
117 class NotAPendingResult(KernelError):
117 class NotAPendingResult(KernelError):
118 pass
118 pass
119
119
120
120
121 class UnpickleableException(KernelError):
121 class UnpickleableException(KernelError):
122 pass
122 pass
123
123
124
124
125 class AbortedPendingDeferredError(KernelError):
125 class AbortedPendingDeferredError(KernelError):
126 pass
126 pass
127
127
128
128
129 class InvalidProperty(KernelError):
129 class InvalidProperty(KernelError):
130 pass
130 pass
131
131
132
132
133 class MissingBlockArgument(KernelError):
133 class MissingBlockArgument(KernelError):
134 pass
134 pass
135
135
136
136
137 class StopLocalExecution(KernelError):
137 class StopLocalExecution(KernelError):
138 pass
138 pass
139
139
140
140
141 class SecurityError(KernelError):
141 class SecurityError(KernelError):
142 pass
142 pass
143
143
144
144
145 class FileTimeoutError(KernelError):
145 class FileTimeoutError(KernelError):
146 pass
146 pass
147
147
148 class TimeoutError(KernelError):
148 class TimeoutError(KernelError):
149 pass
149 pass
150
150
151 class UnmetDependency(KernelError):
151 class UnmetDependency(KernelError):
152 pass
152 pass
153
153
154 class ImpossibleDependency(UnmetDependency):
154 class ImpossibleDependency(UnmetDependency):
155 pass
155 pass
156
156
157 class DependencyTimeout(ImpossibleDependency):
157 class DependencyTimeout(ImpossibleDependency):
158 pass
158 pass
159
159
160 class InvalidDependency(ImpossibleDependency):
160 class InvalidDependency(ImpossibleDependency):
161 pass
161 pass
162
162
163 class RemoteError(KernelError):
163 class RemoteError(KernelError):
164 """Error raised elsewhere"""
164 """Error raised elsewhere"""
165 ename=None
165 ename=None
166 evalue=None
166 evalue=None
167 traceback=None
167 traceback=None
168 engine_info=None
168 engine_info=None
169
169
170 def __init__(self, ename, evalue, traceback, engine_info=None):
170 def __init__(self, ename, evalue, traceback, engine_info=None):
171 self.ename=ename
171 self.ename=ename
172 self.evalue=evalue
172 self.evalue=evalue
173 self.traceback=traceback
173 self.traceback=traceback
174 self.engine_info=engine_info or {}
174 self.engine_info=engine_info or {}
175 self.args=(ename, evalue)
175 self.args=(ename, evalue)
176
176
177 def __repr__(self):
177 def __repr__(self):
178 engineid = self.engine_info.get('engineid', ' ')
178 engineid = self.engine_info.get('engineid', ' ')
179 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
179 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
180
180
181 def __str__(self):
181 def __str__(self):
182 sig = "%s(%s)"%(self.ename, self.evalue)
182 sig = "%s(%s)"%(self.ename, self.evalue)
183 if self.traceback:
183 if self.traceback:
184 return sig + '\n' + self.traceback
184 return sig + '\n' + self.traceback
185 else:
185 else:
186 return sig
186 return sig
187
187
188
188
189 class TaskRejectError(KernelError):
189 class TaskRejectError(KernelError):
190 """Exception to raise when a task should be rejected by an engine.
190 """Exception to raise when a task should be rejected by an engine.
191
191
192 This exception can be used to allow a task running on an engine to test
192 This exception can be used to allow a task running on an engine to test
193 if the engine (or the user's namespace on the engine) has the needed
193 if the engine (or the user's namespace on the engine) has the needed
194 task dependencies. If not, the task should raise this exception. For
194 task dependencies. If not, the task should raise this exception. For
195 the task to be retried on another engine, the task should be created
195 the task to be retried on another engine, the task should be created
196 with the `retries` argument > 1.
196 with the `retries` argument > 1.
197
197
198 The advantage of this approach over our older properties system is that
198 The advantage of this approach over our older properties system is that
199 tasks have full access to the user's namespace on the engines and the
199 tasks have full access to the user's namespace on the engines and the
200 properties don't have to be managed or tested by the controller.
200 properties don't have to be managed or tested by the controller.
201 """
201 """
202
202
203
203
204 class CompositeError(KernelError):
204 class CompositeError(RemoteError):
205 """Error for representing possibly multiple errors on engines"""
205 """Error for representing possibly multiple errors on engines"""
206 def __init__(self, message, elist):
206 def __init__(self, message, elist):
207 Exception.__init__(self, *(message, elist))
207 Exception.__init__(self, *(message, elist))
208 # Don't use pack_exception because it will conflict with the .message
208 # Don't use pack_exception because it will conflict with the .message
209 # attribute that is being deprecated in 2.6 and beyond.
209 # attribute that is being deprecated in 2.6 and beyond.
210 self.msg = message
210 self.msg = message
211 self.elist = elist
211 self.elist = elist
212 self.args = [ e[0] for e in elist ]
212 self.args = [ e[0] for e in elist ]
213
213
214 def _get_engine_str(self, ei):
214 def _get_engine_str(self, ei):
215 if not ei:
215 if not ei:
216 return '[Engine Exception]'
216 return '[Engine Exception]'
217 else:
217 else:
218 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
218 return '[%s:%s]: ' % (ei['engineid'], ei['method'])
219
219
220 def _get_traceback(self, ev):
220 def _get_traceback(self, ev):
221 try:
221 try:
222 tb = ev._ipython_traceback_text
222 tb = ev._ipython_traceback_text
223 except AttributeError:
223 except AttributeError:
224 return 'No traceback available'
224 return 'No traceback available'
225 else:
225 else:
226 return tb
226 return tb
227
227
228 def __str__(self):
228 def __str__(self):
229 s = str(self.msg)
229 s = str(self.msg)
230 for en, ev, etb, ei in self.elist:
230 for en, ev, etb, ei in self.elist:
231 engine_str = self._get_engine_str(ei)
231 engine_str = self._get_engine_str(ei)
232 s = s + '\n' + engine_str + en + ': ' + str(ev)
232 s = s + '\n' + engine_str + en + ': ' + str(ev)
233 return s
233 return s
234
234
235 def __repr__(self):
235 def __repr__(self):
236 return "CompositeError(%i)"%len(self.elist)
236 return "CompositeError(%i)"%len(self.elist)
237
237
238 def print_tracebacks(self, excid=None):
238 def print_tracebacks(self, excid=None):
239 if excid is None:
239 if excid is None:
240 for (en,ev,etb,ei) in self.elist:
240 for (en,ev,etb,ei) in self.elist:
241 print (self._get_engine_str(ei))
241 print (self._get_engine_str(ei))
242 print (etb or 'No traceback available')
242 print (etb or 'No traceback available')
243 print ()
243 print ()
244 else:
244 else:
245 try:
245 try:
246 en,ev,etb,ei = self.elist[excid]
246 en,ev,etb,ei = self.elist[excid]
247 except:
247 except:
248 raise IndexError("an exception with index %i does not exist"%excid)
248 raise IndexError("an exception with index %i does not exist"%excid)
249 else:
249 else:
250 print (self._get_engine_str(ei))
250 print (self._get_engine_str(ei))
251 print (etb or 'No traceback available')
251 print (etb or 'No traceback available')
252
252
253 def raise_exception(self, excid=0):
253 def raise_exception(self, excid=0):
254 try:
254 try:
255 en,ev,etb,ei = self.elist[excid]
255 en,ev,etb,ei = self.elist[excid]
256 except:
256 except:
257 raise IndexError("an exception with index %i does not exist"%excid)
257 raise IndexError("an exception with index %i does not exist"%excid)
258 else:
258 else:
259 try:
259 raise RemoteError(en, ev, etb, ei)
260 raise RemoteError(en, ev, etb, ei)
261 except:
262 et,ev,tb = sys.exc_info()
263
260
264
261
265 def collect_exceptions(rdict_or_list, method='unspecified'):
262 def collect_exceptions(rdict_or_list, method='unspecified'):
266 """check a result dict for errors, and raise CompositeError if any exist.
263 """check a result dict for errors, and raise CompositeError if any exist.
267 Passthrough otherwise."""
264 Passthrough otherwise."""
268 elist = []
265 elist = []
269 if isinstance(rdict_or_list, dict):
266 if isinstance(rdict_or_list, dict):
270 rlist = rdict_or_list.values()
267 rlist = rdict_or_list.values()
271 else:
268 else:
272 rlist = rdict_or_list
269 rlist = rdict_or_list
273 for r in rlist:
270 for r in rlist:
274 if isinstance(r, RemoteError):
271 if isinstance(r, RemoteError):
275 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
272 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
276 # Sometimes we could have CompositeError in our list. Just take
273 # Sometimes we could have CompositeError in our list. Just take
277 # the errors out of them and put them in our new list. This
274 # the errors out of them and put them in our new list. This
278 # has the effect of flattening lists of CompositeErrors into one
275 # has the effect of flattening lists of CompositeErrors into one
279 # CompositeError
276 # CompositeError
280 if en=='CompositeError':
277 if en=='CompositeError':
281 for e in ev.elist:
278 for e in ev.elist:
282 elist.append(e)
279 elist.append(e)
283 else:
280 else:
284 elist.append((en, ev, etb, ei))
281 elist.append((en, ev, etb, ei))
285 if len(elist)==0:
282 if len(elist)==0:
286 return rdict_or_list
283 return rdict_or_list
287 else:
284 else:
288 msg = "one or more exceptions from call to method: %s" % (method)
285 msg = "one or more exceptions from call to method: %s" % (method)
289 # This silliness is needed so the debugger has access to the exception
286 # This silliness is needed so the debugger has access to the exception
290 # instance (e in this case)
287 # instance (e in this case)
291 try:
288 try:
292 raise CompositeError(msg, elist)
289 raise CompositeError(msg, elist)
293 except CompositeError, e:
290 except CompositeError as e:
294 raise e
291 raise e
295
292
@@ -1,85 +1,96 b''
1 import time
1 import time
2 from signal import SIGINT
2 from signal import SIGINT
3 from multiprocessing import Process
3 from multiprocessing import Process
4
4
5 from nose import SkipTest
5 from nose import SkipTest
6
6
7 from zmq.tests import BaseZMQTestCase
7 from zmq.tests import BaseZMQTestCase
8
8
9 from IPython.external.decorator import decorator
9 from IPython.external.decorator import decorator
10
10
11 from IPython.zmq.parallel import error
12 from IPython.zmq.parallel.client import Client
11 from IPython.zmq.parallel.ipcluster import launch_process
13 from IPython.zmq.parallel.ipcluster import launch_process
12 from IPython.zmq.parallel.entry_point import select_random_ports
14 from IPython.zmq.parallel.entry_point import select_random_ports
13 from IPython.zmq.parallel.client import Client
14 from IPython.zmq.parallel.tests import processes,add_engine
15 from IPython.zmq.parallel.tests import processes,add_engine
15
16
16 # simple tasks for use in apply tests
17 # simple tasks for use in apply tests
17
18
18 def segfault():
19 def segfault():
19 """"""
20 """"""
20 import ctypes
21 import ctypes
21 ctypes.memset(-1,0,1)
22 ctypes.memset(-1,0,1)
22
23
23 def wait(n):
24 def wait(n):
24 """sleep for a time"""
25 """sleep for a time"""
25 import time
26 import time
26 time.sleep(n)
27 time.sleep(n)
27 return n
28 return n
28
29
29 def raiser(eclass):
30 def raiser(eclass):
30 """raise an exception"""
31 """raise an exception"""
31 raise eclass()
32 raise eclass()
32
33
33 # test decorator for skipping tests when libraries are unavailable
34 # test decorator for skipping tests when libraries are unavailable
34 def skip_without(*names):
35 def skip_without(*names):
35 """skip a test if some names are not importable"""
36 """skip a test if some names are not importable"""
36 @decorator
37 @decorator
37 def skip_without_names(f, *args, **kwargs):
38 def skip_without_names(f, *args, **kwargs):
38 """decorator to skip tests in the absence of numpy."""
39 """decorator to skip tests in the absence of numpy."""
39 for name in names:
40 for name in names:
40 try:
41 try:
41 __import__(name)
42 __import__(name)
42 except ImportError:
43 except ImportError:
43 raise SkipTest
44 raise SkipTest
44 return f(*args, **kwargs)
45 return f(*args, **kwargs)
45 return skip_without_names
46 return skip_without_names
46
47
47
48
48 class ClusterTestCase(BaseZMQTestCase):
49 class ClusterTestCase(BaseZMQTestCase):
49
50
50 def add_engines(self, n=1, block=True):
51 def add_engines(self, n=1, block=True):
51 """add multiple engines to our cluster"""
52 """add multiple engines to our cluster"""
52 for i in range(n):
53 for i in range(n):
53 self.engines.append(add_engine())
54 self.engines.append(add_engine())
54 if block:
55 if block:
55 self.wait_on_engines()
56 self.wait_on_engines()
56
57
57 def wait_on_engines(self, timeout=5):
58 def wait_on_engines(self, timeout=5):
58 """wait for our engines to connect."""
59 """wait for our engines to connect."""
59 n = len(self.engines)+self.base_engine_count
60 n = len(self.engines)+self.base_engine_count
60 tic = time.time()
61 tic = time.time()
61 while time.time()-tic < timeout and len(self.client.ids) < n:
62 while time.time()-tic < timeout and len(self.client.ids) < n:
62 time.sleep(0.1)
63 time.sleep(0.1)
63
64
64 assert not self.client.ids < n, "waiting for engines timed out"
65 assert not self.client.ids < n, "waiting for engines timed out"
65
66
66 def connect_client(self):
67 def connect_client(self):
67 """connect a client with my Context, and track its sockets for cleanup"""
68 """connect a client with my Context, and track its sockets for cleanup"""
68 c = Client(profile='iptest',context=self.context)
69 c = Client(profile='iptest',context=self.context)
69 for name in filter(lambda n:n.endswith('socket'), dir(c)):
70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
70 self.sockets.append(getattr(c, name))
71 self.sockets.append(getattr(c, name))
71 return c
72 return c
72
73
74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 try:
76 f(*args, **kwargs)
77 except error.CompositeError as e:
78 e.raise_exception()
79 except error.RemoteError as e:
80 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
81 else:
82 self.fail("should have raised a RemoteError")
83
73 def setUp(self):
84 def setUp(self):
74 BaseZMQTestCase.setUp(self)
85 BaseZMQTestCase.setUp(self)
75 self.client = self.connect_client()
86 self.client = self.connect_client()
76 self.base_engine_count=len(self.client.ids)
87 self.base_engine_count=len(self.client.ids)
77 self.engines=[]
88 self.engines=[]
78
89
79 def tearDown(self):
90 def tearDown(self):
80 [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
91 [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
81 # while len(self.client.ids) > self.base_engine_count:
92 # while len(self.client.ids) > self.base_engine_count:
82 # time.sleep(.1)
93 # time.sleep(.1)
83 del self.engines
94 del self.engines
84 BaseZMQTestCase.tearDown(self)
95 BaseZMQTestCase.tearDown(self)
85 No newline at end of file
96
@@ -1,117 +1,134 b''
1 import time
1 import time
2
2
3 import nose.tools as nt
3 import nose.tools as nt
4
4
5 from IPython.zmq.parallel.asyncresult import AsyncResult
5 from IPython.zmq.parallel.asyncresult import AsyncResult
6 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
6 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
7
7
8 from clienttest import ClusterTestCase, segfault
8 from clienttest import ClusterTestCase, segfault
9
9
10 class TestClient(ClusterTestCase):
10 class TestClient(ClusterTestCase):
11
11
12 def test_ids(self):
12 def test_ids(self):
13 self.assertEquals(len(self.client.ids), 1)
13 self.assertEquals(len(self.client.ids), 1)
14 self.add_engines(3)
14 self.add_engines(3)
15 self.assertEquals(len(self.client.ids), 4)
15 self.assertEquals(len(self.client.ids), 4)
16
16
17 def test_segfault(self):
17 def test_segfault(self):
18 self.add_engines(1)
18 self.add_engines(1)
19 eid = self.client.ids[-1]
19 eid = self.client.ids[-1]
20 self.client[eid].apply(segfault)
20 self.client[eid].apply(segfault)
21 while eid in self.client.ids:
21 while eid in self.client.ids:
22 time.sleep(.01)
22 time.sleep(.01)
23 self.client.spin()
23 self.client.spin()
24
24
25 def test_view_indexing(self):
25 def test_view_indexing(self):
26 self.add_engines(4)
26 self.add_engines(4)
27 targets = self.client._build_targets('all')[-1]
27 targets = self.client._build_targets('all')[-1]
28 v = self.client[:]
28 v = self.client[:]
29 self.assertEquals(v.targets, targets)
29 self.assertEquals(v.targets, targets)
30 t = self.client.ids[2]
30 t = self.client.ids[2]
31 v = self.client[t]
31 v = self.client[t]
32 self.assert_(isinstance(v, DirectView))
32 self.assert_(isinstance(v, DirectView))
33 self.assertEquals(v.targets, t)
33 self.assertEquals(v.targets, t)
34 t = self.client.ids[2:4]
34 t = self.client.ids[2:4]
35 v = self.client[t]
35 v = self.client[t]
36 self.assert_(isinstance(v, DirectView))
36 self.assert_(isinstance(v, DirectView))
37 self.assertEquals(v.targets, t)
37 self.assertEquals(v.targets, t)
38 v = self.client[::2]
38 v = self.client[::2]
39 self.assert_(isinstance(v, DirectView))
39 self.assert_(isinstance(v, DirectView))
40 self.assertEquals(v.targets, targets[::2])
40 self.assertEquals(v.targets, targets[::2])
41 v = self.client[1::3]
41 v = self.client[1::3]
42 self.assert_(isinstance(v, DirectView))
42 self.assert_(isinstance(v, DirectView))
43 self.assertEquals(v.targets, targets[1::3])
43 self.assertEquals(v.targets, targets[1::3])
44 v = self.client[:-3]
44 v = self.client[:-3]
45 self.assert_(isinstance(v, DirectView))
45 self.assert_(isinstance(v, DirectView))
46 self.assertEquals(v.targets, targets[:-3])
46 self.assertEquals(v.targets, targets[:-3])
47 v = self.client[-1]
48 self.assert_(isinstance(v, DirectView))
49 self.assertEquals(v.targets, targets[-1])
47 nt.assert_raises(TypeError, lambda : self.client[None])
50 nt.assert_raises(TypeError, lambda : self.client[None])
48
51
49 def test_view_cache(self):
52 def test_view_cache(self):
50 """test blocking and non-blocking behavior"""
53 """test that multiple view requests return the same object"""
51 v = self.client[:2]
54 v = self.client[:2]
52 v2 =self.client[:2]
55 v2 =self.client[:2]
53 self.assertTrue(v is v2)
56 self.assertTrue(v is v2)
54 v = self.client.view()
57 v = self.client.view()
55 v2 = self.client.view(balanced=True)
58 v2 = self.client.view(balanced=True)
56 self.assertTrue(v is v2)
59 self.assertTrue(v is v2)
57
60
58 def test_targets(self):
61 def test_targets(self):
59 """test various valid targets arguments"""
62 """test various valid targets arguments"""
60 pass
63 pass
61
64
62 def test_clear(self):
65 def test_clear(self):
63 """test clear behavior"""
66 """test clear behavior"""
64 # self.add_engines(4)
67 # self.add_engines(4)
65 # self.client.push()
68 # self.client.push()
66
69
67 def test_push_pull(self):
70 def test_push_pull(self):
71 """test pushing and pulling"""
68 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
72 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 self.add_engines(4)
73 self.add_engines(4)
70 push = self.client.push
74 push = self.client.push
71 pull = self.client.pull
75 pull = self.client.pull
72 self.client.block=True
76 self.client.block=True
73 nengines = len(self.client)
77 nengines = len(self.client)
74 push({'data':data}, targets=0)
78 push({'data':data}, targets=0)
75 d = pull('data', targets=0)
79 d = pull('data', targets=0)
76 self.assertEquals(d, data)
80 self.assertEquals(d, data)
77 push({'data':data})
81 push({'data':data})
78 d = pull('data')
82 d = pull('data')
79 self.assertEquals(d, nengines*[data])
83 self.assertEquals(d, nengines*[data])
80 ar = push({'data':data}, block=False)
84 ar = push({'data':data}, block=False)
81 self.assertTrue(isinstance(ar, AsyncResult))
85 self.assertTrue(isinstance(ar, AsyncResult))
82 r = ar.get()
86 r = ar.get()
83 ar = pull('data', block=False)
87 ar = pull('data', block=False)
84 self.assertTrue(isinstance(ar, AsyncResult))
88 self.assertTrue(isinstance(ar, AsyncResult))
85 r = ar.get()
89 r = ar.get()
86 self.assertEquals(r, nengines*[data])
90 self.assertEquals(r, nengines*[data])
87 push(dict(a=10,b=20))
91 push(dict(a=10,b=20))
88 r = pull(('a','b'))
92 r = pull(('a','b'))
89 self.assertEquals(r, nengines*[[10,20]])
93 self.assertEquals(r, nengines*[[10,20]])
90
94
91 def test_push_pull_function(self):
95 def test_push_pull_function(self):
96 "test pushing and pulling functions"
92 def testf(x):
97 def testf(x):
93 return 2.0*x
98 return 2.0*x
94
99
95 self.add_engines(4)
100 self.add_engines(4)
96 self.client.block=True
101 self.client.block=True
97 push = self.client.push
102 push = self.client.push
98 pull = self.client.pull
103 pull = self.client.pull
99 execute = self.client.execute
104 execute = self.client.execute
100 push({'testf':testf}, targets=0)
105 push({'testf':testf}, targets=0)
101 r = pull('testf', targets=0)
106 r = pull('testf', targets=0)
102 self.assertEqual(r(1.0), testf(1.0))
107 self.assertEqual(r(1.0), testf(1.0))
103 execute('r = testf(10)', targets=0)
108 execute('r = testf(10)', targets=0)
104 r = pull('r', targets=0)
109 r = pull('r', targets=0)
105 self.assertEquals(r, testf(10))
110 self.assertEquals(r, testf(10))
106 ar = push({'testf':testf}, block=False)
111 ar = push({'testf':testf}, block=False)
107 ar.get()
112 ar.get()
108 ar = pull('testf', block=False)
113 ar = pull('testf', block=False)
109 rlist = ar.get()
114 rlist = ar.get()
110 for r in rlist:
115 for r in rlist:
111 self.assertEqual(r(1.0), testf(1.0))
116 self.assertEqual(r(1.0), testf(1.0))
112 execute("def g(x): return x*x", targets=0)
117 execute("def g(x): return x*x", targets=0)
113 r = pull(('testf','g'),targets=0)
118 r = pull(('testf','g'),targets=0)
114 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
119 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
115
120
121 def test_push_function_globals(self):
122 """test that pushed functions have access to globals"""
123 def geta():
124 return a
125 self.add_engines(1)
126 v = self.client[-1]
127 v.block=True
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
131 v.execute('b=f()')
132 self.assertEquals(v['b'], 5)
116
133
117 No newline at end of file
134
General Comments 0
You need to be logged in to leave comments. Login now