##// END OF EJS Templates
Merge pull request #1852 from minrk/resubmitbufs...
Fernando Perez -
r7301:344e5435 merge
parent child Browse files
Show More
@@ -1,1314 +1,1314
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.py3compat import cast_bytes
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'received': None,
68 68 'result_header' : None,
69 69 'result_content' : None,
70 70 'result_buffers' : None,
71 71 'queue' : None,
72 72 'pyin' : None,
73 73 'pyout': None,
74 74 'pyerr': None,
75 75 'stdout': '',
76 76 'stderr': '',
77 77 }
78 78
79 79 def init_record(msg):
80 80 """Initialize a TaskRecord based on a request."""
81 81 header = msg['header']
82 82 return {
83 83 'msg_id' : header['msg_id'],
84 84 'header' : header,
85 85 'content': msg['content'],
86 86 'buffers': msg['buffers'],
87 87 'submitted': header['date'],
88 88 'client_uuid' : None,
89 89 'engine_uuid' : None,
90 90 'started': None,
91 91 'completed': None,
92 92 'resubmitted': None,
93 93 'received': None,
94 94 'result_header' : None,
95 95 'result_content' : None,
96 96 'result_buffers' : None,
97 97 'queue' : None,
98 98 'pyin' : None,
99 99 'pyout': None,
100 100 'pyerr': None,
101 101 'stdout': '',
102 102 'stderr': '',
103 103 }
104 104
105 105
106 106 class EngineConnector(HasTraits):
107 107 """A simple object for accessing the various zmq connections of an object.
108 108 Attributes are:
109 109 id (int): engine ID
110 110 uuid (str): uuid (unused?)
111 111 queue (str): identity of queue's XREQ socket
112 112 registration (str): identity of registration XREQ socket
113 113 heartbeat (str): identity of heartbeat XREQ socket
114 114 """
115 115 id=Integer(0)
116 116 queue=CBytes()
117 117 control=CBytes()
118 118 registration=CBytes()
119 119 heartbeat=CBytes()
120 120 pending=Set()
121 121
122 122 class HubFactory(RegistrationFactory):
123 123 """The Configurable for setting up a Hub."""
124 124
125 125 # port-pairs for monitoredqueues:
126 126 hb = Tuple(Integer,Integer,config=True,
127 127 help="""XREQ/SUB Port pair for Engine heartbeats""")
128 128 def _hb_default(self):
129 129 return tuple(util.select_random_ports(2))
130 130
131 131 mux = Tuple(Integer,Integer,config=True,
132 132 help="""Engine/Client Port pair for MUX queue""")
133 133
134 134 def _mux_default(self):
135 135 return tuple(util.select_random_ports(2))
136 136
137 137 task = Tuple(Integer,Integer,config=True,
138 138 help="""Engine/Client Port pair for Task queue""")
139 139 def _task_default(self):
140 140 return tuple(util.select_random_ports(2))
141 141
142 142 control = Tuple(Integer,Integer,config=True,
143 143 help="""Engine/Client Port pair for Control queue""")
144 144
145 145 def _control_default(self):
146 146 return tuple(util.select_random_ports(2))
147 147
148 148 iopub = Tuple(Integer,Integer,config=True,
149 149 help="""Engine/Client Port pair for IOPub relay""")
150 150
151 151 def _iopub_default(self):
152 152 return tuple(util.select_random_ports(2))
153 153
154 154 # single ports:
155 155 mon_port = Integer(config=True,
156 156 help="""Monitor (SUB) port for queue traffic""")
157 157
158 158 def _mon_port_default(self):
159 159 return util.select_random_ports(1)[0]
160 160
161 161 notifier_port = Integer(config=True,
162 162 help="""PUB port for sending engine status notifications""")
163 163
164 164 def _notifier_port_default(self):
165 165 return util.select_random_ports(1)[0]
166 166
167 167 engine_ip = Unicode('127.0.0.1', config=True,
168 168 help="IP on which to listen for engine connections. [default: loopback]")
169 169 engine_transport = Unicode('tcp', config=True,
170 170 help="0MQ transport for engine connections. [default: tcp]")
171 171
172 172 client_ip = Unicode('127.0.0.1', config=True,
173 173 help="IP on which to listen for client connections. [default: loopback]")
174 174 client_transport = Unicode('tcp', config=True,
175 175 help="0MQ transport for client connections. [default : tcp]")
176 176
177 177 monitor_ip = Unicode('127.0.0.1', config=True,
178 178 help="IP on which to listen for monitor messages. [default: loopback]")
179 179 monitor_transport = Unicode('tcp', config=True,
180 180 help="0MQ transport for monitor messages. [default : tcp]")
181 181
182 182 monitor_url = Unicode('')
183 183
184 184 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
185 185 config=True, help="""The class to use for the DB backend""")
186 186
187 187 # not configurable
188 188 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
189 189 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
190 190
191 191 def _ip_changed(self, name, old, new):
192 192 self.engine_ip = new
193 193 self.client_ip = new
194 194 self.monitor_ip = new
195 195 self._update_monitor_url()
196 196
197 197 def _update_monitor_url(self):
198 198 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
199 199
200 200 def _transport_changed(self, name, old, new):
201 201 self.engine_transport = new
202 202 self.client_transport = new
203 203 self.monitor_transport = new
204 204 self._update_monitor_url()
205 205
206 206 def __init__(self, **kwargs):
207 207 super(HubFactory, self).__init__(**kwargs)
208 208 self._update_monitor_url()
209 209
210 210
211 211 def construct(self):
212 212 self.init_hub()
213 213
214 214 def start(self):
215 215 self.heartmonitor.start()
216 216 self.log.info("Heartmonitor started")
217 217
218 218 def init_hub(self):
219 219 """construct"""
220 220 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
221 221 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
222 222
223 223 ctx = self.context
224 224 loop = self.loop
225 225
226 226 # Registrar socket
227 227 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
228 228 q.bind(client_iface % self.regport)
229 229 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
230 230 if self.client_ip != self.engine_ip:
231 231 q.bind(engine_iface % self.regport)
232 232 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
233 233
234 234 ### Engine connections ###
235 235
236 236 # heartbeat
237 237 hpub = ctx.socket(zmq.PUB)
238 238 hpub.bind(engine_iface % self.hb[0])
239 239 hrep = ctx.socket(zmq.ROUTER)
240 240 hrep.bind(engine_iface % self.hb[1])
241 241 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
242 242 pingstream=ZMQStream(hpub,loop),
243 243 pongstream=ZMQStream(hrep,loop)
244 244 )
245 245
246 246 ### Client connections ###
247 247 # Notifier socket
248 248 n = ZMQStream(ctx.socket(zmq.PUB), loop)
249 249 n.bind(client_iface%self.notifier_port)
250 250
251 251 ### build and launch the queues ###
252 252
253 253 # monitor socket
254 254 sub = ctx.socket(zmq.SUB)
255 255 sub.setsockopt(zmq.SUBSCRIBE, b"")
256 256 sub.bind(self.monitor_url)
257 257 sub.bind('inproc://monitor')
258 258 sub = ZMQStream(sub, loop)
259 259
260 260 # connect the db
261 261 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
262 262 # cdir = self.config.Global.cluster_dir
263 263 self.db = import_item(str(self.db_class))(session=self.session.session,
264 264 config=self.config, log=self.log)
265 265 time.sleep(.25)
266 266 try:
267 267 scheme = self.config.TaskScheduler.scheme_name
268 268 except AttributeError:
269 269 from .scheduler import TaskScheduler
270 270 scheme = TaskScheduler.scheme_name.get_default_value()
271 271 # build connection dicts
272 272 self.engine_info = {
273 273 'control' : engine_iface%self.control[1],
274 274 'mux': engine_iface%self.mux[1],
275 275 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
276 276 'task' : engine_iface%self.task[1],
277 277 'iopub' : engine_iface%self.iopub[1],
278 278 # 'monitor' : engine_iface%self.mon_port,
279 279 }
280 280
281 281 self.client_info = {
282 282 'control' : client_iface%self.control[0],
283 283 'mux': client_iface%self.mux[0],
284 284 'task' : (scheme, client_iface%self.task[0]),
285 285 'iopub' : client_iface%self.iopub[0],
286 286 'notification': client_iface%self.notifier_port
287 287 }
288 288 self.log.debug("Hub engine addrs: %s", self.engine_info)
289 289 self.log.debug("Hub client addrs: %s", self.client_info)
290 290
291 291 # resubmit stream
292 292 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
293 293 url = util.disambiguate_url(self.client_info['task'][-1])
294 294 r.setsockopt(zmq.IDENTITY, self.session.bsession)
295 295 r.connect(url)
296 296
297 297 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
298 298 query=q, notifier=n, resubmit=r, db=self.db,
299 299 engine_info=self.engine_info, client_info=self.client_info,
300 300 log=self.log)
301 301
302 302
303 303 class Hub(SessionFactory):
304 304 """The IPython Controller Hub with 0MQ connections
305 305
306 306 Parameters
307 307 ==========
308 308 loop: zmq IOLoop instance
309 309 session: Session object
310 310 <removed> context: zmq context for creating new connections (?)
311 311 queue: ZMQStream for monitoring the command queue (SUB)
312 312 query: ZMQStream for engine registration and client queries requests (XREP)
313 313 heartbeat: HeartMonitor object checking the pulse of the engines
314 314 notifier: ZMQStream for broadcasting engine registration changes (PUB)
315 315 db: connection to db for out of memory logging of commands
316 316 NotImplemented
317 317 engine_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 client_info: dict of zmq connection information for engines to connect
320 320 to the queues.
321 321 """
322 322 # internal data structures:
323 323 ids=Set() # engine IDs
324 324 keytable=Dict()
325 325 by_ident=Dict()
326 326 engines=Dict()
327 327 clients=Dict()
328 328 hearts=Dict()
329 329 pending=Set()
330 330 queues=Dict() # pending msg_ids keyed by engine_id
331 331 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
332 332 completed=Dict() # completed msg_ids keyed by engine_id
333 333 all_completed=Set() # completed msg_ids keyed by engine_id
334 334 dead_engines=Set() # completed msg_ids keyed by engine_id
335 335 unassigned=Set() # set of task msg_ds not yet assigned a destination
336 336 incoming_registrations=Dict()
337 337 registration_timeout=Integer()
338 338 _idcounter=Integer(0)
339 339
340 340 # objects from constructor:
341 341 query=Instance(ZMQStream)
342 342 monitor=Instance(ZMQStream)
343 343 notifier=Instance(ZMQStream)
344 344 resubmit=Instance(ZMQStream)
345 345 heartmonitor=Instance(HeartMonitor)
346 346 db=Instance(object)
347 347 client_info=Dict()
348 348 engine_info=Dict()
349 349
350 350
351 351 def __init__(self, **kwargs):
352 352 """
353 353 # universal:
354 354 loop: IOLoop for creating future connections
355 355 session: streamsession for sending serialized data
356 356 # engine:
357 357 queue: ZMQStream for monitoring queue messages
358 358 query: ZMQStream for engine+client registration and client requests
359 359 heartbeat: HeartMonitor object for tracking engines
360 360 # extra:
361 361 db: ZMQStream for db connection (NotImplemented)
362 362 engine_info: zmq address/protocol dict for engine connections
363 363 client_info: zmq address/protocol dict for client connections
364 364 """
365 365
366 366 super(Hub, self).__init__(**kwargs)
367 367 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
368 368
369 369 # validate connection dicts:
370 370 for k,v in self.client_info.iteritems():
371 371 if k == 'task':
372 372 util.validate_url_container(v[1])
373 373 else:
374 374 util.validate_url_container(v)
375 375 # util.validate_url_container(self.client_info)
376 376 util.validate_url_container(self.engine_info)
377 377
378 378 # register our callbacks
379 379 self.query.on_recv(self.dispatch_query)
380 380 self.monitor.on_recv(self.dispatch_monitor_traffic)
381 381
382 382 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
383 383 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
384 384
385 385 self.monitor_handlers = {b'in' : self.save_queue_request,
386 386 b'out': self.save_queue_result,
387 387 b'intask': self.save_task_request,
388 388 b'outtask': self.save_task_result,
389 389 b'tracktask': self.save_task_destination,
390 390 b'incontrol': _passer,
391 391 b'outcontrol': _passer,
392 392 b'iopub': self.save_iopub_message,
393 393 }
394 394
395 395 self.query_handlers = {'queue_request': self.queue_status,
396 396 'result_request': self.get_results,
397 397 'history_request': self.get_history,
398 398 'db_request': self.db_query,
399 399 'purge_request': self.purge_results,
400 400 'load_request': self.check_load,
401 401 'resubmit_request': self.resubmit_task,
402 402 'shutdown_request': self.shutdown_request,
403 403 'registration_request' : self.register_engine,
404 404 'unregistration_request' : self.unregister_engine,
405 405 'connection_request': self.connection_request,
406 406 }
407 407
408 408 # ignore resubmit replies
409 409 self.resubmit.on_recv(lambda msg: None, copy=False)
410 410
411 411 self.log.info("hub::created hub")
412 412
413 413 @property
414 414 def _next_id(self):
415 415 """gemerate a new ID.
416 416
417 417 No longer reuse old ids, just count from 0."""
418 418 newid = self._idcounter
419 419 self._idcounter += 1
420 420 return newid
421 421 # newid = 0
422 422 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
423 423 # # print newid, self.ids, self.incoming_registrations
424 424 # while newid in self.ids or newid in incoming:
425 425 # newid += 1
426 426 # return newid
427 427
428 428 #-----------------------------------------------------------------------------
429 429 # message validation
430 430 #-----------------------------------------------------------------------------
431 431
432 432 def _validate_targets(self, targets):
433 433 """turn any valid targets argument into a list of integer ids"""
434 434 if targets is None:
435 435 # default to all
436 436 return self.ids
437 437
438 438 if isinstance(targets, (int,str,unicode)):
439 439 # only one target specified
440 440 targets = [targets]
441 441 _targets = []
442 442 for t in targets:
443 443 # map raw identities to ids
444 444 if isinstance(t, (str,unicode)):
445 445 t = self.by_ident.get(cast_bytes(t), t)
446 446 _targets.append(t)
447 447 targets = _targets
448 448 bad_targets = [ t for t in targets if t not in self.ids ]
449 449 if bad_targets:
450 450 raise IndexError("No Such Engine: %r" % bad_targets)
451 451 if not targets:
452 452 raise IndexError("No Engines Registered")
453 453 return targets
454 454
455 455 #-----------------------------------------------------------------------------
456 456 # dispatch methods (1 per stream)
457 457 #-----------------------------------------------------------------------------
458 458
459 459
460 460 @util.log_errors
461 461 def dispatch_monitor_traffic(self, msg):
462 462 """all ME and Task queue messages come through here, as well as
463 463 IOPub traffic."""
464 464 self.log.debug("monitor traffic: %r", msg[0])
465 465 switch = msg[0]
466 466 try:
467 467 idents, msg = self.session.feed_identities(msg[1:])
468 468 except ValueError:
469 469 idents=[]
470 470 if not idents:
471 471 self.log.error("Monitor message without topic: %r", msg)
472 472 return
473 473 handler = self.monitor_handlers.get(switch, None)
474 474 if handler is not None:
475 475 handler(idents, msg)
476 476 else:
477 477 self.log.error("Unrecognized monitor topic: %r", switch)
478 478
479 479
480 480 @util.log_errors
481 481 def dispatch_query(self, msg):
482 482 """Route registration requests and queries from clients."""
483 483 try:
484 484 idents, msg = self.session.feed_identities(msg)
485 485 except ValueError:
486 486 idents = []
487 487 if not idents:
488 488 self.log.error("Bad Query Message: %r", msg)
489 489 return
490 490 client_id = idents[0]
491 491 try:
492 492 msg = self.session.unserialize(msg, content=True)
493 493 except Exception:
494 494 content = error.wrap_exception()
495 495 self.log.error("Bad Query Message: %r", msg, exc_info=True)
496 496 self.session.send(self.query, "hub_error", ident=client_id,
497 497 content=content)
498 498 return
499 499 # print client_id, header, parent, content
500 500 #switch on message type:
501 501 msg_type = msg['header']['msg_type']
502 502 self.log.info("client::client %r requested %r", client_id, msg_type)
503 503 handler = self.query_handlers.get(msg_type, None)
504 504 try:
505 505 assert handler is not None, "Bad Message Type: %r" % msg_type
506 506 except:
507 507 content = error.wrap_exception()
508 508 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
509 509 self.session.send(self.query, "hub_error", ident=client_id,
510 510 content=content)
511 511 return
512 512
513 513 else:
514 514 handler(idents, msg)
515 515
516 516 def dispatch_db(self, msg):
517 517 """"""
518 518 raise NotImplementedError
519 519
520 520 #---------------------------------------------------------------------------
521 521 # handler methods (1 per event)
522 522 #---------------------------------------------------------------------------
523 523
524 524 #----------------------- Heartbeat --------------------------------------
525 525
526 526 def handle_new_heart(self, heart):
527 527 """handler to attach to heartbeater.
528 528 Called when a new heart starts to beat.
529 529 Triggers completion of registration."""
530 530 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
531 531 if heart not in self.incoming_registrations:
532 532 self.log.info("heartbeat::ignoring new heart: %r", heart)
533 533 else:
534 534 self.finish_registration(heart)
535 535
536 536
537 537 def handle_heart_failure(self, heart):
538 538 """handler to attach to heartbeater.
539 539 called when a previously registered heart fails to respond to beat request.
540 540 triggers unregistration"""
541 541 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
542 542 eid = self.hearts.get(heart, None)
543 543 queue = self.engines[eid].queue
544 544 if eid is None or self.keytable[eid] in self.dead_engines:
545 545 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
546 546 else:
547 547 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
548 548
549 549 #----------------------- MUX Queue Traffic ------------------------------
550 550
551 551 def save_queue_request(self, idents, msg):
552 552 if len(idents) < 2:
553 553 self.log.error("invalid identity prefix: %r", idents)
554 554 return
555 555 queue_id, client_id = idents[:2]
556 556 try:
557 557 msg = self.session.unserialize(msg)
558 558 except Exception:
559 559 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
560 560 return
561 561
562 562 eid = self.by_ident.get(queue_id, None)
563 563 if eid is None:
564 564 self.log.error("queue::target %r not registered", queue_id)
565 565 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
566 566 return
567 567 record = init_record(msg)
568 568 msg_id = record['msg_id']
569 569 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
570 570 # Unicode in records
571 571 record['engine_uuid'] = queue_id.decode('ascii')
572 572 record['client_uuid'] = client_id.decode('ascii')
573 573 record['queue'] = 'mux'
574 574
575 575 try:
576 576 # it's posible iopub arrived first:
577 577 existing = self.db.get_record(msg_id)
578 578 for key,evalue in existing.iteritems():
579 579 rvalue = record.get(key, None)
580 580 if evalue and rvalue and evalue != rvalue:
581 581 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
582 582 elif evalue and not rvalue:
583 583 record[key] = evalue
584 584 try:
585 585 self.db.update_record(msg_id, record)
586 586 except Exception:
587 587 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
588 588 except KeyError:
589 589 try:
590 590 self.db.add_record(msg_id, record)
591 591 except Exception:
592 592 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
593 593
594 594
595 595 self.pending.add(msg_id)
596 596 self.queues[eid].append(msg_id)
597 597
598 598 def save_queue_result(self, idents, msg):
599 599 if len(idents) < 2:
600 600 self.log.error("invalid identity prefix: %r", idents)
601 601 return
602 602
603 603 client_id, queue_id = idents[:2]
604 604 try:
605 605 msg = self.session.unserialize(msg)
606 606 except Exception:
607 607 self.log.error("queue::engine %r sent invalid message to %r: %r",
608 608 queue_id, client_id, msg, exc_info=True)
609 609 return
610 610
611 611 eid = self.by_ident.get(queue_id, None)
612 612 if eid is None:
613 613 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
614 614 return
615 615
616 616 parent = msg['parent_header']
617 617 if not parent:
618 618 return
619 619 msg_id = parent['msg_id']
620 620 if msg_id in self.pending:
621 621 self.pending.remove(msg_id)
622 622 self.all_completed.add(msg_id)
623 623 self.queues[eid].remove(msg_id)
624 624 self.completed[eid].append(msg_id)
625 625 self.log.info("queue::request %r completed on %s", msg_id, eid)
626 626 elif msg_id not in self.all_completed:
627 627 # it could be a result from a dead engine that died before delivering the
628 628 # result
629 629 self.log.warn("queue:: unknown msg finished %r", msg_id)
630 630 return
631 631 # update record anyway, because the unregistration could have been premature
632 632 rheader = msg['header']
633 633 completed = rheader['date']
634 634 started = rheader.get('started', None)
635 635 result = {
636 636 'result_header' : rheader,
637 637 'result_content': msg['content'],
638 638 'received': datetime.now(),
639 639 'started' : started,
640 640 'completed' : completed
641 641 }
642 642
643 643 result['result_buffers'] = msg['buffers']
644 644 try:
645 645 self.db.update_record(msg_id, result)
646 646 except Exception:
647 647 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
648 648
649 649
650 650 #--------------------- Task Queue Traffic ------------------------------
651 651
652 652 def save_task_request(self, idents, msg):
653 653 """Save the submission of a task."""
654 654 client_id = idents[0]
655 655
656 656 try:
657 657 msg = self.session.unserialize(msg)
658 658 except Exception:
659 659 self.log.error("task::client %r sent invalid task message: %r",
660 660 client_id, msg, exc_info=True)
661 661 return
662 662 record = init_record(msg)
663 663
664 664 record['client_uuid'] = client_id.decode('ascii')
665 665 record['queue'] = 'task'
666 666 header = msg['header']
667 667 msg_id = header['msg_id']
668 668 self.pending.add(msg_id)
669 669 self.unassigned.add(msg_id)
670 670 try:
671 671 # it's posible iopub arrived first:
672 672 existing = self.db.get_record(msg_id)
673 673 if existing['resubmitted']:
674 674 for key in ('submitted', 'client_uuid', 'buffers'):
675 675 # don't clobber these keys on resubmit
676 676 # submitted and client_uuid should be different
677 677 # and buffers might be big, and shouldn't have changed
678 678 record.pop(key)
679 679 # still check content,header which should not change
680 680 # but are not expensive to compare as buffers
681 681
682 682 for key,evalue in existing.iteritems():
683 683 if key.endswith('buffers'):
684 684 # don't compare buffers
685 685 continue
686 686 rvalue = record.get(key, None)
687 687 if evalue and rvalue and evalue != rvalue:
688 688 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
689 689 elif evalue and not rvalue:
690 690 record[key] = evalue
691 691 try:
692 692 self.db.update_record(msg_id, record)
693 693 except Exception:
694 694 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
695 695 except KeyError:
696 696 try:
697 697 self.db.add_record(msg_id, record)
698 698 except Exception:
699 699 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
700 700 except Exception:
701 701 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
702 702
703 703 def save_task_result(self, idents, msg):
704 704 """save the result of a completed task."""
705 705 client_id = idents[0]
706 706 try:
707 707 msg = self.session.unserialize(msg)
708 708 except Exception:
709 709 self.log.error("task::invalid task result message send to %r: %r",
710 710 client_id, msg, exc_info=True)
711 711 return
712 712
713 713 parent = msg['parent_header']
714 714 if not parent:
715 715 # print msg
716 716 self.log.warn("Task %r had no parent!", msg)
717 717 return
718 718 msg_id = parent['msg_id']
719 719 if msg_id in self.unassigned:
720 720 self.unassigned.remove(msg_id)
721 721
722 722 header = msg['header']
723 723 engine_uuid = header.get('engine', u'')
724 724 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
725 725
726 726 status = header.get('status', None)
727 727
728 728 if msg_id in self.pending:
729 729 self.log.info("task::task %r finished on %s", msg_id, eid)
730 730 self.pending.remove(msg_id)
731 731 self.all_completed.add(msg_id)
732 732 if eid is not None:
733 733 if status != 'aborted':
734 734 self.completed[eid].append(msg_id)
735 735 if msg_id in self.tasks[eid]:
736 736 self.tasks[eid].remove(msg_id)
737 737 completed = header['date']
738 738 started = header.get('started', None)
739 739 result = {
740 740 'result_header' : header,
741 741 'result_content': msg['content'],
742 742 'started' : started,
743 743 'completed' : completed,
744 744 'received' : datetime.now(),
745 745 'engine_uuid': engine_uuid,
746 746 }
747 747
748 748 result['result_buffers'] = msg['buffers']
749 749 try:
750 750 self.db.update_record(msg_id, result)
751 751 except Exception:
752 752 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
753 753
754 754 else:
755 755 self.log.debug("task::unknown task %r finished", msg_id)
756 756
757 757 def save_task_destination(self, idents, msg):
758 758 try:
759 759 msg = self.session.unserialize(msg, content=True)
760 760 except Exception:
761 761 self.log.error("task::invalid task tracking message", exc_info=True)
762 762 return
763 763 content = msg['content']
764 764 # print (content)
765 765 msg_id = content['msg_id']
766 766 engine_uuid = content['engine_id']
767 767 eid = self.by_ident[cast_bytes(engine_uuid)]
768 768
769 769 self.log.info("task::task %r arrived on %r", msg_id, eid)
770 770 if msg_id in self.unassigned:
771 771 self.unassigned.remove(msg_id)
772 772 # else:
773 773 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
774 774
775 775 self.tasks[eid].append(msg_id)
776 776 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
777 777 try:
778 778 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
779 779 except Exception:
780 780 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
781 781
782 782
783 783 def mia_task_request(self, idents, msg):
784 784 raise NotImplementedError
785 785 client_id = idents[0]
786 786 # content = dict(mia=self.mia,status='ok')
787 787 # self.session.send('mia_reply', content=content, idents=client_id)
788 788
789 789
790 790 #--------------------- IOPub Traffic ------------------------------
791 791
792 792 def save_iopub_message(self, topics, msg):
793 793 """save an iopub message into the db"""
794 794 # print (topics)
795 795 try:
796 796 msg = self.session.unserialize(msg, content=True)
797 797 except Exception:
798 798 self.log.error("iopub::invalid IOPub message", exc_info=True)
799 799 return
800 800
801 801 parent = msg['parent_header']
802 802 if not parent:
803 803 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
804 804 return
805 805 msg_id = parent['msg_id']
806 806 msg_type = msg['header']['msg_type']
807 807 content = msg['content']
808 808
809 809 # ensure msg_id is in db
810 810 try:
811 811 rec = self.db.get_record(msg_id)
812 812 except KeyError:
813 813 rec = empty_record()
814 814 rec['msg_id'] = msg_id
815 815 self.db.add_record(msg_id, rec)
816 816 # stream
817 817 d = {}
818 818 if msg_type == 'stream':
819 819 name = content['name']
820 820 s = rec[name] or ''
821 821 d[name] = s + content['data']
822 822
823 823 elif msg_type == 'pyerr':
824 824 d['pyerr'] = content
825 825 elif msg_type == 'pyin':
826 826 d['pyin'] = content['code']
827 827 else:
828 828 d[msg_type] = content.get('data', '')
829 829
830 830 try:
831 831 self.db.update_record(msg_id, d)
832 832 except Exception:
833 833 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
834 834
835 835
836 836
837 837 #-------------------------------------------------------------------------
838 838 # Registration requests
839 839 #-------------------------------------------------------------------------
840 840
841 841 def connection_request(self, client_id, msg):
842 842 """Reply with connection addresses for clients."""
843 843 self.log.info("client::client %r connected", client_id)
844 844 content = dict(status='ok')
845 845 content.update(self.client_info)
846 846 jsonable = {}
847 847 for k,v in self.keytable.iteritems():
848 848 if v not in self.dead_engines:
849 849 jsonable[str(k)] = v.decode('ascii')
850 850 content['engines'] = jsonable
851 851 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
852 852
853 853 def register_engine(self, reg, msg):
854 854 """Register a new engine."""
855 855 content = msg['content']
856 856 try:
857 857 queue = cast_bytes(content['queue'])
858 858 except KeyError:
859 859 self.log.error("registration::queue not specified", exc_info=True)
860 860 return
861 861 heart = content.get('heartbeat', None)
862 862 if heart:
863 863 heart = cast_bytes(heart)
864 864 """register a new engine, and create the socket(s) necessary"""
865 865 eid = self._next_id
866 866 # print (eid, queue, reg, heart)
867 867
868 868 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
869 869
870 870 content = dict(id=eid,status='ok')
871 871 content.update(self.engine_info)
872 872 # check if requesting available IDs:
873 873 if queue in self.by_ident:
874 874 try:
875 875 raise KeyError("queue_id %r in use" % queue)
876 876 except:
877 877 content = error.wrap_exception()
878 878 self.log.error("queue_id %r in use", queue, exc_info=True)
879 879 elif heart in self.hearts: # need to check unique hearts?
880 880 try:
881 881 raise KeyError("heart_id %r in use" % heart)
882 882 except:
883 883 self.log.error("heart_id %r in use", heart, exc_info=True)
884 884 content = error.wrap_exception()
885 885 else:
886 886 for h, pack in self.incoming_registrations.iteritems():
887 887 if heart == h:
888 888 try:
889 889 raise KeyError("heart_id %r in use" % heart)
890 890 except:
891 891 self.log.error("heart_id %r in use", heart, exc_info=True)
892 892 content = error.wrap_exception()
893 893 break
894 894 elif queue == pack[1]:
895 895 try:
896 896 raise KeyError("queue_id %r in use" % queue)
897 897 except:
898 898 self.log.error("queue_id %r in use", queue, exc_info=True)
899 899 content = error.wrap_exception()
900 900 break
901 901
902 902 msg = self.session.send(self.query, "registration_reply",
903 903 content=content,
904 904 ident=reg)
905 905
906 906 if content['status'] == 'ok':
907 907 if heart in self.heartmonitor.hearts:
908 908 # already beating
909 909 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
910 910 self.finish_registration(heart)
911 911 else:
912 912 purge = lambda : self._purge_stalled_registration(heart)
913 913 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
914 914 dc.start()
915 915 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
916 916 else:
917 917 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
918 918 return eid
919 919
920 920 def unregister_engine(self, ident, msg):
921 921 """Unregister an engine that explicitly requested to leave."""
922 922 try:
923 923 eid = msg['content']['id']
924 924 except:
925 925 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
926 926 return
927 927 self.log.info("registration::unregister_engine(%r)", eid)
928 928 # print (eid)
929 929 uuid = self.keytable[eid]
930 930 content=dict(id=eid, queue=uuid.decode('ascii'))
931 931 self.dead_engines.add(uuid)
932 932 # self.ids.remove(eid)
933 933 # uuid = self.keytable.pop(eid)
934 934 #
935 935 # ec = self.engines.pop(eid)
936 936 # self.hearts.pop(ec.heartbeat)
937 937 # self.by_ident.pop(ec.queue)
938 938 # self.completed.pop(eid)
939 939 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
940 940 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
941 941 dc.start()
942 942 ############## TODO: HANDLE IT ################
943 943
944 944 if self.notifier:
945 945 self.session.send(self.notifier, "unregistration_notification", content=content)
946 946
947 947 def _handle_stranded_msgs(self, eid, uuid):
948 948 """Handle messages known to be on an engine when the engine unregisters.
949 949
950 950 It is possible that this will fire prematurely - that is, an engine will
951 951 go down after completing a result, and the client will be notified
952 952 that the result failed and later receive the actual result.
953 953 """
954 954
955 955 outstanding = self.queues[eid]
956 956
957 957 for msg_id in outstanding:
958 958 self.pending.remove(msg_id)
959 959 self.all_completed.add(msg_id)
960 960 try:
961 961 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
962 962 except:
963 963 content = error.wrap_exception()
964 964 # build a fake header:
965 965 header = {}
966 966 header['engine'] = uuid
967 967 header['date'] = datetime.now()
968 968 rec = dict(result_content=content, result_header=header, result_buffers=[])
969 969 rec['completed'] = header['date']
970 970 rec['engine_uuid'] = uuid
971 971 try:
972 972 self.db.update_record(msg_id, rec)
973 973 except Exception:
974 974 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
975 975
976 976
977 977 def finish_registration(self, heart):
978 978 """Second half of engine registration, called after our HeartMonitor
979 979 has received a beat from the Engine's Heart."""
980 980 try:
981 981 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
982 982 except KeyError:
983 983 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
984 984 return
985 985 self.log.info("registration::finished registering engine %i:%r", eid, queue)
986 986 if purge is not None:
987 987 purge.stop()
988 988 control = queue
989 989 self.ids.add(eid)
990 990 self.keytable[eid] = queue
991 991 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
992 992 control=control, heartbeat=heart)
993 993 self.by_ident[queue] = eid
994 994 self.queues[eid] = list()
995 995 self.tasks[eid] = list()
996 996 self.completed[eid] = list()
997 997 self.hearts[heart] = eid
998 998 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
999 999 if self.notifier:
1000 1000 self.session.send(self.notifier, "registration_notification", content=content)
1001 1001 self.log.info("engine::Engine Connected: %i", eid)
1002 1002
1003 1003 def _purge_stalled_registration(self, heart):
1004 1004 if heart in self.incoming_registrations:
1005 1005 eid = self.incoming_registrations.pop(heart)[0]
1006 1006 self.log.info("registration::purging stalled registration: %i", eid)
1007 1007 else:
1008 1008 pass
1009 1009
1010 1010 #-------------------------------------------------------------------------
1011 1011 # Client Requests
1012 1012 #-------------------------------------------------------------------------
1013 1013
1014 1014 def shutdown_request(self, client_id, msg):
1015 1015 """handle shutdown request."""
1016 1016 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1017 1017 # also notify other clients of shutdown
1018 1018 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1019 1019 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1020 1020 dc.start()
1021 1021
1022 1022 def _shutdown(self):
1023 1023 self.log.info("hub::hub shutting down.")
1024 1024 time.sleep(0.1)
1025 1025 sys.exit(0)
1026 1026
1027 1027
1028 1028 def check_load(self, client_id, msg):
1029 1029 content = msg['content']
1030 1030 try:
1031 1031 targets = content['targets']
1032 1032 targets = self._validate_targets(targets)
1033 1033 except:
1034 1034 content = error.wrap_exception()
1035 1035 self.session.send(self.query, "hub_error",
1036 1036 content=content, ident=client_id)
1037 1037 return
1038 1038
1039 1039 content = dict(status='ok')
1040 1040 # loads = {}
1041 1041 for t in targets:
1042 1042 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1043 1043 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1044 1044
1045 1045
1046 1046 def queue_status(self, client_id, msg):
1047 1047 """Return the Queue status of one or more targets.
1048 1048 if verbose: return the msg_ids
1049 1049 else: return len of each type.
1050 1050 keys: queue (pending MUX jobs)
1051 1051 tasks (pending Task jobs)
1052 1052 completed (finished jobs from both queues)"""
1053 1053 content = msg['content']
1054 1054 targets = content['targets']
1055 1055 try:
1056 1056 targets = self._validate_targets(targets)
1057 1057 except:
1058 1058 content = error.wrap_exception()
1059 1059 self.session.send(self.query, "hub_error",
1060 1060 content=content, ident=client_id)
1061 1061 return
1062 1062 verbose = content.get('verbose', False)
1063 1063 content = dict(status='ok')
1064 1064 for t in targets:
1065 1065 queue = self.queues[t]
1066 1066 completed = self.completed[t]
1067 1067 tasks = self.tasks[t]
1068 1068 if not verbose:
1069 1069 queue = len(queue)
1070 1070 completed = len(completed)
1071 1071 tasks = len(tasks)
1072 1072 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1073 1073 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1074 1074 # print (content)
1075 1075 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1076 1076
1077 1077 def purge_results(self, client_id, msg):
1078 1078 """Purge results from memory. This method is more valuable before we move
1079 1079 to a DB based message storage mechanism."""
1080 1080 content = msg['content']
1081 1081 self.log.info("Dropping records with %s", content)
1082 1082 msg_ids = content.get('msg_ids', [])
1083 1083 reply = dict(status='ok')
1084 1084 if msg_ids == 'all':
1085 1085 try:
1086 1086 self.db.drop_matching_records(dict(completed={'$ne':None}))
1087 1087 except Exception:
1088 1088 reply = error.wrap_exception()
1089 1089 else:
1090 1090 pending = filter(lambda m: m in self.pending, msg_ids)
1091 1091 if pending:
1092 1092 try:
1093 1093 raise IndexError("msg pending: %r" % pending[0])
1094 1094 except:
1095 1095 reply = error.wrap_exception()
1096 1096 else:
1097 1097 try:
1098 1098 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1099 1099 except Exception:
1100 1100 reply = error.wrap_exception()
1101 1101
1102 1102 if reply['status'] == 'ok':
1103 1103 eids = content.get('engine_ids', [])
1104 1104 for eid in eids:
1105 1105 if eid not in self.engines:
1106 1106 try:
1107 1107 raise IndexError("No such engine: %i" % eid)
1108 1108 except:
1109 1109 reply = error.wrap_exception()
1110 1110 break
1111 1111 uid = self.engines[eid].queue
1112 1112 try:
1113 1113 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1114 1114 except Exception:
1115 1115 reply = error.wrap_exception()
1116 1116 break
1117 1117
1118 1118 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1119 1119
1120 1120 def resubmit_task(self, client_id, msg):
1121 1121 """Resubmit one or more tasks."""
1122 1122 def finish(reply):
1123 1123 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1124 1124
1125 1125 content = msg['content']
1126 1126 msg_ids = content['msg_ids']
1127 1127 reply = dict(status='ok')
1128 1128 try:
1129 1129 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1130 1130 'header', 'content', 'buffers'])
1131 1131 except Exception:
1132 1132 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1133 1133 return finish(error.wrap_exception())
1134 1134
1135 1135 # validate msg_ids
1136 1136 found_ids = [ rec['msg_id'] for rec in records ]
1137 1137 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1138 1138 if len(records) > len(msg_ids):
1139 1139 try:
1140 1140 raise RuntimeError("DB appears to be in an inconsistent state."
1141 1141 "More matching records were found than should exist")
1142 1142 except Exception:
1143 1143 return finish(error.wrap_exception())
1144 1144 elif len(records) < len(msg_ids):
1145 1145 missing = [ m for m in msg_ids if m not in found_ids ]
1146 1146 try:
1147 1147 raise KeyError("No such msg(s): %r" % missing)
1148 1148 except KeyError:
1149 1149 return finish(error.wrap_exception())
1150 1150 elif pending_ids:
1151 1151 pass
1152 1152 # no need to raise on resubmit of pending task, now that we
1153 1153 # resubmit under new ID, but do we want to raise anyway?
1154 1154 # msg_id = invalid_ids[0]
1155 1155 # try:
1156 1156 # raise ValueError("Task(s) %r appears to be inflight" % )
1157 1157 # except Exception:
1158 1158 # return finish(error.wrap_exception())
1159 1159
1160 1160 # mapping of original IDs to resubmitted IDs
1161 1161 resubmitted = {}
1162 1162
1163 1163 # send the messages
1164 1164 for rec in records:
1165 1165 header = rec['header']
1166 msg = self.session.msg(header['msg_type'])
1166 msg = self.session.msg(header['msg_type'], parent=header)
1167 1167 msg_id = msg['msg_id']
1168 1168 msg['content'] = rec['content']
1169 1169
1170 1170 # use the old header, but update msg_id and timestamp
1171 1171 fresh = msg['header']
1172 1172 header['msg_id'] = fresh['msg_id']
1173 1173 header['date'] = fresh['date']
1174 1174 msg['header'] = header
1175 1175
1176 1176 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1177 1177
1178 1178 resubmitted[rec['msg_id']] = msg_id
1179 1179 self.pending.add(msg_id)
1180 msg['buffers'] = []
1180 msg['buffers'] = rec['buffers']
1181 1181 try:
1182 1182 self.db.add_record(msg_id, init_record(msg))
1183 1183 except Exception:
1184 1184 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1185 1185
1186 1186 finish(dict(status='ok', resubmitted=resubmitted))
1187 1187
1188 1188 # store the new IDs in the Task DB
1189 1189 for msg_id, resubmit_id in resubmitted.iteritems():
1190 1190 try:
1191 1191 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1192 1192 except Exception:
1193 1193 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1194 1194
1195 1195
1196 1196 def _extract_record(self, rec):
1197 1197 """decompose a TaskRecord dict into subsection of reply for get_result"""
1198 1198 io_dict = {}
1199 1199 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1200 1200 io_dict[key] = rec[key]
1201 1201 content = { 'result_content': rec['result_content'],
1202 1202 'header': rec['header'],
1203 1203 'result_header' : rec['result_header'],
1204 1204 'received' : rec['received'],
1205 1205 'io' : io_dict,
1206 1206 }
1207 1207 if rec['result_buffers']:
1208 1208 buffers = map(bytes, rec['result_buffers'])
1209 1209 else:
1210 1210 buffers = []
1211 1211
1212 1212 return content, buffers
1213 1213
1214 1214 def get_results(self, client_id, msg):
1215 1215 """Get the result of 1 or more messages."""
1216 1216 content = msg['content']
1217 1217 msg_ids = sorted(set(content['msg_ids']))
1218 1218 statusonly = content.get('status_only', False)
1219 1219 pending = []
1220 1220 completed = []
1221 1221 content = dict(status='ok')
1222 1222 content['pending'] = pending
1223 1223 content['completed'] = completed
1224 1224 buffers = []
1225 1225 if not statusonly:
1226 1226 try:
1227 1227 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1228 1228 # turn match list into dict, for faster lookup
1229 1229 records = {}
1230 1230 for rec in matches:
1231 1231 records[rec['msg_id']] = rec
1232 1232 except Exception:
1233 1233 content = error.wrap_exception()
1234 1234 self.session.send(self.query, "result_reply", content=content,
1235 1235 parent=msg, ident=client_id)
1236 1236 return
1237 1237 else:
1238 1238 records = {}
1239 1239 for msg_id in msg_ids:
1240 1240 if msg_id in self.pending:
1241 1241 pending.append(msg_id)
1242 1242 elif msg_id in self.all_completed:
1243 1243 completed.append(msg_id)
1244 1244 if not statusonly:
1245 1245 c,bufs = self._extract_record(records[msg_id])
1246 1246 content[msg_id] = c
1247 1247 buffers.extend(bufs)
1248 1248 elif msg_id in records:
1249 1249 if rec['completed']:
1250 1250 completed.append(msg_id)
1251 1251 c,bufs = self._extract_record(records[msg_id])
1252 1252 content[msg_id] = c
1253 1253 buffers.extend(bufs)
1254 1254 else:
1255 1255 pending.append(msg_id)
1256 1256 else:
1257 1257 try:
1258 1258 raise KeyError('No such message: '+msg_id)
1259 1259 except:
1260 1260 content = error.wrap_exception()
1261 1261 break
1262 1262 self.session.send(self.query, "result_reply", content=content,
1263 1263 parent=msg, ident=client_id,
1264 1264 buffers=buffers)
1265 1265
1266 1266 def get_history(self, client_id, msg):
1267 1267 """Get a list of all msg_ids in our DB records"""
1268 1268 try:
1269 1269 msg_ids = self.db.get_history()
1270 1270 except Exception as e:
1271 1271 content = error.wrap_exception()
1272 1272 else:
1273 1273 content = dict(status='ok', history=msg_ids)
1274 1274
1275 1275 self.session.send(self.query, "history_reply", content=content,
1276 1276 parent=msg, ident=client_id)
1277 1277
1278 1278 def db_query(self, client_id, msg):
1279 1279 """Perform a raw query on the task record database."""
1280 1280 content = msg['content']
1281 1281 query = content.get('query', {})
1282 1282 keys = content.get('keys', None)
1283 1283 buffers = []
1284 1284 empty = list()
1285 1285 try:
1286 1286 records = self.db.find_records(query, keys)
1287 1287 except Exception as e:
1288 1288 content = error.wrap_exception()
1289 1289 else:
1290 1290 # extract buffers from reply content:
1291 1291 if keys is not None:
1292 1292 buffer_lens = [] if 'buffers' in keys else None
1293 1293 result_buffer_lens = [] if 'result_buffers' in keys else None
1294 1294 else:
1295 1295 buffer_lens = None
1296 1296 result_buffer_lens = None
1297 1297
1298 1298 for rec in records:
1299 1299 # buffers may be None, so double check
1300 1300 b = rec.pop('buffers', empty) or empty
1301 1301 if buffer_lens is not None:
1302 1302 buffer_lens.append(len(b))
1303 1303 buffers.extend(b)
1304 1304 rb = rec.pop('result_buffers', empty) or empty
1305 1305 if result_buffer_lens is not None:
1306 1306 result_buffer_lens.append(len(rb))
1307 1307 buffers.extend(rb)
1308 1308 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1309 1309 result_buffer_lens=result_buffer_lens)
1310 1310 # self.log.debug (content)
1311 1311 self.session.send(self.query, "db_reply", content=content,
1312 1312 parent=msg, ident=client_id,
1313 1313 buffers=buffers)
1314 1314
@@ -1,408 +1,422
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython.parallel.client import client as clientmod
28 28 from IPython.parallel import error
29 29 from IPython.parallel import AsyncResult, AsyncHubResult
30 30 from IPython.parallel import LoadBalancedView, DirectView
31 31
32 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 33
34 34 def setup():
35 35 add_engines(4, total=True)
36 36
37 37 class TestClient(ClusterTestCase):
38 38
39 39 def test_ids(self):
40 40 n = len(self.client.ids)
41 41 self.add_engines(2)
42 42 self.assertEquals(len(self.client.ids), n+2)
43 43
44 44 def test_view_indexing(self):
45 45 """test index access for views"""
46 46 self.minimum_engines(4)
47 47 targets = self.client._build_targets('all')[-1]
48 48 v = self.client[:]
49 49 self.assertEquals(v.targets, targets)
50 50 t = self.client.ids[2]
51 51 v = self.client[t]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, t)
54 54 t = self.client.ids[2:4]
55 55 v = self.client[t]
56 56 self.assert_(isinstance(v, DirectView))
57 57 self.assertEquals(v.targets, t)
58 58 v = self.client[::2]
59 59 self.assert_(isinstance(v, DirectView))
60 60 self.assertEquals(v.targets, targets[::2])
61 61 v = self.client[1::3]
62 62 self.assert_(isinstance(v, DirectView))
63 63 self.assertEquals(v.targets, targets[1::3])
64 64 v = self.client[:-3]
65 65 self.assert_(isinstance(v, DirectView))
66 66 self.assertEquals(v.targets, targets[:-3])
67 67 v = self.client[-1]
68 68 self.assert_(isinstance(v, DirectView))
69 69 self.assertEquals(v.targets, targets[-1])
70 70 self.assertRaises(TypeError, lambda : self.client[None])
71 71
72 72 def test_lbview_targets(self):
73 73 """test load_balanced_view targets"""
74 74 v = self.client.load_balanced_view()
75 75 self.assertEquals(v.targets, None)
76 76 v = self.client.load_balanced_view(-1)
77 77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 78 v = self.client.load_balanced_view('all')
79 79 self.assertEquals(v.targets, None)
80 80
81 81 def test_dview_targets(self):
82 82 """test direct_view targets"""
83 83 v = self.client.direct_view()
84 84 self.assertEquals(v.targets, 'all')
85 85 v = self.client.direct_view('all')
86 86 self.assertEquals(v.targets, 'all')
87 87 v = self.client.direct_view(-1)
88 88 self.assertEquals(v.targets, self.client.ids[-1])
89 89
90 90 def test_lazy_all_targets(self):
91 91 """test lazy evaluation of rc.direct_view('all')"""
92 92 v = self.client.direct_view()
93 93 self.assertEquals(v.targets, 'all')
94 94
95 95 def double(x):
96 96 return x*2
97 97 seq = range(100)
98 98 ref = [ double(x) for x in seq ]
99 99
100 100 # add some engines, which should be used
101 101 self.add_engines(1)
102 102 n1 = len(self.client.ids)
103 103
104 104 # simple apply
105 105 r = v.apply_sync(lambda : 1)
106 106 self.assertEquals(r, [1] * n1)
107 107
108 108 # map goes through remotefunction
109 109 r = v.map_sync(double, seq)
110 110 self.assertEquals(r, ref)
111 111
112 112 # add a couple more engines, and try again
113 113 self.add_engines(2)
114 114 n2 = len(self.client.ids)
115 115 self.assertNotEquals(n2, n1)
116 116
117 117 # apply
118 118 r = v.apply_sync(lambda : 1)
119 119 self.assertEquals(r, [1] * n2)
120 120
121 121 # map
122 122 r = v.map_sync(double, seq)
123 123 self.assertEquals(r, ref)
124 124
125 125 def test_targets(self):
126 126 """test various valid targets arguments"""
127 127 build = self.client._build_targets
128 128 ids = self.client.ids
129 129 idents,targets = build(None)
130 130 self.assertEquals(ids, targets)
131 131
132 132 def test_clear(self):
133 133 """test clear behavior"""
134 134 self.minimum_engines(2)
135 135 v = self.client[:]
136 136 v.block=True
137 137 v.push(dict(a=5))
138 138 v.pull('a')
139 139 id0 = self.client.ids[-1]
140 140 self.client.clear(targets=id0, block=True)
141 141 a = self.client[:-1].get('a')
142 142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 143 self.client.clear(block=True)
144 144 for i in self.client.ids:
145 145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 146
147 147 def test_get_result(self):
148 148 """test getting results from the Hub."""
149 149 c = clientmod.Client(profile='iptest')
150 150 t = c.ids[-1]
151 151 ar = c[t].apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = self.client.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 156 self.assertEquals(ahr.get(), ar.get())
157 157 ar2 = self.client.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.close()
160 160
161 161 def test_ids_list(self):
162 162 """test client.ids"""
163 163 ids = self.client.ids
164 164 self.assertEquals(ids, self.client._ids)
165 165 self.assertFalse(ids is self.client._ids)
166 166 ids.remove(ids[-1])
167 167 self.assertNotEquals(ids, self.client._ids)
168 168
169 169 def test_queue_status(self):
170 170 ids = self.client.ids
171 171 id0 = ids[0]
172 172 qs = self.client.queue_status(targets=id0)
173 173 self.assertTrue(isinstance(qs, dict))
174 174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 175 allqs = self.client.queue_status()
176 176 self.assertTrue(isinstance(allqs, dict))
177 177 intkeys = list(allqs.keys())
178 178 intkeys.remove('unassigned')
179 179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 180 unassigned = allqs.pop('unassigned')
181 181 for eid,qs in allqs.items():
182 182 self.assertTrue(isinstance(qs, dict))
183 183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184 184
185 185 def test_shutdown(self):
186 186 ids = self.client.ids
187 187 id0 = ids[0]
188 188 self.client.shutdown(id0, block=True)
189 189 while id0 in self.client.ids:
190 190 time.sleep(0.1)
191 191 self.client.spin()
192 192
193 193 self.assertRaises(IndexError, lambda : self.client[id0])
194 194
195 195 def test_result_status(self):
196 196 pass
197 197 # to be written
198 198
199 199 def test_db_query_dt(self):
200 200 """test db query by date"""
201 201 hist = self.client.hub_history()
202 202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 203 tic = middle['submitted']
204 204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 206 self.assertEquals(len(before)+len(after),len(hist))
207 207 for b in before:
208 208 self.assertTrue(b['submitted'] < tic)
209 209 for a in after:
210 210 self.assertTrue(a['submitted'] >= tic)
211 211 same = self.client.db_query({'submitted' : tic})
212 212 for s in same:
213 213 self.assertTrue(s['submitted'] == tic)
214 214
215 215 def test_db_query_keys(self):
216 216 """test extracting subset of record keys"""
217 217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 218 for rec in found:
219 219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220 220
221 221 def test_db_query_default_keys(self):
222 222 """default db_query excludes buffers"""
223 223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 224 for rec in found:
225 225 keys = set(rec.keys())
226 226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228 228
229 229 def test_db_query_msg_id(self):
230 230 """ensure msg_id is always in db queries"""
231 231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 232 for rec in found:
233 233 self.assertTrue('msg_id' in rec.keys())
234 234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 235 for rec in found:
236 236 self.assertTrue('msg_id' in rec.keys())
237 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 238 for rec in found:
239 239 self.assertTrue('msg_id' in rec.keys())
240 240
241 241 def test_db_query_get_result(self):
242 242 """pop in db_query shouldn't pop from result itself"""
243 243 self.client[:].apply_sync(lambda : 1)
244 244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 245 rc2 = clientmod.Client(profile='iptest')
246 246 # If this bug is not fixed, this call will hang:
247 247 ar = rc2.get_result(self.client.history[-1])
248 248 ar.wait(2)
249 249 self.assertTrue(ar.ready())
250 250 ar.get()
251 251 rc2.close()
252 252
253 253 def test_db_query_in(self):
254 254 """test db query with '$in','$nin' operators"""
255 255 hist = self.client.hub_history()
256 256 even = hist[::2]
257 257 odd = hist[1::2]
258 258 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
259 259 found = [ r['msg_id'] for r in recs ]
260 260 self.assertEquals(set(even), set(found))
261 261 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
262 262 found = [ r['msg_id'] for r in recs ]
263 263 self.assertEquals(set(odd), set(found))
264 264
265 265 def test_hub_history(self):
266 266 hist = self.client.hub_history()
267 267 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
268 268 recdict = {}
269 269 for rec in recs:
270 270 recdict[rec['msg_id']] = rec
271 271
272 272 latest = datetime(1984,1,1)
273 273 for msg_id in hist:
274 274 rec = recdict[msg_id]
275 275 newt = rec['submitted']
276 276 self.assertTrue(newt >= latest)
277 277 latest = newt
278 278 ar = self.client[-1].apply_async(lambda : 1)
279 279 ar.get()
280 280 time.sleep(0.25)
281 281 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
282 282
283 283 def _wait_for_idle(self):
284 284 """wait for an engine to become idle, according to the Hub"""
285 285 rc = self.client
286 286
287 287 # timeout 2s, polling every 100ms
288 288 for i in range(20):
289 289 qs = rc.queue_status()
290 290 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
291 291 time.sleep(0.1)
292 292 else:
293 293 break
294 294
295 295 # ensure Hub up to date:
296 296 qs = rc.queue_status()
297 297 self.assertEquals(qs['unassigned'], 0)
298 298 for eid in rc.ids:
299 299 self.assertEquals(qs[eid]['tasks'], 0)
300 300
301 301
302 302 def test_resubmit(self):
303 303 def f():
304 304 import random
305 305 return random.random()
306 306 v = self.client.load_balanced_view()
307 307 ar = v.apply_async(f)
308 308 r1 = ar.get(1)
309 309 # give the Hub a chance to notice:
310 310 self._wait_for_idle()
311 311 ahr = self.client.resubmit(ar.msg_ids)
312 312 r2 = ahr.get(1)
313 313 self.assertFalse(r1 == r2)
314 314
315 def test_resubmit_chain(self):
316 """resubmit resubmitted tasks"""
317 v = self.client.load_balanced_view()
318 ar = v.apply_async(lambda x: x, 'x'*1024)
319 ar.get()
320 self._wait_for_idle()
321 ars = [ar]
322
323 for i in range(10):
324 ar = ars[-1]
325 ar2 = self.client.resubmit(ar.msg_ids)
326
327 [ ar.get() for ar in ars ]
328
315 329 def test_resubmit_header(self):
316 330 """resubmit shouldn't clobber the whole header"""
317 331 def f():
318 332 import random
319 333 return random.random()
320 334 v = self.client.load_balanced_view()
321 335 v.retries = 1
322 336 ar = v.apply_async(f)
323 337 r1 = ar.get(1)
324 338 # give the Hub a chance to notice:
325 339 self._wait_for_idle()
326 340 ahr = self.client.resubmit(ar.msg_ids)
327 341 ahr.get(1)
328 342 time.sleep(0.5)
329 343 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
330 344 h1,h2 = [ r['header'] for r in records ]
331 345 for key in set(h1.keys()).union(set(h2.keys())):
332 346 if key in ('msg_id', 'date'):
333 347 self.assertNotEquals(h1[key], h2[key])
334 348 else:
335 349 self.assertEquals(h1[key], h2[key])
336 350
337 351 def test_resubmit_aborted(self):
338 352 def f():
339 353 import random
340 354 return random.random()
341 355 v = self.client.load_balanced_view()
342 356 # restrict to one engine, so we can put a sleep
343 357 # ahead of the task, so it will get aborted
344 358 eid = self.client.ids[-1]
345 359 v.targets = [eid]
346 360 sleep = v.apply_async(time.sleep, 0.5)
347 361 ar = v.apply_async(f)
348 362 ar.abort()
349 363 self.assertRaises(error.TaskAborted, ar.get)
350 364 # Give the Hub a chance to get up to date:
351 365 self._wait_for_idle()
352 366 ahr = self.client.resubmit(ar.msg_ids)
353 367 r2 = ahr.get(1)
354 368
355 369 def test_resubmit_inflight(self):
356 370 """resubmit of inflight task"""
357 371 v = self.client.load_balanced_view()
358 372 ar = v.apply_async(time.sleep,1)
359 373 # give the message a chance to arrive
360 374 time.sleep(0.2)
361 375 ahr = self.client.resubmit(ar.msg_ids)
362 376 ar.get(2)
363 377 ahr.get(2)
364 378
365 379 def test_resubmit_badkey(self):
366 380 """ensure KeyError on resubmit of nonexistant task"""
367 381 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
368 382
369 383 def test_purge_results(self):
370 384 # ensure there are some tasks
371 385 for i in range(5):
372 386 self.client[:].apply_sync(lambda : 1)
373 387 # Wait for the Hub to realise the result is done:
374 388 # This prevents a race condition, where we
375 389 # might purge a result the Hub still thinks is pending.
376 390 time.sleep(0.1)
377 391 rc2 = clientmod.Client(profile='iptest')
378 392 hist = self.client.hub_history()
379 393 ahr = rc2.get_result([hist[-1]])
380 394 ahr.wait(10)
381 395 self.client.purge_results(hist[-1])
382 396 newhist = self.client.hub_history()
383 397 self.assertEquals(len(newhist)+1,len(hist))
384 398 rc2.spin()
385 399 rc2.close()
386 400
387 401 def test_purge_all_results(self):
388 402 self.client.purge_results('all')
389 403 hist = self.client.hub_history()
390 404 self.assertEquals(len(hist), 0)
391 405
392 406 def test_spin_thread(self):
393 407 self.client.spin_thread(0.01)
394 408 ar = self.client[-1].apply_async(lambda : 1)
395 409 time.sleep(0.1)
396 410 self.assertTrue(ar.wall_time < 0.1,
397 411 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
398 412 )
399 413
400 414 def test_stop_spin_thread(self):
401 415 self.client.spin_thread(0.01)
402 416 self.client.stop_spin_thread()
403 417 ar = self.client[-1].apply_async(lambda : 1)
404 418 time.sleep(0.15)
405 419 self.assertTrue(ar.wall_time > 0.1,
406 420 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
407 421 )
408 422
General Comments 0
You need to be logged in to leave comments. Login now