From cd8465672be199996c6a4443a6b7e924a08f4e84 2011-04-08 00:38:13 From: MinRK Date: 2011-04-08 00:38:13 Subject: [PATCH] PendingResult->AsyncResult; match multiprocessing.AsyncResult api --- diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py new file mode 100644 index 0000000..869c2c2 --- /dev/null +++ b/IPython/zmq/parallel/asyncresult.py @@ -0,0 +1,112 @@ +"""AsyncResult objects for the client""" +#----------------------------------------------------------------------------- +# Copyright (C) 2010 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +import error + +#----------------------------------------------------------------------------- +# Classes +#----------------------------------------------------------------------------- + +class AsyncResult(object): + """Class for representing results of non-blocking calls. + + Provides the same interface as :py:class:`multiprocessing.AsyncResult`. + """ + def __init__(self, client, msg_ids): + self._client = client + self._msg_ids = msg_ids + self._ready = False + self._success = None + + def __repr__(self): + if self._ready: + return "<%s: finished>"%(self.__class__.__name__) + else: + return "<%s: %r>"%(self.__class__.__name__,self._msg_ids) + + + def _reconstruct_result(self, res): + """ + Override me in subclasses for turning a list of results + into the expected form. + """ + if len(res) == 1: + return res[0] + else: + return res + + def get(self, timeout=-1): + """Return the result when it arrives. + + If `timeout` is not ``None`` and the result does not arrive within + `timeout` seconds then ``TimeoutError`` is raised. If the + remote call raised an exception then that exception will be reraised + by get(). + """ + if not self.ready(): + self.wait(timeout) + + if self._ready: + if self._success: + return self._result + else: + raise self._exception + else: + raise error.TimeoutError("Result not ready.") + + def ready(self): + """Return whether the call has completed.""" + if not self._ready: + self.wait(0) + return self._ready + + def wait(self, timeout=-1): + """Wait until the result is available or until `timeout` seconds pass. + """ + if self._ready: + return + self._ready = self._client.barrier(self._msg_ids, timeout) + if self._ready: + try: + results = map(self._client.results.get, self._msg_ids) + results = error.collect_exceptions(results, 'get') + self._result = self._reconstruct_result(results) + except Exception, e: + self._exception = e + self._success = False + else: + self._success = True + + + def successful(self): + """Return whether the call completed without raising an exception. + + Will raise ``AssertionError`` if the result is not ready. + """ + assert self._ready + return self._success + +class AsyncMapResult(AsyncResult): + """Class for representing results of non-blocking gathers. + + This will properly reconstruct the gather. + """ + + def __init__(self, client, msg_ids, mapObject): + self._mapObject = mapObject + AsyncResult.__init__(self, client, msg_ids) + + def _reconstruct_result(self, res): + """Perform the gather on the actual results.""" + return self._mapObject.joinPartitions(res) + + diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 41acc15..e23ac0a 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -29,7 +29,7 @@ from view import DirectView, LoadBalancedView from dependency import Dependency, depend, require import error import map as Map -from pendingresult import PendingResult,PendingMapResult +from asyncresult import AsyncResult, AsyncMapResult from remotefunction import remote,parallel,ParallelFunction,RemoteFunction #-------------------------------------------------------------------------- @@ -746,7 +746,7 @@ class Client(object): self.barrier(msg_id) return self._maybe_raise(self.results[msg_id]) else: - return PendingResult(self, [msg_id]) + return AsyncResult(self, [msg_id]) def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None, after=None, follow=None): @@ -776,7 +776,7 @@ class Client(object): if block: self.barrier(msg_ids) else: - return PendingResult(self, msg_ids) + return AsyncResult(self, msg_ids) if len(msg_ids) == 1: return self._maybe_raise(self.results[msg_ids[0]]) else: @@ -785,12 +785,24 @@ class Client(object): result[target] = self.results[mid] return error.collect_exceptions(result, f.__name__) + #-------------------------------------------------------------------------- + # Map and decorators + #-------------------------------------------------------------------------- + def map(self, f, *sequences): """Parallel version of builtin `map`, using all our engines.""" pf = ParallelFunction(self, f, block=self.block, bound=True, targets='all') return pf.map(*sequences) + def parallel(self, bound=True, targets='all', block=True): + """Decorator for making a ParallelFunction""" + return parallel(self, bound=bound, targets=targets, block=block) + + def remote(self, bound=True, targets='all', block=True): + """Decorator for making a RemoteFunction""" + return remote(self, bound=bound, targets=targets, block=block) + #-------------------------------------------------------------------------- # Data movement #-------------------------------------------------------------------------- @@ -831,7 +843,7 @@ class Client(object): else: mid = self.push({key: partition}, targets=engineid, block=False) msg_ids.append(mid) - r = PendingResult(self, msg_ids) + r = AsyncResult(self, msg_ids) if block: r.wait() return @@ -850,7 +862,7 @@ class Client(object): for index, engineid in enumerate(targets): msg_ids.append(self.pull(key, targets=engineid,block=False)) - r = PendingMapResult(self, msg_ids, mapObject) + r = AsyncMapResult(self, msg_ids, mapObject) if block: r.wait() return r.result @@ -1002,6 +1014,6 @@ __all__ = [ 'Client', 'ParallelFunction', 'DirectView', 'LoadBalancedView', - 'PendingResult', - 'PendingMapResult' + 'AsyncResult', + 'AsyncMapResult' ] diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index 14ac61e..1177978 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -145,6 +145,9 @@ class SecurityError(KernelError): class FileTimeoutError(KernelError): pass +class TimeoutError(KernelError): + pass + class RemoteError(KernelError): """Error raised elsewhere""" ename=None diff --git a/IPython/zmq/parallel/heartmonitor.py b/IPython/zmq/parallel/heartmonitor.py index c612b37..34dcf6f 100644 --- a/IPython/zmq/parallel/heartmonitor.py +++ b/IPython/zmq/parallel/heartmonitor.py @@ -9,7 +9,7 @@ import time import uuid import zmq -from zmq.devices import ProcessDevice +from zmq.devices import ProcessDevice,ThreadDevice from zmq.eventloop import ioloop, zmqstream #internal @@ -27,7 +27,7 @@ class Heart(object): device=None id=None def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None): - self.device = ProcessDevice(zmq.FORWARDER, in_type, out_type) + self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type) self.device.daemon=True self.device.connect_in(in_addr) self.device.connect_out(out_addr) diff --git a/IPython/zmq/parallel/pendingresult.py b/IPython/zmq/parallel/pendingresult.py deleted file mode 100644 index 8f3dfd1..0000000 --- a/IPython/zmq/parallel/pendingresult.py +++ /dev/null @@ -1,75 +0,0 @@ -"""PendingResult objects for the client""" -#----------------------------------------------------------------------------- -# Copyright (C) 2010 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#----------------------------------------------------------------------------- - -#----------------------------------------------------------------------------- -# Imports -#----------------------------------------------------------------------------- - -import error - -#----------------------------------------------------------------------------- -# Classes -#----------------------------------------------------------------------------- - -class PendingResult(object): - """Class for representing results of non-blocking calls.""" - def __init__(self, client, msg_ids): - self.client = client - self.msg_ids = msg_ids - self._result = None - self.done = False - - def __repr__(self): - if self.done: - return "<%s: finished>"%(self.__class__.__name__) - else: - return "<%s: %r>"%(self.__class__.__name__,self.msg_ids) - - @property - def result(self): - if self._result is not None: - return self._result - if not self.done: - self.wait(0) - if self.done: - results = map(self.client.results.get, self.msg_ids) - results = error.collect_exceptions(results, 'get_result') - self._result = self.reconstruct_result(results) - return self._result - else: - raise error.ResultNotCompleted - - def reconstruct_result(self, res): - """ - Override me in subclasses for turning a list of results - into the expected form. - """ - if len(res) == 1: - return res[0] - else: - return res - - def wait(self, timout=-1): - self.done = self.client.barrier(self.msg_ids) - return self.done - -class PendingMapResult(PendingResult): - """Class for representing results of non-blocking gathers. - - This will properly reconstruct the gather. - """ - - def __init__(self, client, msg_ids, mapObject): - self.mapObject = mapObject - PendingResult.__init__(self, client, msg_ids) - - def reconstruct_result(self, res): - """Perform the gather on the actual results.""" - return self.mapObject.joinPartitions(res) - - diff --git a/IPython/zmq/parallel/remotefunction.py b/IPython/zmq/parallel/remotefunction.py index d085b58..027acbc 100644 --- a/IPython/zmq/parallel/remotefunction.py +++ b/IPython/zmq/parallel/remotefunction.py @@ -11,7 +11,7 @@ #----------------------------------------------------------------------------- import map as Map -from pendingresult import PendingMapResult +from asyncresult import AsyncMapResult #----------------------------------------------------------------------------- # Decorators @@ -126,10 +126,10 @@ class ParallelFunction(RemoteFunction): f=self.func mid = self.client.apply(f, args=args, block=False, bound=self.bound, - targets=engineid).msg_ids[0] + targets=engineid)._msg_ids[0] msg_ids.append(mid) - r = PendingMapResult(self.client, msg_ids, self.mapObject) + r = AsyncMapResult(self.client, msg_ids, self.mapObject) if self.block: r.wait() return r.result