##// END OF EJS Templates
improved client.get_results() behavior
MinRK -
Show More
@@ -0,0 +1,35 b''
1 """some generic utilities"""
2
3 class ReverseDict(dict):
4 """simple double-keyed subset of dict methods."""
5
6 def __init__(self, *args, **kwargs):
7 dict.__init__(self, *args, **kwargs)
8 self._reverse = dict()
9 for key, value in self.iteritems():
10 self._reverse[value] = key
11
12 def __getitem__(self, key):
13 try:
14 return dict.__getitem__(self, key)
15 except KeyError:
16 return self._reverse[key]
17
18 def __setitem__(self, key, value):
19 if key in self._reverse:
20 raise KeyError("Can't have key %r on both sides!"%key)
21 dict.__setitem__(self, key, value)
22 self._reverse[value] = key
23
24 def pop(self, key):
25 value = dict.pop(self, key)
26 self.d1.pop(value)
27 return value
28
29 def get(self, key, default=None):
30 try:
31 return self[key]
32 except KeyError:
33 return default
34
35
@@ -40,7 +40,7 b' class AsyncResult(object):'
40 Override me in subclasses for turning a list of results
40 Override me in subclasses for turning a list of results
41 into the expected form.
41 into the expected form.
42 """
42 """
43 if len(res) == 1:
43 if len(self.msg_ids) == 1:
44 return res[0]
44 return res[0]
45 else:
45 else:
46 return res
46 return res
@@ -14,6 +14,7 b' import os'
14 import time
14 import time
15 from getpass import getpass
15 from getpass import getpass
16 from pprint import pprint
16 from pprint import pprint
17 from datetime import datetime
17
18
18 import zmq
19 import zmq
19 from zmq.eventloop import ioloop, zmqstream
20 from zmq.eventloop import ioloop, zmqstream
@@ -29,6 +30,7 b' import error'
29 import map as Map
30 import map as Map
30 from asyncresult import AsyncResult, AsyncMapResult
31 from asyncresult import AsyncResult, AsyncMapResult
31 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
32 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
33 from util import ReverseDict
32
34
33 #--------------------------------------------------------------------------
35 #--------------------------------------------------------------------------
34 # helpers for implementing old MEC API via client.apply
36 # helpers for implementing old MEC API via client.apply
@@ -83,6 +85,11 b' def defaultblock(f, self, *args, **kwargs):'
83 self.block = saveblock
85 self.block = saveblock
84 return ret
86 return ret
85
87
88
89 #--------------------------------------------------------------------------
90 # Classes
91 #--------------------------------------------------------------------------
92
86 class AbortedTask(object):
93 class AbortedTask(object):
87 """A basic wrapper object describing an aborted task."""
94 """A basic wrapper object describing an aborted task."""
88 def __init__(self, msg_id):
95 def __init__(self, msg_id):
@@ -233,10 +240,11 b' class Client(object):'
233 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
240 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
234 else:
241 else:
235 self._registration_socket.connect(addr)
242 self._registration_socket.connect(addr)
236 self._engines = {}
243 self._engines = ReverseDict()
237 self._ids = set()
244 self._ids = set()
238 self.outstanding=set()
245 self.outstanding=set()
239 self.results = {}
246 self.results = {}
247 self.metadata = {}
240 self.history = []
248 self.history = []
241 self.debug = debug
249 self.debug = debug
242 self.session.debug = debug
250 self.session.debug = debug
@@ -342,9 +350,27 b' class Client(object):'
342 if eid in self._ids:
350 if eid in self._ids:
343 self._ids.remove(eid)
351 self._ids.remove(eid)
344 self._engines.pop(eid)
352 self._engines.pop(eid)
345
353 #
354 def _build_metadata(self, header, parent, content):
355 md = {'msg_id' : parent['msg_id'],
356 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
357 'started' : datetime.strptime(header['started'], ss.ISO8601),
358 'completed' : datetime.strptime(header['date'], ss.ISO8601),
359 'received' : datetime.now(),
360 'engine_uuid' : header['engine'],
361 'engine_id' : self._engines.get(header['engine'], None),
362 'follow' : parent['follow'],
363 'after' : parent['after'],
364 'status' : content['status']
365 }
366 return md
367
346 def _handle_execute_reply(self, msg):
368 def _handle_execute_reply(self, msg):
347 """Save the reply to an execute_request into our results."""
369 """Save the reply to an execute_request into our results.
370
371 execute messages are never actually used. apply is used instead.
372 """
373
348 parent = msg['parent_header']
374 parent = msg['parent_header']
349 msg_id = parent['msg_id']
375 msg_id = parent['msg_id']
350 if msg_id not in self.outstanding:
376 if msg_id not in self.outstanding:
@@ -362,8 +388,12 b' class Client(object):'
362 else:
388 else:
363 self.outstanding.remove(msg_id)
389 self.outstanding.remove(msg_id)
364 content = msg['content']
390 content = msg['content']
391 header = msg['header']
392
393 self.metadata[msg_id] = self._build_metadata(header, parent, content)
394
365 if content['status'] == 'ok':
395 if content['status'] == 'ok':
366 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
396 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
367 elif content['status'] == 'aborted':
397 elif content['status'] == 'aborted':
368 self.results[msg_id] = error.AbortedTask(msg_id)
398 self.results[msg_id] = error.AbortedTask(msg_id)
369 elif content['status'] == 'resubmitted':
399 elif content['status'] == 'resubmitted':
@@ -372,10 +402,8 b' class Client(object):'
372 else:
402 else:
373 e = ss.unwrap_exception(content)
403 e = ss.unwrap_exception(content)
374 e_uuid = e.engine_info['engineid']
404 e_uuid = e.engine_info['engineid']
375 for k,v in self._engines.iteritems():
405 eid = self._engines[e_uuid]
376 if v == e_uuid:
406 e.engine_info['engineid'] = eid
377 e.engine_info['engineid'] = k
378 break
379 self.results[msg_id] = e
407 self.results[msg_id] = e
380
408
381 def _flush_notifications(self):
409 def _flush_notifications(self):
@@ -882,6 +910,13 b' class Client(object):'
882 status_only : bool (default: False)
910 status_only : bool (default: False)
883 if False:
911 if False:
884 return the actual results
912 return the actual results
913
914 Returns
915 -------
916
917 results : dict
918 There will always be the keys 'pending' and 'completed', which will
919 be lists of msg_ids.
885 """
920 """
886 if not isinstance(msg_ids, (list,tuple)):
921 if not isinstance(msg_ids, (list,tuple)):
887 msg_ids = [msg_ids]
922 msg_ids = [msg_ids]
@@ -895,11 +930,12 b' class Client(object):'
895
930
896 completed = []
931 completed = []
897 local_results = {}
932 local_results = {}
898 for msg_id in list(theids):
933 # temporarily disable local shortcut
899 if msg_id in self.results:
934 # for msg_id in list(theids):
900 completed.append(msg_id)
935 # if msg_id in self.results:
901 local_results[msg_id] = self.results[msg_id]
936 # completed.append(msg_id)
902 theids.remove(msg_id)
937 # local_results[msg_id] = self.results[msg_id]
938 # theids.remove(msg_id)
903
939
904 if theids: # some not locally cached
940 if theids: # some not locally cached
905 content = dict(msg_ids=theids, status_only=status_only)
941 content = dict(msg_ids=theids, status_only=status_only)
@@ -911,16 +947,40 b' class Client(object):'
911 content = msg['content']
947 content = msg['content']
912 if content['status'] != 'ok':
948 if content['status'] != 'ok':
913 raise ss.unwrap_exception(content)
949 raise ss.unwrap_exception(content)
950 buffers = msg['buffers']
914 else:
951 else:
915 content = dict(completed=[],pending=[])
952 content = dict(completed=[],pending=[])
916 if not status_only:
953
917 # load cached results into result:
954 content['completed'].extend(completed)
918 content['completed'].extend(completed)
955
919 content.update(local_results)
956 if status_only:
920 # update cache with results:
957 return content
921 for msg_id in msg_ids:
958
922 if msg_id in content['completed']:
959 failures = []
923 self.results[msg_id] = content[msg_id]
960 # load cached results into result:
961 content.update(local_results)
962 # update cache with results:
963 for msg_id in sorted(theids):
964 if msg_id in content['completed']:
965 rec = content[msg_id]
966 parent = rec['header']
967 header = rec['result_header']
968 rcontent = rec['result_content']
969 if isinstance(rcontent, str):
970 rcontent = self.session.unpack(rcontent)
971
972 self.metadata[msg_id] = self._build_metadata(header, parent, rcontent)
973
974 if rcontent['status'] == 'ok':
975 res,buffers = ss.unserialize_object(buffers)
976 else:
977 res = ss.unwrap_exception(rcontent)
978 failures.append(res)
979
980 self.results[msg_id] = res
981 content[msg_id] = res
982
983 error.collect_exceptions(failures, "get_results")
924 return content
984 return content
925
985
926 @spinfirst
986 @spinfirst
@@ -945,7 +1005,7 b' class Client(object):'
945 status = content.pop('status')
1005 status = content.pop('status')
946 if status != 'ok':
1006 if status != 'ok':
947 raise ss.unwrap_exception(content)
1007 raise ss.unwrap_exception(content)
948 return content
1008 return ss.rekey(content)
949
1009
950 @spinfirst
1010 @spinfirst
951 def purge_results(self, msg_ids=[], targets=[]):
1011 def purge_results(self, msg_ids=[], targets=[]):
@@ -47,33 +47,6 b' else:'
47 def _passer(*args, **kwargs):
47 def _passer(*args, **kwargs):
48 return
48 return
49
49
50 class ReverseDict(dict):
51 """simple double-keyed subset of dict methods."""
52
53 def __init__(self, *args, **kwargs):
54 dict.__init__(self, *args, **kwargs)
55 self.reverse = dict()
56 for key, value in self.iteritems():
57 self.reverse[value] = key
58
59 def __getitem__(self, key):
60 try:
61 return dict.__getitem__(self, key)
62 except KeyError:
63 return self.reverse[key]
64
65 def __setitem__(self, key, value):
66 if key in self.reverse:
67 raise KeyError("Can't have key %r on both sides!"%key)
68 dict.__setitem__(self, key, value)
69 self.reverse[value] = key
70
71 def pop(self, key):
72 value = dict.pop(self, key)
73 self.d1.pop(value)
74 return value
75
76
77 def init_record(msg):
50 def init_record(msg):
78 """return an empty TaskRecord dict, with all keys initialized with None."""
51 """return an empty TaskRecord dict, with all keys initialized with None."""
79 header = msg['header']
52 header = msg['header']
@@ -484,6 +457,8 b' class Controller(object):'
484 }
457 }
485 if MongoDB is not None and isinstance(self.db, MongoDB):
458 if MongoDB is not None and isinstance(self.db, MongoDB):
486 result['result_buffers'] = map(Binary, msg['buffers'])
459 result['result_buffers'] = map(Binary, msg['buffers'])
460 else:
461 result['result_buffers'] = msg['buffers']
487 self.db.update_record(msg_id, result)
462 self.db.update_record(msg_id, result)
488 else:
463 else:
489 logger.debug("queue:: unknown msg finished %s"%msg_id)
464 logger.debug("queue:: unknown msg finished %s"%msg_id)
@@ -552,6 +527,8 b' class Controller(object):'
552 }
527 }
553 if MongoDB is not None and isinstance(self.db, MongoDB):
528 if MongoDB is not None and isinstance(self.db, MongoDB):
554 result['result_buffers'] = map(Binary, msg['buffers'])
529 result['result_buffers'] = map(Binary, msg['buffers'])
530 else:
531 result['result_buffers'] = msg['buffers']
555 self.db.update_record(msg_id, result)
532 self.db.update_record(msg_id, result)
556
533
557 else:
534 else:
@@ -831,14 +808,16 b' class Controller(object):'
831 def get_results(self, client_id, msg):
808 def get_results(self, client_id, msg):
832 """Get the result of 1 or more messages."""
809 """Get the result of 1 or more messages."""
833 content = msg['content']
810 content = msg['content']
834 msg_ids = set(content['msg_ids'])
811 msg_ids = sorted(set(content['msg_ids']))
835 statusonly = content.get('status_only', False)
812 statusonly = content.get('status_only', False)
836 pending = []
813 pending = []
837 completed = []
814 completed = []
838 content = dict(status='ok')
815 content = dict(status='ok')
839 content['pending'] = pending
816 content['pending'] = pending
840 content['completed'] = completed
817 content['completed'] = completed
818 buffers = []
841 if not statusonly:
819 if not statusonly:
820 content['results'] = {}
842 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
821 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
843 for msg_id in msg_ids:
822 for msg_id in msg_ids:
844 if msg_id in self.pending:
823 if msg_id in self.pending:
@@ -846,7 +825,12 b' class Controller(object):'
846 elif msg_id in self.all_completed:
825 elif msg_id in self.all_completed:
847 completed.append(msg_id)
826 completed.append(msg_id)
848 if not statusonly:
827 if not statusonly:
849 content[msg_id] = records[msg_id]['result_content']
828 rec = records[msg_id]
829 content[msg_id] = { 'result_content': rec['result_content'],
830 'header': rec['header'],
831 'result_header' : rec['result_header'],
832 }
833 buffers.extend(map(str, rec['result_buffers']))
850 else:
834 else:
851 try:
835 try:
852 raise KeyError('No such message: '+msg_id)
836 raise KeyError('No such message: '+msg_id)
@@ -854,7 +838,8 b' class Controller(object):'
854 content = wrap_exception()
838 content = wrap_exception()
855 break
839 break
856 self.session.send(self.clientele, "result_reply", content=content,
840 self.session.send(self.clientele, "result_reply", content=content,
857 parent=msg, ident=client_id)
841 parent=msg, ident=client_id,
842 buffers=buffers)
858
843
859
844
860 #-------------------------------------------------------------------------
845 #-------------------------------------------------------------------------
@@ -35,7 +35,7 b' class MongoDB(object):'
35 def update_record(self, msg_id, rec):
35 def update_record(self, msg_id, rec):
36 """Update the data in an existing record."""
36 """Update the data in an existing record."""
37 obj_id = self._table[msg_id]
37 obj_id = self._table[msg_id]
38 self._records.update({'_id':obj_id}, rec)
38 self._records.update({'_id':obj_id}, {'$set': rec})
39
39
40 def drop_matching_records(self, check):
40 def drop_matching_records(self, check):
41 """Remove a record from the DB."""
41 """Remove a record from the DB."""
@@ -50,7 +50,11 b' class MongoDB(object):'
50 """Find records matching a query dict."""
50 """Find records matching a query dict."""
51 matches = list(self._records.find(check))
51 matches = list(self._records.find(check))
52 if id_only:
52 if id_only:
53 matches = [ rec['msg_id'] for rec in matches ]
53 return [ rec['msg_id'] for rec in matches ]
54 return matches
54 else:
55 data = {}
56 for rec in matches:
57 data[rec['msg_id']] = rec
58 return data
55
59
56
60
@@ -126,10 +126,10 b' class ParallelFunction(RemoteFunction):'
126 f=self.func
126 f=self.func
127 mid = self.client.apply(f, args=args, block=False,
127 mid = self.client.apply(f, args=args, block=False,
128 bound=self.bound,
128 bound=self.bound,
129 targets=engineid)._msg_ids[0]
129 targets=engineid).msg_ids[0]
130 msg_ids.append(mid)
130 msg_ids.append(mid)
131
131
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject)
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
133 if self.block:
133 if self.block:
134 r.wait()
134 r.wait()
135 return r.result
135 return r.result
@@ -208,7 +208,7 b' def unserialize_object(bufs):'
208 for s in sobj:
208 for s in sobj:
209 if s.data is None:
209 if s.data is None:
210 s.data = bufs.pop(0)
210 s.data = bufs.pop(0)
211 return uncanSequence(map(unserialize, sobj))
211 return uncanSequence(map(unserialize, sobj)), bufs
212 elif isinstance(sobj, dict):
212 elif isinstance(sobj, dict):
213 newobj = {}
213 newobj = {}
214 for k in sorted(sobj.iterkeys()):
214 for k in sorted(sobj.iterkeys()):
@@ -216,11 +216,11 b' def unserialize_object(bufs):'
216 if s.data is None:
216 if s.data is None:
217 s.data = bufs.pop(0)
217 s.data = bufs.pop(0)
218 newobj[k] = uncan(unserialize(s))
218 newobj[k] = uncan(unserialize(s))
219 return newobj
219 return newobj, bufs
220 else:
220 else:
221 if sobj.data is None:
221 if sobj.data is None:
222 sobj.data = bufs.pop(0)
222 sobj.data = bufs.pop(0)
223 return uncan(unserialize(sobj))
223 return uncan(unserialize(sobj)), bufs
224
224
225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
226 """pack up a function, args, and kwargs to be sent over the wire
226 """pack up a function, args, and kwargs to be sent over the wire
@@ -183,6 +183,8 b' class View(object):'
183 """Parallel version of builtin `map`, using this view's engines."""
183 """Parallel version of builtin `map`, using this view's engines."""
184 if isinstance(self.targets, int):
184 if isinstance(self.targets, int):
185 targets = [self.targets]
185 targets = [self.targets]
186 else:
187 targets = self.targets
186 pf = ParallelFunction(self.client, f, block=self.block,
188 pf = ParallelFunction(self.client, f, block=self.block,
187 bound=True, targets=targets)
189 bound=True, targets=targets)
188 return pf.map(*sequences)
190 return pf.map(*sequences)
General Comments 0
You need to be logged in to leave comments. Login now