From fcee637ffd259420ecd6f5128b1133ee995f7cfb 2011-04-08 00:38:06 From: MinRK Date: 2011-04-08 00:38:06 Subject: [PATCH] prep newparallel for rebase This mainly involves checking out files @ 568f2f43077a78ee65f86c28b9d9ab63fc2b279f, to allow for cleaner application of changes after that point, where there are no longer name conflicts. --- diff --git a/IPython/zmq/log.py b/IPython/zmq/log.py new file mode 100644 index 0000000..797cd28 --- /dev/null +++ b/IPython/zmq/log.py @@ -0,0 +1,27 @@ +import logging +from logging import INFO, DEBUG, WARN, ERROR, FATAL + +import zmq +from zmq.log.handlers import PUBHandler + +class EnginePUBHandler(PUBHandler): + """A simple PUBHandler subclass that sets root_topic""" + engine=None + + def __init__(self, engine, *args, **kwargs): + PUBHandler.__init__(self,*args, **kwargs) + self.engine = engine + + @property + def root_topic(self): + """this is a property, in case the handler is created + before the engine gets registered with an id""" + if isinstance(getattr(self.engine, 'id', None), int): + return "engine.%i"%self.engine.id + else: + return "engine" + + +logger = logging.getLogger('ipzmq') +logger.setLevel(logging.DEBUG) + diff --git a/IPython/zmq/newserialized.py b/IPython/zmq/newserialized.py new file mode 100644 index 0000000..07577a7 --- /dev/null +++ b/IPython/zmq/newserialized.py @@ -0,0 +1,167 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_newserialized -*- + +"""Refactored serialization classes and interfaces.""" + +__docformat__ = "restructuredtext en" + +# Tell nose to skip this module +__test__ = {} + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import cPickle as pickle + +# from twisted.python import components +# from zope.interface import Interface, implements + +try: + import numpy +except ImportError: + pass + +from IPython.kernel.error import SerializationError + +#----------------------------------------------------------------------------- +# Classes and functions +#----------------------------------------------------------------------------- + +class ISerialized: + + def getData(): + """""" + + def getDataSize(units=10.0**6): + """""" + + def getTypeDescriptor(): + """""" + + def getMetadata(): + """""" + + +class IUnSerialized: + + def getObject(): + """""" + +class Serialized(object): + + # implements(ISerialized) + + def __init__(self, data, typeDescriptor, metadata={}): + self.data = data + self.typeDescriptor = typeDescriptor + self.metadata = metadata + + def getData(self): + return self.data + + def getDataSize(self, units=10.0**6): + return len(self.data)/units + + def getTypeDescriptor(self): + return self.typeDescriptor + + def getMetadata(self): + return self.metadata + + +class UnSerialized(object): + + # implements(IUnSerialized) + + def __init__(self, obj): + self.obj = obj + + def getObject(self): + return self.obj + + +class SerializeIt(object): + + # implements(ISerialized) + + def __init__(self, unSerialized): + self.data = None + self.obj = unSerialized.getObject() + if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray): + if len(self.obj) == 0: # length 0 arrays can't be reconstructed + raise SerializationError("You cannot send a length 0 array") + self.obj = numpy.ascontiguousarray(self.obj, dtype=None) + self.typeDescriptor = 'ndarray' + self.metadata = {'shape':self.obj.shape, + 'dtype':self.obj.dtype.str} + elif isinstance(self.obj, str): + self.typeDescriptor = 'bytes' + self.metadata = {} + elif isinstance(self.obj, buffer): + self.typeDescriptor = 'buffer' + self.metadata = {} + else: + self.typeDescriptor = 'pickle' + self.metadata = {} + self._generateData() + + def _generateData(self): + if self.typeDescriptor == 'ndarray': + self.data = numpy.getbuffer(self.obj) + elif self.typeDescriptor in ('bytes', 'buffer'): + self.data = self.obj + elif self.typeDescriptor == 'pickle': + self.data = pickle.dumps(self.obj, 2) + else: + raise SerializationError("Really wierd serialization error.") + del self.obj + + def getData(self): + return self.data + + def getDataSize(self, units=10.0**6): + return 1.0*len(self.data)/units + + def getTypeDescriptor(self): + return self.typeDescriptor + + def getMetadata(self): + return self.metadata + + +class UnSerializeIt(UnSerialized): + + # implements(IUnSerialized) + + def __init__(self, serialized): + self.serialized = serialized + + def getObject(self): + typeDescriptor = self.serialized.getTypeDescriptor() + if globals().has_key('numpy') and typeDescriptor == 'ndarray': + result = numpy.frombuffer(self.serialized.getData(), dtype = self.serialized.metadata['dtype']) + result.shape = self.serialized.metadata['shape'] + # This is a hack to make the array writable. We are working with + # the numpy folks to address this issue. + result = result.copy() + elif typeDescriptor == 'pickle': + result = pickle.loads(self.serialized.getData()) + elif typeDescriptor in ('bytes', 'buffer'): + result = self.serialized.getData() + else: + raise SerializationError("Really wierd serialization error.") + return result + +def serialize(obj): + return SerializeIt(UnSerialized(obj)) + +def unserialize(serialized): + return UnSerializeIt(serialized).getObject() diff --git a/IPython/zmq/parallel/__init__.py b/IPython/zmq/parallel/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/IPython/zmq/parallel/__init__.py diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py new file mode 100644 index 0000000..cb2d75b --- /dev/null +++ b/IPython/zmq/parallel/client.py @@ -0,0 +1,562 @@ +#!/usr/bin/env python +"""A semi-synchronous Client for the ZMQ controller""" + +import time +import threading + +from functools import wraps + +from IPython.external.decorator import decorator + +import streamsession as ss +import zmq + +from remotenamespace import RemoteNamespace +from view import DirectView + +def _push(ns): + globals().update(ns) + +def _pull(keys): + g = globals() + if isinstance(keys, (list,tuple)): + return map(g.get, keys) + else: + return g.get(keys) + +def _clear(): + globals().clear() + +def execute(code): + exec code in globals() + +# decorators for methods: +@decorator +def spinfirst(f,self,*args,**kwargs): + self.spin() + return f(self, *args, **kwargs) + +@decorator +def defaultblock(f, self, *args, **kwargs): + block = kwargs.get('block',None) + block = self.block if block is None else block + saveblock = self.block + self.block = block + ret = f(self, *args, **kwargs) + self.block = saveblock + return ret + + +# @decorator +# def checktargets(f): +# @wraps(f) +# def checked_method(self, *args, **kwargs): +# self._build_targets(kwargs['targets']) +# return f(self, *args, **kwargs) +# return checked_method + + +# class _ZMQEventLoopThread(threading.Thread): +# +# def __init__(self, loop): +# self.loop = loop +# threading.Thread.__init__(self) +# +# def run(self): +# self.loop.start() +# +class Client(object): + """A semi-synchronous client to the IPython ZMQ controller + + Attributes + ---------- + ids : set + a set of engine IDs + requesting the ids attribute always synchronizes + the registration state. To request ids without synchronization, + use _ids + + history : list of msg_ids + a list of msg_ids, keeping track of all the execution + messages you have submitted + + outstanding : set of msg_ids + a set of msg_ids that have been submitted, but whose + results have not been received + + results : dict + a dict of all our results, keyed by msg_id + + block : bool + determines default behavior when block not specified + in execution methods + + Methods + ------- + spin : flushes incoming results and registration state changes + control methods spin, and requesting `ids` also ensures up to date + + barrier : wait on one or more msg_ids + + execution methods: apply/apply_bound/apply_to + legacy: execute, run + + control methods: queue_status, get_result + + """ + + + _connected=False + _engines=None + registration_socket=None + controller_socket=None + notification_socket=None + queue_socket=None + task_socket=None + block = False + outstanding=None + results = None + history = None + + def __init__(self, addr, context=None, username=None): + if context is None: + context = zmq.Context() + self.context = context + self.addr = addr + if username is None: + self.session = ss.StreamSession() + else: + self.session = ss.StreamSession(username) + self.registration_socket = self.context.socket(zmq.PAIR) + self.registration_socket.setsockopt(zmq.IDENTITY, self.session.session) + self.registration_socket.connect(addr) + self._engines = {} + self._ids = set() + self.outstanding=set() + self.results = {} + self.history = [] + self._connect() + + self._notification_handlers = {'registration_notification' : self._register_engine, + 'unregistration_notification' : self._unregister_engine, + } + self._queue_handlers = {'execute_reply' : self._handle_execute_reply, + 'apply_reply' : self._handle_apply_reply} + + + @property + def ids(self): + self._flush_notifications() + return self._ids + + def _update_engines(self, engines): + for k,v in engines.iteritems(): + eid = int(k) + self._engines[eid] = v + self._ids.add(eid) + + def _build_targets(self, targets): + if targets is None: + targets = self._ids + elif isinstance(targets, str): + if targets.lower() == 'all': + targets = self._ids + else: + raise TypeError("%r not valid str target, must be 'all'"%(targets)) + elif isinstance(targets, int): + targets = [targets] + return [self._engines[t] for t in targets], list(targets) + + def _connect(self): + """setup all our socket connections to the controller""" + if self._connected: + return + self._connected=True + self.session.send(self.registration_socket, 'connection_request') + msg = self.session.recv(self.registration_socket,mode=0)[-1] + msg = ss.Message(msg) + content = msg.content + if content.status == 'ok': + if content.queue: + self.queue_socket = self.context.socket(zmq.PAIR) + self.queue_socket.setsockopt(zmq.IDENTITY, self.session.session) + self.queue_socket.connect(content.queue) + if content.task: + self.task_socket = self.context.socket(zmq.PAIR) + self.task_socket.setsockopt(zmq.IDENTITY, self.session.session) + self.task_socket.connect(content.task) + if content.notification: + self.notification_socket = self.context.socket(zmq.SUB) + self.notification_socket.connect(content.notification) + self.notification_socket.setsockopt(zmq.SUBSCRIBE, "") + if content.controller: + self.controller_socket = self.context.socket(zmq.PAIR) + self.controller_socket.setsockopt(zmq.IDENTITY, self.session.session) + self.controller_socket.connect(content.controller) + self._update_engines(dict(content.engines)) + + else: + self._connected = False + raise Exception("Failed to connect!") + + #### handlers and callbacks for incoming messages ####### + def _register_engine(self, msg): + content = msg['content'] + eid = content['id'] + d = {eid : content['queue']} + self._update_engines(d) + self._ids.add(int(eid)) + + def _unregister_engine(self, msg): + # print 'unregister',msg + content = msg['content'] + eid = int(content['id']) + if eid in self._ids: + self._ids.remove(eid) + self._engines.pop(eid) + + def _handle_execute_reply(self, msg): + # msg_id = msg['msg_id'] + parent = msg['parent_header'] + msg_id = parent['msg_id'] + if msg_id not in self.outstanding: + print "got unknown result: %s"%msg_id + else: + self.outstanding.remove(msg_id) + self.results[msg_id] = ss.unwrap_exception(msg['content']) + + def _handle_apply_reply(self, msg): + # print msg + # msg_id = msg['msg_id'] + parent = msg['parent_header'] + msg_id = parent['msg_id'] + if msg_id not in self.outstanding: + print "got unknown result: %s"%msg_id + else: + self.outstanding.remove(msg_id) + content = msg['content'] + if content['status'] == 'ok': + self.results[msg_id] = ss.unserialize_object(msg['buffers']) + else: + + self.results[msg_id] = ss.unwrap_exception(content) + + def _flush_notifications(self): + "flush incoming notifications of engine registrations" + msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK) + while msg is not None: + msg = msg[-1] + msg_type = msg['msg_type'] + handler = self._notification_handlers.get(msg_type, None) + if handler is None: + raise Exception("Unhandled message type: %s"%msg.msg_type) + else: + handler(msg) + msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK) + + def _flush_results(self, sock): + "flush incoming task or queue results" + msg = self.session.recv(sock, mode=zmq.NOBLOCK) + while msg is not None: + msg = msg[-1] + msg_type = msg['msg_type'] + handler = self._queue_handlers.get(msg_type, None) + if handler is None: + raise Exception("Unhandled message type: %s"%msg.msg_type) + else: + handler(msg) + msg = self.session.recv(sock, mode=zmq.NOBLOCK) + + ###### get/setitem ######## + + def __getitem__(self, key): + if isinstance(key, int): + if key not in self.ids: + raise IndexError("No such engine: %i"%key) + return DirectView(self, key) + + if isinstance(key, slice): + indices = range(len(self.ids))[key] + ids = sorted(self._ids) + key = [ ids[i] for i in indices ] + # newkeys = sorted(self._ids)[thekeys[k]] + + if isinstance(key, (tuple, list, xrange)): + _,targets = self._build_targets(list(key)) + return DirectView(self, targets) + else: + raise TypeError("key by int/iterable of ints only, not %s"%(type(key))) + + ############ begin real methods ############# + + def spin(self): + """flush incoming notifications and execution results.""" + if self.notification_socket: + self._flush_notifications() + if self.queue_socket: + self._flush_results(self.queue_socket) + if self.task_socket: + self._flush_results(self.task_socket) + + @spinfirst + def queue_status(self, targets=None, verbose=False): + """fetch the status of engine queues + + Parameters + ---------- + targets : int/str/list of ints/strs + the engines on which to execute + default : all + verbose : bool + whether to return + + """ + targets = self._build_targets(targets)[1] + content = dict(targets=targets) + self.session.send(self.controller_socket, "queue_request", content=content) + idents,msg = self.session.recv(self.controller_socket, 0) + return msg['content'] + + @spinfirst + def clear(self, targets=None): + """clear the namespace in target(s)""" + pass + + @spinfirst + def abort(self, targets=None): + """abort the Queues of target(s)""" + pass + + @defaultblock + def execute(self, code, targets='all', block=None): + """executes `code` on `targets` in blocking or nonblocking manner. + + Parameters + ---------- + code : str + the code string to be executed + targets : int/str/list of ints/strs + the engines on which to execute + default : all + block : bool + whether or not to wait until done + """ + # block = self.block if block is None else block + # saveblock = self.block + # self.block = block + result = self.apply(execute, (code,), targets=targets, block=block, bound=True) + # self.block = saveblock + return result + + def run(self, code, block=None): + """runs `code` on an engine. + + Calls to this are load-balanced. + + Parameters + ---------- + code : str + the code string to be executed + block : bool + whether or not to wait until done + + """ + result = self.apply(execute, (code,), targets=None, block=block, bound=False) + return result + + # a = time.time() + # content = dict(code=code) + # b = time.time() + # msg = self.session.send(self.task_socket, 'execute_request', + # content=content) + # c = time.time() + # msg_id = msg['msg_id'] + # self.outstanding.add(msg_id) + # self.history.append(msg_id) + # d = time.time() + # if block: + # self.barrier(msg_id) + # return self.results[msg_id] + # else: + # return msg_id + + def _apply_balanced(self, f, args, kwargs, bound=True, block=None): + """the underlying method for applying functions in a load balanced + manner.""" + block = block if block is not None else self.block + + bufs = ss.pack_apply_message(f,args,kwargs) + content = dict(bound=bound) + msg = self.session.send(self.task_socket, "apply_request", + content=content, buffers=bufs) + msg_id = msg['msg_id'] + self.outstanding.add(msg_id) + self.history.append(msg_id) + if block: + self.barrier(msg_id) + return self.results[msg_id] + else: + return msg_id + + def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None): + """Then underlying method for applying functions to specific engines.""" + block = block if block is not None else self.block + queues,targets = self._build_targets(targets) + + bufs = ss.pack_apply_message(f,args,kwargs) + content = dict(bound=bound) + msg_ids = [] + for queue in queues: + msg = self.session.send(self.queue_socket, "apply_request", + content=content, buffers=bufs,ident=queue) + msg_id = msg['msg_id'] + self.outstanding.add(msg_id) + self.history.append(msg_id) + msg_ids.append(msg_id) + if block: + self.barrier(msg_ids) + else: + if len(msg_ids) == 1: + return msg_ids[0] + else: + return msg_ids + if len(msg_ids) == 1: + return self.results[msg_ids[0]] + else: + result = {} + for target,mid in zip(targets, msg_ids): + result[target] = self.results[mid] + return result + + def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None): + """calls f(*args, **kwargs) on a remote engine(s), returning the result. + + if self.block is False: + returns msg_id or list of msg_ids + else: + returns actual result of f(*args, **kwargs) + """ + args = args if args is not None else [] + kwargs = kwargs if kwargs is not None else {} + if targets is None: + return self._apply_balanced(f,args,kwargs,bound=bound, block=block) + else: + return self._apply_direct(f, args, kwargs, + bound=bound,block=block, targets=targets) + + # def apply_bound(self, f, *args, **kwargs): + # """calls f(*args, **kwargs) on a remote engine. This does get + # executed in an engine's namespace. The controller selects the + # target engine via 0MQ XREQ load balancing. + # + # if self.block is False: + # returns msg_id + # else: + # returns actual result of f(*args, **kwargs) + # """ + # return self._apply(f, args, kwargs, bound=True) + # + # + # def apply_to(self, targets, f, *args, **kwargs): + # """calls f(*args, **kwargs) on a specific engine. + # + # if self.block is False: + # returns msg_id + # else: + # returns actual result of f(*args, **kwargs) + # + # The target's namespace is not used here. + # Use apply_bound_to() to access target's globals. + # """ + # return self._apply_to(False, targets, f, args, kwargs) + # + # def apply_bound_to(self, targets, f, *args, **kwargs): + # """calls f(*args, **kwargs) on a specific engine. + # + # if self.block is False: + # returns msg_id + # else: + # returns actual result of f(*args, **kwargs) + # + # This method has access to the target's globals + # + # """ + # return self._apply_to(f, args, kwargs) + # + def push(self, ns, targets=None, block=None): + """push the contents of `ns` into the namespace on `target`""" + if not isinstance(ns, dict): + raise TypeError("Must be a dict, not %s"%type(ns)) + result = self.apply(_push, (ns,), targets=targets, block=block,bound=True) + return result + + @spinfirst + def pull(self, keys, targets=None, block=True): + """pull objects from `target`'s namespace by `keys`""" + + result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True) + return result + + def barrier(self, msg_ids=None, timeout=-1): + """waits on one or more `msg_ids`, for up to `timeout` seconds. + + Parameters + ---------- + msg_ids : int, str, or list of ints and/or strs + ints are indices to self.history + strs are msg_ids + default: wait on all outstanding messages + timeout : float + a time in seconds, after which to give up. + default is -1, which means no timeout + + Returns + ------- + True : when all msg_ids are done + False : timeout reached, msg_ids still outstanding + """ + tic = time.time() + if msg_ids is None: + theids = self.outstanding + else: + if isinstance(msg_ids, (int, str)): + msg_ids = [msg_ids] + theids = set() + for msg_id in msg_ids: + if isinstance(msg_id, int): + msg_id = self.history[msg_id] + theids.add(msg_id) + self.spin() + while theids.intersection(self.outstanding): + if timeout >= 0 and ( time.time()-tic ) > timeout: + break + time.sleep(1e-3) + self.spin() + return len(theids.intersection(self.outstanding)) == 0 + + @spinfirst + def get_results(self, msg_ids,status_only=False): + """returns the result of the execute or task request with `msg_id`""" + if not isinstance(msg_ids, (list,tuple)): + msg_ids = [msg_ids] + theids = [] + for msg_id in msg_ids: + if isinstance(msg_id, int): + msg_id = self.history[msg_id] + theids.append(msg_id) + + content = dict(msg_ids=theids, status_only=status_only) + msg = self.session.send(self.controller_socket, "result_request", content=content) + zmq.select([self.controller_socket], [], []) + idents,msg = self.session.recv(self.controller_socket, zmq.NOBLOCK) + + # while True: + # try: + # except zmq.ZMQError: + # time.sleep(1e-3) + # continue + # else: + # break + return msg['content'] + + \ No newline at end of file diff --git a/IPython/zmq/parallel/controller.py b/IPython/zmq/parallel/controller.py new file mode 100644 index 0000000..7c93f32 --- /dev/null +++ b/IPython/zmq/parallel/controller.py @@ -0,0 +1,770 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""The IPython Controller with 0MQ +This is the master object that handles connections from engines, clients, and +""" +#----------------------------------------------------------------------------- +# Copyright (C) 2008-2009 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- +from datetime import datetime + +import zmq +from zmq.eventloop import zmqstream, ioloop +import uuid + +# internal: +from streamsession import Message, wrap_exception # default_unpacker as unpack, default_packer as pack +from IPython.zmq.log import logger # a Logger object + +# from messages import json # use the same import switches + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +class ReverseDict(dict): + """simple double-keyed subset of dict methods.""" + + def __init__(self, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self.reverse = dict() + for key, value in self.iteritems(): + self.reverse[value] = key + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self.reverse[key] + + def __setitem__(self, key, value): + if key in self.reverse: + raise KeyError("Can't have key %r on both sides!"%key) + dict.__setitem__(self, key, value) + self.reverse[value] = key + + def pop(self, key): + value = dict.pop(self, key) + self.d1.pop(value) + return value + + +class EngineConnector(object): + """A simple object for accessing the various zmq connections of an object. + Attributes are: + id (int): engine ID + uuid (str): uuid (unused?) + queue (str): identity of queue's XREQ socket + registration (str): identity of registration XREQ socket + heartbeat (str): identity of heartbeat XREQ socket + """ + id=0 + queue=None + control=None + registration=None + heartbeat=None + pending=None + + def __init__(self, id, queue, registration, control, heartbeat=None): + logger.info("engine::Engine Connected: %i"%id) + self.id = id + self.queue = queue + self.registration = registration + self.control = control + self.heartbeat = heartbeat + +class Controller(object): + """The IPython Controller with 0MQ connections + + Parameters + ========== + loop: zmq IOLoop instance + session: StreamSession object + context: zmq context for creating new connections (?) + registrar: ZMQStream for engine registration requests (XREP) + clientele: ZMQStream for client connections (XREP) + not used for jobs, only query/control commands + queue: ZMQStream for monitoring the command queue (SUB) + heartbeat: HeartMonitor object checking the pulse of the engines + db_stream: connection to db for out of memory logging of commands + NotImplemented + queue_addr: zmq connection address of the XREP socket for the queue + hb_addr: zmq connection address of the PUB socket for heartbeats + task_addr: zmq connection address of the XREQ socket for task queue + """ + # internal data structures: + ids=None # engine IDs + keytable=None + engines=None + clients=None + hearts=None + pending=None + results=None + tasks=None + completed=None + mia=None + incoming_registrations=None + registration_timeout=None + + #objects from constructor: + loop=None + registrar=None + clientelle=None + queue=None + heartbeat=None + notifier=None + db=None + client_addr=None + engine_addrs=None + + + def __init__(self, loop, session, queue, registrar, heartbeat, clientele, notifier, db, engine_addrs, client_addrs): + """ + # universal: + loop: IOLoop for creating future connections + session: streamsession for sending serialized data + # engine: + queue: ZMQStream for monitoring queue messages + registrar: ZMQStream for engine registration + heartbeat: HeartMonitor object for tracking engines + # client: + clientele: ZMQStream for client connections + # extra: + db: ZMQStream for db connection (NotImplemented) + engine_addrs: zmq address/protocol dict for engine connections + client_addrs: zmq address/protocol dict for client connections + """ + self.ids = set() + self.keytable={} + self.incoming_registrations={} + self.engines = {} + self.by_ident = {} + self.clients = {} + self.hearts = {} + self.mia = set() + + # self.sockets = {} + self.loop = loop + self.session = session + self.registrar = registrar + self.clientele = clientele + self.queue = queue + self.heartbeat = heartbeat + self.notifier = notifier + self.db = db + + self.client_addrs = client_addrs + assert isinstance(client_addrs['queue'], str) + # self.hb_addrs = hb_addrs + self.engine_addrs = engine_addrs + assert isinstance(engine_addrs['queue'], str) + assert len(engine_addrs['heartbeat']) == 2 + + + # register our callbacks + self.registrar.on_recv(self.dispatch_register_request) + self.clientele.on_recv(self.dispatch_client_msg) + self.queue.on_recv(self.dispatch_queue_traffic) + + if heartbeat is not None: + heartbeat.add_heart_failure_handler(self.handle_heart_failure) + heartbeat.add_new_heart_handler(self.handle_new_heart) + + if self.db is not None: + self.db.on_recv(self.dispatch_db) + + self.client_handlers = {'queue_request': self.queue_status, + 'result_request': self.get_results, + 'purge_request': self.purge_results, + 'resubmit_request': self.resubmit_task, + } + + self.registrar_handlers = {'registration_request' : self.register_engine, + 'unregistration_request' : self.unregister_engine, + 'connection_request': self.connection_request, + + } + # + # this is the stuff that will move to DB: + self.results = {} # completed results + self.pending = {} # pending messages, keyed by msg_id + self.queues = {} # pending msg_ids keyed by engine_id + self.tasks = {} # pending msg_ids submitted as tasks, keyed by client_id + self.completed = {} # completed msg_ids keyed by engine_id + self.registration_timeout = max(5000, 2*self.heartbeat.period) + + logger.info("controller::created controller") + + def _new_id(self): + """gemerate a new ID""" + newid = 0 + incoming = [id[0] for id in self.incoming_registrations.itervalues()] + # print newid, self.ids, self.incoming_registrations + while newid in self.ids or newid in incoming: + newid += 1 + return newid + + + #----------------------------------------------------------------------------- + # message validation + #----------------------------------------------------------------------------- + def _validate_targets(self, targets): + """turn any valid targets argument into a list of integer ids""" + if targets is None: + # default to all + targets = self.ids + + if isinstance(targets, (int,str,unicode)): + # only one target specified + targets = [targets] + _targets = [] + for t in targets: + # map raw identities to ids + if isinstance(t, (str,unicode)): + t = self.by_ident.get(t, t) + _targets.append(t) + targets = _targets + bad_targets = [ t for t in targets if t not in self.ids ] + if bad_targets: + raise IndexError("No Such Engine: %r"%bad_targets) + if not targets: + raise IndexError("No Engines Registered") + return targets + + def _validate_client_msg(self, msg): + """validates and unpacks headers of a message. Returns False if invalid, + (ident, header, parent, content)""" + client_id = msg[0] + try: + msg = self.session.unpack_message(msg[1:], content=True) + except: + logger.error("client::Invalid Message %s"%msg) + return False + + msg_type = msg.get('msg_type', None) + if msg_type is None: + return False + header = msg.get('header') + # session doesn't handle split content for now: + return client_id, msg + + + #----------------------------------------------------------------------------- + # dispatch methods (1 per socket) + #----------------------------------------------------------------------------- + + def dispatch_register_request(self, msg): + """""" + logger.debug("registration::dispatch_register_request(%s)"%msg) + idents,msg = self.session.feed_identities(msg) + print idents,msg, len(msg) + try: + msg = self.session.unpack_message(msg,content=True) + except Exception, e: + logger.error("registration::got bad registration message: %s"%msg) + raise e + return + + msg_type = msg['msg_type'] + content = msg['content'] + + handler = self.registrar_handlers.get(msg_type, None) + if handler is None: + logger.error("registration::got bad registration message: %s"%msg) + else: + handler(idents, msg) + + def dispatch_queue_traffic(self, msg): + """all ME and Task queue messages come through here""" + logger.debug("queue traffic: %s"%msg[:2]) + switch = msg[0] + idents, msg = self.session.feed_identities(msg[1:]) + if switch == 'in': + self.save_queue_request(idents, msg) + elif switch == 'out': + self.save_queue_result(idents, msg) + elif switch == 'intask': + self.save_task_request(idents, msg) + elif switch == 'outtask': + self.save_task_result(idents, msg) + elif switch == 'tracktask': + self.save_task_destination(idents, msg) + else: + logger.error("Invalid message topic: %s"%switch) + + + def dispatch_client_msg(self, msg): + """Route messages from clients""" + idents, msg = self.session.feed_identities(msg) + client_id = idents[0] + try: + msg = self.session.unpack_message(msg, content=True) + except: + content = wrap_exception() + logger.error("Bad Client Message: %s"%msg) + self.session.send(self.clientele, "controller_error", ident=client_id, + content=content) + return + + # print client_id, header, parent, content + #switch on message type: + msg_type = msg['msg_type'] + logger.info("client:: client %s requested %s"%(client_id, msg_type)) + handler = self.client_handlers.get(msg_type, None) + try: + assert handler is not None, "Bad Message Type: %s"%msg_type + except: + content = wrap_exception() + logger.error("Bad Message Type: %s"%msg_type) + self.session.send(self.clientele, "controller_error", ident=client_id, + content=content) + return + else: + handler(client_id, msg) + + def dispatch_db(self, msg): + """""" + raise NotImplementedError + + #--------------------------------------------------------------------------- + # handler methods (1 per event) + #--------------------------------------------------------------------------- + + #----------------------- Heartbeat -------------------------------------- + + def handle_new_heart(self, heart): + """handler to attach to heartbeater. + Called when a new heart starts to beat. + Triggers completion of registration.""" + logger.debug("heartbeat::handle_new_heart(%r)"%heart) + if heart not in self.incoming_registrations: + logger.info("heartbeat::ignoring new heart: %r"%heart) + else: + self.finish_registration(heart) + + + def handle_heart_failure(self, heart): + """handler to attach to heartbeater. + called when a previously registered heart fails to respond to beat request. + triggers unregistration""" + logger.debug("heartbeat::handle_heart_failure(%r)"%heart) + eid = self.hearts.get(heart, None) + if eid is None: + logger.info("heartbeat::ignoring heart failure %r"%heart) + else: + self.unregister_engine(heart, dict(content=dict(id=eid))) + + #----------------------- MUX Queue Traffic ------------------------------ + + def save_queue_request(self, idents, msg): + queue_id, client_id = idents[:2] + + try: + msg = self.session.unpack_message(msg, content=False) + except: + logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg)) + return + + eid = self.by_ident.get(queue_id, None) + if eid is None: + logger.error("queue::target %r not registered"%queue_id) + logger.debug("queue:: valid are: %s"%(self.by_ident.keys())) + return + + header = msg['header'] + msg_id = header['msg_id'] + info = dict(submit=datetime.now(), + received=None, + engine=(eid, queue_id)) + self.pending[msg_id] = ( msg, info ) + self.queues[eid][0].append(msg_id) + + def save_queue_result(self, idents, msg): + client_id, queue_id = idents[:2] + + try: + msg = self.session.unpack_message(msg, content=False) + except: + logger.error("queue::engine %r sent invalid message to %r: %s"%( + queue_id,client_id, msg)) + return + + eid = self.by_ident.get(queue_id, None) + if eid is None: + logger.error("queue::unknown engine %r is sending a reply: "%queue_id) + logger.debug("queue:: %s"%msg[2:]) + return + + parent = msg['parent_header'] + if not parent: + return + msg_id = parent['msg_id'] + self.results[msg_id] = msg + if msg_id in self.pending: + self.pending.pop(msg_id) + self.queues[eid][0].remove(msg_id) + self.completed[eid].append(msg_id) + else: + logger.debug("queue:: unknown msg finished %s"%msg_id) + + #--------------------- Task Queue Traffic ------------------------------ + + def save_task_request(self, idents, msg): + client_id = idents[0] + + try: + msg = self.session.unpack_message(msg, content=False) + except: + logger.error("task::client %r sent invalid task message: %s"%( + client_id, msg)) + return + + header = msg['header'] + msg_id = header['msg_id'] + self.mia.add(msg_id) + self.pending[msg_id] = msg + if not self.tasks.has_key(client_id): + self.tasks[client_id] = [] + self.tasks[client_id].append(msg_id) + + def save_task_result(self, idents, msg): + client_id = idents[0] + try: + msg = self.session.unpack_message(msg, content=False) + except: + logger.error("task::invalid task result message send to %r: %s"%( + client_id, msg)) + return + + parent = msg['parent_header'] + if not parent: + # print msg + # logger.warn("") + return + msg_id = parent['msg_id'] + self.results[msg_id] = msg + if msg_id in self.pending: + self.pending.pop(msg_id) + if msg_id in self.mia: + self.mia.remove(msg_id) + else: + logger.debug("task:: unknown task %s finished"%msg_id) + + def save_task_destination(self, idents, msg): + try: + msg = self.session.unpack_message(msg, content=True) + except: + logger.error("task::invalid task tracking message") + return + content = msg['content'] + print content + msg_id = content['msg_id'] + engine_uuid = content['engine_id'] + for eid,queue_id in self.keytable.iteritems(): + if queue_id == engine_uuid: + break + + logger.info("task:: task %s arrived on %s"%(msg_id, eid)) + if msg_id in self.mia: + self.mia.remove(msg_id) + else: + logger.debug("task::task %s not listed as MIA?!"%(msg_id)) + self.tasks[engine_uuid].append(msg_id) + + def mia_task_request(self, idents, msg): + client_id = idents[0] + content = dict(mia=self.mia,status='ok') + self.session.send('mia_reply', content=content, idents=client_id) + + + + #-------------------- Registration ----------------------------- + + def connection_request(self, client_id, msg): + """reply with connection addresses for clients""" + logger.info("client::client %s connected"%client_id) + content = dict(status='ok') + content.update(self.client_addrs) + jsonable = {} + for k,v in self.keytable.iteritems(): + jsonable[str(k)] = v + content['engines'] = jsonable + self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id) + + def register_engine(self, reg, msg): + """register an engine""" + content = msg['content'] + try: + queue = content['queue'] + except KeyError: + logger.error("registration::queue not specified") + return + heart = content.get('heartbeat', None) + """register a new engine, and create the socket(s) necessary""" + eid = self._new_id() + # print (eid, queue, reg, heart) + + logger.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart)) + + content = dict(id=eid,status='ok') + content.update(self.engine_addrs) + # check if requesting available IDs: + if queue in self.by_ident: + content = {'status': 'error', 'reason': "queue_id %r in use"%queue} + elif heart in self.hearts: # need to check unique hearts? + content = {'status': 'error', 'reason': "heart_id %r in use"%heart} + else: + for h, pack in self.incoming_registrations.iteritems(): + if heart == h: + content = {'status': 'error', 'reason': "heart_id %r in use"%heart} + break + elif queue == pack[1]: + content = {'status': 'error', 'reason': "queue_id %r in use"%queue} + break + + msg = self.session.send(self.registrar, "registration_reply", + content=content, + ident=reg) + + if content['status'] == 'ok': + if heart in self.heartbeat.hearts: + # already beating + self.incoming_registrations[heart] = (eid,queue,reg,None) + self.finish_registration(heart) + else: + purge = lambda : self._purge_stalled_registration(heart) + dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop) + dc.start() + self.incoming_registrations[heart] = (eid,queue,reg,dc) + else: + logger.error("registration::registration %i failed: %s"%(eid, content['reason'])) + return eid + + def unregister_engine(self, ident, msg): + try: + eid = msg['content']['id'] + except: + logger.error("registration::bad engine id for unregistration: %s"%ident) + return + logger.info("registration::unregister_engine(%s)"%eid) + content=dict(id=eid, queue=self.engines[eid].queue) + self.ids.remove(eid) + self.keytable.pop(eid) + ec = self.engines.pop(eid) + self.hearts.pop(ec.heartbeat) + self.by_ident.pop(ec.queue) + self.completed.pop(eid) + for msg_id in self.queues.pop(eid)[0]: + msg = self.pending.pop(msg_id) + ############## TODO: HANDLE IT ################ + + if self.notifier: + self.session.send(self.notifier, "unregistration_notification", content=content) + + def finish_registration(self, heart): + try: + (eid,queue,reg,purge) = self.incoming_registrations.pop(heart) + except KeyError: + logger.error("registration::tried to finish nonexistant registration") + return + logger.info("registration::finished registering engine %i:%r"%(eid,queue)) + if purge is not None: + purge.stop() + control = queue + self.ids.add(eid) + self.keytable[eid] = queue + self.engines[eid] = EngineConnector(eid, queue, reg, control, heart) + self.by_ident[queue] = eid + self.queues[eid] = ([],[]) + self.completed[eid] = list() + self.hearts[heart] = eid + content = dict(id=eid, queue=self.engines[eid].queue) + if self.notifier: + self.session.send(self.notifier, "registration_notification", content=content) + + def _purge_stalled_registration(self, heart): + if heart in self.incoming_registrations: + eid = self.incoming_registrations.pop(heart)[0] + logger.info("registration::purging stalled registration: %i"%eid) + else: + pass + + #------------------- Client Requests ------------------------------- + + def check_load(self, client_id, msg): + content = msg['content'] + try: + targets = content['targets'] + targets = self._validate_targets(targets) + except: + content = wrap_exception() + self.session.send(self.clientele, "controller_error", + content=content, ident=client_id) + return + + content = dict(status='ok') + # loads = {} + for t in targets: + content[str(t)] = len(self.queues[t]) + self.session.send(self.clientele, "load_reply", content=content, ident=client_id) + + + def queue_status(self, client_id, msg): + """handle queue_status request""" + content = msg['content'] + targets = content['targets'] + try: + targets = self._validate_targets(targets) + except: + content = wrap_exception() + self.session.send(self.clientele, "controller_error", + content=content, ident=client_id) + return + verbose = msg.get('verbose', False) + content = dict() + for t in targets: + queue = self.queues[t] + completed = self.completed[t] + if not verbose: + queue = len(queue) + completed = len(completed) + content[str(t)] = {'queue': queue, 'completed': completed } + # pending + self.session.send(self.clientele, "queue_reply", content=content, ident=client_id) + + def job_status(self, client_id, msg): + """handle queue_status request""" + content = msg['content'] + msg_ids = content['msg_ids'] + try: + targets = self._validate_targets(targets) + except: + content = wrap_exception() + self.session.send(self.clientele, "controller_error", + content=content, ident=client_id) + return + verbose = msg.get('verbose', False) + content = dict() + for t in targets: + queue = self.queues[t] + completed = self.completed[t] + if not verbose: + queue = len(queue) + completed = len(completed) + content[str(t)] = {'queue': queue, 'completed': completed } + # pending + self.session.send(self.clientele, "queue_reply", content=content, ident=client_id) + + def purge_results(self, client_id, msg): + content = msg['content'] + msg_ids = content.get('msg_ids', []) + reply = dict(status='ok') + if msg_ids == 'all': + self.results = {} + else: + for msg_id in msg_ids: + if msg_id in self.results: + self.results.pop(msg_id) + else: + if msg_id in self.pending: + reply = dict(status='error', reason="msg pending: %r"%msg_id) + else: + reply = dict(status='error', reason="No such msg: %r"%msg_id) + break + eids = content.get('engine_ids', []) + for eid in eids: + if eid not in self.engines: + reply = dict(status='error', reason="No such engine: %i"%eid) + break + msg_ids = self.completed.pop(eid) + for msg_id in msg_ids: + self.results.pop(msg_id) + + self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id) + + def resubmit_task(self, client_id, msg, buffers): + content = msg['content'] + header = msg['header'] + + + msg_ids = content.get('msg_ids', []) + reply = dict(status='ok') + if msg_ids == 'all': + self.results = {} + else: + for msg_id in msg_ids: + if msg_id in self.results: + self.results.pop(msg_id) + else: + if msg_id in self.pending: + reply = dict(status='error', reason="msg pending: %r"%msg_id) + else: + reply = dict(status='error', reason="No such msg: %r"%msg_id) + break + eids = content.get('engine_ids', []) + for eid in eids: + if eid not in self.engines: + reply = dict(status='error', reason="No such engine: %i"%eid) + break + msg_ids = self.completed.pop(eid) + for msg_id in msg_ids: + self.results.pop(msg_id) + + self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id) + + def get_results(self, client_id, msg): + """get the result of 1 or more messages""" + content = msg['content'] + msg_ids = set(content['msg_ids']) + statusonly = content.get('status_only', False) + pending = [] + completed = [] + content = dict(status='ok') + content['pending'] = pending + content['completed'] = completed + for msg_id in msg_ids: + if msg_id in self.pending: + pending.append(msg_id) + elif msg_id in self.results: + completed.append(msg_id) + if not statusonly: + content[msg_id] = self.results[msg_id]['content'] + else: + content = dict(status='error') + content['reason'] = 'no such message: '+msg_id + break + self.session.send(self.clientele, "result_reply", content=content, + parent=msg, ident=client_id) + + + +############ OLD METHODS for Python Relay Controller ################### + def _validate_engine_msg(self, msg): + """validates and unpacks headers of a message. Returns False if invalid, + (ident, message)""" + ident = msg[0] + try: + msg = self.session.unpack_message(msg[1:], content=False) + except: + logger.error("engine.%s::Invalid Message %s"%(ident, msg)) + return False + + try: + eid = msg.header.username + assert self.engines.has_key(eid) + except: + logger.error("engine::Invalid Engine ID %s"%(ident)) + return False + + return eid, msg + + + \ No newline at end of file diff --git a/IPython/zmq/parallel/engine.py b/IPython/zmq/parallel/engine.py new file mode 100644 index 0000000..2a6b781 --- /dev/null +++ b/IPython/zmq/parallel/engine.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +"""A simple engine that talks to a controller over 0MQ. +it handles registration, etc. and launches a kernel +connected to the Controller's queue(s). +""" +import sys +import time +import traceback +import uuid + +import zmq +from zmq.eventloop import ioloop, zmqstream + +from streamsession import Message, StreamSession +from client import Client +import streamkernel as kernel +import heartmonitor +# import taskthread +# from log import logger + + +def printer(*msg): + print msg + +class Engine(object): + """IPython engine""" + + id=None + context=None + loop=None + session=None + queue_id=None + control_id=None + heart_id=None + registrar=None + heart=None + kernel=None + + def __init__(self, context, loop, session, registrar, client, queue_id=None, heart_id=None): + self.context = context + self.loop = loop + self.session = session + self.registrar = registrar + self.client = client + self.queue_id = queue_id or str(uuid.uuid4()) + self.heart_id = heart_id or self.queue_id + self.registrar.on_send(printer) + + def register(self): + + content = dict(queue=self.queue_id, heartbeat=self.heart_id) + self.registrar.on_recv(self.complete_registration) + self.session.send(self.registrar, "registration_request",content=content) + + def complete_registration(self, msg): + # print msg + idents,msg = self.session.feed_identities(msg) + msg = Message(self.session.unpack_message(msg)) + if msg.content.status == 'ok': + self.session.username = str(msg.content.id) + queue_addr = msg.content.queue + if queue_addr: + queue = self.context.socket(zmq.PAIR) + queue.setsockopt(zmq.IDENTITY, self.queue_id) + queue.connect(str(queue_addr)) + self.queue = zmqstream.ZMQStream(queue, self.loop) + + control_addr = msg.content.control + if control_addr: + control = self.context.socket(zmq.PAIR) + control.setsockopt(zmq.IDENTITY, self.queue_id) + control.connect(str(control_addr)) + self.control = zmqstream.ZMQStream(control, self.loop) + + task_addr = msg.content.task + print task_addr + if task_addr: + # task as stream: + task = self.context.socket(zmq.PAIR) + task.connect(str(task_addr)) + self.task_stream = zmqstream.ZMQStream(task, self.loop) + # TaskThread: + # mon_addr = msg.content.monitor + # task = taskthread.TaskThread(zmq.PAIR, zmq.PUB, self.queue_id) + # task.connect_in(str(task_addr)) + # task.connect_out(str(mon_addr)) + # self.task_stream = taskthread.QueueStream(*task.queues) + # task.start() + + hbs = msg.content.heartbeat + self.heart = heartmonitor.Heart(*map(str, hbs), heart_id=self.heart_id) + self.heart.start() + # ioloop.DelayedCallback(self.heart.start, 1000, self.loop).start() + # placeholder for now: + pub = self.context.socket(zmq.PUB) + pub = zmqstream.ZMQStream(pub, self.loop) + # create and start the kernel + self.kernel = kernel.Kernel(self.session, self.control, self.queue, pub, self.task_stream, self.client) + self.kernel.start() + else: + # logger.error("Registration Failed: %s"%msg) + raise Exception("Registration Failed: %s"%msg) + + # logger.info("engine::completed registration with id %s"%self.session.username) + + print msg + + def unregister(self): + self.session.send(self.registrar, "unregistration_request", content=dict(id=int(self.session.username))) + time.sleep(1) + sys.exit(0) + + def start(self): + print "registering" + self.register() + + +if __name__ == '__main__': + + loop = ioloop.IOLoop.instance() + session = StreamSession() + ctx = zmq.Context() + + ip = '127.0.0.1' + reg_port = 10101 + connection = ('tcp://%s' % ip) + ':%i' + reg_conn = connection % reg_port + print reg_conn + print >>sys.__stdout__, "Starting the engine..." + + reg = ctx.socket(zmq.PAIR) + reg.connect(reg_conn) + reg = zmqstream.ZMQStream(reg, loop) + client = Client(reg_conn) + if len(sys.argv) > 1: + queue_id=sys.argv[1] + else: + queue_id = None + + e = Engine(ctx, loop, session, reg, client, queue_id) + dc = ioloop.DelayedCallback(e.start, 500, loop) + dc.start() + loop.start() \ No newline at end of file diff --git a/IPython/zmq/parallel/heartmonitor.py b/IPython/zmq/parallel/heartmonitor.py new file mode 100644 index 0000000..8db6203 --- /dev/null +++ b/IPython/zmq/parallel/heartmonitor.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +""" +A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB, +and hearts are tracked based on their XREQ identities. +""" + +import time +import uuid + +import zmq +from zmq.devices import ProcessDevice +from zmq.eventloop import ioloop, zmqstream + +#internal +from IPython.zmq.log import logger + +class Heart(object): + """A basic heart object for responding to a HeartMonitor. + This is a simple wrapper with defaults for the most common + Device model for responding to heartbeats. + + It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using + SUB/XREQ for in/out. + + You can specify the XREQ's IDENTITY via the optional heart_id argument.""" + device=None + id=None + def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None): + self.device = ProcessDevice(zmq.FORWARDER, in_type, out_type) + self.device.connect_in(in_addr) + self.device.connect_out(out_addr) + if in_type == zmq.SUB: + self.device.setsockopt_in(zmq.SUBSCRIBE, "") + if heart_id is None: + heart_id = str(uuid.uuid4()) + self.device.setsockopt_out(zmq.IDENTITY, heart_id) + self.id = heart_id + + def start(self): + return self.device.start() + +class HeartMonitor(object): + """A basic HeartMonitor class + pingstream: a PUB stream + pongstream: an XREP stream + period: the period of the heartbeat in milliseconds""" + loop=None + pingstream=None + pongstream=None + period=None + hearts=None + on_probation=None + last_ping=None + + def __init__(self, loop, pingstream, pongstream, period=1000): + self.loop = loop + self.period = period + + self.pingstream = pingstream + self.pongstream = pongstream + self.pongstream.on_recv(self.handle_pong) + + self.hearts = set() + self.responses = set() + self.on_probation = set() + self.lifetime = 0 + self.tic = time.time() + + self._new_handlers = set() + self._failure_handlers = set() + + def start(self): + self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop) + self.caller.start() + + def add_new_heart_handler(self, handler): + """add a new handler for new hearts""" + logger.debug("heartbeat::new_heart_handler: %s"%handler) + self._new_handlers.add(handler) + + def add_heart_failure_handler(self, handler): + """add a new handler for heart failure""" + logger.debug("heartbeat::new heart failure handler: %s"%handler) + self._failure_handlers.add(handler) + + def _flush(self): + """override IOLoop triggers""" + while True: + try: + msg = self.pongstream.socket.recv_multipart(zmq.NOBLOCK) + logger.warn("IOLoop triggered beat with incoming heartbeat waiting to be handled") + except zmq.ZMQError: + return + else: + self.handle_pong(msg) + # print '.' + + + def beat(self): + self._flush() + self.last_ping = self.lifetime + + toc = time.time() + self.lifetime += toc-self.tic + self.tic = toc + logger.debug("heartbeat::%s"%self.lifetime) + goodhearts = self.hearts.intersection(self.responses) + missed_beats = self.hearts.difference(goodhearts) + heartfailures = self.on_probation.intersection(missed_beats) + newhearts = self.responses.difference(goodhearts) + map(self.handle_new_heart, newhearts) + map(self.handle_heart_failure, heartfailures) + self.on_probation = missed_beats.intersection(self.hearts) + self.responses = set() + # print self.on_probation, self.hearts + # logger.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts))) + self.pingstream.send(str(self.lifetime)) + + def handle_new_heart(self, heart): + if self._new_handlers: + for handler in self._new_handlers: + handler(heart) + else: + logger.info("heartbeat::yay, got new heart %s!"%heart) + self.hearts.add(heart) + + def handle_heart_failure(self, heart): + if self._failure_handlers: + for handler in self._failure_handlers: + try: + handler(heart) + except Exception, e: + print e + logger.error("heartbeat::Bad Handler! %s"%handler) + pass + else: + logger.info("heartbeat::Heart %s failed :("%heart) + self.hearts.remove(heart) + + + def handle_pong(self, msg): + "a heart just beat" + if msg[1] == str(self.lifetime): + delta = time.time()-self.tic + logger.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta)) + self.responses.add(msg[0]) + elif msg[1] == str(self.last_ping): + delta = time.time()-self.tic + (self.lifetime-self.last_ping) + logger.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta)) + self.responses.add(msg[0]) + else: + logger.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"% + (msg[1],self.lifetime)) + + +if __name__ == '__main__': + loop = ioloop.IOLoop.instance() + context = zmq.Context() + pub = context.socket(zmq.PUB) + pub.bind('tcp://127.0.0.1:5555') + xrep = context.socket(zmq.XREP) + xrep.bind('tcp://127.0.0.1:5556') + + outstream = zmqstream.ZMQStream(pub, loop) + instream = zmqstream.ZMQStream(xrep, loop) + + hb = HeartMonitor(loop, outstream, instream) + + loop.start() diff --git a/IPython/zmq/parallel/remotenamespace.py b/IPython/zmq/parallel/remotenamespace.py new file mode 100644 index 0000000..315e538 --- /dev/null +++ b/IPython/zmq/parallel/remotenamespace.py @@ -0,0 +1,95 @@ +"""RemoteNamespace object, for dict style interaction with a remote +execution kernel.""" + +from functools import wraps +from IPython.external.decorator import decorator + +def _clear(): + globals().clear() + +@decorator +def spinfirst(f): + @wraps(f) + def spun_method(self, *args, **kwargs): + self.spin() + return f(self, *args, **kwargs) + return spun_method + +@decorator +def myblock(f, self, *args, **kwargs): + block = self.client.block + self.client.block = self.block + ret = f(self, *args, **kwargs) + self.client.block = block + return ret + +class RemoteNamespace(object): + """A RemoteNamespace object, providing dictionary + access to an engine via an IPython.zmq.client object. + + + """ + client = None + queue = None + id = None + block = False + + def __init__(self, client, id): + self.client = client + self.id = id + self.block = client.block # initial state is same as client + + def __repr__(self): + return ""%self.id + + @myblock + def apply(self, f, *args, **kwargs): + """call f(*args, **kwargs) in remote namespace + + This method has no access to the user namespace""" + return self.client.apply_to(self.id, f, *args, **kwargs) + + @myblock + def apply_bound(self, f, *args, **kwargs): + """call `f(*args, **kwargs)` in remote namespace. + + `f` will have access to the user namespace as globals().""" + return self.client.apply_bound_to(self.id, f, *args, **kwargs) + + @myblock + def update(self, ns): + """update remote namespace with dict `ns`""" + return self.client.push(self.id, ns, self.block) + + def get(self, key_s): + """get object(s) by `key_s` from remote namespace + will return one object if it is a key. + It also takes a list of keys, and will return a list of objects.""" + return self.client.pull(self.id, key_s, self.block) + + push = update + pull = get + + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self,key,value): + self.update({key:value}) + + def clear(self): + """clear the remote namespace""" + return self.client.apply_bound_to(self.id, _clear) + + @decorator + def withme(self, toapply): + """for use as a decorator, this turns a function into + one that executes remotely.""" + @wraps(toapply) + def applied(self, *args, **kwargs): + return self.apply_bound(self, toapply, *args, **kwargs) + return applied + + + + + diff --git a/IPython/zmq/parallel/streamkernel.py b/IPython/zmq/parallel/streamkernel.py new file mode 100755 index 0000000..25b8d9b --- /dev/null +++ b/IPython/zmq/parallel/streamkernel.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python +""" +Kernel adapted from kernel.py to use ZMQ Streams +""" + +import __builtin__ +import sys +import time +import traceback +from signal import SIGTERM, SIGKILL + +from code import CommandCompiler + +import zmq +from zmq.eventloop import ioloop, zmqstream + +from streamsession import StreamSession, Message, extract_header, serialize_object,\ + unpack_apply_message +from IPython.zmq.completer import KernelCompleter + +class OutStream(object): + """A file like object that publishes the stream to a 0MQ PUB socket.""" + + def __init__(self, session, pub_socket, name, max_buffer=200): + self.session = session + self.pub_socket = pub_socket + self.name = name + self._buffer = [] + self._buffer_len = 0 + self.max_buffer = max_buffer + self.parent_header = {} + + def set_parent(self, parent): + self.parent_header = extract_header(parent) + + def close(self): + self.pub_socket = None + + def flush(self): + if self.pub_socket is None: + raise ValueError(u'I/O operation on closed file') + else: + if self._buffer: + data = ''.join(self._buffer) + content = {u'name':self.name, u'data':data} + # msg = self.session.msg(u'stream', content=content, + # parent=self.parent_header) + msg = self.session.send(self.pub_socket, u'stream', content=content, parent=self.parent_header) + # print>>sys.__stdout__, Message(msg) + # self.pub_socket.send_json(msg) + self._buffer_len = 0 + self._buffer = [] + + def isattr(self): + return False + + def next(self): + raise IOError('Read not supported on a write only stream.') + + def read(self, size=None): + raise IOError('Read not supported on a write only stream.') + + readline=read + + def write(self, s): + if self.pub_socket is None: + raise ValueError('I/O operation on closed file') + else: + self._buffer.append(s) + self._buffer_len += len(s) + self._maybe_send() + + def _maybe_send(self): + if '\n' in self._buffer[-1]: + self.flush() + if self._buffer_len > self.max_buffer: + self.flush() + + def writelines(self, sequence): + if self.pub_socket is None: + raise ValueError('I/O operation on closed file') + else: + for s in sequence: + self.write(s) + + +class DisplayHook(object): + + def __init__(self, session, pub_socket): + self.session = session + self.pub_socket = pub_socket + self.parent_header = {} + + def __call__(self, obj): + if obj is None: + return + + __builtin__._ = obj + # msg = self.session.msg(u'pyout', {u'data':repr(obj)}, + # parent=self.parent_header) + # self.pub_socket.send_json(msg) + self.session.send(self.pub_socket, u'pyout', content={u'data':repr(obj)}, parent=self.parent_header) + + def set_parent(self, parent): + self.parent_header = extract_header(parent) + + +class RawInput(object): + + def __init__(self, session, socket): + self.session = session + self.socket = socket + + def __call__(self, prompt=None): + msg = self.session.msg(u'raw_input') + self.socket.send_json(msg) + while True: + try: + reply = self.socket.recv_json(zmq.NOBLOCK) + except zmq.ZMQError, e: + if e.errno == zmq.EAGAIN: + pass + else: + raise + else: + break + return reply[u'content'][u'data'] + + +class Kernel(object): + + def __init__(self, session, control_stream, reply_stream, pub_stream, + task_stream=None, client=None): + self.session = session + self.control_stream = control_stream + self.reply_stream = reply_stream + self.task_stream = task_stream + self.pub_stream = pub_stream + self.client = client + self.user_ns = {} + self.history = [] + self.compiler = CommandCompiler() + self.completer = KernelCompleter(self.user_ns) + self.aborted = set() + + # Build dict of handlers for message types + self.queue_handlers = {} + self.control_handlers = {} + for msg_type in ['execute_request', 'complete_request', 'apply_request']: + self.queue_handlers[msg_type] = getattr(self, msg_type) + + for msg_type in ['kill_request', 'abort_request']: + self.control_handlers[msg_type] = getattr(self, msg_type) + + #-------------------- control handlers ----------------------------- + + def abort_queue(self, stream): + while True: + try: + msg = self.session.recv(stream, zmq.NOBLOCK,content=True) + except zmq.ZMQError, e: + if e.errno == zmq.EAGAIN: + break + else: + return + else: + if msg is None: + return + else: + idents,msg = msg + + # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part." + # msg = self.reply_socket.recv_json() + print>>sys.__stdout__, "Aborting:" + print>>sys.__stdout__, Message(msg) + msg_type = msg['msg_type'] + reply_type = msg_type.split('_')[0] + '_reply' + # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg) + # self.reply_socket.send(ident,zmq.SNDMORE) + # self.reply_socket.send_json(reply_msg) + reply_msg = self.session.send(stream, reply_type, + content={'status' : 'aborted'}, parent=msg, ident=idents) + print>>sys.__stdout__, Message(reply_msg) + # We need to wait a bit for requests to come in. This can probably + # be set shorter for true asynchronous clients. + time.sleep(0.05) + + def abort_request(self, stream, ident, parent): + msg_ids = parent['content'].get('msg_ids', None) + if not msg_ids: + self.abort_queue(self.task_stream) + self.abort_queue(self.reply_stream) + for mid in msg_ids: + self.aborted.add(mid) + + content = dict(status='ok') + self.session.send(stream, 'abort_reply', content=content, parent=parent, + ident=ident) + + def kill_request(self, stream, idents, parent): + self.abort_queue(self.reply_stream) + if self.task_stream: + self.abort_queue(self.task_stream) + msg = self.session.send(stream, 'kill_reply', ident=idents, parent=parent, + content = dict(status='ok')) + # we can know that a message is done if we *don't* use streams, but + # use a socket directly with MessageTracker + time.sleep(1) + os.kill(os.getpid(), SIGTERM) + time.sleep(.25) + os.kill(os.getpid(), SIGKILL) + + def dispatch_control(self, msg): + idents,msg = self.session.feed_identities(msg, copy=False) + msg = self.session.unpack_message(msg, content=True, copy=False) + + header = msg['header'] + msg_id = header['msg_id'] + + handler = self.control_handlers.get(msg['msg_type'], None) + if handler is None: + print >> sys.__stderr__, "UNKNOWN CONTROL MESSAGE TYPE:", msg + else: + handler(stream, idents, msg) + + def flush_control(self): + while any(zmq.select([self.control_socket],[],[],1e-4)): + try: + msg = self.control_socket.recv_multipart(zmq.NOBLOCK, copy=False) + except zmq.ZMQError, e: + if e.errno != zmq.EAGAIN: + raise e + return + else: + self.dispatch_control(msg) + + + #-------------------- queue helpers ------------------------------ + + def check_dependencies(self, dependencies): + if not dependencies: + return True + if len(dependencies) == 2 and dependencies[0] in 'any all'.split(): + anyorall = dependencies[0] + dependencies = dependencies[1] + else: + anyorall = 'all' + results = self.client.get_results(dependencies,status_only=True) + if results['status'] != 'ok': + return False + + if anyorall == 'any': + if not results['completed']: + return False + else: + if results['pending']: + return False + + return True + + #-------------------- queue handlers ----------------------------- + + def execute_request(self, stream, ident, parent): + try: + code = parent[u'content'][u'code'] + except: + print>>sys.__stderr__, "Got bad msg: " + print>>sys.__stderr__, Message(parent) + return + # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent) + # self.pub_stream.send(pyin_msg) + self.session.send(self.pub_stream, u'pyin', {u'code':code},parent=parent) + try: + comp_code = self.compiler(code, '') + # allow for not overriding displayhook + if hasattr(sys.displayhook, 'set_parent'): + sys.displayhook.set_parent(parent) + exec comp_code in self.user_ns, self.user_ns + except: + # result = u'error' + etype, evalue, tb = sys.exc_info() + tb = traceback.format_exception(etype, evalue, tb) + exc_content = { + u'status' : u'error', + u'traceback' : tb, + u'etype' : unicode(etype), + u'evalue' : unicode(evalue) + } + # exc_msg = self.session.msg(u'pyerr', exc_content, parent) + self.session.send(self.pub_stream, u'pyerr', exc_content, parent=parent) + reply_content = exc_content + else: + reply_content = {'status' : 'ok'} + # reply_msg = self.session.msg(u'execute_reply', reply_content, parent) + # self.reply_socket.send(ident, zmq.SNDMORE) + # self.reply_socket.send_json(reply_msg) + reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, ident=ident) + # print>>sys.__stdout__, Message(reply_msg) + if reply_msg['content']['status'] == u'error': + self.abort_queue() + + def complete_request(self, stream, ident, parent): + matches = {'matches' : self.complete(parent), + 'status' : 'ok'} + completion_msg = self.session.send(stream, 'complete_reply', + matches, parent, ident) + # print >> sys.__stdout__, completion_msg + + def complete(self, msg): + return self.completer.complete(msg.content.line, msg.content.text) + + def apply_request(self, stream, ident, parent): + try: + content = parent[u'content'] + bufs = parent[u'buffers'] + msg_id = parent['header']['msg_id'] + bound = content.get('bound', False) + except: + print>>sys.__stderr__, "Got bad msg: " + print>>sys.__stderr__, Message(parent) + return + # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent) + # self.pub_stream.send(pyin_msg) + # self.session.send(self.pub_stream, u'pyin', {u'code':code},parent=parent) + try: + # allow for not overriding displayhook + if hasattr(sys.displayhook, 'set_parent'): + sys.displayhook.set_parent(parent) + # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns + if bound: + working = self.user_ns + suffix = str(msg_id).replace("-","") + prefix = "_" + + else: + working = dict() + suffix = prefix = "" + f,args,kwargs = unpack_apply_message(bufs, working, copy=False) + # if f.fun + fname = prefix+f.func_name.strip('<>')+suffix + argname = prefix+"args"+suffix + kwargname = prefix+"kwargs"+suffix + resultname = prefix+"result"+suffix + + ns = { fname : f, argname : args, kwargname : kwargs } + # print ns + working.update(ns) + code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname) + exec code in working, working + result = working.get(resultname) + # clear the namespace + if bound: + for key in ns.iterkeys(): + self.user_ns.pop(key) + else: + del working + + packed_result,buf = serialize_object(result) + result_buf = [packed_result]+buf + except: + result = u'error' + etype, evalue, tb = sys.exc_info() + tb = traceback.format_exception(etype, evalue, tb) + exc_content = { + u'status' : u'error', + u'traceback' : tb, + u'etype' : unicode(etype), + u'evalue' : unicode(evalue) + } + # exc_msg = self.session.msg(u'pyerr', exc_content, parent) + self.session.send(self.pub_stream, u'pyerr', exc_content, parent=parent) + reply_content = exc_content + result_buf = [] + else: + reply_content = {'status' : 'ok'} + # reply_msg = self.session.msg(u'execute_reply', reply_content, parent) + # self.reply_socket.send(ident, zmq.SNDMORE) + # self.reply_socket.send_json(reply_msg) + reply_msg = self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident,buffers=result_buf) + # print>>sys.__stdout__, Message(reply_msg) + if reply_msg['content']['status'] == u'error': + self.abort_queue() + + def dispatch_queue(self, stream, msg): + self.flush_control() + idents,msg = self.session.feed_identities(msg, copy=False) + msg = self.session.unpack_message(msg, content=True, copy=False) + + header = msg['header'] + msg_id = header['msg_id'] + dependencies = header.get('dependencies', []) + + if self.check_aborted(msg_id): + return self.abort_reply(stream, msg) + if not self.check_dependencies(dependencies): + return self.unmet_dependencies(stream, msg) + + handler = self.queue_handlers.get(msg['msg_type'], None) + if handler is None: + print >> sys.__stderr__, "UNKNOWN MESSAGE TYPE:", msg + else: + handler(stream, idents, msg) + + def start(self): + #### stream mode: + if self.control_stream: + self.control_stream.on_recv(self.dispatch_control, copy=False) + if self.reply_stream: + self.reply_stream.on_recv(lambda msg: + self.dispatch_queue(self.reply_stream, msg), copy=False) + if self.task_stream: + self.task_stream.on_recv(lambda msg: + self.dispatch_queue(self.task_stream, msg), copy=False) + + #### while True mode: + # while True: + # idle = True + # try: + # msg = self.reply_stream.socket.recv_multipart( + # zmq.NOBLOCK, copy=False) + # except zmq.ZMQError, e: + # if e.errno != zmq.EAGAIN: + # raise e + # else: + # idle=False + # self.dispatch_queue(self.reply_stream, msg) + # + # if not self.task_stream.empty(): + # idle=False + # msg = self.task_stream.recv_multipart() + # self.dispatch_queue(self.task_stream, msg) + # if idle: + # # don't busywait + # time.sleep(1e-3) + + +def main(): + raise Exception("Don't run me anymore") + loop = ioloop.IOLoop.instance() + c = zmq.Context() + + ip = '127.0.0.1' + port_base = 5575 + connection = ('tcp://%s' % ip) + ':%i' + rep_conn = connection % port_base + pub_conn = connection % (port_base+1) + + print >>sys.__stdout__, "Starting the kernel..." + # print >>sys.__stdout__, "XREQ Channel:", rep_conn + # print >>sys.__stdout__, "PUB Channel:", pub_conn + + session = StreamSession(username=u'kernel') + + reply_socket = c.socket(zmq.XREQ) + reply_socket.connect(rep_conn) + + pub_socket = c.socket(zmq.PUB) + pub_socket.connect(pub_conn) + + stdout = OutStream(session, pub_socket, u'stdout') + stderr = OutStream(session, pub_socket, u'stderr') + sys.stdout = stdout + sys.stderr = stderr + + display_hook = DisplayHook(session, pub_socket) + sys.displayhook = display_hook + reply_stream = zmqstream.ZMQStream(reply_socket,loop) + pub_stream = zmqstream.ZMQStream(pub_socket,loop) + kernel = Kernel(session, reply_stream, pub_stream) + + # For debugging convenience, put sleep and a string in the namespace, so we + # have them every time we start. + kernel.user_ns['sleep'] = time.sleep + kernel.user_ns['s'] = 'Test string' + + print >>sys.__stdout__, "Use Ctrl-\\ (NOT Ctrl-C!) to terminate." + kernel.start() + loop.start() + + +if __name__ == '__main__': + main() diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py new file mode 100644 index 0000000..710a605 --- /dev/null +++ b/IPython/zmq/parallel/streamsession.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python +"""edited session.py to work with streams, and move msg_type to the header +""" + + +import os +import sys +import traceback +import pprint +import uuid + +import zmq +from zmq.utils import jsonapi +from zmq.eventloop.zmqstream import ZMQStream + +from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence +from IPython.zmq.newserialized import serialize, unserialize + +try: + import cPickle + pickle = cPickle +except: + cPickle = None + import pickle + +# packer priority: jsonlib[2], cPickle, simplejson/json, pickle +json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__ +if json_name in ('jsonlib', 'jsonlib2'): + use_json = True +elif json_name: + if cPickle is None: + use_json = True + else: + use_json = False +else: + use_json = False + +if use_json: + default_packer = jsonapi.dumps + default_unpacker = jsonapi.loads +else: + default_packer = lambda o: pickle.dumps(o,-1) + default_unpacker = pickle.loads + + +DELIM="" + +def wrap_exception(): + etype, evalue, tb = sys.exc_info() + tb = traceback.format_exception(etype, evalue, tb) + exc_content = { + u'status' : u'error', + u'traceback' : tb, + u'etype' : unicode(etype), + u'evalue' : unicode(evalue) + } + return exc_content + +class KernelError(Exception): + pass + +def unwrap_exception(content): + err = KernelError(content['etype'], content['evalue']) + err.evalue = content['evalue'] + err.etype = content['etype'] + err.traceback = ''.join(content['traceback']) + return err + + +class Message(object): + """A simple message object that maps dict keys to attributes. + + A Message can be created from a dict and a dict from a Message instance + simply by calling dict(msg_obj).""" + + def __init__(self, msg_dict): + dct = self.__dict__ + for k, v in dict(msg_dict).iteritems(): + if isinstance(v, dict): + v = Message(v) + dct[k] = v + + # Having this iterator lets dict(msg_obj) work out of the box. + def __iter__(self): + return iter(self.__dict__.iteritems()) + + def __repr__(self): + return repr(self.__dict__) + + def __str__(self): + return pprint.pformat(self.__dict__) + + def __contains__(self, k): + return k in self.__dict__ + + def __getitem__(self, k): + return self.__dict__[k] + + +def msg_header(msg_id, msg_type, username, session): + return locals() + # return { + # 'msg_id' : msg_id, + # 'msg_type': msg_type, + # 'username' : username, + # 'session' : session + # } + + +def extract_header(msg_or_header): + """Given a message or header, return the header.""" + if not msg_or_header: + return {} + try: + # See if msg_or_header is the entire message. + h = msg_or_header['header'] + except KeyError: + try: + # See if msg_or_header is just the header + h = msg_or_header['msg_id'] + except KeyError: + raise + else: + h = msg_or_header + if not isinstance(h, dict): + h = dict(h) + return h + +def rekey(dikt): + """rekey a dict that has been forced to use str keys where there should be + ints by json. This belongs in the jsonutil added by fperez.""" + for k in dikt.iterkeys(): + if isinstance(k, str): + ik=fk=None + try: + ik = int(k) + except ValueError: + try: + fk = float(k) + except ValueError: + continue + if ik is not None: + nk = ik + else: + nk = fk + if nk in dikt: + raise KeyError("already have key %r"%nk) + dikt[nk] = dikt.pop(k) + return dikt + +def serialize_object(obj, threshold=64e-6): + """serialize an object into a list of sendable buffers. + + Returns: (pmd, bufs) + where pmd is the pickled metadata wrapper, and bufs + is a list of data buffers""" + # threshold is 100 B + databuffers = [] + if isinstance(obj, (list, tuple)): + clist = canSequence(obj) + slist = map(serialize, clist) + for s in slist: + if s.getDataSize() > threshold: + databuffers.append(s.getData()) + s.data = None + return pickle.dumps(slist,-1), databuffers + elif isinstance(obj, dict): + sobj = {} + for k in sorted(obj.iterkeys()): + s = serialize(can(obj[k])) + if s.getDataSize() > threshold: + databuffers.append(s.getData()) + s.data = None + sobj[k] = s + return pickle.dumps(sobj,-1),databuffers + else: + s = serialize(can(obj)) + if s.getDataSize() > threshold: + databuffers.append(s.getData()) + s.data = None + return pickle.dumps(s,-1),databuffers + + +def unserialize_object(bufs): + """reconstruct an object serialized by serialize_object from data buffers""" + bufs = list(bufs) + sobj = pickle.loads(bufs.pop(0)) + if isinstance(sobj, (list, tuple)): + for s in sobj: + if s.data is None: + s.data = bufs.pop(0) + return uncanSequence(map(unserialize, sobj)) + elif isinstance(sobj, dict): + newobj = {} + for k in sorted(sobj.iterkeys()): + s = sobj[k] + if s.data is None: + s.data = bufs.pop(0) + newobj[k] = uncan(unserialize(s)) + return newobj + else: + if sobj.data is None: + sobj.data = bufs.pop(0) + return uncan(unserialize(sobj)) + +def pack_apply_message(f, args, kwargs, threshold=64e-6): + """pack up a function, args, and kwargs to be sent over the wire + as a series of buffers. Any object whose data is larger than `threshold` + will not have their data copied (currently only numpy arrays support zero-copy)""" + msg = [pickle.dumps(can(f),-1)] + databuffers = [] # for large objects + sargs, bufs = serialize_object(args,threshold) + msg.append(sargs) + databuffers.extend(bufs) + skwargs, bufs = serialize_object(kwargs,threshold) + msg.append(skwargs) + databuffers.extend(bufs) + msg.extend(databuffers) + return msg + +def unpack_apply_message(bufs, g=None, copy=True): + """unpack f,args,kwargs from buffers packed by pack_apply_message() + Returns: original f,args,kwargs""" + bufs = list(bufs) # allow us to pop + assert len(bufs) >= 3, "not enough buffers!" + if not copy: + for i in range(3): + bufs[i] = bufs[i].bytes + cf = pickle.loads(bufs.pop(0)) + sargs = list(pickle.loads(bufs.pop(0))) + skwargs = dict(pickle.loads(bufs.pop(0))) + # print sargs, skwargs + f = cf.getFunction(g) + for sa in sargs: + if sa.data is None: + m = bufs.pop(0) + if sa.getTypeDescriptor() in ('buffer', 'ndarray'): + if copy: + sa.data = buffer(m) + else: + sa.data = m.buffer + else: + if copy: + sa.data = m + else: + sa.data = m.bytes + + args = uncanSequence(map(unserialize, sargs), g) + kwargs = {} + for k in sorted(skwargs.iterkeys()): + sa = skwargs[k] + if sa.data is None: + sa.data = bufs.pop(0) + kwargs[k] = uncan(unserialize(sa), g) + + return f,args,kwargs + +class StreamSession(object): + """tweaked version of IPython.zmq.session.Session, for development in Parallel""" + + def __init__(self, username=None, session=None, packer=None, unpacker=None): + if username is None: + username = os.environ.get('USER','username') + self.username = username + if session is None: + self.session = str(uuid.uuid4()) + else: + self.session = session + self.msg_id = str(uuid.uuid4()) + if packer is None: + self.pack = default_packer + else: + if not callable(packer): + raise TypeError("packer must be callable, not %s"%type(packer)) + self.pack = packer + + if unpacker is None: + self.unpack = default_unpacker + else: + if not callable(unpacker): + raise TypeError("unpacker must be callable, not %s"%type(unpacker)) + self.unpack = unpacker + + self.none = self.pack({}) + + def msg_header(self, msg_type): + h = msg_header(self.msg_id, msg_type, self.username, self.session) + self.msg_id = str(uuid.uuid4()) + return h + + def msg(self, msg_type, content=None, parent=None, subheader=None): + msg = {} + msg['header'] = self.msg_header(msg_type) + msg['msg_id'] = msg['header']['msg_id'] + msg['parent_header'] = {} if parent is None else extract_header(parent) + msg['msg_type'] = msg_type + msg['content'] = {} if content is None else content + sub = {} if subheader is None else subheader + msg['header'].update(sub) + return msg + + def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None): + """send a message via stream""" + msg = self.msg(msg_type, content, parent, subheader) + buffers = [] if buffers is None else buffers + to_send = [] + if isinstance(ident, list): + # accept list of idents + to_send.extend(ident) + elif ident is not None: + to_send.append(ident) + to_send.append(DELIM) + to_send.append(self.pack(msg['header'])) + to_send.append(self.pack(msg['parent_header'])) + # if parent is None: + # to_send.append(self.none) + # else: + # to_send.append(self.pack(dict(parent))) + if content is None: + content = self.none + elif isinstance(content, dict): + content = self.pack(content) + elif isinstance(content, str): + # content is already packed, as in a relayed message + pass + else: + raise TypeError("Content incorrect type: %s"%type(content)) + to_send.append(content) + flag = 0 + if buffers: + flag = zmq.SNDMORE + stream.send_multipart(to_send, flag, copy=False) + for b in buffers[:-1]: + stream.send(b, flag, copy=False) + if buffers: + stream.send(buffers[-1], copy=False) + omsg = Message(msg) + return omsg + + def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): + """receives and unpacks a message + returns [idents], msg""" + if isinstance(socket, ZMQStream): + socket = socket.socket + try: + msg = socket.recv_multipart(mode) + except zmq.ZMQError, e: + if e.errno == zmq.EAGAIN: + # We can convert EAGAIN to None as we know in this case + # recv_json won't return None. + return None + else: + raise + # return an actual Message object + # determine the number of idents by trying to unpack them. + # this is terrible: + idents, msg = self.feed_identities(msg, copy) + try: + return idents, self.unpack_message(msg, content=content, copy=copy) + except Exception, e: + print idents, msg + # TODO: handle it + raise e + + def feed_identities(self, msg, copy=True): + """This is a completely horrible thing, but it strips the zmq + ident prefixes off of a message. It will break if any identities + are unpackable by self.unpack.""" + msg = list(msg) + idents = [] + while len(msg) > 3: + if copy: + s = msg[0] + else: + s = msg[0].bytes + if s == DELIM: + msg.pop(0) + break + else: + idents.append(s) + msg.pop(0) + + return idents, msg + + def unpack_message(self, msg, content=True, copy=True): + """return a message object from the format + sent by self.send. + + parameters: + + content : bool (True) + whether to unpack the content dict (True), + or leave it serialized (False) + + copy : bool (True) + whether to return the bytes (True), + or the non-copying Message object in each place (False) + + """ + if not len(msg) >= 3: + raise TypeError("malformed message, must have at least 3 elements") + message = {} + if not copy: + for i in range(3): + msg[i] = msg[i].bytes + message['header'] = self.unpack(msg[0]) + message['msg_type'] = message['header']['msg_type'] + message['parent_header'] = self.unpack(msg[1]) + if content: + message['content'] = self.unpack(msg[2]) + else: + message['content'] = msg[2] + + # message['buffers'] = msg[3:] + # else: + # message['header'] = self.unpack(msg[0].bytes) + # message['msg_type'] = message['header']['msg_type'] + # message['parent_header'] = self.unpack(msg[1].bytes) + # if content: + # message['content'] = self.unpack(msg[2].bytes) + # else: + # message['content'] = msg[2].bytes + + message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ] + return message + + + +def test_msg2obj(): + am = dict(x=1) + ao = Message(am) + assert ao.x == am['x'] + + am['y'] = dict(z=1) + ao = Message(am) + assert ao.y.z == am['y']['z'] + + k1, k2 = 'y', 'z' + assert ao[k1][k2] == am[k1][k2] + + am2 = dict(ao) + assert am['x'] == am2['x'] + assert am['y']['z'] == am2['y']['z'] diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py new file mode 100644 index 0000000..b6d23ec --- /dev/null +++ b/IPython/zmq/parallel/view.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +"""Views""" + +from IPython.external.decorator import decorator + + +@decorator +def myblock(f, self, *args, **kwargs): + block = self.client.block + self.client.block = self.block + ret = f(self, *args, **kwargs) + self.client.block = block + return ret + +class View(object): + """Base View class""" + _targets = None + block=None + + def __init__(self, client, targets): + self.client = client + self._targets = targets + self.block = client.block + + def __repr__(self): + strtargets = str(self._targets) + if len(strtargets) > 16: + strtargets = strtargets[:12]+'...]' + return "<%s %s>"%(self.__class__.__name__, strtargets) + + @property + def results(self): + return self.client.results + + @property + def targets(self): + return self._targets + + @targets.setter + def targets(self, value): + raise TypeError("Cannot set my targets argument after construction!") + + def apply(self, f, *args, **kwargs): + """calls f(*args, **kwargs) on remote engines, returning the result. + + This method does not involve the engine's namespace. + + if self.block is False: + returns msg_id + else: + returns actual result of f(*args, **kwargs) + """ + return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=False) + + def apply_async(self, f, *args, **kwargs): + """calls f(*args, **kwargs) on remote engines in a nonblocking manner. + + This method does not involve the engine's namespace. + + returns msg_id + """ + return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False) + + def apply_sync(self, f, *args, **kwargs): + """calls f(*args, **kwargs) on remote engines in a blocking manner, + returning the result. + + This method does not involve the engine's namespace. + + returns: actual result of f(*args, **kwargs) + """ + return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False) + + def apply_bound(self, f, *args, **kwargs): + """calls f(*args, **kwargs) bound to engine namespace(s). + + if self.block is False: + returns msg_id + else: + returns actual result of f(*args, **kwargs) + + This method has access to the targets' globals + + """ + return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True) + + def apply_async_bound(self, f, *args, **kwargs): + """calls f(*args, **kwargs) bound to engine namespace(s) + in a nonblocking manner. + + returns: msg_id + + This method has access to the targets' globals + + """ + return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True) + + def apply_sync_bound(self, f, *args, **kwargs): + """calls f(*args, **kwargs) bound to engine namespace(s), waiting for the result. + + returns: actual result of f(*args, **kwargs) + + This method has access to the targets' globals + + """ + return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True) + + +class DirectView(View): + """Direct Multiplexer View""" + + def update(self, ns): + """update remote namespace with dict `ns`""" + return self.client.push(ns, targets=self.targets, block=self.block) + + def get(self, key_s): + """get object(s) by `key_s` from remote namespace + will return one object if it is a key. + It also takes a list of keys, and will return a list of objects.""" + # block = block if block is not None else self.block + return self.client.pull(key_s, block=self.block, targets=self.targets) + + push = update + pull = get + + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self,key,value): + self.update({key:value}) + + def clear(self): + """clear the remote namespace""" + return self.client.clear(targets=self.targets,block=self.block) + + def abort(self): + return self.client.abort(targets=self.targets,block=self.block) + +class LoadBalancedView(View): + _targets=None + \ No newline at end of file diff --git a/IPython/zmq/pickleutil.py b/IPython/zmq/pickleutil.py new file mode 100644 index 0000000..0191056 --- /dev/null +++ b/IPython/zmq/pickleutil.py @@ -0,0 +1,95 @@ +# encoding: utf-8 + +"""Pickle related utilities. Perhaps this should be called 'can'.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from types import FunctionType + +# contents of codeutil should either be in here, or codeutil belongs in IPython/util +from IPython.kernel import codeutil + +class CannedObject(object): + pass + +class CannedFunction(CannedObject): + + def __init__(self, f): + self._checkType(f) + self.code = f.func_code + + def _checkType(self, obj): + assert isinstance(obj, FunctionType), "Not a function type" + + def getFunction(self, g=None): + if g is None: + g = globals() + newFunc = FunctionType(self.code, g) + return newFunc + +def can(obj): + if isinstance(obj, FunctionType): + return CannedFunction(obj) + elif isinstance(obj,dict): + return canDict(obj) + elif isinstance(obj, (list,tuple)): + return canSequence(obj) + else: + return obj + +def canDict(obj): + if isinstance(obj, dict): + newobj = {} + for k, v in obj.iteritems(): + newobj[k] = can(v) + return newobj + else: + return obj + +def canSequence(obj): + if isinstance(obj, (list, tuple)): + t = type(obj) + return t([can(i) for i in obj]) + else: + return obj + +def uncan(obj, g=None): + if isinstance(obj, CannedFunction): + return obj.getFunction(g) + elif isinstance(obj,dict): + return uncanDict(obj) + elif isinstance(obj, (list,tuple)): + return uncanSequence(obj) + else: + return obj + +def uncanDict(obj, g=None): + if isinstance(obj, dict): + newobj = {} + for k, v in obj.iteritems(): + newobj[k] = uncan(v,g) + return newobj + else: + return obj + +def uncanSequence(obj, g=None): + if isinstance(obj, (list, tuple)): + t = type(obj) + return t([uncan(i,g) for i in obj]) + else: + return obj + + +def rebindFunctionGlobals(f, glbls): + return FunctionType(f.func_code, glbls) diff --git a/IPython/zmq/taskthread.py b/IPython/zmq/taskthread.py new file mode 100644 index 0000000..b1824ea --- /dev/null +++ b/IPython/zmq/taskthread.py @@ -0,0 +1,100 @@ +"""Thread for popping Tasks from zmq to Python Queue""" + + +import time +from threading import Thread + +try: + from queue import Queue +except: + from Queue import Queue + +import zmq +from zmq.core.poll import _poll as poll +from zmq.devices import ThreadDevice +from IPython.zmq import streamsession as ss + + +class QueueStream(object): + def __init__(self, in_queue, out_queue): + self.in_queue = in_queue + self.out_queue = out_queue + + def send_multipart(self, *args, **kwargs): + while self.out_queue.full(): + time.sleep(1e-3) + self.out_queue.put(('send_multipart', args, kwargs)) + + def send(self, *args, **kwargs): + while self.out_queue.full(): + time.sleep(1e-3) + self.out_queue.put(('send', args, kwargs)) + + def recv_multipart(self): + return self.in_queue.get() + + def empty(self): + return self.in_queue.empty() + +class TaskThread(ThreadDevice): + """Class for popping Tasks from C-ZMQ->Python Queue""" + max_qsize = 100 + in_socket = None + out_socket = None + # queue = None + + def __init__(self, queue_type, mon_type, engine_id, max_qsize=100): + ThreadDevice.__init__(self, 0, queue_type, mon_type) + self.session = ss.StreamSession(username='TaskNotifier[%s]'%engine_id) + self.engine_id = engine_id + self.in_queue = Queue(max_qsize) + self.out_queue = Queue(max_qsize) + self.max_qsize = max_qsize + + @property + def queues(self): + return self.in_queue, self.out_queue + + @property + def can_recv(self): + # print self.in_queue.full(), poll((self.queue_socket, zmq.POLLIN),1e-3) + return (not self.in_queue.full()) and poll([(self.queue_socket, zmq.POLLIN)], 1e-3 ) + + @property + def can_send(self): + return not self.out_queue.empty() + + def run(self): + print 'running' + self.queue_socket,self.mon_socket = self._setup_sockets() + print 'setup' + + while True: + while not self.can_send and not self.can_recv: + # print 'idle' + # nothing to do, wait + time.sleep(1e-3) + while self.can_send: + # flush out queue + print 'flushing...' + meth, args, kwargs = self.out_queue.get() + getattr(self.queue_socket, meth)(*args, **kwargs) + print 'flushed' + + if self.can_recv: + print 'recving' + # get another job from zmq + msg = self.queue_socket.recv_multipart(0, copy=False) + # put it in the Queue + self.in_queue.put(msg) + idents,msg = self.session.feed_identities(msg, copy=False) + msg = self.session.unpack_message(msg, content=False, copy=False) + # notify the Controller that we got it + self.mon_socket.send('tracktask', zmq.SNDMORE) + header = msg['header'] + msg_id = header['msg_id'] + content = dict(engine_id=self.engine_id, msg_id = msg_id) + self.session.send(self.mon_socket, 'task_receipt', content=content) + print 'recvd' + + \ No newline at end of file diff --git a/IPython/zmq/tests/test_controller.py b/IPython/zmq/tests/test_controller.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/IPython/zmq/tests/test_controller.py diff --git a/IPython/zmq/tests/test_newserialized.py b/IPython/zmq/tests/test_newserialized.py new file mode 100644 index 0000000..035bba1 --- /dev/null +++ b/IPython/zmq/tests/test_newserialized.py @@ -0,0 +1,4 @@ + +from unittest import TestCase +from zmq.tests import BaseZMQTest + diff --git a/IPython/zmq/tests/test_streamsession.py b/IPython/zmq/tests/test_streamsession.py new file mode 100755 index 0000000..041101b --- /dev/null +++ b/IPython/zmq/tests/test_streamsession.py @@ -0,0 +1,82 @@ + +import os +import uuid +import zmq + +from zmq.tests import BaseZMQTestCase + +from IPython.zmq.tests import SessionTestCase +from IPython.zmq import streamsession as ss + +class SessionTestCase(BaseZMQTestCase): + + def setUp(self): + BaseZMQTestCase.setUp(self) + self.session = ss.StreamSession() + +class TestSession(SessionTestCase): + + def test_msg(self): + """message format""" + msg = self.session.msg('execute') + thekeys = set('header msg_id parent_header msg_type content'.split()) + s = set(msg.keys()) + self.assertEquals(s, thekeys) + self.assertTrue(isinstance(msg['content'],dict)) + self.assertTrue(isinstance(msg['header'],dict)) + self.assertTrue(isinstance(msg['parent_header'],dict)) + self.assertEquals(msg['msg_type'], 'execute') + + + + def test_args(self): + """initialization arguments for StreamSession""" + s = ss.StreamSession() + self.assertTrue(s.pack is ss.default_packer) + self.assertTrue(s.unpack is ss.default_unpacker) + self.assertEquals(s.username, os.environ.get('USER', 'username')) + + s = ss.StreamSession(username=None) + self.assertEquals(s.username, os.environ.get('USER', 'username')) + + self.assertRaises(TypeError, ss.StreamSession, packer='hi') + self.assertRaises(TypeError, ss.StreamSession, unpacker='hi') + u = str(uuid.uuid4()) + s = ss.StreamSession(username='carrot', session=u) + self.assertEquals(s.session, u) + self.assertEquals(s.username, 'carrot') + + + def test_rekey(self): + """rekeying dict around json str keys""" + d = {'0': uuid.uuid4(), 0:uuid.uuid4()} + self.assertRaises(KeyError, ss.rekey, d) + + d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()} + d2 = {0:d['0'],1:d[1],'asdf':d['asdf']} + rd = ss.rekey(d) + self.assertEquals(d2,rd) + + d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()} + d2 = {1.5:d['1.5'],1:d['1']} + rd = ss.rekey(d) + self.assertEquals(d2,rd) + + d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()} + self.assertRaises(KeyError, ss.rekey, d) + + def test_unique_msg_ids(self): + """test that messages receive unique ids""" + ids = set() + for i in range(2**12): + h = self.session.msg_header('test') + msg_id = h['msg_id'] + self.assertTrue(msg_id not in ids) + ids.add(msg_id) + + def test_feed_identities(self): + """scrub the front for zmq IDENTITIES""" + theids = "engine client other".split() + content = dict(code='whoda',stuff=object()) + themsg = self.session.msg('execute',content=content) + pmsg = theids diff --git a/docs/source/development/index.txt b/docs/source/development/index.txt index 42a78fe..328c673 100644 --- a/docs/source/development/index.txt +++ b/docs/source/development/index.txt @@ -23,3 +23,4 @@ IPython developer's guide ipgraph.txt ipython_qt.txt ipythonzmq.txt + parallelzmq.txt diff --git a/examples/zmqontroller/config.py b/examples/zmqontroller/config.py new file mode 100644 index 0000000..f609476 --- /dev/null +++ b/examples/zmqontroller/config.py @@ -0,0 +1,23 @@ +"""setup the ports""" +config = { + 'interface': 'tcp://127.0.0.1', + 'regport': 10101, + 'heartport': 10102, + + 'cqueueport': 10211, + 'equeueport': 10111, + + 'ctaskport': 10221, + 'etaskport': 10121, + + 'ccontrolport': 10231, + 'econtrolport': 10131, + + 'clientport': 10201, + 'notifierport': 10202, + + 'logport': 20201 +} + + + diff --git a/examples/zmqontroller/controller.py b/examples/zmqontroller/controller.py new file mode 100644 index 0000000..05064f7 --- /dev/null +++ b/examples/zmqontroller/controller.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +"""A script to launch a controller with all its queues and connect it to a logger""" + +import time +import logging + +import zmq +from zmq.devices import ProcessMonitoredQueue, ThreadMonitoredQueue +from zmq.eventloop import ioloop +from zmq.eventloop.zmqstream import ZMQStream +from zmq.log import handlers + +from IPython.zmq import log +from IPython.zmq.parallel import controller, heartmonitor, streamsession as session + + + + +def setup(): + """setup a basic controller and open client,registrar, and logging ports. Start the Queue and the heartbeat""" + ctx = zmq.Context(1) + loop = ioloop.IOLoop.instance() + + # port config + # config={} + execfile('config.py', globals()) + iface = config['interface'] + logport = config['logport'] + rport = config['regport'] + cport = config['clientport'] + cqport = config['cqueueport'] + eqport = config['equeueport'] + ctport = config['ctaskport'] + etport = config['etaskport'] + ccport = config['ccontrolport'] + ecport = config['econtrolport'] + hport = config['heartport'] + nport = config['notifierport'] + + # setup logging + lsock = ctx.socket(zmq.PUB) + lsock.connect('%s:%i'%(iface,logport)) + # connected=False + # while not connected: + # try: + # except: + # logport = logport + 1 + # else: + # connected=True + # + handler = handlers.PUBHandler(lsock) + handler.setLevel(logging.DEBUG) + handler.root_topic = "controller" + log.logger.addHandler(handler) + time.sleep(.5) + + ### Engine connections ### + + # Engine registrar socket + reg = ZMQStream(ctx.socket(zmq.XREP), loop) + reg.bind("%s:%i"%(iface, rport)) + + # heartbeat + hpub = ctx.socket(zmq.PUB) + hpub.bind("%s:%i"%(iface, hport)) + hrep = ctx.socket(zmq.XREP) + hrep.bind("%s:%i"%(iface, hport+1)) + + hb = heartmonitor.HeartMonitor(loop, ZMQStream(hpub,loop), ZMQStream(hrep,loop),2500) + hb.start() + + ### Client connections ### + # Clientele socket + c = ZMQStream(ctx.socket(zmq.XREP), loop) + c.bind("%s:%i"%(iface, cport)) + + n = ZMQStream(ctx.socket(zmq.PUB), loop) + n.bind("%s:%i"%(iface, nport)) + + thesession = session.StreamSession(username="controller") + + + + # build and launch the queue + sub = ctx.socket(zmq.SUB) + sub.setsockopt(zmq.SUBSCRIBE, "") + monport = sub.bind_to_random_port(iface) + sub = ZMQStream(sub, loop) + + # Multiplexer Queue (in a Process) + q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out') + q.bind_in("%s:%i"%(iface, cqport)) + q.bind_out("%s:%i"%(iface, eqport)) + q.connect_mon("%s:%i"%(iface, monport)) + q.daemon=True + q.start() + + # Control Queue (in a Process) + q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol') + q.bind_in("%s:%i"%(iface, ccport)) + q.bind_out("%s:%i"%(iface, ecport)) + q.connect_mon("%s:%i"%(iface, monport)) + q.daemon=True + q.start() + + # Task Queue (in a Process) + q = ProcessMonitoredQueue(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask') + q.bind_in("%s:%i"%(iface, ctport)) + q.bind_out("%s:%i"%(iface, etport)) + q.connect_mon("%s:%i"%(iface, monport)) + q.daemon=True + q.start() + + time.sleep(.25) + + # build connection dicts + engine_addrs = { + 'control' : "%s:%i"%(iface, ecport), + 'queue': "%s:%i"%(iface, eqport), + 'heartbeat': ("%s:%i"%(iface, hport), "%s:%i"%(iface, hport+1)), + 'task' : "%s:%i"%(iface, etport), + 'monitor' : "%s:%i"%(iface, monport), + } + + client_addrs = { + 'control' : "%s:%i"%(iface, ccport), + 'controller': "%s:%i"%(iface, cport), + 'queue': "%s:%i"%(iface, cqport), + 'task' : "%s:%i"%(iface, ctport), + 'notification': "%s:%i"%(iface, nport) + } + con = controller.Controller(loop, thesession, sub, reg, hb, c, n, None, engine_addrs, client_addrs) + + return loop + + +if __name__ == '__main__': + loop = setup() + loop.start() \ No newline at end of file diff --git a/examples/zmqontroller/floodclient.py b/examples/zmqontroller/floodclient.py new file mode 100644 index 0000000..b7dd578 --- /dev/null +++ b/examples/zmqontroller/floodclient.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +import time +import zmq +from zmq.eventloop import ioloop +from zmq.eventloop.zmqstream import ZMQStream +from IPython.zmq import streamsession as session +Message = session.Message +# from IPython.zmq.messages import send_message_pickle as send_message +import uuid + +thesession = session.StreamSession() + +max_messages=10000 +printstep=1000 + +counter = dict(count=0, engines=1) + +def poit(msg): + print "POIT" + print msg + +def count(msg): + count = counter["count"] = counter["count"]+1 + if not count % printstep: + print "#########################" + print count, time.time()-counter['tic'] + +def unpack_and_print(msg): + global msg_counter + msg_counter += 1 + print msg + try: + msg = thesession.unpack_message(msg[-3:]) + except Exception, e: + print e + # pass + print msg + + +ctx = zmq.Context() + +loop = ioloop.IOLoop() +sock = ctx.socket(zmq.XREQ) +queue = ZMQStream(ctx.socket(zmq.XREQ), loop) +client = ZMQStream(sock, loop) +client.on_send(poit) +def check_engines(msg): + # client.on_recv(unpack_and_print) + queue.on_recv(count) + idents = msg[:-3] + msg = thesession.unpack_message(msg[-3:]) + msg = Message(msg) + print msg + queue.connect(str(msg.content.queue)) + engines = dict(msg.content.engines) + # global tic + N=max_messages + if engines: + tic = time.time() + counter['tic']= tic + for i in xrange(N/len(engines)): + for eid,key in engines.iteritems(): + thesession.send(queue, "execute_request", dict(code='id=%i'%(int(eid)+i)),ident=str(key)) + toc = time.time() + print "#####################################" + print N, toc-tic + print "#####################################" + + + + +client.on_recv(check_engines) + +sock.connect('tcp://127.0.0.1:10102') +sock.setsockopt(zmq.IDENTITY, thesession.username) +# stream = ZMQStream() +# header = dict(msg_id = uuid.uuid4().bytes, msg_type='relay', id=0) +parent = dict(targets=2) +# content = "GARBAGE" +thesession.send(client, "connection_request") + +# send_message(client, (header, content)) +# print thesession.recv(client, 0) + +loop.start() diff --git a/examples/zmqontroller/logwatcher.py b/examples/zmqontroller/logwatcher.py new file mode 100644 index 0000000..a65bf80 --- /dev/null +++ b/examples/zmqontroller/logwatcher.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +"""A simple log process that prints messages incoming from""" + +# +# Copyright (c) 2010 Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see . + +import zmq +logport = 20201 +def main(topics, addrs): + + context = zmq.Context() + socket = context.socket(zmq.SUB) + for topic in topics: + socket.setsockopt(zmq.SUBSCRIBE, topic) + if addrs: + for addr in addrs: + print "Connecting to: ", addr + socket.connect(addr) + else: + socket.bind('tcp://127.0.0.1:%i'%logport) + + while True: + # topic = socket.recv() + # print topic + topic, msg = socket.recv_multipart() + # msg = socket.recv_pyobj() + print "%s | %s " % (topic, msg), + +if __name__ == '__main__': + import sys + topics = [] + addrs = [] + for arg in sys.argv[1:]: + if '://' in arg: + addrs.append(arg) + else: + topics.append(arg) + if not topics: + # default to everything + topics = [''] + if len(addrs) < 1: + print "binding instead of connecting" + # addrs = ['tcp://127.0.0.1:%i'%p for p in range(logport,logport+10)] + # print "usage: display.py
[
...]" + # raise SystemExit + + main(topics, addrs)