diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index 0031ce1..2ff03c1 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -21,9 +21,10 @@ class AsyncResult(object): Provides the same interface as :py:class:`multiprocessing.AsyncResult`. """ - def __init__(self, client, msg_ids): + def __init__(self, client, msg_ids, targets=None): self._client = client self.msg_ids = msg_ids + self._targets=targets self._ready = False self._success = None @@ -41,6 +42,8 @@ class AsyncResult(object): """ if len(res) == 1: return res[0] + elif self.targets is not None: + return dict(zip(self._targets, res)) else: return res diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 3b9df47..9eb3f22 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -632,7 +632,7 @@ class Client(object): whether or not to wait until done """ - result = self.apply(execute, (code,), targets=None, block=block, bound=False) + result = self.apply(_execute, (code,), targets=None, block=block, bound=False) return result def _maybe_raise(self, result): @@ -721,6 +721,18 @@ class Client(object): if not isinstance(kwargs, dict): raise TypeError("kwargs must be dict, not %s"%type(kwargs)) + if isinstance(after, Dependency): + after = after.as_dict() + elif isinstance(after, AsyncResult): + after=after.msg_ids + elif after is None: + after = [] + if isinstance(follow, Dependency): + follow = follow.as_dict() + elif isinstance(follow, AsyncResult): + follow=follow.msg_ids + elif follow is None: + follow = [] options = dict(bound=bound, block=block, after=after, follow=follow) if targets is None: @@ -732,18 +744,11 @@ class Client(object): after=None, follow=None): """The underlying method for applying functions in a load balanced manner, via the task queue.""" - if isinstance(after, Dependency): - after = after.as_dict() - elif after is None: - after = [] - if isinstance(follow, Dependency): - follow = follow.as_dict() - elif follow is None: - follow = [] - subheader = dict(after=after, follow=follow) + subheader = dict(after=after, follow=follow) bufs = ss.pack_apply_message(f,args,kwargs) content = dict(bound=bound) + msg = self.session.send(self._task_socket, "apply_request", content=content, buffers=bufs, subheader=subheader) msg_id = msg['msg_id'] @@ -761,17 +766,11 @@ class Client(object): via the MUX queue.""" queues,targets = self._build_targets(targets) - bufs = ss.pack_apply_message(f,args,kwargs) - if isinstance(after, Dependency): - after = after.as_dict() - elif after is None: - after = [] - if isinstance(follow, Dependency): - follow = follow.as_dict() - elif follow is None: - follow = [] + subheader = dict(after=after, follow=follow) content = dict(bound=bound) + bufs = ss.pack_apply_message(f,args,kwargs) + msg_ids = [] for queue in queues: msg = self.session.send(self._mux_socket, "apply_request", @@ -783,7 +782,7 @@ class Client(object): if block: self.barrier(msg_ids) else: - return AsyncResult(self, msg_ids) + return AsyncResult(self, msg_ids, targets=targets) if len(msg_ids) == 1: return self._maybe_raise(self.results[msg_ids[0]]) else: @@ -850,7 +849,7 @@ class Client(object): else: r = self.push({key: partition}, targets=engineid, block=False) msg_ids.extend(r.msg_ids) - r = AsyncResult(self, msg_ids) + r = AsyncResult(self, msg_ids,targets) if block: return r.get() else: diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index 305c794..22cd46e 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -263,21 +263,19 @@ class DirectView(View): Partition a Python sequence and send the partitions to a set of engines. """ block = block if block is not None else self.block - if targets is None: - targets = self.targets + targets = targets if targets is not None else self.targets return self.client.scatter(key, seq, dist=dist, flatten=flatten, targets=targets, block=block) @sync_results @save_ids - def gather(self, key, dist='b', targets=None, block=True): + def gather(self, key, dist='b', targets=None, block=None): """ Gather a partitioned sequence on a set of engines as a single local seq. """ block = block if block is not None else self.block - if targets is None: - targets = self.targets + targets = targets if targets is not None else self.targets return self.client.gather(key, dist=dist, targets=targets, block=block)