##// END OF EJS Templates
adapt kernel/error.py to zmq, improve error propagation.
MinRK -
Show More
@@ -0,0 +1,276 b''
1 # encoding: utf-8
2
3 """Classes and functions for kernel related errors and exceptions."""
4 from __future__ import print_function
5
6 __docformat__ = "restructuredtext en"
7
8 # Tell nose to skip this module
9 __test__ = {}
10
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
13 #
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
17
18 #-------------------------------------------------------------------------------
19 # Error classes
20 #-------------------------------------------------------------------------------
21 class IPythonError(Exception):
22 """Base exception that all of our exceptions inherit from.
23
24 This can be raised by code that doesn't have any more specific
25 information."""
26
27 pass
28
29 # Exceptions associated with the controller objects
30 class ControllerError(IPythonError): pass
31
32 class ControllerCreationError(ControllerError): pass
33
34
35 # Exceptions associated with the Engines
36 class EngineError(IPythonError): pass
37
38 class EngineCreationError(EngineError): pass
39
40 class KernelError(IPythonError):
41 pass
42
43 class NotDefined(KernelError):
44 def __init__(self, name):
45 self.name = name
46 self.args = (name,)
47
48 def __repr__(self):
49 return '<NotDefined: %s>' % self.name
50
51 __str__ = __repr__
52
53
54 class QueueCleared(KernelError):
55 pass
56
57
58 class IdInUse(KernelError):
59 pass
60
61
62 class ProtocolError(KernelError):
63 pass
64
65
66 class ConnectionError(KernelError):
67 pass
68
69
70 class InvalidEngineID(KernelError):
71 pass
72
73
74 class NoEnginesRegistered(KernelError):
75 pass
76
77
78 class InvalidClientID(KernelError):
79 pass
80
81
82 class InvalidDeferredID(KernelError):
83 pass
84
85
86 class SerializationError(KernelError):
87 pass
88
89
90 class MessageSizeError(KernelError):
91 pass
92
93
94 class PBMessageSizeError(MessageSizeError):
95 pass
96
97
98 class ResultNotCompleted(KernelError):
99 pass
100
101
102 class ResultAlreadyRetrieved(KernelError):
103 pass
104
105 class ClientError(KernelError):
106 pass
107
108
109 class TaskAborted(KernelError):
110 pass
111
112
113 class TaskTimeout(KernelError):
114 pass
115
116
117 class NotAPendingResult(KernelError):
118 pass
119
120
121 class UnpickleableException(KernelError):
122 pass
123
124
125 class AbortedPendingDeferredError(KernelError):
126 pass
127
128
129 class InvalidProperty(KernelError):
130 pass
131
132
133 class MissingBlockArgument(KernelError):
134 pass
135
136
137 class StopLocalExecution(KernelError):
138 pass
139
140
141 class SecurityError(KernelError):
142 pass
143
144
145 class FileTimeoutError(KernelError):
146 pass
147
148 class RemoteError(KernelError):
149 """Error raised elsewhere"""
150 ename=None
151 evalue=None
152 traceback=None
153 engine_info=None
154
155 def __init__(self, ename, evalue, traceback, engine_info=None):
156 self.ename=ename
157 self.evalue=evalue
158 self.traceback=traceback
159 self.engine_info=engine_info or {}
160 self.args=(ename, evalue)
161
162 def __repr__(self):
163 engineid = self.engine_info.get('engineid', ' ')
164 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
165
166 def __str__(self):
167 sig = "%s(%s)"%(self.ename, self.evalue)
168 if self.traceback:
169 return sig + '\n' + self.traceback
170 else:
171 return sig
172
173
174 class TaskRejectError(KernelError):
175 """Exception to raise when a task should be rejected by an engine.
176
177 This exception can be used to allow a task running on an engine to test
178 if the engine (or the user's namespace on the engine) has the needed
179 task dependencies. If not, the task should raise this exception. For
180 the task to be retried on another engine, the task should be created
181 with the `retries` argument > 1.
182
183 The advantage of this approach over our older properties system is that
184 tasks have full access to the user's namespace on the engines and the
185 properties don't have to be managed or tested by the controller.
186 """
187
188
189 class CompositeError(KernelError):
190 """Error for representing possibly multiple errors on engines"""
191 def __init__(self, message, elist):
192 Exception.__init__(self, *(message, elist))
193 # Don't use pack_exception because it will conflict with the .message
194 # attribute that is being deprecated in 2.6 and beyond.
195 self.msg = message
196 self.elist = elist
197 self.args = [ e[0] for e in elist ]
198
199 def _get_engine_str(self, ei):
200 if not ei:
201 return '[Engine Exception]'
202 else:
203 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
204
205 def _get_traceback(self, ev):
206 try:
207 tb = ev._ipython_traceback_text
208 except AttributeError:
209 return 'No traceback available'
210 else:
211 return tb
212
213 def __str__(self):
214 s = str(self.msg)
215 for en, ev, etb, ei in self.elist:
216 engine_str = self._get_engine_str(ei)
217 s = s + '\n' + engine_str + en + ': ' + str(ev)
218 return s
219
220 def __repr__(self):
221 return "CompositeError(%i)"%len(self.elist)
222
223 def print_tracebacks(self, excid=None):
224 if excid is None:
225 for (en,ev,etb,ei) in self.elist:
226 print (self._get_engine_str(ei))
227 print (etb or 'No traceback available')
228 print ()
229 else:
230 try:
231 en,ev,etb,ei = self.elist[excid]
232 except:
233 raise IndexError("an exception with index %i does not exist"%excid)
234 else:
235 print (self._get_engine_str(ei))
236 print (etb or 'No traceback available')
237
238 def raise_exception(self, excid=0):
239 try:
240 en,ev,etb,ei = self.elist[excid]
241 except:
242 raise IndexError("an exception with index %i does not exist"%excid)
243 else:
244 try:
245 raise RemoteError(en, ev, etb, ei)
246 except:
247 et,ev,tb = sys.exc_info()
248
249
250 def collect_exceptions(rdict, method):
251 """check a result dict for errors, and raise CompositeError if any exist.
252 Passthrough otherwise."""
253 elist = []
254 for r in rdict.values():
255 if isinstance(r, RemoteError):
256 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
257 # Sometimes we could have CompositeError in our list. Just take
258 # the errors out of them and put them in our new list. This
259 # has the effect of flattening lists of CompositeErrors into one
260 # CompositeError
261 if en=='CompositeError':
262 for e in ev.elist:
263 elist.append(e)
264 else:
265 elist.append((en, ev, etb, ei))
266 if len(elist)==0:
267 return rdict
268 else:
269 msg = "one or more exceptions from call to method: %s" % (method)
270 # This silliness is needed so the debugger has access to the exception
271 # instance (e in this case)
272 try:
273 raise CompositeError(msg, elist)
274 except CompositeError, e:
275 raise e
276
@@ -1,960 +1,975 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import os
16 16 import time
17 17 from getpass import getpass
18 18 from pprint import pprint
19 19
20 20 import zmq
21 21 from zmq.eventloop import ioloop, zmqstream
22 22
23 23 from IPython.external.decorator import decorator
24 24 from IPython.zmq import tunnel
25 25
26 26 import streamsession as ss
27 27 # from remotenamespace import RemoteNamespace
28 28 from view import DirectView, LoadBalancedView
29 29 from dependency import Dependency, depend, require
30 import error
30 31
31 32 def _push(ns):
32 33 globals().update(ns)
33 34
34 35 def _pull(keys):
35 36 g = globals()
36 37 if isinstance(keys, (list,tuple, set)):
37 38 for key in keys:
38 39 if not g.has_key(key):
39 40 raise NameError("name '%s' is not defined"%key)
40 41 return map(g.get, keys)
41 42 else:
42 43 if not g.has_key(keys):
43 44 raise NameError("name '%s' is not defined"%keys)
44 45 return g.get(keys)
45 46
46 47 def _clear():
47 48 globals().clear()
48 49
49 50 def execute(code):
50 51 exec code in globals()
51 52
52 53 #--------------------------------------------------------------------------
53 54 # Decorators for Client methods
54 55 #--------------------------------------------------------------------------
55 56
56 57 @decorator
57 58 def spinfirst(f, self, *args, **kwargs):
58 59 """Call spin() to sync state prior to calling the method."""
59 60 self.spin()
60 61 return f(self, *args, **kwargs)
61 62
62 63 @decorator
63 64 def defaultblock(f, self, *args, **kwargs):
64 65 """Default to self.block; preserve self.block."""
65 66 block = kwargs.get('block',None)
66 67 block = self.block if block is None else block
67 68 saveblock = self.block
68 69 self.block = block
69 70 ret = f(self, *args, **kwargs)
70 71 self.block = saveblock
71 72 return ret
72 73
73 74 def remote(client, bound=False, block=None, targets=None):
74 75 """Turn a function into a remote function.
75 76
76 77 This method can be used for map:
77 78
78 79 >>> @remote(client,block=True)
79 80 def func(a)
80 81 """
81 82 def remote_function(f):
82 83 return RemoteFunction(client, f, bound, block, targets)
83 84 return remote_function
84 85
85 86 #--------------------------------------------------------------------------
86 87 # Classes
87 88 #--------------------------------------------------------------------------
88 89
89 90 class RemoteFunction(object):
90 91 """Turn an existing function into a remote function.
91 92
92 93 Parameters
93 94 ----------
94 95
95 96 client : Client instance
96 97 The client to be used to connect to engines
97 98 f : callable
98 99 The function to be wrapped into a remote function
99 100 bound : bool [default: False]
100 101 Whether the affect the remote namespace when called
101 102 block : bool [default: None]
102 103 Whether to wait for results or not. The default behavior is
103 104 to use the current `block` attribute of `client`
104 105 targets : valid target list [default: all]
105 106 The targets on which to execute.
106 107 """
107 108
108 109 client = None # the remote connection
109 110 func = None # the wrapped function
110 111 block = None # whether to block
111 112 bound = None # whether to affect the namespace
112 113 targets = None # where to execute
113 114
114 115 def __init__(self, client, f, bound=False, block=None, targets=None):
115 116 self.client = client
116 117 self.func = f
117 118 self.block=block
118 119 self.bound=bound
119 120 self.targets=targets
120 121
121 122 def __call__(self, *args, **kwargs):
122 123 return self.client.apply(self.func, args=args, kwargs=kwargs,
123 124 block=self.block, targets=self.targets, bound=self.bound)
124 125
125 126
126 127 class AbortedTask(object):
127 128 """A basic wrapper object describing an aborted task."""
128 129 def __init__(self, msg_id):
129 130 self.msg_id = msg_id
130 131
131 class ControllerError(Exception):
132 """Exception Class for errors in the controller (not the Engine)."""
133 def __init__(self, etype, evalue, tb):
134 self.etype = etype
135 self.evalue = evalue
136 self.traceback=tb
137
132 class ResultDict(dict):
133 """A subclass of dict that raises errors if it has them."""
134 def __getitem__(self, key):
135 res = dict.__getitem__(self, key)
136 if isinstance(res, error.KernelError):
137 raise res
138 return res
139
138 140 class Client(object):
139 141 """A semi-synchronous client to the IPython ZMQ controller
140 142
141 143 Parameters
142 144 ----------
143 145
144 146 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
145 147 The address of the controller's registration socket.
146 148 [Default: 'tcp://127.0.0.1:10101']
147 149 context : zmq.Context
148 150 Pass an existing zmq.Context instance, otherwise the client will create its own
149 151 username : bytes
150 152 set username to be passed to the Session object
151 153 debug : bool
152 154 flag for lots of message printing for debug purposes
153 155
154 156 #-------------- ssh related args ----------------
155 157 # These are args for configuring the ssh tunnel to be used
156 158 # credentials are used to forward connections over ssh to the Controller
157 159 # Note that the ip given in `addr` needs to be relative to sshserver
158 160 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
159 161 # and set sshserver as the same machine the Controller is on. However,
160 162 # the only requirement is that sshserver is able to see the Controller
161 163 # (i.e. is within the same trusted network).
162 164
163 165 sshserver : str
164 166 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
165 167 If keyfile or password is specified, and this is not, it will default to
166 168 the ip given in addr.
167 169 sshkey : str; path to public ssh key file
168 170 This specifies a key to be used in ssh login, default None.
169 171 Regular default ssh keys will be used without specifying this argument.
170 172 password : str;
171 173 Your ssh password to sshserver. Note that if this is left None,
172 174 you will be prompted for it if passwordless key based login is unavailable.
173 175
174 176 #------- exec authentication args -------
175 177 # If even localhost is untrusted, you can have some protection against
176 178 # unauthorized execution by using a key. Messages are still sent
177 179 # as cleartext, so if someone can snoop your loopback traffic this will
178 180 # not help anything.
179 181
180 182 exec_key : str
181 183 an authentication key or file containing a key
182 184 default: None
183 185
184 186
185 187 Attributes
186 188 ----------
187 189 ids : set of int engine IDs
188 190 requesting the ids attribute always synchronizes
189 191 the registration state. To request ids without synchronization,
190 192 use semi-private _ids attributes.
191 193
192 194 history : list of msg_ids
193 195 a list of msg_ids, keeping track of all the execution
194 196 messages you have submitted in order.
195 197
196 198 outstanding : set of msg_ids
197 199 a set of msg_ids that have been submitted, but whose
198 200 results have not yet been received.
199 201
200 202 results : dict
201 203 a dict of all our results, keyed by msg_id
202 204
203 205 block : bool
204 206 determines default behavior when block not specified
205 207 in execution methods
206 208
207 209 Methods
208 210 -------
209 211 spin : flushes incoming results and registration state changes
210 212 control methods spin, and requesting `ids` also ensures up to date
211 213
212 214 barrier : wait on one or more msg_ids
213 215
214 216 execution methods: apply/apply_bound/apply_to/apply_bound
215 217 legacy: execute, run
216 218
217 219 query methods: queue_status, get_result, purge
218 220
219 221 control methods: abort, kill
220 222
221 223 """
222 224
223 225
224 226 _connected=False
225 227 _ssh=False
226 228 _engines=None
227 229 _addr='tcp://127.0.0.1:10101'
228 230 _registration_socket=None
229 231 _query_socket=None
230 232 _control_socket=None
231 233 _notification_socket=None
232 234 _mux_socket=None
233 235 _task_socket=None
234 236 block = False
235 237 outstanding=None
236 238 results = None
237 239 history = None
238 240 debug = False
239 241
240 242 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
241 243 sshserver=None, sshkey=None, password=None, paramiko=None,
242 244 exec_key=None,):
243 245 if context is None:
244 246 context = zmq.Context()
245 247 self.context = context
246 248 self._addr = addr
247 249 self._ssh = bool(sshserver or sshkey or password)
248 250 if self._ssh and sshserver is None:
249 251 # default to the same
250 252 sshserver = addr.split('://')[1].split(':')[0]
251 253 if self._ssh and password is None:
252 254 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
253 255 password=False
254 256 else:
255 257 password = getpass("SSH Password for %s: "%sshserver)
256 258 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
257 259
258 260 if exec_key is not None and os.path.isfile(exec_key):
259 261 arg = 'keyfile'
260 262 else:
261 263 arg = 'key'
262 264 key_arg = {arg:exec_key}
263 265 if username is None:
264 266 self.session = ss.StreamSession(**key_arg)
265 267 else:
266 268 self.session = ss.StreamSession(username, **key_arg)
267 269 self._registration_socket = self.context.socket(zmq.XREQ)
268 270 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
269 271 if self._ssh:
270 272 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
271 273 else:
272 274 self._registration_socket.connect(addr)
273 275 self._engines = {}
274 276 self._ids = set()
275 277 self.outstanding=set()
276 278 self.results = {}
277 279 self.history = []
278 280 self.debug = debug
279 281 self.session.debug = debug
280 282
281 283 self._notification_handlers = {'registration_notification' : self._register_engine,
282 284 'unregistration_notification' : self._unregister_engine,
283 285 }
284 286 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
285 287 'apply_reply' : self._handle_apply_reply}
286 288 self._connect(sshserver, ssh_kwargs)
287 289
288 290
289 291 @property
290 292 def ids(self):
291 293 """Always up to date ids property."""
292 294 self._flush_notifications()
293 295 return self._ids
294 296
295 297 def _update_engines(self, engines):
296 298 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
297 299 for k,v in engines.iteritems():
298 300 eid = int(k)
299 301 self._engines[eid] = bytes(v) # force not unicode
300 302 self._ids.add(eid)
301 303
302 304 def _build_targets(self, targets):
303 305 """Turn valid target IDs or 'all' into two lists:
304 306 (int_ids, uuids).
305 307 """
306 308 if targets is None:
307 309 targets = self._ids
308 310 elif isinstance(targets, str):
309 311 if targets.lower() == 'all':
310 312 targets = self._ids
311 313 else:
312 314 raise TypeError("%r not valid str target, must be 'all'"%(targets))
313 315 elif isinstance(targets, int):
314 316 targets = [targets]
315 317 return [self._engines[t] for t in targets], list(targets)
316 318
317 319 def _connect(self, sshserver, ssh_kwargs):
318 320 """setup all our socket connections to the controller. This is called from
319 321 __init__."""
320 322 if self._connected:
321 323 return
322 324 self._connected=True
323 325
324 326 def connect_socket(s, addr):
325 327 if self._ssh:
326 328 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
327 329 else:
328 330 return s.connect(addr)
329 331
330 332 self.session.send(self._registration_socket, 'connection_request')
331 333 idents,msg = self.session.recv(self._registration_socket,mode=0)
332 334 if self.debug:
333 335 pprint(msg)
334 336 msg = ss.Message(msg)
335 337 content = msg.content
336 338 if content.status == 'ok':
337 339 if content.queue:
338 340 self._mux_socket = self.context.socket(zmq.PAIR)
339 341 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
340 342 connect_socket(self._mux_socket, content.queue)
341 343 if content.task:
342 344 self._task_socket = self.context.socket(zmq.PAIR)
343 345 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 346 connect_socket(self._task_socket, content.task)
345 347 if content.notification:
346 348 self._notification_socket = self.context.socket(zmq.SUB)
347 349 connect_socket(self._notification_socket, content.notification)
348 350 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
349 351 if content.query:
350 352 self._query_socket = self.context.socket(zmq.PAIR)
351 353 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
352 354 connect_socket(self._query_socket, content.query)
353 355 if content.control:
354 356 self._control_socket = self.context.socket(zmq.PAIR)
355 357 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
356 358 connect_socket(self._control_socket, content.control)
357 359 self._update_engines(dict(content.engines))
358 360
359 361 else:
360 362 self._connected = False
361 363 raise Exception("Failed to connect!")
362 364
363 365 #--------------------------------------------------------------------------
364 366 # handlers and callbacks for incoming messages
365 367 #--------------------------------------------------------------------------
366 368
367 369 def _register_engine(self, msg):
368 370 """Register a new engine, and update our connection info."""
369 371 content = msg['content']
370 372 eid = content['id']
371 373 d = {eid : content['queue']}
372 374 self._update_engines(d)
373 375 self._ids.add(int(eid))
374 376
375 377 def _unregister_engine(self, msg):
376 378 """Unregister an engine that has died."""
377 379 content = msg['content']
378 380 eid = int(content['id'])
379 381 if eid in self._ids:
380 382 self._ids.remove(eid)
381 383 self._engines.pop(eid)
382 384
383 385 def _handle_execute_reply(self, msg):
384 386 """Save the reply to an execute_request into our results."""
385 387 parent = msg['parent_header']
386 388 msg_id = parent['msg_id']
387 389 if msg_id not in self.outstanding:
388 390 print("got unknown result: %s"%msg_id)
389 391 else:
390 392 self.outstanding.remove(msg_id)
391 393 self.results[msg_id] = ss.unwrap_exception(msg['content'])
392 394
393 395 def _handle_apply_reply(self, msg):
394 396 """Save the reply to an apply_request into our results."""
395 397 parent = msg['parent_header']
396 398 msg_id = parent['msg_id']
397 399 if msg_id not in self.outstanding:
398 400 print ("got unknown result: %s"%msg_id)
399 401 else:
400 402 self.outstanding.remove(msg_id)
401 403 content = msg['content']
402 404 if content['status'] == 'ok':
403 405 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
404 406 elif content['status'] == 'aborted':
405 self.results[msg_id] = AbortedTask(msg_id)
407 self.results[msg_id] = error.AbortedTask(msg_id)
406 408 elif content['status'] == 'resubmitted':
407 409 # TODO: handle resubmission
408 410 pass
409 411 else:
410 self.results[msg_id] = ss.unwrap_exception(content)
412 e = ss.unwrap_exception(content)
413 e_uuid = e.engine_info['engineid']
414 for k,v in self._engines.iteritems():
415 if v == e_uuid:
416 e.engine_info['engineid'] = k
417 break
418 self.results[msg_id] = e
411 419
412 420 def _flush_notifications(self):
413 421 """Flush notifications of engine registrations waiting
414 422 in ZMQ queue."""
415 423 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
416 424 while msg is not None:
417 425 if self.debug:
418 426 pprint(msg)
419 427 msg = msg[-1]
420 428 msg_type = msg['msg_type']
421 429 handler = self._notification_handlers.get(msg_type, None)
422 430 if handler is None:
423 431 raise Exception("Unhandled message type: %s"%msg.msg_type)
424 432 else:
425 433 handler(msg)
426 434 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
427 435
428 436 def _flush_results(self, sock):
429 437 """Flush task or queue results waiting in ZMQ queue."""
430 438 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
431 439 while msg is not None:
432 440 if self.debug:
433 441 pprint(msg)
434 442 msg = msg[-1]
435 443 msg_type = msg['msg_type']
436 444 handler = self._queue_handlers.get(msg_type, None)
437 445 if handler is None:
438 446 raise Exception("Unhandled message type: %s"%msg.msg_type)
439 447 else:
440 448 handler(msg)
441 449 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
442 450
443 451 def _flush_control(self, sock):
444 452 """Flush replies from the control channel waiting
445 453 in the ZMQ queue.
446 454
447 455 Currently: ignore them."""
448 456 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
449 457 while msg is not None:
450 458 if self.debug:
451 459 pprint(msg)
452 460 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
453 461
454 462 #--------------------------------------------------------------------------
455 463 # getitem
456 464 #--------------------------------------------------------------------------
457 465
458 466 def __getitem__(self, key):
459 467 """Dict access returns DirectView multiplexer objects or,
460 468 if key is None, a LoadBalancedView."""
461 469 if key is None:
462 470 return LoadBalancedView(self)
463 471 if isinstance(key, int):
464 472 if key not in self.ids:
465 473 raise IndexError("No such engine: %i"%key)
466 474 return DirectView(self, key)
467 475
468 476 if isinstance(key, slice):
469 477 indices = range(len(self.ids))[key]
470 478 ids = sorted(self._ids)
471 479 key = [ ids[i] for i in indices ]
472 480 # newkeys = sorted(self._ids)[thekeys[k]]
473 481
474 482 if isinstance(key, (tuple, list, xrange)):
475 483 _,targets = self._build_targets(list(key))
476 484 return DirectView(self, targets)
477 485 else:
478 486 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
479 487
480 488 #--------------------------------------------------------------------------
481 489 # Begin public methods
482 490 #--------------------------------------------------------------------------
483 491
484 492 def spin(self):
485 493 """Flush any registration notifications and execution results
486 494 waiting in the ZMQ queue.
487 495 """
488 496 if self._notification_socket:
489 497 self._flush_notifications()
490 498 if self._mux_socket:
491 499 self._flush_results(self._mux_socket)
492 500 if self._task_socket:
493 501 self._flush_results(self._task_socket)
494 502 if self._control_socket:
495 503 self._flush_control(self._control_socket)
496 504
497 505 def barrier(self, msg_ids=None, timeout=-1):
498 506 """waits on one or more `msg_ids`, for up to `timeout` seconds.
499 507
500 508 Parameters
501 509 ----------
502 510 msg_ids : int, str, or list of ints and/or strs
503 511 ints are indices to self.history
504 512 strs are msg_ids
505 513 default: wait on all outstanding messages
506 514 timeout : float
507 515 a time in seconds, after which to give up.
508 516 default is -1, which means no timeout
509 517
510 518 Returns
511 519 -------
512 520 True : when all msg_ids are done
513 521 False : timeout reached, some msg_ids still outstanding
514 522 """
515 523 tic = time.time()
516 524 if msg_ids is None:
517 525 theids = self.outstanding
518 526 else:
519 527 if isinstance(msg_ids, (int, str)):
520 528 msg_ids = [msg_ids]
521 529 theids = set()
522 530 for msg_id in msg_ids:
523 531 if isinstance(msg_id, int):
524 532 msg_id = self.history[msg_id]
525 533 theids.add(msg_id)
526 534 self.spin()
527 535 while theids.intersection(self.outstanding):
528 536 if timeout >= 0 and ( time.time()-tic ) > timeout:
529 537 break
530 538 time.sleep(1e-3)
531 539 self.spin()
532 540 return len(theids.intersection(self.outstanding)) == 0
533 541
534 542 #--------------------------------------------------------------------------
535 543 # Control methods
536 544 #--------------------------------------------------------------------------
537 545
538 546 @spinfirst
539 547 @defaultblock
540 548 def clear(self, targets=None, block=None):
541 549 """Clear the namespace in target(s)."""
542 550 targets = self._build_targets(targets)[0]
543 551 for t in targets:
544 552 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
545 553 error = False
546 554 if self.block:
547 555 for i in range(len(targets)):
548 556 idents,msg = self.session.recv(self._control_socket,0)
549 557 if self.debug:
550 558 pprint(msg)
551 559 if msg['content']['status'] != 'ok':
552 560 error = ss.unwrap_exception(msg['content'])
553 561 if error:
554 562 return error
555 563
556 564
557 565 @spinfirst
558 566 @defaultblock
559 567 def abort(self, msg_ids = None, targets=None, block=None):
560 568 """Abort the execution queues of target(s)."""
561 569 targets = self._build_targets(targets)[0]
562 570 if isinstance(msg_ids, basestring):
563 571 msg_ids = [msg_ids]
564 572 content = dict(msg_ids=msg_ids)
565 573 for t in targets:
566 574 self.session.send(self._control_socket, 'abort_request',
567 575 content=content, ident=t)
568 576 error = False
569 577 if self.block:
570 578 for i in range(len(targets)):
571 579 idents,msg = self.session.recv(self._control_socket,0)
572 580 if self.debug:
573 581 pprint(msg)
574 582 if msg['content']['status'] != 'ok':
575 583 error = ss.unwrap_exception(msg['content'])
576 584 if error:
577 585 return error
578 586
579 587 @spinfirst
580 588 @defaultblock
581 589 def shutdown(self, targets=None, restart=False, controller=False, block=None):
582 590 """Terminates one or more engine processes, optionally including the controller."""
583 591 if controller:
584 592 targets = 'all'
585 593 targets = self._build_targets(targets)[0]
586 594 for t in targets:
587 595 self.session.send(self._control_socket, 'shutdown_request',
588 596 content={'restart':restart},ident=t)
589 597 error = False
590 598 if block or controller:
591 599 for i in range(len(targets)):
592 600 idents,msg = self.session.recv(self._control_socket,0)
593 601 if self.debug:
594 602 pprint(msg)
595 603 if msg['content']['status'] != 'ok':
596 604 error = ss.unwrap_exception(msg['content'])
597 605
598 606 if controller:
599 607 time.sleep(0.25)
600 608 self.session.send(self._query_socket, 'shutdown_request')
601 609 idents,msg = self.session.recv(self._query_socket, 0)
602 610 if self.debug:
603 611 pprint(msg)
604 612 if msg['content']['status'] != 'ok':
605 613 error = ss.unwrap_exception(msg['content'])
606 614
607 615 if error:
608 616 return error
609 617
610 618 #--------------------------------------------------------------------------
611 619 # Execution methods
612 620 #--------------------------------------------------------------------------
613 621
614 622 @defaultblock
615 623 def execute(self, code, targets='all', block=None):
616 624 """Executes `code` on `targets` in blocking or nonblocking manner.
617 625
618 626 Parameters
619 627 ----------
620 628 code : str
621 629 the code string to be executed
622 630 targets : int/str/list of ints/strs
623 631 the engines on which to execute
624 632 default : all
625 633 block : bool
626 634 whether or not to wait until done to return
627 635 default: self.block
628 636 """
629 637 # block = self.block if block is None else block
630 638 # saveblock = self.block
631 639 # self.block = block
632 640 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
633 641 # self.block = saveblock
634 642 return result
635 643
636 644 def run(self, code, block=None):
637 645 """Runs `code` on an engine.
638 646
639 647 Calls to this are load-balanced.
640 648
641 649 Parameters
642 650 ----------
643 651 code : str
644 652 the code string to be executed
645 653 block : bool
646 654 whether or not to wait until done
647 655
648 656 """
649 657 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
650 658 return result
651 659
660 def _maybe_raise(self, result):
661 """wrapper for maybe raising an exception if apply failed."""
662 if isinstance(result, error.RemoteError):
663 raise result
664
665 return result
666
652 667 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
653 668 after=None, follow=None):
654 669 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
655 670
656 671 This is the central execution command for the client.
657 672
658 673 Parameters
659 674 ----------
660 675
661 676 f : function
662 677 The fuction to be called remotely
663 678 args : tuple/list
664 679 The positional arguments passed to `f`
665 680 kwargs : dict
666 681 The keyword arguments passed to `f`
667 682 bound : bool (default: True)
668 683 Whether to execute in the Engine(s) namespace, or in a clean
669 684 namespace not affecting the engine.
670 685 block : bool (default: self.block)
671 686 Whether to wait for the result, or return immediately.
672 687 False:
673 688 returns msg_id(s)
674 689 if multiple targets:
675 690 list of ids
676 691 True:
677 692 returns actual result(s) of f(*args, **kwargs)
678 693 if multiple targets:
679 694 dict of results, by engine ID
680 695 targets : int,list of ints, 'all', None
681 696 Specify the destination of the job.
682 697 if None:
683 698 Submit via Task queue for load-balancing.
684 699 if 'all':
685 700 Run on all active engines
686 701 if list:
687 702 Run on each specified engine
688 703 if int:
689 704 Run on single engine
690 705
691 706 after : Dependency or collection of msg_ids
692 707 Only for load-balanced execution (targets=None)
693 708 Specify a list of msg_ids as a time-based dependency.
694 709 This job will only be run *after* the dependencies
695 710 have been met.
696 711
697 712 follow : Dependency or collection of msg_ids
698 713 Only for load-balanced execution (targets=None)
699 714 Specify a list of msg_ids as a location-based dependency.
700 715 This job will only be run on an engine where this dependency
701 716 is met.
702 717
703 718 Returns
704 719 -------
705 720 if block is False:
706 721 if single target:
707 722 return msg_id
708 723 else:
709 724 return list of msg_ids
710 725 ? (should this be dict like block=True) ?
711 726 else:
712 727 if single target:
713 728 return result of f(*args, **kwargs)
714 729 else:
715 730 return dict of results, keyed by engine
716 731 """
717 732
718 733 # defaults:
719 734 block = block if block is not None else self.block
720 735 args = args if args is not None else []
721 736 kwargs = kwargs if kwargs is not None else {}
722 737
723 738 # enforce types of f,args,kwrags
724 739 if not callable(f):
725 740 raise TypeError("f must be callable, not %s"%type(f))
726 741 if not isinstance(args, (tuple, list)):
727 742 raise TypeError("args must be tuple or list, not %s"%type(args))
728 743 if not isinstance(kwargs, dict):
729 744 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
730 745
731 746 options = dict(bound=bound, block=block, after=after, follow=follow)
732 747
733 748 if targets is None:
734 749 return self._apply_balanced(f, args, kwargs, **options)
735 750 else:
736 751 return self._apply_direct(f, args, kwargs, targets=targets, **options)
737 752
738 753 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
739 754 after=None, follow=None):
740 755 """The underlying method for applying functions in a load balanced
741 756 manner, via the task queue."""
742 757 if isinstance(after, Dependency):
743 758 after = after.as_dict()
744 759 elif after is None:
745 760 after = []
746 761 if isinstance(follow, Dependency):
747 762 follow = follow.as_dict()
748 763 elif follow is None:
749 764 follow = []
750 765 subheader = dict(after=after, follow=follow)
751 766
752 767 bufs = ss.pack_apply_message(f,args,kwargs)
753 768 content = dict(bound=bound)
754 769 msg = self.session.send(self._task_socket, "apply_request",
755 770 content=content, buffers=bufs, subheader=subheader)
756 771 msg_id = msg['msg_id']
757 772 self.outstanding.add(msg_id)
758 773 self.history.append(msg_id)
759 774 if block:
760 775 self.barrier(msg_id)
761 return self.results[msg_id]
776 return self._maybe_raise(self.results[msg_id])
762 777 else:
763 778 return msg_id
764 779
765 780 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
766 781 after=None, follow=None):
767 782 """Then underlying method for applying functions to specific engines
768 783 via the MUX queue."""
769 784
770 785 queues,targets = self._build_targets(targets)
771 786 bufs = ss.pack_apply_message(f,args,kwargs)
772 787 if isinstance(after, Dependency):
773 788 after = after.as_dict()
774 789 elif after is None:
775 790 after = []
776 791 if isinstance(follow, Dependency):
777 792 follow = follow.as_dict()
778 793 elif follow is None:
779 794 follow = []
780 795 subheader = dict(after=after, follow=follow)
781 796 content = dict(bound=bound)
782 797 msg_ids = []
783 798 for queue in queues:
784 799 msg = self.session.send(self._mux_socket, "apply_request",
785 800 content=content, buffers=bufs,ident=queue, subheader=subheader)
786 801 msg_id = msg['msg_id']
787 802 self.outstanding.add(msg_id)
788 803 self.history.append(msg_id)
789 804 msg_ids.append(msg_id)
790 805 if block:
791 806 self.barrier(msg_ids)
792 807 else:
793 808 if len(msg_ids) == 1:
794 809 return msg_ids[0]
795 810 else:
796 811 return msg_ids
797 812 if len(msg_ids) == 1:
798 return self.results[msg_ids[0]]
813 return self._maybe_raise(self.results[msg_ids[0]])
799 814 else:
800 815 result = {}
801 816 for target,mid in zip(targets, msg_ids):
802 817 result[target] = self.results[mid]
803 return result
818 return error.collect_exceptions(result, f.__name__)
804 819
805 820 #--------------------------------------------------------------------------
806 821 # Data movement
807 822 #--------------------------------------------------------------------------
808 823
809 824 @defaultblock
810 825 def push(self, ns, targets=None, block=None):
811 826 """Push the contents of `ns` into the namespace on `target`"""
812 827 if not isinstance(ns, dict):
813 828 raise TypeError("Must be a dict, not %s"%type(ns))
814 829 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
815 830 return result
816 831
817 832 @defaultblock
818 833 def pull(self, keys, targets=None, block=True):
819 834 """Pull objects from `target`'s namespace by `keys`"""
820 835 if isinstance(keys, str):
821 836 pass
822 837 elif isinstance(keys, (list,tuple,set)):
823 838 for key in keys:
824 839 if not isinstance(key, str):
825 840 raise TypeError
826 841 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
827 842 return result
828 843
829 844 #--------------------------------------------------------------------------
830 845 # Query methods
831 846 #--------------------------------------------------------------------------
832 847
833 848 @spinfirst
834 849 def get_results(self, msg_ids, status_only=False):
835 850 """Returns the result of the execute or task request with `msg_ids`.
836 851
837 852 Parameters
838 853 ----------
839 854 msg_ids : list of ints or msg_ids
840 855 if int:
841 856 Passed as index to self.history for convenience.
842 857 status_only : bool (default: False)
843 858 if False:
844 859 return the actual results
845 860 """
846 861 if not isinstance(msg_ids, (list,tuple)):
847 862 msg_ids = [msg_ids]
848 863 theids = []
849 864 for msg_id in msg_ids:
850 865 if isinstance(msg_id, int):
851 866 msg_id = self.history[msg_id]
852 867 if not isinstance(msg_id, str):
853 868 raise TypeError("msg_ids must be str, not %r"%msg_id)
854 869 theids.append(msg_id)
855 870
856 871 completed = []
857 872 local_results = {}
858 873 for msg_id in list(theids):
859 874 if msg_id in self.results:
860 875 completed.append(msg_id)
861 876 local_results[msg_id] = self.results[msg_id]
862 877 theids.remove(msg_id)
863 878
864 879 if theids: # some not locally cached
865 880 content = dict(msg_ids=theids, status_only=status_only)
866 881 msg = self.session.send(self._query_socket, "result_request", content=content)
867 882 zmq.select([self._query_socket], [], [])
868 883 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
869 884 if self.debug:
870 885 pprint(msg)
871 886 content = msg['content']
872 887 if content['status'] != 'ok':
873 888 raise ss.unwrap_exception(content)
874 889 else:
875 890 content = dict(completed=[],pending=[])
876 891 if not status_only:
877 892 # load cached results into result:
878 893 content['completed'].extend(completed)
879 894 content.update(local_results)
880 895 # update cache with results:
881 896 for msg_id in msg_ids:
882 897 if msg_id in content['completed']:
883 898 self.results[msg_id] = content[msg_id]
884 899 return content
885 900
886 901 @spinfirst
887 902 def queue_status(self, targets=None, verbose=False):
888 903 """Fetch the status of engine queues.
889 904
890 905 Parameters
891 906 ----------
892 907 targets : int/str/list of ints/strs
893 908 the engines on which to execute
894 909 default : all
895 910 verbose : bool
896 911 whether to return lengths only, or lists of ids for each element
897 912 """
898 913 targets = self._build_targets(targets)[1]
899 914 content = dict(targets=targets, verbose=verbose)
900 915 self.session.send(self._query_socket, "queue_request", content=content)
901 916 idents,msg = self.session.recv(self._query_socket, 0)
902 917 if self.debug:
903 918 pprint(msg)
904 919 content = msg['content']
905 920 status = content.pop('status')
906 921 if status != 'ok':
907 922 raise ss.unwrap_exception(content)
908 923 return content
909 924
910 925 @spinfirst
911 926 def purge_results(self, msg_ids=[], targets=[]):
912 927 """Tell the controller to forget results.
913 928
914 929 Individual results can be purged by msg_id, or the entire
915 930 history of specific targets can
916 931
917 932 Parameters
918 933 ----------
919 934 targets : int/str/list of ints/strs
920 935 the targets
921 936 default : None
922 937 """
923 938 if not targets and not msg_ids:
924 939 raise ValueError
925 940 if targets:
926 941 targets = self._build_targets(targets)[1]
927 942 content = dict(targets=targets, msg_ids=msg_ids)
928 943 self.session.send(self._query_socket, "purge_request", content=content)
929 944 idents, msg = self.session.recv(self._query_socket, 0)
930 945 if self.debug:
931 946 pprint(msg)
932 947 content = msg['content']
933 948 if content['status'] != 'ok':
934 949 raise ss.unwrap_exception(content)
935 950
936 951 class AsynClient(Client):
937 952 """An Asynchronous client, using the Tornado Event Loop.
938 953 !!!unfinished!!!"""
939 954 io_loop = None
940 955 _queue_stream = None
941 956 _notifier_stream = None
942 957 _task_stream = None
943 958 _control_stream = None
944 959
945 960 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
946 961 Client.__init__(self, addr, context, username, debug)
947 962 if io_loop is None:
948 963 io_loop = ioloop.IOLoop.instance()
949 964 self.io_loop = io_loop
950 965
951 966 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
952 967 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
953 968 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
954 969 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
955 970
956 971 def spin(self):
957 972 for stream in (self.queue_stream, self.notifier_stream,
958 973 self.task_stream, self.control_stream):
959 974 stream.flush()
960 975
@@ -1,423 +1,430 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Imports
8 8 #-----------------------------------------------------------------------------
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 12 import __builtin__
13 13 from code import CommandCompiler
14 14 import os
15 15 import sys
16 16 import time
17 17 import traceback
18 18 from datetime import datetime
19 19 from signal import SIGTERM, SIGKILL
20 20 from pprint import pprint
21 21
22 22 # System library imports.
23 23 import zmq
24 24 from zmq.eventloop import ioloop, zmqstream
25 25
26 26 # Local imports.
27 from IPython.core import ultratb
27 28 from IPython.utils.traitlets import HasTraits, Instance, List
28 29 from IPython.zmq.completer import KernelCompleter
29 30 from IPython.zmq.log import logger # a Logger object
30 31
31 32 from streamsession import StreamSession, Message, extract_header, serialize_object,\
32 33 unpack_apply_message, ISO8601, wrap_exception
33 34 from dependency import UnmetDependency
34 35 import heartmonitor
35 36 from client import Client
36 37
37 38 def printer(*args):
38 39 pprint(args)
39 40
40 41 #-----------------------------------------------------------------------------
41 42 # Main kernel class
42 43 #-----------------------------------------------------------------------------
43 44
44 45 class Kernel(HasTraits):
45 46
46 47 #---------------------------------------------------------------------------
47 48 # Kernel interface
48 49 #---------------------------------------------------------------------------
49 50
50 51 session = Instance(StreamSession)
51 52 shell_streams = Instance(list)
52 53 control_stream = Instance(zmqstream.ZMQStream)
53 54 task_stream = Instance(zmqstream.ZMQStream)
54 55 iopub_stream = Instance(zmqstream.ZMQStream)
55 56 client = Instance(Client)
56 57 loop = Instance(ioloop.IOLoop)
57 58
58 59 def __init__(self, **kwargs):
59 60 super(Kernel, self).__init__(**kwargs)
60 61 self.identity = self.shell_streams[0].getsockopt(zmq.IDENTITY)
61 62 self.user_ns = {}
62 63 self.history = []
63 64 self.compiler = CommandCompiler()
64 65 self.completer = KernelCompleter(self.user_ns)
65 66 self.aborted = set()
66 67
67 68 # Build dict of handlers for message types
68 69 self.shell_handlers = {}
69 70 self.control_handlers = {}
70 71 for msg_type in ['execute_request', 'complete_request', 'apply_request',
71 72 'clear_request']:
72 73 self.shell_handlers[msg_type] = getattr(self, msg_type)
73 74
74 75 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
75 76 self.control_handlers[msg_type] = getattr(self, msg_type)
76
77
78
79 def _wrap_exception(self, method=None):
80 e_info = dict(engineid=self.identity, method=method)
81 content=wrap_exception(e_info)
82 return content
83
77 84 #-------------------- control handlers -----------------------------
78 85 def abort_queues(self):
79 86 for stream in self.shell_streams:
80 87 if stream:
81 88 self.abort_queue(stream)
82 89
83 90 def abort_queue(self, stream):
84 91 while True:
85 92 try:
86 93 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
87 94 except zmq.ZMQError as e:
88 95 if e.errno == zmq.EAGAIN:
89 96 break
90 97 else:
91 98 return
92 99 else:
93 100 if msg is None:
94 101 return
95 102 else:
96 103 idents,msg = msg
97 104
98 105 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
99 106 # msg = self.reply_socket.recv_json()
100 107 print ("Aborting:", file=sys.__stdout__)
101 108 print (Message(msg), file=sys.__stdout__)
102 109 msg_type = msg['msg_type']
103 110 reply_type = msg_type.split('_')[0] + '_reply'
104 111 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
105 112 # self.reply_socket.send(ident,zmq.SNDMORE)
106 113 # self.reply_socket.send_json(reply_msg)
107 114 reply_msg = self.session.send(stream, reply_type,
108 115 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
109 116 print(Message(reply_msg), file=sys.__stdout__)
110 117 # We need to wait a bit for requests to come in. This can probably
111 118 # be set shorter for true asynchronous clients.
112 119 time.sleep(0.05)
113 120
114 121 def abort_request(self, stream, ident, parent):
115 122 """abort a specifig msg by id"""
116 123 msg_ids = parent['content'].get('msg_ids', None)
117 124 if isinstance(msg_ids, basestring):
118 125 msg_ids = [msg_ids]
119 126 if not msg_ids:
120 127 self.abort_queues()
121 128 for mid in msg_ids:
122 129 self.aborted.add(str(mid))
123 130
124 131 content = dict(status='ok')
125 132 reply_msg = self.session.send(stream, 'abort_reply', content=content,
126 133 parent=parent, ident=ident)[0]
127 134 print(Message(reply_msg), file=sys.__stdout__)
128 135
129 136 def shutdown_request(self, stream, ident, parent):
130 137 """kill ourself. This should really be handled in an external process"""
131 138 try:
132 139 self.abort_queues()
133 140 except:
134 content = wrap_exception()
141 content = self._wrap_exception('shutdown')
135 142 else:
136 143 content = dict(parent['content'])
137 144 content['status'] = 'ok'
138 145 msg = self.session.send(stream, 'shutdown_reply',
139 146 content=content, parent=parent, ident=ident)
140 147 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
141 148 # content, parent, ident)
142 149 # print >> sys.__stdout__, msg
143 150 # time.sleep(0.2)
144 151 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
145 152 dc.start()
146 153
147 154 def dispatch_control(self, msg):
148 155 idents,msg = self.session.feed_identities(msg, copy=False)
149 156 try:
150 157 msg = self.session.unpack_message(msg, content=True, copy=False)
151 158 except:
152 159 logger.error("Invalid Message", exc_info=True)
153 160 return
154 161
155 162 header = msg['header']
156 163 msg_id = header['msg_id']
157 164
158 165 handler = self.control_handlers.get(msg['msg_type'], None)
159 166 if handler is None:
160 167 print ("UNKNOWN CONTROL MESSAGE TYPE:", msg, file=sys.__stderr__)
161 168 else:
162 169 handler(self.control_stream, idents, msg)
163 170
164 171
165 172 #-------------------- queue helpers ------------------------------
166 173
167 174 def check_dependencies(self, dependencies):
168 175 if not dependencies:
169 176 return True
170 177 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
171 178 anyorall = dependencies[0]
172 179 dependencies = dependencies[1]
173 180 else:
174 181 anyorall = 'all'
175 182 results = self.client.get_results(dependencies,status_only=True)
176 183 if results['status'] != 'ok':
177 184 return False
178 185
179 186 if anyorall == 'any':
180 187 if not results['completed']:
181 188 return False
182 189 else:
183 190 if results['pending']:
184 191 return False
185 192
186 193 return True
187 194
188 195 def check_aborted(self, msg_id):
189 196 return msg_id in self.aborted
190 197
191 198 #-------------------- queue handlers -----------------------------
192 199
193 200 def clear_request(self, stream, idents, parent):
194 201 """Clear our namespace."""
195 202 self.user_ns = {}
196 203 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
197 204 content = dict(status='ok'))
198 205
199 206 def execute_request(self, stream, ident, parent):
200 207 try:
201 208 code = parent[u'content'][u'code']
202 209 except:
203 210 print("Got bad msg: ", file=sys.__stderr__)
204 211 print(Message(parent), file=sys.__stderr__)
205 212 return
206 213 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
207 214 # self.iopub_stream.send(pyin_msg)
208 215 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
209 216 started = datetime.now().strftime(ISO8601)
210 217 try:
211 218 comp_code = self.compiler(code, '<zmq-kernel>')
212 219 # allow for not overriding displayhook
213 220 if hasattr(sys.displayhook, 'set_parent'):
214 221 sys.displayhook.set_parent(parent)
215 222 exec comp_code in self.user_ns, self.user_ns
216 223 except:
217 exc_content = wrap_exception()
224 exc_content = self._wrap_exception('execute')
218 225 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
219 226 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
220 227 reply_content = exc_content
221 228 else:
222 229 reply_content = {'status' : 'ok'}
223 230 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
224 231 # self.reply_socket.send(ident, zmq.SNDMORE)
225 232 # self.reply_socket.send_json(reply_msg)
226 233 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
227 234 ident=ident, subheader = dict(started=started))
228 235 print(Message(reply_msg), file=sys.__stdout__)
229 236 if reply_msg['content']['status'] == u'error':
230 237 self.abort_queues()
231 238
232 239 def complete_request(self, stream, ident, parent):
233 240 matches = {'matches' : self.complete(parent),
234 241 'status' : 'ok'}
235 242 completion_msg = self.session.send(stream, 'complete_reply',
236 243 matches, parent, ident)
237 244 # print >> sys.__stdout__, completion_msg
238 245
239 246 def complete(self, msg):
240 247 return self.completer.complete(msg.content.line, msg.content.text)
241 248
242 249 def apply_request(self, stream, ident, parent):
243 250 print (parent)
244 251 try:
245 252 content = parent[u'content']
246 253 bufs = parent[u'buffers']
247 254 msg_id = parent['header']['msg_id']
248 255 bound = content.get('bound', False)
249 256 except:
250 257 print("Got bad msg: ", file=sys.__stderr__)
251 258 print(Message(parent), file=sys.__stderr__)
252 259 return
253 260 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
254 261 # self.iopub_stream.send(pyin_msg)
255 262 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
256 263 sub = {'dependencies_met' : True, 'engine' : self.identity,
257 264 'started': datetime.now().strftime(ISO8601)}
258 265 try:
259 266 # allow for not overriding displayhook
260 267 if hasattr(sys.displayhook, 'set_parent'):
261 268 sys.displayhook.set_parent(parent)
262 269 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
263 270 if bound:
264 271 working = self.user_ns
265 272 suffix = str(msg_id).replace("-","")
266 273 prefix = "_"
267 274
268 275 else:
269 276 working = dict()
270 277 suffix = prefix = "_" # prevent keyword collisions with lambda
271 278 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
272 279 # if f.fun
273 280 fname = prefix+f.func_name.strip('<>')+suffix
274 281 argname = prefix+"args"+suffix
275 282 kwargname = prefix+"kwargs"+suffix
276 283 resultname = prefix+"result"+suffix
277 284
278 285 ns = { fname : f, argname : args, kwargname : kwargs }
279 286 # print ns
280 287 working.update(ns)
281 288 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
282 289 exec code in working, working
283 290 result = working.get(resultname)
284 291 # clear the namespace
285 292 if bound:
286 293 for key in ns.iterkeys():
287 294 self.user_ns.pop(key)
288 295 else:
289 296 del working
290 297
291 298 packed_result,buf = serialize_object(result)
292 299 result_buf = [packed_result]+buf
293 300 except:
294 exc_content = wrap_exception()
301 exc_content = self._wrap_exception('apply')
295 302 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
296 303 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
297 304 reply_content = exc_content
298 305 result_buf = []
299 306
300 if etype is UnmetDependency:
307 if exc_content['ename'] == UnmetDependency.__name__:
301 308 sub['dependencies_met'] = False
302 309 else:
303 310 reply_content = {'status' : 'ok'}
304 311 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
305 312 # self.reply_socket.send(ident, zmq.SNDMORE)
306 313 # self.reply_socket.send_json(reply_msg)
307 314 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
308 315 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
309 316 print(Message(reply_msg), file=sys.__stdout__)
310 317 # if reply_msg['content']['status'] == u'error':
311 318 # self.abort_queues()
312 319
313 320 def dispatch_queue(self, stream, msg):
314 321 self.control_stream.flush()
315 322 idents,msg = self.session.feed_identities(msg, copy=False)
316 323 try:
317 324 msg = self.session.unpack_message(msg, content=True, copy=False)
318 325 except:
319 326 logger.error("Invalid Message", exc_info=True)
320 327 return
321 328
322 329
323 330 header = msg['header']
324 331 msg_id = header['msg_id']
325 332 if self.check_aborted(msg_id):
326 333 self.aborted.remove(msg_id)
327 334 # is it safe to assume a msg_id will not be resubmitted?
328 335 reply_type = msg['msg_type'].split('_')[0] + '_reply'
329 336 reply_msg = self.session.send(stream, reply_type,
330 337 content={'status' : 'aborted'}, parent=msg, ident=idents)
331 338 return
332 339 handler = self.shell_handlers.get(msg['msg_type'], None)
333 340 if handler is None:
334 341 print ("UNKNOWN MESSAGE TYPE:", msg, file=sys.__stderr__)
335 342 else:
336 343 handler(stream, idents, msg)
337 344
338 345 def start(self):
339 346 #### stream mode:
340 347 if self.control_stream:
341 348 self.control_stream.on_recv(self.dispatch_control, copy=False)
342 349 self.control_stream.on_err(printer)
343 350
344 351 def make_dispatcher(stream):
345 352 def dispatcher(msg):
346 353 return self.dispatch_queue(stream, msg)
347 354 return dispatcher
348 355
349 356 for s in self.shell_streams:
350 357 s.on_recv(make_dispatcher(s), copy=False)
351 358 s.on_err(printer)
352 359
353 360 if self.iopub_stream:
354 361 self.iopub_stream.on_err(printer)
355 362 self.iopub_stream.on_send(printer)
356 363
357 364 #### while True mode:
358 365 # while True:
359 366 # idle = True
360 367 # try:
361 368 # msg = self.shell_stream.socket.recv_multipart(
362 369 # zmq.NOBLOCK, copy=False)
363 370 # except zmq.ZMQError, e:
364 371 # if e.errno != zmq.EAGAIN:
365 372 # raise e
366 373 # else:
367 374 # idle=False
368 375 # self.dispatch_queue(self.shell_stream, msg)
369 376 #
370 377 # if not self.task_stream.empty():
371 378 # idle=False
372 379 # msg = self.task_stream.recv_multipart()
373 380 # self.dispatch_queue(self.task_stream, msg)
374 381 # if idle:
375 382 # # don't busywait
376 383 # time.sleep(1e-3)
377 384
378 385 def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
379 386 client_addr=None, loop=None, context=None, key=None):
380 387 # create loop, context, and session:
381 388 if loop is None:
382 389 loop = ioloop.IOLoop.instance()
383 390 if context is None:
384 391 context = zmq.Context()
385 392 c = context
386 393 session = StreamSession(key=key)
387 394 # print (session.key)
388 395 print (control_addr, shell_addrs, iopub_addr, hb_addrs)
389 396
390 397 # create Control Stream
391 398 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
392 399 control_stream.setsockopt(zmq.IDENTITY, identity)
393 400 control_stream.connect(control_addr)
394 401
395 402 # create Shell Streams (MUX, Task, etc.):
396 403 shell_streams = []
397 404 for addr in shell_addrs:
398 405 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
399 406 stream.setsockopt(zmq.IDENTITY, identity)
400 407 stream.connect(addr)
401 408 shell_streams.append(stream)
402 409
403 410 # create iopub stream:
404 411 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
405 412 iopub_stream.setsockopt(zmq.IDENTITY, identity)
406 413 iopub_stream.connect(iopub_addr)
407 414
408 415 # launch heartbeat
409 416 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
410 417 heart.start()
411 418
412 419 # create (optional) Client
413 420 if client_addr:
414 421 client = Client(client_addr, username=identity)
415 422 else:
416 423 client = None
417 424
418 425 kernel = Kernel(session=session, control_stream=control_stream,
419 426 shell_streams=shell_streams, iopub_stream=iopub_stream,
420 427 client=client, loop=loop)
421 428 kernel.start()
422 429 return loop, c, kernel
423 430
@@ -1,531 +1,530 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11 from datetime import datetime
12 12
13 13 import zmq
14 14 from zmq.utils import jsonapi
15 15 from zmq.eventloop.zmqstream import ZMQStream
16 16
17 17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 18 from IPython.utils.newserialized import serialize, unserialize
19 19
20 from IPython.zmq.parallel.error import RemoteError
21
20 22 try:
21 23 import cPickle
22 24 pickle = cPickle
23 25 except:
24 26 cPickle = None
25 27 import pickle
26 28
27 29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
28 30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
29 31 if json_name in ('jsonlib', 'jsonlib2'):
30 32 use_json = True
31 33 elif json_name:
32 34 if cPickle is None:
33 35 use_json = True
34 36 else:
35 37 use_json = False
36 38 else:
37 39 use_json = False
38 40
39 41 def squash_unicode(obj):
40 42 if isinstance(obj,dict):
41 43 for key in obj.keys():
42 44 obj[key] = squash_unicode(obj[key])
43 45 if isinstance(key, unicode):
44 46 obj[squash_unicode(key)] = obj.pop(key)
45 47 elif isinstance(obj, list):
46 48 for i,v in enumerate(obj):
47 49 obj[i] = squash_unicode(v)
48 50 elif isinstance(obj, unicode):
49 51 obj = obj.encode('utf8')
50 52 return obj
51 53
52 54 if use_json:
53 55 default_packer = jsonapi.dumps
54 56 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
55 57 else:
56 58 default_packer = lambda o: pickle.dumps(o,-1)
57 59 default_unpacker = pickle.loads
58 60
59 61
60 62 DELIM="<IDS|MSG>"
61 63 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
62 64
63 def wrap_exception():
65 def wrap_exception(engine_info={}):
64 66 etype, evalue, tb = sys.exc_info()
65 tb = traceback.format_exception(etype, evalue, tb)
67 stb = traceback.format_exception(etype, evalue, tb)
66 68 exc_content = {
67 69 'status' : 'error',
68 'traceback' : [ line.encode('utf8') for line in tb ],
69 'etype' : str(etype).encode('utf8'),
70 'evalue' : evalue.encode('utf8')
70 'traceback' : stb,
71 'ename' : unicode(etype.__name__),
72 'evalue' : unicode(evalue),
73 'engine_info' : engine_info
71 74 }
72 75 return exc_content
73 76
74 class KernelError(Exception):
75 pass
76
77 77 def unwrap_exception(content):
78 err = KernelError(content['etype'], content['evalue'])
79 err.evalue = content['evalue']
80 err.etype = content['etype']
81 err.traceback = ''.join(content['traceback'])
78 err = RemoteError(content['ename'], content['evalue'],
79 ''.join(content['traceback']),
80 content.get('engine_info', {}))
82 81 return err
83 82
84 83
85 84 class Message(object):
86 85 """A simple message object that maps dict keys to attributes.
87 86
88 87 A Message can be created from a dict and a dict from a Message instance
89 88 simply by calling dict(msg_obj)."""
90 89
91 90 def __init__(self, msg_dict):
92 91 dct = self.__dict__
93 92 for k, v in dict(msg_dict).iteritems():
94 93 if isinstance(v, dict):
95 94 v = Message(v)
96 95 dct[k] = v
97 96
98 97 # Having this iterator lets dict(msg_obj) work out of the box.
99 98 def __iter__(self):
100 99 return iter(self.__dict__.iteritems())
101 100
102 101 def __repr__(self):
103 102 return repr(self.__dict__)
104 103
105 104 def __str__(self):
106 105 return pprint.pformat(self.__dict__)
107 106
108 107 def __contains__(self, k):
109 108 return k in self.__dict__
110 109
111 110 def __getitem__(self, k):
112 111 return self.__dict__[k]
113 112
114 113
115 114 def msg_header(msg_id, msg_type, username, session):
116 115 date=datetime.now().strftime(ISO8601)
117 116 return locals()
118 117
119 118 def extract_header(msg_or_header):
120 119 """Given a message or header, return the header."""
121 120 if not msg_or_header:
122 121 return {}
123 122 try:
124 123 # See if msg_or_header is the entire message.
125 124 h = msg_or_header['header']
126 125 except KeyError:
127 126 try:
128 127 # See if msg_or_header is just the header
129 128 h = msg_or_header['msg_id']
130 129 except KeyError:
131 130 raise
132 131 else:
133 132 h = msg_or_header
134 133 if not isinstance(h, dict):
135 134 h = dict(h)
136 135 return h
137 136
138 137 def rekey(dikt):
139 138 """Rekey a dict that has been forced to use str keys where there should be
140 139 ints by json. This belongs in the jsonutil added by fperez."""
141 140 for k in dikt.iterkeys():
142 141 if isinstance(k, str):
143 142 ik=fk=None
144 143 try:
145 144 ik = int(k)
146 145 except ValueError:
147 146 try:
148 147 fk = float(k)
149 148 except ValueError:
150 149 continue
151 150 if ik is not None:
152 151 nk = ik
153 152 else:
154 153 nk = fk
155 154 if nk in dikt:
156 155 raise KeyError("already have key %r"%nk)
157 156 dikt[nk] = dikt.pop(k)
158 157 return dikt
159 158
160 159 def serialize_object(obj, threshold=64e-6):
161 160 """Serialize an object into a list of sendable buffers.
162 161
163 162 Parameters
164 163 ----------
165 164
166 165 obj : object
167 166 The object to be serialized
168 167 threshold : float
169 168 The threshold for not double-pickling the content.
170 169
171 170
172 171 Returns
173 172 -------
174 173 ('pmd', [bufs]) :
175 174 where pmd is the pickled metadata wrapper,
176 175 bufs is a list of data buffers
177 176 """
178 177 databuffers = []
179 178 if isinstance(obj, (list, tuple)):
180 179 clist = canSequence(obj)
181 180 slist = map(serialize, clist)
182 181 for s in slist:
183 182 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
184 183 databuffers.append(s.getData())
185 184 s.data = None
186 185 return pickle.dumps(slist,-1), databuffers
187 186 elif isinstance(obj, dict):
188 187 sobj = {}
189 188 for k in sorted(obj.iterkeys()):
190 189 s = serialize(can(obj[k]))
191 190 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
192 191 databuffers.append(s.getData())
193 192 s.data = None
194 193 sobj[k] = s
195 194 return pickle.dumps(sobj,-1),databuffers
196 195 else:
197 196 s = serialize(can(obj))
198 197 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
199 198 databuffers.append(s.getData())
200 199 s.data = None
201 200 return pickle.dumps(s,-1),databuffers
202 201
203 202
204 203 def unserialize_object(bufs):
205 204 """reconstruct an object serialized by serialize_object from data buffers"""
206 205 bufs = list(bufs)
207 206 sobj = pickle.loads(bufs.pop(0))
208 207 if isinstance(sobj, (list, tuple)):
209 208 for s in sobj:
210 209 if s.data is None:
211 210 s.data = bufs.pop(0)
212 211 return uncanSequence(map(unserialize, sobj))
213 212 elif isinstance(sobj, dict):
214 213 newobj = {}
215 214 for k in sorted(sobj.iterkeys()):
216 215 s = sobj[k]
217 216 if s.data is None:
218 217 s.data = bufs.pop(0)
219 218 newobj[k] = uncan(unserialize(s))
220 219 return newobj
221 220 else:
222 221 if sobj.data is None:
223 222 sobj.data = bufs.pop(0)
224 223 return uncan(unserialize(sobj))
225 224
226 225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
227 226 """pack up a function, args, and kwargs to be sent over the wire
228 227 as a series of buffers. Any object whose data is larger than `threshold`
229 228 will not have their data copied (currently only numpy arrays support zero-copy)"""
230 229 msg = [pickle.dumps(can(f),-1)]
231 230 databuffers = [] # for large objects
232 231 sargs, bufs = serialize_object(args,threshold)
233 232 msg.append(sargs)
234 233 databuffers.extend(bufs)
235 234 skwargs, bufs = serialize_object(kwargs,threshold)
236 235 msg.append(skwargs)
237 236 databuffers.extend(bufs)
238 237 msg.extend(databuffers)
239 238 return msg
240 239
241 240 def unpack_apply_message(bufs, g=None, copy=True):
242 241 """unpack f,args,kwargs from buffers packed by pack_apply_message()
243 242 Returns: original f,args,kwargs"""
244 243 bufs = list(bufs) # allow us to pop
245 244 assert len(bufs) >= 3, "not enough buffers!"
246 245 if not copy:
247 246 for i in range(3):
248 247 bufs[i] = bufs[i].bytes
249 248 cf = pickle.loads(bufs.pop(0))
250 249 sargs = list(pickle.loads(bufs.pop(0)))
251 250 skwargs = dict(pickle.loads(bufs.pop(0)))
252 251 # print sargs, skwargs
253 252 f = uncan(cf, g)
254 253 for sa in sargs:
255 254 if sa.data is None:
256 255 m = bufs.pop(0)
257 256 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
258 257 if copy:
259 258 sa.data = buffer(m)
260 259 else:
261 260 sa.data = m.buffer
262 261 else:
263 262 if copy:
264 263 sa.data = m
265 264 else:
266 265 sa.data = m.bytes
267 266
268 267 args = uncanSequence(map(unserialize, sargs), g)
269 268 kwargs = {}
270 269 for k in sorted(skwargs.iterkeys()):
271 270 sa = skwargs[k]
272 271 if sa.data is None:
273 272 sa.data = bufs.pop(0)
274 273 kwargs[k] = uncan(unserialize(sa), g)
275 274
276 275 return f,args,kwargs
277 276
278 277 class StreamSession(object):
279 278 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
280 279 debug=False
281 280 key=None
282 281
283 282 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
284 283 if username is None:
285 284 username = os.environ.get('USER','username')
286 285 self.username = username
287 286 if session is None:
288 287 self.session = str(uuid.uuid4())
289 288 else:
290 289 self.session = session
291 290 self.msg_id = str(uuid.uuid4())
292 291 if packer is None:
293 292 self.pack = default_packer
294 293 else:
295 294 if not callable(packer):
296 295 raise TypeError("packer must be callable, not %s"%type(packer))
297 296 self.pack = packer
298 297
299 298 if unpacker is None:
300 299 self.unpack = default_unpacker
301 300 else:
302 301 if not callable(unpacker):
303 302 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
304 303 self.unpack = unpacker
305 304
306 305 if key is not None and keyfile is not None:
307 306 raise TypeError("Must specify key OR keyfile, not both")
308 307 if keyfile is not None:
309 308 with open(keyfile) as f:
310 309 self.key = f.read().strip()
311 310 else:
312 311 self.key = key
313 312 # print key, keyfile, self.key
314 313 self.none = self.pack({})
315 314
316 315 def msg_header(self, msg_type):
317 316 h = msg_header(self.msg_id, msg_type, self.username, self.session)
318 317 self.msg_id = str(uuid.uuid4())
319 318 return h
320 319
321 320 def msg(self, msg_type, content=None, parent=None, subheader=None):
322 321 msg = {}
323 322 msg['header'] = self.msg_header(msg_type)
324 323 msg['msg_id'] = msg['header']['msg_id']
325 324 msg['parent_header'] = {} if parent is None else extract_header(parent)
326 325 msg['msg_type'] = msg_type
327 326 msg['content'] = {} if content is None else content
328 327 sub = {} if subheader is None else subheader
329 328 msg['header'].update(sub)
330 329 return msg
331 330
332 331 def check_key(self, msg_or_header):
333 332 """Check that a message's header has the right key"""
334 333 if self.key is None:
335 334 return True
336 335 header = extract_header(msg_or_header)
337 336 return header.get('key', None) == self.key
338 337
339 338
340 339 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
341 340 """Build and send a message via stream or socket.
342 341
343 342 Parameters
344 343 ----------
345 344
346 345 stream : zmq.Socket or ZMQStream
347 346 the socket-like object used to send the data
348 347 msg_type : str or Message/dict
349 348 Normally, msg_type will be
350 349
351 350
352 351
353 352 Returns
354 353 -------
355 354 (msg,sent) : tuple
356 355 msg : Message
357 356 the nice wrapped dict-like object containing the headers
358 357
359 358 """
360 359 if isinstance(msg_type, (Message, dict)):
361 360 # we got a Message, not a msg_type
362 361 # don't build a new Message
363 362 msg = msg_type
364 363 content = msg['content']
365 364 else:
366 365 msg = self.msg(msg_type, content, parent, subheader)
367 366 buffers = [] if buffers is None else buffers
368 367 to_send = []
369 368 if isinstance(ident, list):
370 369 # accept list of idents
371 370 to_send.extend(ident)
372 371 elif ident is not None:
373 372 to_send.append(ident)
374 373 to_send.append(DELIM)
375 374 if self.key is not None:
376 375 to_send.append(self.key)
377 376 to_send.append(self.pack(msg['header']))
378 377 to_send.append(self.pack(msg['parent_header']))
379 378
380 379 if content is None:
381 380 content = self.none
382 381 elif isinstance(content, dict):
383 382 content = self.pack(content)
384 383 elif isinstance(content, str):
385 384 # content is already packed, as in a relayed message
386 385 pass
387 386 else:
388 387 raise TypeError("Content incorrect type: %s"%type(content))
389 388 to_send.append(content)
390 389 flag = 0
391 390 if buffers:
392 391 flag = zmq.SNDMORE
393 392 stream.send_multipart(to_send, flag, copy=False)
394 393 for b in buffers[:-1]:
395 394 stream.send(b, flag, copy=False)
396 395 if buffers:
397 396 stream.send(buffers[-1], copy=False)
398 397 omsg = Message(msg)
399 398 if self.debug:
400 399 pprint.pprint(omsg)
401 400 pprint.pprint(to_send)
402 401 pprint.pprint(buffers)
403 402 return omsg
404 403
405 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
404 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
406 405 """Send a raw message via idents.
407 406
408 407 Parameters
409 408 ----------
410 409 msg : list of sendable buffers"""
411 410 to_send = []
412 411 if isinstance(ident, str):
413 412 ident = [ident]
414 413 if ident is not None:
415 414 to_send.extend(ident)
416 415 to_send.append(DELIM)
417 416 if self.key is not None:
418 417 to_send.append(self.key)
419 418 to_send.extend(msg)
420 419 stream.send_multipart(msg, flags, copy=copy)
421 420
422 421 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
423 422 """receives and unpacks a message
424 423 returns [idents], msg"""
425 424 if isinstance(socket, ZMQStream):
426 425 socket = socket.socket
427 426 try:
428 427 msg = socket.recv_multipart(mode)
429 428 except zmq.ZMQError as e:
430 429 if e.errno == zmq.EAGAIN:
431 430 # We can convert EAGAIN to None as we know in this case
432 431 # recv_json won't return None.
433 432 return None
434 433 else:
435 434 raise
436 435 # return an actual Message object
437 436 # determine the number of idents by trying to unpack them.
438 437 # this is terrible:
439 438 idents, msg = self.feed_identities(msg, copy)
440 439 try:
441 440 return idents, self.unpack_message(msg, content=content, copy=copy)
442 441 except Exception as e:
443 442 print (idents, msg)
444 443 # TODO: handle it
445 444 raise e
446 445
447 446 def feed_identities(self, msg, copy=True):
448 447 """This is a completely horrible thing, but it strips the zmq
449 448 ident prefixes off of a message. It will break if any identities
450 449 are unpackable by self.unpack."""
451 450 msg = list(msg)
452 451 idents = []
453 452 while len(msg) > 3:
454 453 if copy:
455 454 s = msg[0]
456 455 else:
457 456 s = msg[0].bytes
458 457 if s == DELIM:
459 458 msg.pop(0)
460 459 break
461 460 else:
462 461 idents.append(s)
463 462 msg.pop(0)
464 463
465 464 return idents, msg
466 465
467 466 def unpack_message(self, msg, content=True, copy=True):
468 467 """Return a message object from the format
469 468 sent by self.send.
470 469
471 470 Parameters:
472 471 -----------
473 472
474 473 content : bool (True)
475 474 whether to unpack the content dict (True),
476 475 or leave it serialized (False)
477 476
478 477 copy : bool (True)
479 478 whether to return the bytes (True),
480 479 or the non-copying Message object in each place (False)
481 480
482 481 """
483 482 ikey = int(self.key is not None)
484 483 minlen = 3 + ikey
485 484 if not len(msg) >= minlen:
486 485 raise TypeError("malformed message, must have at least %i elements"%minlen)
487 486 message = {}
488 487 if not copy:
489 488 for i in range(minlen):
490 489 msg[i] = msg[i].bytes
491 490 if ikey:
492 491 if not self.key == msg[0]:
493 492 raise KeyError("Invalid Session Key: %s"%msg[0])
494 493 message['header'] = self.unpack(msg[ikey+0])
495 494 message['msg_type'] = message['header']['msg_type']
496 495 message['parent_header'] = self.unpack(msg[ikey+1])
497 496 if content:
498 497 message['content'] = self.unpack(msg[ikey+2])
499 498 else:
500 499 message['content'] = msg[ikey+2]
501 500
502 501 # message['buffers'] = msg[3:]
503 502 # else:
504 503 # message['header'] = self.unpack(msg[0].bytes)
505 504 # message['msg_type'] = message['header']['msg_type']
506 505 # message['parent_header'] = self.unpack(msg[1].bytes)
507 506 # if content:
508 507 # message['content'] = self.unpack(msg[2].bytes)
509 508 # else:
510 509 # message['content'] = msg[2].bytes
511 510
512 511 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
513 512 return message
514 513
515 514
516 515
517 516 def test_msg2obj():
518 517 am = dict(x=1)
519 518 ao = Message(am)
520 519 assert ao.x == am['x']
521 520
522 521 am['y'] = dict(z=1)
523 522 ao = Message(am)
524 523 assert ao.y.z == am['y']['z']
525 524
526 525 k1, k2 = 'y', 'z'
527 526 assert ao[k1][k2] == am[k1][k2]
528 527
529 528 am2 = dict(ao)
530 529 assert am['x'] == am2['x']
531 530 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now