diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index 0a76c04..500ca85 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -34,13 +34,17 @@ class AsyncResult(object): """ msg_ids = None + _targets = None + _tracker = None - def __init__(self, client, msg_ids, fname='unknown'): + def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None): self._client = client if isinstance(msg_ids, basestring): msg_ids = [msg_ids] self.msg_ids = msg_ids self._fname=fname + self._targets = targets + self._tracker = tracker self._ready = False self._success = None self._single_result = len(msg_ids) == 1 @@ -169,6 +173,19 @@ class AsyncResult(object): def __dict__(self): return self.get_dict(0) + + def abort(self): + """abort my tasks.""" + assert not self.ready(), "Can't abort, I am already done!" + return self.client.abort(self.msg_ids, targets=self._targets, block=True) + + @property + def sent(self): + """check whether my messages have been sent""" + if self._tracker is None: + return True + else: + return self._tracker.done #------------------------------------- # dict-access diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index efdcfd5..b0ac11b 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -356,6 +356,9 @@ class Client(HasTraits): 'apply_reply' : self._handle_apply_reply} self._connect(sshserver, ssh_kwargs) + def __del__(self): + """cleanup sockets, but _not_ context.""" + self.close() def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir): if ipython_dir is None: @@ -387,7 +390,8 @@ class Client(HasTraits): return snames = filter(lambda n: n.endswith('socket'), dir(self)) for socket in map(lambda name: getattr(self, name), snames): - socket.close() + if isinstance(socket, zmq.Socket) and not socket.closed: + socket.close() self._closed = True def _update_engines(self, engines): @@ -550,7 +554,6 @@ class Client(HasTraits): outstanding = self._outstanding_dict[uuid] for msg_id in list(outstanding): - print msg_id if msg_id in self.results: # we already continue @@ -796,7 +799,7 @@ class Client(HasTraits): if msg['content']['status'] != 'ok': error = self._unwrap_exception(msg['content']) if error: - return error + raise error @spinfirst @@ -840,7 +843,7 @@ class Client(HasTraits): if msg['content']['status'] != 'ok': error = self._unwrap_exception(msg['content']) if error: - return error + raise error @spinfirst @defaultblock @@ -945,7 +948,8 @@ class Client(HasTraits): @defaultblock def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None, balanced=None, - after=None, follow=None, timeout=None): + after=None, follow=None, timeout=None, + track=False): """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. This is the central execution command for the client. @@ -1003,6 +1007,9 @@ class Client(HasTraits): Specify an amount of time (in seconds) for the scheduler to wait for dependencies to be met before failing with a DependencyTimeout. + track : bool + whether to track non-copying sends. + [default False] after,follow,timeout only used if `balanced=True`. @@ -1044,7 +1051,7 @@ class Client(HasTraits): if not isinstance(kwargs, dict): raise TypeError("kwargs must be dict, not %s"%type(kwargs)) - options = dict(bound=bound, block=block, targets=targets) + options = dict(bound=bound, block=block, targets=targets, track=track) if balanced: return self._apply_balanced(f, args, kwargs, timeout=timeout, @@ -1057,7 +1064,7 @@ class Client(HasTraits): return self._apply_direct(f, args, kwargs, **options) def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None, - after=None, follow=None, timeout=None): + after=None, follow=None, timeout=None, track=None): """call f(*args, **kwargs) remotely in a load-balanced manner. This is a private method, see `apply` for details. @@ -1065,7 +1072,7 @@ class Client(HasTraits): """ loc = locals() - for name in ('bound', 'block'): + for name in ('bound', 'block', 'track'): assert loc[name] is not None, "kwarg %r must be specified!"%name if self._task_socket is None: @@ -1101,13 +1108,13 @@ class Client(HasTraits): content = dict(bound=bound) msg = self.session.send(self._task_socket, "apply_request", - content=content, buffers=bufs, subheader=subheader) + content=content, buffers=bufs, subheader=subheader, track=track) msg_id = msg['msg_id'] self.outstanding.add(msg_id) self.history.append(msg_id) self.metadata[msg_id]['submitted'] = datetime.now() - - ar = AsyncResult(self, [msg_id], fname=f.__name__) + tracker = None if track is False else msg['tracker'] + ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker) if block: try: return ar.get() @@ -1116,7 +1123,8 @@ class Client(HasTraits): else: return ar - def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None): + def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None, + track=None): """Then underlying method for applying functions to specific engines via the MUX queue. @@ -1124,7 +1132,7 @@ class Client(HasTraits): Not to be called directly! """ loc = locals() - for name in ('bound', 'block', 'targets'): + for name in ('bound', 'block', 'targets', 'track'): assert loc[name] is not None, "kwarg %r must be specified!"%name idents,targets = self._build_targets(targets) @@ -1134,15 +1142,22 @@ class Client(HasTraits): bufs = util.pack_apply_message(f,args,kwargs) msg_ids = [] + trackers = [] for ident in idents: msg = self.session.send(self._mux_socket, "apply_request", - content=content, buffers=bufs, ident=ident, subheader=subheader) + content=content, buffers=bufs, ident=ident, subheader=subheader, + track=track) + if track: + trackers.append(msg['tracker']) msg_id = msg['msg_id'] self.outstanding.add(msg_id) self._outstanding_dict[ident].add(msg_id) self.history.append(msg_id) msg_ids.append(msg_id) - ar = AsyncResult(self, msg_ids, fname=f.__name__) + + tracker = None if track is False else zmq.MessageTracker(*trackers) + ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) + if block: try: return ar.get() @@ -1230,11 +1245,11 @@ class Client(HasTraits): #-------------------------------------------------------------------------- @defaultblock - def push(self, ns, targets='all', block=None): + def push(self, ns, targets='all', block=None, track=False): """Push the contents of `ns` into the namespace on `target`""" if not isinstance(ns, dict): raise TypeError("Must be a dict, not %s"%type(ns)) - result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False) + result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track) if not block: return result @@ -1251,7 +1266,7 @@ class Client(HasTraits): return result @defaultblock - def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None): + def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False): """ Partition a Python sequence and send the partitions to a set of engines. """ @@ -1259,16 +1274,25 @@ class Client(HasTraits): mapObject = Map.dists[dist]() nparts = len(targets) msg_ids = [] + trackers = [] for index, engineid in enumerate(targets): partition = mapObject.getPartition(seq, index, nparts) if flatten and len(partition) == 1: - r = self.push({key: partition[0]}, targets=engineid, block=False) + r = self.push({key: partition[0]}, targets=engineid, block=False, track=track) else: - r = self.push({key: partition}, targets=engineid, block=False) + r = self.push({key: partition}, targets=engineid, block=False, track=track) msg_ids.extend(r.msg_ids) - r = AsyncResult(self, msg_ids, fname='scatter') + if track: + trackers.append(r._tracker) + + if track: + tracker = zmq.MessageTracker(*trackers) + else: + tracker = None + + r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker) if block: - r.get() + r.wait() else: return r diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py index e41dd75..5da452d 100644 --- a/IPython/zmq/parallel/streamsession.py +++ b/IPython/zmq/parallel/streamsession.py @@ -179,7 +179,7 @@ class StreamSession(object): return header.get('key', None) == self.key - def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None): + def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False): """Build and send a message via stream or socket. Parameters @@ -191,13 +191,34 @@ class StreamSession(object): Normally, msg_or_type will be a msg_type unless a message is being sent more than once. + content : dict or None + the content of the message (ignored if msg_or_type is a message) + buffers : list or None + the already-serialized buffers to be appended to the message + parent : Message or dict or None + the parent or parent header describing the parent of this message + subheader : dict or None + extra header keys for this message's header + ident : bytes or list of bytes + the zmq.IDENTITY routing path + track : bool + whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. + Returns ------- - (msg,sent) : tuple - msg : Message - the nice wrapped dict-like object containing the headers + msg : message dict + the constructed message + (msg,tracker) : (message dict, MessageTracker) + if track=True, then a 2-tuple will be returned, the first element being the constructed + message, and the second being the MessageTracker """ + + if not isinstance(stream, (zmq.Socket, ZMQStream)): + raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) + elif track and isinstance(stream, ZMQStream): + raise TypeError("ZMQStream cannot track messages") + if isinstance(msg_or_type, (Message, dict)): # we got a Message, not a msg_type # don't build a new Message @@ -205,6 +226,7 @@ class StreamSession(object): content = msg['content'] else: msg = self.msg(msg_or_type, content, parent, subheader) + buffers = [] if buffers is None else buffers to_send = [] if isinstance(ident, list): @@ -222,7 +244,7 @@ class StreamSession(object): content = self.none elif isinstance(content, dict): content = self.pack(content) - elif isinstance(content, str): + elif isinstance(content, bytes): # content is already packed, as in a relayed message pass else: @@ -231,16 +253,29 @@ class StreamSession(object): flag = 0 if buffers: flag = zmq.SNDMORE - stream.send_multipart(to_send, flag, copy=False) + _track = False + else: + _track=track + if track: + tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) + else: + tracker = stream.send_multipart(to_send, flag, copy=False) for b in buffers[:-1]: stream.send(b, flag, copy=False) if buffers: - stream.send(buffers[-1], copy=False) + if track: + tracker = stream.send(buffers[-1], copy=False, track=track) + else: + tracker = stream.send(buffers[-1], copy=False) + # omsg = Message(msg) if self.debug: pprint.pprint(msg) pprint.pprint(to_send) pprint.pprint(buffers) + + msg['tracker'] = tracker + return msg def send_raw(self, stream, msg, flags=0, copy=True, ident=None): @@ -250,7 +285,7 @@ class StreamSession(object): ---------- msg : list of sendable buffers""" to_send = [] - if isinstance(ident, str): + if isinstance(ident, bytes): ident = [ident] if ident is not None: to_send.extend(ident) diff --git a/IPython/zmq/parallel/tests/__init__.py b/IPython/zmq/parallel/tests/__init__.py index 27f491b..e9617a3 100644 --- a/IPython/zmq/parallel/tests/__init__.py +++ b/IPython/zmq/parallel/tests/__init__.py @@ -1,24 +1,26 @@ """toplevel setup/teardown for parallel tests.""" +import tempfile import time -from subprocess import Popen, PIPE +from subprocess import Popen, PIPE, STDOUT from IPython.zmq.parallel.ipcluster import launch_process from IPython.zmq.parallel.entry_point import select_random_ports processes = [] +blackhole = tempfile.TemporaryFile() # nose setup/teardown def setup(): - cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE) + cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT) processes.append(cp) time.sleep(.5) add_engine() - time.sleep(3) + time.sleep(2) def add_engine(profile='iptest'): - ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE) + ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT) # ep.start() processes.append(ep) return ep diff --git a/IPython/zmq/parallel/tests/clienttest.py b/IPython/zmq/parallel/tests/clienttest.py index ce55404..d70ee58 100644 --- a/IPython/zmq/parallel/tests/clienttest.py +++ b/IPython/zmq/parallel/tests/clienttest.py @@ -88,7 +88,9 @@ class ClusterTestCase(BaseZMQTestCase): self.base_engine_count=len(self.client.ids) self.engines=[] - # def tearDown(self): + def tearDown(self): + self.client.close() + BaseZMQTestCase.tearDown(self) # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ] # [ e.wait() for e in self.engines ] # while len(self.client.ids) > self.base_engine_count: diff --git a/IPython/zmq/parallel/tests/test_client.py b/IPython/zmq/parallel/tests/test_client.py index 27f2863..89cae51 100644 --- a/IPython/zmq/parallel/tests/test_client.py +++ b/IPython/zmq/parallel/tests/test_client.py @@ -2,6 +2,7 @@ import time from tempfile import mktemp import nose.tools as nt +import zmq from IPython.zmq.parallel import client as clientmod from IPython.zmq.parallel import error @@ -18,10 +19,9 @@ class TestClient(ClusterTestCase): self.assertEquals(len(self.client.ids), n+3) self.assertTrue - def test_segfault(self): - """test graceful handling of engine death""" + def test_segfault_task(self): + """test graceful handling of engine death (balanced)""" self.add_engines(1) - eid = self.client.ids[-1] ar = self.client.apply(segfault, block=False) self.assertRaisesRemote(error.EngineError, ar.get) eid = ar.engine_id @@ -29,6 +29,17 @@ class TestClient(ClusterTestCase): time.sleep(.01) self.client.spin() + def test_segfault_mux(self): + """test graceful handling of engine death (direct)""" + self.add_engines(1) + eid = self.client.ids[-1] + ar = self.client[eid].apply_async(segfault) + self.assertRaisesRemote(error.EngineError, ar.get) + eid = ar.engine_id + while eid in self.client.ids: + time.sleep(.01) + self.client.spin() + def test_view_indexing(self): """test index access for views""" self.add_engines(2) @@ -91,13 +102,14 @@ class TestClient(ClusterTestCase): def test_push_pull(self): """test pushing and pulling""" data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) + t = self.client.ids[-1] self.add_engines(2) push = self.client.push pull = self.client.pull self.client.block=True nengines = len(self.client) - push({'data':data}, targets=0) - d = pull('data', targets=0) + push({'data':data}, targets=t) + d = pull('data', targets=t) self.assertEquals(d, data) push({'data':data}) d = pull('data') @@ -119,15 +131,16 @@ class TestClient(ClusterTestCase): return 2.0*x self.add_engines(4) + t = self.client.ids[-1] self.client.block=True push = self.client.push pull = self.client.pull execute = self.client.execute - push({'testf':testf}, targets=0) - r = pull('testf', targets=0) + push({'testf':testf}, targets=t) + r = pull('testf', targets=t) self.assertEqual(r(1.0), testf(1.0)) - execute('r = testf(10)', targets=0) - r = pull('r', targets=0) + execute('r = testf(10)', targets=t) + r = pull('r', targets=t) self.assertEquals(r, testf(10)) ar = push({'testf':testf}, block=False) ar.get() @@ -135,8 +148,8 @@ class TestClient(ClusterTestCase): rlist = ar.get() for r in rlist: self.assertEqual(r(1.0), testf(1.0)) - execute("def g(x): return x*x", targets=0) - r = pull(('testf','g'),targets=0) + execute("def g(x): return x*x", targets=t) + r = pull(('testf','g'),targets=t) self.assertEquals((r[0](10),r[1](10)), (testf(10), 100)) def test_push_function_globals(self): @@ -173,7 +186,7 @@ class TestClient(ClusterTestCase): ids.remove(ids[-1]) self.assertNotEquals(ids, self.client._ids) - def test_arun_newline(self): + def test_run_newline(self): """test that run appends newline to files""" tmpfile = mktemp() with open(tmpfile, 'w') as f: @@ -184,4 +197,56 @@ class TestClient(ClusterTestCase): v.run(tmpfile, block=True) self.assertEquals(v.apply_sync_bound(lambda : g()), 5) - \ No newline at end of file + def test_apply_tracked(self): + """test tracking for apply""" + # self.add_engines(1) + t = self.client.ids[-1] + self.client.block=False + def echo(n=1024*1024, **kwargs): + return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs) + ar = echo(1) + self.assertTrue(ar._tracker is None) + self.assertTrue(ar.sent) + ar = echo(track=True) + self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) + self.assertEquals(ar.sent, ar._tracker.done) + ar._tracker.wait() + self.assertTrue(ar.sent) + + def test_push_tracked(self): + t = self.client.ids[-1] + ns = dict(x='x'*1024*1024) + ar = self.client.push(ns, targets=t, block=False) + self.assertTrue(ar._tracker is None) + self.assertTrue(ar.sent) + + ar = self.client.push(ns, targets=t, block=False, track=True) + self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) + self.assertEquals(ar.sent, ar._tracker.done) + ar._tracker.wait() + self.assertTrue(ar.sent) + ar.get() + + def test_scatter_tracked(self): + t = self.client.ids + x='x'*1024*1024 + ar = self.client.scatter('x', x, targets=t, block=False) + self.assertTrue(ar._tracker is None) + self.assertTrue(ar.sent) + + ar = self.client.scatter('x', x, targets=t, block=False, track=True) + self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) + self.assertEquals(ar.sent, ar._tracker.done) + ar._tracker.wait() + self.assertTrue(ar.sent) + ar.get() + + def test_remote_reference(self): + v = self.client[-1] + v['a'] = 123 + ra = clientmod.Reference('a') + b = v.apply_sync_bound(lambda x: x, ra) + self.assertEquals(b, 123) + self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra) + + diff --git a/IPython/zmq/parallel/tests/test_streamsession.py b/IPython/zmq/parallel/tests/test_streamsession.py index 643b53f..7a2b896 100644 --- a/IPython/zmq/parallel/tests/test_streamsession.py +++ b/IPython/zmq/parallel/tests/test_streamsession.py @@ -4,7 +4,7 @@ import uuid import zmq from zmq.tests import BaseZMQTestCase - +from zmq.eventloop.zmqstream import ZMQStream # from IPython.zmq.tests import SessionTestCase from IPython.zmq.parallel import streamsession as ss @@ -31,7 +31,7 @@ class TestSession(SessionTestCase): def test_args(self): """initialization arguments for StreamSession""" - s = ss.StreamSession() + s = self.session self.assertTrue(s.pack is ss.default_packer) self.assertTrue(s.unpack is ss.default_unpacker) self.assertEquals(s.username, os.environ.get('USER', 'username')) @@ -46,7 +46,24 @@ class TestSession(SessionTestCase): self.assertEquals(s.session, u) self.assertEquals(s.username, 'carrot') - + def test_tracking(self): + """test tracking messages""" + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + s = self.session + stream = ZMQStream(a) + msg = s.send(a, 'hello', track=False) + self.assertTrue(msg['tracker'] is None) + msg = s.send(a, 'hello', track=True) + self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) + M = zmq.Message(b'hi there', track=True) + msg = s.send(a, 'hello', buffers=[M], track=True) + t = msg['tracker'] + self.assertTrue(isinstance(t, zmq.MessageTracker)) + self.assertRaises(zmq.NotDone, t.wait, .1) + del M + t.wait(1) # this will raise + + # def test_rekey(self): # """rekeying dict around json str keys""" # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}