##// END OF EJS Templates
Improvements to dependency handling...
MinRK -
Show More
@@ -1,115 +1,127 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 from types import FunctionType
19 import copy
19 20
20 # contents of codeutil should either be in here, or codeutil belongs in IPython/util
21 21 from IPython.zmq.parallel.dependency import dependent
22
22 23 import codeutil
23 24
25 #-------------------------------------------------------------------------------
26 # Classes
27 #-------------------------------------------------------------------------------
28
29
24 30 class CannedObject(object):
25 31 def __init__(self, obj, keys=[]):
26 32 self.keys = keys
27 self.obj = obj
33 self.obj = copy.copy(obj)
28 34 for key in keys:
29 setattr(obj, key, can(getattr(obj, key)))
35 setattr(self.obj, key, can(getattr(obj, key)))
30 36
31 37
32 38 def getObject(self, g=None):
33 39 if g is None:
34 40 g = globals()
35 41 for key in self.keys:
36 42 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
37 43 return self.obj
38 44
39 45
40 46
41 47 class CannedFunction(CannedObject):
42 48
43 49 def __init__(self, f):
44 50 self._checkType(f)
45 51 self.code = f.func_code
52 self.__name__ = f.__name__
46 53
47 54 def _checkType(self, obj):
48 55 assert isinstance(obj, FunctionType), "Not a function type"
49 56
50 57 def getFunction(self, g=None):
51 58 if g is None:
52 59 g = globals()
53 60 newFunc = FunctionType(self.code, g)
54 61 return newFunc
55 62
63 #-------------------------------------------------------------------------------
64 # Functions
65 #-------------------------------------------------------------------------------
66
67
56 68 def can(obj):
57 69 if isinstance(obj, FunctionType):
58 70 return CannedFunction(obj)
59 71 elif isinstance(obj, dependent):
60 72 keys = ('f','df')
61 73 return CannedObject(obj, keys=keys)
62 74 elif isinstance(obj,dict):
63 75 return canDict(obj)
64 76 elif isinstance(obj, (list,tuple)):
65 77 return canSequence(obj)
66 78 else:
67 79 return obj
68 80
69 81 def canDict(obj):
70 82 if isinstance(obj, dict):
71 83 newobj = {}
72 84 for k, v in obj.iteritems():
73 85 newobj[k] = can(v)
74 86 return newobj
75 87 else:
76 88 return obj
77 89
78 90 def canSequence(obj):
79 91 if isinstance(obj, (list, tuple)):
80 92 t = type(obj)
81 93 return t([can(i) for i in obj])
82 94 else:
83 95 return obj
84 96
85 97 def uncan(obj, g=None):
86 98 if isinstance(obj, CannedFunction):
87 99 return obj.getFunction(g)
88 100 elif isinstance(obj, CannedObject):
89 101 return obj.getObject(g)
90 102 elif isinstance(obj,dict):
91 103 return uncanDict(obj)
92 104 elif isinstance(obj, (list,tuple)):
93 105 return uncanSequence(obj)
94 106 else:
95 107 return obj
96 108
97 109 def uncanDict(obj, g=None):
98 110 if isinstance(obj, dict):
99 111 newobj = {}
100 112 for k, v in obj.iteritems():
101 113 newobj[k] = uncan(v,g)
102 114 return newobj
103 115 else:
104 116 return obj
105 117
106 118 def uncanSequence(obj, g=None):
107 119 if isinstance(obj, (list, tuple)):
108 120 t = type(obj)
109 121 return t([uncan(i,g) for i in obj])
110 122 else:
111 123 return obj
112 124
113 125
114 126 def rebindFunctionGlobals(f, glbls):
115 127 return FunctionType(f.func_code, glbls)
@@ -1,188 +1,200 b''
1 1 """AsyncResult objects for the client"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from IPython.external.decorator import decorator
14 14 import error
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Classes
18 18 #-----------------------------------------------------------------------------
19 19
20 20 @decorator
21 21 def check_ready(f, self, *args, **kwargs):
22 22 """Call spin() to sync state prior to calling the method."""
23 23 self.wait(0)
24 24 if not self._ready:
25 25 raise error.TimeoutError("result not ready")
26 26 return f(self, *args, **kwargs)
27 27
28 28 class AsyncResult(object):
29 29 """Class for representing results of non-blocking calls.
30 30
31 31 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
32 32 """
33 33 def __init__(self, client, msg_ids, fname=''):
34 34 self._client = client
35 35 self.msg_ids = msg_ids
36 36 self._fname=fname
37 37 self._ready = False
38 38 self._success = None
39 self._flatten_result = len(msg_ids) == 1
39 40
40 41 def __repr__(self):
41 42 if self._ready:
42 43 return "<%s: finished>"%(self.__class__.__name__)
43 44 else:
44 45 return "<%s: %s>"%(self.__class__.__name__,self._fname)
45 46
46 47
47 48 def _reconstruct_result(self, res):
48 49 """
49 50 Override me in subclasses for turning a list of results
50 51 into the expected form.
51 52 """
52 if len(self.msg_ids) == 1:
53 if self._flatten_result:
53 54 return res[0]
54 55 else:
55 56 return res
56 57
57 58 def get(self, timeout=-1):
58 59 """Return the result when it arrives.
59 60
60 61 If `timeout` is not ``None`` and the result does not arrive within
61 62 `timeout` seconds then ``TimeoutError`` is raised. If the
62 63 remote call raised an exception then that exception will be reraised
63 64 by get().
64 65 """
65 66 if not self.ready():
66 67 self.wait(timeout)
67 68
68 69 if self._ready:
69 70 if self._success:
70 71 return self._result
71 72 else:
72 73 raise self._exception
73 74 else:
74 75 raise error.TimeoutError("Result not ready.")
75 76
76 77 def ready(self):
77 78 """Return whether the call has completed."""
78 79 if not self._ready:
79 80 self.wait(0)
80 81 return self._ready
81 82
82 83 def wait(self, timeout=-1):
83 84 """Wait until the result is available or until `timeout` seconds pass.
84 85 """
85 86 if self._ready:
86 87 return
87 88 self._ready = self._client.barrier(self.msg_ids, timeout)
88 89 if self._ready:
89 90 try:
90 91 results = map(self._client.results.get, self.msg_ids)
91 92 self._result = results
92 93 results = error.collect_exceptions(results, self._fname)
93 94 self._result = self._reconstruct_result(results)
94 95 except Exception, e:
95 96 self._exception = e
96 97 self._success = False
97 98 else:
98 99 self._success = True
99 100 finally:
100 101 self._metadata = map(self._client.metadata.get, self.msg_ids)
101 102
102 103
103 104 def successful(self):
104 105 """Return whether the call completed without raising an exception.
105 106
106 107 Will raise ``AssertionError`` if the result is not ready.
107 108 """
108 109 assert self._ready
109 110 return self._success
110 111
111 112 #----------------------------------------------------------------
112 113 # Extra methods not in mp.pool.AsyncResult
113 114 #----------------------------------------------------------------
114 115
115 116 def get_dict(self, timeout=-1):
116 117 """Get the results as a dict, keyed by engine_id."""
117 118 results = self.get(timeout)
118 engine_ids = [md['engine_id'] for md in self._metadata ]
119 engine_ids = [ md['engine_id'] for md in self._metadata ]
119 120 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
120 121 maxcount = bycount.count(bycount[-1])
121 122 if maxcount > 1:
122 123 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
123 124 maxcount, bycount[-1]))
124 125
125 126 return dict(zip(engine_ids,results))
126 127
127 128 @property
128 129 @check_ready
129 130 def result(self):
130 131 """result property."""
131 132 return self._result
132 133
134 # abbreviated alias:
135 r = result
136
133 137 @property
134 138 @check_ready
135 139 def metadata(self):
136 140 """metadata property."""
137 return self._metadata
141 if self._flatten_result:
142 return self._metadata[0]
143 else:
144 return self._metadata
138 145
139 146 @property
140 147 def result_dict(self):
141 148 """result property as a dict."""
142 149 return self.get_dict(0)
143 150
144 151 def __dict__(self):
145 152 return self.get_dict(0)
146 153
147 154 #-------------------------------------
148 155 # dict-access
149 156 #-------------------------------------
150 157
151 158 @check_ready
152 159 def __getitem__(self, key):
153 160 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
154 161 """
155 162 if isinstance(key, int):
156 163 return error.collect_exceptions([self._result[key]], self._fname)[0]
157 164 elif isinstance(key, slice):
158 165 return error.collect_exceptions(self._result[key], self._fname)
159 166 elif isinstance(key, basestring):
160 return [ md[key] for md in self._metadata ]
167 values = [ md[key] for md in self._metadata ]
168 if self._flatten_result:
169 return values[0]
170 else:
171 return values
161 172 else:
162 173 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
163 174
164 175 @check_ready
165 176 def __getattr__(self, key):
166 177 """getattr maps to getitem for convenient access to metadata."""
167 178 if key not in self._metadata[0].keys():
168 179 raise AttributeError("%r object has no attribute %r"%(
169 180 self.__class__.__name__, key))
170 181 return self.__getitem__(key)
171 182
172 183
173 184 class AsyncMapResult(AsyncResult):
174 185 """Class for representing results of non-blocking gathers.
175 186
176 187 This will properly reconstruct the gather.
177 188 """
178 189
179 190 def __init__(self, client, msg_ids, mapObject, fname=''):
180 self._mapObject = mapObject
181 191 AsyncResult.__init__(self, client, msg_ids, fname=fname)
192 self._mapObject = mapObject
193 self._flatten_result = False
182 194
183 195 def _reconstruct_result(self, res):
184 196 """Perform the gather on the actual results."""
185 197 return self._mapObject.joinPartitions(res)
186 198
187 199
188 200 __all__ = ['AsyncResult', 'AsyncMapResult'] No newline at end of file
@@ -1,1164 +1,1206 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import time
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17 from datetime import datetime
18 18
19 19 import zmq
20 20 from zmq.eventloop import ioloop, zmqstream
21 21
22 22 from IPython.external.decorator import decorator
23 23 from IPython.zmq import tunnel
24 24
25 25 import streamsession as ss
26 26 # from remotenamespace import RemoteNamespace
27 27 from view import DirectView, LoadBalancedView
28 28 from dependency import Dependency, depend, require
29 29 import error
30 30 import map as Map
31 31 from asyncresult import AsyncResult, AsyncMapResult
32 32 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
33 33 from util import ReverseDict
34 34
35 35 #--------------------------------------------------------------------------
36 36 # helpers for implementing old MEC API via client.apply
37 37 #--------------------------------------------------------------------------
38 38
39 39 def _push(ns):
40 40 """helper method for implementing `client.push` via `client.apply`"""
41 41 globals().update(ns)
42 42
43 43 def _pull(keys):
44 44 """helper method for implementing `client.pull` via `client.apply`"""
45 45 g = globals()
46 46 if isinstance(keys, (list,tuple, set)):
47 47 for key in keys:
48 48 if not g.has_key(key):
49 49 raise NameError("name '%s' is not defined"%key)
50 50 return map(g.get, keys)
51 51 else:
52 52 if not g.has_key(keys):
53 53 raise NameError("name '%s' is not defined"%keys)
54 54 return g.get(keys)
55 55
56 56 def _clear():
57 57 """helper method for implementing `client.clear` via `client.apply`"""
58 58 globals().clear()
59 59
60 60 def _execute(code):
61 61 """helper method for implementing `client.execute` via `client.apply`"""
62 62 exec code in globals()
63 63
64 64
65 65 #--------------------------------------------------------------------------
66 66 # Decorators for Client methods
67 67 #--------------------------------------------------------------------------
68 68
69 69 @decorator
70 70 def spinfirst(f, self, *args, **kwargs):
71 71 """Call spin() to sync state prior to calling the method."""
72 72 self.spin()
73 73 return f(self, *args, **kwargs)
74 74
75 75 @decorator
76 76 def defaultblock(f, self, *args, **kwargs):
77 77 """Default to self.block; preserve self.block."""
78 78 block = kwargs.get('block',None)
79 79 block = self.block if block is None else block
80 80 saveblock = self.block
81 81 self.block = block
82 82 try:
83 83 ret = f(self, *args, **kwargs)
84 84 finally:
85 85 self.block = saveblock
86 86 return ret
87 87
88 88
89 89 #--------------------------------------------------------------------------
90 90 # Classes
91 91 #--------------------------------------------------------------------------
92 92
93 93 class Metadata(dict):
94 """Subclass of dict for initializing metadata values."""
94 """Subclass of dict for initializing metadata values.
95
96 Attribute access works on keys.
97
98 These objects have a strict set of keys - errors will raise if you try
99 to add new keys.
100 """
95 101 def __init__(self, *args, **kwargs):
96 102 dict.__init__(self)
97 103 md = {'msg_id' : None,
98 104 'submitted' : None,
99 105 'started' : None,
100 106 'completed' : None,
101 107 'received' : None,
102 108 'engine_uuid' : None,
103 109 'engine_id' : None,
104 110 'follow' : None,
105 111 'after' : None,
106 112 'status' : None,
107 113
108 114 'pyin' : None,
109 115 'pyout' : None,
110 116 'pyerr' : None,
111 117 'stdout' : '',
112 118 'stderr' : '',
113 119 }
114 120 self.update(md)
115 121 self.update(dict(*args, **kwargs))
122
123 def __getattr__(self, key):
124 """getattr aliased to getitem"""
125 if key in self.iterkeys():
126 return self[key]
127 else:
128 raise AttributeError(key)
116 129
130 def __setattr__(self, key, value):
131 """setattr aliased to setitem, with strict"""
132 if key in self.iterkeys():
133 self[key] = value
134 else:
135 raise AttributeError(key)
136
137 def __setitem__(self, key, value):
138 """strict static key enforcement"""
139 if key in self.iterkeys():
140 dict.__setitem__(self, key, value)
141 else:
142 raise KeyError(key)
117 143
118 144
119 145 class Client(object):
120 146 """A semi-synchronous client to the IPython ZMQ controller
121 147
122 148 Parameters
123 149 ----------
124 150
125 151 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
126 152 The address of the controller's registration socket.
127 153 [Default: 'tcp://127.0.0.1:10101']
128 154 context : zmq.Context
129 155 Pass an existing zmq.Context instance, otherwise the client will create its own
130 156 username : bytes
131 157 set username to be passed to the Session object
132 158 debug : bool
133 159 flag for lots of message printing for debug purposes
134 160
135 161 #-------------- ssh related args ----------------
136 162 # These are args for configuring the ssh tunnel to be used
137 163 # credentials are used to forward connections over ssh to the Controller
138 164 # Note that the ip given in `addr` needs to be relative to sshserver
139 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
140 166 # and set sshserver as the same machine the Controller is on. However,
141 167 # the only requirement is that sshserver is able to see the Controller
142 168 # (i.e. is within the same trusted network).
143 169
144 170 sshserver : str
145 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
146 172 If keyfile or password is specified, and this is not, it will default to
147 173 the ip given in addr.
148 174 sshkey : str; path to public ssh key file
149 175 This specifies a key to be used in ssh login, default None.
150 176 Regular default ssh keys will be used without specifying this argument.
151 177 password : str;
152 178 Your ssh password to sshserver. Note that if this is left None,
153 179 you will be prompted for it if passwordless key based login is unavailable.
154 180
155 181 #------- exec authentication args -------
156 182 # If even localhost is untrusted, you can have some protection against
157 183 # unauthorized execution by using a key. Messages are still sent
158 184 # as cleartext, so if someone can snoop your loopback traffic this will
159 185 # not help anything.
160 186
161 187 exec_key : str
162 188 an authentication key or file containing a key
163 189 default: None
164 190
165 191
166 192 Attributes
167 193 ----------
168 194 ids : set of int engine IDs
169 195 requesting the ids attribute always synchronizes
170 196 the registration state. To request ids without synchronization,
171 197 use semi-private _ids attributes.
172 198
173 199 history : list of msg_ids
174 200 a list of msg_ids, keeping track of all the execution
175 201 messages you have submitted in order.
176 202
177 203 outstanding : set of msg_ids
178 204 a set of msg_ids that have been submitted, but whose
179 205 results have not yet been received.
180 206
181 207 results : dict
182 208 a dict of all our results, keyed by msg_id
183 209
184 210 block : bool
185 211 determines default behavior when block not specified
186 212 in execution methods
187 213
188 214 Methods
189 215 -------
190 216 spin : flushes incoming results and registration state changes
191 217 control methods spin, and requesting `ids` also ensures up to date
192 218
193 219 barrier : wait on one or more msg_ids
194 220
195 221 execution methods: apply/apply_bound/apply_to/apply_bound
196 222 legacy: execute, run
197 223
198 224 query methods: queue_status, get_result, purge
199 225
200 226 control methods: abort, kill
201 227
202 228 """
203 229
204 230
205 231 _connected=False
206 232 _ssh=False
207 233 _engines=None
208 234 _addr='tcp://127.0.0.1:10101'
209 235 _registration_socket=None
210 236 _query_socket=None
211 237 _control_socket=None
212 238 _iopub_socket=None
213 239 _notification_socket=None
214 240 _mux_socket=None
215 241 _task_socket=None
216 242 block = False
217 243 outstanding=None
218 244 results = None
219 245 history = None
220 246 debug = False
221 247 targets = None
222 248
223 249 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
224 250 sshserver=None, sshkey=None, password=None, paramiko=None,
225 251 exec_key=None,):
226 252 if context is None:
227 253 context = zmq.Context()
228 254 self.context = context
229 255 self.targets = 'all'
230 256 self._addr = addr
231 257 self._ssh = bool(sshserver or sshkey or password)
232 258 if self._ssh and sshserver is None:
233 259 # default to the same
234 260 sshserver = addr.split('://')[1].split(':')[0]
235 261 if self._ssh and password is None:
236 262 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
237 263 password=False
238 264 else:
239 265 password = getpass("SSH Password for %s: "%sshserver)
240 266 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
241 267
242 268 if exec_key is not None and os.path.isfile(exec_key):
243 269 arg = 'keyfile'
244 270 else:
245 271 arg = 'key'
246 272 key_arg = {arg:exec_key}
247 273 if username is None:
248 274 self.session = ss.StreamSession(**key_arg)
249 275 else:
250 276 self.session = ss.StreamSession(username, **key_arg)
251 277 self._registration_socket = self.context.socket(zmq.XREQ)
252 278 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
253 279 if self._ssh:
254 280 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
255 281 else:
256 282 self._registration_socket.connect(addr)
257 283 self._engines = ReverseDict()
258 284 self._ids = set()
259 285 self.outstanding=set()
260 286 self.results = {}
261 287 self.metadata = {}
262 288 self.history = []
263 289 self.debug = debug
264 290 self.session.debug = debug
265 291
266 292 self._notification_handlers = {'registration_notification' : self._register_engine,
267 293 'unregistration_notification' : self._unregister_engine,
268 294 }
269 295 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
270 296 'apply_reply' : self._handle_apply_reply}
271 297 self._connect(sshserver, ssh_kwargs)
272 298
273 299
274 300 @property
275 301 def ids(self):
276 302 """Always up to date ids property."""
277 303 self._flush_notifications()
278 304 return self._ids
279 305
280 306 def _update_engines(self, engines):
281 307 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
282 308 for k,v in engines.iteritems():
283 309 eid = int(k)
284 310 self._engines[eid] = bytes(v) # force not unicode
285 311 self._ids.add(eid)
286 312
287 313 def _build_targets(self, targets):
288 314 """Turn valid target IDs or 'all' into two lists:
289 315 (int_ids, uuids).
290 316 """
291 317 if targets is None:
292 318 targets = self._ids
293 319 elif isinstance(targets, str):
294 320 if targets.lower() == 'all':
295 321 targets = self._ids
296 322 else:
297 323 raise TypeError("%r not valid str target, must be 'all'"%(targets))
298 324 elif isinstance(targets, int):
299 325 targets = [targets]
300 326 return [self._engines[t] for t in targets], list(targets)
301 327
302 328 def _connect(self, sshserver, ssh_kwargs):
303 329 """setup all our socket connections to the controller. This is called from
304 330 __init__."""
305 331 if self._connected:
306 332 return
307 333 self._connected=True
308 334
309 335 def connect_socket(s, addr):
310 336 if self._ssh:
311 337 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
312 338 else:
313 339 return s.connect(addr)
314 340
315 341 self.session.send(self._registration_socket, 'connection_request')
316 342 idents,msg = self.session.recv(self._registration_socket,mode=0)
317 343 if self.debug:
318 344 pprint(msg)
319 345 msg = ss.Message(msg)
320 346 content = msg.content
321 347 if content.status == 'ok':
322 348 if content.mux:
323 349 self._mux_socket = self.context.socket(zmq.PAIR)
324 350 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
325 351 connect_socket(self._mux_socket, content.mux)
326 352 if content.task:
327 353 self._task_socket = self.context.socket(zmq.PAIR)
328 354 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
329 355 connect_socket(self._task_socket, content.task)
330 356 if content.notification:
331 357 self._notification_socket = self.context.socket(zmq.SUB)
332 358 connect_socket(self._notification_socket, content.notification)
333 359 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
334 360 if content.query:
335 361 self._query_socket = self.context.socket(zmq.PAIR)
336 362 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
337 363 connect_socket(self._query_socket, content.query)
338 364 if content.control:
339 365 self._control_socket = self.context.socket(zmq.PAIR)
340 366 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
341 367 connect_socket(self._control_socket, content.control)
342 368 if content.iopub:
343 369 self._iopub_socket = self.context.socket(zmq.SUB)
344 370 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
345 371 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
346 372 connect_socket(self._iopub_socket, content.iopub)
347 373 self._update_engines(dict(content.engines))
348 374
349 375 else:
350 376 self._connected = False
351 377 raise Exception("Failed to connect!")
352 378
353 379 #--------------------------------------------------------------------------
354 380 # handlers and callbacks for incoming messages
355 381 #--------------------------------------------------------------------------
356 382
357 383 def _register_engine(self, msg):
358 384 """Register a new engine, and update our connection info."""
359 385 content = msg['content']
360 386 eid = content['id']
361 387 d = {eid : content['queue']}
362 388 self._update_engines(d)
363 389 self._ids.add(int(eid))
364 390
365 391 def _unregister_engine(self, msg):
366 392 """Unregister an engine that has died."""
367 393 content = msg['content']
368 394 eid = int(content['id'])
369 395 if eid in self._ids:
370 396 self._ids.remove(eid)
371 397 self._engines.pop(eid)
372 398
373 399 def _extract_metadata(self, header, parent, content):
374 400 md = {'msg_id' : parent['msg_id'],
375 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
376 'started' : datetime.strptime(header['started'], ss.ISO8601),
377 'completed' : datetime.strptime(header['date'], ss.ISO8601),
378 401 'received' : datetime.now(),
379 'engine_uuid' : header['engine'],
380 'engine_id' : self._engines.get(header['engine'], None),
402 'engine_uuid' : header.get('engine', None),
381 403 'follow' : parent['follow'],
382 404 'after' : parent['after'],
383 405 'status' : content['status'],
384 406 }
407
408 if md['engine_uuid'] is not None:
409 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
410
411 if 'date' in parent:
412 md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
413 if 'started' in header:
414 md['started'] = datetime.strptime(header['started'], ss.ISO8601)
415 if 'date' in header:
416 md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
385 417 return md
386 418
387 419 def _handle_execute_reply(self, msg):
388 420 """Save the reply to an execute_request into our results.
389 421
390 422 execute messages are never actually used. apply is used instead.
391 423 """
392 424
393 425 parent = msg['parent_header']
394 426 msg_id = parent['msg_id']
395 427 if msg_id not in self.outstanding:
396 print("got unknown result: %s"%msg_id)
428 if msg_id in self.history:
429 print ("got stale result: %s"%msg_id)
430 else:
431 print ("got unknown result: %s"%msg_id)
397 432 else:
398 433 self.outstanding.remove(msg_id)
399 434 self.results[msg_id] = ss.unwrap_exception(msg['content'])
400 435
401 436 def _handle_apply_reply(self, msg):
402 437 """Save the reply to an apply_request into our results."""
403 438 parent = msg['parent_header']
404 439 msg_id = parent['msg_id']
405 440 if msg_id not in self.outstanding:
406 print ("got unknown result: %s"%msg_id)
441 if msg_id in self.history:
442 print ("got stale result: %s"%msg_id)
443 print self.results[msg_id]
444 print msg
445 else:
446 print ("got unknown result: %s"%msg_id)
407 447 else:
408 448 self.outstanding.remove(msg_id)
409 449 content = msg['content']
410 450 header = msg['header']
411 451
412 452 # construct metadata:
413 453 md = self.metadata.setdefault(msg_id, Metadata())
414 454 md.update(self._extract_metadata(header, parent, content))
415 455 self.metadata[msg_id] = md
416 456
417 457 # construct result:
418 458 if content['status'] == 'ok':
419 459 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
420 460 elif content['status'] == 'aborted':
421 461 self.results[msg_id] = error.AbortedTask(msg_id)
422 462 elif content['status'] == 'resubmitted':
423 463 # TODO: handle resubmission
424 464 pass
425 465 else:
426 466 e = ss.unwrap_exception(content)
427 e_uuid = e.engine_info['engineid']
428 eid = self._engines[e_uuid]
429 e.engine_info['engineid'] = eid
467 if e.engine_info:
468 e_uuid = e.engine_info['engineid']
469 eid = self._engines[e_uuid]
470 e.engine_info['engineid'] = eid
430 471 self.results[msg_id] = e
431 472
432 473 def _flush_notifications(self):
433 474 """Flush notifications of engine registrations waiting
434 475 in ZMQ queue."""
435 476 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
436 477 while msg is not None:
437 478 if self.debug:
438 479 pprint(msg)
439 480 msg = msg[-1]
440 481 msg_type = msg['msg_type']
441 482 handler = self._notification_handlers.get(msg_type, None)
442 483 if handler is None:
443 484 raise Exception("Unhandled message type: %s"%msg.msg_type)
444 485 else:
445 486 handler(msg)
446 487 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
447 488
448 489 def _flush_results(self, sock):
449 490 """Flush task or queue results waiting in ZMQ queue."""
450 491 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
451 492 while msg is not None:
452 493 if self.debug:
453 494 pprint(msg)
454 495 msg = msg[-1]
455 496 msg_type = msg['msg_type']
456 497 handler = self._queue_handlers.get(msg_type, None)
457 498 if handler is None:
458 499 raise Exception("Unhandled message type: %s"%msg.msg_type)
459 500 else:
460 501 handler(msg)
461 502 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
462 503
463 504 def _flush_control(self, sock):
464 505 """Flush replies from the control channel waiting
465 506 in the ZMQ queue.
466 507
467 508 Currently: ignore them."""
468 509 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
469 510 while msg is not None:
470 511 if self.debug:
471 512 pprint(msg)
472 513 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
473 514
474 515 def _flush_iopub(self, sock):
475 516 """Flush replies from the iopub channel waiting
476 517 in the ZMQ queue.
477 518 """
478 519 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
479 520 while msg is not None:
480 521 if self.debug:
481 522 pprint(msg)
482 523 msg = msg[-1]
483 524 parent = msg['parent_header']
484 525 msg_id = parent['msg_id']
485 526 content = msg['content']
486 527 header = msg['header']
487 528 msg_type = msg['msg_type']
488 529
489 530 # init metadata:
490 531 md = self.metadata.setdefault(msg_id, Metadata())
491 532
492 533 if msg_type == 'stream':
493 534 name = content['name']
494 535 s = md[name] or ''
495 536 md[name] = s + content['data']
496 537 elif msg_type == 'pyerr':
497 538 md.update({'pyerr' : ss.unwrap_exception(content)})
498 539 else:
499 540 md.update({msg_type : content['data']})
500 541
501 542 self.metadata[msg_id] = md
502 543
503 544 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
504 545
505 546 #--------------------------------------------------------------------------
506 547 # getitem
507 548 #--------------------------------------------------------------------------
508 549
509 550 def __getitem__(self, key):
510 551 """Dict access returns DirectView multiplexer objects or,
511 552 if key is None, a LoadBalancedView."""
512 553 if key is None:
513 554 return LoadBalancedView(self)
514 555 if isinstance(key, int):
515 556 if key not in self.ids:
516 557 raise IndexError("No such engine: %i"%key)
517 558 return DirectView(self, key)
518 559
519 560 if isinstance(key, slice):
520 561 indices = range(len(self.ids))[key]
521 562 ids = sorted(self._ids)
522 563 key = [ ids[i] for i in indices ]
523 564 # newkeys = sorted(self._ids)[thekeys[k]]
524 565
525 566 if isinstance(key, (tuple, list, xrange)):
526 567 _,targets = self._build_targets(list(key))
527 568 return DirectView(self, targets)
528 569 else:
529 570 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
530 571
531 572 #--------------------------------------------------------------------------
532 573 # Begin public methods
533 574 #--------------------------------------------------------------------------
534 575
535 576 @property
536 577 def remote(self):
537 578 """property for convenient RemoteFunction generation.
538 579
539 580 >>> @client.remote
540 581 ... def f():
541 582 import os
542 583 print (os.getpid())
543 584 """
544 585 return remote(self, block=self.block)
545 586
546 587 def spin(self):
547 588 """Flush any registration notifications and execution results
548 589 waiting in the ZMQ queue.
549 590 """
550 591 if self._notification_socket:
551 592 self._flush_notifications()
552 593 if self._mux_socket:
553 594 self._flush_results(self._mux_socket)
554 595 if self._task_socket:
555 596 self._flush_results(self._task_socket)
556 597 if self._control_socket:
557 598 self._flush_control(self._control_socket)
558 599 if self._iopub_socket:
559 600 self._flush_iopub(self._iopub_socket)
560 601
561 602 def barrier(self, msg_ids=None, timeout=-1):
562 603 """waits on one or more `msg_ids`, for up to `timeout` seconds.
563 604
564 605 Parameters
565 606 ----------
566 607 msg_ids : int, str, or list of ints and/or strs, or one or more AsyncResult objects
567 608 ints are indices to self.history
568 609 strs are msg_ids
569 610 default: wait on all outstanding messages
570 611 timeout : float
571 612 a time in seconds, after which to give up.
572 613 default is -1, which means no timeout
573 614
574 615 Returns
575 616 -------
576 617 True : when all msg_ids are done
577 618 False : timeout reached, some msg_ids still outstanding
578 619 """
579 620 tic = time.time()
580 621 if msg_ids is None:
581 622 theids = self.outstanding
582 623 else:
583 624 if isinstance(msg_ids, (int, str, AsyncResult)):
584 625 msg_ids = [msg_ids]
585 626 theids = set()
586 627 for msg_id in msg_ids:
587 628 if isinstance(msg_id, int):
588 629 msg_id = self.history[msg_id]
589 630 elif isinstance(msg_id, AsyncResult):
590 631 map(theids.add, msg_id.msg_ids)
591 632 continue
592 633 theids.add(msg_id)
593 634 if not theids.intersection(self.outstanding):
594 635 return True
595 636 self.spin()
596 637 while theids.intersection(self.outstanding):
597 638 if timeout >= 0 and ( time.time()-tic ) > timeout:
598 639 break
599 640 time.sleep(1e-3)
600 641 self.spin()
601 642 return len(theids.intersection(self.outstanding)) == 0
602 643
603 644 #--------------------------------------------------------------------------
604 645 # Control methods
605 646 #--------------------------------------------------------------------------
606 647
607 648 @spinfirst
608 649 @defaultblock
609 650 def clear(self, targets=None, block=None):
610 651 """Clear the namespace in target(s)."""
611 652 targets = self._build_targets(targets)[0]
612 653 for t in targets:
613 654 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
614 655 error = False
615 656 if self.block:
616 657 for i in range(len(targets)):
617 658 idents,msg = self.session.recv(self._control_socket,0)
618 659 if self.debug:
619 660 pprint(msg)
620 661 if msg['content']['status'] != 'ok':
621 662 error = ss.unwrap_exception(msg['content'])
622 663 if error:
623 664 return error
624 665
625 666
626 667 @spinfirst
627 668 @defaultblock
628 669 def abort(self, msg_ids = None, targets=None, block=None):
629 670 """Abort the execution queues of target(s)."""
630 671 targets = self._build_targets(targets)[0]
631 672 if isinstance(msg_ids, basestring):
632 673 msg_ids = [msg_ids]
633 674 content = dict(msg_ids=msg_ids)
634 675 for t in targets:
635 676 self.session.send(self._control_socket, 'abort_request',
636 677 content=content, ident=t)
637 678 error = False
638 679 if self.block:
639 680 for i in range(len(targets)):
640 681 idents,msg = self.session.recv(self._control_socket,0)
641 682 if self.debug:
642 683 pprint(msg)
643 684 if msg['content']['status'] != 'ok':
644 685 error = ss.unwrap_exception(msg['content'])
645 686 if error:
646 687 return error
647 688
648 689 @spinfirst
649 690 @defaultblock
650 691 def shutdown(self, targets=None, restart=False, controller=False, block=None):
651 692 """Terminates one or more engine processes, optionally including the controller."""
652 693 if controller:
653 694 targets = 'all'
654 695 targets = self._build_targets(targets)[0]
655 696 for t in targets:
656 697 self.session.send(self._control_socket, 'shutdown_request',
657 698 content={'restart':restart},ident=t)
658 699 error = False
659 700 if block or controller:
660 701 for i in range(len(targets)):
661 702 idents,msg = self.session.recv(self._control_socket,0)
662 703 if self.debug:
663 704 pprint(msg)
664 705 if msg['content']['status'] != 'ok':
665 706 error = ss.unwrap_exception(msg['content'])
666 707
667 708 if controller:
668 709 time.sleep(0.25)
669 710 self.session.send(self._query_socket, 'shutdown_request')
670 711 idents,msg = self.session.recv(self._query_socket, 0)
671 712 if self.debug:
672 713 pprint(msg)
673 714 if msg['content']['status'] != 'ok':
674 715 error = ss.unwrap_exception(msg['content'])
675 716
676 717 if error:
677 718 raise error
678 719
679 720 #--------------------------------------------------------------------------
680 721 # Execution methods
681 722 #--------------------------------------------------------------------------
682 723
683 724 @defaultblock
684 725 def execute(self, code, targets='all', block=None):
685 726 """Executes `code` on `targets` in blocking or nonblocking manner.
686 727
687 728 ``execute`` is always `bound` (affects engine namespace)
688 729
689 730 Parameters
690 731 ----------
691 732 code : str
692 733 the code string to be executed
693 734 targets : int/str/list of ints/strs
694 735 the engines on which to execute
695 736 default : all
696 737 block : bool
697 738 whether or not to wait until done to return
698 739 default: self.block
699 740 """
700 741 result = self.apply(_execute, (code,), targets=targets, block=self.block, bound=True)
701 742 return result
702 743
703 744 def run(self, code, block=None):
704 745 """Runs `code` on an engine.
705 746
706 747 Calls to this are load-balanced.
707 748
708 749 ``run`` is never `bound` (no effect on engine namespace)
709 750
710 751 Parameters
711 752 ----------
712 753 code : str
713 754 the code string to be executed
714 755 block : bool
715 756 whether or not to wait until done
716 757
717 758 """
718 759 result = self.apply(_execute, (code,), targets=None, block=block, bound=False)
719 760 return result
720 761
721 762 def _maybe_raise(self, result):
722 763 """wrapper for maybe raising an exception if apply failed."""
723 764 if isinstance(result, error.RemoteError):
724 765 raise result
725 766
726 767 return result
727 768
728 769 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
729 770 after=None, follow=None):
730 771 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
731 772
732 773 This is the central execution command for the client.
733 774
734 775 Parameters
735 776 ----------
736 777
737 778 f : function
738 779 The fuction to be called remotely
739 780 args : tuple/list
740 781 The positional arguments passed to `f`
741 782 kwargs : dict
742 783 The keyword arguments passed to `f`
743 784 bound : bool (default: True)
744 785 Whether to execute in the Engine(s) namespace, or in a clean
745 786 namespace not affecting the engine.
746 787 block : bool (default: self.block)
747 788 Whether to wait for the result, or return immediately.
748 789 False:
749 790 returns msg_id(s)
750 791 if multiple targets:
751 792 list of ids
752 793 True:
753 794 returns actual result(s) of f(*args, **kwargs)
754 795 if multiple targets:
755 796 dict of results, by engine ID
756 797 targets : int,list of ints, 'all', None
757 798 Specify the destination of the job.
758 799 if None:
759 800 Submit via Task queue for load-balancing.
760 801 if 'all':
761 802 Run on all active engines
762 803 if list:
763 804 Run on each specified engine
764 805 if int:
765 806 Run on single engine
766 807
767 808 after : Dependency or collection of msg_ids
768 809 Only for load-balanced execution (targets=None)
769 810 Specify a list of msg_ids as a time-based dependency.
770 811 This job will only be run *after* the dependencies
771 812 have been met.
772 813
773 814 follow : Dependency or collection of msg_ids
774 815 Only for load-balanced execution (targets=None)
775 816 Specify a list of msg_ids as a location-based dependency.
776 817 This job will only be run on an engine where this dependency
777 818 is met.
778 819
779 820 Returns
780 821 -------
781 822 if block is False:
782 823 if single target:
783 824 return msg_id
784 825 else:
785 826 return list of msg_ids
786 827 ? (should this be dict like block=True) ?
787 828 else:
788 829 if single target:
789 830 return result of f(*args, **kwargs)
790 831 else:
791 832 return dict of results, keyed by engine
792 833 """
793 834
794 835 # defaults:
795 836 block = block if block is not None else self.block
796 837 args = args if args is not None else []
797 838 kwargs = kwargs if kwargs is not None else {}
798 839
799 840 # enforce types of f,args,kwrags
800 841 if not callable(f):
801 842 raise TypeError("f must be callable, not %s"%type(f))
802 843 if not isinstance(args, (tuple, list)):
803 844 raise TypeError("args must be tuple or list, not %s"%type(args))
804 845 if not isinstance(kwargs, dict):
805 846 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
806 847
807 848 if isinstance(after, Dependency):
808 849 after = after.as_dict()
809 850 elif isinstance(after, AsyncResult):
810 851 after=after.msg_ids
811 852 elif after is None:
812 853 after = []
813 854 if isinstance(follow, Dependency):
855 # if len(follow) > 1 and follow.mode == 'all':
856 # warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning)
814 857 follow = follow.as_dict()
815 858 elif isinstance(follow, AsyncResult):
816 859 follow=follow.msg_ids
817 860 elif follow is None:
818 861 follow = []
819 862 options = dict(bound=bound, block=block, after=after, follow=follow)
820 863
821 864 if targets is None:
822 865 return self._apply_balanced(f, args, kwargs, **options)
823 866 else:
824 867 return self._apply_direct(f, args, kwargs, targets=targets, **options)
825 868
826 869 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
827 870 after=None, follow=None):
828 871 """The underlying method for applying functions in a load balanced
829 872 manner, via the task queue."""
830
831 873 subheader = dict(after=after, follow=follow)
832 874 bufs = ss.pack_apply_message(f,args,kwargs)
833 875 content = dict(bound=bound)
834 876
835 877 msg = self.session.send(self._task_socket, "apply_request",
836 878 content=content, buffers=bufs, subheader=subheader)
837 879 msg_id = msg['msg_id']
838 880 self.outstanding.add(msg_id)
839 881 self.history.append(msg_id)
840 882 ar = AsyncResult(self, [msg_id], fname=f.__name__)
841 883 if block:
842 884 return ar.get()
843 885 else:
844 886 return ar
845 887
846 888 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
847 889 after=None, follow=None):
848 890 """Then underlying method for applying functions to specific engines
849 891 via the MUX queue."""
850 892
851 893 queues,targets = self._build_targets(targets)
852 894
853 895 subheader = dict(after=after, follow=follow)
854 896 content = dict(bound=bound)
855 897 bufs = ss.pack_apply_message(f,args,kwargs)
856 898
857 899 msg_ids = []
858 900 for queue in queues:
859 901 msg = self.session.send(self._mux_socket, "apply_request",
860 902 content=content, buffers=bufs,ident=queue, subheader=subheader)
861 903 msg_id = msg['msg_id']
862 904 self.outstanding.add(msg_id)
863 905 self.history.append(msg_id)
864 906 msg_ids.append(msg_id)
865 907 ar = AsyncResult(self, msg_ids, fname=f.__name__)
866 908 if block:
867 909 return ar.get()
868 910 else:
869 911 return ar
870 912
871 913 #--------------------------------------------------------------------------
872 914 # Map and decorators
873 915 #--------------------------------------------------------------------------
874 916
875 917 def map(self, f, *sequences):
876 918 """Parallel version of builtin `map`, using all our engines."""
877 919 pf = ParallelFunction(self, f, block=self.block,
878 920 bound=True, targets='all')
879 921 return pf.map(*sequences)
880 922
881 923 def parallel(self, bound=True, targets='all', block=True):
882 924 """Decorator for making a ParallelFunction."""
883 925 return parallel(self, bound=bound, targets=targets, block=block)
884 926
885 927 def remote(self, bound=True, targets='all', block=True):
886 928 """Decorator for making a RemoteFunction."""
887 929 return remote(self, bound=bound, targets=targets, block=block)
888 930
889 931 #--------------------------------------------------------------------------
890 932 # Data movement
891 933 #--------------------------------------------------------------------------
892 934
893 935 @defaultblock
894 936 def push(self, ns, targets='all', block=None):
895 937 """Push the contents of `ns` into the namespace on `target`"""
896 938 if not isinstance(ns, dict):
897 939 raise TypeError("Must be a dict, not %s"%type(ns))
898 940 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
899 941 return result
900 942
901 943 @defaultblock
902 944 def pull(self, keys, targets='all', block=None):
903 945 """Pull objects from `target`'s namespace by `keys`"""
904 946 if isinstance(keys, str):
905 947 pass
906 948 elif isinstance(keys, (list,tuple,set)):
907 949 for key in keys:
908 950 if not isinstance(key, str):
909 951 raise TypeError
910 952 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
911 953 return result
912 954
913 955 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
914 956 """
915 957 Partition a Python sequence and send the partitions to a set of engines.
916 958 """
917 959 block = block if block is not None else self.block
918 960 targets = self._build_targets(targets)[-1]
919 961 mapObject = Map.dists[dist]()
920 962 nparts = len(targets)
921 963 msg_ids = []
922 964 for index, engineid in enumerate(targets):
923 965 partition = mapObject.getPartition(seq, index, nparts)
924 966 if flatten and len(partition) == 1:
925 967 r = self.push({key: partition[0]}, targets=engineid, block=False)
926 968 else:
927 969 r = self.push({key: partition}, targets=engineid, block=False)
928 970 msg_ids.extend(r.msg_ids)
929 971 r = AsyncResult(self, msg_ids, fname='scatter')
930 972 if block:
931 973 return r.get()
932 974 else:
933 975 return r
934 976
935 977 def gather(self, key, dist='b', targets='all', block=None):
936 978 """
937 979 Gather a partitioned sequence on a set of engines as a single local seq.
938 980 """
939 981 block = block if block is not None else self.block
940 982
941 983 targets = self._build_targets(targets)[-1]
942 984 mapObject = Map.dists[dist]()
943 985 msg_ids = []
944 986 for index, engineid in enumerate(targets):
945 987 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
946 988
947 989 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
948 990 if block:
949 991 return r.get()
950 992 else:
951 993 return r
952 994
953 995 #--------------------------------------------------------------------------
954 996 # Query methods
955 997 #--------------------------------------------------------------------------
956 998
957 999 @spinfirst
958 1000 def get_results(self, msg_ids, status_only=False):
959 1001 """Returns the result of the execute or task request with `msg_ids`.
960 1002
961 1003 Parameters
962 1004 ----------
963 1005 msg_ids : list of ints or msg_ids
964 1006 if int:
965 1007 Passed as index to self.history for convenience.
966 1008 status_only : bool (default: False)
967 1009 if False:
968 1010 return the actual results
969 1011
970 1012 Returns
971 1013 -------
972 1014
973 1015 results : dict
974 1016 There will always be the keys 'pending' and 'completed', which will
975 1017 be lists of msg_ids.
976 1018 """
977 1019 if not isinstance(msg_ids, (list,tuple)):
978 1020 msg_ids = [msg_ids]
979 1021 theids = []
980 1022 for msg_id in msg_ids:
981 1023 if isinstance(msg_id, int):
982 1024 msg_id = self.history[msg_id]
983 1025 if not isinstance(msg_id, str):
984 1026 raise TypeError("msg_ids must be str, not %r"%msg_id)
985 1027 theids.append(msg_id)
986 1028
987 1029 completed = []
988 1030 local_results = {}
989 1031 # temporarily disable local shortcut
990 1032 # for msg_id in list(theids):
991 1033 # if msg_id in self.results:
992 1034 # completed.append(msg_id)
993 1035 # local_results[msg_id] = self.results[msg_id]
994 1036 # theids.remove(msg_id)
995 1037
996 1038 if theids: # some not locally cached
997 1039 content = dict(msg_ids=theids, status_only=status_only)
998 1040 msg = self.session.send(self._query_socket, "result_request", content=content)
999 1041 zmq.select([self._query_socket], [], [])
1000 1042 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1001 1043 if self.debug:
1002 1044 pprint(msg)
1003 1045 content = msg['content']
1004 1046 if content['status'] != 'ok':
1005 1047 raise ss.unwrap_exception(content)
1006 1048 buffers = msg['buffers']
1007 1049 else:
1008 1050 content = dict(completed=[],pending=[])
1009 1051
1010 1052 content['completed'].extend(completed)
1011 1053
1012 1054 if status_only:
1013 1055 return content
1014 1056
1015 1057 failures = []
1016 1058 # load cached results into result:
1017 1059 content.update(local_results)
1018 1060 # update cache with results:
1019 1061 for msg_id in sorted(theids):
1020 1062 if msg_id in content['completed']:
1021 1063 rec = content[msg_id]
1022 1064 parent = rec['header']
1023 1065 header = rec['result_header']
1024 1066 rcontent = rec['result_content']
1025 1067 iodict = rec['io']
1026 1068 if isinstance(rcontent, str):
1027 1069 rcontent = self.session.unpack(rcontent)
1028 1070
1029 1071 md = self.metadata.setdefault(msg_id, Metadata())
1030 1072 md.update(self._extract_metadata(header, parent, rcontent))
1031 1073 md.update(iodict)
1032 1074
1033 1075 if rcontent['status'] == 'ok':
1034 1076 res,buffers = ss.unserialize_object(buffers)
1035 1077 else:
1036 1078 res = ss.unwrap_exception(rcontent)
1037 1079 failures.append(res)
1038 1080
1039 1081 self.results[msg_id] = res
1040 1082 content[msg_id] = res
1041 1083
1042 1084 error.collect_exceptions(failures, "get_results")
1043 1085 return content
1044 1086
1045 1087 @spinfirst
1046 1088 def queue_status(self, targets=None, verbose=False):
1047 1089 """Fetch the status of engine queues.
1048 1090
1049 1091 Parameters
1050 1092 ----------
1051 1093 targets : int/str/list of ints/strs
1052 1094 the engines on which to execute
1053 1095 default : all
1054 1096 verbose : bool
1055 1097 Whether to return lengths only, or lists of ids for each element
1056 1098 """
1057 1099 targets = self._build_targets(targets)[1]
1058 1100 content = dict(targets=targets, verbose=verbose)
1059 1101 self.session.send(self._query_socket, "queue_request", content=content)
1060 1102 idents,msg = self.session.recv(self._query_socket, 0)
1061 1103 if self.debug:
1062 1104 pprint(msg)
1063 1105 content = msg['content']
1064 1106 status = content.pop('status')
1065 1107 if status != 'ok':
1066 1108 raise ss.unwrap_exception(content)
1067 1109 return ss.rekey(content)
1068 1110
1069 1111 @spinfirst
1070 1112 def purge_results(self, msg_ids=[], targets=[]):
1071 1113 """Tell the controller to forget results.
1072 1114
1073 1115 Individual results can be purged by msg_id, or the entire
1074 1116 history of specific targets can be purged.
1075 1117
1076 1118 Parameters
1077 1119 ----------
1078 1120 msg_ids : str or list of strs
1079 1121 the msg_ids whose results should be forgotten.
1080 1122 targets : int/str/list of ints/strs
1081 1123 The targets, by uuid or int_id, whose entire history is to be purged.
1082 1124 Use `targets='all'` to scrub everything from the controller's memory.
1083 1125
1084 1126 default : None
1085 1127 """
1086 1128 if not targets and not msg_ids:
1087 1129 raise ValueError
1088 1130 if targets:
1089 1131 targets = self._build_targets(targets)[1]
1090 1132 content = dict(targets=targets, msg_ids=msg_ids)
1091 1133 self.session.send(self._query_socket, "purge_request", content=content)
1092 1134 idents, msg = self.session.recv(self._query_socket, 0)
1093 1135 if self.debug:
1094 1136 pprint(msg)
1095 1137 content = msg['content']
1096 1138 if content['status'] != 'ok':
1097 1139 raise ss.unwrap_exception(content)
1098 1140
1099 1141 #----------------------------------------
1100 1142 # activate for %px,%autopx magics
1101 1143 #----------------------------------------
1102 1144 def activate(self):
1103 1145 """Make this `View` active for parallel magic commands.
1104 1146
1105 1147 IPython has a magic command syntax to work with `MultiEngineClient` objects.
1106 1148 In a given IPython session there is a single active one. While
1107 1149 there can be many `Views` created and used by the user,
1108 1150 there is only one active one. The active `View` is used whenever
1109 1151 the magic commands %px and %autopx are used.
1110 1152
1111 1153 The activate() method is called on a given `View` to make it
1112 1154 active. Once this has been done, the magic commands can be used.
1113 1155 """
1114 1156
1115 1157 try:
1116 1158 # This is injected into __builtins__.
1117 1159 ip = get_ipython()
1118 1160 except NameError:
1119 1161 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
1120 1162 else:
1121 1163 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
1122 1164 if pmagic is not None:
1123 1165 pmagic.active_multiengine_client = self
1124 1166 else:
1125 1167 print "You must first load the parallelmagic extension " \
1126 1168 "by doing '%load_ext parallelmagic'"
1127 1169
1128 1170 class AsynClient(Client):
1129 1171 """An Asynchronous client, using the Tornado Event Loop.
1130 1172 !!!unfinished!!!"""
1131 1173 io_loop = None
1132 1174 _queue_stream = None
1133 1175 _notifier_stream = None
1134 1176 _task_stream = None
1135 1177 _control_stream = None
1136 1178
1137 1179 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
1138 1180 Client.__init__(self, addr, context, username, debug)
1139 1181 if io_loop is None:
1140 1182 io_loop = ioloop.IOLoop.instance()
1141 1183 self.io_loop = io_loop
1142 1184
1143 1185 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
1144 1186 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
1145 1187 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
1146 1188 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
1147 1189
1148 1190 def spin(self):
1149 1191 for stream in (self.queue_stream, self.notifier_stream,
1150 1192 self.task_stream, self.control_stream):
1151 1193 stream.flush()
1152 1194
1153 1195 __all__ = [ 'Client',
1154 1196 'depend',
1155 1197 'require',
1156 1198 'remote',
1157 1199 'parallel',
1158 1200 'RemoteFunction',
1159 1201 'ParallelFunction',
1160 1202 'DirectView',
1161 1203 'LoadBalancedView',
1162 1204 'AsyncResult',
1163 1205 'AsyncMapResult'
1164 1206 ]
@@ -1,90 +1,111 b''
1 1 """Dependency utilities"""
2 2
3 3 from IPython.external.decorator import decorator
4 from error import UnmetDependency
5
4 6
5 7 # flags
6 8 ALL = 1 << 0
7 9 ANY = 1 << 1
8 10 HERE = 1 << 2
9 11 ANYWHERE = 1 << 3
10 12
11 class UnmetDependency(Exception):
12 pass
13
14 13
15 14 class depend(object):
16 15 """Dependency decorator, for use with tasks."""
17 16 def __init__(self, f, *args, **kwargs):
18 17 self.f = f
19 18 self.args = args
20 19 self.kwargs = kwargs
21 20
22 21 def __call__(self, f):
23 22 return dependent(f, self.f, *self.args, **self.kwargs)
24 23
25 24 class dependent(object):
26 25 """A function that depends on another function.
27 26 This is an object to prevent the closure used
28 27 in traditional decorators, which are not picklable.
29 28 """
30 29
31 30 def __init__(self, f, df, *dargs, **dkwargs):
32 31 self.f = f
33 self.func_name = self.f.func_name
32 self.func_name = getattr(f, '__name__', 'f')
34 33 self.df = df
35 34 self.dargs = dargs
36 35 self.dkwargs = dkwargs
37 36
38 37 def __call__(self, *args, **kwargs):
39 38 if self.df(*self.dargs, **self.dkwargs) is False:
40 39 raise UnmetDependency()
41 40 return self.f(*args, **kwargs)
41
42 @property
43 def __name__(self):
44 return self.func_name
42 45
43 46 def _require(*names):
44 47 for name in names:
45 48 try:
46 49 __import__(name)
47 50 except ImportError:
48 51 return False
49 52 return True
50 53
51 54 def require(*names):
52 55 return depend(_require, *names)
53 56
54 57 class Dependency(set):
55 58 """An object for representing a set of dependencies.
56 59
57 60 Subclassed from set()."""
58 61
59 62 mode='all'
63 success_only=True
60 64
61 def __init__(self, dependencies=[], mode='all'):
65 def __init__(self, dependencies=[], mode='all', success_only=True):
62 66 if isinstance(dependencies, dict):
63 67 # load from dict
64 dependencies = dependencies.get('dependencies', [])
65 68 mode = dependencies.get('mode', mode)
69 success_only = dependencies.get('success_only', success_only)
70 dependencies = dependencies.get('dependencies', [])
66 71 set.__init__(self, dependencies)
67 72 self.mode = mode.lower()
73 self.success_only=success_only
68 74 if self.mode not in ('any', 'all'):
69 75 raise NotImplementedError("Only any|all supported, not %r"%mode)
70 76
71 def check(self, completed):
77 def check(self, completed, failed=None):
78 if failed is not None and not self.success_only:
79 completed = completed.union(failed)
72 80 if len(self) == 0:
73 81 return True
74 82 if self.mode == 'all':
75 83 return self.issubset(completed)
76 84 elif self.mode == 'any':
77 85 return not self.isdisjoint(completed)
78 86 else:
79 87 raise NotImplementedError("Only any|all supported, not %r"%mode)
80 88
89 def unreachable(self, failed):
90 if len(self) == 0 or len(failed) == 0 or not self.success_only:
91 return False
92 print self, self.success_only, self.mode, failed
93 if self.mode == 'all':
94 return not self.isdisjoint(failed)
95 elif self.mode == 'any':
96 return self.issubset(failed)
97 else:
98 raise NotImplementedError("Only any|all supported, not %r"%mode)
99
100
81 101 def as_dict(self):
82 102 """Represent this dependency as a dict. For json compatibility."""
83 103 return dict(
84 104 dependencies=list(self),
85 mode=self.mode
105 mode=self.mode,
106 success_only=self.success_only,
86 107 )
87 108
88 109
89 __all__ = ['UnmetDependency', 'depend', 'require', 'Dependency']
110 __all__ = ['depend', 'require', 'Dependency']
90 111
@@ -1,283 +1,289 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes and functions for kernel related errors and exceptions."""
4 4 from __future__ import print_function
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Error classes
20 20 #-------------------------------------------------------------------------------
21 21 class IPythonError(Exception):
22 22 """Base exception that all of our exceptions inherit from.
23 23
24 24 This can be raised by code that doesn't have any more specific
25 25 information."""
26 26
27 27 pass
28 28
29 29 # Exceptions associated with the controller objects
30 30 class ControllerError(IPythonError): pass
31 31
32 32 class ControllerCreationError(ControllerError): pass
33 33
34 34
35 35 # Exceptions associated with the Engines
36 36 class EngineError(IPythonError): pass
37 37
38 38 class EngineCreationError(EngineError): pass
39 39
40 40 class KernelError(IPythonError):
41 41 pass
42 42
43 43 class NotDefined(KernelError):
44 44 def __init__(self, name):
45 45 self.name = name
46 46 self.args = (name,)
47 47
48 48 def __repr__(self):
49 49 return '<NotDefined: %s>' % self.name
50 50
51 51 __str__ = __repr__
52 52
53 53
54 54 class QueueCleared(KernelError):
55 55 pass
56 56
57 57
58 58 class IdInUse(KernelError):
59 59 pass
60 60
61 61
62 62 class ProtocolError(KernelError):
63 63 pass
64 64
65 65
66 66 class ConnectionError(KernelError):
67 67 pass
68 68
69 69
70 70 class InvalidEngineID(KernelError):
71 71 pass
72 72
73 73
74 74 class NoEnginesRegistered(KernelError):
75 75 pass
76 76
77 77
78 78 class InvalidClientID(KernelError):
79 79 pass
80 80
81 81
82 82 class InvalidDeferredID(KernelError):
83 83 pass
84 84
85 85
86 86 class SerializationError(KernelError):
87 87 pass
88 88
89 89
90 90 class MessageSizeError(KernelError):
91 91 pass
92 92
93 93
94 94 class PBMessageSizeError(MessageSizeError):
95 95 pass
96 96
97 97
98 98 class ResultNotCompleted(KernelError):
99 99 pass
100 100
101 101
102 102 class ResultAlreadyRetrieved(KernelError):
103 103 pass
104 104
105 105 class ClientError(KernelError):
106 106 pass
107 107
108 108
109 109 class TaskAborted(KernelError):
110 110 pass
111 111
112 112
113 113 class TaskTimeout(KernelError):
114 114 pass
115 115
116 116
117 117 class NotAPendingResult(KernelError):
118 118 pass
119 119
120 120
121 121 class UnpickleableException(KernelError):
122 122 pass
123 123
124 124
125 125 class AbortedPendingDeferredError(KernelError):
126 126 pass
127 127
128 128
129 129 class InvalidProperty(KernelError):
130 130 pass
131 131
132 132
133 133 class MissingBlockArgument(KernelError):
134 134 pass
135 135
136 136
137 137 class StopLocalExecution(KernelError):
138 138 pass
139 139
140 140
141 141 class SecurityError(KernelError):
142 142 pass
143 143
144 144
145 145 class FileTimeoutError(KernelError):
146 146 pass
147 147
148 148 class TimeoutError(KernelError):
149 149 pass
150 150
151 class UnmetDependency(KernelError):
152 pass
153
154 class ImpossibleDependency(UnmetDependency):
155 pass
156
151 157 class RemoteError(KernelError):
152 158 """Error raised elsewhere"""
153 159 ename=None
154 160 evalue=None
155 161 traceback=None
156 162 engine_info=None
157 163
158 164 def __init__(self, ename, evalue, traceback, engine_info=None):
159 165 self.ename=ename
160 166 self.evalue=evalue
161 167 self.traceback=traceback
162 168 self.engine_info=engine_info or {}
163 169 self.args=(ename, evalue)
164 170
165 171 def __repr__(self):
166 172 engineid = self.engine_info.get('engineid', ' ')
167 173 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
168 174
169 175 def __str__(self):
170 176 sig = "%s(%s)"%(self.ename, self.evalue)
171 177 if self.traceback:
172 178 return sig + '\n' + self.traceback
173 179 else:
174 180 return sig
175 181
176 182
177 183 class TaskRejectError(KernelError):
178 184 """Exception to raise when a task should be rejected by an engine.
179 185
180 186 This exception can be used to allow a task running on an engine to test
181 187 if the engine (or the user's namespace on the engine) has the needed
182 188 task dependencies. If not, the task should raise this exception. For
183 189 the task to be retried on another engine, the task should be created
184 190 with the `retries` argument > 1.
185 191
186 192 The advantage of this approach over our older properties system is that
187 193 tasks have full access to the user's namespace on the engines and the
188 194 properties don't have to be managed or tested by the controller.
189 195 """
190 196
191 197
192 198 class CompositeError(KernelError):
193 199 """Error for representing possibly multiple errors on engines"""
194 200 def __init__(self, message, elist):
195 201 Exception.__init__(self, *(message, elist))
196 202 # Don't use pack_exception because it will conflict with the .message
197 203 # attribute that is being deprecated in 2.6 and beyond.
198 204 self.msg = message
199 205 self.elist = elist
200 206 self.args = [ e[0] for e in elist ]
201 207
202 208 def _get_engine_str(self, ei):
203 209 if not ei:
204 210 return '[Engine Exception]'
205 211 else:
206 212 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
207 213
208 214 def _get_traceback(self, ev):
209 215 try:
210 216 tb = ev._ipython_traceback_text
211 217 except AttributeError:
212 218 return 'No traceback available'
213 219 else:
214 220 return tb
215 221
216 222 def __str__(self):
217 223 s = str(self.msg)
218 224 for en, ev, etb, ei in self.elist:
219 225 engine_str = self._get_engine_str(ei)
220 226 s = s + '\n' + engine_str + en + ': ' + str(ev)
221 227 return s
222 228
223 229 def __repr__(self):
224 230 return "CompositeError(%i)"%len(self.elist)
225 231
226 232 def print_tracebacks(self, excid=None):
227 233 if excid is None:
228 234 for (en,ev,etb,ei) in self.elist:
229 235 print (self._get_engine_str(ei))
230 236 print (etb or 'No traceback available')
231 237 print ()
232 238 else:
233 239 try:
234 240 en,ev,etb,ei = self.elist[excid]
235 241 except:
236 242 raise IndexError("an exception with index %i does not exist"%excid)
237 243 else:
238 244 print (self._get_engine_str(ei))
239 245 print (etb or 'No traceback available')
240 246
241 247 def raise_exception(self, excid=0):
242 248 try:
243 249 en,ev,etb,ei = self.elist[excid]
244 250 except:
245 251 raise IndexError("an exception with index %i does not exist"%excid)
246 252 else:
247 253 try:
248 254 raise RemoteError(en, ev, etb, ei)
249 255 except:
250 256 et,ev,tb = sys.exc_info()
251 257
252 258
253 259 def collect_exceptions(rdict_or_list, method='unspecified'):
254 260 """check a result dict for errors, and raise CompositeError if any exist.
255 261 Passthrough otherwise."""
256 262 elist = []
257 263 if isinstance(rdict_or_list, dict):
258 264 rlist = rdict_or_list.values()
259 265 else:
260 266 rlist = rdict_or_list
261 267 for r in rlist:
262 268 if isinstance(r, RemoteError):
263 269 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
264 270 # Sometimes we could have CompositeError in our list. Just take
265 271 # the errors out of them and put them in our new list. This
266 272 # has the effect of flattening lists of CompositeErrors into one
267 273 # CompositeError
268 274 if en=='CompositeError':
269 275 for e in ev.elist:
270 276 elist.append(e)
271 277 else:
272 278 elist.append((en, ev, etb, ei))
273 279 if len(elist)==0:
274 280 return rdict_or_list
275 281 else:
276 282 msg = "one or more exceptions from call to method: %s" % (method)
277 283 # This silliness is needed so the debugger has access to the exception
278 284 # instance (e in this case)
279 285 try:
280 286 raise CompositeError(msg, elist)
281 287 except CompositeError, e:
282 288 raise e
283 289
@@ -1,426 +1,509 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 #----------------------------------------------------------------------
9 9 # Imports
10 10 #----------------------------------------------------------------------
11 11
12 12 from __future__ import print_function
13 13 from random import randint,random
14 14 import logging
15 15 from types import FunctionType
16 16
17 17 try:
18 18 import numpy
19 19 except ImportError:
20 20 numpy = None
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 # local imports
26 26 from IPython.external.decorator import decorator
27 27 from IPython.config.configurable import Configurable
28 28 from IPython.utils.traitlets import Instance, Dict, List, Set
29 29
30 import error
30 31 from client import Client
31 32 from dependency import Dependency
32 33 import streamsession as ss
33 34 from entry_point import connect_logger, local_logger
34 35
35 36
36 37 @decorator
37 38 def logged(f,self,*args,**kwargs):
38 39 # print ("#--------------------")
39 40 logging.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
40 41 # print ("#--")
41 42 return f(self,*args, **kwargs)
42 43
43 44 #----------------------------------------------------------------------
44 45 # Chooser functions
45 46 #----------------------------------------------------------------------
46 47
47 48 def plainrandom(loads):
48 49 """Plain random pick."""
49 50 n = len(loads)
50 51 return randint(0,n-1)
51 52
52 53 def lru(loads):
53 54 """Always pick the front of the line.
54 55
55 56 The content of `loads` is ignored.
56 57
57 58 Assumes LRU ordering of loads, with oldest first.
58 59 """
59 60 return 0
60 61
61 62 def twobin(loads):
62 63 """Pick two at random, use the LRU of the two.
63 64
64 65 The content of loads is ignored.
65 66
66 67 Assumes LRU ordering of loads, with oldest first.
67 68 """
68 69 n = len(loads)
69 70 a = randint(0,n-1)
70 71 b = randint(0,n-1)
71 72 return min(a,b)
72 73
73 74 def weighted(loads):
74 75 """Pick two at random using inverse load as weight.
75 76
76 77 Return the less loaded of the two.
77 78 """
78 79 # weight 0 a million times more than 1:
79 80 weights = 1./(1e-6+numpy.array(loads))
80 81 sums = weights.cumsum()
81 82 t = sums[-1]
82 83 x = random()*t
83 84 y = random()*t
84 85 idx = 0
85 86 idy = 0
86 87 while sums[idx] < x:
87 88 idx += 1
88 89 while sums[idy] < y:
89 90 idy += 1
90 91 if weights[idy] > weights[idx]:
91 92 return idy
92 93 else:
93 94 return idx
94 95
95 96 def leastload(loads):
96 97 """Always choose the lowest load.
97 98
98 99 If the lowest load occurs more than once, the first
99 100 occurance will be used. If loads has LRU ordering, this means
100 101 the LRU of those with the lowest load is chosen.
101 102 """
102 103 return loads.index(min(loads))
103 104
104 105 #---------------------------------------------------------------------
105 106 # Classes
106 107 #---------------------------------------------------------------------
108 # store empty default dependency:
109 MET = Dependency([])
110
107 111 class TaskScheduler(Configurable):
108 112 """Python TaskScheduler object.
109 113
110 114 This is the simplest object that supports msg_id based
111 115 DAG dependencies. *Only* task msg_ids are checked, not
112 116 msg_ids of jobs submitted via the MUX queue.
113 117
114 118 """
115 119
116 120 # input arguments:
117 121 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
118 122 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
119 123 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
120 124 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
121 125 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
122 126 io_loop = Instance(ioloop.IOLoop)
123 127
124 128 # internals:
125 129 dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
126 130 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
127 131 pending = Dict() # dict by engine_uuid of submitted tasks
128 132 completed = Dict() # dict by engine_uuid of completed tasks
133 failed = Dict() # dict by engine_uuid of failed tasks
134 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
129 135 clients = Dict() # dict by msg_id for who submitted the task
130 136 targets = List() # list of target IDENTs
131 137 loads = List() # list of engine loads
132 all_done = Set() # set of all completed tasks
138 all_completed = Set() # set of all completed tasks
139 all_failed = Set() # set of all failed tasks
140 all_done = Set() # set of all finished tasks=union(completed,failed)
133 141 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
134 142 session = Instance(ss.StreamSession)
135 143
136 144
137 145 def __init__(self, **kwargs):
138 146 super(TaskScheduler, self).__init__(**kwargs)
139 147
140 148 self.session = ss.StreamSession(username="TaskScheduler")
141 149
142 150 self.engine_stream.on_recv(self.dispatch_result, copy=False)
143 151 self._notification_handlers = dict(
144 152 registration_notification = self._register_engine,
145 153 unregistration_notification = self._unregister_engine
146 154 )
147 155 self.notifier_stream.on_recv(self.dispatch_notification)
148 156 logging.info("Scheduler started...%r"%self)
149 157
150 158 def resume_receiving(self):
151 159 """Resume accepting jobs."""
152 160 self.client_stream.on_recv(self.dispatch_submission, copy=False)
153 161
154 162 def stop_receiving(self):
155 163 """Stop accepting jobs while there are no engines.
156 164 Leave them in the ZMQ queue."""
157 165 self.client_stream.on_recv(None)
158 166
159 167 #-----------------------------------------------------------------------
160 168 # [Un]Registration Handling
161 169 #-----------------------------------------------------------------------
162 170
163 171 def dispatch_notification(self, msg):
164 172 """dispatch register/unregister events."""
165 173 idents,msg = self.session.feed_identities(msg)
166 174 msg = self.session.unpack_message(msg)
167 175 msg_type = msg['msg_type']
168 176 handler = self._notification_handlers.get(msg_type, None)
169 177 if handler is None:
170 178 raise Exception("Unhandled message type: %s"%msg_type)
171 179 else:
172 180 try:
173 181 handler(str(msg['content']['queue']))
174 182 except KeyError:
175 183 logging.error("task::Invalid notification msg: %s"%msg)
176 184
177 185 @logged
178 186 def _register_engine(self, uid):
179 187 """New engine with ident `uid` became available."""
180 188 # head of the line:
181 189 self.targets.insert(0,uid)
182 190 self.loads.insert(0,0)
183 191 # initialize sets
184 192 self.completed[uid] = set()
193 self.failed[uid] = set()
185 194 self.pending[uid] = {}
186 195 if len(self.targets) == 1:
187 196 self.resume_receiving()
188 197
189 198 def _unregister_engine(self, uid):
190 199 """Existing engine with ident `uid` became unavailable."""
191 200 if len(self.targets) == 1:
192 201 # this was our only engine
193 202 self.stop_receiving()
194 203
195 204 # handle any potentially finished tasks:
196 205 self.engine_stream.flush()
197 206
198 207 self.completed.pop(uid)
208 self.failed.pop(uid)
209 # don't pop destinations, because it might be used later
210 # map(self.destinations.pop, self.completed.pop(uid))
211 # map(self.destinations.pop, self.failed.pop(uid))
212
199 213 lost = self.pending.pop(uid)
200 214
201 215 idx = self.targets.index(uid)
202 216 self.targets.pop(idx)
203 217 self.loads.pop(idx)
204 218
205 219 self.handle_stranded_tasks(lost)
206 220
207 221 def handle_stranded_tasks(self, lost):
208 222 """Deal with jobs resident in an engine that died."""
209 223 # TODO: resubmit the tasks?
210 224 for msg_id in lost:
211 225 pass
212 226
213 227
214 228 #-----------------------------------------------------------------------
215 229 # Job Submission
216 230 #-----------------------------------------------------------------------
217 231 @logged
218 232 def dispatch_submission(self, raw_msg):
219 233 """Dispatch job submission to appropriate handlers."""
220 234 # ensure targets up to date:
221 235 self.notifier_stream.flush()
222 236 try:
223 237 idents, msg = self.session.feed_identities(raw_msg, copy=False)
224 238 except Exception as e:
225 239 logging.error("task::Invaid msg: %s"%msg)
226 240 return
227 241
228 242 # send to monitor
229 243 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
230 244
231 245 msg = self.session.unpack_message(msg, content=False, copy=False)
232 246 header = msg['header']
233 247 msg_id = header['msg_id']
234 248
235 249 # time dependencies
236 250 after = Dependency(header.get('after', []))
237 251 if after.mode == 'all':
238 after.difference_update(self.all_done)
239 if after.check(self.all_done):
252 after.difference_update(self.all_completed)
253 if not after.success_only:
254 after.difference_update(self.all_failed)
255 if after.check(self.all_completed, self.all_failed):
240 256 # recast as empty set, if `after` already met,
241 257 # to prevent unnecessary set comparisons
242 after = Dependency([])
258 after = MET
243 259
244 260 # location dependencies
245 261 follow = Dependency(header.get('follow', []))
246 if len(after) == 0:
262
263 # check if unreachable:
264 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
265 self.depending[msg_id] = [raw_msg,MET,MET]
266 return self.fail_unreachable(msg_id)
267
268 if after.check(self.all_completed, self.all_failed):
247 269 # time deps already met, try to run
248 270 if not self.maybe_run(msg_id, raw_msg, follow):
249 271 # can't run yet
250 272 self.save_unmet(msg_id, raw_msg, after, follow)
251 273 else:
252 274 self.save_unmet(msg_id, raw_msg, after, follow)
253 275
254 276 @logged
277 def fail_unreachable(self, msg_id):
278 """a message has become unreachable"""
279 if msg_id not in self.depending:
280 logging.error("msg %r already failed!"%msg_id)
281 return
282 raw_msg, after, follow = self.depending.pop(msg_id)
283 for mid in follow.union(after):
284 if mid in self.dependencies:
285 self.dependencies[mid].remove(msg_id)
286
287 idents,msg = self.session.feed_identities(raw_msg, copy=False)
288 msg = self.session.unpack_message(msg, copy=False, content=False)
289 header = msg['header']
290
291 try:
292 raise error.ImpossibleDependency()
293 except:
294 content = ss.wrap_exception()
295
296 self.all_done.add(msg_id)
297 self.all_failed.add(msg_id)
298
299 msg = self.session.send(self.client_stream, 'apply_reply', content,
300 parent=header, ident=idents)
301 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
302
303 self.update_dependencies(msg_id, success=False)
304
305 @logged
255 306 def maybe_run(self, msg_id, raw_msg, follow=None):
256 307 """check location dependencies, and run if they are met."""
257 308
258 309 if follow:
259 310 def can_run(idx):
260 311 target = self.targets[idx]
261 312 return target not in self.blacklist.get(msg_id, []) and\
262 follow.check(self.completed[target])
313 follow.check(self.completed[target], self.failed[target])
263 314
264 315 indices = filter(can_run, range(len(self.targets)))
265 316 if not indices:
317 # TODO evaluate unmeetable follow dependencies
318 if follow.mode == 'all':
319 dests = set()
320 relevant = self.all_completed if follow.success_only else self.all_done
321 for m in follow.intersection(relevant):
322 dests.add(self.destinations[m])
323 if len(dests) > 1:
324 self.fail_unreachable(msg_id)
325
326
266 327 return False
267 328 else:
268 329 indices = None
269 330
270 331 self.submit_task(msg_id, raw_msg, indices)
271 332 return True
272 333
273 334 @logged
274 def save_unmet(self, msg_id, msg, after, follow):
335 def save_unmet(self, msg_id, raw_msg, after, follow):
275 336 """Save a message for later submission when its dependencies are met."""
276 self.depending[msg_id] = (msg_id,msg,after,follow)
277 # track the ids in both follow/after, but not those already completed
337 self.depending[msg_id] = [raw_msg,after,follow]
338 # track the ids in follow or after, but not those already finished
278 339 for dep_id in after.union(follow).difference(self.all_done):
279 340 if dep_id not in self.dependencies:
280 341 self.dependencies[dep_id] = set()
281 342 self.dependencies[dep_id].add(msg_id)
282 343
283 344 @logged
284 345 def submit_task(self, msg_id, msg, follow=None, indices=None):
285 346 """Submit a task to any of a subset of our targets."""
286 347 if indices:
287 348 loads = [self.loads[i] for i in indices]
288 349 else:
289 350 loads = self.loads
290 351 idx = self.scheme(loads)
291 352 if indices:
292 353 idx = indices[idx]
293 354 target = self.targets[idx]
294 355 # print (target, map(str, msg[:3]))
295 356 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
296 357 self.engine_stream.send_multipart(msg, copy=False)
297 358 self.add_job(idx)
298 359 self.pending[target][msg_id] = (msg, follow)
299 360 content = dict(msg_id=msg_id, engine_id=target)
300 361 self.session.send(self.mon_stream, 'task_destination', content=content,
301 362 ident=['tracktask',self.session.session])
302 363
303 364 #-----------------------------------------------------------------------
304 365 # Result Handling
305 366 #-----------------------------------------------------------------------
306 367 @logged
307 368 def dispatch_result(self, raw_msg):
308 369 try:
309 370 idents,msg = self.session.feed_identities(raw_msg, copy=False)
310 371 except Exception as e:
311 372 logging.error("task::Invaid result: %s"%msg)
312 373 return
313 374 msg = self.session.unpack_message(msg, content=False, copy=False)
314 375 header = msg['header']
315 376 if header.get('dependencies_met', True):
316 self.handle_result_success(idents, msg['parent_header'], raw_msg)
317 # send to monitor
377 success = (header['status'] == 'ok')
378 self.handle_result(idents, msg['parent_header'], raw_msg, success)
379 # send to Hub monitor
318 380 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
319 381 else:
320 382 self.handle_unmet_dependency(idents, msg['parent_header'])
321 383
322 384 @logged
323 def handle_result_success(self, idents, parent, raw_msg):
385 def handle_result(self, idents, parent, raw_msg, success=True):
324 386 # first, relay result to client
325 387 engine = idents[0]
326 388 client = idents[1]
327 389 # swap_ids for XREP-XREP mirror
328 390 raw_msg[:2] = [client,engine]
329 391 # print (map(str, raw_msg[:4]))
330 392 self.client_stream.send_multipart(raw_msg, copy=False)
331 393 # now, update our data structures
332 394 msg_id = parent['msg_id']
333 395 self.pending[engine].pop(msg_id)
334 self.completed[engine].add(msg_id)
396 if success:
397 self.completed[engine].add(msg_id)
398 self.all_completed.add(msg_id)
399 else:
400 self.failed[engine].add(msg_id)
401 self.all_failed.add(msg_id)
335 402 self.all_done.add(msg_id)
403 self.destinations[msg_id] = engine
336 404
337 self.update_dependencies(msg_id)
405 self.update_dependencies(msg_id, success)
338 406
339 407 @logged
340 408 def handle_unmet_dependency(self, idents, parent):
341 409 engine = idents[0]
342 410 msg_id = parent['msg_id']
343 411 if msg_id not in self.blacklist:
344 412 self.blacklist[msg_id] = set()
345 413 self.blacklist[msg_id].add(engine)
346 414 raw_msg,follow = self.pending[engine].pop(msg_id)
347 415 if not self.maybe_run(msg_id, raw_msg, follow):
348 416 # resubmit failed, put it back in our dependency tree
349 self.save_unmet(msg_id, raw_msg, Dependency(), follow)
417 self.save_unmet(msg_id, raw_msg, MET, follow)
350 418 pass
419
351 420 @logged
352 def update_dependencies(self, dep_id):
421 def update_dependencies(self, dep_id, success=True):
353 422 """dep_id just finished. Update our dependency
354 423 table and submit any jobs that just became runable."""
355
424 # print ("\n\n***********")
425 # pprint (dep_id)
426 # pprint (self.dependencies)
427 # pprint (self.depending)
428 # pprint (self.all_completed)
429 # pprint (self.all_failed)
430 # print ("\n\n***********\n\n")
356 431 if dep_id not in self.dependencies:
357 432 return
358 433 jobs = self.dependencies.pop(dep_id)
359 for job in jobs:
360 msg_id, raw_msg, after, follow = self.depending[job]
361 if dep_id in after:
362 after.remove(dep_id)
363 if not after: # time deps met, maybe run
434
435 for msg_id in jobs:
436 raw_msg, after, follow = self.depending[msg_id]
437 # if dep_id in after:
438 # if after.mode == 'all' and (success or not after.success_only):
439 # after.remove(dep_id)
440
441 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
442 self.fail_unreachable(msg_id)
443
444 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
445 self.depending[msg_id][1] = MET
364 446 if self.maybe_run(msg_id, raw_msg, follow):
365 self.depending.pop(job)
366 for mid in follow:
447
448 self.depending.pop(msg_id)
449 for mid in follow.union(after):
367 450 if mid in self.dependencies:
368 451 self.dependencies[mid].remove(msg_id)
369 452
370 453 #----------------------------------------------------------------------
371 454 # methods to be overridden by subclasses
372 455 #----------------------------------------------------------------------
373 456
374 457 def add_job(self, idx):
375 458 """Called after self.targets[idx] just got the job with header.
376 459 Override with subclasses. The default ordering is simple LRU.
377 460 The default loads are the number of outstanding jobs."""
378 461 self.loads[idx] += 1
379 462 for lis in (self.targets, self.loads):
380 463 lis.append(lis.pop(idx))
381 464
382 465
383 466 def finish_job(self, idx):
384 467 """Called after self.targets[idx] just finished a job.
385 468 Override with subclasses."""
386 469 self.loads[idx] -= 1
387 470
388 471
389 472
390 473 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, log_addr=None, loglevel=logging.DEBUG, scheme='weighted'):
391 474 from zmq.eventloop import ioloop
392 475 from zmq.eventloop.zmqstream import ZMQStream
393 476
394 477 ctx = zmq.Context()
395 478 loop = ioloop.IOLoop()
396 479
397 480 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
398 481 ins.bind(in_addr)
399 482 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
400 483 outs.bind(out_addr)
401 484 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
402 485 mons.connect(mon_addr)
403 486 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
404 487 nots.setsockopt(zmq.SUBSCRIBE, '')
405 488 nots.connect(not_addr)
406 489
407 490 scheme = globals().get(scheme, None)
408 491 # setup logging
409 492 if log_addr:
410 493 connect_logger(ctx, log_addr, root="scheduler", loglevel=loglevel)
411 494 else:
412 495 local_logger(loglevel)
413 496
414 497 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
415 498 mon_stream=mons,notifier_stream=nots,
416 499 scheme=scheme,io_loop=loop)
417 500
418 501 try:
419 502 loop.start()
420 503 except KeyboardInterrupt:
421 504 print ("interrupted, exiting...", file=sys.__stderr__)
422 505
423 506
424 507 if __name__ == '__main__':
425 508 iface = 'tcp://127.0.0.1:%i'
426 509 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
@@ -1,490 +1,483 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Imports
8 8 #-----------------------------------------------------------------------------
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 12 import __builtin__
13 13 from code import CommandCompiler
14 14 import os
15 15 import sys
16 16 import time
17 17 import traceback
18 18 import logging
19 19 from datetime import datetime
20 20 from signal import SIGTERM, SIGKILL
21 21 from pprint import pprint
22 22
23 23 # System library imports.
24 24 import zmq
25 25 from zmq.eventloop import ioloop, zmqstream
26 26
27 27 # Local imports.
28 28 from IPython.core import ultratb
29 29 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
30 30 from IPython.zmq.completer import KernelCompleter
31 31 from IPython.zmq.iostream import OutStream
32 32 from IPython.zmq.displayhook import DisplayHook
33 33
34 34 from factory import SessionFactory
35 35 from streamsession import StreamSession, Message, extract_header, serialize_object,\
36 36 unpack_apply_message, ISO8601, wrap_exception
37 from dependency import UnmetDependency
38 37 import heartmonitor
39 38 from client import Client
40 39
41 40 def printer(*args):
42 41 pprint(args, stream=sys.__stdout__)
43 42
44 43
45 44 class _Passer:
46 45 """Empty class that implements `send()` that does nothing."""
47 46 def send(self, *args, **kwargs):
48 47 pass
49 48 send_multipart = send
50 49
51 50
52 51 #-----------------------------------------------------------------------------
53 52 # Main kernel class
54 53 #-----------------------------------------------------------------------------
55 54
56 55 class Kernel(SessionFactory):
57 56
58 57 #---------------------------------------------------------------------------
59 58 # Kernel interface
60 59 #---------------------------------------------------------------------------
61 60
62 61 # kwargs:
63 62 int_id = Int(-1, config=True)
64 63 user_ns = Dict(config=True)
65 64 exec_lines = List(config=True)
66 65
67 66 control_stream = Instance(zmqstream.ZMQStream)
68 67 task_stream = Instance(zmqstream.ZMQStream)
69 68 iopub_stream = Instance(zmqstream.ZMQStream)
70 69 client = Instance('IPython.zmq.parallel.client.Client')
71 70
72 71 # internals
73 72 shell_streams = List()
74 73 compiler = Instance(CommandCompiler, (), {})
75 74 completer = Instance(KernelCompleter)
76 75
77 76 aborted = Set()
78 77 shell_handlers = Dict()
79 78 control_handlers = Dict()
80 79
81 80 def _set_prefix(self):
82 81 self.prefix = "engine.%s"%self.int_id
83 82
84 83 def _connect_completer(self):
85 84 self.completer = KernelCompleter(self.user_ns)
86 85
87 86 def __init__(self, **kwargs):
88 87 super(Kernel, self).__init__(**kwargs)
89 88 self._set_prefix()
90 89 self._connect_completer()
91 90
92 91 self.on_trait_change(self._set_prefix, 'id')
93 92 self.on_trait_change(self._connect_completer, 'user_ns')
94 93
95 94 # Build dict of handlers for message types
96 95 for msg_type in ['execute_request', 'complete_request', 'apply_request',
97 96 'clear_request']:
98 97 self.shell_handlers[msg_type] = getattr(self, msg_type)
99 98
100 99 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
101 100 self.control_handlers[msg_type] = getattr(self, msg_type)
102 101
103 102 self._initial_exec_lines()
104 103
105 104 def _wrap_exception(self, method=None):
106 105 e_info = dict(engineid=self.ident, method=method)
107 106 content=wrap_exception(e_info)
108 107 return content
109 108
110 109 def _initial_exec_lines(self):
111 110 s = _Passer()
112 111 content = dict(silent=True, user_variable=[],user_expressions=[])
113 112 for line in self.exec_lines:
114 113 logging.debug("executing initialization: %s"%line)
115 114 content.update({'code':line})
116 115 msg = self.session.msg('execute_request', content)
117 116 self.execute_request(s, [], msg)
118 117
119 118
120 119 #-------------------- control handlers -----------------------------
121 120 def abort_queues(self):
122 121 for stream in self.shell_streams:
123 122 if stream:
124 123 self.abort_queue(stream)
125 124
126 125 def abort_queue(self, stream):
127 126 while True:
128 127 try:
129 128 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
130 129 except zmq.ZMQError as e:
131 130 if e.errno == zmq.EAGAIN:
132 131 break
133 132 else:
134 133 return
135 134 else:
136 135 if msg is None:
137 136 return
138 137 else:
139 138 idents,msg = msg
140 139
141 140 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
142 141 # msg = self.reply_socket.recv_json()
143 142 logging.info("Aborting:")
144 143 logging.info(str(msg))
145 144 msg_type = msg['msg_type']
146 145 reply_type = msg_type.split('_')[0] + '_reply'
147 146 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
148 147 # self.reply_socket.send(ident,zmq.SNDMORE)
149 148 # self.reply_socket.send_json(reply_msg)
150 149 reply_msg = self.session.send(stream, reply_type,
151 150 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
152 151 logging.debug(str(reply_msg))
153 152 # We need to wait a bit for requests to come in. This can probably
154 153 # be set shorter for true asynchronous clients.
155 154 time.sleep(0.05)
156 155
157 156 def abort_request(self, stream, ident, parent):
158 157 """abort a specifig msg by id"""
159 158 msg_ids = parent['content'].get('msg_ids', None)
160 159 if isinstance(msg_ids, basestring):
161 160 msg_ids = [msg_ids]
162 161 if not msg_ids:
163 162 self.abort_queues()
164 163 for mid in msg_ids:
165 164 self.aborted.add(str(mid))
166 165
167 166 content = dict(status='ok')
168 167 reply_msg = self.session.send(stream, 'abort_reply', content=content,
169 168 parent=parent, ident=ident)[0]
170 169 logging.debug(str(reply_msg))
171 170
172 171 def shutdown_request(self, stream, ident, parent):
173 172 """kill ourself. This should really be handled in an external process"""
174 173 try:
175 174 self.abort_queues()
176 175 except:
177 176 content = self._wrap_exception('shutdown')
178 177 else:
179 178 content = dict(parent['content'])
180 179 content['status'] = 'ok'
181 180 msg = self.session.send(stream, 'shutdown_reply',
182 181 content=content, parent=parent, ident=ident)
183 182 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
184 183 # content, parent, ident)
185 184 # print >> sys.__stdout__, msg
186 185 # time.sleep(0.2)
187 186 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
188 187 dc.start()
189 188
190 189 def dispatch_control(self, msg):
191 190 idents,msg = self.session.feed_identities(msg, copy=False)
192 191 try:
193 192 msg = self.session.unpack_message(msg, content=True, copy=False)
194 193 except:
195 194 logging.error("Invalid Message", exc_info=True)
196 195 return
197 196
198 197 header = msg['header']
199 198 msg_id = header['msg_id']
200 199
201 200 handler = self.control_handlers.get(msg['msg_type'], None)
202 201 if handler is None:
203 202 logging.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
204 203 else:
205 204 handler(self.control_stream, idents, msg)
206 205
207 206
208 207 #-------------------- queue helpers ------------------------------
209 208
210 209 def check_dependencies(self, dependencies):
211 210 if not dependencies:
212 211 return True
213 212 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
214 213 anyorall = dependencies[0]
215 214 dependencies = dependencies[1]
216 215 else:
217 216 anyorall = 'all'
218 217 results = self.client.get_results(dependencies,status_only=True)
219 218 if results['status'] != 'ok':
220 219 return False
221 220
222 221 if anyorall == 'any':
223 222 if not results['completed']:
224 223 return False
225 224 else:
226 225 if results['pending']:
227 226 return False
228 227
229 228 return True
230 229
231 230 def check_aborted(self, msg_id):
232 231 return msg_id in self.aborted
233 232
234 233 #-------------------- queue handlers -----------------------------
235 234
236 235 def clear_request(self, stream, idents, parent):
237 236 """Clear our namespace."""
238 237 self.user_ns = {}
239 238 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
240 239 content = dict(status='ok'))
241 240 self._initial_exec_lines()
242 241
243 242 def execute_request(self, stream, ident, parent):
244 243 logging.debug('execute request %s'%parent)
245 244 try:
246 245 code = parent[u'content'][u'code']
247 246 except:
248 247 logging.error("Got bad msg: %s"%parent, exc_info=True)
249 248 return
250 249 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
251 250 ident='%s.pyin'%self.prefix)
252 251 started = datetime.now().strftime(ISO8601)
253 252 try:
254 253 comp_code = self.compiler(code, '<zmq-kernel>')
255 254 # allow for not overriding displayhook
256 255 if hasattr(sys.displayhook, 'set_parent'):
257 256 sys.displayhook.set_parent(parent)
258 257 sys.stdout.set_parent(parent)
259 258 sys.stderr.set_parent(parent)
260 259 exec comp_code in self.user_ns, self.user_ns
261 260 except:
262 261 exc_content = self._wrap_exception('execute')
263 262 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
264 263 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
265 264 ident='%s.pyerr'%self.prefix)
266 265 reply_content = exc_content
267 266 else:
268 267 reply_content = {'status' : 'ok'}
269 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
270 # self.reply_socket.send(ident, zmq.SNDMORE)
271 # self.reply_socket.send_json(reply_msg)
268
272 269 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
273 270 ident=ident, subheader = dict(started=started))
274 271 logging.debug(str(reply_msg))
275 272 if reply_msg['content']['status'] == u'error':
276 273 self.abort_queues()
277 274
278 275 def complete_request(self, stream, ident, parent):
279 276 matches = {'matches' : self.complete(parent),
280 277 'status' : 'ok'}
281 278 completion_msg = self.session.send(stream, 'complete_reply',
282 279 matches, parent, ident)
283 280 # print >> sys.__stdout__, completion_msg
284 281
285 282 def complete(self, msg):
286 283 return self.completer.complete(msg.content.line, msg.content.text)
287 284
288 285 def apply_request(self, stream, ident, parent):
289 286 # print (parent)
290 287 try:
291 288 content = parent[u'content']
292 289 bufs = parent[u'buffers']
293 290 msg_id = parent['header']['msg_id']
294 291 bound = content.get('bound', False)
295 292 except:
296 293 logging.error("Got bad msg: %s"%parent, exc_info=True)
297 294 return
298 295 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
299 296 # self.iopub_stream.send(pyin_msg)
300 297 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
301 298 sub = {'dependencies_met' : True, 'engine' : self.ident,
302 299 'started': datetime.now().strftime(ISO8601)}
303 300 try:
304 301 # allow for not overriding displayhook
305 302 if hasattr(sys.displayhook, 'set_parent'):
306 303 sys.displayhook.set_parent(parent)
307 304 sys.stdout.set_parent(parent)
308 305 sys.stderr.set_parent(parent)
309 306 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
310 307 if bound:
311 308 working = self.user_ns
312 309 suffix = str(msg_id).replace("-","")
313 310 prefix = "_"
314 311
315 312 else:
316 313 working = dict()
317 314 suffix = prefix = "_" # prevent keyword collisions with lambda
318 315 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
319 316 # if f.fun
320 if hasattr(f, 'func_name'):
321 fname = f.func_name
322 else:
323 fname = f.__name__
317 fname = getattr(f, '__name__', 'f')
324 318
325 319 fname = prefix+fname.strip('<>')+suffix
326 320 argname = prefix+"args"+suffix
327 321 kwargname = prefix+"kwargs"+suffix
328 322 resultname = prefix+"result"+suffix
329 323
330 324 ns = { fname : f, argname : args, kwargname : kwargs }
331 325 # print ns
332 326 working.update(ns)
333 327 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
334 328 exec code in working, working
335 329 result = working.get(resultname)
336 330 # clear the namespace
337 331 if bound:
338 332 for key in ns.iterkeys():
339 333 self.user_ns.pop(key)
340 334 else:
341 335 del working
342 336
343 337 packed_result,buf = serialize_object(result)
344 338 result_buf = [packed_result]+buf
345 339 except:
346 340 exc_content = self._wrap_exception('apply')
347 341 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
348 342 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
349 343 ident='%s.pyerr'%self.prefix)
350 344 reply_content = exc_content
351 345 result_buf = []
352 346
353 if exc_content['ename'] == UnmetDependency.__name__:
347 if exc_content['ename'] == 'UnmetDependency':
354 348 sub['dependencies_met'] = False
355 349 else:
356 350 reply_content = {'status' : 'ok'}
357 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
358 # self.reply_socket.send(ident, zmq.SNDMORE)
359 # self.reply_socket.send_json(reply_msg)
351
352 # put 'ok'/'error' status in header, for scheduler introspection:
353 sub['status'] = reply_content['status']
354
360 355 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
361 356 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
362 # print(Message(reply_msg), file=sys.__stdout__)
357
363 358 # if reply_msg['content']['status'] == u'error':
364 359 # self.abort_queues()
365 360
366 361 def dispatch_queue(self, stream, msg):
367 362 self.control_stream.flush()
368 363 idents,msg = self.session.feed_identities(msg, copy=False)
369 364 try:
370 365 msg = self.session.unpack_message(msg, content=True, copy=False)
371 366 except:
372 367 logging.error("Invalid Message", exc_info=True)
373 368 return
374 369
375 370
376 371 header = msg['header']
377 372 msg_id = header['msg_id']
378 373 if self.check_aborted(msg_id):
379 374 self.aborted.remove(msg_id)
380 375 # is it safe to assume a msg_id will not be resubmitted?
381 376 reply_type = msg['msg_type'].split('_')[0] + '_reply'
382 377 reply_msg = self.session.send(stream, reply_type,
383 378 content={'status' : 'aborted'}, parent=msg, ident=idents)
384 379 return
385 380 handler = self.shell_handlers.get(msg['msg_type'], None)
386 381 if handler is None:
387 382 logging.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
388 383 else:
389 384 handler(stream, idents, msg)
390 385
391 386 def start(self):
392 387 #### stream mode:
393 388 if self.control_stream:
394 389 self.control_stream.on_recv(self.dispatch_control, copy=False)
395 390 self.control_stream.on_err(printer)
396 391
397 392 def make_dispatcher(stream):
398 393 def dispatcher(msg):
399 394 return self.dispatch_queue(stream, msg)
400 395 return dispatcher
401 396
402 397 for s in self.shell_streams:
403 # s.on_recv(printer)
404 398 s.on_recv(make_dispatcher(s), copy=False)
405 # s.on_err(printer)
399 s.on_err(printer)
406 400
407 401 if self.iopub_stream:
408 402 self.iopub_stream.on_err(printer)
409 # self.iopub_stream.on_send(printer)
410 403
411 404 #### while True mode:
412 405 # while True:
413 406 # idle = True
414 407 # try:
415 408 # msg = self.shell_stream.socket.recv_multipart(
416 409 # zmq.NOBLOCK, copy=False)
417 410 # except zmq.ZMQError, e:
418 411 # if e.errno != zmq.EAGAIN:
419 412 # raise e
420 413 # else:
421 414 # idle=False
422 415 # self.dispatch_queue(self.shell_stream, msg)
423 416 #
424 417 # if not self.task_stream.empty():
425 418 # idle=False
426 419 # msg = self.task_stream.recv_multipart()
427 420 # self.dispatch_queue(self.task_stream, msg)
428 421 # if idle:
429 422 # # don't busywait
430 423 # time.sleep(1e-3)
431 424
432 425 def make_kernel(int_id, identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
433 426 client_addr=None, loop=None, context=None, key=None,
434 427 out_stream_factory=OutStream, display_hook_factory=DisplayHook):
435 428 """NO LONGER IN USE"""
436 429 # create loop, context, and session:
437 430 if loop is None:
438 431 loop = ioloop.IOLoop.instance()
439 432 if context is None:
440 433 context = zmq.Context()
441 434 c = context
442 435 session = StreamSession(key=key)
443 436 # print (session.key)
444 437 # print (control_addr, shell_addrs, iopub_addr, hb_addrs)
445 438
446 439 # create Control Stream
447 440 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
448 441 control_stream.setsockopt(zmq.IDENTITY, identity)
449 442 control_stream.connect(control_addr)
450 443
451 444 # create Shell Streams (MUX, Task, etc.):
452 445 shell_streams = []
453 446 for addr in shell_addrs:
454 447 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
455 448 stream.setsockopt(zmq.IDENTITY, identity)
456 449 stream.connect(addr)
457 450 shell_streams.append(stream)
458 451
459 452 # create iopub stream:
460 453 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
461 454 iopub_stream.setsockopt(zmq.IDENTITY, identity)
462 455 iopub_stream.connect(iopub_addr)
463 456
464 457 # Redirect input streams and set a display hook.
465 458 if out_stream_factory:
466 459 sys.stdout = out_stream_factory(session, iopub_stream, u'stdout')
467 460 sys.stdout.topic = 'engine.%i.stdout'%int_id
468 461 sys.stderr = out_stream_factory(session, iopub_stream, u'stderr')
469 462 sys.stderr.topic = 'engine.%i.stderr'%int_id
470 463 if display_hook_factory:
471 464 sys.displayhook = display_hook_factory(session, iopub_stream)
472 465 sys.displayhook.topic = 'engine.%i.pyout'%int_id
473 466
474 467
475 468 # launch heartbeat
476 469 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
477 470 heart.start()
478 471
479 472 # create (optional) Client
480 473 if client_addr:
481 474 client = Client(client_addr, username=identity)
482 475 else:
483 476 client = None
484 477
485 478 kernel = Kernel(id=int_id, session=session, control_stream=control_stream,
486 479 shell_streams=shell_streams, iopub_stream=iopub_stream,
487 480 client=client, loop=loop)
488 481 kernel.start()
489 482 return loop, c, kernel
490 483
@@ -1,549 +1,549 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11 from datetime import datetime
12 12
13 13 import zmq
14 14 from zmq.utils import jsonapi
15 15 from zmq.eventloop.zmqstream import ZMQStream
16 16
17 17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 18 from IPython.utils.newserialized import serialize, unserialize
19 19
20 20 from IPython.zmq.parallel.error import RemoteError
21 21
22 22 try:
23 23 import cPickle
24 24 pickle = cPickle
25 25 except:
26 26 cPickle = None
27 27 import pickle
28 28
29 29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
30 30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
31 31 if json_name in ('jsonlib', 'jsonlib2'):
32 32 use_json = True
33 33 elif json_name:
34 34 if cPickle is None:
35 35 use_json = True
36 36 else:
37 37 use_json = False
38 38 else:
39 39 use_json = False
40 40
41 41 def squash_unicode(obj):
42 42 if isinstance(obj,dict):
43 43 for key in obj.keys():
44 44 obj[key] = squash_unicode(obj[key])
45 45 if isinstance(key, unicode):
46 46 obj[squash_unicode(key)] = obj.pop(key)
47 47 elif isinstance(obj, list):
48 48 for i,v in enumerate(obj):
49 49 obj[i] = squash_unicode(v)
50 50 elif isinstance(obj, unicode):
51 51 obj = obj.encode('utf8')
52 52 return obj
53 53
54 54 json_packer = jsonapi.dumps
55 55 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
56 56
57 57 pickle_packer = lambda o: pickle.dumps(o,-1)
58 58 pickle_unpacker = pickle.loads
59 59
60 60 if use_json:
61 61 default_packer = json_packer
62 62 default_unpacker = json_unpacker
63 63 else:
64 64 default_packer = pickle_packer
65 65 default_unpacker = pickle_unpacker
66 66
67 67
68 68 DELIM="<IDS|MSG>"
69 69 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
70 70
71 71 def wrap_exception(engine_info={}):
72 72 etype, evalue, tb = sys.exc_info()
73 73 stb = traceback.format_exception(etype, evalue, tb)
74 74 exc_content = {
75 75 'status' : 'error',
76 76 'traceback' : stb,
77 77 'ename' : unicode(etype.__name__),
78 78 'evalue' : unicode(evalue),
79 79 'engine_info' : engine_info
80 80 }
81 81 return exc_content
82 82
83 83 def unwrap_exception(content):
84 84 err = RemoteError(content['ename'], content['evalue'],
85 85 ''.join(content['traceback']),
86 86 content.get('engine_info', {}))
87 87 return err
88 88
89 89
90 90 class Message(object):
91 91 """A simple message object that maps dict keys to attributes.
92 92
93 93 A Message can be created from a dict and a dict from a Message instance
94 94 simply by calling dict(msg_obj)."""
95 95
96 96 def __init__(self, msg_dict):
97 97 dct = self.__dict__
98 98 for k, v in dict(msg_dict).iteritems():
99 99 if isinstance(v, dict):
100 100 v = Message(v)
101 101 dct[k] = v
102 102
103 103 # Having this iterator lets dict(msg_obj) work out of the box.
104 104 def __iter__(self):
105 105 return iter(self.__dict__.iteritems())
106 106
107 107 def __repr__(self):
108 108 return repr(self.__dict__)
109 109
110 110 def __str__(self):
111 111 return pprint.pformat(self.__dict__)
112 112
113 113 def __contains__(self, k):
114 114 return k in self.__dict__
115 115
116 116 def __getitem__(self, k):
117 117 return self.__dict__[k]
118 118
119 119
120 120 def msg_header(msg_id, msg_type, username, session):
121 121 date=datetime.now().strftime(ISO8601)
122 122 return locals()
123 123
124 124 def extract_header(msg_or_header):
125 125 """Given a message or header, return the header."""
126 126 if not msg_or_header:
127 127 return {}
128 128 try:
129 129 # See if msg_or_header is the entire message.
130 130 h = msg_or_header['header']
131 131 except KeyError:
132 132 try:
133 133 # See if msg_or_header is just the header
134 134 h = msg_or_header['msg_id']
135 135 except KeyError:
136 136 raise
137 137 else:
138 138 h = msg_or_header
139 139 if not isinstance(h, dict):
140 140 h = dict(h)
141 141 return h
142 142
143 143 def rekey(dikt):
144 144 """Rekey a dict that has been forced to use str keys where there should be
145 145 ints by json. This belongs in the jsonutil added by fperez."""
146 146 for k in dikt.iterkeys():
147 147 if isinstance(k, str):
148 148 ik=fk=None
149 149 try:
150 150 ik = int(k)
151 151 except ValueError:
152 152 try:
153 153 fk = float(k)
154 154 except ValueError:
155 155 continue
156 156 if ik is not None:
157 157 nk = ik
158 158 else:
159 159 nk = fk
160 160 if nk in dikt:
161 161 raise KeyError("already have key %r"%nk)
162 162 dikt[nk] = dikt.pop(k)
163 163 return dikt
164 164
165 165 def serialize_object(obj, threshold=64e-6):
166 166 """Serialize an object into a list of sendable buffers.
167 167
168 168 Parameters
169 169 ----------
170 170
171 171 obj : object
172 172 The object to be serialized
173 173 threshold : float
174 174 The threshold for not double-pickling the content.
175 175
176 176
177 177 Returns
178 178 -------
179 179 ('pmd', [bufs]) :
180 180 where pmd is the pickled metadata wrapper,
181 181 bufs is a list of data buffers
182 182 """
183 183 databuffers = []
184 184 if isinstance(obj, (list, tuple)):
185 185 clist = canSequence(obj)
186 186 slist = map(serialize, clist)
187 187 for s in slist:
188 188 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
189 189 databuffers.append(s.getData())
190 190 s.data = None
191 191 return pickle.dumps(slist,-1), databuffers
192 192 elif isinstance(obj, dict):
193 193 sobj = {}
194 194 for k in sorted(obj.iterkeys()):
195 195 s = serialize(can(obj[k]))
196 196 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
197 197 databuffers.append(s.getData())
198 198 s.data = None
199 199 sobj[k] = s
200 200 return pickle.dumps(sobj,-1),databuffers
201 201 else:
202 202 s = serialize(can(obj))
203 203 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
204 204 databuffers.append(s.getData())
205 205 s.data = None
206 206 return pickle.dumps(s,-1),databuffers
207 207
208 208
209 209 def unserialize_object(bufs):
210 210 """reconstruct an object serialized by serialize_object from data buffers."""
211 211 bufs = list(bufs)
212 212 sobj = pickle.loads(bufs.pop(0))
213 213 if isinstance(sobj, (list, tuple)):
214 214 for s in sobj:
215 215 if s.data is None:
216 216 s.data = bufs.pop(0)
217 217 return uncanSequence(map(unserialize, sobj)), bufs
218 218 elif isinstance(sobj, dict):
219 219 newobj = {}
220 220 for k in sorted(sobj.iterkeys()):
221 221 s = sobj[k]
222 222 if s.data is None:
223 223 s.data = bufs.pop(0)
224 224 newobj[k] = uncan(unserialize(s))
225 225 return newobj, bufs
226 226 else:
227 227 if sobj.data is None:
228 228 sobj.data = bufs.pop(0)
229 229 return uncan(unserialize(sobj)), bufs
230 230
231 231 def pack_apply_message(f, args, kwargs, threshold=64e-6):
232 232 """pack up a function, args, and kwargs to be sent over the wire
233 233 as a series of buffers. Any object whose data is larger than `threshold`
234 234 will not have their data copied (currently only numpy arrays support zero-copy)"""
235 235 msg = [pickle.dumps(can(f),-1)]
236 236 databuffers = [] # for large objects
237 237 sargs, bufs = serialize_object(args,threshold)
238 238 msg.append(sargs)
239 239 databuffers.extend(bufs)
240 240 skwargs, bufs = serialize_object(kwargs,threshold)
241 241 msg.append(skwargs)
242 242 databuffers.extend(bufs)
243 243 msg.extend(databuffers)
244 244 return msg
245 245
246 246 def unpack_apply_message(bufs, g=None, copy=True):
247 247 """unpack f,args,kwargs from buffers packed by pack_apply_message()
248 248 Returns: original f,args,kwargs"""
249 249 bufs = list(bufs) # allow us to pop
250 250 assert len(bufs) >= 3, "not enough buffers!"
251 251 if not copy:
252 252 for i in range(3):
253 253 bufs[i] = bufs[i].bytes
254 254 cf = pickle.loads(bufs.pop(0))
255 255 sargs = list(pickle.loads(bufs.pop(0)))
256 256 skwargs = dict(pickle.loads(bufs.pop(0)))
257 257 # print sargs, skwargs
258 258 f = uncan(cf, g)
259 259 for sa in sargs:
260 260 if sa.data is None:
261 261 m = bufs.pop(0)
262 262 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
263 263 if copy:
264 264 sa.data = buffer(m)
265 265 else:
266 266 sa.data = m.buffer
267 267 else:
268 268 if copy:
269 269 sa.data = m
270 270 else:
271 271 sa.data = m.bytes
272 272
273 273 args = uncanSequence(map(unserialize, sargs), g)
274 274 kwargs = {}
275 275 for k in sorted(skwargs.iterkeys()):
276 276 sa = skwargs[k]
277 277 if sa.data is None:
278 278 sa.data = bufs.pop(0)
279 279 kwargs[k] = uncan(unserialize(sa), g)
280 280
281 281 return f,args,kwargs
282 282
283 283 class StreamSession(object):
284 284 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
285 285 debug=False
286 286 key=None
287 287
288 288 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
289 289 if username is None:
290 290 username = os.environ.get('USER','username')
291 291 self.username = username
292 292 if session is None:
293 293 self.session = str(uuid.uuid4())
294 294 else:
295 295 self.session = session
296 296 self.msg_id = str(uuid.uuid4())
297 297 if packer is None:
298 298 self.pack = default_packer
299 299 else:
300 300 if not callable(packer):
301 301 raise TypeError("packer must be callable, not %s"%type(packer))
302 302 self.pack = packer
303 303
304 304 if unpacker is None:
305 305 self.unpack = default_unpacker
306 306 else:
307 307 if not callable(unpacker):
308 308 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
309 309 self.unpack = unpacker
310 310
311 311 if key is not None and keyfile is not None:
312 312 raise TypeError("Must specify key OR keyfile, not both")
313 313 if keyfile is not None:
314 314 with open(keyfile) as f:
315 315 self.key = f.read().strip()
316 316 else:
317 317 self.key = key
318 318 # print key, keyfile, self.key
319 319 self.none = self.pack({})
320 320
321 321 def msg_header(self, msg_type):
322 322 h = msg_header(self.msg_id, msg_type, self.username, self.session)
323 323 self.msg_id = str(uuid.uuid4())
324 324 return h
325 325
326 326 def msg(self, msg_type, content=None, parent=None, subheader=None):
327 327 msg = {}
328 328 msg['header'] = self.msg_header(msg_type)
329 329 msg['msg_id'] = msg['header']['msg_id']
330 330 msg['parent_header'] = {} if parent is None else extract_header(parent)
331 331 msg['msg_type'] = msg_type
332 332 msg['content'] = {} if content is None else content
333 333 sub = {} if subheader is None else subheader
334 334 msg['header'].update(sub)
335 335 return msg
336 336
337 337 def check_key(self, msg_or_header):
338 338 """Check that a message's header has the right key"""
339 339 if self.key is None:
340 340 return True
341 341 header = extract_header(msg_or_header)
342 342 return header.get('key', None) == self.key
343 343
344 344
345 345 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
346 346 """Build and send a message via stream or socket.
347 347
348 348 Parameters
349 349 ----------
350 350
351 351 stream : zmq.Socket or ZMQStream
352 352 the socket-like object used to send the data
353 353 msg_or_type : str or Message/dict
354 354 Normally, msg_or_type will be a msg_type unless a message is being sent more
355 355 than once.
356 356
357 357 Returns
358 358 -------
359 359 (msg,sent) : tuple
360 360 msg : Message
361 361 the nice wrapped dict-like object containing the headers
362 362
363 363 """
364 364 if isinstance(msg_or_type, (Message, dict)):
365 365 # we got a Message, not a msg_type
366 366 # don't build a new Message
367 367 msg = msg_or_type
368 368 content = msg['content']
369 369 else:
370 370 msg = self.msg(msg_or_type, content, parent, subheader)
371 371 buffers = [] if buffers is None else buffers
372 372 to_send = []
373 373 if isinstance(ident, list):
374 374 # accept list of idents
375 375 to_send.extend(ident)
376 376 elif ident is not None:
377 377 to_send.append(ident)
378 378 to_send.append(DELIM)
379 379 if self.key is not None:
380 380 to_send.append(self.key)
381 381 to_send.append(self.pack(msg['header']))
382 382 to_send.append(self.pack(msg['parent_header']))
383 383
384 384 if content is None:
385 385 content = self.none
386 386 elif isinstance(content, dict):
387 387 content = self.pack(content)
388 388 elif isinstance(content, str):
389 389 # content is already packed, as in a relayed message
390 390 pass
391 391 else:
392 392 raise TypeError("Content incorrect type: %s"%type(content))
393 393 to_send.append(content)
394 394 flag = 0
395 395 if buffers:
396 396 flag = zmq.SNDMORE
397 397 stream.send_multipart(to_send, flag, copy=False)
398 398 for b in buffers[:-1]:
399 399 stream.send(b, flag, copy=False)
400 400 if buffers:
401 401 stream.send(buffers[-1], copy=False)
402 omsg = Message(msg)
402 # omsg = Message(msg)
403 403 if self.debug:
404 pprint.pprint(omsg)
404 pprint.pprint(msg)
405 405 pprint.pprint(to_send)
406 406 pprint.pprint(buffers)
407 return omsg
407 return msg
408 408
409 409 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
410 410 """Send a raw message via ident path.
411 411
412 412 Parameters
413 413 ----------
414 414 msg : list of sendable buffers"""
415 415 to_send = []
416 416 if isinstance(ident, str):
417 417 ident = [ident]
418 418 if ident is not None:
419 419 to_send.extend(ident)
420 420 to_send.append(DELIM)
421 421 if self.key is not None:
422 422 to_send.append(self.key)
423 423 to_send.extend(msg)
424 424 stream.send_multipart(msg, flags, copy=copy)
425 425
426 426 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
427 427 """receives and unpacks a message
428 428 returns [idents], msg"""
429 429 if isinstance(socket, ZMQStream):
430 430 socket = socket.socket
431 431 try:
432 432 msg = socket.recv_multipart(mode)
433 433 except zmq.ZMQError as e:
434 434 if e.errno == zmq.EAGAIN:
435 435 # We can convert EAGAIN to None as we know in this case
436 436 # recv_json won't return None.
437 437 return None
438 438 else:
439 439 raise
440 440 # return an actual Message object
441 441 # determine the number of idents by trying to unpack them.
442 442 # this is terrible:
443 443 idents, msg = self.feed_identities(msg, copy)
444 444 try:
445 445 return idents, self.unpack_message(msg, content=content, copy=copy)
446 446 except Exception as e:
447 447 print (idents, msg)
448 448 # TODO: handle it
449 449 raise e
450 450
451 451 def feed_identities(self, msg, copy=True):
452 452 """feed until DELIM is reached, then return the prefix as idents and remainder as
453 453 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
454 454
455 455 Parameters
456 456 ----------
457 457 msg : a list of Message or bytes objects
458 458 the message to be split
459 459 copy : bool
460 460 flag determining whether the arguments are bytes or Messages
461 461
462 462 Returns
463 463 -------
464 464 (idents,msg) : two lists
465 465 idents will always be a list of bytes - the indentity prefix
466 466 msg will be a list of bytes or Messages, unchanged from input
467 467 msg should be unpackable via self.unpack_message at this point.
468 468 """
469 469 msg = list(msg)
470 470 idents = []
471 471 while len(msg) > 3:
472 472 if copy:
473 473 s = msg[0]
474 474 else:
475 475 s = msg[0].bytes
476 476 if s == DELIM:
477 477 msg.pop(0)
478 478 break
479 479 else:
480 480 idents.append(s)
481 481 msg.pop(0)
482 482
483 483 return idents, msg
484 484
485 485 def unpack_message(self, msg, content=True, copy=True):
486 486 """Return a message object from the format
487 487 sent by self.send.
488 488
489 489 Parameters:
490 490 -----------
491 491
492 492 content : bool (True)
493 493 whether to unpack the content dict (True),
494 494 or leave it serialized (False)
495 495
496 496 copy : bool (True)
497 497 whether to return the bytes (True),
498 498 or the non-copying Message object in each place (False)
499 499
500 500 """
501 501 ikey = int(self.key is not None)
502 502 minlen = 3 + ikey
503 503 if not len(msg) >= minlen:
504 504 raise TypeError("malformed message, must have at least %i elements"%minlen)
505 505 message = {}
506 506 if not copy:
507 507 for i in range(minlen):
508 508 msg[i] = msg[i].bytes
509 509 if ikey:
510 510 if not self.key == msg[0]:
511 511 raise KeyError("Invalid Session Key: %s"%msg[0])
512 512 message['header'] = self.unpack(msg[ikey+0])
513 513 message['msg_type'] = message['header']['msg_type']
514 514 message['parent_header'] = self.unpack(msg[ikey+1])
515 515 if content:
516 516 message['content'] = self.unpack(msg[ikey+2])
517 517 else:
518 518 message['content'] = msg[ikey+2]
519 519
520 520 # message['buffers'] = msg[3:]
521 521 # else:
522 522 # message['header'] = self.unpack(msg[0].bytes)
523 523 # message['msg_type'] = message['header']['msg_type']
524 524 # message['parent_header'] = self.unpack(msg[1].bytes)
525 525 # if content:
526 526 # message['content'] = self.unpack(msg[2].bytes)
527 527 # else:
528 528 # message['content'] = msg[2].bytes
529 529
530 530 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
531 531 return message
532 532
533 533
534 534
535 535 def test_msg2obj():
536 536 am = dict(x=1)
537 537 ao = Message(am)
538 538 assert ao.x == am['x']
539 539
540 540 am['y'] = dict(z=1)
541 541 ao = Message(am)
542 542 assert ao.y.z == am['y']['z']
543 543
544 544 k1, k2 = 'y', 'z'
545 545 assert ao[k1][k2] == am[k1][k2]
546 546
547 547 am2 = dict(ao)
548 548 assert am['x'] == am2['x']
549 549 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now