Show More
@@ -1043,6 +1043,68 b' class Client(HasTraits):' | |||||
1043 | return ar |
|
1043 | return ar | |
1044 |
|
1044 | |||
1045 | @spin_first |
|
1045 | @spin_first | |
|
1046 | def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None): | |||
|
1047 | """Resubmit one or more tasks. | |||
|
1048 | ||||
|
1049 | in-flight tasks may not be resubmitted. | |||
|
1050 | ||||
|
1051 | Parameters | |||
|
1052 | ---------- | |||
|
1053 | ||||
|
1054 | indices_or_msg_ids : integer history index, str msg_id, or list of either | |||
|
1055 | The indices or msg_ids of indices to be retrieved | |||
|
1056 | ||||
|
1057 | block : bool | |||
|
1058 | Whether to wait for the result to be done | |||
|
1059 | ||||
|
1060 | Returns | |||
|
1061 | ------- | |||
|
1062 | ||||
|
1063 | AsyncHubResult | |||
|
1064 | A subclass of AsyncResult that retrieves results from the Hub | |||
|
1065 | ||||
|
1066 | """ | |||
|
1067 | block = self.block if block is None else block | |||
|
1068 | if indices_or_msg_ids is None: | |||
|
1069 | indices_or_msg_ids = -1 | |||
|
1070 | ||||
|
1071 | if not isinstance(indices_or_msg_ids, (list,tuple)): | |||
|
1072 | indices_or_msg_ids = [indices_or_msg_ids] | |||
|
1073 | ||||
|
1074 | theids = [] | |||
|
1075 | for id in indices_or_msg_ids: | |||
|
1076 | if isinstance(id, int): | |||
|
1077 | id = self.history[id] | |||
|
1078 | if not isinstance(id, str): | |||
|
1079 | raise TypeError("indices must be str or int, not %r"%id) | |||
|
1080 | theids.append(id) | |||
|
1081 | ||||
|
1082 | for msg_id in theids: | |||
|
1083 | self.outstanding.discard(msg_id) | |||
|
1084 | if msg_id in self.history: | |||
|
1085 | self.history.remove(msg_id) | |||
|
1086 | self.results.pop(msg_id, None) | |||
|
1087 | self.metadata.pop(msg_id, None) | |||
|
1088 | content = dict(msg_ids = theids) | |||
|
1089 | ||||
|
1090 | self.session.send(self._query_socket, 'resubmit_request', content) | |||
|
1091 | ||||
|
1092 | zmq.select([self._query_socket], [], []) | |||
|
1093 | idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK) | |||
|
1094 | if self.debug: | |||
|
1095 | pprint(msg) | |||
|
1096 | content = msg['content'] | |||
|
1097 | if content['status'] != 'ok': | |||
|
1098 | raise self._unwrap_exception(content) | |||
|
1099 | ||||
|
1100 | ar = AsyncHubResult(self, msg_ids=theids) | |||
|
1101 | ||||
|
1102 | if block: | |||
|
1103 | ar.wait() | |||
|
1104 | ||||
|
1105 | return ar | |||
|
1106 | ||||
|
1107 | @spin_first | |||
1046 | def result_status(self, msg_ids, status_only=True): |
|
1108 | def result_status(self, msg_ids, status_only=True): | |
1047 | """Check on the status of the result(s) of the apply request with `msg_ids`. |
|
1109 | """Check on the status of the result(s) of the apply request with `msg_ids`. | |
1048 |
|
1110 |
@@ -268,8 +268,15 b' class HubFactory(RegistrationFactory):' | |||||
268 | } |
|
268 | } | |
269 | self.log.debug("Hub engine addrs: %s"%self.engine_info) |
|
269 | self.log.debug("Hub engine addrs: %s"%self.engine_info) | |
270 | self.log.debug("Hub client addrs: %s"%self.client_info) |
|
270 | self.log.debug("Hub client addrs: %s"%self.client_info) | |
|
271 | ||||
|
272 | # resubmit stream | |||
|
273 | r = ZMQStream(ctx.socket(zmq.XREQ), loop) | |||
|
274 | url = util.disambiguate_url(self.client_info['task'][-1]) | |||
|
275 | r.setsockopt(zmq.IDENTITY, self.session.session) | |||
|
276 | r.connect(url) | |||
|
277 | ||||
271 | self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor, |
|
278 | self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor, | |
272 | query=q, notifier=n, db=self.db, |
|
279 | query=q, notifier=n, resubmit=r, db=self.db, | |
273 | engine_info=self.engine_info, client_info=self.client_info, |
|
280 | engine_info=self.engine_info, client_info=self.client_info, | |
274 | logname=self.log.name) |
|
281 | logname=self.log.name) | |
275 |
|
282 | |||
@@ -315,8 +322,9 b' class Hub(LoggingFactory):' | |||||
315 | loop=Instance(ioloop.IOLoop) |
|
322 | loop=Instance(ioloop.IOLoop) | |
316 | query=Instance(ZMQStream) |
|
323 | query=Instance(ZMQStream) | |
317 | monitor=Instance(ZMQStream) |
|
324 | monitor=Instance(ZMQStream) | |
318 | heartmonitor=Instance(HeartMonitor) |
|
|||
319 | notifier=Instance(ZMQStream) |
|
325 | notifier=Instance(ZMQStream) | |
|
326 | resubmit=Instance(ZMQStream) | |||
|
327 | heartmonitor=Instance(HeartMonitor) | |||
320 | db=Instance(object) |
|
328 | db=Instance(object) | |
321 | client_info=Dict() |
|
329 | client_info=Dict() | |
322 | engine_info=Dict() |
|
330 | engine_info=Dict() | |
@@ -379,6 +387,9 b' class Hub(LoggingFactory):' | |||||
379 | 'connection_request': self.connection_request, |
|
387 | 'connection_request': self.connection_request, | |
380 | } |
|
388 | } | |
381 |
|
389 | |||
|
390 | # ignore resubmit replies | |||
|
391 | self.resubmit.on_recv(lambda msg: None, copy=False) | |||
|
392 | ||||
382 | self.log.info("hub::created hub") |
|
393 | self.log.info("hub::created hub") | |
383 |
|
394 | |||
384 | @property |
|
395 | @property | |
@@ -452,31 +463,31 b' class Hub(LoggingFactory):' | |||||
452 | def dispatch_monitor_traffic(self, msg): |
|
463 | def dispatch_monitor_traffic(self, msg): | |
453 | """all ME and Task queue messages come through here, as well as |
|
464 | """all ME and Task queue messages come through here, as well as | |
454 | IOPub traffic.""" |
|
465 | IOPub traffic.""" | |
455 |
self.log.debug("monitor traffic: % |
|
466 | self.log.debug("monitor traffic: %r"%msg[:2]) | |
456 | switch = msg[0] |
|
467 | switch = msg[0] | |
457 | idents, msg = self.session.feed_identities(msg[1:]) |
|
468 | idents, msg = self.session.feed_identities(msg[1:]) | |
458 | if not idents: |
|
469 | if not idents: | |
459 |
self.log.error("Bad Monitor Message: % |
|
470 | self.log.error("Bad Monitor Message: %r"%msg) | |
460 | return |
|
471 | return | |
461 | handler = self.monitor_handlers.get(switch, None) |
|
472 | handler = self.monitor_handlers.get(switch, None) | |
462 | if handler is not None: |
|
473 | if handler is not None: | |
463 | handler(idents, msg) |
|
474 | handler(idents, msg) | |
464 | else: |
|
475 | else: | |
465 |
self.log.error("Invalid monitor topic: % |
|
476 | self.log.error("Invalid monitor topic: %r"%switch) | |
466 |
|
477 | |||
467 |
|
478 | |||
468 | def dispatch_query(self, msg): |
|
479 | def dispatch_query(self, msg): | |
469 | """Route registration requests and queries from clients.""" |
|
480 | """Route registration requests and queries from clients.""" | |
470 | idents, msg = self.session.feed_identities(msg) |
|
481 | idents, msg = self.session.feed_identities(msg) | |
471 | if not idents: |
|
482 | if not idents: | |
472 |
self.log.error("Bad Query Message: % |
|
483 | self.log.error("Bad Query Message: %r"%msg) | |
473 | return |
|
484 | return | |
474 | client_id = idents[0] |
|
485 | client_id = idents[0] | |
475 | try: |
|
486 | try: | |
476 | msg = self.session.unpack_message(msg, content=True) |
|
487 | msg = self.session.unpack_message(msg, content=True) | |
477 | except: |
|
488 | except: | |
478 | content = error.wrap_exception() |
|
489 | content = error.wrap_exception() | |
479 |
self.log.error("Bad Query Message: % |
|
490 | self.log.error("Bad Query Message: %r"%msg, exc_info=True) | |
480 | self.session.send(self.query, "hub_error", ident=client_id, |
|
491 | self.session.send(self.query, "hub_error", ident=client_id, | |
481 | content=content) |
|
492 | content=content) | |
482 | return |
|
493 | return | |
@@ -484,16 +495,17 b' class Hub(LoggingFactory):' | |||||
484 | # print client_id, header, parent, content |
|
495 | # print client_id, header, parent, content | |
485 | #switch on message type: |
|
496 | #switch on message type: | |
486 | msg_type = msg['msg_type'] |
|
497 | msg_type = msg['msg_type'] | |
487 |
self.log.info("client::client % |
|
498 | self.log.info("client::client %r requested %r"%(client_id, msg_type)) | |
488 | handler = self.query_handlers.get(msg_type, None) |
|
499 | handler = self.query_handlers.get(msg_type, None) | |
489 | try: |
|
500 | try: | |
490 |
assert handler is not None, "Bad Message Type: % |
|
501 | assert handler is not None, "Bad Message Type: %r"%msg_type | |
491 | except: |
|
502 | except: | |
492 | content = error.wrap_exception() |
|
503 | content = error.wrap_exception() | |
493 |
self.log.error("Bad Message Type: % |
|
504 | self.log.error("Bad Message Type: %r"%msg_type, exc_info=True) | |
494 | self.session.send(self.query, "hub_error", ident=client_id, |
|
505 | self.session.send(self.query, "hub_error", ident=client_id, | |
495 | content=content) |
|
506 | content=content) | |
496 | return |
|
507 | return | |
|
508 | ||||
497 | else: |
|
509 | else: | |
498 | handler(idents, msg) |
|
510 | handler(idents, msg) | |
499 |
|
511 | |||
@@ -560,9 +572,9 b' class Hub(LoggingFactory):' | |||||
560 | # it's posible iopub arrived first: |
|
572 | # it's posible iopub arrived first: | |
561 | existing = self.db.get_record(msg_id) |
|
573 | existing = self.db.get_record(msg_id) | |
562 | for key,evalue in existing.iteritems(): |
|
574 | for key,evalue in existing.iteritems(): | |
563 |
rvalue = record |
|
575 | rvalue = record.get(key, None) | |
564 | if evalue and rvalue and evalue != rvalue: |
|
576 | if evalue and rvalue and evalue != rvalue: | |
565 |
self.log. |
|
577 | self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue)) | |
566 | elif evalue and not rvalue: |
|
578 | elif evalue and not rvalue: | |
567 | record[key] = evalue |
|
579 | record[key] = evalue | |
568 | self.db.update_record(msg_id, record) |
|
580 | self.db.update_record(msg_id, record) | |
@@ -648,10 +660,22 b' class Hub(LoggingFactory):' | |||||
648 | try: |
|
660 | try: | |
649 | # it's posible iopub arrived first: |
|
661 | # it's posible iopub arrived first: | |
650 | existing = self.db.get_record(msg_id) |
|
662 | existing = self.db.get_record(msg_id) | |
|
663 | if existing['resubmitted']: | |||
|
664 | for key in ('submitted', 'client_uuid', 'buffers'): | |||
|
665 | # don't clobber these keys on resubmit | |||
|
666 | # submitted and client_uuid should be different | |||
|
667 | # and buffers might be big, and shouldn't have changed | |||
|
668 | record.pop(key) | |||
|
669 | # still check content,header which should not change | |||
|
670 | # but are not expensive to compare as buffers | |||
|
671 | ||||
651 | for key,evalue in existing.iteritems(): |
|
672 | for key,evalue in existing.iteritems(): | |
652 | rvalue = record[key] |
|
673 | if key.endswith('buffers'): | |
|
674 | # don't compare buffers | |||
|
675 | continue | |||
|
676 | rvalue = record.get(key, None) | |||
653 | if evalue and rvalue and evalue != rvalue: |
|
677 | if evalue and rvalue and evalue != rvalue: | |
654 |
self.log. |
|
678 | self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue)) | |
655 | elif evalue and not rvalue: |
|
679 | elif evalue and not rvalue: | |
656 | record[key] = evalue |
|
680 | record[key] = evalue | |
657 | self.db.update_record(msg_id, record) |
|
681 | self.db.update_record(msg_id, record) | |
@@ -1075,9 +1099,68 b' class Hub(LoggingFactory):' | |||||
1075 |
|
1099 | |||
1076 | self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) |
|
1100 | self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) | |
1077 |
|
1101 | |||
1078 |
def resubmit_task(self, client_id, msg |
|
1102 | def resubmit_task(self, client_id, msg): | |
1079 |
"""Resubmit |
|
1103 | """Resubmit one or more tasks.""" | |
1080 | raise NotImplementedError |
|
1104 | def finish(reply): | |
|
1105 | self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id) | |||
|
1106 | ||||
|
1107 | content = msg['content'] | |||
|
1108 | msg_ids = content['msg_ids'] | |||
|
1109 | reply = dict(status='ok') | |||
|
1110 | try: | |||
|
1111 | records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[ | |||
|
1112 | 'header', 'content', 'buffers']) | |||
|
1113 | except Exception: | |||
|
1114 | self.log.error('db::db error finding tasks to resubmit', exc_info=True) | |||
|
1115 | return finish(error.wrap_exception()) | |||
|
1116 | ||||
|
1117 | # validate msg_ids | |||
|
1118 | found_ids = [ rec['msg_id'] for rec in records ] | |||
|
1119 | invalid_ids = filter(lambda m: m in self.pending, found_ids) | |||
|
1120 | if len(records) > len(msg_ids): | |||
|
1121 | try: | |||
|
1122 | raise RuntimeError("DB appears to be in an inconsistent state." | |||
|
1123 | "More matching records were found than should exist") | |||
|
1124 | except Exception: | |||
|
1125 | return finish(error.wrap_exception()) | |||
|
1126 | elif len(records) < len(msg_ids): | |||
|
1127 | missing = [ m for m in msg_ids if m not in found_ids ] | |||
|
1128 | try: | |||
|
1129 | raise KeyError("No such msg(s): %s"%missing) | |||
|
1130 | except KeyError: | |||
|
1131 | return finish(error.wrap_exception()) | |||
|
1132 | elif invalid_ids: | |||
|
1133 | msg_id = invalid_ids[0] | |||
|
1134 | try: | |||
|
1135 | raise ValueError("Task %r appears to be inflight"%(msg_id)) | |||
|
1136 | except Exception: | |||
|
1137 | return finish(error.wrap_exception()) | |||
|
1138 | ||||
|
1139 | # clear the existing records | |||
|
1140 | rec = empty_record() | |||
|
1141 | map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted']) | |||
|
1142 | rec['resubmitted'] = datetime.now() | |||
|
1143 | rec['queue'] = 'task' | |||
|
1144 | rec['client_uuid'] = client_id[0] | |||
|
1145 | try: | |||
|
1146 | for msg_id in msg_ids: | |||
|
1147 | self.all_completed.discard(msg_id) | |||
|
1148 | self.db.update_record(msg_id, rec) | |||
|
1149 | except Exception: | |||
|
1150 | self.log.error('db::db error upating record', exc_info=True) | |||
|
1151 | reply = error.wrap_exception() | |||
|
1152 | else: | |||
|
1153 | # send the messages | |||
|
1154 | for rec in records: | |||
|
1155 | header = rec['header'] | |||
|
1156 | msg = self.session.msg(header['msg_type']) | |||
|
1157 | msg['content'] = rec['content'] | |||
|
1158 | msg['header'] = header | |||
|
1159 | msg['msg_id'] = rec['msg_id'] | |||
|
1160 | self.session.send(self.resubmit, msg, buffers=rec['buffers']) | |||
|
1161 | ||||
|
1162 | finish(dict(status='ok')) | |||
|
1163 | ||||
1081 |
|
1164 | |||
1082 | def _extract_record(self, rec): |
|
1165 | def _extract_record(self, rec): | |
1083 | """decompose a TaskRecord dict into subsection of reply for get_result""" |
|
1166 | """decompose a TaskRecord dict into subsection of reply for get_result""" | |
@@ -1124,12 +1207,20 b' class Hub(LoggingFactory):' | |||||
1124 | for msg_id in msg_ids: |
|
1207 | for msg_id in msg_ids: | |
1125 | if msg_id in self.pending: |
|
1208 | if msg_id in self.pending: | |
1126 | pending.append(msg_id) |
|
1209 | pending.append(msg_id) | |
1127 |
elif msg_id in self.all_completed |
|
1210 | elif msg_id in self.all_completed: | |
1128 | completed.append(msg_id) |
|
1211 | completed.append(msg_id) | |
1129 | if not statusonly: |
|
1212 | if not statusonly: | |
1130 | c,bufs = self._extract_record(records[msg_id]) |
|
1213 | c,bufs = self._extract_record(records[msg_id]) | |
1131 | content[msg_id] = c |
|
1214 | content[msg_id] = c | |
1132 | buffers.extend(bufs) |
|
1215 | buffers.extend(bufs) | |
|
1216 | elif msg_id in records: | |||
|
1217 | if rec['completed']: | |||
|
1218 | completed.append(msg_id) | |||
|
1219 | c,bufs = self._extract_record(records[msg_id]) | |||
|
1220 | content[msg_id] = c | |||
|
1221 | buffers.extend(bufs) | |||
|
1222 | else: | |||
|
1223 | pending.append(msg_id) | |||
1133 | else: |
|
1224 | else: | |
1134 | try: |
|
1225 | try: | |
1135 | raise KeyError('No such message: '+msg_id) |
|
1226 | raise KeyError('No such message: '+msg_id) |
@@ -186,6 +186,9 b' class StreamSession(object):' | |||||
186 | elif isinstance(content, bytes): |
|
186 | elif isinstance(content, bytes): | |
187 | # content is already packed, as in a relayed message |
|
187 | # content is already packed, as in a relayed message | |
188 | pass |
|
188 | pass | |
|
189 | elif isinstance(content, unicode): | |||
|
190 | # should be bytes, but JSON often spits out unicode | |||
|
191 | content = content.encode('utf8') | |||
189 | else: |
|
192 | else: | |
190 | raise TypeError("Content incorrect type: %s"%type(content)) |
|
193 | raise TypeError("Content incorrect type: %s"%type(content)) | |
191 |
|
194 |
@@ -212,3 +212,26 b' class TestClient(ClusterTestCase):' | |||||
212 | time.sleep(0.25) |
|
212 | time.sleep(0.25) | |
213 | self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids) |
|
213 | self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids) | |
214 |
|
214 | |||
|
215 | def test_resubmit(self): | |||
|
216 | def f(): | |||
|
217 | import random | |||
|
218 | return random.random() | |||
|
219 | v = self.client.load_balanced_view() | |||
|
220 | ar = v.apply_async(f) | |||
|
221 | r1 = ar.get(1) | |||
|
222 | ahr = self.client.resubmit(ar.msg_ids) | |||
|
223 | r2 = ahr.get(1) | |||
|
224 | self.assertFalse(r1 == r2) | |||
|
225 | ||||
|
226 | def test_resubmit_inflight(self): | |||
|
227 | """ensure ValueError on resubmit of inflight task""" | |||
|
228 | v = self.client.load_balanced_view() | |||
|
229 | ar = v.apply_async(time.sleep,1) | |||
|
230 | # give the message a chance to arrive | |||
|
231 | time.sleep(0.2) | |||
|
232 | self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids) | |||
|
233 | ar.get(2) | |||
|
234 | ||||
|
235 | def test_resubmit_badkey(self): | |||
|
236 | """ensure KeyError on resubmit of nonexistant task""" | |||
|
237 | self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid']) |
@@ -36,7 +36,7 b' class TestLoadBalancedView(ClusterTestCase):' | |||||
36 | """test graceful handling of engine death (balanced)""" |
|
36 | """test graceful handling of engine death (balanced)""" | |
37 | # self.add_engines(1) |
|
37 | # self.add_engines(1) | |
38 | ar = self.view.apply_async(crash) |
|
38 | ar = self.view.apply_async(crash) | |
39 | self.assertRaisesRemote(error.EngineError, ar.get) |
|
39 | self.assertRaisesRemote(error.EngineError, ar.get, 10) | |
40 | eid = ar.engine_id |
|
40 | eid = ar.engine_id | |
41 | tic = time.time() |
|
41 | tic = time.time() | |
42 | while eid in self.client.ids and time.time()-tic < 5: |
|
42 | while eid in self.client.ids and time.time()-tic < 5: |
General Comments 0
You need to be logged in to leave comments.
Login now