diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index 999e8fc..e96b75b 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -10,12 +10,21 @@ # Imports #----------------------------------------------------------------------------- +from IPython.external.decorator import decorator import error #----------------------------------------------------------------------------- # Classes #----------------------------------------------------------------------------- +@decorator +def check_ready(f, self, *args, **kwargs): + """Call spin() to sync state prior to calling the method.""" + self.wait(0) + if not self._ready: + raise error.TimeoutError("result not ready") + return f(self, *args, **kwargs) + class AsyncResult(object): """Class for representing results of non-blocking calls. @@ -79,6 +88,7 @@ class AsyncResult(object): if self._ready: try: results = map(self._client.results.get, self.msg_ids) + self._result = results results = error.collect_exceptions(results, self._fname) self._result = self._reconstruct_result(results) except Exception, e: @@ -86,6 +96,8 @@ class AsyncResult(object): self._success = False else: self._success = True + finally: + self._metadata = map(self._client.metadata.get, self.msg_ids) def successful(self): @@ -95,7 +107,67 @@ class AsyncResult(object): """ assert self._ready return self._success + + #---------------------------------------------------------------- + # Extra methods not in mp.pool.AsyncResult + #---------------------------------------------------------------- + + def get_dict(self, timeout=-1): + """Get the results as a dict, keyed by engine_id.""" + results = self.get(timeout) + engine_ids = [md['engine_id'] for md in self._metadata ] + bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k)) + maxcount = bycount.count(bycount[-1]) + if maxcount > 1: + raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%( + maxcount, bycount[-1])) + + return dict(zip(engine_ids,results)) + + @property + @check_ready + def result(self): + """result property.""" + return self._result + + @property + @check_ready + def metadata(self): + """metadata property.""" + return self._metadata + + @property + @check_ready + def result_dict(self): + """result property as a dict.""" + return self.get_dict(0) + + #------------------------------------- + # dict-access + #------------------------------------- + + @check_ready + def __getitem__(self, key): + """getitem returns result value(s) if keyed by int/slice, or metadata if key is str. + """ + if isinstance(key, int): + return error.collect_exceptions([self._result[key]], self._fname)[0] + elif isinstance(key, slice): + return error.collect_exceptions(self._result[key], self._fname) + elif isinstance(key, basestring): + return [ md[key] for md in self._metadata ] + else: + raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key)) + + @check_ready + def __getattr__(self, key): + """getattr maps to getitem for convenient access to metadata.""" + if key not in self._metadata[0].keys(): + raise AttributeError("%r object has no attribute %r"%( + self.__class__.__name__, key)) + return self.__getitem__(key) + class AsyncMapResult(AsyncResult): """Class for representing results of non-blocking gathers. @@ -111,3 +183,4 @@ class AsyncMapResult(AsyncResult): return self._mapObject.joinPartitions(res) +__all__ = ['AsyncResult', 'AsyncMapResult'] \ No newline at end of file diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index 1177978..add5aac 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -250,7 +250,7 @@ class CompositeError(KernelError): et,ev,tb = sys.exc_info() -def collect_exceptions(rdict_or_list, method): +def collect_exceptions(rdict_or_list, method='unspecified'): """check a result dict for errors, and raise CompositeError if any exist. Passthrough otherwise.""" elist = []