##// END OF EJS Templates
add Client.resubmit for re-running tasks...
MinRK -
Show More
@@ -1041,6 +1041,68 b' class Client(HasTraits):'
1041 1041 ar.wait()
1042 1042
1043 1043 return ar
1044
1045 @spin_first
1046 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1047 """Resubmit one or more tasks.
1048
1049 in-flight tasks may not be resubmitted.
1050
1051 Parameters
1052 ----------
1053
1054 indices_or_msg_ids : integer history index, str msg_id, or list of either
1055 The indices or msg_ids of indices to be retrieved
1056
1057 block : bool
1058 Whether to wait for the result to be done
1059
1060 Returns
1061 -------
1062
1063 AsyncHubResult
1064 A subclass of AsyncResult that retrieves results from the Hub
1065
1066 """
1067 block = self.block if block is None else block
1068 if indices_or_msg_ids is None:
1069 indices_or_msg_ids = -1
1070
1071 if not isinstance(indices_or_msg_ids, (list,tuple)):
1072 indices_or_msg_ids = [indices_or_msg_ids]
1073
1074 theids = []
1075 for id in indices_or_msg_ids:
1076 if isinstance(id, int):
1077 id = self.history[id]
1078 if not isinstance(id, str):
1079 raise TypeError("indices must be str or int, not %r"%id)
1080 theids.append(id)
1081
1082 for msg_id in theids:
1083 self.outstanding.discard(msg_id)
1084 if msg_id in self.history:
1085 self.history.remove(msg_id)
1086 self.results.pop(msg_id, None)
1087 self.metadata.pop(msg_id, None)
1088 content = dict(msg_ids = theids)
1089
1090 self.session.send(self._query_socket, 'resubmit_request', content)
1091
1092 zmq.select([self._query_socket], [], [])
1093 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1094 if self.debug:
1095 pprint(msg)
1096 content = msg['content']
1097 if content['status'] != 'ok':
1098 raise self._unwrap_exception(content)
1099
1100 ar = AsyncHubResult(self, msg_ids=theids)
1101
1102 if block:
1103 ar.wait()
1104
1105 return ar
1044 1106
1045 1107 @spin_first
1046 1108 def result_status(self, msg_ids, status_only=True):
@@ -268,8 +268,15 b' class HubFactory(RegistrationFactory):'
268 268 }
269 269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
270 270 self.log.debug("Hub client addrs: %s"%self.client_info)
271
272 # resubmit stream
273 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
274 url = util.disambiguate_url(self.client_info['task'][-1])
275 r.setsockopt(zmq.IDENTITY, self.session.session)
276 r.connect(url)
277
271 278 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
272 query=q, notifier=n, db=self.db,
279 query=q, notifier=n, resubmit=r, db=self.db,
273 280 engine_info=self.engine_info, client_info=self.client_info,
274 281 logname=self.log.name)
275 282
@@ -315,8 +322,9 b' class Hub(LoggingFactory):'
315 322 loop=Instance(ioloop.IOLoop)
316 323 query=Instance(ZMQStream)
317 324 monitor=Instance(ZMQStream)
318 heartmonitor=Instance(HeartMonitor)
319 325 notifier=Instance(ZMQStream)
326 resubmit=Instance(ZMQStream)
327 heartmonitor=Instance(HeartMonitor)
320 328 db=Instance(object)
321 329 client_info=Dict()
322 330 engine_info=Dict()
@@ -379,6 +387,9 b' class Hub(LoggingFactory):'
379 387 'connection_request': self.connection_request,
380 388 }
381 389
390 # ignore resubmit replies
391 self.resubmit.on_recv(lambda msg: None, copy=False)
392
382 393 self.log.info("hub::created hub")
383 394
384 395 @property
@@ -452,31 +463,31 b' class Hub(LoggingFactory):'
452 463 def dispatch_monitor_traffic(self, msg):
453 464 """all ME and Task queue messages come through here, as well as
454 465 IOPub traffic."""
455 self.log.debug("monitor traffic: %s"%msg[:2])
466 self.log.debug("monitor traffic: %r"%msg[:2])
456 467 switch = msg[0]
457 468 idents, msg = self.session.feed_identities(msg[1:])
458 469 if not idents:
459 self.log.error("Bad Monitor Message: %s"%msg)
470 self.log.error("Bad Monitor Message: %r"%msg)
460 471 return
461 472 handler = self.monitor_handlers.get(switch, None)
462 473 if handler is not None:
463 474 handler(idents, msg)
464 475 else:
465 self.log.error("Invalid monitor topic: %s"%switch)
476 self.log.error("Invalid monitor topic: %r"%switch)
466 477
467 478
468 479 def dispatch_query(self, msg):
469 480 """Route registration requests and queries from clients."""
470 481 idents, msg = self.session.feed_identities(msg)
471 482 if not idents:
472 self.log.error("Bad Query Message: %s"%msg)
483 self.log.error("Bad Query Message: %r"%msg)
473 484 return
474 485 client_id = idents[0]
475 486 try:
476 487 msg = self.session.unpack_message(msg, content=True)
477 488 except:
478 489 content = error.wrap_exception()
479 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
480 491 self.session.send(self.query, "hub_error", ident=client_id,
481 492 content=content)
482 493 return
@@ -484,16 +495,17 b' class Hub(LoggingFactory):'
484 495 # print client_id, header, parent, content
485 496 #switch on message type:
486 497 msg_type = msg['msg_type']
487 self.log.info("client::client %s requested %s"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
488 499 handler = self.query_handlers.get(msg_type, None)
489 500 try:
490 assert handler is not None, "Bad Message Type: %s"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
491 502 except:
492 503 content = error.wrap_exception()
493 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
494 505 self.session.send(self.query, "hub_error", ident=client_id,
495 506 content=content)
496 507 return
508
497 509 else:
498 510 handler(idents, msg)
499 511
@@ -560,9 +572,9 b' class Hub(LoggingFactory):'
560 572 # it's posible iopub arrived first:
561 573 existing = self.db.get_record(msg_id)
562 574 for key,evalue in existing.iteritems():
563 rvalue = record[key]
575 rvalue = record.get(key, None)
564 576 if evalue and rvalue and evalue != rvalue:
565 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
577 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
566 578 elif evalue and not rvalue:
567 579 record[key] = evalue
568 580 self.db.update_record(msg_id, record)
@@ -648,10 +660,22 b' class Hub(LoggingFactory):'
648 660 try:
649 661 # it's posible iopub arrived first:
650 662 existing = self.db.get_record(msg_id)
663 if existing['resubmitted']:
664 for key in ('submitted', 'client_uuid', 'buffers'):
665 # don't clobber these keys on resubmit
666 # submitted and client_uuid should be different
667 # and buffers might be big, and shouldn't have changed
668 record.pop(key)
669 # still check content,header which should not change
670 # but are not expensive to compare as buffers
671
651 672 for key,evalue in existing.iteritems():
652 rvalue = record[key]
673 if key.endswith('buffers'):
674 # don't compare buffers
675 continue
676 rvalue = record.get(key, None)
653 677 if evalue and rvalue and evalue != rvalue:
654 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
678 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
655 679 elif evalue and not rvalue:
656 680 record[key] = evalue
657 681 self.db.update_record(msg_id, record)
@@ -1075,9 +1099,68 b' class Hub(LoggingFactory):'
1075 1099
1076 1100 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1077 1101
1078 def resubmit_task(self, client_id, msg, buffers):
1079 """Resubmit a task."""
1080 raise NotImplementedError
1102 def resubmit_task(self, client_id, msg):
1103 """Resubmit one or more tasks."""
1104 def finish(reply):
1105 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1106
1107 content = msg['content']
1108 msg_ids = content['msg_ids']
1109 reply = dict(status='ok')
1110 try:
1111 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1112 'header', 'content', 'buffers'])
1113 except Exception:
1114 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1115 return finish(error.wrap_exception())
1116
1117 # validate msg_ids
1118 found_ids = [ rec['msg_id'] for rec in records ]
1119 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1120 if len(records) > len(msg_ids):
1121 try:
1122 raise RuntimeError("DB appears to be in an inconsistent state."
1123 "More matching records were found than should exist")
1124 except Exception:
1125 return finish(error.wrap_exception())
1126 elif len(records) < len(msg_ids):
1127 missing = [ m for m in msg_ids if m not in found_ids ]
1128 try:
1129 raise KeyError("No such msg(s): %s"%missing)
1130 except KeyError:
1131 return finish(error.wrap_exception())
1132 elif invalid_ids:
1133 msg_id = invalid_ids[0]
1134 try:
1135 raise ValueError("Task %r appears to be inflight"%(msg_id))
1136 except Exception:
1137 return finish(error.wrap_exception())
1138
1139 # clear the existing records
1140 rec = empty_record()
1141 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 rec['resubmitted'] = datetime.now()
1143 rec['queue'] = 'task'
1144 rec['client_uuid'] = client_id[0]
1145 try:
1146 for msg_id in msg_ids:
1147 self.all_completed.discard(msg_id)
1148 self.db.update_record(msg_id, rec)
1149 except Exception:
1150 self.log.error('db::db error upating record', exc_info=True)
1151 reply = error.wrap_exception()
1152 else:
1153 # send the messages
1154 for rec in records:
1155 header = rec['header']
1156 msg = self.session.msg(header['msg_type'])
1157 msg['content'] = rec['content']
1158 msg['header'] = header
1159 msg['msg_id'] = rec['msg_id']
1160 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1161
1162 finish(dict(status='ok'))
1163
1081 1164
1082 1165 def _extract_record(self, rec):
1083 1166 """decompose a TaskRecord dict into subsection of reply for get_result"""
@@ -1124,12 +1207,20 b' class Hub(LoggingFactory):'
1124 1207 for msg_id in msg_ids:
1125 1208 if msg_id in self.pending:
1126 1209 pending.append(msg_id)
1127 elif msg_id in self.all_completed or msg_id in records:
1210 elif msg_id in self.all_completed:
1128 1211 completed.append(msg_id)
1129 1212 if not statusonly:
1130 1213 c,bufs = self._extract_record(records[msg_id])
1131 1214 content[msg_id] = c
1132 1215 buffers.extend(bufs)
1216 elif msg_id in records:
1217 if rec['completed']:
1218 completed.append(msg_id)
1219 c,bufs = self._extract_record(records[msg_id])
1220 content[msg_id] = c
1221 buffers.extend(bufs)
1222 else:
1223 pending.append(msg_id)
1133 1224 else:
1134 1225 try:
1135 1226 raise KeyError('No such message: '+msg_id)
@@ -186,6 +186,9 b' class StreamSession(object):'
186 186 elif isinstance(content, bytes):
187 187 # content is already packed, as in a relayed message
188 188 pass
189 elif isinstance(content, unicode):
190 # should be bytes, but JSON often spits out unicode
191 content = content.encode('utf8')
189 192 else:
190 193 raise TypeError("Content incorrect type: %s"%type(content))
191 194
@@ -212,3 +212,26 b' class TestClient(ClusterTestCase):'
212 212 time.sleep(0.25)
213 213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214 214
215 def test_resubmit(self):
216 def f():
217 import random
218 return random.random()
219 v = self.client.load_balanced_view()
220 ar = v.apply_async(f)
221 r1 = ar.get(1)
222 ahr = self.client.resubmit(ar.msg_ids)
223 r2 = ahr.get(1)
224 self.assertFalse(r1 == r2)
225
226 def test_resubmit_inflight(self):
227 """ensure ValueError on resubmit of inflight task"""
228 v = self.client.load_balanced_view()
229 ar = v.apply_async(time.sleep,1)
230 # give the message a chance to arrive
231 time.sleep(0.2)
232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 ar.get(2)
234
235 def test_resubmit_badkey(self):
236 """ensure KeyError on resubmit of nonexistant task"""
237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
@@ -36,7 +36,7 b' class TestLoadBalancedView(ClusterTestCase):'
36 36 """test graceful handling of engine death (balanced)"""
37 37 # self.add_engines(1)
38 38 ar = self.view.apply_async(crash)
39 self.assertRaisesRemote(error.EngineError, ar.get)
39 self.assertRaisesRemote(error.EngineError, ar.get, 10)
40 40 eid = ar.engine_id
41 41 tic = time.time()
42 42 while eid in self.client.ids and time.time()-tic < 5:
General Comments 0
You need to be logged in to leave comments. Login now