##// END OF EJS Templates
remove a few dangling asbytes from rebase
MinRK -
Show More
@@ -1,1401 +1,1401 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import json
22 22 import os
23 23 import sys
24 24 import time
25 25 from datetime import datetime
26 26
27 27 import zmq
28 28 from zmq.eventloop import ioloop
29 29 from zmq.eventloop.zmqstream import ZMQStream
30 30
31 31 # internal:
32 32 from IPython.utils.importstring import import_item
33 33 from IPython.utils.py3compat import cast_bytes
34 34 from IPython.utils.traitlets import (
35 35 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
36 36 )
37 37
38 38 from IPython.parallel import error, util
39 39 from IPython.parallel.factory import RegistrationFactory
40 40
41 41 from IPython.zmq.session import SessionFactory
42 42
43 43 from .heartmonitor import HeartMonitor
44 44
45 45 #-----------------------------------------------------------------------------
46 46 # Code
47 47 #-----------------------------------------------------------------------------
48 48
49 49 def _passer(*args, **kwargs):
50 50 return
51 51
52 52 def _printer(*args, **kwargs):
53 53 print (args)
54 54 print (kwargs)
55 55
56 56 def empty_record():
57 57 """Return an empty dict with all record keys."""
58 58 return {
59 59 'msg_id' : None,
60 60 'header' : None,
61 61 'content': None,
62 62 'buffers': None,
63 63 'submitted': None,
64 64 'client_uuid' : None,
65 65 'engine_uuid' : None,
66 66 'started': None,
67 67 'completed': None,
68 68 'resubmitted': None,
69 69 'received': None,
70 70 'result_header' : None,
71 71 'result_content' : None,
72 72 'result_buffers' : None,
73 73 'queue' : None,
74 74 'pyin' : None,
75 75 'pyout': None,
76 76 'pyerr': None,
77 77 'stdout': '',
78 78 'stderr': '',
79 79 }
80 80
81 81 def init_record(msg):
82 82 """Initialize a TaskRecord based on a request."""
83 83 header = msg['header']
84 84 return {
85 85 'msg_id' : header['msg_id'],
86 86 'header' : header,
87 87 'content': msg['content'],
88 88 'buffers': msg['buffers'],
89 89 'submitted': header['date'],
90 90 'client_uuid' : None,
91 91 'engine_uuid' : None,
92 92 'started': None,
93 93 'completed': None,
94 94 'resubmitted': None,
95 95 'received': None,
96 96 'result_header' : None,
97 97 'result_content' : None,
98 98 'result_buffers' : None,
99 99 'queue' : None,
100 100 'pyin' : None,
101 101 'pyout': None,
102 102 'pyerr': None,
103 103 'stdout': '',
104 104 'stderr': '',
105 105 }
106 106
107 107
108 108 class EngineConnector(HasTraits):
109 109 """A simple object for accessing the various zmq connections of an object.
110 110 Attributes are:
111 111 id (int): engine ID
112 112 uuid (unicode): engine UUID
113 113 pending: set of msg_ids
114 114 stallback: DelayedCallback for stalled registration
115 115 """
116 116
117 117 id = Integer(0)
118 118 uuid = Unicode()
119 119 pending = Set()
120 120 stallback = Instance(ioloop.DelayedCallback)
121 121
122 122
123 123 _db_shortcuts = {
124 124 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
125 125 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
126 126 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
127 127 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
128 128 }
129 129
130 130 class HubFactory(RegistrationFactory):
131 131 """The Configurable for setting up a Hub."""
132 132
133 133 # port-pairs for monitoredqueues:
134 134 hb = Tuple(Integer,Integer,config=True,
135 135 help="""PUB/ROUTER Port pair for Engine heartbeats""")
136 136 def _hb_default(self):
137 137 return tuple(util.select_random_ports(2))
138 138
139 139 mux = Tuple(Integer,Integer,config=True,
140 140 help="""Client/Engine Port pair for MUX queue""")
141 141
142 142 def _mux_default(self):
143 143 return tuple(util.select_random_ports(2))
144 144
145 145 task = Tuple(Integer,Integer,config=True,
146 146 help="""Client/Engine Port pair for Task queue""")
147 147 def _task_default(self):
148 148 return tuple(util.select_random_ports(2))
149 149
150 150 control = Tuple(Integer,Integer,config=True,
151 151 help="""Client/Engine Port pair for Control queue""")
152 152
153 153 def _control_default(self):
154 154 return tuple(util.select_random_ports(2))
155 155
156 156 iopub = Tuple(Integer,Integer,config=True,
157 157 help="""Client/Engine Port pair for IOPub relay""")
158 158
159 159 def _iopub_default(self):
160 160 return tuple(util.select_random_ports(2))
161 161
162 162 # single ports:
163 163 mon_port = Integer(config=True,
164 164 help="""Monitor (SUB) port for queue traffic""")
165 165
166 166 def _mon_port_default(self):
167 167 return util.select_random_ports(1)[0]
168 168
169 169 notifier_port = Integer(config=True,
170 170 help="""PUB port for sending engine status notifications""")
171 171
172 172 def _notifier_port_default(self):
173 173 return util.select_random_ports(1)[0]
174 174
175 175 engine_ip = Unicode('127.0.0.1', config=True,
176 176 help="IP on which to listen for engine connections. [default: loopback]")
177 177 engine_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for engine connections. [default: tcp]")
179 179
180 180 client_ip = Unicode('127.0.0.1', config=True,
181 181 help="IP on which to listen for client connections. [default: loopback]")
182 182 client_transport = Unicode('tcp', config=True,
183 183 help="0MQ transport for client connections. [default : tcp]")
184 184
185 185 monitor_ip = Unicode('127.0.0.1', config=True,
186 186 help="IP on which to listen for monitor messages. [default: loopback]")
187 187 monitor_transport = Unicode('tcp', config=True,
188 188 help="0MQ transport for monitor messages. [default : tcp]")
189 189
190 190 monitor_url = Unicode('')
191 191
192 192 db_class = DottedObjectName('NoDB',
193 193 config=True, help="""The class to use for the DB backend
194 194
195 195 Options include:
196 196
197 197 SQLiteDB: SQLite
198 198 MongoDB : use MongoDB
199 199 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
200 200 NoDB : disable database altogether (default)
201 201
202 202 """)
203 203
204 204 # not configurable
205 205 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
206 206 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
207 207
208 208 def _ip_changed(self, name, old, new):
209 209 self.engine_ip = new
210 210 self.client_ip = new
211 211 self.monitor_ip = new
212 212 self._update_monitor_url()
213 213
214 214 def _update_monitor_url(self):
215 215 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
216 216
217 217 def _transport_changed(self, name, old, new):
218 218 self.engine_transport = new
219 219 self.client_transport = new
220 220 self.monitor_transport = new
221 221 self._update_monitor_url()
222 222
223 223 def __init__(self, **kwargs):
224 224 super(HubFactory, self).__init__(**kwargs)
225 225 self._update_monitor_url()
226 226
227 227
228 228 def construct(self):
229 229 self.init_hub()
230 230
231 231 def start(self):
232 232 self.heartmonitor.start()
233 233 self.log.info("Heartmonitor started")
234 234
235 235 def client_url(self, channel):
236 236 """return full zmq url for a named client channel"""
237 237 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
238 238
239 239 def engine_url(self, channel):
240 240 """return full zmq url for a named engine channel"""
241 241 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
242 242
243 243 def init_hub(self):
244 244 """construct Hub object"""
245 245
246 246 ctx = self.context
247 247 loop = self.loop
248 248
249 249 try:
250 250 scheme = self.config.TaskScheduler.scheme_name
251 251 except AttributeError:
252 252 from .scheduler import TaskScheduler
253 253 scheme = TaskScheduler.scheme_name.get_default_value()
254 254
255 255 # build connection dicts
256 256 engine = self.engine_info = {
257 257 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
258 258 'registration' : self.regport,
259 259 'control' : self.control[1],
260 260 'mux' : self.mux[1],
261 261 'hb_ping' : self.hb[0],
262 262 'hb_pong' : self.hb[1],
263 263 'task' : self.task[1],
264 264 'iopub' : self.iopub[1],
265 265 }
266 266
267 267 client = self.client_info = {
268 268 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
269 269 'registration' : self.regport,
270 270 'control' : self.control[0],
271 271 'mux' : self.mux[0],
272 272 'task' : self.task[0],
273 273 'task_scheme' : scheme,
274 274 'iopub' : self.iopub[0],
275 275 'notification' : self.notifier_port,
276 276 }
277 277
278 278 self.log.debug("Hub engine addrs: %s", self.engine_info)
279 279 self.log.debug("Hub client addrs: %s", self.client_info)
280 280
281 281 # Registrar socket
282 282 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
283 283 q.bind(self.client_url('registration'))
284 284 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
285 285 if self.client_ip != self.engine_ip:
286 286 q.bind(self.engine_url('registration'))
287 287 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
288 288
289 289 ### Engine connections ###
290 290
291 291 # heartbeat
292 292 hpub = ctx.socket(zmq.PUB)
293 293 hpub.bind(self.engine_url('hb_ping'))
294 294 hrep = ctx.socket(zmq.ROUTER)
295 295 hrep.bind(self.engine_url('hb_pong'))
296 296 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
297 297 pingstream=ZMQStream(hpub,loop),
298 298 pongstream=ZMQStream(hrep,loop)
299 299 )
300 300
301 301 ### Client connections ###
302 302
303 303 # Notifier socket
304 304 n = ZMQStream(ctx.socket(zmq.PUB), loop)
305 305 n.bind(self.client_url('notification'))
306 306
307 307 ### build and launch the queues ###
308 308
309 309 # monitor socket
310 310 sub = ctx.socket(zmq.SUB)
311 311 sub.setsockopt(zmq.SUBSCRIBE, b"")
312 312 sub.bind(self.monitor_url)
313 313 sub.bind('inproc://monitor')
314 314 sub = ZMQStream(sub, loop)
315 315
316 316 # connect the db
317 317 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
318 318 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
319 319 self.db = import_item(str(db_class))(session=self.session.session,
320 320 config=self.config, log=self.log)
321 321 time.sleep(.25)
322 322
323 323 # resubmit stream
324 324 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
325 325 url = util.disambiguate_url(self.client_url('task'))
326 326 r.connect(url)
327 327
328 328 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
329 329 query=q, notifier=n, resubmit=r, db=self.db,
330 330 engine_info=self.engine_info, client_info=self.client_info,
331 331 log=self.log)
332 332
333 333
334 334 class Hub(SessionFactory):
335 335 """The IPython Controller Hub with 0MQ connections
336 336
337 337 Parameters
338 338 ==========
339 339 loop: zmq IOLoop instance
340 340 session: Session object
341 341 <removed> context: zmq context for creating new connections (?)
342 342 queue: ZMQStream for monitoring the command queue (SUB)
343 343 query: ZMQStream for engine registration and client queries requests (ROUTER)
344 344 heartbeat: HeartMonitor object checking the pulse of the engines
345 345 notifier: ZMQStream for broadcasting engine registration changes (PUB)
346 346 db: connection to db for out of memory logging of commands
347 347 NotImplemented
348 348 engine_info: dict of zmq connection information for engines to connect
349 349 to the queues.
350 350 client_info: dict of zmq connection information for engines to connect
351 351 to the queues.
352 352 """
353 353
354 354 engine_state_file = Unicode()
355 355
356 356 # internal data structures:
357 357 ids=Set() # engine IDs
358 358 keytable=Dict()
359 359 by_ident=Dict()
360 360 engines=Dict()
361 361 clients=Dict()
362 362 hearts=Dict()
363 363 pending=Set()
364 364 queues=Dict() # pending msg_ids keyed by engine_id
365 365 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
366 366 completed=Dict() # completed msg_ids keyed by engine_id
367 367 all_completed=Set() # completed msg_ids keyed by engine_id
368 368 dead_engines=Set() # completed msg_ids keyed by engine_id
369 369 unassigned=Set() # set of task msg_ds not yet assigned a destination
370 370 incoming_registrations=Dict()
371 371 registration_timeout=Integer()
372 372 _idcounter=Integer(0)
373 373
374 374 # objects from constructor:
375 375 query=Instance(ZMQStream)
376 376 monitor=Instance(ZMQStream)
377 377 notifier=Instance(ZMQStream)
378 378 resubmit=Instance(ZMQStream)
379 379 heartmonitor=Instance(HeartMonitor)
380 380 db=Instance(object)
381 381 client_info=Dict()
382 382 engine_info=Dict()
383 383
384 384
385 385 def __init__(self, **kwargs):
386 386 """
387 387 # universal:
388 388 loop: IOLoop for creating future connections
389 389 session: streamsession for sending serialized data
390 390 # engine:
391 391 queue: ZMQStream for monitoring queue messages
392 392 query: ZMQStream for engine+client registration and client requests
393 393 heartbeat: HeartMonitor object for tracking engines
394 394 # extra:
395 395 db: ZMQStream for db connection (NotImplemented)
396 396 engine_info: zmq address/protocol dict for engine connections
397 397 client_info: zmq address/protocol dict for client connections
398 398 """
399 399
400 400 super(Hub, self).__init__(**kwargs)
401 401 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
402 402
403 403 # register our callbacks
404 404 self.query.on_recv(self.dispatch_query)
405 405 self.monitor.on_recv(self.dispatch_monitor_traffic)
406 406
407 407 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
408 408 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
409 409
410 410 self.monitor_handlers = {b'in' : self.save_queue_request,
411 411 b'out': self.save_queue_result,
412 412 b'intask': self.save_task_request,
413 413 b'outtask': self.save_task_result,
414 414 b'tracktask': self.save_task_destination,
415 415 b'incontrol': _passer,
416 416 b'outcontrol': _passer,
417 417 b'iopub': self.save_iopub_message,
418 418 }
419 419
420 420 self.query_handlers = {'queue_request': self.queue_status,
421 421 'result_request': self.get_results,
422 422 'history_request': self.get_history,
423 423 'db_request': self.db_query,
424 424 'purge_request': self.purge_results,
425 425 'load_request': self.check_load,
426 426 'resubmit_request': self.resubmit_task,
427 427 'shutdown_request': self.shutdown_request,
428 428 'registration_request' : self.register_engine,
429 429 'unregistration_request' : self.unregister_engine,
430 430 'connection_request': self.connection_request,
431 431 }
432 432
433 433 # ignore resubmit replies
434 434 self.resubmit.on_recv(lambda msg: None, copy=False)
435 435
436 436 self.log.info("hub::created hub")
437 437
438 438 @property
439 439 def _next_id(self):
440 440 """gemerate a new ID.
441 441
442 442 No longer reuse old ids, just count from 0."""
443 443 newid = self._idcounter
444 444 self._idcounter += 1
445 445 return newid
446 446 # newid = 0
447 447 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
448 448 # # print newid, self.ids, self.incoming_registrations
449 449 # while newid in self.ids or newid in incoming:
450 450 # newid += 1
451 451 # return newid
452 452
453 453 #-----------------------------------------------------------------------------
454 454 # message validation
455 455 #-----------------------------------------------------------------------------
456 456
457 457 def _validate_targets(self, targets):
458 458 """turn any valid targets argument into a list of integer ids"""
459 459 if targets is None:
460 460 # default to all
461 461 return self.ids
462 462
463 463 if isinstance(targets, (int,str,unicode)):
464 464 # only one target specified
465 465 targets = [targets]
466 466 _targets = []
467 467 for t in targets:
468 468 # map raw identities to ids
469 469 if isinstance(t, (str,unicode)):
470 470 t = self.by_ident.get(cast_bytes(t), t)
471 471 _targets.append(t)
472 472 targets = _targets
473 473 bad_targets = [ t for t in targets if t not in self.ids ]
474 474 if bad_targets:
475 475 raise IndexError("No Such Engine: %r" % bad_targets)
476 476 if not targets:
477 477 raise IndexError("No Engines Registered")
478 478 return targets
479 479
480 480 #-----------------------------------------------------------------------------
481 481 # dispatch methods (1 per stream)
482 482 #-----------------------------------------------------------------------------
483 483
484 484
485 485 @util.log_errors
486 486 def dispatch_monitor_traffic(self, msg):
487 487 """all ME and Task queue messages come through here, as well as
488 488 IOPub traffic."""
489 489 self.log.debug("monitor traffic: %r", msg[0])
490 490 switch = msg[0]
491 491 try:
492 492 idents, msg = self.session.feed_identities(msg[1:])
493 493 except ValueError:
494 494 idents=[]
495 495 if not idents:
496 496 self.log.error("Monitor message without topic: %r", msg)
497 497 return
498 498 handler = self.monitor_handlers.get(switch, None)
499 499 if handler is not None:
500 500 handler(idents, msg)
501 501 else:
502 502 self.log.error("Unrecognized monitor topic: %r", switch)
503 503
504 504
505 505 @util.log_errors
506 506 def dispatch_query(self, msg):
507 507 """Route registration requests and queries from clients."""
508 508 try:
509 509 idents, msg = self.session.feed_identities(msg)
510 510 except ValueError:
511 511 idents = []
512 512 if not idents:
513 513 self.log.error("Bad Query Message: %r", msg)
514 514 return
515 515 client_id = idents[0]
516 516 try:
517 517 msg = self.session.unserialize(msg, content=True)
518 518 except Exception:
519 519 content = error.wrap_exception()
520 520 self.log.error("Bad Query Message: %r", msg, exc_info=True)
521 521 self.session.send(self.query, "hub_error", ident=client_id,
522 522 content=content)
523 523 return
524 524 # print client_id, header, parent, content
525 525 #switch on message type:
526 526 msg_type = msg['header']['msg_type']
527 527 self.log.info("client::client %r requested %r", client_id, msg_type)
528 528 handler = self.query_handlers.get(msg_type, None)
529 529 try:
530 530 assert handler is not None, "Bad Message Type: %r" % msg_type
531 531 except:
532 532 content = error.wrap_exception()
533 533 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
534 534 self.session.send(self.query, "hub_error", ident=client_id,
535 535 content=content)
536 536 return
537 537
538 538 else:
539 539 handler(idents, msg)
540 540
541 541 def dispatch_db(self, msg):
542 542 """"""
543 543 raise NotImplementedError
544 544
545 545 #---------------------------------------------------------------------------
546 546 # handler methods (1 per event)
547 547 #---------------------------------------------------------------------------
548 548
549 549 #----------------------- Heartbeat --------------------------------------
550 550
551 551 def handle_new_heart(self, heart):
552 552 """handler to attach to heartbeater.
553 553 Called when a new heart starts to beat.
554 554 Triggers completion of registration."""
555 555 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
556 556 if heart not in self.incoming_registrations:
557 557 self.log.info("heartbeat::ignoring new heart: %r", heart)
558 558 else:
559 559 self.finish_registration(heart)
560 560
561 561
562 562 def handle_heart_failure(self, heart):
563 563 """handler to attach to heartbeater.
564 564 called when a previously registered heart fails to respond to beat request.
565 565 triggers unregistration"""
566 566 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
567 567 eid = self.hearts.get(heart, None)
568 568 uuid = self.engines[eid].uuid
569 569 if eid is None or self.keytable[eid] in self.dead_engines:
570 570 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
571 571 else:
572 572 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
573 573
574 574 #----------------------- MUX Queue Traffic ------------------------------
575 575
576 576 def save_queue_request(self, idents, msg):
577 577 if len(idents) < 2:
578 578 self.log.error("invalid identity prefix: %r", idents)
579 579 return
580 580 queue_id, client_id = idents[:2]
581 581 try:
582 582 msg = self.session.unserialize(msg)
583 583 except Exception:
584 584 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
585 585 return
586 586
587 587 eid = self.by_ident.get(queue_id, None)
588 588 if eid is None:
589 589 self.log.error("queue::target %r not registered", queue_id)
590 590 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
591 591 return
592 592 record = init_record(msg)
593 593 msg_id = record['msg_id']
594 594 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
595 595 # Unicode in records
596 596 record['engine_uuid'] = queue_id.decode('ascii')
597 597 record['client_uuid'] = msg['header']['session']
598 598 record['queue'] = 'mux'
599 599
600 600 try:
601 601 # it's posible iopub arrived first:
602 602 existing = self.db.get_record(msg_id)
603 603 for key,evalue in existing.iteritems():
604 604 rvalue = record.get(key, None)
605 605 if evalue and rvalue and evalue != rvalue:
606 606 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
607 607 elif evalue and not rvalue:
608 608 record[key] = evalue
609 609 try:
610 610 self.db.update_record(msg_id, record)
611 611 except Exception:
612 612 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
613 613 except KeyError:
614 614 try:
615 615 self.db.add_record(msg_id, record)
616 616 except Exception:
617 617 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
618 618
619 619
620 620 self.pending.add(msg_id)
621 621 self.queues[eid].append(msg_id)
622 622
623 623 def save_queue_result(self, idents, msg):
624 624 if len(idents) < 2:
625 625 self.log.error("invalid identity prefix: %r", idents)
626 626 return
627 627
628 628 client_id, queue_id = idents[:2]
629 629 try:
630 630 msg = self.session.unserialize(msg)
631 631 except Exception:
632 632 self.log.error("queue::engine %r sent invalid message to %r: %r",
633 633 queue_id, client_id, msg, exc_info=True)
634 634 return
635 635
636 636 eid = self.by_ident.get(queue_id, None)
637 637 if eid is None:
638 638 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
639 639 return
640 640
641 641 parent = msg['parent_header']
642 642 if not parent:
643 643 return
644 644 msg_id = parent['msg_id']
645 645 if msg_id in self.pending:
646 646 self.pending.remove(msg_id)
647 647 self.all_completed.add(msg_id)
648 648 self.queues[eid].remove(msg_id)
649 649 self.completed[eid].append(msg_id)
650 650 self.log.info("queue::request %r completed on %s", msg_id, eid)
651 651 elif msg_id not in self.all_completed:
652 652 # it could be a result from a dead engine that died before delivering the
653 653 # result
654 654 self.log.warn("queue:: unknown msg finished %r", msg_id)
655 655 return
656 656 # update record anyway, because the unregistration could have been premature
657 657 rheader = msg['header']
658 658 completed = rheader['date']
659 659 started = rheader.get('started', None)
660 660 result = {
661 661 'result_header' : rheader,
662 662 'result_content': msg['content'],
663 663 'received': datetime.now(),
664 664 'started' : started,
665 665 'completed' : completed
666 666 }
667 667
668 668 result['result_buffers'] = msg['buffers']
669 669 try:
670 670 self.db.update_record(msg_id, result)
671 671 except Exception:
672 672 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
673 673
674 674
675 675 #--------------------- Task Queue Traffic ------------------------------
676 676
677 677 def save_task_request(self, idents, msg):
678 678 """Save the submission of a task."""
679 679 client_id = idents[0]
680 680
681 681 try:
682 682 msg = self.session.unserialize(msg)
683 683 except Exception:
684 684 self.log.error("task::client %r sent invalid task message: %r",
685 685 client_id, msg, exc_info=True)
686 686 return
687 687 record = init_record(msg)
688 688
689 689 record['client_uuid'] = msg['header']['session']
690 690 record['queue'] = 'task'
691 691 header = msg['header']
692 692 msg_id = header['msg_id']
693 693 self.pending.add(msg_id)
694 694 self.unassigned.add(msg_id)
695 695 try:
696 696 # it's posible iopub arrived first:
697 697 existing = self.db.get_record(msg_id)
698 698 if existing['resubmitted']:
699 699 for key in ('submitted', 'client_uuid', 'buffers'):
700 700 # don't clobber these keys on resubmit
701 701 # submitted and client_uuid should be different
702 702 # and buffers might be big, and shouldn't have changed
703 703 record.pop(key)
704 704 # still check content,header which should not change
705 705 # but are not expensive to compare as buffers
706 706
707 707 for key,evalue in existing.iteritems():
708 708 if key.endswith('buffers'):
709 709 # don't compare buffers
710 710 continue
711 711 rvalue = record.get(key, None)
712 712 if evalue and rvalue and evalue != rvalue:
713 713 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
714 714 elif evalue and not rvalue:
715 715 record[key] = evalue
716 716 try:
717 717 self.db.update_record(msg_id, record)
718 718 except Exception:
719 719 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
720 720 except KeyError:
721 721 try:
722 722 self.db.add_record(msg_id, record)
723 723 except Exception:
724 724 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
725 725 except Exception:
726 726 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
727 727
728 728 def save_task_result(self, idents, msg):
729 729 """save the result of a completed task."""
730 730 client_id = idents[0]
731 731 try:
732 732 msg = self.session.unserialize(msg)
733 733 except Exception:
734 734 self.log.error("task::invalid task result message send to %r: %r",
735 735 client_id, msg, exc_info=True)
736 736 return
737 737
738 738 parent = msg['parent_header']
739 739 if not parent:
740 740 # print msg
741 741 self.log.warn("Task %r had no parent!", msg)
742 742 return
743 743 msg_id = parent['msg_id']
744 744 if msg_id in self.unassigned:
745 745 self.unassigned.remove(msg_id)
746 746
747 747 header = msg['header']
748 748 engine_uuid = header.get('engine', u'')
749 749 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
750 750
751 751 status = header.get('status', None)
752 752
753 753 if msg_id in self.pending:
754 754 self.log.info("task::task %r finished on %s", msg_id, eid)
755 755 self.pending.remove(msg_id)
756 756 self.all_completed.add(msg_id)
757 757 if eid is not None:
758 758 if status != 'aborted':
759 759 self.completed[eid].append(msg_id)
760 760 if msg_id in self.tasks[eid]:
761 761 self.tasks[eid].remove(msg_id)
762 762 completed = header['date']
763 763 started = header.get('started', None)
764 764 result = {
765 765 'result_header' : header,
766 766 'result_content': msg['content'],
767 767 'started' : started,
768 768 'completed' : completed,
769 769 'received' : datetime.now(),
770 770 'engine_uuid': engine_uuid,
771 771 }
772 772
773 773 result['result_buffers'] = msg['buffers']
774 774 try:
775 775 self.db.update_record(msg_id, result)
776 776 except Exception:
777 777 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
778 778
779 779 else:
780 780 self.log.debug("task::unknown task %r finished", msg_id)
781 781
782 782 def save_task_destination(self, idents, msg):
783 783 try:
784 784 msg = self.session.unserialize(msg, content=True)
785 785 except Exception:
786 786 self.log.error("task::invalid task tracking message", exc_info=True)
787 787 return
788 788 content = msg['content']
789 789 # print (content)
790 790 msg_id = content['msg_id']
791 791 engine_uuid = content['engine_id']
792 792 eid = self.by_ident[cast_bytes(engine_uuid)]
793 793
794 794 self.log.info("task::task %r arrived on %r", msg_id, eid)
795 795 if msg_id in self.unassigned:
796 796 self.unassigned.remove(msg_id)
797 797 # else:
798 798 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
799 799
800 800 self.tasks[eid].append(msg_id)
801 801 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
802 802 try:
803 803 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
804 804 except Exception:
805 805 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
806 806
807 807
808 808 def mia_task_request(self, idents, msg):
809 809 raise NotImplementedError
810 810 client_id = idents[0]
811 811 # content = dict(mia=self.mia,status='ok')
812 812 # self.session.send('mia_reply', content=content, idents=client_id)
813 813
814 814
815 815 #--------------------- IOPub Traffic ------------------------------
816 816
817 817 def save_iopub_message(self, topics, msg):
818 818 """save an iopub message into the db"""
819 819 # print (topics)
820 820 try:
821 821 msg = self.session.unserialize(msg, content=True)
822 822 except Exception:
823 823 self.log.error("iopub::invalid IOPub message", exc_info=True)
824 824 return
825 825
826 826 parent = msg['parent_header']
827 827 if not parent:
828 828 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
829 829 return
830 830 msg_id = parent['msg_id']
831 831 msg_type = msg['header']['msg_type']
832 832 content = msg['content']
833 833
834 834 # ensure msg_id is in db
835 835 try:
836 836 rec = self.db.get_record(msg_id)
837 837 except KeyError:
838 838 rec = empty_record()
839 839 rec['msg_id'] = msg_id
840 840 self.db.add_record(msg_id, rec)
841 841 # stream
842 842 d = {}
843 843 if msg_type == 'stream':
844 844 name = content['name']
845 845 s = rec[name] or ''
846 846 d[name] = s + content['data']
847 847
848 848 elif msg_type == 'pyerr':
849 849 d['pyerr'] = content
850 850 elif msg_type == 'pyin':
851 851 d['pyin'] = content['code']
852 852 elif msg_type in ('display_data', 'pyout'):
853 853 d[msg_type] = content
854 854 elif msg_type == 'status':
855 855 pass
856 856 else:
857 857 self.log.warn("unhandled iopub msg_type: %r", msg_type)
858 858
859 859 if not d:
860 860 return
861 861
862 862 try:
863 863 self.db.update_record(msg_id, d)
864 864 except Exception:
865 865 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
866 866
867 867
868 868
869 869 #-------------------------------------------------------------------------
870 870 # Registration requests
871 871 #-------------------------------------------------------------------------
872 872
873 873 def connection_request(self, client_id, msg):
874 874 """Reply with connection addresses for clients."""
875 875 self.log.info("client::client %r connected", client_id)
876 876 content = dict(status='ok')
877 877 jsonable = {}
878 878 for k,v in self.keytable.iteritems():
879 879 if v not in self.dead_engines:
880 880 jsonable[str(k)] = v
881 881 content['engines'] = jsonable
882 882 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
883 883
884 884 def register_engine(self, reg, msg):
885 885 """Register a new engine."""
886 886 content = msg['content']
887 887 try:
888 888 uuid = content['uuid']
889 889 except KeyError:
890 890 self.log.error("registration::queue not specified", exc_info=True)
891 891 return
892 892
893 893 eid = self._next_id
894 894
895 895 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
896 896
897 897 content = dict(id=eid,status='ok')
898 898 # check if requesting available IDs:
899 899 if uuid in self.by_ident:
900 900 try:
901 901 raise KeyError("uuid %r in use" % uuid)
902 902 except:
903 903 content = error.wrap_exception()
904 904 self.log.error("uuid %r in use", uuid, exc_info=True)
905 905 else:
906 906 for h, ec in self.incoming_registrations.iteritems():
907 907 if uuid == h:
908 908 try:
909 909 raise KeyError("heart_id %r in use" % uuid)
910 910 except:
911 911 self.log.error("heart_id %r in use", uuid, exc_info=True)
912 912 content = error.wrap_exception()
913 913 break
914 914 elif uuid == ec.uuid:
915 915 try:
916 916 raise KeyError("uuid %r in use" % uuid)
917 917 except:
918 918 self.log.error("uuid %r in use", uuid, exc_info=True)
919 919 content = error.wrap_exception()
920 920 break
921 921
922 922 msg = self.session.send(self.query, "registration_reply",
923 923 content=content,
924 924 ident=reg)
925 925
926 heart = util.asbytes(uuid)
926 heart = cast_bytes(uuid)
927 927
928 928 if content['status'] == 'ok':
929 929 if heart in self.heartmonitor.hearts:
930 930 # already beating
931 931 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
932 932 self.finish_registration(heart)
933 933 else:
934 934 purge = lambda : self._purge_stalled_registration(heart)
935 935 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
936 936 dc.start()
937 937 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
938 938 else:
939 939 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
940 940
941 941 return eid
942 942
943 943 def unregister_engine(self, ident, msg):
944 944 """Unregister an engine that explicitly requested to leave."""
945 945 try:
946 946 eid = msg['content']['id']
947 947 except:
948 948 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
949 949 return
950 950 self.log.info("registration::unregister_engine(%r)", eid)
951 951 # print (eid)
952 952 uuid = self.keytable[eid]
953 953 content=dict(id=eid, uuid=uuid)
954 954 self.dead_engines.add(uuid)
955 955 # self.ids.remove(eid)
956 956 # uuid = self.keytable.pop(eid)
957 957 #
958 958 # ec = self.engines.pop(eid)
959 959 # self.hearts.pop(ec.heartbeat)
960 960 # self.by_ident.pop(ec.queue)
961 961 # self.completed.pop(eid)
962 962 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
963 963 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
964 964 dc.start()
965 965 ############## TODO: HANDLE IT ################
966 966
967 967 self._save_engine_state()
968 968
969 969 if self.notifier:
970 970 self.session.send(self.notifier, "unregistration_notification", content=content)
971 971
972 972 def _handle_stranded_msgs(self, eid, uuid):
973 973 """Handle messages known to be on an engine when the engine unregisters.
974 974
975 975 It is possible that this will fire prematurely - that is, an engine will
976 976 go down after completing a result, and the client will be notified
977 977 that the result failed and later receive the actual result.
978 978 """
979 979
980 980 outstanding = self.queues[eid]
981 981
982 982 for msg_id in outstanding:
983 983 self.pending.remove(msg_id)
984 984 self.all_completed.add(msg_id)
985 985 try:
986 986 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
987 987 except:
988 988 content = error.wrap_exception()
989 989 # build a fake header:
990 990 header = {}
991 991 header['engine'] = uuid
992 992 header['date'] = datetime.now()
993 993 rec = dict(result_content=content, result_header=header, result_buffers=[])
994 994 rec['completed'] = header['date']
995 995 rec['engine_uuid'] = uuid
996 996 try:
997 997 self.db.update_record(msg_id, rec)
998 998 except Exception:
999 999 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1000 1000
1001 1001
1002 1002 def finish_registration(self, heart):
1003 1003 """Second half of engine registration, called after our HeartMonitor
1004 1004 has received a beat from the Engine's Heart."""
1005 1005 try:
1006 1006 ec = self.incoming_registrations.pop(heart)
1007 1007 except KeyError:
1008 1008 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1009 1009 return
1010 1010 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1011 1011 if ec.stallback is not None:
1012 1012 ec.stallback.stop()
1013 1013 eid = ec.id
1014 1014 self.ids.add(eid)
1015 1015 self.keytable[eid] = ec.uuid
1016 1016 self.engines[eid] = ec
1017 1017 self.by_ident[ec.uuid] = ec.id
1018 1018 self.queues[eid] = list()
1019 1019 self.tasks[eid] = list()
1020 1020 self.completed[eid] = list()
1021 1021 self.hearts[heart] = eid
1022 1022 content = dict(id=eid, uuid=self.engines[eid].uuid)
1023 1023 if self.notifier:
1024 1024 self.session.send(self.notifier, "registration_notification", content=content)
1025 1025 self.log.info("engine::Engine Connected: %i", eid)
1026 1026
1027 1027 self._save_engine_state()
1028 1028
1029 1029 def _purge_stalled_registration(self, heart):
1030 1030 if heart in self.incoming_registrations:
1031 1031 ec = self.incoming_registrations.pop(heart)
1032 1032 self.log.info("registration::purging stalled registration: %i", ec.id)
1033 1033 else:
1034 1034 pass
1035 1035
1036 1036 #-------------------------------------------------------------------------
1037 1037 # Engine State
1038 1038 #-------------------------------------------------------------------------
1039 1039
1040 1040
1041 1041 def _cleanup_engine_state_file(self):
1042 1042 """cleanup engine state mapping"""
1043 1043
1044 1044 if os.path.exists(self.engine_state_file):
1045 1045 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1046 1046 try:
1047 1047 os.remove(self.engine_state_file)
1048 1048 except IOError:
1049 1049 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1050 1050
1051 1051
1052 1052 def _save_engine_state(self):
1053 1053 """save engine mapping to JSON file"""
1054 1054 if not self.engine_state_file:
1055 1055 return
1056 1056 self.log.debug("save engine state to %s" % self.engine_state_file)
1057 1057 state = {}
1058 1058 engines = {}
1059 1059 for eid, ec in self.engines.iteritems():
1060 1060 if ec.uuid not in self.dead_engines:
1061 1061 engines[eid] = ec.uuid
1062 1062
1063 1063 state['engines'] = engines
1064 1064
1065 1065 state['next_id'] = self._idcounter
1066 1066
1067 1067 with open(self.engine_state_file, 'w') as f:
1068 1068 json.dump(state, f)
1069 1069
1070 1070
1071 1071 def _load_engine_state(self):
1072 1072 """load engine mapping from JSON file"""
1073 1073 if not os.path.exists(self.engine_state_file):
1074 1074 return
1075 1075
1076 1076 self.log.info("loading engine state from %s" % self.engine_state_file)
1077 1077
1078 1078 with open(self.engine_state_file) as f:
1079 1079 state = json.load(f)
1080 1080
1081 1081 save_notifier = self.notifier
1082 1082 self.notifier = None
1083 1083 for eid, uuid in state['engines'].iteritems():
1084 1084 heart = uuid.encode('ascii')
1085 1085 # start with this heart as current and beating:
1086 1086 self.heartmonitor.responses.add(heart)
1087 1087 self.heartmonitor.hearts.add(heart)
1088 1088
1089 1089 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1090 1090 self.finish_registration(heart)
1091 1091
1092 1092 self.notifier = save_notifier
1093 1093
1094 1094 self._idcounter = state['next_id']
1095 1095
1096 1096 #-------------------------------------------------------------------------
1097 1097 # Client Requests
1098 1098 #-------------------------------------------------------------------------
1099 1099
1100 1100 def shutdown_request(self, client_id, msg):
1101 1101 """handle shutdown request."""
1102 1102 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1103 1103 # also notify other clients of shutdown
1104 1104 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1105 1105 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1106 1106 dc.start()
1107 1107
1108 1108 def _shutdown(self):
1109 1109 self.log.info("hub::hub shutting down.")
1110 1110 time.sleep(0.1)
1111 1111 sys.exit(0)
1112 1112
1113 1113
1114 1114 def check_load(self, client_id, msg):
1115 1115 content = msg['content']
1116 1116 try:
1117 1117 targets = content['targets']
1118 1118 targets = self._validate_targets(targets)
1119 1119 except:
1120 1120 content = error.wrap_exception()
1121 1121 self.session.send(self.query, "hub_error",
1122 1122 content=content, ident=client_id)
1123 1123 return
1124 1124
1125 1125 content = dict(status='ok')
1126 1126 # loads = {}
1127 1127 for t in targets:
1128 1128 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1129 1129 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1130 1130
1131 1131
1132 1132 def queue_status(self, client_id, msg):
1133 1133 """Return the Queue status of one or more targets.
1134 1134 if verbose: return the msg_ids
1135 1135 else: return len of each type.
1136 1136 keys: queue (pending MUX jobs)
1137 1137 tasks (pending Task jobs)
1138 1138 completed (finished jobs from both queues)"""
1139 1139 content = msg['content']
1140 1140 targets = content['targets']
1141 1141 try:
1142 1142 targets = self._validate_targets(targets)
1143 1143 except:
1144 1144 content = error.wrap_exception()
1145 1145 self.session.send(self.query, "hub_error",
1146 1146 content=content, ident=client_id)
1147 1147 return
1148 1148 verbose = content.get('verbose', False)
1149 1149 content = dict(status='ok')
1150 1150 for t in targets:
1151 1151 queue = self.queues[t]
1152 1152 completed = self.completed[t]
1153 1153 tasks = self.tasks[t]
1154 1154 if not verbose:
1155 1155 queue = len(queue)
1156 1156 completed = len(completed)
1157 1157 tasks = len(tasks)
1158 1158 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1159 1159 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1160 1160 # print (content)
1161 1161 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1162 1162
1163 1163 def purge_results(self, client_id, msg):
1164 1164 """Purge results from memory. This method is more valuable before we move
1165 1165 to a DB based message storage mechanism."""
1166 1166 content = msg['content']
1167 1167 self.log.info("Dropping records with %s", content)
1168 1168 msg_ids = content.get('msg_ids', [])
1169 1169 reply = dict(status='ok')
1170 1170 if msg_ids == 'all':
1171 1171 try:
1172 1172 self.db.drop_matching_records(dict(completed={'$ne':None}))
1173 1173 except Exception:
1174 1174 reply = error.wrap_exception()
1175 1175 else:
1176 1176 pending = filter(lambda m: m in self.pending, msg_ids)
1177 1177 if pending:
1178 1178 try:
1179 1179 raise IndexError("msg pending: %r" % pending[0])
1180 1180 except:
1181 1181 reply = error.wrap_exception()
1182 1182 else:
1183 1183 try:
1184 1184 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1185 1185 except Exception:
1186 1186 reply = error.wrap_exception()
1187 1187
1188 1188 if reply['status'] == 'ok':
1189 1189 eids = content.get('engine_ids', [])
1190 1190 for eid in eids:
1191 1191 if eid not in self.engines:
1192 1192 try:
1193 1193 raise IndexError("No such engine: %i" % eid)
1194 1194 except:
1195 1195 reply = error.wrap_exception()
1196 1196 break
1197 1197 uid = self.engines[eid].uuid
1198 1198 try:
1199 1199 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1200 1200 except Exception:
1201 1201 reply = error.wrap_exception()
1202 1202 break
1203 1203
1204 1204 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1205 1205
1206 1206 def resubmit_task(self, client_id, msg):
1207 1207 """Resubmit one or more tasks."""
1208 1208 def finish(reply):
1209 1209 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1210 1210
1211 1211 content = msg['content']
1212 1212 msg_ids = content['msg_ids']
1213 1213 reply = dict(status='ok')
1214 1214 try:
1215 1215 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1216 1216 'header', 'content', 'buffers'])
1217 1217 except Exception:
1218 1218 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1219 1219 return finish(error.wrap_exception())
1220 1220
1221 1221 # validate msg_ids
1222 1222 found_ids = [ rec['msg_id'] for rec in records ]
1223 1223 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1224 1224 if len(records) > len(msg_ids):
1225 1225 try:
1226 1226 raise RuntimeError("DB appears to be in an inconsistent state."
1227 1227 "More matching records were found than should exist")
1228 1228 except Exception:
1229 1229 return finish(error.wrap_exception())
1230 1230 elif len(records) < len(msg_ids):
1231 1231 missing = [ m for m in msg_ids if m not in found_ids ]
1232 1232 try:
1233 1233 raise KeyError("No such msg(s): %r" % missing)
1234 1234 except KeyError:
1235 1235 return finish(error.wrap_exception())
1236 1236 elif pending_ids:
1237 1237 pass
1238 1238 # no need to raise on resubmit of pending task, now that we
1239 1239 # resubmit under new ID, but do we want to raise anyway?
1240 1240 # msg_id = invalid_ids[0]
1241 1241 # try:
1242 1242 # raise ValueError("Task(s) %r appears to be inflight" % )
1243 1243 # except Exception:
1244 1244 # return finish(error.wrap_exception())
1245 1245
1246 1246 # mapping of original IDs to resubmitted IDs
1247 1247 resubmitted = {}
1248 1248
1249 1249 # send the messages
1250 1250 for rec in records:
1251 1251 header = rec['header']
1252 1252 msg = self.session.msg(header['msg_type'], parent=header)
1253 1253 msg_id = msg['msg_id']
1254 1254 msg['content'] = rec['content']
1255 1255
1256 1256 # use the old header, but update msg_id and timestamp
1257 1257 fresh = msg['header']
1258 1258 header['msg_id'] = fresh['msg_id']
1259 1259 header['date'] = fresh['date']
1260 1260 msg['header'] = header
1261 1261
1262 1262 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1263 1263
1264 1264 resubmitted[rec['msg_id']] = msg_id
1265 1265 self.pending.add(msg_id)
1266 1266 msg['buffers'] = rec['buffers']
1267 1267 try:
1268 1268 self.db.add_record(msg_id, init_record(msg))
1269 1269 except Exception:
1270 1270 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1271 1271 return finish(error.wrap_exception())
1272 1272
1273 1273 finish(dict(status='ok', resubmitted=resubmitted))
1274 1274
1275 1275 # store the new IDs in the Task DB
1276 1276 for msg_id, resubmit_id in resubmitted.iteritems():
1277 1277 try:
1278 1278 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1279 1279 except Exception:
1280 1280 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1281 1281
1282 1282
1283 1283 def _extract_record(self, rec):
1284 1284 """decompose a TaskRecord dict into subsection of reply for get_result"""
1285 1285 io_dict = {}
1286 1286 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1287 1287 io_dict[key] = rec[key]
1288 1288 content = { 'result_content': rec['result_content'],
1289 1289 'header': rec['header'],
1290 1290 'result_header' : rec['result_header'],
1291 1291 'received' : rec['received'],
1292 1292 'io' : io_dict,
1293 1293 }
1294 1294 if rec['result_buffers']:
1295 1295 buffers = map(bytes, rec['result_buffers'])
1296 1296 else:
1297 1297 buffers = []
1298 1298
1299 1299 return content, buffers
1300 1300
1301 1301 def get_results(self, client_id, msg):
1302 1302 """Get the result of 1 or more messages."""
1303 1303 content = msg['content']
1304 1304 msg_ids = sorted(set(content['msg_ids']))
1305 1305 statusonly = content.get('status_only', False)
1306 1306 pending = []
1307 1307 completed = []
1308 1308 content = dict(status='ok')
1309 1309 content['pending'] = pending
1310 1310 content['completed'] = completed
1311 1311 buffers = []
1312 1312 if not statusonly:
1313 1313 try:
1314 1314 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1315 1315 # turn match list into dict, for faster lookup
1316 1316 records = {}
1317 1317 for rec in matches:
1318 1318 records[rec['msg_id']] = rec
1319 1319 except Exception:
1320 1320 content = error.wrap_exception()
1321 1321 self.session.send(self.query, "result_reply", content=content,
1322 1322 parent=msg, ident=client_id)
1323 1323 return
1324 1324 else:
1325 1325 records = {}
1326 1326 for msg_id in msg_ids:
1327 1327 if msg_id in self.pending:
1328 1328 pending.append(msg_id)
1329 1329 elif msg_id in self.all_completed:
1330 1330 completed.append(msg_id)
1331 1331 if not statusonly:
1332 1332 c,bufs = self._extract_record(records[msg_id])
1333 1333 content[msg_id] = c
1334 1334 buffers.extend(bufs)
1335 1335 elif msg_id in records:
1336 1336 if rec['completed']:
1337 1337 completed.append(msg_id)
1338 1338 c,bufs = self._extract_record(records[msg_id])
1339 1339 content[msg_id] = c
1340 1340 buffers.extend(bufs)
1341 1341 else:
1342 1342 pending.append(msg_id)
1343 1343 else:
1344 1344 try:
1345 1345 raise KeyError('No such message: '+msg_id)
1346 1346 except:
1347 1347 content = error.wrap_exception()
1348 1348 break
1349 1349 self.session.send(self.query, "result_reply", content=content,
1350 1350 parent=msg, ident=client_id,
1351 1351 buffers=buffers)
1352 1352
1353 1353 def get_history(self, client_id, msg):
1354 1354 """Get a list of all msg_ids in our DB records"""
1355 1355 try:
1356 1356 msg_ids = self.db.get_history()
1357 1357 except Exception as e:
1358 1358 content = error.wrap_exception()
1359 1359 else:
1360 1360 content = dict(status='ok', history=msg_ids)
1361 1361
1362 1362 self.session.send(self.query, "history_reply", content=content,
1363 1363 parent=msg, ident=client_id)
1364 1364
1365 1365 def db_query(self, client_id, msg):
1366 1366 """Perform a raw query on the task record database."""
1367 1367 content = msg['content']
1368 1368 query = content.get('query', {})
1369 1369 keys = content.get('keys', None)
1370 1370 buffers = []
1371 1371 empty = list()
1372 1372 try:
1373 1373 records = self.db.find_records(query, keys)
1374 1374 except Exception as e:
1375 1375 content = error.wrap_exception()
1376 1376 else:
1377 1377 # extract buffers from reply content:
1378 1378 if keys is not None:
1379 1379 buffer_lens = [] if 'buffers' in keys else None
1380 1380 result_buffer_lens = [] if 'result_buffers' in keys else None
1381 1381 else:
1382 1382 buffer_lens = None
1383 1383 result_buffer_lens = None
1384 1384
1385 1385 for rec in records:
1386 1386 # buffers may be None, so double check
1387 1387 b = rec.pop('buffers', empty) or empty
1388 1388 if buffer_lens is not None:
1389 1389 buffer_lens.append(len(b))
1390 1390 buffers.extend(b)
1391 1391 rb = rec.pop('result_buffers', empty) or empty
1392 1392 if result_buffer_lens is not None:
1393 1393 result_buffer_lens.append(len(rb))
1394 1394 buffers.extend(rb)
1395 1395 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1396 1396 result_buffer_lens=result_buffer_lens)
1397 1397 # self.log.debug (content)
1398 1398 self.session.send(self.query, "db_reply", content=content,
1399 1399 parent=msg, ident=client_id,
1400 1400 buffers=buffers)
1401 1401
@@ -1,794 +1,794 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26 import time
27 27
28 28 from datetime import datetime, timedelta
29 29 from random import randint, random
30 30 from types import FunctionType
31 31
32 32 try:
33 33 import numpy
34 34 except ImportError:
35 35 numpy = None
36 36
37 37 import zmq
38 38 from zmq.eventloop import ioloop, zmqstream
39 39
40 40 # local imports
41 41 from IPython.external.decorator import decorator
42 42 from IPython.config.application import Application
43 43 from IPython.config.loader import Config
44 44 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
45 45 from IPython.utils.py3compat import cast_bytes
46 46
47 47 from IPython.parallel import error, util
48 48 from IPython.parallel.factory import SessionFactory
49 49 from IPython.parallel.util import connect_logger, local_logger
50 50
51 51 from .dependency import Dependency
52 52
53 53 @decorator
54 54 def logged(f,self,*args,**kwargs):
55 55 # print ("#--------------------")
56 56 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
57 57 # print ("#--")
58 58 return f(self,*args, **kwargs)
59 59
60 60 #----------------------------------------------------------------------
61 61 # Chooser functions
62 62 #----------------------------------------------------------------------
63 63
64 64 def plainrandom(loads):
65 65 """Plain random pick."""
66 66 n = len(loads)
67 67 return randint(0,n-1)
68 68
69 69 def lru(loads):
70 70 """Always pick the front of the line.
71 71
72 72 The content of `loads` is ignored.
73 73
74 74 Assumes LRU ordering of loads, with oldest first.
75 75 """
76 76 return 0
77 77
78 78 def twobin(loads):
79 79 """Pick two at random, use the LRU of the two.
80 80
81 81 The content of loads is ignored.
82 82
83 83 Assumes LRU ordering of loads, with oldest first.
84 84 """
85 85 n = len(loads)
86 86 a = randint(0,n-1)
87 87 b = randint(0,n-1)
88 88 return min(a,b)
89 89
90 90 def weighted(loads):
91 91 """Pick two at random using inverse load as weight.
92 92
93 93 Return the less loaded of the two.
94 94 """
95 95 # weight 0 a million times more than 1:
96 96 weights = 1./(1e-6+numpy.array(loads))
97 97 sums = weights.cumsum()
98 98 t = sums[-1]
99 99 x = random()*t
100 100 y = random()*t
101 101 idx = 0
102 102 idy = 0
103 103 while sums[idx] < x:
104 104 idx += 1
105 105 while sums[idy] < y:
106 106 idy += 1
107 107 if weights[idy] > weights[idx]:
108 108 return idy
109 109 else:
110 110 return idx
111 111
112 112 def leastload(loads):
113 113 """Always choose the lowest load.
114 114
115 115 If the lowest load occurs more than once, the first
116 116 occurance will be used. If loads has LRU ordering, this means
117 117 the LRU of those with the lowest load is chosen.
118 118 """
119 119 return loads.index(min(loads))
120 120
121 121 #---------------------------------------------------------------------
122 122 # Classes
123 123 #---------------------------------------------------------------------
124 124
125 125
126 126 # store empty default dependency:
127 127 MET = Dependency([])
128 128
129 129
130 130 class Job(object):
131 131 """Simple container for a job"""
132 132 def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout):
133 133 self.msg_id = msg_id
134 134 self.raw_msg = raw_msg
135 135 self.idents = idents
136 136 self.msg = msg
137 137 self.header = header
138 138 self.targets = targets
139 139 self.after = after
140 140 self.follow = follow
141 141 self.timeout = timeout
142 142
143 143
144 144 self.timestamp = time.time()
145 145 self.blacklist = set()
146 146
147 147 @property
148 148 def dependents(self):
149 149 return self.follow.union(self.after)
150 150
151 151 class TaskScheduler(SessionFactory):
152 152 """Python TaskScheduler object.
153 153
154 154 This is the simplest object that supports msg_id based
155 155 DAG dependencies. *Only* task msg_ids are checked, not
156 156 msg_ids of jobs submitted via the MUX queue.
157 157
158 158 """
159 159
160 160 hwm = Integer(1, config=True,
161 161 help="""specify the High Water Mark (HWM) for the downstream
162 162 socket in the Task scheduler. This is the maximum number
163 163 of allowed outstanding tasks on each engine.
164 164
165 165 The default (1) means that only one task can be outstanding on each
166 166 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
167 167 engines continue to be assigned tasks while they are working,
168 168 effectively hiding network latency behind computation, but can result
169 169 in an imbalance of work when submitting many heterogenous tasks all at
170 170 once. Any positive value greater than one is a compromise between the
171 171 two.
172 172
173 173 """
174 174 )
175 175 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
176 176 'leastload', config=True, allow_none=False,
177 177 help="""select the task scheduler scheme [default: Python LRU]
178 178 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
179 179 )
180 180 def _scheme_name_changed(self, old, new):
181 181 self.log.debug("Using scheme %r"%new)
182 182 self.scheme = globals()[new]
183 183
184 184 # input arguments:
185 185 scheme = Instance(FunctionType) # function for determining the destination
186 186 def _scheme_default(self):
187 187 return leastload
188 188 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
189 189 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
190 190 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
191 191 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
192 192 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
193 193
194 194 # internals:
195 195 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
196 196 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
197 197 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
198 198 depending = Dict() # dict by msg_id of Jobs
199 199 pending = Dict() # dict by engine_uuid of submitted tasks
200 200 completed = Dict() # dict by engine_uuid of completed tasks
201 201 failed = Dict() # dict by engine_uuid of failed tasks
202 202 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
203 203 clients = Dict() # dict by msg_id for who submitted the task
204 204 targets = List() # list of target IDENTs
205 205 loads = List() # list of engine loads
206 206 # full = Set() # set of IDENTs that have HWM outstanding tasks
207 207 all_completed = Set() # set of all completed tasks
208 208 all_failed = Set() # set of all failed tasks
209 209 all_done = Set() # set of all finished tasks=union(completed,failed)
210 210 all_ids = Set() # set of all submitted task IDs
211 211
212 212 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
213 213
214 214 ident = CBytes() # ZMQ identity. This should just be self.session.session
215 215 # but ensure Bytes
216 216 def _ident_default(self):
217 217 return self.session.bsession
218 218
219 219 def start(self):
220 220 self.query_stream.on_recv(self.dispatch_query_reply)
221 221 self.session.send(self.query_stream, "connection_request", {})
222 222
223 223 self.engine_stream.on_recv(self.dispatch_result, copy=False)
224 224 self.client_stream.on_recv(self.dispatch_submission, copy=False)
225 225
226 226 self._notification_handlers = dict(
227 227 registration_notification = self._register_engine,
228 228 unregistration_notification = self._unregister_engine
229 229 )
230 230 self.notifier_stream.on_recv(self.dispatch_notification)
231 231 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
232 232 self.auditor.start()
233 233 self.log.info("Scheduler started [%s]"%self.scheme_name)
234 234
235 235 def resume_receiving(self):
236 236 """Resume accepting jobs."""
237 237 self.client_stream.on_recv(self.dispatch_submission, copy=False)
238 238
239 239 def stop_receiving(self):
240 240 """Stop accepting jobs while there are no engines.
241 241 Leave them in the ZMQ queue."""
242 242 self.client_stream.on_recv(None)
243 243
244 244 #-----------------------------------------------------------------------
245 245 # [Un]Registration Handling
246 246 #-----------------------------------------------------------------------
247 247
248 248
249 249 def dispatch_query_reply(self, msg):
250 250 """handle reply to our initial connection request"""
251 251 try:
252 252 idents,msg = self.session.feed_identities(msg)
253 253 except ValueError:
254 254 self.log.warn("task::Invalid Message: %r",msg)
255 255 return
256 256 try:
257 257 msg = self.session.unserialize(msg)
258 258 except ValueError:
259 259 self.log.warn("task::Unauthorized message from: %r"%idents)
260 260 return
261 261
262 262 content = msg['content']
263 263 for uuid in content.get('engines', {}).values():
264 self._register_engine(asbytes(uuid))
264 self._register_engine(cast_bytes(uuid))
265 265
266 266
267 267 @util.log_errors
268 268 def dispatch_notification(self, msg):
269 269 """dispatch register/unregister events."""
270 270 try:
271 271 idents,msg = self.session.feed_identities(msg)
272 272 except ValueError:
273 273 self.log.warn("task::Invalid Message: %r",msg)
274 274 return
275 275 try:
276 276 msg = self.session.unserialize(msg)
277 277 except ValueError:
278 278 self.log.warn("task::Unauthorized message from: %r"%idents)
279 279 return
280 280
281 281 msg_type = msg['header']['msg_type']
282 282
283 283 handler = self._notification_handlers.get(msg_type, None)
284 284 if handler is None:
285 285 self.log.error("Unhandled message type: %r"%msg_type)
286 286 else:
287 287 try:
288 288 handler(cast_bytes(msg['content']['uuid']))
289 289 except Exception:
290 290 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
291 291
292 292 def _register_engine(self, uid):
293 293 """New engine with ident `uid` became available."""
294 294 # head of the line:
295 295 self.targets.insert(0,uid)
296 296 self.loads.insert(0,0)
297 297
298 298 # initialize sets
299 299 self.completed[uid] = set()
300 300 self.failed[uid] = set()
301 301 self.pending[uid] = {}
302 302
303 303 # rescan the graph:
304 304 self.update_graph(None)
305 305
306 306 def _unregister_engine(self, uid):
307 307 """Existing engine with ident `uid` became unavailable."""
308 308 if len(self.targets) == 1:
309 309 # this was our only engine
310 310 pass
311 311
312 312 # handle any potentially finished tasks:
313 313 self.engine_stream.flush()
314 314
315 315 # don't pop destinations, because they might be used later
316 316 # map(self.destinations.pop, self.completed.pop(uid))
317 317 # map(self.destinations.pop, self.failed.pop(uid))
318 318
319 319 # prevent this engine from receiving work
320 320 idx = self.targets.index(uid)
321 321 self.targets.pop(idx)
322 322 self.loads.pop(idx)
323 323
324 324 # wait 5 seconds before cleaning up pending jobs, since the results might
325 325 # still be incoming
326 326 if self.pending[uid]:
327 327 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
328 328 dc.start()
329 329 else:
330 330 self.completed.pop(uid)
331 331 self.failed.pop(uid)
332 332
333 333
334 334 def handle_stranded_tasks(self, engine):
335 335 """Deal with jobs resident in an engine that died."""
336 336 lost = self.pending[engine]
337 337 for msg_id in lost.keys():
338 338 if msg_id not in self.pending[engine]:
339 339 # prevent double-handling of messages
340 340 continue
341 341
342 342 raw_msg = lost[msg_id].raw_msg
343 343 idents,msg = self.session.feed_identities(raw_msg, copy=False)
344 344 parent = self.session.unpack(msg[1].bytes)
345 345 idents = [engine, idents[0]]
346 346
347 347 # build fake error reply
348 348 try:
349 349 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
350 350 except:
351 351 content = error.wrap_exception()
352 352 # build fake header
353 353 header = dict(
354 354 status='error',
355 355 engine=engine,
356 356 date=datetime.now(),
357 357 )
358 358 msg = self.session.msg('apply_reply', content, parent=parent, subheader=header)
359 359 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
360 360 # and dispatch it
361 361 self.dispatch_result(raw_reply)
362 362
363 363 # finally scrub completed/failed lists
364 364 self.completed.pop(engine)
365 365 self.failed.pop(engine)
366 366
367 367
368 368 #-----------------------------------------------------------------------
369 369 # Job Submission
370 370 #-----------------------------------------------------------------------
371 371
372 372
373 373 @util.log_errors
374 374 def dispatch_submission(self, raw_msg):
375 375 """Dispatch job submission to appropriate handlers."""
376 376 # ensure targets up to date:
377 377 self.notifier_stream.flush()
378 378 try:
379 379 idents, msg = self.session.feed_identities(raw_msg, copy=False)
380 380 msg = self.session.unserialize(msg, content=False, copy=False)
381 381 except Exception:
382 382 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
383 383 return
384 384
385 385
386 386 # send to monitor
387 387 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
388 388
389 389 header = msg['header']
390 390 msg_id = header['msg_id']
391 391 self.all_ids.add(msg_id)
392 392
393 393 # get targets as a set of bytes objects
394 394 # from a list of unicode objects
395 395 targets = header.get('targets', [])
396 396 targets = map(cast_bytes, targets)
397 397 targets = set(targets)
398 398
399 399 retries = header.get('retries', 0)
400 400 self.retries[msg_id] = retries
401 401
402 402 # time dependencies
403 403 after = header.get('after', None)
404 404 if after:
405 405 after = Dependency(after)
406 406 if after.all:
407 407 if after.success:
408 408 after = Dependency(after.difference(self.all_completed),
409 409 success=after.success,
410 410 failure=after.failure,
411 411 all=after.all,
412 412 )
413 413 if after.failure:
414 414 after = Dependency(after.difference(self.all_failed),
415 415 success=after.success,
416 416 failure=after.failure,
417 417 all=after.all,
418 418 )
419 419 if after.check(self.all_completed, self.all_failed):
420 420 # recast as empty set, if `after` already met,
421 421 # to prevent unnecessary set comparisons
422 422 after = MET
423 423 else:
424 424 after = MET
425 425
426 426 # location dependencies
427 427 follow = Dependency(header.get('follow', []))
428 428
429 429 # turn timeouts into datetime objects:
430 430 timeout = header.get('timeout', None)
431 431 if timeout:
432 432 # cast to float, because jsonlib returns floats as decimal.Decimal,
433 433 # which timedelta does not accept
434 434 timeout = datetime.now() + timedelta(0,float(timeout),0)
435 435
436 436 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
437 437 header=header, targets=targets, after=after, follow=follow,
438 438 timeout=timeout,
439 439 )
440 440
441 441 # validate and reduce dependencies:
442 442 for dep in after,follow:
443 443 if not dep: # empty dependency
444 444 continue
445 445 # check valid:
446 446 if msg_id in dep or dep.difference(self.all_ids):
447 447 self.depending[msg_id] = job
448 448 return self.fail_unreachable(msg_id, error.InvalidDependency)
449 449 # check if unreachable:
450 450 if dep.unreachable(self.all_completed, self.all_failed):
451 451 self.depending[msg_id] = job
452 452 return self.fail_unreachable(msg_id)
453 453
454 454 if after.check(self.all_completed, self.all_failed):
455 455 # time deps already met, try to run
456 456 if not self.maybe_run(job):
457 457 # can't run yet
458 458 if msg_id not in self.all_failed:
459 459 # could have failed as unreachable
460 460 self.save_unmet(job)
461 461 else:
462 462 self.save_unmet(job)
463 463
464 464 def audit_timeouts(self):
465 465 """Audit all waiting tasks for expired timeouts."""
466 466 now = datetime.now()
467 467 for msg_id in self.depending.keys():
468 468 # must recheck, in case one failure cascaded to another:
469 469 if msg_id in self.depending:
470 470 job = self.depending[msg_id]
471 471 if job.timeout and job.timeout < now:
472 472 self.fail_unreachable(msg_id, error.TaskTimeout)
473 473
474 474 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
475 475 """a task has become unreachable, send a reply with an ImpossibleDependency
476 476 error."""
477 477 if msg_id not in self.depending:
478 478 self.log.error("msg %r already failed!", msg_id)
479 479 return
480 480 job = self.depending.pop(msg_id)
481 481 for mid in job.dependents:
482 482 if mid in self.graph:
483 483 self.graph[mid].remove(msg_id)
484 484
485 485 try:
486 486 raise why()
487 487 except:
488 488 content = error.wrap_exception()
489 489
490 490 self.all_done.add(msg_id)
491 491 self.all_failed.add(msg_id)
492 492
493 493 msg = self.session.send(self.client_stream, 'apply_reply', content,
494 494 parent=job.header, ident=job.idents)
495 495 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
496 496
497 497 self.update_graph(msg_id, success=False)
498 498
499 499 def maybe_run(self, job):
500 500 """check location dependencies, and run if they are met."""
501 501 msg_id = job.msg_id
502 502 self.log.debug("Attempting to assign task %s", msg_id)
503 503 if not self.targets:
504 504 # no engines, definitely can't run
505 505 return False
506 506
507 507 if job.follow or job.targets or job.blacklist or self.hwm:
508 508 # we need a can_run filter
509 509 def can_run(idx):
510 510 # check hwm
511 511 if self.hwm and self.loads[idx] == self.hwm:
512 512 return False
513 513 target = self.targets[idx]
514 514 # check blacklist
515 515 if target in job.blacklist:
516 516 return False
517 517 # check targets
518 518 if job.targets and target not in job.targets:
519 519 return False
520 520 # check follow
521 521 return job.follow.check(self.completed[target], self.failed[target])
522 522
523 523 indices = filter(can_run, range(len(self.targets)))
524 524
525 525 if not indices:
526 526 # couldn't run
527 527 if job.follow.all:
528 528 # check follow for impossibility
529 529 dests = set()
530 530 relevant = set()
531 531 if job.follow.success:
532 532 relevant = self.all_completed
533 533 if job.follow.failure:
534 534 relevant = relevant.union(self.all_failed)
535 535 for m in job.follow.intersection(relevant):
536 536 dests.add(self.destinations[m])
537 537 if len(dests) > 1:
538 538 self.depending[msg_id] = job
539 539 self.fail_unreachable(msg_id)
540 540 return False
541 541 if job.targets:
542 542 # check blacklist+targets for impossibility
543 543 job.targets.difference_update(job.blacklist)
544 544 if not job.targets or not job.targets.intersection(self.targets):
545 545 self.depending[msg_id] = job
546 546 self.fail_unreachable(msg_id)
547 547 return False
548 548 return False
549 549 else:
550 550 indices = None
551 551
552 552 self.submit_task(job, indices)
553 553 return True
554 554
555 555 def save_unmet(self, job):
556 556 """Save a message for later submission when its dependencies are met."""
557 557 msg_id = job.msg_id
558 558 self.depending[msg_id] = job
559 559 # track the ids in follow or after, but not those already finished
560 560 for dep_id in job.after.union(job.follow).difference(self.all_done):
561 561 if dep_id not in self.graph:
562 562 self.graph[dep_id] = set()
563 563 self.graph[dep_id].add(msg_id)
564 564
565 565 def submit_task(self, job, indices=None):
566 566 """Submit a task to any of a subset of our targets."""
567 567 if indices:
568 568 loads = [self.loads[i] for i in indices]
569 569 else:
570 570 loads = self.loads
571 571 idx = self.scheme(loads)
572 572 if indices:
573 573 idx = indices[idx]
574 574 target = self.targets[idx]
575 575 # print (target, map(str, msg[:3]))
576 576 # send job to the engine
577 577 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
578 578 self.engine_stream.send_multipart(job.raw_msg, copy=False)
579 579 # update load
580 580 self.add_job(idx)
581 581 self.pending[target][job.msg_id] = job
582 582 # notify Hub
583 583 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
584 584 self.session.send(self.mon_stream, 'task_destination', content=content,
585 585 ident=[b'tracktask',self.ident])
586 586
587 587
588 588 #-----------------------------------------------------------------------
589 589 # Result Handling
590 590 #-----------------------------------------------------------------------
591 591
592 592
593 593 @util.log_errors
594 594 def dispatch_result(self, raw_msg):
595 595 """dispatch method for result replies"""
596 596 try:
597 597 idents,msg = self.session.feed_identities(raw_msg, copy=False)
598 598 msg = self.session.unserialize(msg, content=False, copy=False)
599 599 engine = idents[0]
600 600 try:
601 601 idx = self.targets.index(engine)
602 602 except ValueError:
603 603 pass # skip load-update for dead engines
604 604 else:
605 605 self.finish_job(idx)
606 606 except Exception:
607 607 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
608 608 return
609 609
610 610 header = msg['header']
611 611 parent = msg['parent_header']
612 612 if header.get('dependencies_met', True):
613 613 success = (header['status'] == 'ok')
614 614 msg_id = parent['msg_id']
615 615 retries = self.retries[msg_id]
616 616 if not success and retries > 0:
617 617 # failed
618 618 self.retries[msg_id] = retries - 1
619 619 self.handle_unmet_dependency(idents, parent)
620 620 else:
621 621 del self.retries[msg_id]
622 622 # relay to client and update graph
623 623 self.handle_result(idents, parent, raw_msg, success)
624 624 # send to Hub monitor
625 625 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
626 626 else:
627 627 self.handle_unmet_dependency(idents, parent)
628 628
629 629 def handle_result(self, idents, parent, raw_msg, success=True):
630 630 """handle a real task result, either success or failure"""
631 631 # first, relay result to client
632 632 engine = idents[0]
633 633 client = idents[1]
634 634 # swap_ids for ROUTER-ROUTER mirror
635 635 raw_msg[:2] = [client,engine]
636 636 # print (map(str, raw_msg[:4]))
637 637 self.client_stream.send_multipart(raw_msg, copy=False)
638 638 # now, update our data structures
639 639 msg_id = parent['msg_id']
640 640 self.pending[engine].pop(msg_id)
641 641 if success:
642 642 self.completed[engine].add(msg_id)
643 643 self.all_completed.add(msg_id)
644 644 else:
645 645 self.failed[engine].add(msg_id)
646 646 self.all_failed.add(msg_id)
647 647 self.all_done.add(msg_id)
648 648 self.destinations[msg_id] = engine
649 649
650 650 self.update_graph(msg_id, success)
651 651
652 652 def handle_unmet_dependency(self, idents, parent):
653 653 """handle an unmet dependency"""
654 654 engine = idents[0]
655 655 msg_id = parent['msg_id']
656 656
657 657 job = self.pending[engine].pop(msg_id)
658 658 job.blacklist.add(engine)
659 659
660 660 if job.blacklist == job.targets:
661 661 self.depending[msg_id] = job
662 662 self.fail_unreachable(msg_id)
663 663 elif not self.maybe_run(job):
664 664 # resubmit failed
665 665 if msg_id not in self.all_failed:
666 666 # put it back in our dependency tree
667 667 self.save_unmet(job)
668 668
669 669 if self.hwm:
670 670 try:
671 671 idx = self.targets.index(engine)
672 672 except ValueError:
673 673 pass # skip load-update for dead engines
674 674 else:
675 675 if self.loads[idx] == self.hwm-1:
676 676 self.update_graph(None)
677 677
678 678
679 679
680 680 def update_graph(self, dep_id=None, success=True):
681 681 """dep_id just finished. Update our dependency
682 682 graph and submit any jobs that just became runable.
683 683
684 684 Called with dep_id=None to update entire graph for hwm, but without finishing
685 685 a task.
686 686 """
687 687 # print ("\n\n***********")
688 688 # pprint (dep_id)
689 689 # pprint (self.graph)
690 690 # pprint (self.depending)
691 691 # pprint (self.all_completed)
692 692 # pprint (self.all_failed)
693 693 # print ("\n\n***********\n\n")
694 694 # update any jobs that depended on the dependency
695 695 jobs = self.graph.pop(dep_id, [])
696 696
697 697 # recheck *all* jobs if
698 698 # a) we have HWM and an engine just become no longer full
699 699 # or b) dep_id was given as None
700 700
701 701 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
702 702 jobs = self.depending.keys()
703 703
704 704 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
705 705 job = self.depending[msg_id]
706 706
707 707 if job.after.unreachable(self.all_completed, self.all_failed)\
708 708 or job.follow.unreachable(self.all_completed, self.all_failed):
709 709 self.fail_unreachable(msg_id)
710 710
711 711 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
712 712 if self.maybe_run(job):
713 713
714 714 self.depending.pop(msg_id)
715 715 for mid in job.dependents:
716 716 if mid in self.graph:
717 717 self.graph[mid].remove(msg_id)
718 718
719 719 #----------------------------------------------------------------------
720 720 # methods to be overridden by subclasses
721 721 #----------------------------------------------------------------------
722 722
723 723 def add_job(self, idx):
724 724 """Called after self.targets[idx] just got the job with header.
725 725 Override with subclasses. The default ordering is simple LRU.
726 726 The default loads are the number of outstanding jobs."""
727 727 self.loads[idx] += 1
728 728 for lis in (self.targets, self.loads):
729 729 lis.append(lis.pop(idx))
730 730
731 731
732 732 def finish_job(self, idx):
733 733 """Called after self.targets[idx] just finished a job.
734 734 Override with subclasses."""
735 735 self.loads[idx] -= 1
736 736
737 737
738 738
739 739 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
740 740 logname='root', log_url=None, loglevel=logging.DEBUG,
741 741 identity=b'task', in_thread=False):
742 742
743 743 ZMQStream = zmqstream.ZMQStream
744 744
745 745 if config:
746 746 # unwrap dict back into Config
747 747 config = Config(config)
748 748
749 749 if in_thread:
750 750 # use instance() to get the same Context/Loop as our parent
751 751 ctx = zmq.Context.instance()
752 752 loop = ioloop.IOLoop.instance()
753 753 else:
754 754 # in a process, don't use instance()
755 755 # for safety with multiprocessing
756 756 ctx = zmq.Context()
757 757 loop = ioloop.IOLoop()
758 758 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
759 759 ins.setsockopt(zmq.IDENTITY, identity+'_in')
760 760 ins.bind(in_addr)
761 761
762 762 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
763 763 outs.setsockopt(zmq.IDENTITY, identity+'_out')
764 764 outs.bind(out_addr)
765 765 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
766 766 mons.connect(mon_addr)
767 767 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
768 768 nots.setsockopt(zmq.SUBSCRIBE, b'')
769 769 nots.connect(not_addr)
770 770
771 771 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
772 772 querys.connect(reg_addr)
773 773
774 774 # setup logging.
775 775 if in_thread:
776 776 log = Application.instance().log
777 777 else:
778 778 if log_url:
779 779 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
780 780 else:
781 781 log = local_logger(logname, loglevel)
782 782
783 783 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
784 784 mon_stream=mons, notifier_stream=nots,
785 785 query_stream=querys,
786 786 loop=loop, log=log,
787 787 config=config)
788 788 scheduler.start()
789 789 if not in_thread:
790 790 try:
791 791 loop.start()
792 792 except KeyboardInterrupt:
793 793 scheduler.log.critical("Interrupted, exiting...")
794 794
General Comments 0
You need to be logged in to leave comments. Login now