##// END OF EJS Templates
receive tasks, even when no engines are registered...
MinRK -
Show More
@@ -1,1293 +1,1293 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.traitlets import (
32 32 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 33 )
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import RegistrationFactory
37 37
38 38 from IPython.zmq.session import SessionFactory
39 39
40 40 from .heartmonitor import HeartMonitor
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Code
44 44 #-----------------------------------------------------------------------------
45 45
46 46 def _passer(*args, **kwargs):
47 47 return
48 48
49 49 def _printer(*args, **kwargs):
50 50 print (args)
51 51 print (kwargs)
52 52
53 53 def empty_record():
54 54 """Return an empty dict with all record keys."""
55 55 return {
56 56 'msg_id' : None,
57 57 'header' : None,
58 58 'content': None,
59 59 'buffers': None,
60 60 'submitted': None,
61 61 'client_uuid' : None,
62 62 'engine_uuid' : None,
63 63 'started': None,
64 64 'completed': None,
65 65 'resubmitted': None,
66 66 'result_header' : None,
67 67 'result_content' : None,
68 68 'result_buffers' : None,
69 69 'queue' : None,
70 70 'pyin' : None,
71 71 'pyout': None,
72 72 'pyerr': None,
73 73 'stdout': '',
74 74 'stderr': '',
75 75 }
76 76
77 77 def init_record(msg):
78 78 """Initialize a TaskRecord based on a request."""
79 79 header = msg['header']
80 80 return {
81 81 'msg_id' : header['msg_id'],
82 82 'header' : header,
83 83 'content': msg['content'],
84 84 'buffers': msg['buffers'],
85 85 'submitted': header['date'],
86 86 'client_uuid' : None,
87 87 'engine_uuid' : None,
88 88 'started': None,
89 89 'completed': None,
90 90 'resubmitted': None,
91 91 'result_header' : None,
92 92 'result_content' : None,
93 93 'result_buffers' : None,
94 94 'queue' : None,
95 95 'pyin' : None,
96 96 'pyout': None,
97 97 'pyerr': None,
98 98 'stdout': '',
99 99 'stderr': '',
100 100 }
101 101
102 102
103 103 class EngineConnector(HasTraits):
104 104 """A simple object for accessing the various zmq connections of an object.
105 105 Attributes are:
106 106 id (int): engine ID
107 107 uuid (str): uuid (unused?)
108 108 queue (str): identity of queue's XREQ socket
109 109 registration (str): identity of registration XREQ socket
110 110 heartbeat (str): identity of heartbeat XREQ socket
111 111 """
112 112 id=Integer(0)
113 113 queue=CBytes()
114 114 control=CBytes()
115 115 registration=CBytes()
116 116 heartbeat=CBytes()
117 117 pending=Set()
118 118
119 119 class HubFactory(RegistrationFactory):
120 120 """The Configurable for setting up a Hub."""
121 121
122 122 # port-pairs for monitoredqueues:
123 123 hb = Tuple(Integer,Integer,config=True,
124 124 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 125 def _hb_default(self):
126 126 return tuple(util.select_random_ports(2))
127 127
128 128 mux = Tuple(Integer,Integer,config=True,
129 129 help="""Engine/Client Port pair for MUX queue""")
130 130
131 131 def _mux_default(self):
132 132 return tuple(util.select_random_ports(2))
133 133
134 134 task = Tuple(Integer,Integer,config=True,
135 135 help="""Engine/Client Port pair for Task queue""")
136 136 def _task_default(self):
137 137 return tuple(util.select_random_ports(2))
138 138
139 139 control = Tuple(Integer,Integer,config=True,
140 140 help="""Engine/Client Port pair for Control queue""")
141 141
142 142 def _control_default(self):
143 143 return tuple(util.select_random_ports(2))
144 144
145 145 iopub = Tuple(Integer,Integer,config=True,
146 146 help="""Engine/Client Port pair for IOPub relay""")
147 147
148 148 def _iopub_default(self):
149 149 return tuple(util.select_random_ports(2))
150 150
151 151 # single ports:
152 152 mon_port = Integer(config=True,
153 153 help="""Monitor (SUB) port for queue traffic""")
154 154
155 155 def _mon_port_default(self):
156 156 return util.select_random_ports(1)[0]
157 157
158 158 notifier_port = Integer(config=True,
159 159 help="""PUB port for sending engine status notifications""")
160 160
161 161 def _notifier_port_default(self):
162 162 return util.select_random_ports(1)[0]
163 163
164 164 engine_ip = Unicode('127.0.0.1', config=True,
165 165 help="IP on which to listen for engine connections. [default: loopback]")
166 166 engine_transport = Unicode('tcp', config=True,
167 167 help="0MQ transport for engine connections. [default: tcp]")
168 168
169 169 client_ip = Unicode('127.0.0.1', config=True,
170 170 help="IP on which to listen for client connections. [default: loopback]")
171 171 client_transport = Unicode('tcp', config=True,
172 172 help="0MQ transport for client connections. [default : tcp]")
173 173
174 174 monitor_ip = Unicode('127.0.0.1', config=True,
175 175 help="IP on which to listen for monitor messages. [default: loopback]")
176 176 monitor_transport = Unicode('tcp', config=True,
177 177 help="0MQ transport for monitor messages. [default : tcp]")
178 178
179 179 monitor_url = Unicode('')
180 180
181 181 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
182 182 config=True, help="""The class to use for the DB backend""")
183 183
184 184 # not configurable
185 185 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 186 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 187
188 188 def _ip_changed(self, name, old, new):
189 189 self.engine_ip = new
190 190 self.client_ip = new
191 191 self.monitor_ip = new
192 192 self._update_monitor_url()
193 193
194 194 def _update_monitor_url(self):
195 195 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
196 196
197 197 def _transport_changed(self, name, old, new):
198 198 self.engine_transport = new
199 199 self.client_transport = new
200 200 self.monitor_transport = new
201 201 self._update_monitor_url()
202 202
203 203 def __init__(self, **kwargs):
204 204 super(HubFactory, self).__init__(**kwargs)
205 205 self._update_monitor_url()
206 206
207 207
208 208 def construct(self):
209 209 self.init_hub()
210 210
211 211 def start(self):
212 212 self.heartmonitor.start()
213 213 self.log.info("Heartmonitor started")
214 214
215 215 def init_hub(self):
216 216 """construct"""
217 217 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
218 218 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
219 219
220 220 ctx = self.context
221 221 loop = self.loop
222 222
223 223 # Registrar socket
224 224 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
225 225 q.bind(client_iface % self.regport)
226 226 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
227 227 if self.client_ip != self.engine_ip:
228 228 q.bind(engine_iface % self.regport)
229 229 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
230 230
231 231 ### Engine connections ###
232 232
233 233 # heartbeat
234 234 hpub = ctx.socket(zmq.PUB)
235 235 hpub.bind(engine_iface % self.hb[0])
236 236 hrep = ctx.socket(zmq.ROUTER)
237 237 hrep.bind(engine_iface % self.hb[1])
238 238 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 239 pingstream=ZMQStream(hpub,loop),
240 240 pongstream=ZMQStream(hrep,loop)
241 241 )
242 242
243 243 ### Client connections ###
244 244 # Notifier socket
245 245 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 246 n.bind(client_iface%self.notifier_port)
247 247
248 248 ### build and launch the queues ###
249 249
250 250 # monitor socket
251 251 sub = ctx.socket(zmq.SUB)
252 252 sub.setsockopt(zmq.SUBSCRIBE, b"")
253 253 sub.bind(self.monitor_url)
254 254 sub.bind('inproc://monitor')
255 255 sub = ZMQStream(sub, loop)
256 256
257 257 # connect the db
258 258 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 259 # cdir = self.config.Global.cluster_dir
260 260 self.db = import_item(str(self.db_class))(session=self.session.session,
261 261 config=self.config, log=self.log)
262 262 time.sleep(.25)
263 263 try:
264 264 scheme = self.config.TaskScheduler.scheme_name
265 265 except AttributeError:
266 266 from .scheduler import TaskScheduler
267 267 scheme = TaskScheduler.scheme_name.get_default_value()
268 268 # build connection dicts
269 269 self.engine_info = {
270 270 'control' : engine_iface%self.control[1],
271 271 'mux': engine_iface%self.mux[1],
272 272 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 273 'task' : engine_iface%self.task[1],
274 274 'iopub' : engine_iface%self.iopub[1],
275 275 # 'monitor' : engine_iface%self.mon_port,
276 276 }
277 277
278 278 self.client_info = {
279 279 'control' : client_iface%self.control[0],
280 280 'mux': client_iface%self.mux[0],
281 281 'task' : (scheme, client_iface%self.task[0]),
282 282 'iopub' : client_iface%self.iopub[0],
283 283 'notification': client_iface%self.notifier_port
284 284 }
285 285 self.log.debug("Hub engine addrs: %s", self.engine_info)
286 286 self.log.debug("Hub client addrs: %s", self.client_info)
287 287
288 288 # resubmit stream
289 289 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
290 290 url = util.disambiguate_url(self.client_info['task'][-1])
291 291 r.setsockopt(zmq.IDENTITY, self.session.bsession)
292 292 r.connect(url)
293 293
294 294 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 295 query=q, notifier=n, resubmit=r, db=self.db,
296 296 engine_info=self.engine_info, client_info=self.client_info,
297 297 log=self.log)
298 298
299 299
300 300 class Hub(SessionFactory):
301 301 """The IPython Controller Hub with 0MQ connections
302 302
303 303 Parameters
304 304 ==========
305 305 loop: zmq IOLoop instance
306 306 session: Session object
307 307 <removed> context: zmq context for creating new connections (?)
308 308 queue: ZMQStream for monitoring the command queue (SUB)
309 309 query: ZMQStream for engine registration and client queries requests (XREP)
310 310 heartbeat: HeartMonitor object checking the pulse of the engines
311 311 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 312 db: connection to db for out of memory logging of commands
313 313 NotImplemented
314 314 engine_info: dict of zmq connection information for engines to connect
315 315 to the queues.
316 316 client_info: dict of zmq connection information for engines to connect
317 317 to the queues.
318 318 """
319 319 # internal data structures:
320 320 ids=Set() # engine IDs
321 321 keytable=Dict()
322 322 by_ident=Dict()
323 323 engines=Dict()
324 324 clients=Dict()
325 325 hearts=Dict()
326 326 pending=Set()
327 327 queues=Dict() # pending msg_ids keyed by engine_id
328 328 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 329 completed=Dict() # completed msg_ids keyed by engine_id
330 330 all_completed=Set() # completed msg_ids keyed by engine_id
331 331 dead_engines=Set() # completed msg_ids keyed by engine_id
332 332 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 333 incoming_registrations=Dict()
334 334 registration_timeout=Integer()
335 335 _idcounter=Integer(0)
336 336
337 337 # objects from constructor:
338 338 query=Instance(ZMQStream)
339 339 monitor=Instance(ZMQStream)
340 340 notifier=Instance(ZMQStream)
341 341 resubmit=Instance(ZMQStream)
342 342 heartmonitor=Instance(HeartMonitor)
343 343 db=Instance(object)
344 344 client_info=Dict()
345 345 engine_info=Dict()
346 346
347 347
348 348 def __init__(self, **kwargs):
349 349 """
350 350 # universal:
351 351 loop: IOLoop for creating future connections
352 352 session: streamsession for sending serialized data
353 353 # engine:
354 354 queue: ZMQStream for monitoring queue messages
355 355 query: ZMQStream for engine+client registration and client requests
356 356 heartbeat: HeartMonitor object for tracking engines
357 357 # extra:
358 358 db: ZMQStream for db connection (NotImplemented)
359 359 engine_info: zmq address/protocol dict for engine connections
360 360 client_info: zmq address/protocol dict for client connections
361 361 """
362 362
363 363 super(Hub, self).__init__(**kwargs)
364 364 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 365
366 366 # validate connection dicts:
367 367 for k,v in self.client_info.iteritems():
368 368 if k == 'task':
369 369 util.validate_url_container(v[1])
370 370 else:
371 371 util.validate_url_container(v)
372 372 # util.validate_url_container(self.client_info)
373 373 util.validate_url_container(self.engine_info)
374 374
375 375 # register our callbacks
376 376 self.query.on_recv(self.dispatch_query)
377 377 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 378
379 379 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 380 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 381
382 382 self.monitor_handlers = {b'in' : self.save_queue_request,
383 383 b'out': self.save_queue_result,
384 384 b'intask': self.save_task_request,
385 385 b'outtask': self.save_task_result,
386 386 b'tracktask': self.save_task_destination,
387 387 b'incontrol': _passer,
388 388 b'outcontrol': _passer,
389 389 b'iopub': self.save_iopub_message,
390 390 }
391 391
392 392 self.query_handlers = {'queue_request': self.queue_status,
393 393 'result_request': self.get_results,
394 394 'history_request': self.get_history,
395 395 'db_request': self.db_query,
396 396 'purge_request': self.purge_results,
397 397 'load_request': self.check_load,
398 398 'resubmit_request': self.resubmit_task,
399 399 'shutdown_request': self.shutdown_request,
400 400 'registration_request' : self.register_engine,
401 401 'unregistration_request' : self.unregister_engine,
402 402 'connection_request': self.connection_request,
403 403 }
404 404
405 405 # ignore resubmit replies
406 406 self.resubmit.on_recv(lambda msg: None, copy=False)
407 407
408 408 self.log.info("hub::created hub")
409 409
410 410 @property
411 411 def _next_id(self):
412 412 """gemerate a new ID.
413 413
414 414 No longer reuse old ids, just count from 0."""
415 415 newid = self._idcounter
416 416 self._idcounter += 1
417 417 return newid
418 418 # newid = 0
419 419 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 420 # # print newid, self.ids, self.incoming_registrations
421 421 # while newid in self.ids or newid in incoming:
422 422 # newid += 1
423 423 # return newid
424 424
425 425 #-----------------------------------------------------------------------------
426 426 # message validation
427 427 #-----------------------------------------------------------------------------
428 428
429 429 def _validate_targets(self, targets):
430 430 """turn any valid targets argument into a list of integer ids"""
431 431 if targets is None:
432 432 # default to all
433 433 targets = self.ids
434 434
435 435 if isinstance(targets, (int,str,unicode)):
436 436 # only one target specified
437 437 targets = [targets]
438 438 _targets = []
439 439 for t in targets:
440 440 # map raw identities to ids
441 441 if isinstance(t, (str,unicode)):
442 442 t = self.by_ident.get(t, t)
443 443 _targets.append(t)
444 444 targets = _targets
445 445 bad_targets = [ t for t in targets if t not in self.ids ]
446 446 if bad_targets:
447 447 raise IndexError("No Such Engine: %r" % bad_targets)
448 448 if not targets:
449 449 raise IndexError("No Engines Registered")
450 450 return targets
451 451
452 452 #-----------------------------------------------------------------------------
453 453 # dispatch methods (1 per stream)
454 454 #-----------------------------------------------------------------------------
455 455
456 456
457 457 def dispatch_monitor_traffic(self, msg):
458 458 """all ME and Task queue messages come through here, as well as
459 459 IOPub traffic."""
460 self.log.debug("monitor traffic: %r", msg[:2])
460 self.log.debug("monitor traffic: %r", msg[0])
461 461 switch = msg[0]
462 462 try:
463 463 idents, msg = self.session.feed_identities(msg[1:])
464 464 except ValueError:
465 465 idents=[]
466 466 if not idents:
467 467 self.log.error("Bad Monitor Message: %r", msg)
468 468 return
469 469 handler = self.monitor_handlers.get(switch, None)
470 470 if handler is not None:
471 471 handler(idents, msg)
472 472 else:
473 473 self.log.error("Invalid monitor topic: %r", switch)
474 474
475 475
476 476 def dispatch_query(self, msg):
477 477 """Route registration requests and queries from clients."""
478 478 try:
479 479 idents, msg = self.session.feed_identities(msg)
480 480 except ValueError:
481 481 idents = []
482 482 if not idents:
483 483 self.log.error("Bad Query Message: %r", msg)
484 484 return
485 485 client_id = idents[0]
486 486 try:
487 487 msg = self.session.unserialize(msg, content=True)
488 488 except Exception:
489 489 content = error.wrap_exception()
490 490 self.log.error("Bad Query Message: %r", msg, exc_info=True)
491 491 self.session.send(self.query, "hub_error", ident=client_id,
492 492 content=content)
493 493 return
494 494 # print client_id, header, parent, content
495 495 #switch on message type:
496 496 msg_type = msg['header']['msg_type']
497 497 self.log.info("client::client %r requested %r", client_id, msg_type)
498 498 handler = self.query_handlers.get(msg_type, None)
499 499 try:
500 500 assert handler is not None, "Bad Message Type: %r" % msg_type
501 501 except:
502 502 content = error.wrap_exception()
503 503 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
504 504 self.session.send(self.query, "hub_error", ident=client_id,
505 505 content=content)
506 506 return
507 507
508 508 else:
509 509 handler(idents, msg)
510 510
511 511 def dispatch_db(self, msg):
512 512 """"""
513 513 raise NotImplementedError
514 514
515 515 #---------------------------------------------------------------------------
516 516 # handler methods (1 per event)
517 517 #---------------------------------------------------------------------------
518 518
519 519 #----------------------- Heartbeat --------------------------------------
520 520
521 521 def handle_new_heart(self, heart):
522 522 """handler to attach to heartbeater.
523 523 Called when a new heart starts to beat.
524 524 Triggers completion of registration."""
525 525 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
526 526 if heart not in self.incoming_registrations:
527 527 self.log.info("heartbeat::ignoring new heart: %r", heart)
528 528 else:
529 529 self.finish_registration(heart)
530 530
531 531
532 532 def handle_heart_failure(self, heart):
533 533 """handler to attach to heartbeater.
534 534 called when a previously registered heart fails to respond to beat request.
535 535 triggers unregistration"""
536 536 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
537 537 eid = self.hearts.get(heart, None)
538 538 queue = self.engines[eid].queue
539 539 if eid is None or self.keytable[eid] in self.dead_engines:
540 540 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
541 541 else:
542 542 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 543
544 544 #----------------------- MUX Queue Traffic ------------------------------
545 545
546 546 def save_queue_request(self, idents, msg):
547 547 if len(idents) < 2:
548 548 self.log.error("invalid identity prefix: %r", idents)
549 549 return
550 550 queue_id, client_id = idents[:2]
551 551 try:
552 552 msg = self.session.unserialize(msg)
553 553 except Exception:
554 554 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
555 555 return
556 556
557 557 eid = self.by_ident.get(queue_id, None)
558 558 if eid is None:
559 559 self.log.error("queue::target %r not registered", queue_id)
560 560 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
561 561 return
562 562 record = init_record(msg)
563 563 msg_id = record['msg_id']
564 564 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
565 565 # Unicode in records
566 566 record['engine_uuid'] = queue_id.decode('ascii')
567 567 record['client_uuid'] = client_id.decode('ascii')
568 568 record['queue'] = 'mux'
569 569
570 570 try:
571 571 # it's posible iopub arrived first:
572 572 existing = self.db.get_record(msg_id)
573 573 for key,evalue in existing.iteritems():
574 574 rvalue = record.get(key, None)
575 575 if evalue and rvalue and evalue != rvalue:
576 576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
577 577 elif evalue and not rvalue:
578 578 record[key] = evalue
579 579 try:
580 580 self.db.update_record(msg_id, record)
581 581 except Exception:
582 582 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
583 583 except KeyError:
584 584 try:
585 585 self.db.add_record(msg_id, record)
586 586 except Exception:
587 587 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
588 588
589 589
590 590 self.pending.add(msg_id)
591 591 self.queues[eid].append(msg_id)
592 592
593 593 def save_queue_result(self, idents, msg):
594 594 if len(idents) < 2:
595 595 self.log.error("invalid identity prefix: %r", idents)
596 596 return
597 597
598 598 client_id, queue_id = idents[:2]
599 599 try:
600 600 msg = self.session.unserialize(msg)
601 601 except Exception:
602 602 self.log.error("queue::engine %r sent invalid message to %r: %r",
603 603 queue_id, client_id, msg, exc_info=True)
604 604 return
605 605
606 606 eid = self.by_ident.get(queue_id, None)
607 607 if eid is None:
608 608 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
609 609 return
610 610
611 611 parent = msg['parent_header']
612 612 if not parent:
613 613 return
614 614 msg_id = parent['msg_id']
615 615 if msg_id in self.pending:
616 616 self.pending.remove(msg_id)
617 617 self.all_completed.add(msg_id)
618 618 self.queues[eid].remove(msg_id)
619 619 self.completed[eid].append(msg_id)
620 620 self.log.info("queue::request %r completed on %s", msg_id, eid)
621 621 elif msg_id not in self.all_completed:
622 622 # it could be a result from a dead engine that died before delivering the
623 623 # result
624 624 self.log.warn("queue:: unknown msg finished %r", msg_id)
625 625 return
626 626 # update record anyway, because the unregistration could have been premature
627 627 rheader = msg['header']
628 628 completed = rheader['date']
629 629 started = rheader.get('started', None)
630 630 result = {
631 631 'result_header' : rheader,
632 632 'result_content': msg['content'],
633 633 'started' : started,
634 634 'completed' : completed
635 635 }
636 636
637 637 result['result_buffers'] = msg['buffers']
638 638 try:
639 639 self.db.update_record(msg_id, result)
640 640 except Exception:
641 641 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
642 642
643 643
644 644 #--------------------- Task Queue Traffic ------------------------------
645 645
646 646 def save_task_request(self, idents, msg):
647 647 """Save the submission of a task."""
648 648 client_id = idents[0]
649 649
650 650 try:
651 651 msg = self.session.unserialize(msg)
652 652 except Exception:
653 653 self.log.error("task::client %r sent invalid task message: %r",
654 654 client_id, msg, exc_info=True)
655 655 return
656 656 record = init_record(msg)
657 657
658 658 record['client_uuid'] = client_id.decode('ascii')
659 659 record['queue'] = 'task'
660 660 header = msg['header']
661 661 msg_id = header['msg_id']
662 662 self.pending.add(msg_id)
663 663 self.unassigned.add(msg_id)
664 664 try:
665 665 # it's posible iopub arrived first:
666 666 existing = self.db.get_record(msg_id)
667 667 if existing['resubmitted']:
668 668 for key in ('submitted', 'client_uuid', 'buffers'):
669 669 # don't clobber these keys on resubmit
670 670 # submitted and client_uuid should be different
671 671 # and buffers might be big, and shouldn't have changed
672 672 record.pop(key)
673 673 # still check content,header which should not change
674 674 # but are not expensive to compare as buffers
675 675
676 676 for key,evalue in existing.iteritems():
677 677 if key.endswith('buffers'):
678 678 # don't compare buffers
679 679 continue
680 680 rvalue = record.get(key, None)
681 681 if evalue and rvalue and evalue != rvalue:
682 682 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
683 683 elif evalue and not rvalue:
684 684 record[key] = evalue
685 685 try:
686 686 self.db.update_record(msg_id, record)
687 687 except Exception:
688 688 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
689 689 except KeyError:
690 690 try:
691 691 self.db.add_record(msg_id, record)
692 692 except Exception:
693 693 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
694 694 except Exception:
695 695 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
696 696
697 697 def save_task_result(self, idents, msg):
698 698 """save the result of a completed task."""
699 699 client_id = idents[0]
700 700 try:
701 701 msg = self.session.unserialize(msg)
702 702 except Exception:
703 703 self.log.error("task::invalid task result message send to %r: %r",
704 704 client_id, msg, exc_info=True)
705 705 return
706 706
707 707 parent = msg['parent_header']
708 708 if not parent:
709 709 # print msg
710 710 self.log.warn("Task %r had no parent!", msg)
711 711 return
712 712 msg_id = parent['msg_id']
713 713 if msg_id in self.unassigned:
714 714 self.unassigned.remove(msg_id)
715 715
716 716 header = msg['header']
717 717 engine_uuid = header.get('engine', None)
718 718 eid = self.by_ident.get(engine_uuid, None)
719 719
720 720 if msg_id in self.pending:
721 721 self.log.info("task::task %r finished on %s", msg_id, eid)
722 722 self.pending.remove(msg_id)
723 723 self.all_completed.add(msg_id)
724 724 if eid is not None:
725 725 self.completed[eid].append(msg_id)
726 726 if msg_id in self.tasks[eid]:
727 727 self.tasks[eid].remove(msg_id)
728 728 completed = header['date']
729 729 started = header.get('started', None)
730 730 result = {
731 731 'result_header' : header,
732 732 'result_content': msg['content'],
733 733 'started' : started,
734 734 'completed' : completed,
735 735 'engine_uuid': engine_uuid
736 736 }
737 737
738 738 result['result_buffers'] = msg['buffers']
739 739 try:
740 740 self.db.update_record(msg_id, result)
741 741 except Exception:
742 742 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
743 743
744 744 else:
745 745 self.log.debug("task::unknown task %r finished", msg_id)
746 746
747 747 def save_task_destination(self, idents, msg):
748 748 try:
749 749 msg = self.session.unserialize(msg, content=True)
750 750 except Exception:
751 751 self.log.error("task::invalid task tracking message", exc_info=True)
752 752 return
753 753 content = msg['content']
754 754 # print (content)
755 755 msg_id = content['msg_id']
756 756 engine_uuid = content['engine_id']
757 757 eid = self.by_ident[util.asbytes(engine_uuid)]
758 758
759 759 self.log.info("task::task %r arrived on %r", msg_id, eid)
760 760 if msg_id in self.unassigned:
761 761 self.unassigned.remove(msg_id)
762 762 # else:
763 763 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
764 764
765 765 self.tasks[eid].append(msg_id)
766 766 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
767 767 try:
768 768 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
769 769 except Exception:
770 770 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
771 771
772 772
773 773 def mia_task_request(self, idents, msg):
774 774 raise NotImplementedError
775 775 client_id = idents[0]
776 776 # content = dict(mia=self.mia,status='ok')
777 777 # self.session.send('mia_reply', content=content, idents=client_id)
778 778
779 779
780 780 #--------------------- IOPub Traffic ------------------------------
781 781
782 782 def save_iopub_message(self, topics, msg):
783 783 """save an iopub message into the db"""
784 784 # print (topics)
785 785 try:
786 786 msg = self.session.unserialize(msg, content=True)
787 787 except Exception:
788 788 self.log.error("iopub::invalid IOPub message", exc_info=True)
789 789 return
790 790
791 791 parent = msg['parent_header']
792 792 if not parent:
793 793 self.log.error("iopub::invalid IOPub message: %r", msg)
794 794 return
795 795 msg_id = parent['msg_id']
796 796 msg_type = msg['header']['msg_type']
797 797 content = msg['content']
798 798
799 799 # ensure msg_id is in db
800 800 try:
801 801 rec = self.db.get_record(msg_id)
802 802 except KeyError:
803 803 rec = empty_record()
804 804 rec['msg_id'] = msg_id
805 805 self.db.add_record(msg_id, rec)
806 806 # stream
807 807 d = {}
808 808 if msg_type == 'stream':
809 809 name = content['name']
810 810 s = rec[name] or ''
811 811 d[name] = s + content['data']
812 812
813 813 elif msg_type == 'pyerr':
814 814 d['pyerr'] = content
815 815 elif msg_type == 'pyin':
816 816 d['pyin'] = content['code']
817 817 else:
818 818 d[msg_type] = content.get('data', '')
819 819
820 820 try:
821 821 self.db.update_record(msg_id, d)
822 822 except Exception:
823 823 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
824 824
825 825
826 826
827 827 #-------------------------------------------------------------------------
828 828 # Registration requests
829 829 #-------------------------------------------------------------------------
830 830
831 831 def connection_request(self, client_id, msg):
832 832 """Reply with connection addresses for clients."""
833 833 self.log.info("client::client %r connected", client_id)
834 834 content = dict(status='ok')
835 835 content.update(self.client_info)
836 836 jsonable = {}
837 837 for k,v in self.keytable.iteritems():
838 838 if v not in self.dead_engines:
839 839 jsonable[str(k)] = v.decode('ascii')
840 840 content['engines'] = jsonable
841 841 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
842 842
843 843 def register_engine(self, reg, msg):
844 844 """Register a new engine."""
845 845 content = msg['content']
846 846 try:
847 847 queue = util.asbytes(content['queue'])
848 848 except KeyError:
849 849 self.log.error("registration::queue not specified", exc_info=True)
850 850 return
851 851 heart = content.get('heartbeat', None)
852 852 if heart:
853 853 heart = util.asbytes(heart)
854 854 """register a new engine, and create the socket(s) necessary"""
855 855 eid = self._next_id
856 856 # print (eid, queue, reg, heart)
857 857
858 858 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
859 859
860 860 content = dict(id=eid,status='ok')
861 861 content.update(self.engine_info)
862 862 # check if requesting available IDs:
863 863 if queue in self.by_ident:
864 864 try:
865 865 raise KeyError("queue_id %r in use" % queue)
866 866 except:
867 867 content = error.wrap_exception()
868 868 self.log.error("queue_id %r in use", queue, exc_info=True)
869 869 elif heart in self.hearts: # need to check unique hearts?
870 870 try:
871 871 raise KeyError("heart_id %r in use" % heart)
872 872 except:
873 873 self.log.error("heart_id %r in use", heart, exc_info=True)
874 874 content = error.wrap_exception()
875 875 else:
876 876 for h, pack in self.incoming_registrations.iteritems():
877 877 if heart == h:
878 878 try:
879 879 raise KeyError("heart_id %r in use" % heart)
880 880 except:
881 881 self.log.error("heart_id %r in use", heart, exc_info=True)
882 882 content = error.wrap_exception()
883 883 break
884 884 elif queue == pack[1]:
885 885 try:
886 886 raise KeyError("queue_id %r in use" % queue)
887 887 except:
888 888 self.log.error("queue_id %r in use", queue, exc_info=True)
889 889 content = error.wrap_exception()
890 890 break
891 891
892 892 msg = self.session.send(self.query, "registration_reply",
893 893 content=content,
894 894 ident=reg)
895 895
896 896 if content['status'] == 'ok':
897 897 if heart in self.heartmonitor.hearts:
898 898 # already beating
899 899 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
900 900 self.finish_registration(heart)
901 901 else:
902 902 purge = lambda : self._purge_stalled_registration(heart)
903 903 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
904 904 dc.start()
905 905 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
906 906 else:
907 907 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
908 908 return eid
909 909
910 910 def unregister_engine(self, ident, msg):
911 911 """Unregister an engine that explicitly requested to leave."""
912 912 try:
913 913 eid = msg['content']['id']
914 914 except:
915 915 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
916 916 return
917 917 self.log.info("registration::unregister_engine(%r)", eid)
918 918 # print (eid)
919 919 uuid = self.keytable[eid]
920 920 content=dict(id=eid, queue=uuid.decode('ascii'))
921 921 self.dead_engines.add(uuid)
922 922 # self.ids.remove(eid)
923 923 # uuid = self.keytable.pop(eid)
924 924 #
925 925 # ec = self.engines.pop(eid)
926 926 # self.hearts.pop(ec.heartbeat)
927 927 # self.by_ident.pop(ec.queue)
928 928 # self.completed.pop(eid)
929 929 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
930 930 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
931 931 dc.start()
932 932 ############## TODO: HANDLE IT ################
933 933
934 934 if self.notifier:
935 935 self.session.send(self.notifier, "unregistration_notification", content=content)
936 936
937 937 def _handle_stranded_msgs(self, eid, uuid):
938 938 """Handle messages known to be on an engine when the engine unregisters.
939 939
940 940 It is possible that this will fire prematurely - that is, an engine will
941 941 go down after completing a result, and the client will be notified
942 942 that the result failed and later receive the actual result.
943 943 """
944 944
945 945 outstanding = self.queues[eid]
946 946
947 947 for msg_id in outstanding:
948 948 self.pending.remove(msg_id)
949 949 self.all_completed.add(msg_id)
950 950 try:
951 951 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
952 952 except:
953 953 content = error.wrap_exception()
954 954 # build a fake header:
955 955 header = {}
956 956 header['engine'] = uuid
957 957 header['date'] = datetime.now()
958 958 rec = dict(result_content=content, result_header=header, result_buffers=[])
959 959 rec['completed'] = header['date']
960 960 rec['engine_uuid'] = uuid
961 961 try:
962 962 self.db.update_record(msg_id, rec)
963 963 except Exception:
964 964 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
965 965
966 966
967 967 def finish_registration(self, heart):
968 968 """Second half of engine registration, called after our HeartMonitor
969 969 has received a beat from the Engine's Heart."""
970 970 try:
971 971 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
972 972 except KeyError:
973 973 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
974 974 return
975 975 self.log.info("registration::finished registering engine %i:%r", eid, queue)
976 976 if purge is not None:
977 977 purge.stop()
978 978 control = queue
979 979 self.ids.add(eid)
980 980 self.keytable[eid] = queue
981 981 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
982 982 control=control, heartbeat=heart)
983 983 self.by_ident[queue] = eid
984 984 self.queues[eid] = list()
985 985 self.tasks[eid] = list()
986 986 self.completed[eid] = list()
987 987 self.hearts[heart] = eid
988 988 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
989 989 if self.notifier:
990 990 self.session.send(self.notifier, "registration_notification", content=content)
991 991 self.log.info("engine::Engine Connected: %i", eid)
992 992
993 993 def _purge_stalled_registration(self, heart):
994 994 if heart in self.incoming_registrations:
995 995 eid = self.incoming_registrations.pop(heart)[0]
996 996 self.log.info("registration::purging stalled registration: %i", eid)
997 997 else:
998 998 pass
999 999
1000 1000 #-------------------------------------------------------------------------
1001 1001 # Client Requests
1002 1002 #-------------------------------------------------------------------------
1003 1003
1004 1004 def shutdown_request(self, client_id, msg):
1005 1005 """handle shutdown request."""
1006 1006 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1007 1007 # also notify other clients of shutdown
1008 1008 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1009 1009 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1010 1010 dc.start()
1011 1011
1012 1012 def _shutdown(self):
1013 1013 self.log.info("hub::hub shutting down.")
1014 1014 time.sleep(0.1)
1015 1015 sys.exit(0)
1016 1016
1017 1017
1018 1018 def check_load(self, client_id, msg):
1019 1019 content = msg['content']
1020 1020 try:
1021 1021 targets = content['targets']
1022 1022 targets = self._validate_targets(targets)
1023 1023 except:
1024 1024 content = error.wrap_exception()
1025 1025 self.session.send(self.query, "hub_error",
1026 1026 content=content, ident=client_id)
1027 1027 return
1028 1028
1029 1029 content = dict(status='ok')
1030 1030 # loads = {}
1031 1031 for t in targets:
1032 1032 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1033 1033 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1034 1034
1035 1035
1036 1036 def queue_status(self, client_id, msg):
1037 1037 """Return the Queue status of one or more targets.
1038 1038 if verbose: return the msg_ids
1039 1039 else: return len of each type.
1040 1040 keys: queue (pending MUX jobs)
1041 1041 tasks (pending Task jobs)
1042 1042 completed (finished jobs from both queues)"""
1043 1043 content = msg['content']
1044 1044 targets = content['targets']
1045 1045 try:
1046 1046 targets = self._validate_targets(targets)
1047 1047 except:
1048 1048 content = error.wrap_exception()
1049 1049 self.session.send(self.query, "hub_error",
1050 1050 content=content, ident=client_id)
1051 1051 return
1052 1052 verbose = content.get('verbose', False)
1053 1053 content = dict(status='ok')
1054 1054 for t in targets:
1055 1055 queue = self.queues[t]
1056 1056 completed = self.completed[t]
1057 1057 tasks = self.tasks[t]
1058 1058 if not verbose:
1059 1059 queue = len(queue)
1060 1060 completed = len(completed)
1061 1061 tasks = len(tasks)
1062 1062 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1063 1063 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1064 1064 # print (content)
1065 1065 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1066 1066
1067 1067 def purge_results(self, client_id, msg):
1068 1068 """Purge results from memory. This method is more valuable before we move
1069 1069 to a DB based message storage mechanism."""
1070 1070 content = msg['content']
1071 1071 self.log.info("Dropping records with %s", content)
1072 1072 msg_ids = content.get('msg_ids', [])
1073 1073 reply = dict(status='ok')
1074 1074 if msg_ids == 'all':
1075 1075 try:
1076 1076 self.db.drop_matching_records(dict(completed={'$ne':None}))
1077 1077 except Exception:
1078 1078 reply = error.wrap_exception()
1079 1079 else:
1080 1080 pending = filter(lambda m: m in self.pending, msg_ids)
1081 1081 if pending:
1082 1082 try:
1083 1083 raise IndexError("msg pending: %r" % pending[0])
1084 1084 except:
1085 1085 reply = error.wrap_exception()
1086 1086 else:
1087 1087 try:
1088 1088 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1089 1089 except Exception:
1090 1090 reply = error.wrap_exception()
1091 1091
1092 1092 if reply['status'] == 'ok':
1093 1093 eids = content.get('engine_ids', [])
1094 1094 for eid in eids:
1095 1095 if eid not in self.engines:
1096 1096 try:
1097 1097 raise IndexError("No such engine: %i" % eid)
1098 1098 except:
1099 1099 reply = error.wrap_exception()
1100 1100 break
1101 1101 uid = self.engines[eid].queue
1102 1102 try:
1103 1103 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1104 1104 except Exception:
1105 1105 reply = error.wrap_exception()
1106 1106 break
1107 1107
1108 1108 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1109 1109
1110 1110 def resubmit_task(self, client_id, msg):
1111 1111 """Resubmit one or more tasks."""
1112 1112 def finish(reply):
1113 1113 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1114 1114
1115 1115 content = msg['content']
1116 1116 msg_ids = content['msg_ids']
1117 1117 reply = dict(status='ok')
1118 1118 try:
1119 1119 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1120 1120 'header', 'content', 'buffers'])
1121 1121 except Exception:
1122 1122 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1123 1123 return finish(error.wrap_exception())
1124 1124
1125 1125 # validate msg_ids
1126 1126 found_ids = [ rec['msg_id'] for rec in records ]
1127 1127 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1128 1128 if len(records) > len(msg_ids):
1129 1129 try:
1130 1130 raise RuntimeError("DB appears to be in an inconsistent state."
1131 1131 "More matching records were found than should exist")
1132 1132 except Exception:
1133 1133 return finish(error.wrap_exception())
1134 1134 elif len(records) < len(msg_ids):
1135 1135 missing = [ m for m in msg_ids if m not in found_ids ]
1136 1136 try:
1137 1137 raise KeyError("No such msg(s): %r" % missing)
1138 1138 except KeyError:
1139 1139 return finish(error.wrap_exception())
1140 1140 elif invalid_ids:
1141 1141 msg_id = invalid_ids[0]
1142 1142 try:
1143 1143 raise ValueError("Task %r appears to be inflight" % msg_id)
1144 1144 except Exception:
1145 1145 return finish(error.wrap_exception())
1146 1146
1147 1147 # clear the existing records
1148 1148 now = datetime.now()
1149 1149 rec = empty_record()
1150 1150 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1151 1151 rec['resubmitted'] = now
1152 1152 rec['queue'] = 'task'
1153 1153 rec['client_uuid'] = client_id[0]
1154 1154 try:
1155 1155 for msg_id in msg_ids:
1156 1156 self.all_completed.discard(msg_id)
1157 1157 self.db.update_record(msg_id, rec)
1158 1158 except Exception:
1159 1159 self.log.error('db::db error upating record', exc_info=True)
1160 1160 reply = error.wrap_exception()
1161 1161 else:
1162 1162 # send the messages
1163 1163 for rec in records:
1164 1164 header = rec['header']
1165 1165 # include resubmitted in header to prevent digest collision
1166 1166 header['resubmitted'] = now
1167 1167 msg = self.session.msg(header['msg_type'])
1168 1168 msg['content'] = rec['content']
1169 1169 msg['header'] = header
1170 1170 msg['header']['msg_id'] = rec['msg_id']
1171 1171 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1172 1172
1173 1173 finish(dict(status='ok'))
1174 1174
1175 1175
1176 1176 def _extract_record(self, rec):
1177 1177 """decompose a TaskRecord dict into subsection of reply for get_result"""
1178 1178 io_dict = {}
1179 1179 for key in 'pyin pyout pyerr stdout stderr'.split():
1180 1180 io_dict[key] = rec[key]
1181 1181 content = { 'result_content': rec['result_content'],
1182 1182 'header': rec['header'],
1183 1183 'result_header' : rec['result_header'],
1184 1184 'io' : io_dict,
1185 1185 }
1186 1186 if rec['result_buffers']:
1187 1187 buffers = map(bytes, rec['result_buffers'])
1188 1188 else:
1189 1189 buffers = []
1190 1190
1191 1191 return content, buffers
1192 1192
1193 1193 def get_results(self, client_id, msg):
1194 1194 """Get the result of 1 or more messages."""
1195 1195 content = msg['content']
1196 1196 msg_ids = sorted(set(content['msg_ids']))
1197 1197 statusonly = content.get('status_only', False)
1198 1198 pending = []
1199 1199 completed = []
1200 1200 content = dict(status='ok')
1201 1201 content['pending'] = pending
1202 1202 content['completed'] = completed
1203 1203 buffers = []
1204 1204 if not statusonly:
1205 1205 try:
1206 1206 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1207 1207 # turn match list into dict, for faster lookup
1208 1208 records = {}
1209 1209 for rec in matches:
1210 1210 records[rec['msg_id']] = rec
1211 1211 except Exception:
1212 1212 content = error.wrap_exception()
1213 1213 self.session.send(self.query, "result_reply", content=content,
1214 1214 parent=msg, ident=client_id)
1215 1215 return
1216 1216 else:
1217 1217 records = {}
1218 1218 for msg_id in msg_ids:
1219 1219 if msg_id in self.pending:
1220 1220 pending.append(msg_id)
1221 1221 elif msg_id in self.all_completed:
1222 1222 completed.append(msg_id)
1223 1223 if not statusonly:
1224 1224 c,bufs = self._extract_record(records[msg_id])
1225 1225 content[msg_id] = c
1226 1226 buffers.extend(bufs)
1227 1227 elif msg_id in records:
1228 1228 if rec['completed']:
1229 1229 completed.append(msg_id)
1230 1230 c,bufs = self._extract_record(records[msg_id])
1231 1231 content[msg_id] = c
1232 1232 buffers.extend(bufs)
1233 1233 else:
1234 1234 pending.append(msg_id)
1235 1235 else:
1236 1236 try:
1237 1237 raise KeyError('No such message: '+msg_id)
1238 1238 except:
1239 1239 content = error.wrap_exception()
1240 1240 break
1241 1241 self.session.send(self.query, "result_reply", content=content,
1242 1242 parent=msg, ident=client_id,
1243 1243 buffers=buffers)
1244 1244
1245 1245 def get_history(self, client_id, msg):
1246 1246 """Get a list of all msg_ids in our DB records"""
1247 1247 try:
1248 1248 msg_ids = self.db.get_history()
1249 1249 except Exception as e:
1250 1250 content = error.wrap_exception()
1251 1251 else:
1252 1252 content = dict(status='ok', history=msg_ids)
1253 1253
1254 1254 self.session.send(self.query, "history_reply", content=content,
1255 1255 parent=msg, ident=client_id)
1256 1256
1257 1257 def db_query(self, client_id, msg):
1258 1258 """Perform a raw query on the task record database."""
1259 1259 content = msg['content']
1260 1260 query = content.get('query', {})
1261 1261 keys = content.get('keys', None)
1262 1262 buffers = []
1263 1263 empty = list()
1264 1264 try:
1265 1265 records = self.db.find_records(query, keys)
1266 1266 except Exception as e:
1267 1267 content = error.wrap_exception()
1268 1268 else:
1269 1269 # extract buffers from reply content:
1270 1270 if keys is not None:
1271 1271 buffer_lens = [] if 'buffers' in keys else None
1272 1272 result_buffer_lens = [] if 'result_buffers' in keys else None
1273 1273 else:
1274 1274 buffer_lens = []
1275 1275 result_buffer_lens = []
1276 1276
1277 1277 for rec in records:
1278 1278 # buffers may be None, so double check
1279 1279 if buffer_lens is not None:
1280 1280 b = rec.pop('buffers', empty) or empty
1281 1281 buffer_lens.append(len(b))
1282 1282 buffers.extend(b)
1283 1283 if result_buffer_lens is not None:
1284 1284 rb = rec.pop('result_buffers', empty) or empty
1285 1285 result_buffer_lens.append(len(rb))
1286 1286 buffers.extend(rb)
1287 1287 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1288 1288 result_buffer_lens=result_buffer_lens)
1289 1289 # self.log.debug (content)
1290 1290 self.session.send(self.query, "db_reply", content=content,
1291 1291 parent=msg, ident=client_id,
1292 1292 buffers=buffers)
1293 1293
@@ -1,726 +1,732 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44
45 45 from IPython.parallel import error
46 46 from IPython.parallel.factory import SessionFactory
47 47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48 48
49 49 from .dependency import Dependency
50 50
51 51 @decorator
52 52 def logged(f,self,*args,**kwargs):
53 53 # print ("#--------------------")
54 54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 55 # print ("#--")
56 56 return f(self,*args, **kwargs)
57 57
58 58 #----------------------------------------------------------------------
59 59 # Chooser functions
60 60 #----------------------------------------------------------------------
61 61
62 62 def plainrandom(loads):
63 63 """Plain random pick."""
64 64 n = len(loads)
65 65 return randint(0,n-1)
66 66
67 67 def lru(loads):
68 68 """Always pick the front of the line.
69 69
70 70 The content of `loads` is ignored.
71 71
72 72 Assumes LRU ordering of loads, with oldest first.
73 73 """
74 74 return 0
75 75
76 76 def twobin(loads):
77 77 """Pick two at random, use the LRU of the two.
78 78
79 79 The content of loads is ignored.
80 80
81 81 Assumes LRU ordering of loads, with oldest first.
82 82 """
83 83 n = len(loads)
84 84 a = randint(0,n-1)
85 85 b = randint(0,n-1)
86 86 return min(a,b)
87 87
88 88 def weighted(loads):
89 89 """Pick two at random using inverse load as weight.
90 90
91 91 Return the less loaded of the two.
92 92 """
93 93 # weight 0 a million times more than 1:
94 94 weights = 1./(1e-6+numpy.array(loads))
95 95 sums = weights.cumsum()
96 96 t = sums[-1]
97 97 x = random()*t
98 98 y = random()*t
99 99 idx = 0
100 100 idy = 0
101 101 while sums[idx] < x:
102 102 idx += 1
103 103 while sums[idy] < y:
104 104 idy += 1
105 105 if weights[idy] > weights[idx]:
106 106 return idy
107 107 else:
108 108 return idx
109 109
110 110 def leastload(loads):
111 111 """Always choose the lowest load.
112 112
113 113 If the lowest load occurs more than once, the first
114 114 occurance will be used. If loads has LRU ordering, this means
115 115 the LRU of those with the lowest load is chosen.
116 116 """
117 117 return loads.index(min(loads))
118 118
119 119 #---------------------------------------------------------------------
120 120 # Classes
121 121 #---------------------------------------------------------------------
122 122 # store empty default dependency:
123 123 MET = Dependency([])
124 124
125 125 class TaskScheduler(SessionFactory):
126 126 """Python TaskScheduler object.
127 127
128 128 This is the simplest object that supports msg_id based
129 129 DAG dependencies. *Only* task msg_ids are checked, not
130 130 msg_ids of jobs submitted via the MUX queue.
131 131
132 132 """
133 133
134 134 hwm = Integer(1, config=True,
135 135 help="""specify the High Water Mark (HWM) for the downstream
136 136 socket in the Task scheduler. This is the maximum number
137 137 of allowed outstanding tasks on each engine.
138 138
139 139 The default (1) means that only one task can be outstanding on each
140 140 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
141 141 engines continue to be assigned tasks while they are working,
142 142 effectively hiding network latency behind computation, but can result
143 143 in an imbalance of work when submitting many heterogenous tasks all at
144 144 once. Any positive value greater than one is a compromise between the
145 145 two.
146 146
147 147 """
148 148 )
149 149 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
150 150 'leastload', config=True, allow_none=False,
151 151 help="""select the task scheduler scheme [default: Python LRU]
152 152 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
153 153 )
154 154 def _scheme_name_changed(self, old, new):
155 155 self.log.debug("Using scheme %r"%new)
156 156 self.scheme = globals()[new]
157 157
158 158 # input arguments:
159 159 scheme = Instance(FunctionType) # function for determining the destination
160 160 def _scheme_default(self):
161 161 return leastload
162 162 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
163 163 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
164 164 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
165 165 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
166 166
167 167 # internals:
168 168 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
169 169 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
170 170 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
171 171 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
172 172 pending = Dict() # dict by engine_uuid of submitted tasks
173 173 completed = Dict() # dict by engine_uuid of completed tasks
174 174 failed = Dict() # dict by engine_uuid of failed tasks
175 175 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
176 176 clients = Dict() # dict by msg_id for who submitted the task
177 177 targets = List() # list of target IDENTs
178 178 loads = List() # list of engine loads
179 179 # full = Set() # set of IDENTs that have HWM outstanding tasks
180 180 all_completed = Set() # set of all completed tasks
181 181 all_failed = Set() # set of all failed tasks
182 182 all_done = Set() # set of all finished tasks=union(completed,failed)
183 183 all_ids = Set() # set of all submitted task IDs
184 184 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
185 185 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
186 186
187 187 ident = CBytes() # ZMQ identity. This should just be self.session.session
188 188 # but ensure Bytes
189 189 def _ident_default(self):
190 190 return self.session.bsession
191 191
192 192 def start(self):
193 193 self.engine_stream.on_recv(self.dispatch_result, copy=False)
194 self.client_stream.on_recv(self.dispatch_submission, copy=False)
195
194 196 self._notification_handlers = dict(
195 197 registration_notification = self._register_engine,
196 198 unregistration_notification = self._unregister_engine
197 199 )
198 200 self.notifier_stream.on_recv(self.dispatch_notification)
199 201 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
200 202 self.auditor.start()
201 203 self.log.info("Scheduler started [%s]"%self.scheme_name)
202 204
203 205 def resume_receiving(self):
204 206 """Resume accepting jobs."""
205 207 self.client_stream.on_recv(self.dispatch_submission, copy=False)
206 208
207 209 def stop_receiving(self):
208 210 """Stop accepting jobs while there are no engines.
209 211 Leave them in the ZMQ queue."""
210 212 self.client_stream.on_recv(None)
211 213
212 214 #-----------------------------------------------------------------------
213 215 # [Un]Registration Handling
214 216 #-----------------------------------------------------------------------
215 217
216 218 def dispatch_notification(self, msg):
217 219 """dispatch register/unregister events."""
218 220 try:
219 221 idents,msg = self.session.feed_identities(msg)
220 222 except ValueError:
221 223 self.log.warn("task::Invalid Message: %r",msg)
222 224 return
223 225 try:
224 226 msg = self.session.unserialize(msg)
225 227 except ValueError:
226 228 self.log.warn("task::Unauthorized message from: %r"%idents)
227 229 return
228 230
229 231 msg_type = msg['header']['msg_type']
230 232
231 233 handler = self._notification_handlers.get(msg_type, None)
232 234 if handler is None:
233 235 self.log.error("Unhandled message type: %r"%msg_type)
234 236 else:
235 237 try:
236 238 handler(asbytes(msg['content']['queue']))
237 239 except Exception:
238 240 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
239 241
240 242 def _register_engine(self, uid):
241 243 """New engine with ident `uid` became available."""
242 244 # head of the line:
243 245 self.targets.insert(0,uid)
244 246 self.loads.insert(0,0)
245 247
246 248 # initialize sets
247 249 self.completed[uid] = set()
248 250 self.failed[uid] = set()
249 251 self.pending[uid] = {}
250 if len(self.targets) == 1:
251 self.resume_receiving()
252
252 253 # rescan the graph:
253 254 self.update_graph(None)
254 255
255 256 def _unregister_engine(self, uid):
256 257 """Existing engine with ident `uid` became unavailable."""
257 258 if len(self.targets) == 1:
258 259 # this was our only engine
259 self.stop_receiving()
260 pass
260 261
261 262 # handle any potentially finished tasks:
262 263 self.engine_stream.flush()
263 264
264 265 # don't pop destinations, because they might be used later
265 266 # map(self.destinations.pop, self.completed.pop(uid))
266 267 # map(self.destinations.pop, self.failed.pop(uid))
267 268
268 269 # prevent this engine from receiving work
269 270 idx = self.targets.index(uid)
270 271 self.targets.pop(idx)
271 272 self.loads.pop(idx)
272 273
273 274 # wait 5 seconds before cleaning up pending jobs, since the results might
274 275 # still be incoming
275 276 if self.pending[uid]:
276 277 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
277 278 dc.start()
278 279 else:
279 280 self.completed.pop(uid)
280 281 self.failed.pop(uid)
281 282
282 283
283 284 def handle_stranded_tasks(self, engine):
284 285 """Deal with jobs resident in an engine that died."""
285 286 lost = self.pending[engine]
286 287 for msg_id in lost.keys():
287 288 if msg_id not in self.pending[engine]:
288 289 # prevent double-handling of messages
289 290 continue
290 291
291 292 raw_msg = lost[msg_id][0]
292 293 idents,msg = self.session.feed_identities(raw_msg, copy=False)
293 294 parent = self.session.unpack(msg[1].bytes)
294 295 idents = [engine, idents[0]]
295 296
296 297 # build fake error reply
297 298 try:
298 299 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
299 300 except:
300 301 content = error.wrap_exception()
301 302 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
302 303 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
303 304 # and dispatch it
304 305 self.dispatch_result(raw_reply)
305 306
306 307 # finally scrub completed/failed lists
307 308 self.completed.pop(engine)
308 309 self.failed.pop(engine)
309 310
310 311
311 312 #-----------------------------------------------------------------------
312 313 # Job Submission
313 314 #-----------------------------------------------------------------------
314 315 def dispatch_submission(self, raw_msg):
315 316 """Dispatch job submission to appropriate handlers."""
316 317 # ensure targets up to date:
317 318 self.notifier_stream.flush()
318 319 try:
319 320 idents, msg = self.session.feed_identities(raw_msg, copy=False)
320 321 msg = self.session.unserialize(msg, content=False, copy=False)
321 322 except Exception:
322 323 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
323 324 return
324 325
325 326
326 327 # send to monitor
327 328 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
328 329
329 330 header = msg['header']
330 331 msg_id = header['msg_id']
331 332 self.all_ids.add(msg_id)
332 333
333 334 # get targets as a set of bytes objects
334 335 # from a list of unicode objects
335 336 targets = header.get('targets', [])
336 337 targets = map(asbytes, targets)
337 338 targets = set(targets)
338 339
339 340 retries = header.get('retries', 0)
340 341 self.retries[msg_id] = retries
341 342
342 343 # time dependencies
343 344 after = header.get('after', None)
344 345 if after:
345 346 after = Dependency(after)
346 347 if after.all:
347 348 if after.success:
348 349 after = Dependency(after.difference(self.all_completed),
349 350 success=after.success,
350 351 failure=after.failure,
351 352 all=after.all,
352 353 )
353 354 if after.failure:
354 355 after = Dependency(after.difference(self.all_failed),
355 356 success=after.success,
356 357 failure=after.failure,
357 358 all=after.all,
358 359 )
359 360 if after.check(self.all_completed, self.all_failed):
360 361 # recast as empty set, if `after` already met,
361 362 # to prevent unnecessary set comparisons
362 363 after = MET
363 364 else:
364 365 after = MET
365 366
366 367 # location dependencies
367 368 follow = Dependency(header.get('follow', []))
368 369
369 370 # turn timeouts into datetime objects:
370 371 timeout = header.get('timeout', None)
371 372 if timeout:
372 373 # cast to float, because jsonlib returns floats as decimal.Decimal,
373 374 # which timedelta does not accept
374 375 timeout = datetime.now() + timedelta(0,float(timeout),0)
375 376
376 377 args = [raw_msg, targets, after, follow, timeout]
377 378
378 379 # validate and reduce dependencies:
379 380 for dep in after,follow:
380 381 if not dep: # empty dependency
381 382 continue
382 383 # check valid:
383 384 if msg_id in dep or dep.difference(self.all_ids):
384 385 self.depending[msg_id] = args
385 386 return self.fail_unreachable(msg_id, error.InvalidDependency)
386 387 # check if unreachable:
387 388 if dep.unreachable(self.all_completed, self.all_failed):
388 389 self.depending[msg_id] = args
389 390 return self.fail_unreachable(msg_id)
390 391
391 392 if after.check(self.all_completed, self.all_failed):
392 393 # time deps already met, try to run
393 394 if not self.maybe_run(msg_id, *args):
394 395 # can't run yet
395 396 if msg_id not in self.all_failed:
396 397 # could have failed as unreachable
397 398 self.save_unmet(msg_id, *args)
398 399 else:
399 400 self.save_unmet(msg_id, *args)
400 401
401 402 def audit_timeouts(self):
402 403 """Audit all waiting tasks for expired timeouts."""
403 404 now = datetime.now()
404 405 for msg_id in self.depending.keys():
405 406 # must recheck, in case one failure cascaded to another:
406 407 if msg_id in self.depending:
407 408 raw,after,targets,follow,timeout = self.depending[msg_id]
408 409 if timeout and timeout < now:
409 410 self.fail_unreachable(msg_id, error.TaskTimeout)
410 411
411 412 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
412 413 """a task has become unreachable, send a reply with an ImpossibleDependency
413 414 error."""
414 415 if msg_id not in self.depending:
415 416 self.log.error("msg %r already failed!", msg_id)
416 417 return
417 418 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
418 419 for mid in follow.union(after):
419 420 if mid in self.graph:
420 421 self.graph[mid].remove(msg_id)
421 422
422 423 # FIXME: unpacking a message I've already unpacked, but didn't save:
423 424 idents,msg = self.session.feed_identities(raw_msg, copy=False)
424 425 header = self.session.unpack(msg[1].bytes)
425 426
426 427 try:
427 428 raise why()
428 429 except:
429 430 content = error.wrap_exception()
430 431
431 432 self.all_done.add(msg_id)
432 433 self.all_failed.add(msg_id)
433 434
434 435 msg = self.session.send(self.client_stream, 'apply_reply', content,
435 436 parent=header, ident=idents)
436 437 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
437 438
438 439 self.update_graph(msg_id, success=False)
439 440
440 441 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
441 442 """check location dependencies, and run if they are met."""
443 self.log.debug("Attempting to assign task %s", msg_id)
444 if not self.targets:
445 # no engines, definitely can't run
446 return False
447
442 448 blacklist = self.blacklist.setdefault(msg_id, set())
443 449 if follow or targets or blacklist or self.hwm:
444 450 # we need a can_run filter
445 451 def can_run(idx):
446 452 # check hwm
447 453 if self.hwm and self.loads[idx] == self.hwm:
448 454 return False
449 455 target = self.targets[idx]
450 456 # check blacklist
451 457 if target in blacklist:
452 458 return False
453 459 # check targets
454 460 if targets and target not in targets:
455 461 return False
456 462 # check follow
457 463 return follow.check(self.completed[target], self.failed[target])
458 464
459 465 indices = filter(can_run, range(len(self.targets)))
460 466
461 467 if not indices:
462 468 # couldn't run
463 469 if follow.all:
464 470 # check follow for impossibility
465 471 dests = set()
466 472 relevant = set()
467 473 if follow.success:
468 474 relevant = self.all_completed
469 475 if follow.failure:
470 476 relevant = relevant.union(self.all_failed)
471 477 for m in follow.intersection(relevant):
472 478 dests.add(self.destinations[m])
473 479 if len(dests) > 1:
474 480 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
475 481 self.fail_unreachable(msg_id)
476 482 return False
477 483 if targets:
478 484 # check blacklist+targets for impossibility
479 485 targets.difference_update(blacklist)
480 486 if not targets or not targets.intersection(self.targets):
481 487 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
482 488 self.fail_unreachable(msg_id)
483 489 return False
484 490 return False
485 491 else:
486 492 indices = None
487 493
488 494 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
489 495 return True
490 496
491 497 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
492 498 """Save a message for later submission when its dependencies are met."""
493 499 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
494 500 # track the ids in follow or after, but not those already finished
495 501 for dep_id in after.union(follow).difference(self.all_done):
496 502 if dep_id not in self.graph:
497 503 self.graph[dep_id] = set()
498 504 self.graph[dep_id].add(msg_id)
499 505
500 506 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
501 507 """Submit a task to any of a subset of our targets."""
502 508 if indices:
503 509 loads = [self.loads[i] for i in indices]
504 510 else:
505 511 loads = self.loads
506 512 idx = self.scheme(loads)
507 513 if indices:
508 514 idx = indices[idx]
509 515 target = self.targets[idx]
510 516 # print (target, map(str, msg[:3]))
511 517 # send job to the engine
512 518 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
513 519 self.engine_stream.send_multipart(raw_msg, copy=False)
514 520 # update load
515 521 self.add_job(idx)
516 522 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
517 523 # notify Hub
518 524 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
519 525 self.session.send(self.mon_stream, 'task_destination', content=content,
520 526 ident=[b'tracktask',self.ident])
521 527
522 528
523 529 #-----------------------------------------------------------------------
524 530 # Result Handling
525 531 #-----------------------------------------------------------------------
526 532 def dispatch_result(self, raw_msg):
527 533 """dispatch method for result replies"""
528 534 try:
529 535 idents,msg = self.session.feed_identities(raw_msg, copy=False)
530 536 msg = self.session.unserialize(msg, content=False, copy=False)
531 537 engine = idents[0]
532 538 try:
533 539 idx = self.targets.index(engine)
534 540 except ValueError:
535 541 pass # skip load-update for dead engines
536 542 else:
537 543 self.finish_job(idx)
538 544 except Exception:
539 545 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
540 546 return
541 547
542 548 header = msg['header']
543 549 parent = msg['parent_header']
544 550 if header.get('dependencies_met', True):
545 551 success = (header['status'] == 'ok')
546 552 msg_id = parent['msg_id']
547 553 retries = self.retries[msg_id]
548 554 if not success and retries > 0:
549 555 # failed
550 556 self.retries[msg_id] = retries - 1
551 557 self.handle_unmet_dependency(idents, parent)
552 558 else:
553 559 del self.retries[msg_id]
554 560 # relay to client and update graph
555 561 self.handle_result(idents, parent, raw_msg, success)
556 562 # send to Hub monitor
557 563 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
558 564 else:
559 565 self.handle_unmet_dependency(idents, parent)
560 566
561 567 def handle_result(self, idents, parent, raw_msg, success=True):
562 568 """handle a real task result, either success or failure"""
563 569 # first, relay result to client
564 570 engine = idents[0]
565 571 client = idents[1]
566 572 # swap_ids for XREP-XREP mirror
567 573 raw_msg[:2] = [client,engine]
568 574 # print (map(str, raw_msg[:4]))
569 575 self.client_stream.send_multipart(raw_msg, copy=False)
570 576 # now, update our data structures
571 577 msg_id = parent['msg_id']
572 578 self.blacklist.pop(msg_id, None)
573 579 self.pending[engine].pop(msg_id)
574 580 if success:
575 581 self.completed[engine].add(msg_id)
576 582 self.all_completed.add(msg_id)
577 583 else:
578 584 self.failed[engine].add(msg_id)
579 585 self.all_failed.add(msg_id)
580 586 self.all_done.add(msg_id)
581 587 self.destinations[msg_id] = engine
582 588
583 589 self.update_graph(msg_id, success)
584 590
585 591 def handle_unmet_dependency(self, idents, parent):
586 592 """handle an unmet dependency"""
587 593 engine = idents[0]
588 594 msg_id = parent['msg_id']
589 595
590 596 if msg_id not in self.blacklist:
591 597 self.blacklist[msg_id] = set()
592 598 self.blacklist[msg_id].add(engine)
593 599
594 600 args = self.pending[engine].pop(msg_id)
595 601 raw,targets,after,follow,timeout = args
596 602
597 603 if self.blacklist[msg_id] == targets:
598 604 self.depending[msg_id] = args
599 605 self.fail_unreachable(msg_id)
600 606 elif not self.maybe_run(msg_id, *args):
601 607 # resubmit failed
602 608 if msg_id not in self.all_failed:
603 609 # put it back in our dependency tree
604 610 self.save_unmet(msg_id, *args)
605 611
606 612 if self.hwm:
607 613 try:
608 614 idx = self.targets.index(engine)
609 615 except ValueError:
610 616 pass # skip load-update for dead engines
611 617 else:
612 618 if self.loads[idx] == self.hwm-1:
613 619 self.update_graph(None)
614 620
615 621
616 622
617 623 def update_graph(self, dep_id=None, success=True):
618 624 """dep_id just finished. Update our dependency
619 625 graph and submit any jobs that just became runable.
620 626
621 627 Called with dep_id=None to update entire graph for hwm, but without finishing
622 628 a task.
623 629 """
624 630 # print ("\n\n***********")
625 631 # pprint (dep_id)
626 632 # pprint (self.graph)
627 633 # pprint (self.depending)
628 634 # pprint (self.all_completed)
629 635 # pprint (self.all_failed)
630 636 # print ("\n\n***********\n\n")
631 637 # update any jobs that depended on the dependency
632 638 jobs = self.graph.pop(dep_id, [])
633 639
634 640 # recheck *all* jobs if
635 641 # a) we have HWM and an engine just become no longer full
636 642 # or b) dep_id was given as None
637 643 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
638 644 jobs = self.depending.keys()
639 645
640 646 for msg_id in jobs:
641 647 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
642 648
643 649 if after.unreachable(self.all_completed, self.all_failed)\
644 650 or follow.unreachable(self.all_completed, self.all_failed):
645 651 self.fail_unreachable(msg_id)
646 652
647 653 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
648 654 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
649 655
650 656 self.depending.pop(msg_id)
651 657 for mid in follow.union(after):
652 658 if mid in self.graph:
653 659 self.graph[mid].remove(msg_id)
654 660
655 661 #----------------------------------------------------------------------
656 662 # methods to be overridden by subclasses
657 663 #----------------------------------------------------------------------
658 664
659 665 def add_job(self, idx):
660 666 """Called after self.targets[idx] just got the job with header.
661 667 Override with subclasses. The default ordering is simple LRU.
662 668 The default loads are the number of outstanding jobs."""
663 669 self.loads[idx] += 1
664 670 for lis in (self.targets, self.loads):
665 671 lis.append(lis.pop(idx))
666 672
667 673
668 674 def finish_job(self, idx):
669 675 """Called after self.targets[idx] just finished a job.
670 676 Override with subclasses."""
671 677 self.loads[idx] -= 1
672 678
673 679
674 680
675 681 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
676 682 logname='root', log_url=None, loglevel=logging.DEBUG,
677 683 identity=b'task', in_thread=False):
678 684
679 685 ZMQStream = zmqstream.ZMQStream
680 686
681 687 if config:
682 688 # unwrap dict back into Config
683 689 config = Config(config)
684 690
685 691 if in_thread:
686 692 # use instance() to get the same Context/Loop as our parent
687 693 ctx = zmq.Context.instance()
688 694 loop = ioloop.IOLoop.instance()
689 695 else:
690 696 # in a process, don't use instance()
691 697 # for safety with multiprocessing
692 698 ctx = zmq.Context()
693 699 loop = ioloop.IOLoop()
694 700 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
695 701 ins.setsockopt(zmq.IDENTITY, identity)
696 702 ins.bind(in_addr)
697 703
698 704 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
699 705 outs.setsockopt(zmq.IDENTITY, identity)
700 706 outs.bind(out_addr)
701 707 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
702 708 mons.connect(mon_addr)
703 709 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
704 710 nots.setsockopt(zmq.SUBSCRIBE, b'')
705 711 nots.connect(not_addr)
706 712
707 713 # setup logging.
708 714 if in_thread:
709 715 log = Application.instance().log
710 716 else:
711 717 if log_url:
712 718 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
713 719 else:
714 720 log = local_logger(logname, loglevel)
715 721
716 722 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
717 723 mon_stream=mons, notifier_stream=nots,
718 724 loop=loop, log=log,
719 725 config=config)
720 726 scheduler.start()
721 727 if not in_thread:
722 728 try:
723 729 loop.start()
724 730 except KeyboardInterrupt:
725 731 scheduler.log.critical("Interrupted, exiting...")
726 732
General Comments 0
You need to be logged in to leave comments. Login now