##// END OF EJS Templates
avoid race condition when deleting/starting sessions...
avoid race condition when deleting/starting sessions javascript doesn't guarantee the order of AJAX requests, so we give `Session.delete` and `Kernel.kill` a callback signature. Changing the kernel type calls `Notebook.start_kernel`, which terminates the previous session, if defined, before starting the new one. A flag is stored, to prevent multiple simultaneous attempts to start sessions, raising a SessionAlreadyStarting Error, preventing the spec_changed event from firing.

File last commit:

r17264:506c6033
r17649:fb1ac74c
Show More
hub.py
1449 lines | 53.5 KiB | text/x-python | PythonLexer
"""The IPython Controller Hub with 0MQ
This is the master object that handles connections from engines and clients,
and monitors traffic through the various queues.
"""
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import print_function
import json
import os
import sys
import time
from datetime import datetime
import zmq
from zmq.eventloop import ioloop
from zmq.eventloop.zmqstream import ZMQStream
# internal:
from IPython.utils.importstring import import_item
from IPython.utils.jsonutil import extract_dates
from IPython.utils.localinterfaces import localhost
from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
from IPython.utils.traitlets import (
HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
)
from IPython.parallel import error, util
from IPython.parallel.factory import RegistrationFactory
from IPython.kernel.zmq.session import SessionFactory
from .heartmonitor import HeartMonitor
#-----------------------------------------------------------------------------
# Code
#-----------------------------------------------------------------------------
def _passer(*args, **kwargs):
return
def _printer(*args, **kwargs):
print (args)
print (kwargs)
def empty_record():
"""Return an empty dict with all record keys."""
return {
'msg_id' : None,
'header' : None,
'metadata' : None,
'content': None,
'buffers': None,
'submitted': None,
'client_uuid' : None,
'engine_uuid' : None,
'started': None,
'completed': None,
'resubmitted': None,
'received': None,
'result_header' : None,
'result_metadata' : None,
'result_content' : None,
'result_buffers' : None,
'queue' : None,
'execute_input' : None,
'execute_result': None,
'error': None,
'stdout': '',
'stderr': '',
}
def init_record(msg):
"""Initialize a TaskRecord based on a request."""
header = msg['header']
return {
'msg_id' : header['msg_id'],
'header' : header,
'content': msg['content'],
'metadata': msg['metadata'],
'buffers': msg['buffers'],
'submitted': header['date'],
'client_uuid' : None,
'engine_uuid' : None,
'started': None,
'completed': None,
'resubmitted': None,
'received': None,
'result_header' : None,
'result_metadata': None,
'result_content' : None,
'result_buffers' : None,
'queue' : None,
'execute_input' : None,
'execute_result': None,
'error': None,
'stdout': '',
'stderr': '',
}
class EngineConnector(HasTraits):
"""A simple object for accessing the various zmq connections of an object.
Attributes are:
id (int): engine ID
uuid (unicode): engine UUID
pending: set of msg_ids
stallback: DelayedCallback for stalled registration
"""
id = Integer(0)
uuid = Unicode()
pending = Set()
stallback = Instance(ioloop.DelayedCallback)
_db_shortcuts = {
'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
}
class HubFactory(RegistrationFactory):
"""The Configurable for setting up a Hub."""
# port-pairs for monitoredqueues:
hb = Tuple(Integer,Integer,config=True,
help="""PUB/ROUTER Port pair for Engine heartbeats""")
def _hb_default(self):
return tuple(util.select_random_ports(2))
mux = Tuple(Integer,Integer,config=True,
help="""Client/Engine Port pair for MUX queue""")
def _mux_default(self):
return tuple(util.select_random_ports(2))
task = Tuple(Integer,Integer,config=True,
help="""Client/Engine Port pair for Task queue""")
def _task_default(self):
return tuple(util.select_random_ports(2))
control = Tuple(Integer,Integer,config=True,
help="""Client/Engine Port pair for Control queue""")
def _control_default(self):
return tuple(util.select_random_ports(2))
iopub = Tuple(Integer,Integer,config=True,
help="""Client/Engine Port pair for IOPub relay""")
def _iopub_default(self):
return tuple(util.select_random_ports(2))
# single ports:
mon_port = Integer(config=True,
help="""Monitor (SUB) port for queue traffic""")
def _mon_port_default(self):
return util.select_random_ports(1)[0]
notifier_port = Integer(config=True,
help="""PUB port for sending engine status notifications""")
def _notifier_port_default(self):
return util.select_random_ports(1)[0]
engine_ip = Unicode(config=True,
help="IP on which to listen for engine connections. [default: loopback]")
def _engine_ip_default(self):
return localhost()
engine_transport = Unicode('tcp', config=True,
help="0MQ transport for engine connections. [default: tcp]")
client_ip = Unicode(config=True,
help="IP on which to listen for client connections. [default: loopback]")
client_transport = Unicode('tcp', config=True,
help="0MQ transport for client connections. [default : tcp]")
monitor_ip = Unicode(config=True,
help="IP on which to listen for monitor messages. [default: loopback]")
monitor_transport = Unicode('tcp', config=True,
help="0MQ transport for monitor messages. [default : tcp]")
_client_ip_default = _monitor_ip_default = _engine_ip_default
monitor_url = Unicode('')
db_class = DottedObjectName('NoDB',
config=True, help="""The class to use for the DB backend
Options include:
SQLiteDB: SQLite
MongoDB : use MongoDB
DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
NoDB : disable database altogether (default)
""")
registration_timeout = Integer(0, config=True,
help="Engine registration timeout in seconds [default: max(30,"
"10*heartmonitor.period)]" )
def _registration_timeout_default(self):
if self.heartmonitor is None:
# early initialization, this value will be ignored
return 0
# heartmonitor period is in milliseconds, so 10x in seconds is .01
return max(30, int(.01 * self.heartmonitor.period))
# not configurable
db = Instance('IPython.parallel.controller.dictdb.BaseDB')
heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
def _ip_changed(self, name, old, new):
self.engine_ip = new
self.client_ip = new
self.monitor_ip = new
self._update_monitor_url()
def _update_monitor_url(self):
self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
def _transport_changed(self, name, old, new):
self.engine_transport = new
self.client_transport = new
self.monitor_transport = new
self._update_monitor_url()
def __init__(self, **kwargs):
super(HubFactory, self).__init__(**kwargs)
self._update_monitor_url()
def construct(self):
self.init_hub()
def start(self):
self.heartmonitor.start()
self.log.info("Heartmonitor started")
def client_url(self, channel):
"""return full zmq url for a named client channel"""
return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
def engine_url(self, channel):
"""return full zmq url for a named engine channel"""
return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
def init_hub(self):
"""construct Hub object"""
ctx = self.context
loop = self.loop
if 'TaskScheduler.scheme_name' in self.config:
scheme = self.config.TaskScheduler.scheme_name
else:
from .scheduler import TaskScheduler
scheme = TaskScheduler.scheme_name.get_default_value()
# build connection dicts
engine = self.engine_info = {
'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
'registration' : self.regport,
'control' : self.control[1],
'mux' : self.mux[1],
'hb_ping' : self.hb[0],
'hb_pong' : self.hb[1],
'task' : self.task[1],
'iopub' : self.iopub[1],
}
client = self.client_info = {
'interface' : "%s://%s" % (self.client_transport, self.client_ip),
'registration' : self.regport,
'control' : self.control[0],
'mux' : self.mux[0],
'task' : self.task[0],
'task_scheme' : scheme,
'iopub' : self.iopub[0],
'notification' : self.notifier_port,
}
self.log.debug("Hub engine addrs: %s", self.engine_info)
self.log.debug("Hub client addrs: %s", self.client_info)
# Registrar socket
q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
util.set_hwm(q, 0)
q.bind(self.client_url('registration'))
self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
if self.client_ip != self.engine_ip:
q.bind(self.engine_url('registration'))
self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
### Engine connections ###
# heartbeat
hpub = ctx.socket(zmq.PUB)
hpub.bind(self.engine_url('hb_ping'))
hrep = ctx.socket(zmq.ROUTER)
util.set_hwm(hrep, 0)
hrep.bind(self.engine_url('hb_pong'))
self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
pingstream=ZMQStream(hpub,loop),
pongstream=ZMQStream(hrep,loop)
)
### Client connections ###
# Notifier socket
n = ZMQStream(ctx.socket(zmq.PUB), loop)
n.bind(self.client_url('notification'))
### build and launch the queues ###
# monitor socket
sub = ctx.socket(zmq.SUB)
sub.setsockopt(zmq.SUBSCRIBE, b"")
sub.bind(self.monitor_url)
sub.bind('inproc://monitor')
sub = ZMQStream(sub, loop)
# connect the db
db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
self.db = import_item(str(db_class))(session=self.session.session,
parent=self, log=self.log)
time.sleep(.25)
# resubmit stream
r = ZMQStream(ctx.socket(zmq.DEALER), loop)
url = util.disambiguate_url(self.client_url('task'))
r.connect(url)
# convert seconds to msec
registration_timeout = 1000*self.registration_timeout
self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
query=q, notifier=n, resubmit=r, db=self.db,
engine_info=self.engine_info, client_info=self.client_info,
log=self.log, registration_timeout=registration_timeout)
class Hub(SessionFactory):
"""The IPython Controller Hub with 0MQ connections
Parameters
==========
loop: zmq IOLoop instance
session: Session object
<removed> context: zmq context for creating new connections (?)
queue: ZMQStream for monitoring the command queue (SUB)
query: ZMQStream for engine registration and client queries requests (ROUTER)
heartbeat: HeartMonitor object checking the pulse of the engines
notifier: ZMQStream for broadcasting engine registration changes (PUB)
db: connection to db for out of memory logging of commands
NotImplemented
engine_info: dict of zmq connection information for engines to connect
to the queues.
client_info: dict of zmq connection information for engines to connect
to the queues.
"""
engine_state_file = Unicode()
# internal data structures:
ids=Set() # engine IDs
keytable=Dict()
by_ident=Dict()
engines=Dict()
clients=Dict()
hearts=Dict()
pending=Set()
queues=Dict() # pending msg_ids keyed by engine_id
tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
completed=Dict() # completed msg_ids keyed by engine_id
all_completed=Set() # completed msg_ids keyed by engine_id
dead_engines=Set() # completed msg_ids keyed by engine_id
unassigned=Set() # set of task msg_ds not yet assigned a destination
incoming_registrations=Dict()
registration_timeout=Integer()
_idcounter=Integer(0)
# objects from constructor:
query=Instance(ZMQStream)
monitor=Instance(ZMQStream)
notifier=Instance(ZMQStream)
resubmit=Instance(ZMQStream)
heartmonitor=Instance(HeartMonitor)
db=Instance(object)
client_info=Dict()
engine_info=Dict()
def __init__(self, **kwargs):
"""
# universal:
loop: IOLoop for creating future connections
session: streamsession for sending serialized data
# engine:
queue: ZMQStream for monitoring queue messages
query: ZMQStream for engine+client registration and client requests
heartbeat: HeartMonitor object for tracking engines
# extra:
db: ZMQStream for db connection (NotImplemented)
engine_info: zmq address/protocol dict for engine connections
client_info: zmq address/protocol dict for client connections
"""
super(Hub, self).__init__(**kwargs)
# register our callbacks
self.query.on_recv(self.dispatch_query)
self.monitor.on_recv(self.dispatch_monitor_traffic)
self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
self.monitor_handlers = {b'in' : self.save_queue_request,
b'out': self.save_queue_result,
b'intask': self.save_task_request,
b'outtask': self.save_task_result,
b'tracktask': self.save_task_destination,
b'incontrol': _passer,
b'outcontrol': _passer,
b'iopub': self.save_iopub_message,
}
self.query_handlers = {'queue_request': self.queue_status,
'result_request': self.get_results,
'history_request': self.get_history,
'db_request': self.db_query,
'purge_request': self.purge_results,
'load_request': self.check_load,
'resubmit_request': self.resubmit_task,
'shutdown_request': self.shutdown_request,
'registration_request' : self.register_engine,
'unregistration_request' : self.unregister_engine,
'connection_request': self.connection_request,
}
# ignore resubmit replies
self.resubmit.on_recv(lambda msg: None, copy=False)
self.log.info("hub::created hub")
@property
def _next_id(self):
"""gemerate a new ID.
No longer reuse old ids, just count from 0."""
newid = self._idcounter
self._idcounter += 1
return newid
# newid = 0
# incoming = [id[0] for id in itervalues(self.incoming_registrations)]
# # print newid, self.ids, self.incoming_registrations
# while newid in self.ids or newid in incoming:
# newid += 1
# return newid
#-----------------------------------------------------------------------------
# message validation
#-----------------------------------------------------------------------------
def _validate_targets(self, targets):
"""turn any valid targets argument into a list of integer ids"""
if targets is None:
# default to all
return self.ids
if isinstance(targets, (int,str,unicode_type)):
# only one target specified
targets = [targets]
_targets = []
for t in targets:
# map raw identities to ids
if isinstance(t, (str,unicode_type)):
t = self.by_ident.get(cast_bytes(t), t)
_targets.append(t)
targets = _targets
bad_targets = [ t for t in targets if t not in self.ids ]
if bad_targets:
raise IndexError("No Such Engine: %r" % bad_targets)
if not targets:
raise IndexError("No Engines Registered")
return targets
#-----------------------------------------------------------------------------
# dispatch methods (1 per stream)
#-----------------------------------------------------------------------------
@util.log_errors
def dispatch_monitor_traffic(self, msg):
"""all ME and Task queue messages come through here, as well as
IOPub traffic."""
self.log.debug("monitor traffic: %r", msg[0])
switch = msg[0]
try:
idents, msg = self.session.feed_identities(msg[1:])
except ValueError:
idents=[]
if not idents:
self.log.error("Monitor message without topic: %r", msg)
return
handler = self.monitor_handlers.get(switch, None)
if handler is not None:
handler(idents, msg)
else:
self.log.error("Unrecognized monitor topic: %r", switch)
@util.log_errors
def dispatch_query(self, msg):
"""Route registration requests and queries from clients."""
try:
idents, msg = self.session.feed_identities(msg)
except ValueError:
idents = []
if not idents:
self.log.error("Bad Query Message: %r", msg)
return
client_id = idents[0]
try:
msg = self.session.unserialize(msg, content=True)
except Exception:
content = error.wrap_exception()
self.log.error("Bad Query Message: %r", msg, exc_info=True)
self.session.send(self.query, "hub_error", ident=client_id,
content=content)
return
# print client_id, header, parent, content
#switch on message type:
msg_type = msg['header']['msg_type']
self.log.info("client::client %r requested %r", client_id, msg_type)
handler = self.query_handlers.get(msg_type, None)
try:
assert handler is not None, "Bad Message Type: %r" % msg_type
except:
content = error.wrap_exception()
self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
self.session.send(self.query, "hub_error", ident=client_id,
content=content)
return
else:
handler(idents, msg)
def dispatch_db(self, msg):
""""""
raise NotImplementedError
#---------------------------------------------------------------------------
# handler methods (1 per event)
#---------------------------------------------------------------------------
#----------------------- Heartbeat --------------------------------------
def handle_new_heart(self, heart):
"""handler to attach to heartbeater.
Called when a new heart starts to beat.
Triggers completion of registration."""
self.log.debug("heartbeat::handle_new_heart(%r)", heart)
if heart not in self.incoming_registrations:
self.log.info("heartbeat::ignoring new heart: %r", heart)
else:
self.finish_registration(heart)
def handle_heart_failure(self, heart):
"""handler to attach to heartbeater.
called when a previously registered heart fails to respond to beat request.
triggers unregistration"""
self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
eid = self.hearts.get(heart, None)
uuid = self.engines[eid].uuid
if eid is None or self.keytable[eid] in self.dead_engines:
self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
else:
self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
#----------------------- MUX Queue Traffic ------------------------------
def save_queue_request(self, idents, msg):
if len(idents) < 2:
self.log.error("invalid identity prefix: %r", idents)
return
queue_id, client_id = idents[:2]
try:
msg = self.session.unserialize(msg)
except Exception:
self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
return
eid = self.by_ident.get(queue_id, None)
if eid is None:
self.log.error("queue::target %r not registered", queue_id)
self.log.debug("queue:: valid are: %r", self.by_ident.keys())
return
record = init_record(msg)
msg_id = record['msg_id']
self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
# Unicode in records
record['engine_uuid'] = queue_id.decode('ascii')
record['client_uuid'] = msg['header']['session']
record['queue'] = 'mux'
try:
# it's posible iopub arrived first:
existing = self.db.get_record(msg_id)
for key,evalue in iteritems(existing):
rvalue = record.get(key, None)
if evalue and rvalue and evalue != rvalue:
self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
elif evalue and not rvalue:
record[key] = evalue
try:
self.db.update_record(msg_id, record)
except Exception:
self.log.error("DB Error updating record %r", msg_id, exc_info=True)
except KeyError:
try:
self.db.add_record(msg_id, record)
except Exception:
self.log.error("DB Error adding record %r", msg_id, exc_info=True)
self.pending.add(msg_id)
self.queues[eid].append(msg_id)
def save_queue_result(self, idents, msg):
if len(idents) < 2:
self.log.error("invalid identity prefix: %r", idents)
return
client_id, queue_id = idents[:2]
try:
msg = self.session.unserialize(msg)
except Exception:
self.log.error("queue::engine %r sent invalid message to %r: %r",
queue_id, client_id, msg, exc_info=True)
return
eid = self.by_ident.get(queue_id, None)
if eid is None:
self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
return
parent = msg['parent_header']
if not parent:
return
msg_id = parent['msg_id']
if msg_id in self.pending:
self.pending.remove(msg_id)
self.all_completed.add(msg_id)
self.queues[eid].remove(msg_id)
self.completed[eid].append(msg_id)
self.log.info("queue::request %r completed on %s", msg_id, eid)
elif msg_id not in self.all_completed:
# it could be a result from a dead engine that died before delivering the
# result
self.log.warn("queue:: unknown msg finished %r", msg_id)
return
# update record anyway, because the unregistration could have been premature
rheader = msg['header']
md = msg['metadata']
completed = rheader['date']
started = extract_dates(md.get('started', None))
result = {
'result_header' : rheader,
'result_metadata': md,
'result_content': msg['content'],
'received': datetime.now(),
'started' : started,
'completed' : completed
}
result['result_buffers'] = msg['buffers']
try:
self.db.update_record(msg_id, result)
except Exception:
self.log.error("DB Error updating record %r", msg_id, exc_info=True)
#--------------------- Task Queue Traffic ------------------------------
def save_task_request(self, idents, msg):
"""Save the submission of a task."""
client_id = idents[0]
try:
msg = self.session.unserialize(msg)
except Exception:
self.log.error("task::client %r sent invalid task message: %r",
client_id, msg, exc_info=True)
return
record = init_record(msg)
record['client_uuid'] = msg['header']['session']
record['queue'] = 'task'
header = msg['header']
msg_id = header['msg_id']
self.pending.add(msg_id)
self.unassigned.add(msg_id)
try:
# it's posible iopub arrived first:
existing = self.db.get_record(msg_id)
if existing['resubmitted']:
for key in ('submitted', 'client_uuid', 'buffers'):
# don't clobber these keys on resubmit
# submitted and client_uuid should be different
# and buffers might be big, and shouldn't have changed
record.pop(key)
# still check content,header which should not change
# but are not expensive to compare as buffers
for key,evalue in iteritems(existing):
if key.endswith('buffers'):
# don't compare buffers
continue
rvalue = record.get(key, None)
if evalue and rvalue and evalue != rvalue:
self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
elif evalue and not rvalue:
record[key] = evalue
try:
self.db.update_record(msg_id, record)
except Exception:
self.log.error("DB Error updating record %r", msg_id, exc_info=True)
except KeyError:
try:
self.db.add_record(msg_id, record)
except Exception:
self.log.error("DB Error adding record %r", msg_id, exc_info=True)
except Exception:
self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
def save_task_result(self, idents, msg):
"""save the result of a completed task."""
client_id = idents[0]
try:
msg = self.session.unserialize(msg)
except Exception:
self.log.error("task::invalid task result message send to %r: %r",
client_id, msg, exc_info=True)
return
parent = msg['parent_header']
if not parent:
# print msg
self.log.warn("Task %r had no parent!", msg)
return
msg_id = parent['msg_id']
if msg_id in self.unassigned:
self.unassigned.remove(msg_id)
header = msg['header']
md = msg['metadata']
engine_uuid = md.get('engine', u'')
eid = self.by_ident.get(cast_bytes(engine_uuid), None)
status = md.get('status', None)
if msg_id in self.pending:
self.log.info("task::task %r finished on %s", msg_id, eid)
self.pending.remove(msg_id)
self.all_completed.add(msg_id)
if eid is not None:
if status != 'aborted':
self.completed[eid].append(msg_id)
if msg_id in self.tasks[eid]:
self.tasks[eid].remove(msg_id)
completed = header['date']
started = extract_dates(md.get('started', None))
result = {
'result_header' : header,
'result_metadata': msg['metadata'],
'result_content': msg['content'],
'started' : started,
'completed' : completed,
'received' : datetime.now(),
'engine_uuid': engine_uuid,
}
result['result_buffers'] = msg['buffers']
try:
self.db.update_record(msg_id, result)
except Exception:
self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
else:
self.log.debug("task::unknown task %r finished", msg_id)
def save_task_destination(self, idents, msg):
try:
msg = self.session.unserialize(msg, content=True)
except Exception:
self.log.error("task::invalid task tracking message", exc_info=True)
return
content = msg['content']
# print (content)
msg_id = content['msg_id']
engine_uuid = content['engine_id']
eid = self.by_ident[cast_bytes(engine_uuid)]
self.log.info("task::task %r arrived on %r", msg_id, eid)
if msg_id in self.unassigned:
self.unassigned.remove(msg_id)
# else:
# self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
self.tasks[eid].append(msg_id)
# self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
try:
self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
except Exception:
self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
def mia_task_request(self, idents, msg):
raise NotImplementedError
client_id = idents[0]
# content = dict(mia=self.mia,status='ok')
# self.session.send('mia_reply', content=content, idents=client_id)
#--------------------- IOPub Traffic ------------------------------
def save_iopub_message(self, topics, msg):
"""save an iopub message into the db"""
# print (topics)
try:
msg = self.session.unserialize(msg, content=True)
except Exception:
self.log.error("iopub::invalid IOPub message", exc_info=True)
return
parent = msg['parent_header']
if not parent:
self.log.debug("iopub::IOPub message lacks parent: %r", msg)
return
msg_id = parent['msg_id']
msg_type = msg['header']['msg_type']
content = msg['content']
# ensure msg_id is in db
try:
rec = self.db.get_record(msg_id)
except KeyError:
rec = None
# stream
d = {}
if msg_type == 'stream':
name = content['name']
s = '' if rec is None else rec[name]
d[name] = s + content['data']
elif msg_type == 'error':
d['error'] = content
elif msg_type == 'execute_input':
d['execute_input'] = content['code']
elif msg_type in ('display_data', 'execute_result'):
d[msg_type] = content
elif msg_type == 'status':
pass
elif msg_type == 'data_pub':
self.log.info("ignored data_pub message for %s" % msg_id)
else:
self.log.warn("unhandled iopub msg_type: %r", msg_type)
if not d:
return
if rec is None:
# new record
rec = empty_record()
rec['msg_id'] = msg_id
rec.update(d)
d = rec
update_record = self.db.add_record
else:
update_record = self.db.update_record
try:
update_record(msg_id, d)
except Exception:
self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
#-------------------------------------------------------------------------
# Registration requests
#-------------------------------------------------------------------------
def connection_request(self, client_id, msg):
"""Reply with connection addresses for clients."""
self.log.info("client::client %r connected", client_id)
content = dict(status='ok')
jsonable = {}
for k,v in iteritems(self.keytable):
if v not in self.dead_engines:
jsonable[str(k)] = v
content['engines'] = jsonable
self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
def register_engine(self, reg, msg):
"""Register a new engine."""
content = msg['content']
try:
uuid = content['uuid']
except KeyError:
self.log.error("registration::queue not specified", exc_info=True)
return
eid = self._next_id
self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
# check if requesting available IDs:
if cast_bytes(uuid) in self.by_ident:
try:
raise KeyError("uuid %r in use" % uuid)
except:
content = error.wrap_exception()
self.log.error("uuid %r in use", uuid, exc_info=True)
else:
for h, ec in iteritems(self.incoming_registrations):
if uuid == h:
try:
raise KeyError("heart_id %r in use" % uuid)
except:
self.log.error("heart_id %r in use", uuid, exc_info=True)
content = error.wrap_exception()
break
elif uuid == ec.uuid:
try:
raise KeyError("uuid %r in use" % uuid)
except:
self.log.error("uuid %r in use", uuid, exc_info=True)
content = error.wrap_exception()
break
msg = self.session.send(self.query, "registration_reply",
content=content,
ident=reg)
heart = cast_bytes(uuid)
if content['status'] == 'ok':
if heart in self.heartmonitor.hearts:
# already beating
self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
self.finish_registration(heart)
else:
purge = lambda : self._purge_stalled_registration(heart)
dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
dc.start()
self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
else:
self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
return eid
def unregister_engine(self, ident, msg):
"""Unregister an engine that explicitly requested to leave."""
try:
eid = msg['content']['id']
except:
self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
return
self.log.info("registration::unregister_engine(%r)", eid)
# print (eid)
uuid = self.keytable[eid]
content=dict(id=eid, uuid=uuid)
self.dead_engines.add(uuid)
# self.ids.remove(eid)
# uuid = self.keytable.pop(eid)
#
# ec = self.engines.pop(eid)
# self.hearts.pop(ec.heartbeat)
# self.by_ident.pop(ec.queue)
# self.completed.pop(eid)
handleit = lambda : self._handle_stranded_msgs(eid, uuid)
dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
dc.start()
############## TODO: HANDLE IT ################
self._save_engine_state()
if self.notifier:
self.session.send(self.notifier, "unregistration_notification", content=content)
def _handle_stranded_msgs(self, eid, uuid):
"""Handle messages known to be on an engine when the engine unregisters.
It is possible that this will fire prematurely - that is, an engine will
go down after completing a result, and the client will be notified
that the result failed and later receive the actual result.
"""
outstanding = self.queues[eid]
for msg_id in outstanding:
self.pending.remove(msg_id)
self.all_completed.add(msg_id)
try:
raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
except:
content = error.wrap_exception()
# build a fake header:
header = {}
header['engine'] = uuid
header['date'] = datetime.now()
rec = dict(result_content=content, result_header=header, result_buffers=[])
rec['completed'] = header['date']
rec['engine_uuid'] = uuid
try:
self.db.update_record(msg_id, rec)
except Exception:
self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
def finish_registration(self, heart):
"""Second half of engine registration, called after our HeartMonitor
has received a beat from the Engine's Heart."""
try:
ec = self.incoming_registrations.pop(heart)
except KeyError:
self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
return
self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
if ec.stallback is not None:
ec.stallback.stop()
eid = ec.id
self.ids.add(eid)
self.keytable[eid] = ec.uuid
self.engines[eid] = ec
self.by_ident[cast_bytes(ec.uuid)] = ec.id
self.queues[eid] = list()
self.tasks[eid] = list()
self.completed[eid] = list()
self.hearts[heart] = eid
content = dict(id=eid, uuid=self.engines[eid].uuid)
if self.notifier:
self.session.send(self.notifier, "registration_notification", content=content)
self.log.info("engine::Engine Connected: %i", eid)
self._save_engine_state()
def _purge_stalled_registration(self, heart):
if heart in self.incoming_registrations:
ec = self.incoming_registrations.pop(heart)
self.log.info("registration::purging stalled registration: %i", ec.id)
else:
pass
#-------------------------------------------------------------------------
# Engine State
#-------------------------------------------------------------------------
def _cleanup_engine_state_file(self):
"""cleanup engine state mapping"""
if os.path.exists(self.engine_state_file):
self.log.debug("cleaning up engine state: %s", self.engine_state_file)
try:
os.remove(self.engine_state_file)
except IOError:
self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
def _save_engine_state(self):
"""save engine mapping to JSON file"""
if not self.engine_state_file:
return
self.log.debug("save engine state to %s" % self.engine_state_file)
state = {}
engines = {}
for eid, ec in iteritems(self.engines):
if ec.uuid not in self.dead_engines:
engines[eid] = ec.uuid
state['engines'] = engines
state['next_id'] = self._idcounter
with open(self.engine_state_file, 'w') as f:
json.dump(state, f)
def _load_engine_state(self):
"""load engine mapping from JSON file"""
if not os.path.exists(self.engine_state_file):
return
self.log.info("loading engine state from %s" % self.engine_state_file)
with open(self.engine_state_file) as f:
state = json.load(f)
save_notifier = self.notifier
self.notifier = None
for eid, uuid in iteritems(state['engines']):
heart = uuid.encode('ascii')
# start with this heart as current and beating:
self.heartmonitor.responses.add(heart)
self.heartmonitor.hearts.add(heart)
self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
self.finish_registration(heart)
self.notifier = save_notifier
self._idcounter = state['next_id']
#-------------------------------------------------------------------------
# Client Requests
#-------------------------------------------------------------------------
def shutdown_request(self, client_id, msg):
"""handle shutdown request."""
self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
# also notify other clients of shutdown
self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
dc.start()
def _shutdown(self):
self.log.info("hub::hub shutting down.")
time.sleep(0.1)
sys.exit(0)
def check_load(self, client_id, msg):
content = msg['content']
try:
targets = content['targets']
targets = self._validate_targets(targets)
except:
content = error.wrap_exception()
self.session.send(self.query, "hub_error",
content=content, ident=client_id)
return
content = dict(status='ok')
# loads = {}
for t in targets:
content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
self.session.send(self.query, "load_reply", content=content, ident=client_id)
def queue_status(self, client_id, msg):
"""Return the Queue status of one or more targets.
If verbose, return the msg_ids, else return len of each type.
Keys:
* queue (pending MUX jobs)
* tasks (pending Task jobs)
* completed (finished jobs from both queues)
"""
content = msg['content']
targets = content['targets']
try:
targets = self._validate_targets(targets)
except:
content = error.wrap_exception()
self.session.send(self.query, "hub_error",
content=content, ident=client_id)
return
verbose = content.get('verbose', False)
content = dict(status='ok')
for t in targets:
queue = self.queues[t]
completed = self.completed[t]
tasks = self.tasks[t]
if not verbose:
queue = len(queue)
completed = len(completed)
tasks = len(tasks)
content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
# print (content)
self.session.send(self.query, "queue_reply", content=content, ident=client_id)
def purge_results(self, client_id, msg):
"""Purge results from memory. This method is more valuable before we move
to a DB based message storage mechanism."""
content = msg['content']
self.log.info("Dropping records with %s", content)
msg_ids = content.get('msg_ids', [])
reply = dict(status='ok')
if msg_ids == 'all':
try:
self.db.drop_matching_records(dict(completed={'$ne':None}))
except Exception:
reply = error.wrap_exception()
self.log.exception("Error dropping records")
else:
pending = [m for m in msg_ids if (m in self.pending)]
if pending:
try:
raise IndexError("msg pending: %r" % pending[0])
except:
reply = error.wrap_exception()
self.log.exception("Error dropping records")
else:
try:
self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
except Exception:
reply = error.wrap_exception()
self.log.exception("Error dropping records")
if reply['status'] == 'ok':
eids = content.get('engine_ids', [])
for eid in eids:
if eid not in self.engines:
try:
raise IndexError("No such engine: %i" % eid)
except:
reply = error.wrap_exception()
self.log.exception("Error dropping records")
break
uid = self.engines[eid].uuid
try:
self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
except Exception:
reply = error.wrap_exception()
self.log.exception("Error dropping records")
break
self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
def resubmit_task(self, client_id, msg):
"""Resubmit one or more tasks."""
def finish(reply):
self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
content = msg['content']
msg_ids = content['msg_ids']
reply = dict(status='ok')
try:
records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
'header', 'content', 'buffers'])
except Exception:
self.log.error('db::db error finding tasks to resubmit', exc_info=True)
return finish(error.wrap_exception())
# validate msg_ids
found_ids = [ rec['msg_id'] for rec in records ]
pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
if len(records) > len(msg_ids):
try:
raise RuntimeError("DB appears to be in an inconsistent state."
"More matching records were found than should exist")
except Exception:
self.log.exception("Failed to resubmit task")
return finish(error.wrap_exception())
elif len(records) < len(msg_ids):
missing = [ m for m in msg_ids if m not in found_ids ]
try:
raise KeyError("No such msg(s): %r" % missing)
except KeyError:
self.log.exception("Failed to resubmit task")
return finish(error.wrap_exception())
elif pending_ids:
pass
# no need to raise on resubmit of pending task, now that we
# resubmit under new ID, but do we want to raise anyway?
# msg_id = invalid_ids[0]
# try:
# raise ValueError("Task(s) %r appears to be inflight" % )
# except Exception:
# return finish(error.wrap_exception())
# mapping of original IDs to resubmitted IDs
resubmitted = {}
# send the messages
for rec in records:
header = rec['header']
msg = self.session.msg(header['msg_type'], parent=header)
msg_id = msg['msg_id']
msg['content'] = rec['content']
# use the old header, but update msg_id and timestamp
fresh = msg['header']
header['msg_id'] = fresh['msg_id']
header['date'] = fresh['date']
msg['header'] = header
self.session.send(self.resubmit, msg, buffers=rec['buffers'])
resubmitted[rec['msg_id']] = msg_id
self.pending.add(msg_id)
msg['buffers'] = rec['buffers']
try:
self.db.add_record(msg_id, init_record(msg))
except Exception:
self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
return finish(error.wrap_exception())
finish(dict(status='ok', resubmitted=resubmitted))
# store the new IDs in the Task DB
for msg_id, resubmit_id in iteritems(resubmitted):
try:
self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
except Exception:
self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
def _extract_record(self, rec):
"""decompose a TaskRecord dict into subsection of reply for get_result"""
io_dict = {}
for key in ('execute_input', 'execute_result', 'error', 'stdout', 'stderr'):
io_dict[key] = rec[key]
content = {
'header': rec['header'],
'metadata': rec['metadata'],
'result_metadata': rec['result_metadata'],
'result_header' : rec['result_header'],
'result_content': rec['result_content'],
'received' : rec['received'],
'io' : io_dict,
}
if rec['result_buffers']:
buffers = list(map(bytes, rec['result_buffers']))
else:
buffers = []
return content, buffers
def get_results(self, client_id, msg):
"""Get the result of 1 or more messages."""
content = msg['content']
msg_ids = sorted(set(content['msg_ids']))
statusonly = content.get('status_only', False)
pending = []
completed = []
content = dict(status='ok')
content['pending'] = pending
content['completed'] = completed
buffers = []
if not statusonly:
try:
matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
# turn match list into dict, for faster lookup
records = {}
for rec in matches:
records[rec['msg_id']] = rec
except Exception:
content = error.wrap_exception()
self.log.exception("Failed to get results")
self.session.send(self.query, "result_reply", content=content,
parent=msg, ident=client_id)
return
else:
records = {}
for msg_id in msg_ids:
if msg_id in self.pending:
pending.append(msg_id)
elif msg_id in self.all_completed:
completed.append(msg_id)
if not statusonly:
c,bufs = self._extract_record(records[msg_id])
content[msg_id] = c
buffers.extend(bufs)
elif msg_id in records:
if rec['completed']:
completed.append(msg_id)
c,bufs = self._extract_record(records[msg_id])
content[msg_id] = c
buffers.extend(bufs)
else:
pending.append(msg_id)
else:
try:
raise KeyError('No such message: '+msg_id)
except:
content = error.wrap_exception()
break
self.session.send(self.query, "result_reply", content=content,
parent=msg, ident=client_id,
buffers=buffers)
def get_history(self, client_id, msg):
"""Get a list of all msg_ids in our DB records"""
try:
msg_ids = self.db.get_history()
except Exception as e:
content = error.wrap_exception()
self.log.exception("Failed to get history")
else:
content = dict(status='ok', history=msg_ids)
self.session.send(self.query, "history_reply", content=content,
parent=msg, ident=client_id)
def db_query(self, client_id, msg):
"""Perform a raw query on the task record database."""
content = msg['content']
query = extract_dates(content.get('query', {}))
keys = content.get('keys', None)
buffers = []
empty = list()
try:
records = self.db.find_records(query, keys)
except Exception as e:
content = error.wrap_exception()
self.log.exception("DB query failed")
else:
# extract buffers from reply content:
if keys is not None:
buffer_lens = [] if 'buffers' in keys else None
result_buffer_lens = [] if 'result_buffers' in keys else None
else:
buffer_lens = None
result_buffer_lens = None
for rec in records:
# buffers may be None, so double check
b = rec.pop('buffers', empty) or empty
if buffer_lens is not None:
buffer_lens.append(len(b))
buffers.extend(b)
rb = rec.pop('result_buffers', empty) or empty
if result_buffer_lens is not None:
result_buffer_lens.append(len(rb))
buffers.extend(rb)
content = dict(status='ok', records=records, buffer_lens=buffer_lens,
result_buffer_lens=result_buffer_lens)
# self.log.debug (content)
self.session.send(self.query, "db_reply", content=content,
parent=msg, ident=client_id,
buffers=buffers)