client.py
639 lines
| 21.9 KiB
| text/x-python
|
PythonLexer
|
r3539 | #!/usr/bin/env python | |
"""A semi-synchronous Client for the ZMQ controller""" | |||
import time | |||
|
r3540 | from pprint import pprint | |
|
r3539 | from IPython.external.decorator import decorator | |
import streamsession as ss | |||
import zmq | |||
|
r3548 | from zmq.eventloop import ioloop, zmqstream | |
|
r3539 | from remotenamespace import RemoteNamespace | |
from view import DirectView | |||
|
r3548 | from dependency import Dependency, depend, require | |
|
r3539 | ||
def _push(ns): | |||
globals().update(ns) | |||
def _pull(keys): | |||
g = globals() | |||
if isinstance(keys, (list,tuple)): | |||
return map(g.get, keys) | |||
else: | |||
return g.get(keys) | |||
def _clear(): | |||
globals().clear() | |||
def execute(code): | |||
exec code in globals() | |||
# decorators for methods: | |||
@decorator | |||
def spinfirst(f,self,*args,**kwargs): | |||
self.spin() | |||
return f(self, *args, **kwargs) | |||
@decorator | |||
def defaultblock(f, self, *args, **kwargs): | |||
block = kwargs.get('block',None) | |||
block = self.block if block is None else block | |||
saveblock = self.block | |||
self.block = block | |||
ret = f(self, *args, **kwargs) | |||
self.block = saveblock | |||
return ret | |||
|
r3540 | class AbortedTask(object): | |
def __init__(self, msg_id): | |||
self.msg_id = msg_id | |||
|
r3539 | # @decorator | |
# def checktargets(f): | |||
# @wraps(f) | |||
# def checked_method(self, *args, **kwargs): | |||
# self._build_targets(kwargs['targets']) | |||
# return f(self, *args, **kwargs) | |||
# return checked_method | |||
# class _ZMQEventLoopThread(threading.Thread): | |||
# | |||
# def __init__(self, loop): | |||
# self.loop = loop | |||
# threading.Thread.__init__(self) | |||
# | |||
# def run(self): | |||
# self.loop.start() | |||
# | |||
class Client(object): | |||
"""A semi-synchronous client to the IPython ZMQ controller | |||
Attributes | |||
---------- | |||
ids : set | |||
a set of engine IDs | |||
requesting the ids attribute always synchronizes | |||
the registration state. To request ids without synchronization, | |||
use _ids | |||
history : list of msg_ids | |||
a list of msg_ids, keeping track of all the execution | |||
messages you have submitted | |||
outstanding : set of msg_ids | |||
a set of msg_ids that have been submitted, but whose | |||
results have not been received | |||
results : dict | |||
a dict of all our results, keyed by msg_id | |||
block : bool | |||
determines default behavior when block not specified | |||
in execution methods | |||
Methods | |||
------- | |||
spin : flushes incoming results and registration state changes | |||
control methods spin, and requesting `ids` also ensures up to date | |||
barrier : wait on one or more msg_ids | |||
execution methods: apply/apply_bound/apply_to | |||
legacy: execute, run | |||
|
r3540 | query methods: queue_status, get_result | |
control methods: abort, kill | |||
|
r3539 | ||
""" | |||
_connected=False | |||
_engines=None | |||
registration_socket=None | |||
|
r3540 | query_socket=None | |
control_socket=None | |||
|
r3539 | notification_socket=None | |
queue_socket=None | |||
task_socket=None | |||
block = False | |||
outstanding=None | |||
results = None | |||
history = None | |||
|
r3540 | debug = False | |
|
r3539 | ||
|
r3540 | def __init__(self, addr, context=None, username=None, debug=False): | |
|
r3539 | if context is None: | |
context = zmq.Context() | |||
self.context = context | |||
self.addr = addr | |||
if username is None: | |||
self.session = ss.StreamSession() | |||
else: | |||
self.session = ss.StreamSession(username) | |||
self.registration_socket = self.context.socket(zmq.PAIR) | |||
self.registration_socket.setsockopt(zmq.IDENTITY, self.session.session) | |||
self.registration_socket.connect(addr) | |||
self._engines = {} | |||
self._ids = set() | |||
self.outstanding=set() | |||
self.results = {} | |||
self.history = [] | |||
|
r3540 | self.debug = debug | |
self.session.debug = debug | |||
|
r3539 | ||
self._notification_handlers = {'registration_notification' : self._register_engine, | |||
'unregistration_notification' : self._unregister_engine, | |||
} | |||
self._queue_handlers = {'execute_reply' : self._handle_execute_reply, | |||
'apply_reply' : self._handle_apply_reply} | |||
|
r3548 | self._connect() | |
|
r3539 | ||
@property | |||
def ids(self): | |||
self._flush_notifications() | |||
return self._ids | |||
def _update_engines(self, engines): | |||
for k,v in engines.iteritems(): | |||
eid = int(k) | |||
|
r3540 | self._engines[eid] = bytes(v) # force not unicode | |
|
r3539 | self._ids.add(eid) | |
def _build_targets(self, targets): | |||
if targets is None: | |||
targets = self._ids | |||
elif isinstance(targets, str): | |||
if targets.lower() == 'all': | |||
targets = self._ids | |||
else: | |||
raise TypeError("%r not valid str target, must be 'all'"%(targets)) | |||
elif isinstance(targets, int): | |||
targets = [targets] | |||
return [self._engines[t] for t in targets], list(targets) | |||
def _connect(self): | |||
"""setup all our socket connections to the controller""" | |||
if self._connected: | |||
return | |||
self._connected=True | |||
self.session.send(self.registration_socket, 'connection_request') | |||
|
r3540 | idents,msg = self.session.recv(self.registration_socket,mode=0) | |
if self.debug: | |||
pprint(msg) | |||
|
r3539 | msg = ss.Message(msg) | |
content = msg.content | |||
if content.status == 'ok': | |||
if content.queue: | |||
self.queue_socket = self.context.socket(zmq.PAIR) | |||
self.queue_socket.setsockopt(zmq.IDENTITY, self.session.session) | |||
self.queue_socket.connect(content.queue) | |||
if content.task: | |||
self.task_socket = self.context.socket(zmq.PAIR) | |||
self.task_socket.setsockopt(zmq.IDENTITY, self.session.session) | |||
self.task_socket.connect(content.task) | |||
if content.notification: | |||
self.notification_socket = self.context.socket(zmq.SUB) | |||
self.notification_socket.connect(content.notification) | |||
self.notification_socket.setsockopt(zmq.SUBSCRIBE, "") | |||
|
r3540 | if content.query: | |
self.query_socket = self.context.socket(zmq.PAIR) | |||
self.query_socket.setsockopt(zmq.IDENTITY, self.session.session) | |||
self.query_socket.connect(content.query) | |||
if content.control: | |||
self.control_socket = self.context.socket(zmq.PAIR) | |||
self.control_socket.setsockopt(zmq.IDENTITY, self.session.session) | |||
self.control_socket.connect(content.control) | |||
|
r3539 | self._update_engines(dict(content.engines)) | |
else: | |||
self._connected = False | |||
raise Exception("Failed to connect!") | |||
#### handlers and callbacks for incoming messages ####### | |||
def _register_engine(self, msg): | |||
content = msg['content'] | |||
eid = content['id'] | |||
d = {eid : content['queue']} | |||
self._update_engines(d) | |||
self._ids.add(int(eid)) | |||
def _unregister_engine(self, msg): | |||
# print 'unregister',msg | |||
content = msg['content'] | |||
eid = int(content['id']) | |||
if eid in self._ids: | |||
self._ids.remove(eid) | |||
self._engines.pop(eid) | |||
def _handle_execute_reply(self, msg): | |||
# msg_id = msg['msg_id'] | |||
parent = msg['parent_header'] | |||
msg_id = parent['msg_id'] | |||
if msg_id not in self.outstanding: | |||
print "got unknown result: %s"%msg_id | |||
else: | |||
self.outstanding.remove(msg_id) | |||
self.results[msg_id] = ss.unwrap_exception(msg['content']) | |||
def _handle_apply_reply(self, msg): | |||
|
r3540 | # pprint(msg) | |
|
r3539 | # msg_id = msg['msg_id'] | |
parent = msg['parent_header'] | |||
msg_id = parent['msg_id'] | |||
if msg_id not in self.outstanding: | |||
print "got unknown result: %s"%msg_id | |||
else: | |||
self.outstanding.remove(msg_id) | |||
content = msg['content'] | |||
if content['status'] == 'ok': | |||
self.results[msg_id] = ss.unserialize_object(msg['buffers']) | |||
|
r3540 | elif content['status'] == 'aborted': | |
self.results[msg_id] = AbortedTask(msg_id) | |||
elif content['status'] == 'resubmitted': | |||
pass # handle resubmission | |||
|
r3539 | else: | |
self.results[msg_id] = ss.unwrap_exception(content) | |||
def _flush_notifications(self): | |||
"flush incoming notifications of engine registrations" | |||
msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK) | |||
while msg is not None: | |||
|
r3540 | if self.debug: | |
pprint(msg) | |||
|
r3539 | msg = msg[-1] | |
msg_type = msg['msg_type'] | |||
handler = self._notification_handlers.get(msg_type, None) | |||
if handler is None: | |||
raise Exception("Unhandled message type: %s"%msg.msg_type) | |||
else: | |||
handler(msg) | |||
msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK) | |||
def _flush_results(self, sock): | |||
"flush incoming task or queue results" | |||
msg = self.session.recv(sock, mode=zmq.NOBLOCK) | |||
while msg is not None: | |||
|
r3540 | if self.debug: | |
pprint(msg) | |||
|
r3539 | msg = msg[-1] | |
msg_type = msg['msg_type'] | |||
handler = self._queue_handlers.get(msg_type, None) | |||
if handler is None: | |||
raise Exception("Unhandled message type: %s"%msg.msg_type) | |||
else: | |||
handler(msg) | |||
msg = self.session.recv(sock, mode=zmq.NOBLOCK) | |||
|
r3540 | def _flush_control(self, sock): | |
"flush incoming control replies" | |||
msg = self.session.recv(sock, mode=zmq.NOBLOCK) | |||
while msg is not None: | |||
if self.debug: | |||
pprint(msg) | |||
msg = self.session.recv(sock, mode=zmq.NOBLOCK) | |||
|
r3539 | ###### get/setitem ######## | |
def __getitem__(self, key): | |||
if isinstance(key, int): | |||
if key not in self.ids: | |||
raise IndexError("No such engine: %i"%key) | |||
return DirectView(self, key) | |||
if isinstance(key, slice): | |||
indices = range(len(self.ids))[key] | |||
ids = sorted(self._ids) | |||
key = [ ids[i] for i in indices ] | |||
# newkeys = sorted(self._ids)[thekeys[k]] | |||
if isinstance(key, (tuple, list, xrange)): | |||
_,targets = self._build_targets(list(key)) | |||
return DirectView(self, targets) | |||
else: | |||
raise TypeError("key by int/iterable of ints only, not %s"%(type(key))) | |||
############ begin real methods ############# | |||
def spin(self): | |||
"""flush incoming notifications and execution results.""" | |||
if self.notification_socket: | |||
self._flush_notifications() | |||
if self.queue_socket: | |||
self._flush_results(self.queue_socket) | |||
if self.task_socket: | |||
self._flush_results(self.task_socket) | |||
|
r3540 | if self.control_socket: | |
self._flush_control(self.control_socket) | |||
|
r3539 | ||
@spinfirst | |||
def queue_status(self, targets=None, verbose=False): | |||
"""fetch the status of engine queues | |||
Parameters | |||
---------- | |||
targets : int/str/list of ints/strs | |||
the engines on which to execute | |||
default : all | |||
verbose : bool | |||
|
r3540 | whether to return lengths only, or lists of ids for each element | |
|
r3539 | ||
""" | |||
targets = self._build_targets(targets)[1] | |||
content = dict(targets=targets) | |||
|
r3540 | self.session.send(self.query_socket, "queue_request", content=content) | |
idents,msg = self.session.recv(self.query_socket, 0) | |||
if self.debug: | |||
pprint(msg) | |||
|
r3539 | return msg['content'] | |
@spinfirst | |||
|
r3540 | @defaultblock | |
def clear(self, targets=None, block=None): | |||
|
r3539 | """clear the namespace in target(s)""" | |
|
r3540 | targets = self._build_targets(targets)[0] | |
print targets | |||
for t in targets: | |||
self.session.send(self.control_socket, 'clear_request', content={},ident=t) | |||
error = False | |||
if self.block: | |||
for i in range(len(targets)): | |||
idents,msg = self.session.recv(self.control_socket,0) | |||
if self.debug: | |||
pprint(msg) | |||
if msg['content']['status'] != 'ok': | |||
error = msg['content'] | |||
if error: | |||
return error | |||
|
r3539 | ||
@spinfirst | |||
|
r3540 | @defaultblock | |
def abort(self, msg_ids = None, targets=None, block=None): | |||
|
r3539 | """abort the Queues of target(s)""" | |
|
r3540 | targets = self._build_targets(targets)[0] | |
print targets | |||
if isinstance(msg_ids, basestring): | |||
msg_ids = [msg_ids] | |||
content = dict(msg_ids=msg_ids) | |||
for t in targets: | |||
self.session.send(self.control_socket, 'abort_request', | |||
content=content, ident=t) | |||
error = False | |||
if self.block: | |||
for i in range(len(targets)): | |||
idents,msg = self.session.recv(self.control_socket,0) | |||
if self.debug: | |||
pprint(msg) | |||
if msg['content']['status'] != 'ok': | |||
error = msg['content'] | |||
if error: | |||
return error | |||
|
r3539 | ||
|
r3540 | @spinfirst | |
@defaultblock | |||
def kill(self, targets=None, block=None): | |||
"""Terminates one or more engine processes.""" | |||
targets = self._build_targets(targets)[0] | |||
print targets | |||
for t in targets: | |||
self.session.send(self.control_socket, 'kill_request', content={},ident=t) | |||
error = False | |||
if self.block: | |||
for i in range(len(targets)): | |||
idents,msg = self.session.recv(self.control_socket,0) | |||
if self.debug: | |||
pprint(msg) | |||
if msg['content']['status'] != 'ok': | |||
error = msg['content'] | |||
if error: | |||
return error | |||
|
r3539 | @defaultblock | |
def execute(self, code, targets='all', block=None): | |||
"""executes `code` on `targets` in blocking or nonblocking manner. | |||
Parameters | |||
---------- | |||
code : str | |||
the code string to be executed | |||
targets : int/str/list of ints/strs | |||
the engines on which to execute | |||
default : all | |||
block : bool | |||
whether or not to wait until done | |||
""" | |||
# block = self.block if block is None else block | |||
# saveblock = self.block | |||
# self.block = block | |||
result = self.apply(execute, (code,), targets=targets, block=block, bound=True) | |||
# self.block = saveblock | |||
return result | |||
def run(self, code, block=None): | |||
"""runs `code` on an engine. | |||
Calls to this are load-balanced. | |||
Parameters | |||
---------- | |||
code : str | |||
the code string to be executed | |||
block : bool | |||
whether or not to wait until done | |||
""" | |||
result = self.apply(execute, (code,), targets=None, block=block, bound=False) | |||
return result | |||
|
r3548 | def _apply_balanced(self, f, args, kwargs, bound=True, block=None, | |
after=None, follow=None): | |||
|
r3539 | """the underlying method for applying functions in a load balanced | |
manner.""" | |||
block = block if block is not None else self.block | |||
bufs = ss.pack_apply_message(f,args,kwargs) | |||
content = dict(bound=bound) | |||
msg = self.session.send(self.task_socket, "apply_request", | |||
content=content, buffers=bufs) | |||
msg_id = msg['msg_id'] | |||
self.outstanding.add(msg_id) | |||
self.history.append(msg_id) | |||
if block: | |||
self.barrier(msg_id) | |||
return self.results[msg_id] | |||
else: | |||
return msg_id | |||
|
r3548 | def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None, | |
after=None, follow=None): | |||
|
r3539 | """Then underlying method for applying functions to specific engines.""" | |
|
r3548 | ||
|
r3539 | block = block if block is not None else self.block | |
|
r3548 | ||
|
r3539 | queues,targets = self._build_targets(targets) | |
|
r3540 | print queues | |
|
r3539 | bufs = ss.pack_apply_message(f,args,kwargs) | |
|
r3548 | if isinstance(after, Dependency): | |
after = after.as_dict() | |||
elif after is None: | |||
after = [] | |||
if isinstance(follow, Dependency): | |||
follow = follow.as_dict() | |||
elif follow is None: | |||
follow = [] | |||
subheader = dict(after=after, follow=follow) | |||
|
r3539 | content = dict(bound=bound) | |
msg_ids = [] | |||
for queue in queues: | |||
msg = self.session.send(self.queue_socket, "apply_request", | |||
|
r3548 | content=content, buffers=bufs,ident=queue, subheader=subheader) | |
|
r3539 | msg_id = msg['msg_id'] | |
self.outstanding.add(msg_id) | |||
self.history.append(msg_id) | |||
msg_ids.append(msg_id) | |||
if block: | |||
self.barrier(msg_ids) | |||
else: | |||
if len(msg_ids) == 1: | |||
return msg_ids[0] | |||
else: | |||
return msg_ids | |||
if len(msg_ids) == 1: | |||
return self.results[msg_ids[0]] | |||
else: | |||
result = {} | |||
for target,mid in zip(targets, msg_ids): | |||
result[target] = self.results[mid] | |||
return result | |||
|
r3548 | def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None, | |
after=None, follow=None): | |||
|
r3539 | """calls f(*args, **kwargs) on a remote engine(s), returning the result. | |
if self.block is False: | |||
returns msg_id or list of msg_ids | |||
else: | |||
returns actual result of f(*args, **kwargs) | |||
""" | |||
|
r3548 | # enforce types of f,args,kwrags | |
|
r3539 | args = args if args is not None else [] | |
kwargs = kwargs if kwargs is not None else {} | |||
|
r3548 | if not callable(f): | |
raise TypeError("f must be callable, not %s"%type(f)) | |||
|
r3540 | if not isinstance(args, (tuple, list)): | |
raise TypeError("args must be tuple or list, not %s"%type(args)) | |||
if not isinstance(kwargs, dict): | |||
raise TypeError("kwargs must be dict, not %s"%type(kwargs)) | |||
|
r3548 | ||
options = dict(bound=bound, block=block, after=after, follow=follow) | |||
|
r3539 | if targets is None: | |
|
r3548 | return self._apply_balanced(f, args, kwargs, **options) | |
|
r3539 | else: | |
|
r3548 | return self._apply_direct(f, args, kwargs, targets=targets, **options) | |
|
r3539 | ||
def push(self, ns, targets=None, block=None): | |||
"""push the contents of `ns` into the namespace on `target`""" | |||
if not isinstance(ns, dict): | |||
raise TypeError("Must be a dict, not %s"%type(ns)) | |||
result = self.apply(_push, (ns,), targets=targets, block=block,bound=True) | |||
return result | |||
@spinfirst | |||
def pull(self, keys, targets=None, block=True): | |||
"""pull objects from `target`'s namespace by `keys`""" | |||
result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True) | |||
return result | |||
def barrier(self, msg_ids=None, timeout=-1): | |||
"""waits on one or more `msg_ids`, for up to `timeout` seconds. | |||
Parameters | |||
---------- | |||
msg_ids : int, str, or list of ints and/or strs | |||
ints are indices to self.history | |||
strs are msg_ids | |||
default: wait on all outstanding messages | |||
timeout : float | |||
a time in seconds, after which to give up. | |||
default is -1, which means no timeout | |||
Returns | |||
------- | |||
True : when all msg_ids are done | |||
False : timeout reached, msg_ids still outstanding | |||
""" | |||
tic = time.time() | |||
if msg_ids is None: | |||
theids = self.outstanding | |||
else: | |||
if isinstance(msg_ids, (int, str)): | |||
msg_ids = [msg_ids] | |||
theids = set() | |||
for msg_id in msg_ids: | |||
if isinstance(msg_id, int): | |||
msg_id = self.history[msg_id] | |||
theids.add(msg_id) | |||
self.spin() | |||
while theids.intersection(self.outstanding): | |||
if timeout >= 0 and ( time.time()-tic ) > timeout: | |||
break | |||
time.sleep(1e-3) | |||
self.spin() | |||
return len(theids.intersection(self.outstanding)) == 0 | |||
@spinfirst | |||
def get_results(self, msg_ids,status_only=False): | |||
"""returns the result of the execute or task request with `msg_id`""" | |||
if not isinstance(msg_ids, (list,tuple)): | |||
msg_ids = [msg_ids] | |||
theids = [] | |||
for msg_id in msg_ids: | |||
if isinstance(msg_id, int): | |||
msg_id = self.history[msg_id] | |||
theids.append(msg_id) | |||
content = dict(msg_ids=theids, status_only=status_only) | |||
|
r3540 | msg = self.session.send(self.query_socket, "result_request", content=content) | |
zmq.select([self.query_socket], [], []) | |||
idents,msg = self.session.recv(self.query_socket, zmq.NOBLOCK) | |||
if self.debug: | |||
pprint(msg) | |||
|
r3539 | ||
# while True: | |||
# try: | |||
# except zmq.ZMQError: | |||
# time.sleep(1e-3) | |||
# continue | |||
# else: | |||
# break | |||
return msg['content'] | |||
|
r3548 | ||
class AsynClient(Client): | |||
"""An Asynchronous client, using the Tornado Event Loop""" | |||
io_loop = None | |||
queue_stream = None | |||
notifier_stream = None | |||
def __init__(self, addr, context=None, username=None, debug=False, io_loop=None): | |||
Client.__init__(self, addr, context, username, debug) | |||
if io_loop is None: | |||
io_loop = ioloop.IOLoop.instance() | |||
self.io_loop = io_loop | |||
self.queue_stream = zmqstream.ZMQStream(self.queue_socket, io_loop) | |||
self.control_stream = zmqstream.ZMQStream(self.control_socket, io_loop) | |||
self.task_stream = zmqstream.ZMQStream(self.task_socket, io_loop) | |||
self.notification_stream = zmqstream.ZMQStream(self.notification_socket, io_loop) | |||
|
r3539 | ||
|
r3548 | def spin(self): | |
for stream in (self.queue_stream, self.notifier_stream, | |||
self.task_stream, self.control_stream): | |||
stream.flush() | |||
|
r3539 |