##// END OF EJS Templates
improved client.get_results() behavior
MinRK -
Show More
@@ -0,0 +1,35 b''
1 """some generic utilities"""
2
3 class ReverseDict(dict):
4 """simple double-keyed subset of dict methods."""
5
6 def __init__(self, *args, **kwargs):
7 dict.__init__(self, *args, **kwargs)
8 self._reverse = dict()
9 for key, value in self.iteritems():
10 self._reverse[value] = key
11
12 def __getitem__(self, key):
13 try:
14 return dict.__getitem__(self, key)
15 except KeyError:
16 return self._reverse[key]
17
18 def __setitem__(self, key, value):
19 if key in self._reverse:
20 raise KeyError("Can't have key %r on both sides!"%key)
21 dict.__setitem__(self, key, value)
22 self._reverse[value] = key
23
24 def pop(self, key):
25 value = dict.pop(self, key)
26 self.d1.pop(value)
27 return value
28
29 def get(self, key, default=None):
30 try:
31 return self[key]
32 except KeyError:
33 return default
34
35
@@ -40,7 +40,7 b' class AsyncResult(object):'
40 40 Override me in subclasses for turning a list of results
41 41 into the expected form.
42 42 """
43 if len(res) == 1:
43 if len(self.msg_ids) == 1:
44 44 return res[0]
45 45 else:
46 46 return res
@@ -14,6 +14,7 b' import os'
14 14 import time
15 15 from getpass import getpass
16 16 from pprint import pprint
17 from datetime import datetime
17 18
18 19 import zmq
19 20 from zmq.eventloop import ioloop, zmqstream
@@ -29,6 +30,7 b' import error'
29 30 import map as Map
30 31 from asyncresult import AsyncResult, AsyncMapResult
31 32 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
33 from util import ReverseDict
32 34
33 35 #--------------------------------------------------------------------------
34 36 # helpers for implementing old MEC API via client.apply
@@ -83,6 +85,11 b' def defaultblock(f, self, *args, **kwargs):'
83 85 self.block = saveblock
84 86 return ret
85 87
88
89 #--------------------------------------------------------------------------
90 # Classes
91 #--------------------------------------------------------------------------
92
86 93 class AbortedTask(object):
87 94 """A basic wrapper object describing an aborted task."""
88 95 def __init__(self, msg_id):
@@ -233,10 +240,11 b' class Client(object):'
233 240 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
234 241 else:
235 242 self._registration_socket.connect(addr)
236 self._engines = {}
243 self._engines = ReverseDict()
237 244 self._ids = set()
238 245 self.outstanding=set()
239 246 self.results = {}
247 self.metadata = {}
240 248 self.history = []
241 249 self.debug = debug
242 250 self.session.debug = debug
@@ -342,9 +350,27 b' class Client(object):'
342 350 if eid in self._ids:
343 351 self._ids.remove(eid)
344 352 self._engines.pop(eid)
345
353 #
354 def _build_metadata(self, header, parent, content):
355 md = {'msg_id' : parent['msg_id'],
356 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
357 'started' : datetime.strptime(header['started'], ss.ISO8601),
358 'completed' : datetime.strptime(header['date'], ss.ISO8601),
359 'received' : datetime.now(),
360 'engine_uuid' : header['engine'],
361 'engine_id' : self._engines.get(header['engine'], None),
362 'follow' : parent['follow'],
363 'after' : parent['after'],
364 'status' : content['status']
365 }
366 return md
367
346 368 def _handle_execute_reply(self, msg):
347 """Save the reply to an execute_request into our results."""
369 """Save the reply to an execute_request into our results.
370
371 execute messages are never actually used. apply is used instead.
372 """
373
348 374 parent = msg['parent_header']
349 375 msg_id = parent['msg_id']
350 376 if msg_id not in self.outstanding:
@@ -362,8 +388,12 b' class Client(object):'
362 388 else:
363 389 self.outstanding.remove(msg_id)
364 390 content = msg['content']
391 header = msg['header']
392
393 self.metadata[msg_id] = self._build_metadata(header, parent, content)
394
365 395 if content['status'] == 'ok':
366 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
396 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
367 397 elif content['status'] == 'aborted':
368 398 self.results[msg_id] = error.AbortedTask(msg_id)
369 399 elif content['status'] == 'resubmitted':
@@ -372,10 +402,8 b' class Client(object):'
372 402 else:
373 403 e = ss.unwrap_exception(content)
374 404 e_uuid = e.engine_info['engineid']
375 for k,v in self._engines.iteritems():
376 if v == e_uuid:
377 e.engine_info['engineid'] = k
378 break
405 eid = self._engines[e_uuid]
406 e.engine_info['engineid'] = eid
379 407 self.results[msg_id] = e
380 408
381 409 def _flush_notifications(self):
@@ -882,6 +910,13 b' class Client(object):'
882 910 status_only : bool (default: False)
883 911 if False:
884 912 return the actual results
913
914 Returns
915 -------
916
917 results : dict
918 There will always be the keys 'pending' and 'completed', which will
919 be lists of msg_ids.
885 920 """
886 921 if not isinstance(msg_ids, (list,tuple)):
887 922 msg_ids = [msg_ids]
@@ -895,11 +930,12 b' class Client(object):'
895 930
896 931 completed = []
897 932 local_results = {}
898 for msg_id in list(theids):
899 if msg_id in self.results:
900 completed.append(msg_id)
901 local_results[msg_id] = self.results[msg_id]
902 theids.remove(msg_id)
933 # temporarily disable local shortcut
934 # for msg_id in list(theids):
935 # if msg_id in self.results:
936 # completed.append(msg_id)
937 # local_results[msg_id] = self.results[msg_id]
938 # theids.remove(msg_id)
903 939
904 940 if theids: # some not locally cached
905 941 content = dict(msg_ids=theids, status_only=status_only)
@@ -911,16 +947,40 b' class Client(object):'
911 947 content = msg['content']
912 948 if content['status'] != 'ok':
913 949 raise ss.unwrap_exception(content)
950 buffers = msg['buffers']
914 951 else:
915 952 content = dict(completed=[],pending=[])
916 if not status_only:
917 # load cached results into result:
918 content['completed'].extend(completed)
919 content.update(local_results)
920 # update cache with results:
921 for msg_id in msg_ids:
922 if msg_id in content['completed']:
923 self.results[msg_id] = content[msg_id]
953
954 content['completed'].extend(completed)
955
956 if status_only:
957 return content
958
959 failures = []
960 # load cached results into result:
961 content.update(local_results)
962 # update cache with results:
963 for msg_id in sorted(theids):
964 if msg_id in content['completed']:
965 rec = content[msg_id]
966 parent = rec['header']
967 header = rec['result_header']
968 rcontent = rec['result_content']
969 if isinstance(rcontent, str):
970 rcontent = self.session.unpack(rcontent)
971
972 self.metadata[msg_id] = self._build_metadata(header, parent, rcontent)
973
974 if rcontent['status'] == 'ok':
975 res,buffers = ss.unserialize_object(buffers)
976 else:
977 res = ss.unwrap_exception(rcontent)
978 failures.append(res)
979
980 self.results[msg_id] = res
981 content[msg_id] = res
982
983 error.collect_exceptions(failures, "get_results")
924 984 return content
925 985
926 986 @spinfirst
@@ -945,7 +1005,7 b' class Client(object):'
945 1005 status = content.pop('status')
946 1006 if status != 'ok':
947 1007 raise ss.unwrap_exception(content)
948 return content
1008 return ss.rekey(content)
949 1009
950 1010 @spinfirst
951 1011 def purge_results(self, msg_ids=[], targets=[]):
@@ -47,33 +47,6 b' else:'
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 class ReverseDict(dict):
51 """simple double-keyed subset of dict methods."""
52
53 def __init__(self, *args, **kwargs):
54 dict.__init__(self, *args, **kwargs)
55 self.reverse = dict()
56 for key, value in self.iteritems():
57 self.reverse[value] = key
58
59 def __getitem__(self, key):
60 try:
61 return dict.__getitem__(self, key)
62 except KeyError:
63 return self.reverse[key]
64
65 def __setitem__(self, key, value):
66 if key in self.reverse:
67 raise KeyError("Can't have key %r on both sides!"%key)
68 dict.__setitem__(self, key, value)
69 self.reverse[value] = key
70
71 def pop(self, key):
72 value = dict.pop(self, key)
73 self.d1.pop(value)
74 return value
75
76
77 50 def init_record(msg):
78 51 """return an empty TaskRecord dict, with all keys initialized with None."""
79 52 header = msg['header']
@@ -484,6 +457,8 b' class Controller(object):'
484 457 }
485 458 if MongoDB is not None and isinstance(self.db, MongoDB):
486 459 result['result_buffers'] = map(Binary, msg['buffers'])
460 else:
461 result['result_buffers'] = msg['buffers']
487 462 self.db.update_record(msg_id, result)
488 463 else:
489 464 logger.debug("queue:: unknown msg finished %s"%msg_id)
@@ -552,6 +527,8 b' class Controller(object):'
552 527 }
553 528 if MongoDB is not None and isinstance(self.db, MongoDB):
554 529 result['result_buffers'] = map(Binary, msg['buffers'])
530 else:
531 result['result_buffers'] = msg['buffers']
555 532 self.db.update_record(msg_id, result)
556 533
557 534 else:
@@ -831,14 +808,16 b' class Controller(object):'
831 808 def get_results(self, client_id, msg):
832 809 """Get the result of 1 or more messages."""
833 810 content = msg['content']
834 msg_ids = set(content['msg_ids'])
811 msg_ids = sorted(set(content['msg_ids']))
835 812 statusonly = content.get('status_only', False)
836 813 pending = []
837 814 completed = []
838 815 content = dict(status='ok')
839 816 content['pending'] = pending
840 817 content['completed'] = completed
818 buffers = []
841 819 if not statusonly:
820 content['results'] = {}
842 821 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
843 822 for msg_id in msg_ids:
844 823 if msg_id in self.pending:
@@ -846,7 +825,12 b' class Controller(object):'
846 825 elif msg_id in self.all_completed:
847 826 completed.append(msg_id)
848 827 if not statusonly:
849 content[msg_id] = records[msg_id]['result_content']
828 rec = records[msg_id]
829 content[msg_id] = { 'result_content': rec['result_content'],
830 'header': rec['header'],
831 'result_header' : rec['result_header'],
832 }
833 buffers.extend(map(str, rec['result_buffers']))
850 834 else:
851 835 try:
852 836 raise KeyError('No such message: '+msg_id)
@@ -854,7 +838,8 b' class Controller(object):'
854 838 content = wrap_exception()
855 839 break
856 840 self.session.send(self.clientele, "result_reply", content=content,
857 parent=msg, ident=client_id)
841 parent=msg, ident=client_id,
842 buffers=buffers)
858 843
859 844
860 845 #-------------------------------------------------------------------------
@@ -35,7 +35,7 b' class MongoDB(object):'
35 35 def update_record(self, msg_id, rec):
36 36 """Update the data in an existing record."""
37 37 obj_id = self._table[msg_id]
38 self._records.update({'_id':obj_id}, rec)
38 self._records.update({'_id':obj_id}, {'$set': rec})
39 39
40 40 def drop_matching_records(self, check):
41 41 """Remove a record from the DB."""
@@ -50,7 +50,11 b' class MongoDB(object):'
50 50 """Find records matching a query dict."""
51 51 matches = list(self._records.find(check))
52 52 if id_only:
53 matches = [ rec['msg_id'] for rec in matches ]
54 return matches
53 return [ rec['msg_id'] for rec in matches ]
54 else:
55 data = {}
56 for rec in matches:
57 data[rec['msg_id']] = rec
58 return data
55 59
56 60
@@ -126,10 +126,10 b' class ParallelFunction(RemoteFunction):'
126 126 f=self.func
127 127 mid = self.client.apply(f, args=args, block=False,
128 128 bound=self.bound,
129 targets=engineid)._msg_ids[0]
129 targets=engineid).msg_ids[0]
130 130 msg_ids.append(mid)
131 131
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject)
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
133 133 if self.block:
134 134 r.wait()
135 135 return r.result
@@ -208,7 +208,7 b' def unserialize_object(bufs):'
208 208 for s in sobj:
209 209 if s.data is None:
210 210 s.data = bufs.pop(0)
211 return uncanSequence(map(unserialize, sobj))
211 return uncanSequence(map(unserialize, sobj)), bufs
212 212 elif isinstance(sobj, dict):
213 213 newobj = {}
214 214 for k in sorted(sobj.iterkeys()):
@@ -216,11 +216,11 b' def unserialize_object(bufs):'
216 216 if s.data is None:
217 217 s.data = bufs.pop(0)
218 218 newobj[k] = uncan(unserialize(s))
219 return newobj
219 return newobj, bufs
220 220 else:
221 221 if sobj.data is None:
222 222 sobj.data = bufs.pop(0)
223 return uncan(unserialize(sobj))
223 return uncan(unserialize(sobj)), bufs
224 224
225 225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
226 226 """pack up a function, args, and kwargs to be sent over the wire
@@ -183,6 +183,8 b' class View(object):'
183 183 """Parallel version of builtin `map`, using this view's engines."""
184 184 if isinstance(self.targets, int):
185 185 targets = [self.targets]
186 else:
187 targets = self.targets
186 188 pf = ParallelFunction(self.client, f, block=self.block,
187 189 bound=True, targets=targets)
188 190 return pf.map(*sequences)
General Comments 0
You need to be logged in to leave comments. Login now