From 6f48516f6a0054fb1f58dfa4bf424af7d3643e8b 2011-04-08 00:38:14 From: MinRK Date: 2011-04-08 00:38:14 Subject: [PATCH] improved client.get_results() behavior --- diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index 6328c8a..999e8fc 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -40,7 +40,7 @@ class AsyncResult(object): Override me in subclasses for turning a list of results into the expected form. """ - if len(res) == 1: + if len(self.msg_ids) == 1: return res[0] else: return res diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 0762f01..cfd857b 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -14,6 +14,7 @@ import os import time from getpass import getpass from pprint import pprint +from datetime import datetime import zmq from zmq.eventloop import ioloop, zmqstream @@ -29,6 +30,7 @@ import error import map as Map from asyncresult import AsyncResult, AsyncMapResult from remotefunction import remote,parallel,ParallelFunction,RemoteFunction +from util import ReverseDict #-------------------------------------------------------------------------- # helpers for implementing old MEC API via client.apply @@ -83,6 +85,11 @@ def defaultblock(f, self, *args, **kwargs): self.block = saveblock return ret + +#-------------------------------------------------------------------------- +# Classes +#-------------------------------------------------------------------------- + class AbortedTask(object): """A basic wrapper object describing an aborted task.""" def __init__(self, msg_id): @@ -233,10 +240,11 @@ class Client(object): tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs) else: self._registration_socket.connect(addr) - self._engines = {} + self._engines = ReverseDict() self._ids = set() self.outstanding=set() self.results = {} + self.metadata = {} self.history = [] self.debug = debug self.session.debug = debug @@ -342,9 +350,27 @@ class Client(object): if eid in self._ids: self._ids.remove(eid) self._engines.pop(eid) - + # + def _build_metadata(self, header, parent, content): + md = {'msg_id' : parent['msg_id'], + 'submitted' : datetime.strptime(parent['date'], ss.ISO8601), + 'started' : datetime.strptime(header['started'], ss.ISO8601), + 'completed' : datetime.strptime(header['date'], ss.ISO8601), + 'received' : datetime.now(), + 'engine_uuid' : header['engine'], + 'engine_id' : self._engines.get(header['engine'], None), + 'follow' : parent['follow'], + 'after' : parent['after'], + 'status' : content['status'] + } + return md + def _handle_execute_reply(self, msg): - """Save the reply to an execute_request into our results.""" + """Save the reply to an execute_request into our results. + + execute messages are never actually used. apply is used instead. + """ + parent = msg['parent_header'] msg_id = parent['msg_id'] if msg_id not in self.outstanding: @@ -362,8 +388,12 @@ class Client(object): else: self.outstanding.remove(msg_id) content = msg['content'] + header = msg['header'] + + self.metadata[msg_id] = self._build_metadata(header, parent, content) + if content['status'] == 'ok': - self.results[msg_id] = ss.unserialize_object(msg['buffers']) + self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0] elif content['status'] == 'aborted': self.results[msg_id] = error.AbortedTask(msg_id) elif content['status'] == 'resubmitted': @@ -372,10 +402,8 @@ class Client(object): else: e = ss.unwrap_exception(content) e_uuid = e.engine_info['engineid'] - for k,v in self._engines.iteritems(): - if v == e_uuid: - e.engine_info['engineid'] = k - break + eid = self._engines[e_uuid] + e.engine_info['engineid'] = eid self.results[msg_id] = e def _flush_notifications(self): @@ -882,6 +910,13 @@ class Client(object): status_only : bool (default: False) if False: return the actual results + + Returns + ------- + + results : dict + There will always be the keys 'pending' and 'completed', which will + be lists of msg_ids. """ if not isinstance(msg_ids, (list,tuple)): msg_ids = [msg_ids] @@ -895,11 +930,12 @@ class Client(object): completed = [] local_results = {} - for msg_id in list(theids): - if msg_id in self.results: - completed.append(msg_id) - local_results[msg_id] = self.results[msg_id] - theids.remove(msg_id) + # temporarily disable local shortcut + # for msg_id in list(theids): + # if msg_id in self.results: + # completed.append(msg_id) + # local_results[msg_id] = self.results[msg_id] + # theids.remove(msg_id) if theids: # some not locally cached content = dict(msg_ids=theids, status_only=status_only) @@ -911,16 +947,40 @@ class Client(object): content = msg['content'] if content['status'] != 'ok': raise ss.unwrap_exception(content) + buffers = msg['buffers'] else: content = dict(completed=[],pending=[]) - if not status_only: - # load cached results into result: - content['completed'].extend(completed) - content.update(local_results) - # update cache with results: - for msg_id in msg_ids: - if msg_id in content['completed']: - self.results[msg_id] = content[msg_id] + + content['completed'].extend(completed) + + if status_only: + return content + + failures = [] + # load cached results into result: + content.update(local_results) + # update cache with results: + for msg_id in sorted(theids): + if msg_id in content['completed']: + rec = content[msg_id] + parent = rec['header'] + header = rec['result_header'] + rcontent = rec['result_content'] + if isinstance(rcontent, str): + rcontent = self.session.unpack(rcontent) + + self.metadata[msg_id] = self._build_metadata(header, parent, rcontent) + + if rcontent['status'] == 'ok': + res,buffers = ss.unserialize_object(buffers) + else: + res = ss.unwrap_exception(rcontent) + failures.append(res) + + self.results[msg_id] = res + content[msg_id] = res + + error.collect_exceptions(failures, "get_results") return content @spinfirst @@ -945,7 +1005,7 @@ class Client(object): status = content.pop('status') if status != 'ok': raise ss.unwrap_exception(content) - return content + return ss.rekey(content) @spinfirst def purge_results(self, msg_ids=[], targets=[]): diff --git a/IPython/zmq/parallel/controller.py b/IPython/zmq/parallel/controller.py index ab0fba3..717f348 100755 --- a/IPython/zmq/parallel/controller.py +++ b/IPython/zmq/parallel/controller.py @@ -47,33 +47,6 @@ else: def _passer(*args, **kwargs): return -class ReverseDict(dict): - """simple double-keyed subset of dict methods.""" - - def __init__(self, *args, **kwargs): - dict.__init__(self, *args, **kwargs) - self.reverse = dict() - for key, value in self.iteritems(): - self.reverse[value] = key - - def __getitem__(self, key): - try: - return dict.__getitem__(self, key) - except KeyError: - return self.reverse[key] - - def __setitem__(self, key, value): - if key in self.reverse: - raise KeyError("Can't have key %r on both sides!"%key) - dict.__setitem__(self, key, value) - self.reverse[value] = key - - def pop(self, key): - value = dict.pop(self, key) - self.d1.pop(value) - return value - - def init_record(msg): """return an empty TaskRecord dict, with all keys initialized with None.""" header = msg['header'] @@ -484,6 +457,8 @@ class Controller(object): } if MongoDB is not None and isinstance(self.db, MongoDB): result['result_buffers'] = map(Binary, msg['buffers']) + else: + result['result_buffers'] = msg['buffers'] self.db.update_record(msg_id, result) else: logger.debug("queue:: unknown msg finished %s"%msg_id) @@ -552,6 +527,8 @@ class Controller(object): } if MongoDB is not None and isinstance(self.db, MongoDB): result['result_buffers'] = map(Binary, msg['buffers']) + else: + result['result_buffers'] = msg['buffers'] self.db.update_record(msg_id, result) else: @@ -831,14 +808,16 @@ class Controller(object): def get_results(self, client_id, msg): """Get the result of 1 or more messages.""" content = msg['content'] - msg_ids = set(content['msg_ids']) + msg_ids = sorted(set(content['msg_ids'])) statusonly = content.get('status_only', False) pending = [] completed = [] content = dict(status='ok') content['pending'] = pending content['completed'] = completed + buffers = [] if not statusonly: + content['results'] = {} records = self.db.find_records(dict(msg_id={'$in':msg_ids})) for msg_id in msg_ids: if msg_id in self.pending: @@ -846,7 +825,12 @@ class Controller(object): elif msg_id in self.all_completed: completed.append(msg_id) if not statusonly: - content[msg_id] = records[msg_id]['result_content'] + rec = records[msg_id] + content[msg_id] = { 'result_content': rec['result_content'], + 'header': rec['header'], + 'result_header' : rec['result_header'], + } + buffers.extend(map(str, rec['result_buffers'])) else: try: raise KeyError('No such message: '+msg_id) @@ -854,7 +838,8 @@ class Controller(object): content = wrap_exception() break self.session.send(self.clientele, "result_reply", content=content, - parent=msg, ident=client_id) + parent=msg, ident=client_id, + buffers=buffers) #------------------------------------------------------------------------- diff --git a/IPython/zmq/parallel/mongodb.py b/IPython/zmq/parallel/mongodb.py index 3361881..61e7836 100644 --- a/IPython/zmq/parallel/mongodb.py +++ b/IPython/zmq/parallel/mongodb.py @@ -35,7 +35,7 @@ class MongoDB(object): def update_record(self, msg_id, rec): """Update the data in an existing record.""" obj_id = self._table[msg_id] - self._records.update({'_id':obj_id}, rec) + self._records.update({'_id':obj_id}, {'$set': rec}) def drop_matching_records(self, check): """Remove a record from the DB.""" @@ -50,7 +50,11 @@ class MongoDB(object): """Find records matching a query dict.""" matches = list(self._records.find(check)) if id_only: - matches = [ rec['msg_id'] for rec in matches ] - return matches + return [ rec['msg_id'] for rec in matches ] + else: + data = {} + for rec in matches: + data[rec['msg_id']] = rec + return data diff --git a/IPython/zmq/parallel/remotefunction.py b/IPython/zmq/parallel/remotefunction.py index 027acbc..4035f1a 100644 --- a/IPython/zmq/parallel/remotefunction.py +++ b/IPython/zmq/parallel/remotefunction.py @@ -126,10 +126,10 @@ class ParallelFunction(RemoteFunction): f=self.func mid = self.client.apply(f, args=args, block=False, bound=self.bound, - targets=engineid)._msg_ids[0] + targets=engineid).msg_ids[0] msg_ids.append(mid) - r = AsyncMapResult(self.client, msg_ids, self.mapObject) + r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__) if self.block: r.wait() return r.result diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py index af3e0c1..58aba18 100644 --- a/IPython/zmq/parallel/streamsession.py +++ b/IPython/zmq/parallel/streamsession.py @@ -208,7 +208,7 @@ def unserialize_object(bufs): for s in sobj: if s.data is None: s.data = bufs.pop(0) - return uncanSequence(map(unserialize, sobj)) + return uncanSequence(map(unserialize, sobj)), bufs elif isinstance(sobj, dict): newobj = {} for k in sorted(sobj.iterkeys()): @@ -216,11 +216,11 @@ def unserialize_object(bufs): if s.data is None: s.data = bufs.pop(0) newobj[k] = uncan(unserialize(s)) - return newobj + return newobj, bufs else: if sobj.data is None: sobj.data = bufs.pop(0) - return uncan(unserialize(sobj)) + return uncan(unserialize(sobj)), bufs def pack_apply_message(f, args, kwargs, threshold=64e-6): """pack up a function, args, and kwargs to be sent over the wire diff --git a/IPython/zmq/parallel/util.py b/IPython/zmq/parallel/util.py new file mode 100644 index 0000000..67dfe26 --- /dev/null +++ b/IPython/zmq/parallel/util.py @@ -0,0 +1,35 @@ +"""some generic utilities""" + +class ReverseDict(dict): + """simple double-keyed subset of dict methods.""" + + def __init__(self, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self._reverse = dict() + for key, value in self.iteritems(): + self._reverse[value] = key + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self._reverse[key] + + def __setitem__(self, key, value): + if key in self._reverse: + raise KeyError("Can't have key %r on both sides!"%key) + dict.__setitem__(self, key, value) + self._reverse[value] = key + + def pop(self, key): + value = dict.pop(self, key) + self.d1.pop(value) + return value + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index 22cd46e..541faaa 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -183,6 +183,8 @@ class View(object): """Parallel version of builtin `map`, using this view's engines.""" if isinstance(self.targets, int): targets = [self.targets] + else: + targets = self.targets pf = ParallelFunction(self.client, f, block=self.block, bound=True, targets=targets) return pf.map(*sequences)