##// END OF EJS Templates
fix/test pushed function globals
MinRK -
Show More
@@ -1,127 +1,127 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 from types import FunctionType
19 19 import copy
20 20
21 21 from IPython.zmq.parallel.dependency import dependent
22 22
23 23 import codeutil
24 24
25 25 #-------------------------------------------------------------------------------
26 26 # Classes
27 27 #-------------------------------------------------------------------------------
28 28
29 29
30 30 class CannedObject(object):
31 31 def __init__(self, obj, keys=[]):
32 32 self.keys = keys
33 33 self.obj = copy.copy(obj)
34 34 for key in keys:
35 35 setattr(self.obj, key, can(getattr(obj, key)))
36 36
37 37
38 38 def getObject(self, g=None):
39 39 if g is None:
40 40 g = globals()
41 41 for key in self.keys:
42 42 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
43 43 return self.obj
44 44
45 45
46 46
47 47 class CannedFunction(CannedObject):
48 48
49 49 def __init__(self, f):
50 50 self._checkType(f)
51 51 self.code = f.func_code
52 52 self.__name__ = f.__name__
53 53
54 54 def _checkType(self, obj):
55 55 assert isinstance(obj, FunctionType), "Not a function type"
56 56
57 57 def getFunction(self, g=None):
58 58 if g is None:
59 59 g = globals()
60 60 newFunc = FunctionType(self.code, g)
61 61 return newFunc
62 62
63 63 #-------------------------------------------------------------------------------
64 64 # Functions
65 65 #-------------------------------------------------------------------------------
66 66
67 67
68 68 def can(obj):
69 69 if isinstance(obj, FunctionType):
70 70 return CannedFunction(obj)
71 71 elif isinstance(obj, dependent):
72 72 keys = ('f','df')
73 73 return CannedObject(obj, keys=keys)
74 74 elif isinstance(obj,dict):
75 75 return canDict(obj)
76 76 elif isinstance(obj, (list,tuple)):
77 77 return canSequence(obj)
78 78 else:
79 79 return obj
80 80
81 81 def canDict(obj):
82 82 if isinstance(obj, dict):
83 83 newobj = {}
84 84 for k, v in obj.iteritems():
85 85 newobj[k] = can(v)
86 86 return newobj
87 87 else:
88 88 return obj
89 89
90 90 def canSequence(obj):
91 91 if isinstance(obj, (list, tuple)):
92 92 t = type(obj)
93 93 return t([can(i) for i in obj])
94 94 else:
95 95 return obj
96 96
97 97 def uncan(obj, g=None):
98 98 if isinstance(obj, CannedFunction):
99 99 return obj.getFunction(g)
100 100 elif isinstance(obj, CannedObject):
101 101 return obj.getObject(g)
102 102 elif isinstance(obj,dict):
103 return uncanDict(obj)
103 return uncanDict(obj, g)
104 104 elif isinstance(obj, (list,tuple)):
105 return uncanSequence(obj)
105 return uncanSequence(obj, g)
106 106 else:
107 107 return obj
108 108
109 109 def uncanDict(obj, g=None):
110 110 if isinstance(obj, dict):
111 111 newobj = {}
112 112 for k, v in obj.iteritems():
113 113 newobj[k] = uncan(v,g)
114 114 return newobj
115 115 else:
116 116 return obj
117 117
118 118 def uncanSequence(obj, g=None):
119 119 if isinstance(obj, (list, tuple)):
120 120 t = type(obj)
121 121 return t([uncan(i,g) for i in obj])
122 122 else:
123 123 return obj
124 124
125 125
126 126 def rebindFunctionGlobals(f, glbls):
127 127 return FunctionType(f.func_code, glbls)
@@ -1,295 +1,292 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes and functions for kernel related errors and exceptions."""
4 4 from __future__ import print_function
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Error classes
20 20 #-------------------------------------------------------------------------------
21 21 class IPythonError(Exception):
22 22 """Base exception that all of our exceptions inherit from.
23 23
24 24 This can be raised by code that doesn't have any more specific
25 25 information."""
26 26
27 27 pass
28 28
29 29 # Exceptions associated with the controller objects
30 30 class ControllerError(IPythonError): pass
31 31
32 32 class ControllerCreationError(ControllerError): pass
33 33
34 34
35 35 # Exceptions associated with the Engines
36 36 class EngineError(IPythonError): pass
37 37
38 38 class EngineCreationError(EngineError): pass
39 39
40 40 class KernelError(IPythonError):
41 41 pass
42 42
43 43 class NotDefined(KernelError):
44 44 def __init__(self, name):
45 45 self.name = name
46 46 self.args = (name,)
47 47
48 48 def __repr__(self):
49 49 return '<NotDefined: %s>' % self.name
50 50
51 51 __str__ = __repr__
52 52
53 53
54 54 class QueueCleared(KernelError):
55 55 pass
56 56
57 57
58 58 class IdInUse(KernelError):
59 59 pass
60 60
61 61
62 62 class ProtocolError(KernelError):
63 63 pass
64 64
65 65
66 66 class ConnectionError(KernelError):
67 67 pass
68 68
69 69
70 70 class InvalidEngineID(KernelError):
71 71 pass
72 72
73 73
74 74 class NoEnginesRegistered(KernelError):
75 75 pass
76 76
77 77
78 78 class InvalidClientID(KernelError):
79 79 pass
80 80
81 81
82 82 class InvalidDeferredID(KernelError):
83 83 pass
84 84
85 85
86 86 class SerializationError(KernelError):
87 87 pass
88 88
89 89
90 90 class MessageSizeError(KernelError):
91 91 pass
92 92
93 93
94 94 class PBMessageSizeError(MessageSizeError):
95 95 pass
96 96
97 97
98 98 class ResultNotCompleted(KernelError):
99 99 pass
100 100
101 101
102 102 class ResultAlreadyRetrieved(KernelError):
103 103 pass
104 104
105 105 class ClientError(KernelError):
106 106 pass
107 107
108 108
109 109 class TaskAborted(KernelError):
110 110 pass
111 111
112 112
113 113 class TaskTimeout(KernelError):
114 114 pass
115 115
116 116
117 117 class NotAPendingResult(KernelError):
118 118 pass
119 119
120 120
121 121 class UnpickleableException(KernelError):
122 122 pass
123 123
124 124
125 125 class AbortedPendingDeferredError(KernelError):
126 126 pass
127 127
128 128
129 129 class InvalidProperty(KernelError):
130 130 pass
131 131
132 132
133 133 class MissingBlockArgument(KernelError):
134 134 pass
135 135
136 136
137 137 class StopLocalExecution(KernelError):
138 138 pass
139 139
140 140
141 141 class SecurityError(KernelError):
142 142 pass
143 143
144 144
145 145 class FileTimeoutError(KernelError):
146 146 pass
147 147
148 148 class TimeoutError(KernelError):
149 149 pass
150 150
151 151 class UnmetDependency(KernelError):
152 152 pass
153 153
154 154 class ImpossibleDependency(UnmetDependency):
155 155 pass
156 156
157 157 class DependencyTimeout(ImpossibleDependency):
158 158 pass
159 159
160 160 class InvalidDependency(ImpossibleDependency):
161 161 pass
162 162
163 163 class RemoteError(KernelError):
164 164 """Error raised elsewhere"""
165 165 ename=None
166 166 evalue=None
167 167 traceback=None
168 168 engine_info=None
169 169
170 170 def __init__(self, ename, evalue, traceback, engine_info=None):
171 171 self.ename=ename
172 172 self.evalue=evalue
173 173 self.traceback=traceback
174 174 self.engine_info=engine_info or {}
175 175 self.args=(ename, evalue)
176 176
177 177 def __repr__(self):
178 178 engineid = self.engine_info.get('engineid', ' ')
179 179 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
180 180
181 181 def __str__(self):
182 182 sig = "%s(%s)"%(self.ename, self.evalue)
183 183 if self.traceback:
184 184 return sig + '\n' + self.traceback
185 185 else:
186 186 return sig
187 187
188 188
189 189 class TaskRejectError(KernelError):
190 190 """Exception to raise when a task should be rejected by an engine.
191 191
192 192 This exception can be used to allow a task running on an engine to test
193 193 if the engine (or the user's namespace on the engine) has the needed
194 194 task dependencies. If not, the task should raise this exception. For
195 195 the task to be retried on another engine, the task should be created
196 196 with the `retries` argument > 1.
197 197
198 198 The advantage of this approach over our older properties system is that
199 199 tasks have full access to the user's namespace on the engines and the
200 200 properties don't have to be managed or tested by the controller.
201 201 """
202 202
203 203
204 class CompositeError(KernelError):
204 class CompositeError(RemoteError):
205 205 """Error for representing possibly multiple errors on engines"""
206 206 def __init__(self, message, elist):
207 207 Exception.__init__(self, *(message, elist))
208 208 # Don't use pack_exception because it will conflict with the .message
209 209 # attribute that is being deprecated in 2.6 and beyond.
210 210 self.msg = message
211 211 self.elist = elist
212 212 self.args = [ e[0] for e in elist ]
213 213
214 214 def _get_engine_str(self, ei):
215 215 if not ei:
216 216 return '[Engine Exception]'
217 217 else:
218 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
218 return '[%s:%s]: ' % (ei['engineid'], ei['method'])
219 219
220 220 def _get_traceback(self, ev):
221 221 try:
222 222 tb = ev._ipython_traceback_text
223 223 except AttributeError:
224 224 return 'No traceback available'
225 225 else:
226 226 return tb
227 227
228 228 def __str__(self):
229 229 s = str(self.msg)
230 230 for en, ev, etb, ei in self.elist:
231 231 engine_str = self._get_engine_str(ei)
232 232 s = s + '\n' + engine_str + en + ': ' + str(ev)
233 233 return s
234 234
235 235 def __repr__(self):
236 236 return "CompositeError(%i)"%len(self.elist)
237 237
238 238 def print_tracebacks(self, excid=None):
239 239 if excid is None:
240 240 for (en,ev,etb,ei) in self.elist:
241 241 print (self._get_engine_str(ei))
242 242 print (etb or 'No traceback available')
243 243 print ()
244 244 else:
245 245 try:
246 246 en,ev,etb,ei = self.elist[excid]
247 247 except:
248 248 raise IndexError("an exception with index %i does not exist"%excid)
249 249 else:
250 250 print (self._get_engine_str(ei))
251 251 print (etb or 'No traceback available')
252 252
253 253 def raise_exception(self, excid=0):
254 254 try:
255 255 en,ev,etb,ei = self.elist[excid]
256 256 except:
257 257 raise IndexError("an exception with index %i does not exist"%excid)
258 258 else:
259 try:
260 raise RemoteError(en, ev, etb, ei)
261 except:
262 et,ev,tb = sys.exc_info()
259 raise RemoteError(en, ev, etb, ei)
263 260
264 261
265 262 def collect_exceptions(rdict_or_list, method='unspecified'):
266 263 """check a result dict for errors, and raise CompositeError if any exist.
267 264 Passthrough otherwise."""
268 265 elist = []
269 266 if isinstance(rdict_or_list, dict):
270 267 rlist = rdict_or_list.values()
271 268 else:
272 269 rlist = rdict_or_list
273 270 for r in rlist:
274 271 if isinstance(r, RemoteError):
275 272 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
276 273 # Sometimes we could have CompositeError in our list. Just take
277 274 # the errors out of them and put them in our new list. This
278 275 # has the effect of flattening lists of CompositeErrors into one
279 276 # CompositeError
280 277 if en=='CompositeError':
281 278 for e in ev.elist:
282 279 elist.append(e)
283 280 else:
284 281 elist.append((en, ev, etb, ei))
285 282 if len(elist)==0:
286 283 return rdict_or_list
287 284 else:
288 285 msg = "one or more exceptions from call to method: %s" % (method)
289 286 # This silliness is needed so the debugger has access to the exception
290 287 # instance (e in this case)
291 288 try:
292 289 raise CompositeError(msg, elist)
293 except CompositeError, e:
290 except CompositeError as e:
294 291 raise e
295 292
@@ -1,85 +1,96 b''
1 1 import time
2 2 from signal import SIGINT
3 3 from multiprocessing import Process
4 4
5 5 from nose import SkipTest
6 6
7 7 from zmq.tests import BaseZMQTestCase
8 8
9 9 from IPython.external.decorator import decorator
10 10
11 from IPython.zmq.parallel import error
12 from IPython.zmq.parallel.client import Client
11 13 from IPython.zmq.parallel.ipcluster import launch_process
12 14 from IPython.zmq.parallel.entry_point import select_random_ports
13 from IPython.zmq.parallel.client import Client
14 15 from IPython.zmq.parallel.tests import processes,add_engine
15 16
16 17 # simple tasks for use in apply tests
17 18
18 19 def segfault():
19 20 """"""
20 21 import ctypes
21 22 ctypes.memset(-1,0,1)
22 23
23 24 def wait(n):
24 25 """sleep for a time"""
25 26 import time
26 27 time.sleep(n)
27 28 return n
28 29
29 30 def raiser(eclass):
30 31 """raise an exception"""
31 32 raise eclass()
32 33
33 34 # test decorator for skipping tests when libraries are unavailable
34 35 def skip_without(*names):
35 36 """skip a test if some names are not importable"""
36 37 @decorator
37 38 def skip_without_names(f, *args, **kwargs):
38 39 """decorator to skip tests in the absence of numpy."""
39 40 for name in names:
40 41 try:
41 42 __import__(name)
42 43 except ImportError:
43 44 raise SkipTest
44 45 return f(*args, **kwargs)
45 46 return skip_without_names
46 47
47 48
48 49 class ClusterTestCase(BaseZMQTestCase):
49 50
50 51 def add_engines(self, n=1, block=True):
51 52 """add multiple engines to our cluster"""
52 53 for i in range(n):
53 54 self.engines.append(add_engine())
54 55 if block:
55 56 self.wait_on_engines()
56 57
57 58 def wait_on_engines(self, timeout=5):
58 59 """wait for our engines to connect."""
59 60 n = len(self.engines)+self.base_engine_count
60 61 tic = time.time()
61 62 while time.time()-tic < timeout and len(self.client.ids) < n:
62 63 time.sleep(0.1)
63 64
64 65 assert not self.client.ids < n, "waiting for engines timed out"
65 66
66 67 def connect_client(self):
67 68 """connect a client with my Context, and track its sockets for cleanup"""
68 69 c = Client(profile='iptest',context=self.context)
69 70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
70 71 self.sockets.append(getattr(c, name))
71 72 return c
72 73
74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 try:
76 f(*args, **kwargs)
77 except error.CompositeError as e:
78 e.raise_exception()
79 except error.RemoteError as e:
80 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
81 else:
82 self.fail("should have raised a RemoteError")
83
73 84 def setUp(self):
74 85 BaseZMQTestCase.setUp(self)
75 86 self.client = self.connect_client()
76 87 self.base_engine_count=len(self.client.ids)
77 88 self.engines=[]
78 89
79 90 def tearDown(self):
80 91 [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
81 92 # while len(self.client.ids) > self.base_engine_count:
82 93 # time.sleep(.1)
83 94 del self.engines
84 95 BaseZMQTestCase.tearDown(self)
85 96 No newline at end of file
@@ -1,117 +1,134 b''
1 1 import time
2 2
3 3 import nose.tools as nt
4 4
5 5 from IPython.zmq.parallel.asyncresult import AsyncResult
6 6 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
7 7
8 8 from clienttest import ClusterTestCase, segfault
9 9
10 10 class TestClient(ClusterTestCase):
11 11
12 12 def test_ids(self):
13 13 self.assertEquals(len(self.client.ids), 1)
14 14 self.add_engines(3)
15 15 self.assertEquals(len(self.client.ids), 4)
16 16
17 17 def test_segfault(self):
18 18 self.add_engines(1)
19 19 eid = self.client.ids[-1]
20 20 self.client[eid].apply(segfault)
21 21 while eid in self.client.ids:
22 22 time.sleep(.01)
23 23 self.client.spin()
24 24
25 25 def test_view_indexing(self):
26 26 self.add_engines(4)
27 27 targets = self.client._build_targets('all')[-1]
28 28 v = self.client[:]
29 29 self.assertEquals(v.targets, targets)
30 30 t = self.client.ids[2]
31 31 v = self.client[t]
32 32 self.assert_(isinstance(v, DirectView))
33 33 self.assertEquals(v.targets, t)
34 34 t = self.client.ids[2:4]
35 35 v = self.client[t]
36 36 self.assert_(isinstance(v, DirectView))
37 37 self.assertEquals(v.targets, t)
38 38 v = self.client[::2]
39 39 self.assert_(isinstance(v, DirectView))
40 40 self.assertEquals(v.targets, targets[::2])
41 41 v = self.client[1::3]
42 42 self.assert_(isinstance(v, DirectView))
43 43 self.assertEquals(v.targets, targets[1::3])
44 44 v = self.client[:-3]
45 45 self.assert_(isinstance(v, DirectView))
46 46 self.assertEquals(v.targets, targets[:-3])
47 v = self.client[-1]
48 self.assert_(isinstance(v, DirectView))
49 self.assertEquals(v.targets, targets[-1])
47 50 nt.assert_raises(TypeError, lambda : self.client[None])
48 51
49 52 def test_view_cache(self):
50 """test blocking and non-blocking behavior"""
53 """test that multiple view requests return the same object"""
51 54 v = self.client[:2]
52 55 v2 =self.client[:2]
53 56 self.assertTrue(v is v2)
54 57 v = self.client.view()
55 58 v2 = self.client.view(balanced=True)
56 59 self.assertTrue(v is v2)
57 60
58 61 def test_targets(self):
59 62 """test various valid targets arguments"""
60 63 pass
61 64
62 65 def test_clear(self):
63 66 """test clear behavior"""
64 67 # self.add_engines(4)
65 68 # self.client.push()
66 69
67 70 def test_push_pull(self):
71 """test pushing and pulling"""
68 72 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 73 self.add_engines(4)
70 74 push = self.client.push
71 75 pull = self.client.pull
72 76 self.client.block=True
73 77 nengines = len(self.client)
74 78 push({'data':data}, targets=0)
75 79 d = pull('data', targets=0)
76 80 self.assertEquals(d, data)
77 81 push({'data':data})
78 82 d = pull('data')
79 83 self.assertEquals(d, nengines*[data])
80 84 ar = push({'data':data}, block=False)
81 85 self.assertTrue(isinstance(ar, AsyncResult))
82 86 r = ar.get()
83 87 ar = pull('data', block=False)
84 88 self.assertTrue(isinstance(ar, AsyncResult))
85 89 r = ar.get()
86 90 self.assertEquals(r, nengines*[data])
87 91 push(dict(a=10,b=20))
88 92 r = pull(('a','b'))
89 93 self.assertEquals(r, nengines*[[10,20]])
90 94
91 95 def test_push_pull_function(self):
96 "test pushing and pulling functions"
92 97 def testf(x):
93 98 return 2.0*x
94 99
95 100 self.add_engines(4)
96 101 self.client.block=True
97 102 push = self.client.push
98 103 pull = self.client.pull
99 104 execute = self.client.execute
100 105 push({'testf':testf}, targets=0)
101 106 r = pull('testf', targets=0)
102 107 self.assertEqual(r(1.0), testf(1.0))
103 108 execute('r = testf(10)', targets=0)
104 109 r = pull('r', targets=0)
105 110 self.assertEquals(r, testf(10))
106 111 ar = push({'testf':testf}, block=False)
107 112 ar.get()
108 113 ar = pull('testf', block=False)
109 114 rlist = ar.get()
110 115 for r in rlist:
111 116 self.assertEqual(r(1.0), testf(1.0))
112 117 execute("def g(x): return x*x", targets=0)
113 118 r = pull(('testf','g'),targets=0)
114 119 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
115
120
121 def test_push_function_globals(self):
122 """test that pushed functions have access to globals"""
123 def geta():
124 return a
125 self.add_engines(1)
126 v = self.client[-1]
127 v.block=True
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
131 v.execute('b=f()')
132 self.assertEquals(v['b'], 5)
116 133
117 134 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now