##// END OF EJS Templates
Fixing docstrings and a few more places for msg_id/msg_type.
Brian E. Granger -
Show More
@@ -1,1291 +1,1291 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20 from __future__ import print_function
21 21
22 22 import sys
23 23 import time
24 24 from datetime import datetime
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop
28 28 from zmq.eventloop.zmqstream import ZMQStream
29 29
30 30 # internal:
31 31 from IPython.utils.importstring import import_item
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'result_header' : None,
68 68 'result_content' : None,
69 69 'result_buffers' : None,
70 70 'queue' : None,
71 71 'pyin' : None,
72 72 'pyout': None,
73 73 'pyerr': None,
74 74 'stdout': '',
75 75 'stderr': '',
76 76 }
77 77
78 78 def init_record(msg):
79 79 """Initialize a TaskRecord based on a request."""
80 80 header = msg['header']
81 81 return {
82 82 'msg_id' : header['msg_id'],
83 83 'header' : header,
84 84 'content': msg['content'],
85 85 'buffers': msg['buffers'],
86 86 'submitted': header['date'],
87 87 'client_uuid' : None,
88 88 'engine_uuid' : None,
89 89 'started': None,
90 90 'completed': None,
91 91 'resubmitted': None,
92 92 'result_header' : None,
93 93 'result_content' : None,
94 94 'result_buffers' : None,
95 95 'queue' : None,
96 96 'pyin' : None,
97 97 'pyout': None,
98 98 'pyerr': None,
99 99 'stdout': '',
100 100 'stderr': '',
101 101 }
102 102
103 103
104 104 class EngineConnector(HasTraits):
105 105 """A simple object for accessing the various zmq connections of an object.
106 106 Attributes are:
107 107 id (int): engine ID
108 108 uuid (str): uuid (unused?)
109 109 queue (str): identity of queue's XREQ socket
110 110 registration (str): identity of registration XREQ socket
111 111 heartbeat (str): identity of heartbeat XREQ socket
112 112 """
113 113 id=Int(0)
114 114 queue=CBytes()
115 115 control=CBytes()
116 116 registration=CBytes()
117 117 heartbeat=CBytes()
118 118 pending=Set()
119 119
120 120 class HubFactory(RegistrationFactory):
121 121 """The Configurable for setting up a Hub."""
122 122
123 123 # port-pairs for monitoredqueues:
124 124 hb = Tuple(Int,Int,config=True,
125 125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 126 def _hb_default(self):
127 127 return tuple(util.select_random_ports(2))
128 128
129 129 mux = Tuple(Int,Int,config=True,
130 130 help="""Engine/Client Port pair for MUX queue""")
131 131
132 132 def _mux_default(self):
133 133 return tuple(util.select_random_ports(2))
134 134
135 135 task = Tuple(Int,Int,config=True,
136 136 help="""Engine/Client Port pair for Task queue""")
137 137 def _task_default(self):
138 138 return tuple(util.select_random_ports(2))
139 139
140 140 control = Tuple(Int,Int,config=True,
141 141 help="""Engine/Client Port pair for Control queue""")
142 142
143 143 def _control_default(self):
144 144 return tuple(util.select_random_ports(2))
145 145
146 146 iopub = Tuple(Int,Int,config=True,
147 147 help="""Engine/Client Port pair for IOPub relay""")
148 148
149 149 def _iopub_default(self):
150 150 return tuple(util.select_random_ports(2))
151 151
152 152 # single ports:
153 153 mon_port = Int(config=True,
154 154 help="""Monitor (SUB) port for queue traffic""")
155 155
156 156 def _mon_port_default(self):
157 157 return util.select_random_ports(1)[0]
158 158
159 159 notifier_port = Int(config=True,
160 160 help="""PUB port for sending engine status notifications""")
161 161
162 162 def _notifier_port_default(self):
163 163 return util.select_random_ports(1)[0]
164 164
165 165 engine_ip = Unicode('127.0.0.1', config=True,
166 166 help="IP on which to listen for engine connections. [default: loopback]")
167 167 engine_transport = Unicode('tcp', config=True,
168 168 help="0MQ transport for engine connections. [default: tcp]")
169 169
170 170 client_ip = Unicode('127.0.0.1', config=True,
171 171 help="IP on which to listen for client connections. [default: loopback]")
172 172 client_transport = Unicode('tcp', config=True,
173 173 help="0MQ transport for client connections. [default : tcp]")
174 174
175 175 monitor_ip = Unicode('127.0.0.1', config=True,
176 176 help="IP on which to listen for monitor messages. [default: loopback]")
177 177 monitor_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for monitor messages. [default : tcp]")
179 179
180 180 monitor_url = Unicode('')
181 181
182 182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 183 config=True, help="""The class to use for the DB backend""")
184 184
185 185 # not configurable
186 186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188 188
189 189 def _ip_changed(self, name, old, new):
190 190 self.engine_ip = new
191 191 self.client_ip = new
192 192 self.monitor_ip = new
193 193 self._update_monitor_url()
194 194
195 195 def _update_monitor_url(self):
196 196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197 197
198 198 def _transport_changed(self, name, old, new):
199 199 self.engine_transport = new
200 200 self.client_transport = new
201 201 self.monitor_transport = new
202 202 self._update_monitor_url()
203 203
204 204 def __init__(self, **kwargs):
205 205 super(HubFactory, self).__init__(**kwargs)
206 206 self._update_monitor_url()
207 207
208 208
209 209 def construct(self):
210 210 self.init_hub()
211 211
212 212 def start(self):
213 213 self.heartmonitor.start()
214 214 self.log.info("Heartmonitor started")
215 215
216 216 def init_hub(self):
217 217 """construct"""
218 218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220 220
221 221 ctx = self.context
222 222 loop = self.loop
223 223
224 224 # Registrar socket
225 225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 226 q.bind(client_iface % self.regport)
227 227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 228 if self.client_ip != self.engine_ip:
229 229 q.bind(engine_iface % self.regport)
230 230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231 231
232 232 ### Engine connections ###
233 233
234 234 # heartbeat
235 235 hpub = ctx.socket(zmq.PUB)
236 236 hpub.bind(engine_iface % self.hb[0])
237 237 hrep = ctx.socket(zmq.XREP)
238 238 hrep.bind(engine_iface % self.hb[1])
239 239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 240 pingstream=ZMQStream(hpub,loop),
241 241 pongstream=ZMQStream(hrep,loop)
242 242 )
243 243
244 244 ### Client connections ###
245 245 # Notifier socket
246 246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 247 n.bind(client_iface%self.notifier_port)
248 248
249 249 ### build and launch the queues ###
250 250
251 251 # monitor socket
252 252 sub = ctx.socket(zmq.SUB)
253 253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 254 sub.bind(self.monitor_url)
255 255 sub.bind('inproc://monitor')
256 256 sub = ZMQStream(sub, loop)
257 257
258 258 # connect the db
259 259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 260 # cdir = self.config.Global.cluster_dir
261 261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 262 config=self.config, log=self.log)
263 263 time.sleep(.25)
264 264 try:
265 265 scheme = self.config.TaskScheduler.scheme_name
266 266 except AttributeError:
267 267 from .scheduler import TaskScheduler
268 268 scheme = TaskScheduler.scheme_name.get_default_value()
269 269 # build connection dicts
270 270 self.engine_info = {
271 271 'control' : engine_iface%self.control[1],
272 272 'mux': engine_iface%self.mux[1],
273 273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 274 'task' : engine_iface%self.task[1],
275 275 'iopub' : engine_iface%self.iopub[1],
276 276 # 'monitor' : engine_iface%self.mon_port,
277 277 }
278 278
279 279 self.client_info = {
280 280 'control' : client_iface%self.control[0],
281 281 'mux': client_iface%self.mux[0],
282 282 'task' : (scheme, client_iface%self.task[0]),
283 283 'iopub' : client_iface%self.iopub[0],
284 284 'notification': client_iface%self.notifier_port
285 285 }
286 286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 287 self.log.debug("Hub client addrs: %s"%self.client_info)
288 288
289 289 # resubmit stream
290 290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 291 url = util.disambiguate_url(self.client_info['task'][-1])
292 292 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
293 293 r.connect(url)
294 294
295 295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 296 query=q, notifier=n, resubmit=r, db=self.db,
297 297 engine_info=self.engine_info, client_info=self.client_info,
298 298 log=self.log)
299 299
300 300
301 301 class Hub(SessionFactory):
302 302 """The IPython Controller Hub with 0MQ connections
303 303
304 304 Parameters
305 305 ==========
306 306 loop: zmq IOLoop instance
307 307 session: Session object
308 308 <removed> context: zmq context for creating new connections (?)
309 309 queue: ZMQStream for monitoring the command queue (SUB)
310 310 query: ZMQStream for engine registration and client queries requests (XREP)
311 311 heartbeat: HeartMonitor object checking the pulse of the engines
312 312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 313 db: connection to db for out of memory logging of commands
314 314 NotImplemented
315 315 engine_info: dict of zmq connection information for engines to connect
316 316 to the queues.
317 317 client_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 """
320 320 # internal data structures:
321 321 ids=Set() # engine IDs
322 322 keytable=Dict()
323 323 by_ident=Dict()
324 324 engines=Dict()
325 325 clients=Dict()
326 326 hearts=Dict()
327 327 pending=Set()
328 328 queues=Dict() # pending msg_ids keyed by engine_id
329 329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 330 completed=Dict() # completed msg_ids keyed by engine_id
331 331 all_completed=Set() # completed msg_ids keyed by engine_id
332 332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 334 incoming_registrations=Dict()
335 335 registration_timeout=Int()
336 336 _idcounter=Int(0)
337 337
338 338 # objects from constructor:
339 339 query=Instance(ZMQStream)
340 340 monitor=Instance(ZMQStream)
341 341 notifier=Instance(ZMQStream)
342 342 resubmit=Instance(ZMQStream)
343 343 heartmonitor=Instance(HeartMonitor)
344 344 db=Instance(object)
345 345 client_info=Dict()
346 346 engine_info=Dict()
347 347
348 348
349 349 def __init__(self, **kwargs):
350 350 """
351 351 # universal:
352 352 loop: IOLoop for creating future connections
353 353 session: streamsession for sending serialized data
354 354 # engine:
355 355 queue: ZMQStream for monitoring queue messages
356 356 query: ZMQStream for engine+client registration and client requests
357 357 heartbeat: HeartMonitor object for tracking engines
358 358 # extra:
359 359 db: ZMQStream for db connection (NotImplemented)
360 360 engine_info: zmq address/protocol dict for engine connections
361 361 client_info: zmq address/protocol dict for client connections
362 362 """
363 363
364 364 super(Hub, self).__init__(**kwargs)
365 365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366 366
367 367 # validate connection dicts:
368 368 for k,v in self.client_info.iteritems():
369 369 if k == 'task':
370 370 util.validate_url_container(v[1])
371 371 else:
372 372 util.validate_url_container(v)
373 373 # util.validate_url_container(self.client_info)
374 374 util.validate_url_container(self.engine_info)
375 375
376 376 # register our callbacks
377 377 self.query.on_recv(self.dispatch_query)
378 378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379 379
380 380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382 382
383 383 self.monitor_handlers = {b'in' : self.save_queue_request,
384 384 b'out': self.save_queue_result,
385 385 b'intask': self.save_task_request,
386 386 b'outtask': self.save_task_result,
387 387 b'tracktask': self.save_task_destination,
388 388 b'incontrol': _passer,
389 389 b'outcontrol': _passer,
390 390 b'iopub': self.save_iopub_message,
391 391 }
392 392
393 393 self.query_handlers = {'queue_request': self.queue_status,
394 394 'result_request': self.get_results,
395 395 'history_request': self.get_history,
396 396 'db_request': self.db_query,
397 397 'purge_request': self.purge_results,
398 398 'load_request': self.check_load,
399 399 'resubmit_request': self.resubmit_task,
400 400 'shutdown_request': self.shutdown_request,
401 401 'registration_request' : self.register_engine,
402 402 'unregistration_request' : self.unregister_engine,
403 403 'connection_request': self.connection_request,
404 404 }
405 405
406 406 # ignore resubmit replies
407 407 self.resubmit.on_recv(lambda msg: None, copy=False)
408 408
409 409 self.log.info("hub::created hub")
410 410
411 411 @property
412 412 def _next_id(self):
413 413 """gemerate a new ID.
414 414
415 415 No longer reuse old ids, just count from 0."""
416 416 newid = self._idcounter
417 417 self._idcounter += 1
418 418 return newid
419 419 # newid = 0
420 420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 421 # # print newid, self.ids, self.incoming_registrations
422 422 # while newid in self.ids or newid in incoming:
423 423 # newid += 1
424 424 # return newid
425 425
426 426 #-----------------------------------------------------------------------------
427 427 # message validation
428 428 #-----------------------------------------------------------------------------
429 429
430 430 def _validate_targets(self, targets):
431 431 """turn any valid targets argument into a list of integer ids"""
432 432 if targets is None:
433 433 # default to all
434 434 targets = self.ids
435 435
436 436 if isinstance(targets, (int,str,unicode)):
437 437 # only one target specified
438 438 targets = [targets]
439 439 _targets = []
440 440 for t in targets:
441 441 # map raw identities to ids
442 442 if isinstance(t, (str,unicode)):
443 443 t = self.by_ident.get(t, t)
444 444 _targets.append(t)
445 445 targets = _targets
446 446 bad_targets = [ t for t in targets if t not in self.ids ]
447 447 if bad_targets:
448 448 raise IndexError("No Such Engine: %r"%bad_targets)
449 449 if not targets:
450 450 raise IndexError("No Engines Registered")
451 451 return targets
452 452
453 453 #-----------------------------------------------------------------------------
454 454 # dispatch methods (1 per stream)
455 455 #-----------------------------------------------------------------------------
456 456
457 457
458 458 def dispatch_monitor_traffic(self, msg):
459 459 """all ME and Task queue messages come through here, as well as
460 460 IOPub traffic."""
461 461 self.log.debug("monitor traffic: %r"%msg[:2])
462 462 switch = msg[0]
463 463 try:
464 464 idents, msg = self.session.feed_identities(msg[1:])
465 465 except ValueError:
466 466 idents=[]
467 467 if not idents:
468 468 self.log.error("Bad Monitor Message: %r"%msg)
469 469 return
470 470 handler = self.monitor_handlers.get(switch, None)
471 471 if handler is not None:
472 472 handler(idents, msg)
473 473 else:
474 474 self.log.error("Invalid monitor topic: %r"%switch)
475 475
476 476
477 477 def dispatch_query(self, msg):
478 478 """Route registration requests and queries from clients."""
479 479 try:
480 480 idents, msg = self.session.feed_identities(msg)
481 481 except ValueError:
482 482 idents = []
483 483 if not idents:
484 484 self.log.error("Bad Query Message: %r"%msg)
485 485 return
486 486 client_id = idents[0]
487 487 try:
488 488 msg = self.session.unserialize(msg, content=True)
489 489 except Exception:
490 490 content = error.wrap_exception()
491 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 492 self.session.send(self.query, "hub_error", ident=client_id,
493 493 content=content)
494 494 return
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 497 msg_type = msg['header']['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
501 501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 505 self.session.send(self.query, "hub_error", ident=client_id,
506 506 content=content)
507 507 return
508 508
509 509 else:
510 510 handler(idents, msg)
511 511
512 512 def dispatch_db(self, msg):
513 513 """"""
514 514 raise NotImplementedError
515 515
516 516 #---------------------------------------------------------------------------
517 517 # handler methods (1 per event)
518 518 #---------------------------------------------------------------------------
519 519
520 520 #----------------------- Heartbeat --------------------------------------
521 521
522 522 def handle_new_heart(self, heart):
523 523 """handler to attach to heartbeater.
524 524 Called when a new heart starts to beat.
525 525 Triggers completion of registration."""
526 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 527 if heart not in self.incoming_registrations:
528 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 529 else:
530 530 self.finish_registration(heart)
531 531
532 532
533 533 def handle_heart_failure(self, heart):
534 534 """handler to attach to heartbeater.
535 535 called when a previously registered heart fails to respond to beat request.
536 536 triggers unregistration"""
537 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 538 eid = self.hearts.get(heart, None)
539 539 queue = self.engines[eid].queue
540 540 if eid is None:
541 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 542 else:
543 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 544
545 545 #----------------------- MUX Queue Traffic ------------------------------
546 546
547 547 def save_queue_request(self, idents, msg):
548 548 if len(idents) < 2:
549 549 self.log.error("invalid identity prefix: %r"%idents)
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 553 msg = self.session.unserialize(msg)
554 554 except Exception:
555 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 556 return
557 557
558 558 eid = self.by_ident.get(queue_id, None)
559 559 if eid is None:
560 560 self.log.error("queue::target %r not registered"%queue_id)
561 561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 562 return
563 563 record = init_record(msg)
564 564 msg_id = record['msg_id']
565 565 # Unicode in records
566 566 record['engine_uuid'] = queue_id.decode('ascii')
567 567 record['client_uuid'] = client_id.decode('ascii')
568 568 record['queue'] = 'mux'
569 569
570 570 try:
571 571 # it's posible iopub arrived first:
572 572 existing = self.db.get_record(msg_id)
573 573 for key,evalue in existing.iteritems():
574 574 rvalue = record.get(key, None)
575 575 if evalue and rvalue and evalue != rvalue:
576 576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 577 elif evalue and not rvalue:
578 578 record[key] = evalue
579 579 try:
580 580 self.db.update_record(msg_id, record)
581 581 except Exception:
582 582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
583 583 except KeyError:
584 584 try:
585 585 self.db.add_record(msg_id, record)
586 586 except Exception:
587 587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
588 588
589 589
590 590 self.pending.add(msg_id)
591 591 self.queues[eid].append(msg_id)
592 592
593 593 def save_queue_result(self, idents, msg):
594 594 if len(idents) < 2:
595 595 self.log.error("invalid identity prefix: %r"%idents)
596 596 return
597 597
598 598 client_id, queue_id = idents[:2]
599 599 try:
600 600 msg = self.session.unserialize(msg)
601 601 except Exception:
602 602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 603 queue_id,client_id, msg), exc_info=True)
604 604 return
605 605
606 606 eid = self.by_ident.get(queue_id, None)
607 607 if eid is None:
608 608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
609 609 return
610 610
611 611 parent = msg['parent_header']
612 612 if not parent:
613 613 return
614 614 msg_id = parent['msg_id']
615 615 if msg_id in self.pending:
616 616 self.pending.remove(msg_id)
617 617 self.all_completed.add(msg_id)
618 618 self.queues[eid].remove(msg_id)
619 619 self.completed[eid].append(msg_id)
620 620 elif msg_id not in self.all_completed:
621 621 # it could be a result from a dead engine that died before delivering the
622 622 # result
623 623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
624 624 return
625 625 # update record anyway, because the unregistration could have been premature
626 626 rheader = msg['header']
627 627 completed = rheader['date']
628 628 started = rheader.get('started', None)
629 629 result = {
630 630 'result_header' : rheader,
631 631 'result_content': msg['content'],
632 632 'started' : started,
633 633 'completed' : completed
634 634 }
635 635
636 636 result['result_buffers'] = msg['buffers']
637 637 try:
638 638 self.db.update_record(msg_id, result)
639 639 except Exception:
640 640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
641 641
642 642
643 643 #--------------------- Task Queue Traffic ------------------------------
644 644
645 645 def save_task_request(self, idents, msg):
646 646 """Save the submission of a task."""
647 647 client_id = idents[0]
648 648
649 649 try:
650 650 msg = self.session.unserialize(msg)
651 651 except Exception:
652 652 self.log.error("task::client %r sent invalid task message: %r"%(
653 653 client_id, msg), exc_info=True)
654 654 return
655 655 record = init_record(msg)
656 656
657 657 record['client_uuid'] = client_id
658 658 record['queue'] = 'task'
659 659 header = msg['header']
660 660 msg_id = header['msg_id']
661 661 self.pending.add(msg_id)
662 662 self.unassigned.add(msg_id)
663 663 try:
664 664 # it's posible iopub arrived first:
665 665 existing = self.db.get_record(msg_id)
666 666 if existing['resubmitted']:
667 667 for key in ('submitted', 'client_uuid', 'buffers'):
668 668 # don't clobber these keys on resubmit
669 669 # submitted and client_uuid should be different
670 670 # and buffers might be big, and shouldn't have changed
671 671 record.pop(key)
672 672 # still check content,header which should not change
673 673 # but are not expensive to compare as buffers
674 674
675 675 for key,evalue in existing.iteritems():
676 676 if key.endswith('buffers'):
677 677 # don't compare buffers
678 678 continue
679 679 rvalue = record.get(key, None)
680 680 if evalue and rvalue and evalue != rvalue:
681 681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
682 682 elif evalue and not rvalue:
683 683 record[key] = evalue
684 684 try:
685 685 self.db.update_record(msg_id, record)
686 686 except Exception:
687 687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
688 688 except KeyError:
689 689 try:
690 690 self.db.add_record(msg_id, record)
691 691 except Exception:
692 692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
693 693 except Exception:
694 694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
695 695
696 696 def save_task_result(self, idents, msg):
697 697 """save the result of a completed task."""
698 698 client_id = idents[0]
699 699 try:
700 700 msg = self.session.unserialize(msg)
701 701 except Exception:
702 702 self.log.error("task::invalid task result message send to %r: %r"%(
703 703 client_id, msg), exc_info=True)
704 704 return
705 705
706 706 parent = msg['parent_header']
707 707 if not parent:
708 708 # print msg
709 709 self.log.warn("Task %r had no parent!"%msg)
710 710 return
711 711 msg_id = parent['msg_id']
712 712 if msg_id in self.unassigned:
713 713 self.unassigned.remove(msg_id)
714 714
715 715 header = msg['header']
716 716 engine_uuid = header.get('engine', None)
717 717 eid = self.by_ident.get(engine_uuid, None)
718 718
719 719 if msg_id in self.pending:
720 720 self.pending.remove(msg_id)
721 721 self.all_completed.add(msg_id)
722 722 if eid is not None:
723 723 self.completed[eid].append(msg_id)
724 724 if msg_id in self.tasks[eid]:
725 725 self.tasks[eid].remove(msg_id)
726 726 completed = header['date']
727 727 started = header.get('started', None)
728 728 result = {
729 729 'result_header' : header,
730 730 'result_content': msg['content'],
731 731 'started' : started,
732 732 'completed' : completed,
733 733 'engine_uuid': engine_uuid
734 734 }
735 735
736 736 result['result_buffers'] = msg['buffers']
737 737 try:
738 738 self.db.update_record(msg_id, result)
739 739 except Exception:
740 740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
741 741
742 742 else:
743 743 self.log.debug("task::unknown task %r finished"%msg_id)
744 744
745 745 def save_task_destination(self, idents, msg):
746 746 try:
747 747 msg = self.session.unserialize(msg, content=True)
748 748 except Exception:
749 749 self.log.error("task::invalid task tracking message", exc_info=True)
750 750 return
751 751 content = msg['content']
752 752 # print (content)
753 753 msg_id = content['msg_id']
754 754 engine_uuid = content['engine_id']
755 755 eid = self.by_ident[util.asbytes(engine_uuid)]
756 756
757 757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
758 758 if msg_id in self.unassigned:
759 759 self.unassigned.remove(msg_id)
760 760 # else:
761 761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
762 762
763 763 self.tasks[eid].append(msg_id)
764 764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
765 765 try:
766 766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
767 767 except Exception:
768 768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
769 769
770 770
771 771 def mia_task_request(self, idents, msg):
772 772 raise NotImplementedError
773 773 client_id = idents[0]
774 774 # content = dict(mia=self.mia,status='ok')
775 775 # self.session.send('mia_reply', content=content, idents=client_id)
776 776
777 777
778 778 #--------------------- IOPub Traffic ------------------------------
779 779
780 780 def save_iopub_message(self, topics, msg):
781 781 """save an iopub message into the db"""
782 782 # print (topics)
783 783 try:
784 784 msg = self.session.unserialize(msg, content=True)
785 785 except Exception:
786 786 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 787 return
788 788
789 789 parent = msg['parent_header']
790 790 if not parent:
791 791 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 792 return
793 793 msg_id = parent['msg_id']
794 794 msg_type = msg['header']['msg_type']
795 795 content = msg['content']
796 796
797 797 # ensure msg_id is in db
798 798 try:
799 799 rec = self.db.get_record(msg_id)
800 800 except KeyError:
801 801 rec = empty_record()
802 802 rec['msg_id'] = msg_id
803 803 self.db.add_record(msg_id, rec)
804 804 # stream
805 805 d = {}
806 806 if msg_type == 'stream':
807 807 name = content['name']
808 808 s = rec[name] or ''
809 809 d[name] = s + content['data']
810 810
811 811 elif msg_type == 'pyerr':
812 812 d['pyerr'] = content
813 813 elif msg_type == 'pyin':
814 814 d['pyin'] = content['code']
815 815 else:
816 816 d[msg_type] = content.get('data', '')
817 817
818 818 try:
819 819 self.db.update_record(msg_id, d)
820 820 except Exception:
821 821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
822 822
823 823
824 824
825 825 #-------------------------------------------------------------------------
826 826 # Registration requests
827 827 #-------------------------------------------------------------------------
828 828
829 829 def connection_request(self, client_id, msg):
830 830 """Reply with connection addresses for clients."""
831 831 self.log.info("client::client %r connected"%client_id)
832 832 content = dict(status='ok')
833 833 content.update(self.client_info)
834 834 jsonable = {}
835 835 for k,v in self.keytable.iteritems():
836 836 if v not in self.dead_engines:
837 837 jsonable[str(k)] = v.decode('ascii')
838 838 content['engines'] = jsonable
839 839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
840 840
841 841 def register_engine(self, reg, msg):
842 842 """Register a new engine."""
843 843 content = msg['content']
844 844 try:
845 845 queue = util.asbytes(content['queue'])
846 846 except KeyError:
847 847 self.log.error("registration::queue not specified", exc_info=True)
848 848 return
849 849 heart = content.get('heartbeat', None)
850 850 if heart:
851 851 heart = util.asbytes(heart)
852 852 """register a new engine, and create the socket(s) necessary"""
853 853 eid = self._next_id
854 854 # print (eid, queue, reg, heart)
855 855
856 856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
857 857
858 858 content = dict(id=eid,status='ok')
859 859 content.update(self.engine_info)
860 860 # check if requesting available IDs:
861 861 if queue in self.by_ident:
862 862 try:
863 863 raise KeyError("queue_id %r in use"%queue)
864 864 except:
865 865 content = error.wrap_exception()
866 866 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 867 elif heart in self.hearts: # need to check unique hearts?
868 868 try:
869 869 raise KeyError("heart_id %r in use"%heart)
870 870 except:
871 871 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 872 content = error.wrap_exception()
873 873 else:
874 874 for h, pack in self.incoming_registrations.iteritems():
875 875 if heart == h:
876 876 try:
877 877 raise KeyError("heart_id %r in use"%heart)
878 878 except:
879 879 self.log.error("heart_id %r in use"%heart, exc_info=True)
880 880 content = error.wrap_exception()
881 881 break
882 882 elif queue == pack[1]:
883 883 try:
884 884 raise KeyError("queue_id %r in use"%queue)
885 885 except:
886 886 self.log.error("queue_id %r in use"%queue, exc_info=True)
887 887 content = error.wrap_exception()
888 888 break
889 889
890 890 msg = self.session.send(self.query, "registration_reply",
891 891 content=content,
892 892 ident=reg)
893 893
894 894 if content['status'] == 'ok':
895 895 if heart in self.heartmonitor.hearts:
896 896 # already beating
897 897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
898 898 self.finish_registration(heart)
899 899 else:
900 900 purge = lambda : self._purge_stalled_registration(heart)
901 901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
902 902 dc.start()
903 903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
904 904 else:
905 905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
906 906 return eid
907 907
908 908 def unregister_engine(self, ident, msg):
909 909 """Unregister an engine that explicitly requested to leave."""
910 910 try:
911 911 eid = msg['content']['id']
912 912 except:
913 913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
914 914 return
915 915 self.log.info("registration::unregister_engine(%r)"%eid)
916 916 # print (eid)
917 917 uuid = self.keytable[eid]
918 918 content=dict(id=eid, queue=uuid.decode('ascii'))
919 919 self.dead_engines.add(uuid)
920 920 # self.ids.remove(eid)
921 921 # uuid = self.keytable.pop(eid)
922 922 #
923 923 # ec = self.engines.pop(eid)
924 924 # self.hearts.pop(ec.heartbeat)
925 925 # self.by_ident.pop(ec.queue)
926 926 # self.completed.pop(eid)
927 927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
928 928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
929 929 dc.start()
930 930 ############## TODO: HANDLE IT ################
931 931
932 932 if self.notifier:
933 933 self.session.send(self.notifier, "unregistration_notification", content=content)
934 934
935 935 def _handle_stranded_msgs(self, eid, uuid):
936 936 """Handle messages known to be on an engine when the engine unregisters.
937 937
938 938 It is possible that this will fire prematurely - that is, an engine will
939 939 go down after completing a result, and the client will be notified
940 940 that the result failed and later receive the actual result.
941 941 """
942 942
943 943 outstanding = self.queues[eid]
944 944
945 945 for msg_id in outstanding:
946 946 self.pending.remove(msg_id)
947 947 self.all_completed.add(msg_id)
948 948 try:
949 949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
950 950 except:
951 951 content = error.wrap_exception()
952 952 # build a fake header:
953 953 header = {}
954 954 header['engine'] = uuid
955 955 header['date'] = datetime.now()
956 956 rec = dict(result_content=content, result_header=header, result_buffers=[])
957 957 rec['completed'] = header['date']
958 958 rec['engine_uuid'] = uuid
959 959 try:
960 960 self.db.update_record(msg_id, rec)
961 961 except Exception:
962 962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
963 963
964 964
965 965 def finish_registration(self, heart):
966 966 """Second half of engine registration, called after our HeartMonitor
967 967 has received a beat from the Engine's Heart."""
968 968 try:
969 969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
970 970 except KeyError:
971 971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
972 972 return
973 973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
974 974 if purge is not None:
975 975 purge.stop()
976 976 control = queue
977 977 self.ids.add(eid)
978 978 self.keytable[eid] = queue
979 979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
980 980 control=control, heartbeat=heart)
981 981 self.by_ident[queue] = eid
982 982 self.queues[eid] = list()
983 983 self.tasks[eid] = list()
984 984 self.completed[eid] = list()
985 985 self.hearts[heart] = eid
986 986 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
987 987 if self.notifier:
988 988 self.session.send(self.notifier, "registration_notification", content=content)
989 989 self.log.info("engine::Engine Connected: %i"%eid)
990 990
991 991 def _purge_stalled_registration(self, heart):
992 992 if heart in self.incoming_registrations:
993 993 eid = self.incoming_registrations.pop(heart)[0]
994 994 self.log.info("registration::purging stalled registration: %i"%eid)
995 995 else:
996 996 pass
997 997
998 998 #-------------------------------------------------------------------------
999 999 # Client Requests
1000 1000 #-------------------------------------------------------------------------
1001 1001
1002 1002 def shutdown_request(self, client_id, msg):
1003 1003 """handle shutdown request."""
1004 1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1005 1005 # also notify other clients of shutdown
1006 1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1007 1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1008 1008 dc.start()
1009 1009
1010 1010 def _shutdown(self):
1011 1011 self.log.info("hub::hub shutting down.")
1012 1012 time.sleep(0.1)
1013 1013 sys.exit(0)
1014 1014
1015 1015
1016 1016 def check_load(self, client_id, msg):
1017 1017 content = msg['content']
1018 1018 try:
1019 1019 targets = content['targets']
1020 1020 targets = self._validate_targets(targets)
1021 1021 except:
1022 1022 content = error.wrap_exception()
1023 1023 self.session.send(self.query, "hub_error",
1024 1024 content=content, ident=client_id)
1025 1025 return
1026 1026
1027 1027 content = dict(status='ok')
1028 1028 # loads = {}
1029 1029 for t in targets:
1030 1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1031 1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1032 1032
1033 1033
1034 1034 def queue_status(self, client_id, msg):
1035 1035 """Return the Queue status of one or more targets.
1036 1036 if verbose: return the msg_ids
1037 1037 else: return len of each type.
1038 1038 keys: queue (pending MUX jobs)
1039 1039 tasks (pending Task jobs)
1040 1040 completed (finished jobs from both queues)"""
1041 1041 content = msg['content']
1042 1042 targets = content['targets']
1043 1043 try:
1044 1044 targets = self._validate_targets(targets)
1045 1045 except:
1046 1046 content = error.wrap_exception()
1047 1047 self.session.send(self.query, "hub_error",
1048 1048 content=content, ident=client_id)
1049 1049 return
1050 1050 verbose = content.get('verbose', False)
1051 1051 content = dict(status='ok')
1052 1052 for t in targets:
1053 1053 queue = self.queues[t]
1054 1054 completed = self.completed[t]
1055 1055 tasks = self.tasks[t]
1056 1056 if not verbose:
1057 1057 queue = len(queue)
1058 1058 completed = len(completed)
1059 1059 tasks = len(tasks)
1060 1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1061 1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1062 1062 # print (content)
1063 1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1064 1064
1065 1065 def purge_results(self, client_id, msg):
1066 1066 """Purge results from memory. This method is more valuable before we move
1067 1067 to a DB based message storage mechanism."""
1068 1068 content = msg['content']
1069 1069 self.log.info("Dropping records with %s", content)
1070 1070 msg_ids = content.get('msg_ids', [])
1071 1071 reply = dict(status='ok')
1072 1072 if msg_ids == 'all':
1073 1073 try:
1074 1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1075 1075 except Exception:
1076 1076 reply = error.wrap_exception()
1077 1077 else:
1078 1078 pending = filter(lambda m: m in self.pending, msg_ids)
1079 1079 if pending:
1080 1080 try:
1081 1081 raise IndexError("msg pending: %r"%pending[0])
1082 1082 except:
1083 1083 reply = error.wrap_exception()
1084 1084 else:
1085 1085 try:
1086 1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1087 1087 except Exception:
1088 1088 reply = error.wrap_exception()
1089 1089
1090 1090 if reply['status'] == 'ok':
1091 1091 eids = content.get('engine_ids', [])
1092 1092 for eid in eids:
1093 1093 if eid not in self.engines:
1094 1094 try:
1095 1095 raise IndexError("No such engine: %i"%eid)
1096 1096 except:
1097 1097 reply = error.wrap_exception()
1098 1098 break
1099 1099 uid = self.engines[eid].queue
1100 1100 try:
1101 1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1102 1102 except Exception:
1103 1103 reply = error.wrap_exception()
1104 1104 break
1105 1105
1106 1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1107 1107
1108 1108 def resubmit_task(self, client_id, msg):
1109 1109 """Resubmit one or more tasks."""
1110 1110 def finish(reply):
1111 1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1112 1112
1113 1113 content = msg['content']
1114 1114 msg_ids = content['msg_ids']
1115 1115 reply = dict(status='ok')
1116 1116 try:
1117 1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1118 1118 'header', 'content', 'buffers'])
1119 1119 except Exception:
1120 1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1121 1121 return finish(error.wrap_exception())
1122 1122
1123 1123 # validate msg_ids
1124 1124 found_ids = [ rec['msg_id'] for rec in records ]
1125 1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1126 1126 if len(records) > len(msg_ids):
1127 1127 try:
1128 1128 raise RuntimeError("DB appears to be in an inconsistent state."
1129 1129 "More matching records were found than should exist")
1130 1130 except Exception:
1131 1131 return finish(error.wrap_exception())
1132 1132 elif len(records) < len(msg_ids):
1133 1133 missing = [ m for m in msg_ids if m not in found_ids ]
1134 1134 try:
1135 1135 raise KeyError("No such msg(s): %r"%missing)
1136 1136 except KeyError:
1137 1137 return finish(error.wrap_exception())
1138 1138 elif invalid_ids:
1139 1139 msg_id = invalid_ids[0]
1140 1140 try:
1141 1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1142 1142 except Exception:
1143 1143 return finish(error.wrap_exception())
1144 1144
1145 1145 # clear the existing records
1146 1146 now = datetime.now()
1147 1147 rec = empty_record()
1148 1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1149 1149 rec['resubmitted'] = now
1150 1150 rec['queue'] = 'task'
1151 1151 rec['client_uuid'] = client_id[0]
1152 1152 try:
1153 1153 for msg_id in msg_ids:
1154 1154 self.all_completed.discard(msg_id)
1155 1155 self.db.update_record(msg_id, rec)
1156 1156 except Exception:
1157 1157 self.log.error('db::db error upating record', exc_info=True)
1158 1158 reply = error.wrap_exception()
1159 1159 else:
1160 1160 # send the messages
1161 1161 for rec in records:
1162 1162 header = rec['header']
1163 1163 # include resubmitted in header to prevent digest collision
1164 1164 header['resubmitted'] = now
1165 1165 msg = self.session.msg(header['msg_type'])
1166 1166 msg['content'] = rec['content']
1167 1167 msg['header'] = header
1168 msg['msg_id'] = rec['msg_id']
1168 msg['header']['msg_id'] = rec['msg_id']
1169 1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170 1170
1171 1171 finish(dict(status='ok'))
1172 1172
1173 1173
1174 1174 def _extract_record(self, rec):
1175 1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1176 1176 io_dict = {}
1177 1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1178 1178 io_dict[key] = rec[key]
1179 1179 content = { 'result_content': rec['result_content'],
1180 1180 'header': rec['header'],
1181 1181 'result_header' : rec['result_header'],
1182 1182 'io' : io_dict,
1183 1183 }
1184 1184 if rec['result_buffers']:
1185 1185 buffers = map(bytes, rec['result_buffers'])
1186 1186 else:
1187 1187 buffers = []
1188 1188
1189 1189 return content, buffers
1190 1190
1191 1191 def get_results(self, client_id, msg):
1192 1192 """Get the result of 1 or more messages."""
1193 1193 content = msg['content']
1194 1194 msg_ids = sorted(set(content['msg_ids']))
1195 1195 statusonly = content.get('status_only', False)
1196 1196 pending = []
1197 1197 completed = []
1198 1198 content = dict(status='ok')
1199 1199 content['pending'] = pending
1200 1200 content['completed'] = completed
1201 1201 buffers = []
1202 1202 if not statusonly:
1203 1203 try:
1204 1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1205 1205 # turn match list into dict, for faster lookup
1206 1206 records = {}
1207 1207 for rec in matches:
1208 1208 records[rec['msg_id']] = rec
1209 1209 except Exception:
1210 1210 content = error.wrap_exception()
1211 1211 self.session.send(self.query, "result_reply", content=content,
1212 1212 parent=msg, ident=client_id)
1213 1213 return
1214 1214 else:
1215 1215 records = {}
1216 1216 for msg_id in msg_ids:
1217 1217 if msg_id in self.pending:
1218 1218 pending.append(msg_id)
1219 1219 elif msg_id in self.all_completed:
1220 1220 completed.append(msg_id)
1221 1221 if not statusonly:
1222 1222 c,bufs = self._extract_record(records[msg_id])
1223 1223 content[msg_id] = c
1224 1224 buffers.extend(bufs)
1225 1225 elif msg_id in records:
1226 1226 if rec['completed']:
1227 1227 completed.append(msg_id)
1228 1228 c,bufs = self._extract_record(records[msg_id])
1229 1229 content[msg_id] = c
1230 1230 buffers.extend(bufs)
1231 1231 else:
1232 1232 pending.append(msg_id)
1233 1233 else:
1234 1234 try:
1235 1235 raise KeyError('No such message: '+msg_id)
1236 1236 except:
1237 1237 content = error.wrap_exception()
1238 1238 break
1239 1239 self.session.send(self.query, "result_reply", content=content,
1240 1240 parent=msg, ident=client_id,
1241 1241 buffers=buffers)
1242 1242
1243 1243 def get_history(self, client_id, msg):
1244 1244 """Get a list of all msg_ids in our DB records"""
1245 1245 try:
1246 1246 msg_ids = self.db.get_history()
1247 1247 except Exception as e:
1248 1248 content = error.wrap_exception()
1249 1249 else:
1250 1250 content = dict(status='ok', history=msg_ids)
1251 1251
1252 1252 self.session.send(self.query, "history_reply", content=content,
1253 1253 parent=msg, ident=client_id)
1254 1254
1255 1255 def db_query(self, client_id, msg):
1256 1256 """Perform a raw query on the task record database."""
1257 1257 content = msg['content']
1258 1258 query = content.get('query', {})
1259 1259 keys = content.get('keys', None)
1260 1260 buffers = []
1261 1261 empty = list()
1262 1262 try:
1263 1263 records = self.db.find_records(query, keys)
1264 1264 except Exception as e:
1265 1265 content = error.wrap_exception()
1266 1266 else:
1267 1267 # extract buffers from reply content:
1268 1268 if keys is not None:
1269 1269 buffer_lens = [] if 'buffers' in keys else None
1270 1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1271 1271 else:
1272 1272 buffer_lens = []
1273 1273 result_buffer_lens = []
1274 1274
1275 1275 for rec in records:
1276 1276 # buffers may be None, so double check
1277 1277 if buffer_lens is not None:
1278 1278 b = rec.pop('buffers', empty) or empty
1279 1279 buffer_lens.append(len(b))
1280 1280 buffers.extend(b)
1281 1281 if result_buffer_lens is not None:
1282 1282 rb = rec.pop('result_buffers', empty) or empty
1283 1283 result_buffer_lens.append(len(rb))
1284 1284 buffers.extend(rb)
1285 1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1286 1286 result_buffer_lens=result_buffer_lens)
1287 1287 # self.log.debug (content)
1288 1288 self.session.send(self.query, "db_reply", content=content,
1289 1289 parent=msg, ident=client_id,
1290 1290 buffers=buffers)
1291 1291
@@ -1,691 +1,696 b''
1 1 #!/usr/bin/env python
2 2 """Session object for building, serializing, sending, and receiving messages in
3 3 IPython. The Session object supports serialization, HMAC signatures, and
4 4 metadata on messages.
5 5
6 6 Also defined here are utilities for working with Sessions:
7 7 * A SessionFactory to be used as a base class for configurables that work with
8 8 Sessions.
9 9 * A Message object for convenience that allows attribute-access to the msg dict.
10 10
11 11 Authors:
12 12
13 13 * Min RK
14 14 * Brian Granger
15 15 * Fernando Perez
16 16 """
17 17 #-----------------------------------------------------------------------------
18 18 # Copyright (C) 2010-2011 The IPython Development Team
19 19 #
20 20 # Distributed under the terms of the BSD License. The full license is in
21 21 # the file COPYING, distributed as part of this software.
22 22 #-----------------------------------------------------------------------------
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Imports
26 26 #-----------------------------------------------------------------------------
27 27
28 28 import hmac
29 29 import logging
30 30 import os
31 31 import pprint
32 32 import uuid
33 33 from datetime import datetime
34 34
35 35 try:
36 36 import cPickle
37 37 pickle = cPickle
38 38 except:
39 39 cPickle = None
40 40 import pickle
41 41
42 42 import zmq
43 43 from zmq.utils import jsonapi
44 44 from zmq.eventloop.ioloop import IOLoop
45 45 from zmq.eventloop.zmqstream import ZMQStream
46 46
47 47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 48 from IPython.utils.importstring import import_item
49 49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 51 DottedObjectName)
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # utility functions
55 55 #-----------------------------------------------------------------------------
56 56
57 57 def squash_unicode(obj):
58 58 """coerce unicode back to bytestrings."""
59 59 if isinstance(obj,dict):
60 60 for key in obj.keys():
61 61 obj[key] = squash_unicode(obj[key])
62 62 if isinstance(key, unicode):
63 63 obj[squash_unicode(key)] = obj.pop(key)
64 64 elif isinstance(obj, list):
65 65 for i,v in enumerate(obj):
66 66 obj[i] = squash_unicode(v)
67 67 elif isinstance(obj, unicode):
68 68 obj = obj.encode('utf8')
69 69 return obj
70 70
71 71 #-----------------------------------------------------------------------------
72 72 # globals and defaults
73 73 #-----------------------------------------------------------------------------
74 74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77 77
78 78 pickle_packer = lambda o: pickle.dumps(o,-1)
79 79 pickle_unpacker = pickle.loads
80 80
81 81 default_packer = json_packer
82 82 default_unpacker = json_unpacker
83 83
84 84
85 85 DELIM=b"<IDS|MSG>"
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # Classes
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SessionFactory(LoggingConfigurable):
92 92 """The Base class for configurables that have a Session, Context, logger,
93 93 and IOLoop.
94 94 """
95 95
96 96 logname = Unicode('')
97 97 def _logname_changed(self, name, old, new):
98 98 self.log = logging.getLogger(new)
99 99
100 100 # not configurable:
101 101 context = Instance('zmq.Context')
102 102 def _context_default(self):
103 103 return zmq.Context.instance()
104 104
105 105 session = Instance('IPython.zmq.session.Session')
106 106
107 107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 108 def _loop_default(self):
109 109 return IOLoop.instance()
110 110
111 111 def __init__(self, **kwargs):
112 112 super(SessionFactory, self).__init__(**kwargs)
113 113
114 114 if self.session is None:
115 115 # construct the session
116 116 self.session = Session(**kwargs)
117 117
118 118
119 119 class Message(object):
120 120 """A simple message object that maps dict keys to attributes.
121 121
122 122 A Message can be created from a dict and a dict from a Message instance
123 123 simply by calling dict(msg_obj)."""
124 124
125 125 def __init__(self, msg_dict):
126 126 dct = self.__dict__
127 127 for k, v in dict(msg_dict).iteritems():
128 128 if isinstance(v, dict):
129 129 v = Message(v)
130 130 dct[k] = v
131 131
132 132 # Having this iterator lets dict(msg_obj) work out of the box.
133 133 def __iter__(self):
134 134 return iter(self.__dict__.iteritems())
135 135
136 136 def __repr__(self):
137 137 return repr(self.__dict__)
138 138
139 139 def __str__(self):
140 140 return pprint.pformat(self.__dict__)
141 141
142 142 def __contains__(self, k):
143 143 return k in self.__dict__
144 144
145 145 def __getitem__(self, k):
146 146 return self.__dict__[k]
147 147
148 148
149 149 def msg_header(msg_id, msg_type, username, session):
150 150 date = datetime.now()
151 151 return locals()
152 152
153 153 def extract_header(msg_or_header):
154 154 """Given a message or header, return the header."""
155 155 if not msg_or_header:
156 156 return {}
157 157 try:
158 158 # See if msg_or_header is the entire message.
159 159 h = msg_or_header['header']
160 160 except KeyError:
161 161 try:
162 162 # See if msg_or_header is just the header
163 163 h = msg_or_header['msg_id']
164 164 except KeyError:
165 165 raise
166 166 else:
167 167 h = msg_or_header
168 168 if not isinstance(h, dict):
169 169 h = dict(h)
170 170 return h
171 171
172 172 class Session(Configurable):
173 173 """Object for handling serialization and sending of messages.
174 174
175 175 The Session object handles building messages and sending them
176 176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 177 other over the network via Session objects, and only need to work with the
178 178 dict-based IPython message spec. The Session will handle
179 179 serialization/deserialization, security, and metadata.
180 180
181 181 Sessions support configurable serialiization via packer/unpacker traits,
182 182 and signing with HMAC digests via the key/keyfile traits.
183 183
184 184 Parameters
185 185 ----------
186 186
187 187 debug : bool
188 188 whether to trigger extra debugging statements
189 189 packer/unpacker : str : 'json', 'pickle' or import_string
190 190 importstrings for methods to serialize message parts. If just
191 191 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 192 Otherwise, the entire importstring must be used.
193 193
194 194 The functions must accept at least valid JSON input, and output *bytes*.
195 195
196 196 For example, to use msgpack:
197 197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 198 pack/unpack : callables
199 199 You can also set the pack/unpack callables for serialization directly.
200 200 session : bytes
201 201 the ID of this Session object. The default is to generate a new UUID.
202 202 username : unicode
203 203 username added to message headers. The default is to ask the OS.
204 204 key : bytes
205 205 The key used to initialize an HMAC signature. If unset, messages
206 206 will not be signed or checked.
207 207 keyfile : filepath
208 208 The file containing a key. If this is set, `key` will be initialized
209 209 to the contents of the file.
210 210
211 211 """
212 212
213 213 debug=Bool(False, config=True, help="""Debug output in the Session""")
214 214
215 215 packer = DottedObjectName('json',config=True,
216 216 help="""The name of the packer for serializing messages.
217 217 Should be one of 'json', 'pickle', or an import name
218 218 for a custom callable serializer.""")
219 219 def _packer_changed(self, name, old, new):
220 220 if new.lower() == 'json':
221 221 self.pack = json_packer
222 222 self.unpack = json_unpacker
223 223 elif new.lower() == 'pickle':
224 224 self.pack = pickle_packer
225 225 self.unpack = pickle_unpacker
226 226 else:
227 227 self.pack = import_item(str(new))
228 228
229 229 unpacker = DottedObjectName('json', config=True,
230 230 help="""The name of the unpacker for unserializing messages.
231 231 Only used with custom functions for `packer`.""")
232 232 def _unpacker_changed(self, name, old, new):
233 233 if new.lower() == 'json':
234 234 self.pack = json_packer
235 235 self.unpack = json_unpacker
236 236 elif new.lower() == 'pickle':
237 237 self.pack = pickle_packer
238 238 self.unpack = pickle_unpacker
239 239 else:
240 240 self.unpack = import_item(str(new))
241 241
242 242 session = CBytes(b'', config=True,
243 243 help="""The UUID identifying this session.""")
244 244 def _session_default(self):
245 245 return bytes(uuid.uuid4())
246 246
247 247 username = Unicode(os.environ.get('USER','username'), config=True,
248 248 help="""Username for the Session. Default is your system username.""")
249 249
250 250 # message signature related traits:
251 251 key = CBytes(b'', config=True,
252 252 help="""execution key, for extra authentication.""")
253 253 def _key_changed(self, name, old, new):
254 254 if new:
255 255 self.auth = hmac.HMAC(new)
256 256 else:
257 257 self.auth = None
258 258 auth = Instance(hmac.HMAC)
259 259 digest_history = Set()
260 260
261 261 keyfile = Unicode('', config=True,
262 262 help="""path to file containing execution key.""")
263 263 def _keyfile_changed(self, name, old, new):
264 264 with open(new, 'rb') as f:
265 265 self.key = f.read().strip()
266 266
267 267 pack = Any(default_packer) # the actual packer function
268 268 def _pack_changed(self, name, old, new):
269 269 if not callable(new):
270 270 raise TypeError("packer must be callable, not %s"%type(new))
271 271
272 272 unpack = Any(default_unpacker) # the actual packer function
273 273 def _unpack_changed(self, name, old, new):
274 274 # unpacker is not checked - it is assumed to be
275 275 if not callable(new):
276 276 raise TypeError("unpacker must be callable, not %s"%type(new))
277 277
278 278 def __init__(self, **kwargs):
279 279 """create a Session object
280 280
281 281 Parameters
282 282 ----------
283 283
284 284 debug : bool
285 285 whether to trigger extra debugging statements
286 286 packer/unpacker : str : 'json', 'pickle' or import_string
287 287 importstrings for methods to serialize message parts. If just
288 288 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 289 Otherwise, the entire importstring must be used.
290 290
291 291 The functions must accept at least valid JSON input, and output
292 292 *bytes*.
293 293
294 294 For example, to use msgpack:
295 295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 296 pack/unpack : callables
297 297 You can also set the pack/unpack callables for serialization
298 298 directly.
299 299 session : bytes
300 300 the ID of this Session object. The default is to generate a new
301 301 UUID.
302 302 username : unicode
303 303 username added to message headers. The default is to ask the OS.
304 304 key : bytes
305 305 The key used to initialize an HMAC signature. If unset, messages
306 306 will not be signed or checked.
307 307 keyfile : filepath
308 308 The file containing a key. If this is set, `key` will be
309 309 initialized to the contents of the file.
310 310 """
311 311 super(Session, self).__init__(**kwargs)
312 312 self._check_packers()
313 313 self.none = self.pack({})
314 314
315 315 @property
316 316 def msg_id(self):
317 317 """always return new uuid"""
318 318 return str(uuid.uuid4())
319 319
320 320 def _check_packers(self):
321 321 """check packers for binary data and datetime support."""
322 322 pack = self.pack
323 323 unpack = self.unpack
324 324
325 325 # check simple serialization
326 326 msg = dict(a=[1,'hi'])
327 327 try:
328 328 packed = pack(msg)
329 329 except Exception:
330 330 raise ValueError("packer could not serialize a simple message")
331 331
332 332 # ensure packed message is bytes
333 333 if not isinstance(packed, bytes):
334 334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335 335
336 336 # check that unpack is pack's inverse
337 337 try:
338 338 unpacked = unpack(packed)
339 339 except Exception:
340 340 raise ValueError("unpacker could not handle the packer's output")
341 341
342 342 # check datetime support
343 343 msg = dict(t=datetime.now())
344 344 try:
345 345 unpacked = unpack(pack(msg))
346 346 except Exception:
347 347 self.pack = lambda o: pack(squash_dates(o))
348 348 self.unpack = lambda s: extract_dates(unpack(s))
349 349
350 350 def msg_header(self, msg_type):
351 351 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 352
353 def msg(self, msg_type, content=None, parent=None, subheader=None):
353 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
354 354 """Return the nested message dict.
355 355
356 356 This format is different from what is sent over the wire. The
357 self.serialize method converts this nested message dict to the wire
358 format, which uses a message list.
357 serialize/unserialize methods converts this nested message dict to the wire
358 format, which is a list of message parts.
359 359 """
360 360 msg = {}
361 msg['header'] = self.msg_header(msg_type)
361 msg['header'] = self.msg_header(msg_type) if header is None else header
362 362 msg['parent_header'] = {} if parent is None else extract_header(parent)
363 363 msg['content'] = {} if content is None else content
364 364 sub = {} if subheader is None else subheader
365 365 msg['header'].update(sub)
366 366 return msg
367 367
368 368 def sign(self, msg_list):
369 369 """Sign a message with HMAC digest. If no auth, return b''.
370 370
371 371 Parameters
372 372 ----------
373 373 msg_list : list
374 374 The [p_header,p_parent,p_content] part of the message list.
375 375 """
376 376 if self.auth is None:
377 377 return b''
378 378 h = self.auth.copy()
379 379 for m in msg_list:
380 380 h.update(m)
381 381 return h.hexdigest()
382 382
383 383 def serialize(self, msg, ident=None):
384 384 """Serialize the message components to bytes.
385 385
386 386 This is roughly the inverse of unserialize. The serialize/unserialize
387 387 methods work with full message lists, whereas pack/unpack work with
388 388 the individual message parts in the message list.
389 389
390 390 Parameters
391 391 ----------
392 392 msg : dict or Message
393 393 The nexted message dict as returned by the self.msg method.
394 394
395 395 Returns
396 396 -------
397 397 msg_list : list
398 398 The list of bytes objects to be sent with the format:
399 399 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
400 400 buffer1,buffer2,...]. In this list, the p_* entities are
401 401 the packed or serialized versions, so if JSON is used, these
402 402 are uft8 encoded JSON strings.
403 403 """
404 404 content = msg.get('content', {})
405 405 if content is None:
406 406 content = self.none
407 407 elif isinstance(content, dict):
408 408 content = self.pack(content)
409 409 elif isinstance(content, bytes):
410 410 # content is already packed, as in a relayed message
411 411 pass
412 412 elif isinstance(content, unicode):
413 413 # should be bytes, but JSON often spits out unicode
414 414 content = content.encode('utf8')
415 415 else:
416 416 raise TypeError("Content incorrect type: %s"%type(content))
417 417
418 418 real_message = [self.pack(msg['header']),
419 419 self.pack(msg['parent_header']),
420 420 content
421 421 ]
422 422
423 423 to_send = []
424 424
425 425 if isinstance(ident, list):
426 426 # accept list of idents
427 427 to_send.extend(ident)
428 428 elif ident is not None:
429 429 to_send.append(ident)
430 430 to_send.append(DELIM)
431 431
432 432 signature = self.sign(real_message)
433 433 to_send.append(signature)
434 434
435 435 to_send.extend(real_message)
436 436
437 437 return to_send
438 438
439 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
440 buffers=None, subheader=None, track=False):
439 def send(self, stream, msg_or_type, content=None, parent=None, ident=None
440 buffers=None, subheader=None, track=False, header=None):
441 441 """Build and send a message via stream or socket.
442 442
443 443 The message format used by this function internally is as follows:
444 444
445 445 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
446 446 buffer1,buffer2,...]
447 447
448 The self.serialize method converts the nested message dict into this
448 The serialize/unserialize methods convert the nested message dict into this
449 449 format.
450 450
451 451 Parameters
452 452 ----------
453 453
454 454 stream : zmq.Socket or ZMQStream
455 the socket-like object used to send the data
455 The socket-like object used to send the data.
456 456 msg_or_type : str or Message/dict
457 457 Normally, msg_or_type will be a msg_type unless a message is being
458 458 sent more than once.
459 459
460 460 content : dict or None
461 the content of the message (ignored if msg_or_type is a message)
461 The content of the message (ignored if msg_or_type is a message).
462 header : dict or None
463 The header dict for the message (ignores if msg_to_type is a message).
462 464 parent : Message or dict or None
463 the parent or parent header describing the parent of this message
465 The parent or parent header describing the parent of this message
466 (ignored if msg_or_type is a message).
464 467 ident : bytes or list of bytes
465 the zmq.IDENTITY routing path
468 The zmq.IDENTITY routing path.
466 469 subheader : dict or None
467 extra header keys for this message's header
470 Extra header keys for this message's header (ignored if msg_or_type
471 is a message).
468 472 buffers : list or None
469 the already-serialized buffers to be appended to the message
473 The already-serialized buffers to be appended to the message.
470 474 track : bool
471 whether to track. Only for use with Sockets,
472 because ZMQStream objects cannot track messages.
475 Whether to track. Only for use with Sockets, because ZMQStream
476 objects cannot track messages.
473 477
474 478 Returns
475 479 -------
476 msg : message dict
477 the constructed message
478 (msg,tracker) : (message dict, MessageTracker)
480 msg : dict
481 The constructed message.
482 (msg,tracker) : (dict, MessageTracker)
479 483 if track=True, then a 2-tuple will be returned,
480 484 the first element being the constructed
481 485 message, and the second being the MessageTracker
482 486
483 487 """
484 488
485 489 if not isinstance(stream, (zmq.Socket, ZMQStream)):
486 490 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
487 491 elif track and isinstance(stream, ZMQStream):
488 492 raise TypeError("ZMQStream cannot track messages")
489 493
490 494 if isinstance(msg_or_type, (Message, dict)):
491 # we got a Message, not a msg_type
492 # don't build a new Message
495 # We got a Message or message dict, not a msg_type so don't
496 # build a new Message.
493 497 msg = msg_or_type
494 498 else:
495 msg = self.msg(msg_or_type, content, parent, subheader)
496
499 msg = self.msg(msg_or_type, content=content, parent=parent,
500 subheader=subheader, header=header)
501
497 502 buffers = [] if buffers is None else buffers
498 503 to_send = self.serialize(msg, ident)
499 504 flag = 0
500 505 if buffers:
501 506 flag = zmq.SNDMORE
502 507 _track = False
503 508 else:
504 509 _track=track
505 510 if track:
506 511 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
507 512 else:
508 513 tracker = stream.send_multipart(to_send, flag, copy=False)
509 514 for b in buffers[:-1]:
510 515 stream.send(b, flag, copy=False)
511 516 if buffers:
512 517 if track:
513 518 tracker = stream.send(buffers[-1], copy=False, track=track)
514 519 else:
515 520 tracker = stream.send(buffers[-1], copy=False)
516 521
517 522 # omsg = Message(msg)
518 523 if self.debug:
519 524 pprint.pprint(msg)
520 525 pprint.pprint(to_send)
521 526 pprint.pprint(buffers)
522 527
523 528 msg['tracker'] = tracker
524 529
525 530 return msg
526
531
527 532 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
528 533 """Send a raw message via ident path.
529 534
530 535 This method is used to send a already serialized message.
531 536
532 537 Parameters
533 538 ----------
534 539 stream : ZMQStream or Socket
535 540 The ZMQ stream or socket to use for sending the message.
536 541 msg_list : list
537 542 The serialized list of messages to send. This only includes the
538 543 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
539 544 the message.
540 545 ident : ident or list
541 546 A single ident or a list of idents to use in sending.
542 547 """
543 548 to_send = []
544 549 if isinstance(ident, bytes):
545 550 ident = [ident]
546 551 if ident is not None:
547 552 to_send.extend(ident)
548
553
549 554 to_send.append(DELIM)
550 555 to_send.append(self.sign(msg_list))
551 556 to_send.extend(msg_list)
552 557 stream.send_multipart(msg_list, flags, copy=copy)
553 558
554 559 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
555 560 """Receive and unpack a message.
556 561
557 562 Parameters
558 563 ----------
559 564 socket : ZMQStream or Socket
560 565 The socket or stream to use in receiving.
561 566
562 567 Returns
563 568 -------
564 569 [idents], msg
565 570 [idents] is a list of idents and msg is a nested message dict of
566 571 same format as self.msg returns.
567 572 """
568 573 if isinstance(socket, ZMQStream):
569 574 socket = socket.socket
570 575 try:
571 576 msg_list = socket.recv_multipart(mode)
572 577 except zmq.ZMQError as e:
573 578 if e.errno == zmq.EAGAIN:
574 579 # We can convert EAGAIN to None as we know in this case
575 580 # recv_multipart won't return None.
576 581 return None,None
577 582 else:
578 583 raise
579 584 # split multipart message into identity list and message dict
580 585 # invalid large messages can cause very expensive string comparisons
581 586 idents, msg_list = self.feed_identities(msg_list, copy)
582 587 try:
583 588 return idents, self.unserialize(msg_list, content=content, copy=copy)
584 589 except Exception as e:
585 590 print (idents, msg_list)
586 591 # TODO: handle it
587 592 raise e
588 593
589 594 def feed_identities(self, msg_list, copy=True):
590 595 """Split the identities from the rest of the message.
591 596
592 597 Feed until DELIM is reached, then return the prefix as idents and
593 598 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
594 599 but that would be silly.
595 600
596 601 Parameters
597 602 ----------
598 603 msg_list : a list of Message or bytes objects
599 604 The message to be split.
600 605 copy : bool
601 606 flag determining whether the arguments are bytes or Messages
602 607
603 608 Returns
604 609 -------
605 610 (idents, msg_list) : two lists
606 611 idents will always be a list of bytes, each of which is a ZMQ
607 612 identity. msg_list will be a list of bytes or zmq.Messages of the
608 613 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
609 614 should be unpackable/unserializable via self.unserialize at this
610 615 point.
611 616 """
612 617 if copy:
613 618 idx = msg_list.index(DELIM)
614 619 return msg_list[:idx], msg_list[idx+1:]
615 620 else:
616 621 failed = True
617 622 for idx,m in enumerate(msg_list):
618 623 if m.bytes == DELIM:
619 624 failed = False
620 625 break
621 626 if failed:
622 627 raise ValueError("DELIM not in msg_list")
623 628 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
624 629 return [m.bytes for m in idents], msg_list
625 630
626 631 def unserialize(self, msg_list, content=True, copy=True):
627 632 """Unserialize a msg_list to a nested message dict.
628 633
629 634 This is roughly the inverse of serialize. The serialize/unserialize
630 635 methods work with full message lists, whereas pack/unpack work with
631 636 the individual message parts in the message list.
632 637
633 638 Parameters:
634 639 -----------
635 640 msg_list : list of bytes or Message objects
636 641 The list of message parts of the form [HMAC,p_header,p_parent,
637 642 p_content,buffer1,buffer2,...].
638 643 content : bool (True)
639 644 Whether to unpack the content dict (True), or leave it packed
640 645 (False).
641 646 copy : bool (True)
642 647 Whether to return the bytes (True), or the non-copying Message
643 648 object in each place (False).
644 649
645 650 Returns
646 651 -------
647 652 msg : dict
648 653 The nested message dict with top-level keys [header, parent_header,
649 654 content, buffers].
650 655 """
651 656 minlen = 4
652 657 message = {}
653 658 if not copy:
654 659 for i in range(minlen):
655 660 msg_list[i] = msg_list[i].bytes
656 661 if self.auth is not None:
657 662 signature = msg_list[0]
658 663 if signature in self.digest_history:
659 664 raise ValueError("Duplicate Signature: %r"%signature)
660 665 self.digest_history.add(signature)
661 666 check = self.sign(msg_list[1:4])
662 667 if not signature == check:
663 668 raise ValueError("Invalid Signature: %r"%signature)
664 669 if not len(msg_list) >= minlen:
665 670 raise TypeError("malformed message, must have at least %i elements"%minlen)
666 671 message['header'] = self.unpack(msg_list[1])
667 672 message['parent_header'] = self.unpack(msg_list[2])
668 673 if content:
669 674 message['content'] = self.unpack(msg_list[3])
670 675 else:
671 676 message['content'] = msg_list[3]
672 677
673 678 message['buffers'] = msg_list[4:]
674 679 return message
675 680
676 681 def test_msg2obj():
677 682 am = dict(x=1)
678 683 ao = Message(am)
679 684 assert ao.x == am['x']
680 685
681 686 am['y'] = dict(z=1)
682 687 ao = Message(am)
683 688 assert ao.y.z == am['y']['z']
684 689
685 690 k1, k2 = 'y', 'z'
686 691 assert ao[k1][k2] == am[k1][k2]
687 692
688 693 am2 = dict(ao)
689 694 assert am['x'] == am2['x']
690 695 assert am['y']['z'] == am2['y']['z']
691 696
General Comments 0
You need to be logged in to leave comments. Login now