From 0c043a69bbbbe986b0ed6c602afbec7d8e386782 2011-05-17 21:27:10 From: MinRK Date: 2011-05-17 21:27:10 Subject: [PATCH] add Client.resubmit for re-running tasks closes gh-411 * allow `content` in session.serialize to be a unicode object, because mongo+JSON cannot be relied upon to produce encoded bytes. --- diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index de810a2..ef2f246 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -1041,6 +1041,68 @@ class Client(HasTraits): ar.wait() return ar + + @spin_first + def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None): + """Resubmit one or more tasks. + + in-flight tasks may not be resubmitted. + + Parameters + ---------- + + indices_or_msg_ids : integer history index, str msg_id, or list of either + The indices or msg_ids of indices to be retrieved + + block : bool + Whether to wait for the result to be done + + Returns + ------- + + AsyncHubResult + A subclass of AsyncResult that retrieves results from the Hub + + """ + block = self.block if block is None else block + if indices_or_msg_ids is None: + indices_or_msg_ids = -1 + + if not isinstance(indices_or_msg_ids, (list,tuple)): + indices_or_msg_ids = [indices_or_msg_ids] + + theids = [] + for id in indices_or_msg_ids: + if isinstance(id, int): + id = self.history[id] + if not isinstance(id, str): + raise TypeError("indices must be str or int, not %r"%id) + theids.append(id) + + for msg_id in theids: + self.outstanding.discard(msg_id) + if msg_id in self.history: + self.history.remove(msg_id) + self.results.pop(msg_id, None) + self.metadata.pop(msg_id, None) + content = dict(msg_ids = theids) + + self.session.send(self._query_socket, 'resubmit_request', content) + + zmq.select([self._query_socket], [], []) + idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK) + if self.debug: + pprint(msg) + content = msg['content'] + if content['status'] != 'ok': + raise self._unwrap_exception(content) + + ar = AsyncHubResult(self, msg_ids=theids) + + if block: + ar.wait() + + return ar @spin_first def result_status(self, msg_ids, status_only=True): diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index 9b2ffcb..9136ef2 100755 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -268,8 +268,15 @@ class HubFactory(RegistrationFactory): } self.log.debug("Hub engine addrs: %s"%self.engine_info) self.log.debug("Hub client addrs: %s"%self.client_info) + + # resubmit stream + r = ZMQStream(ctx.socket(zmq.XREQ), loop) + url = util.disambiguate_url(self.client_info['task'][-1]) + r.setsockopt(zmq.IDENTITY, self.session.session) + r.connect(url) + self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor, - query=q, notifier=n, db=self.db, + query=q, notifier=n, resubmit=r, db=self.db, engine_info=self.engine_info, client_info=self.client_info, logname=self.log.name) @@ -315,8 +322,9 @@ class Hub(LoggingFactory): loop=Instance(ioloop.IOLoop) query=Instance(ZMQStream) monitor=Instance(ZMQStream) - heartmonitor=Instance(HeartMonitor) notifier=Instance(ZMQStream) + resubmit=Instance(ZMQStream) + heartmonitor=Instance(HeartMonitor) db=Instance(object) client_info=Dict() engine_info=Dict() @@ -379,6 +387,9 @@ class Hub(LoggingFactory): 'connection_request': self.connection_request, } + # ignore resubmit replies + self.resubmit.on_recv(lambda msg: None, copy=False) + self.log.info("hub::created hub") @property @@ -452,31 +463,31 @@ class Hub(LoggingFactory): def dispatch_monitor_traffic(self, msg): """all ME and Task queue messages come through here, as well as IOPub traffic.""" - self.log.debug("monitor traffic: %s"%msg[:2]) + self.log.debug("monitor traffic: %r"%msg[:2]) switch = msg[0] idents, msg = self.session.feed_identities(msg[1:]) if not idents: - self.log.error("Bad Monitor Message: %s"%msg) + self.log.error("Bad Monitor Message: %r"%msg) return handler = self.monitor_handlers.get(switch, None) if handler is not None: handler(idents, msg) else: - self.log.error("Invalid monitor topic: %s"%switch) + self.log.error("Invalid monitor topic: %r"%switch) def dispatch_query(self, msg): """Route registration requests and queries from clients.""" idents, msg = self.session.feed_identities(msg) if not idents: - self.log.error("Bad Query Message: %s"%msg) + self.log.error("Bad Query Message: %r"%msg) return client_id = idents[0] try: msg = self.session.unpack_message(msg, content=True) except: content = error.wrap_exception() - self.log.error("Bad Query Message: %s"%msg, exc_info=True) + self.log.error("Bad Query Message: %r"%msg, exc_info=True) self.session.send(self.query, "hub_error", ident=client_id, content=content) return @@ -484,16 +495,17 @@ class Hub(LoggingFactory): # print client_id, header, parent, content #switch on message type: msg_type = msg['msg_type'] - self.log.info("client::client %s requested %s"%(client_id, msg_type)) + self.log.info("client::client %r requested %r"%(client_id, msg_type)) handler = self.query_handlers.get(msg_type, None) try: - assert handler is not None, "Bad Message Type: %s"%msg_type + assert handler is not None, "Bad Message Type: %r"%msg_type except: content = error.wrap_exception() - self.log.error("Bad Message Type: %s"%msg_type, exc_info=True) + self.log.error("Bad Message Type: %r"%msg_type, exc_info=True) self.session.send(self.query, "hub_error", ident=client_id, content=content) return + else: handler(idents, msg) @@ -560,9 +572,9 @@ class Hub(LoggingFactory): # it's posible iopub arrived first: existing = self.db.get_record(msg_id) for key,evalue in existing.iteritems(): - rvalue = record[key] + rvalue = record.get(key, None) if evalue and rvalue and evalue != rvalue: - self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue)) + self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue)) elif evalue and not rvalue: record[key] = evalue self.db.update_record(msg_id, record) @@ -648,10 +660,22 @@ class Hub(LoggingFactory): try: # it's posible iopub arrived first: existing = self.db.get_record(msg_id) + if existing['resubmitted']: + for key in ('submitted', 'client_uuid', 'buffers'): + # don't clobber these keys on resubmit + # submitted and client_uuid should be different + # and buffers might be big, and shouldn't have changed + record.pop(key) + # still check content,header which should not change + # but are not expensive to compare as buffers + for key,evalue in existing.iteritems(): - rvalue = record[key] + if key.endswith('buffers'): + # don't compare buffers + continue + rvalue = record.get(key, None) if evalue and rvalue and evalue != rvalue: - self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue)) + self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue)) elif evalue and not rvalue: record[key] = evalue self.db.update_record(msg_id, record) @@ -1075,9 +1099,68 @@ class Hub(LoggingFactory): self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) - def resubmit_task(self, client_id, msg, buffers): - """Resubmit a task.""" - raise NotImplementedError + def resubmit_task(self, client_id, msg): + """Resubmit one or more tasks.""" + def finish(reply): + self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id) + + content = msg['content'] + msg_ids = content['msg_ids'] + reply = dict(status='ok') + try: + records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[ + 'header', 'content', 'buffers']) + except Exception: + self.log.error('db::db error finding tasks to resubmit', exc_info=True) + return finish(error.wrap_exception()) + + # validate msg_ids + found_ids = [ rec['msg_id'] for rec in records ] + invalid_ids = filter(lambda m: m in self.pending, found_ids) + if len(records) > len(msg_ids): + try: + raise RuntimeError("DB appears to be in an inconsistent state." + "More matching records were found than should exist") + except Exception: + return finish(error.wrap_exception()) + elif len(records) < len(msg_ids): + missing = [ m for m in msg_ids if m not in found_ids ] + try: + raise KeyError("No such msg(s): %s"%missing) + except KeyError: + return finish(error.wrap_exception()) + elif invalid_ids: + msg_id = invalid_ids[0] + try: + raise ValueError("Task %r appears to be inflight"%(msg_id)) + except Exception: + return finish(error.wrap_exception()) + + # clear the existing records + rec = empty_record() + map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted']) + rec['resubmitted'] = datetime.now() + rec['queue'] = 'task' + rec['client_uuid'] = client_id[0] + try: + for msg_id in msg_ids: + self.all_completed.discard(msg_id) + self.db.update_record(msg_id, rec) + except Exception: + self.log.error('db::db error upating record', exc_info=True) + reply = error.wrap_exception() + else: + # send the messages + for rec in records: + header = rec['header'] + msg = self.session.msg(header['msg_type']) + msg['content'] = rec['content'] + msg['header'] = header + msg['msg_id'] = rec['msg_id'] + self.session.send(self.resubmit, msg, buffers=rec['buffers']) + + finish(dict(status='ok')) + def _extract_record(self, rec): """decompose a TaskRecord dict into subsection of reply for get_result""" @@ -1124,12 +1207,20 @@ class Hub(LoggingFactory): for msg_id in msg_ids: if msg_id in self.pending: pending.append(msg_id) - elif msg_id in self.all_completed or msg_id in records: + elif msg_id in self.all_completed: completed.append(msg_id) if not statusonly: c,bufs = self._extract_record(records[msg_id]) content[msg_id] = c buffers.extend(bufs) + elif msg_id in records: + if rec['completed']: + completed.append(msg_id) + c,bufs = self._extract_record(records[msg_id]) + content[msg_id] = c + buffers.extend(bufs) + else: + pending.append(msg_id) else: try: raise KeyError('No such message: '+msg_id) diff --git a/IPython/parallel/streamsession.py b/IPython/parallel/streamsession.py index 3d1b9cf..1b7f5ca 100644 --- a/IPython/parallel/streamsession.py +++ b/IPython/parallel/streamsession.py @@ -186,6 +186,9 @@ class StreamSession(object): elif isinstance(content, bytes): # content is already packed, as in a relayed message pass + elif isinstance(content, unicode): + # should be bytes, but JSON often spits out unicode + content = content.encode('utf8') else: raise TypeError("Content incorrect type: %s"%type(content)) diff --git a/IPython/parallel/tests/test_client.py b/IPython/parallel/tests/test_client.py index ba8c372..4d7d42d 100644 --- a/IPython/parallel/tests/test_client.py +++ b/IPython/parallel/tests/test_client.py @@ -212,3 +212,26 @@ class TestClient(ClusterTestCase): time.sleep(0.25) self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids) + def test_resubmit(self): + def f(): + import random + return random.random() + v = self.client.load_balanced_view() + ar = v.apply_async(f) + r1 = ar.get(1) + ahr = self.client.resubmit(ar.msg_ids) + r2 = ahr.get(1) + self.assertFalse(r1 == r2) + + def test_resubmit_inflight(self): + """ensure ValueError on resubmit of inflight task""" + v = self.client.load_balanced_view() + ar = v.apply_async(time.sleep,1) + # give the message a chance to arrive + time.sleep(0.2) + self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids) + ar.get(2) + + def test_resubmit_badkey(self): + """ensure KeyError on resubmit of nonexistant task""" + self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid']) diff --git a/IPython/parallel/tests/test_lbview.py b/IPython/parallel/tests/test_lbview.py index 14a8211..9695a2e 100644 --- a/IPython/parallel/tests/test_lbview.py +++ b/IPython/parallel/tests/test_lbview.py @@ -36,7 +36,7 @@ class TestLoadBalancedView(ClusterTestCase): """test graceful handling of engine death (balanced)""" # self.add_engines(1) ar = self.view.apply_async(crash) - self.assertRaisesRemote(error.EngineError, ar.get) + self.assertRaisesRemote(error.EngineError, ar.get, 10) eid = ar.engine_id tic = time.time() while eid in self.client.ids and time.time()-tic < 5: