##// END OF EJS Templates
General improvements to database backend...
MinRK -
Show More
@@ -0,0 +1,182 b''
1 """Tests for db backends"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14
15 import tempfile
16 import time
17
18 import uuid
19
20 from datetime import datetime, timedelta
21 from random import choice, randint
22 from unittest import TestCase
23
24 from nose import SkipTest
25
26 from IPython.parallel import error, streamsession as ss
27 from IPython.parallel.controller.dictdb import DictDB
28 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 from IPython.parallel.controller.hub import init_record, empty_record
30
31 #-------------------------------------------------------------------------------
32 # TestCases
33 #-------------------------------------------------------------------------------
34
35 class TestDictBackend(TestCase):
36 def setUp(self):
37 self.session = ss.StreamSession()
38 self.db = self.create_db()
39 self.load_records(16)
40
41 def create_db(self):
42 return DictDB()
43
44 def load_records(self, n=1):
45 """load n records for testing"""
46 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 time.sleep(0.1)
48 msg_ids = []
49 for i in range(n):
50 msg = self.session.msg('apply_request', content=dict(a=5))
51 msg['buffers'] = []
52 rec = init_record(msg)
53 msg_ids.append(msg['msg_id'])
54 self.db.add_record(msg['msg_id'], rec)
55 return msg_ids
56
57 def test_add_record(self):
58 before = self.db.get_history()
59 self.load_records(5)
60 after = self.db.get_history()
61 self.assertEquals(len(after), len(before)+5)
62 self.assertEquals(after[:-5],before)
63
64 def test_drop_record(self):
65 msg_id = self.load_records()[-1]
66 rec = self.db.get_record(msg_id)
67 self.db.drop_record(msg_id)
68 self.assertRaises(KeyError,self.db.get_record, msg_id)
69
70 def _round_to_millisecond(self, dt):
71 """necessary because mongodb rounds microseconds"""
72 micro = dt.microsecond
73 extra = int(str(micro)[-3:])
74 return dt - timedelta(microseconds=extra)
75
76 def test_update_record(self):
77 now = self._round_to_millisecond(datetime.now())
78 #
79 msg_id = self.db.get_history()[-1]
80 rec1 = self.db.get_record(msg_id)
81 data = {'stdout': 'hello there', 'completed' : now}
82 self.db.update_record(msg_id, data)
83 rec2 = self.db.get_record(msg_id)
84 self.assertEquals(rec2['stdout'], 'hello there')
85 self.assertEquals(rec2['completed'], now)
86 rec1.update(data)
87 self.assertEquals(rec1, rec2)
88
89 # def test_update_record_bad(self):
90 # """test updating nonexistant records"""
91 # msg_id = str(uuid.uuid4())
92 # data = {'stdout': 'hello there'}
93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94
95 def test_find_records_dt(self):
96 """test finding records by date"""
97 hist = self.db.get_history()
98 middle = self.db.get_record(hist[len(hist)/2])
99 tic = middle['submitted']
100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 self.assertEquals(len(before)+len(after),len(hist))
103 for b in before:
104 self.assertTrue(b['submitted'] < tic)
105 for a in after:
106 self.assertTrue(a['submitted'] >= tic)
107 same = self.db.find_records({'submitted' : tic})
108 for s in same:
109 self.assertTrue(s['submitted'] == tic)
110
111 def test_find_records_keys(self):
112 """test extracting subset of record keys"""
113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 for rec in found:
115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116
117 def test_find_records_msg_id(self):
118 """ensure msg_id is always in found records"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
121 self.assertTrue('msg_id' in rec.keys())
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 for rec in found:
124 self.assertTrue('msg_id' in rec.keys())
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
128
129 def test_find_records_in(self):
130 """test finding records with '$in','$nin' operators"""
131 hist = self.db.get_history()
132 even = hist[::2]
133 odd = hist[1::2]
134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 found = [ r['msg_id'] for r in recs ]
136 self.assertEquals(set(even), set(found))
137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 found = [ r['msg_id'] for r in recs ]
139 self.assertEquals(set(odd), set(found))
140
141 def test_get_history(self):
142 msg_ids = self.db.get_history()
143 latest = datetime(1984,1,1)
144 for msg_id in msg_ids:
145 rec = self.db.get_record(msg_id)
146 newt = rec['submitted']
147 self.assertTrue(newt >= latest)
148 latest = newt
149 msg_id = self.load_records(1)[-1]
150 self.assertEquals(self.db.get_history()[-1],msg_id)
151
152 def test_datetime(self):
153 """get/set timestamps with datetime objects"""
154 msg_id = self.db.get_history()[-1]
155 rec = self.db.get_record(msg_id)
156 self.assertTrue(isinstance(rec['submitted'], datetime))
157 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 rec = self.db.get_record(msg_id)
159 self.assertTrue(isinstance(rec['completed'], datetime))
160
161 class TestSQLiteBackend(TestDictBackend):
162 def create_db(self):
163 return SQLiteDB(location=tempfile.gettempdir())
164
165 def tearDown(self):
166 self.db._db.close()
167
168 # optional MongoDB test
169 try:
170 from IPython.parallel.controller.mongodb import MongoDB
171 except ImportError:
172 pass
173 else:
174 class TestMongoBackend(TestDictBackend):
175 def create_db(self):
176 try:
177 return MongoDB(database='iptestdb')
178 except Exception:
179 raise SkipTest("Couldn't connect to mongodb instance")
180
181 def tearDown(self):
182 self.db._connection.drop_database('iptestdb')
@@ -1215,5 +1215,78 b' class Client(HasTraits):'
1215 if content['status'] != 'ok':
1215 if content['status'] != 'ok':
1216 raise self._unwrap_exception(content)
1216 raise self._unwrap_exception(content)
1217
1217
1218 @spin_first
1219 def hub_history(self):
1220 """Get the Hub's history
1221
1222 Just like the Client, the Hub has a history, which is a list of msg_ids.
1223 This will contain the history of all clients, and, depending on configuration,
1224 may contain history across multiple cluster sessions.
1225
1226 Any msg_id returned here is a valid argument to `get_result`.
1227
1228 Returns
1229 -------
1230
1231 msg_ids : list of strs
1232 list of all msg_ids, ordered by task submission time.
1233 """
1234
1235 self.session.send(self._query_socket, "history_request", content={})
1236 idents, msg = self.session.recv(self._query_socket, 0)
1237
1238 if self.debug:
1239 pprint(msg)
1240 content = msg['content']
1241 if content['status'] != 'ok':
1242 raise self._unwrap_exception(content)
1243 else:
1244 return content['history']
1245
1246 @spin_first
1247 def db_query(self, query, keys=None):
1248 """Query the Hub's TaskRecord database
1249
1250 This will return a list of task record dicts that match `query`
1251
1252 Parameters
1253 ----------
1254
1255 query : mongodb query dict
1256 The search dict. See mongodb query docs for details.
1257 keys : list of strs [optional]
1258 THe subset of keys to be returned. The default is to fetch everything.
1259 'msg_id' will *always* be included.
1260 """
1261 content = dict(query=query, keys=keys)
1262 self.session.send(self._query_socket, "db_request", content=content)
1263 idents, msg = self.session.recv(self._query_socket, 0)
1264 if self.debug:
1265 pprint(msg)
1266 content = msg['content']
1267 if content['status'] != 'ok':
1268 raise self._unwrap_exception(content)
1269
1270 records = content['records']
1271 buffer_lens = content['buffer_lens']
1272 result_buffer_lens = content['result_buffer_lens']
1273 buffers = msg['buffers']
1274 has_bufs = buffer_lens is not None
1275 has_rbufs = result_buffer_lens is not None
1276 for i,rec in enumerate(records):
1277 # relink buffers
1278 if has_bufs:
1279 blen = buffer_lens[i]
1280 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1281 if has_rbufs:
1282 blen = result_buffer_lens[i]
1283 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1284 # turn timestamps back into times
1285 for key in 'submitted started completed resubmitted'.split():
1286 maybedate = rec.get(key, None)
1287 if maybedate and util.ISO8601_RE.match(maybedate):
1288 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1289
1290 return records
1218
1291
1219 __all__ = [ 'Client' ]
1292 __all__ = [ 'Client' ]
@@ -103,9 +103,9 b' class DictDB(BaseDB):'
103 return False
103 return False
104 return True
104 return True
105
105
106 def _match(self, check, id_only=True):
106 def _match(self, check):
107 """Find all the matches for a check dict."""
107 """Find all the matches for a check dict."""
108 matches = {}
108 matches = []
109 tests = {}
109 tests = {}
110 for k,v in check.iteritems():
110 for k,v in check.iteritems():
111 if isinstance(v, dict):
111 if isinstance(v, dict):
@@ -113,14 +113,18 b' class DictDB(BaseDB):'
113 else:
113 else:
114 tests[k] = lambda o: o==v
114 tests[k] = lambda o: o==v
115
115
116 for msg_id, rec in self._records.iteritems():
116 for rec in self._records.itervalues():
117 if self._match_one(rec, tests):
117 if self._match_one(rec, tests):
118 matches[msg_id] = rec
118 matches.append(rec)
119 if id_only:
119 return matches
120 return matches.keys()
120
121 else:
121 def _extract_subdict(self, rec, keys):
122 return matches
122 """extract subdict of keys"""
123
123 d = {}
124 d['msg_id'] = rec['msg_id']
125 for key in keys:
126 d[key] = rec[key]
127 return d
124
128
125 def add_record(self, msg_id, rec):
129 def add_record(self, msg_id, rec):
126 """Add a new Task Record, by msg_id."""
130 """Add a new Task Record, by msg_id."""
@@ -140,7 +144,7 b' class DictDB(BaseDB):'
140
144
141 def drop_matching_records(self, check):
145 def drop_matching_records(self, check):
142 """Remove a record from the DB."""
146 """Remove a record from the DB."""
143 matches = self._match(check, id_only=True)
147 matches = self._match(check)
144 for m in matches:
148 for m in matches:
145 del self._records[m]
149 del self._records[m]
146
150
@@ -149,7 +153,28 b' class DictDB(BaseDB):'
149 del self._records[msg_id]
153 del self._records[msg_id]
150
154
151
155
152 def find_records(self, check, id_only=False):
156 def find_records(self, check, keys=None):
153 """Find records matching a query dict."""
157 """Find records matching a query dict, optionally extracting subset of keys.
154 matches = self._match(check, id_only)
158
155 return matches No newline at end of file
159 Returns dict keyed by msg_id of matching records.
160
161 Parameters
162 ----------
163
164 check: dict
165 mongodb-style query argument
166 keys: list of strs [optional]
167 if specified, the subset of keys to extract. msg_id will *always* be
168 included.
169 """
170 matches = self._match(check)
171 if keys:
172 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 else:
174 return matches
175
176
177 def get_history(self):
178 """get all msg_ids, ordered by time submitted."""
179 msg_ids = self._records.keys()
180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -27,9 +27,8 b' from zmq.eventloop.zmqstream import ZMQStream'
27 from IPython.utils.importstring import import_item
27 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29
29
30 from IPython.parallel import error
30 from IPython.parallel import error, util
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601
33
32
34 from .heartmonitor import HeartMonitor
33 from .heartmonitor import HeartMonitor
35
34
@@ -76,7 +75,7 b' def init_record(msg):'
76 'header' : header,
75 'header' : header,
77 'content': msg['content'],
76 'content': msg['content'],
78 'buffers': msg['buffers'],
77 'buffers': msg['buffers'],
79 'submitted': datetime.strptime(header['date'], ISO8601),
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
80 'client_uuid' : None,
79 'client_uuid' : None,
81 'engine_uuid' : None,
80 'engine_uuid' : None,
82 'started': None,
81 'started': None,
@@ -119,32 +118,32 b' class HubFactory(RegistrationFactory):'
119 # port-pairs for monitoredqueues:
118 # port-pairs for monitoredqueues:
120 hb = Instance(list, config=True)
119 hb = Instance(list, config=True)
121 def _hb_default(self):
120 def _hb_default(self):
122 return select_random_ports(2)
121 return util.select_random_ports(2)
123
122
124 mux = Instance(list, config=True)
123 mux = Instance(list, config=True)
125 def _mux_default(self):
124 def _mux_default(self):
126 return select_random_ports(2)
125 return util.select_random_ports(2)
127
126
128 task = Instance(list, config=True)
127 task = Instance(list, config=True)
129 def _task_default(self):
128 def _task_default(self):
130 return select_random_ports(2)
129 return util.select_random_ports(2)
131
130
132 control = Instance(list, config=True)
131 control = Instance(list, config=True)
133 def _control_default(self):
132 def _control_default(self):
134 return select_random_ports(2)
133 return util.select_random_ports(2)
135
134
136 iopub = Instance(list, config=True)
135 iopub = Instance(list, config=True)
137 def _iopub_default(self):
136 def _iopub_default(self):
138 return select_random_ports(2)
137 return util.select_random_ports(2)
139
138
140 # single ports:
139 # single ports:
141 mon_port = Instance(int, config=True)
140 mon_port = Instance(int, config=True)
142 def _mon_port_default(self):
141 def _mon_port_default(self):
143 return select_random_ports(1)[0]
142 return util.select_random_ports(1)[0]
144
143
145 notifier_port = Instance(int, config=True)
144 notifier_port = Instance(int, config=True)
146 def _notifier_port_default(self):
145 def _notifier_port_default(self):
147 return select_random_ports(1)[0]
146 return util.select_random_ports(1)[0]
148
147
149 ping = Int(1000, config=True) # ping frequency
148 ping = Int(1000, config=True) # ping frequency
150
149
@@ -344,11 +343,11 b' class Hub(LoggingFactory):'
344 # validate connection dicts:
343 # validate connection dicts:
345 for k,v in self.client_info.iteritems():
344 for k,v in self.client_info.iteritems():
346 if k == 'task':
345 if k == 'task':
347 validate_url_container(v[1])
346 util.validate_url_container(v[1])
348 else:
347 else:
349 validate_url_container(v)
348 util.validate_url_container(v)
350 # validate_url_container(self.client_info)
349 # util.validate_url_container(self.client_info)
351 validate_url_container(self.engine_info)
350 util.validate_url_container(self.engine_info)
352
351
353 # register our callbacks
352 # register our callbacks
354 self.query.on_recv(self.dispatch_query)
353 self.query.on_recv(self.dispatch_query)
@@ -369,6 +368,8 b' class Hub(LoggingFactory):'
369
368
370 self.query_handlers = {'queue_request': self.queue_status,
369 self.query_handlers = {'queue_request': self.queue_status,
371 'result_request': self.get_results,
370 'result_request': self.get_results,
371 'history_request': self.get_history,
372 'db_request': self.db_query,
372 'purge_request': self.purge_results,
373 'purge_request': self.purge_results,
373 'load_request': self.check_load,
374 'load_request': self.check_load,
374 'resubmit_request': self.resubmit_task,
375 'resubmit_request': self.resubmit_task,
@@ -606,10 +607,10 b' class Hub(LoggingFactory):'
606 return
607 return
607 # update record anyway, because the unregistration could have been premature
608 # update record anyway, because the unregistration could have been premature
608 rheader = msg['header']
609 rheader = msg['header']
609 completed = datetime.strptime(rheader['date'], ISO8601)
610 completed = datetime.strptime(rheader['date'], util.ISO8601)
610 started = rheader.get('started', None)
611 started = rheader.get('started', None)
611 if started is not None:
612 if started is not None:
612 started = datetime.strptime(started, ISO8601)
613 started = datetime.strptime(started, util.ISO8601)
613 result = {
614 result = {
614 'result_header' : rheader,
615 'result_header' : rheader,
615 'result_content': msg['content'],
616 'result_content': msg['content'],
@@ -618,7 +619,10 b' class Hub(LoggingFactory):'
618 }
619 }
619
620
620 result['result_buffers'] = msg['buffers']
621 result['result_buffers'] = msg['buffers']
621 self.db.update_record(msg_id, result)
622 try:
623 self.db.update_record(msg_id, result)
624 except Exception:
625 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
622
626
623
627
624 #--------------------- Task Queue Traffic ------------------------------
628 #--------------------- Task Queue Traffic ------------------------------
@@ -653,6 +657,8 b' class Hub(LoggingFactory):'
653 self.db.update_record(msg_id, record)
657 self.db.update_record(msg_id, record)
654 except KeyError:
658 except KeyError:
655 self.db.add_record(msg_id, record)
659 self.db.add_record(msg_id, record)
660 except Exception:
661 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
656
662
657 def save_task_result(self, idents, msg):
663 def save_task_result(self, idents, msg):
658 """save the result of a completed task."""
664 """save the result of a completed task."""
@@ -685,10 +691,10 b' class Hub(LoggingFactory):'
685 self.completed[eid].append(msg_id)
691 self.completed[eid].append(msg_id)
686 if msg_id in self.tasks[eid]:
692 if msg_id in self.tasks[eid]:
687 self.tasks[eid].remove(msg_id)
693 self.tasks[eid].remove(msg_id)
688 completed = datetime.strptime(header['date'], ISO8601)
694 completed = datetime.strptime(header['date'], util.ISO8601)
689 started = header.get('started', None)
695 started = header.get('started', None)
690 if started is not None:
696 if started is not None:
691 started = datetime.strptime(started, ISO8601)
697 started = datetime.strptime(started, util.ISO8601)
692 result = {
698 result = {
693 'result_header' : header,
699 'result_header' : header,
694 'result_content': msg['content'],
700 'result_content': msg['content'],
@@ -698,7 +704,10 b' class Hub(LoggingFactory):'
698 }
704 }
699
705
700 result['result_buffers'] = msg['buffers']
706 result['result_buffers'] = msg['buffers']
701 self.db.update_record(msg_id, result)
707 try:
708 self.db.update_record(msg_id, result)
709 except Exception:
710 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
702
711
703 else:
712 else:
704 self.log.debug("task::unknown task %s finished"%msg_id)
713 self.log.debug("task::unknown task %s finished"%msg_id)
@@ -723,7 +732,11 b' class Hub(LoggingFactory):'
723
732
724 self.tasks[eid].append(msg_id)
733 self.tasks[eid].append(msg_id)
725 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
734 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
726 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
735 try:
736 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
737 except Exception:
738 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
739
727
740
728 def mia_task_request(self, idents, msg):
741 def mia_task_request(self, idents, msg):
729 raise NotImplementedError
742 raise NotImplementedError
@@ -772,7 +785,10 b' class Hub(LoggingFactory):'
772 else:
785 else:
773 d[msg_type] = content.get('data', '')
786 d[msg_type] = content.get('data', '')
774
787
775 self.db.update_record(msg_id, d)
788 try:
789 self.db.update_record(msg_id, d)
790 except Exception:
791 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
776
792
777
793
778
794
@@ -904,11 +920,15 b' class Hub(LoggingFactory):'
904 # build a fake header:
920 # build a fake header:
905 header = {}
921 header = {}
906 header['engine'] = uuid
922 header['engine'] = uuid
907 header['date'] = datetime.now().strftime(ISO8601)
923 header['date'] = datetime.now()
908 rec = dict(result_content=content, result_header=header, result_buffers=[])
924 rec = dict(result_content=content, result_header=header, result_buffers=[])
909 rec['completed'] = header['date']
925 rec['completed'] = header['date']
910 rec['engine_uuid'] = uuid
926 rec['engine_uuid'] = uuid
911 self.db.update_record(msg_id, rec)
927 try:
928 self.db.update_record(msg_id, rec)
929 except Exception:
930 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
931
912
932
913 def finish_registration(self, heart):
933 def finish_registration(self, heart):
914 """Second half of engine registration, called after our HeartMonitor
934 """Second half of engine registration, called after our HeartMonitor
@@ -1017,7 +1037,10 b' class Hub(LoggingFactory):'
1017 msg_ids = content.get('msg_ids', [])
1037 msg_ids = content.get('msg_ids', [])
1018 reply = dict(status='ok')
1038 reply = dict(status='ok')
1019 if msg_ids == 'all':
1039 if msg_ids == 'all':
1020 self.db.drop_matching_records(dict(completed={'$ne':None}))
1040 try:
1041 self.db.drop_matching_records(dict(completed={'$ne':None}))
1042 except Exception:
1043 reply = error.wrap_exception()
1021 else:
1044 else:
1022 for msg_id in msg_ids:
1045 for msg_id in msg_ids:
1023 if msg_id in self.all_completed:
1046 if msg_id in self.all_completed:
@@ -1044,7 +1067,11 b' class Hub(LoggingFactory):'
1044 break
1067 break
1045 msg_ids = self.completed.pop(eid)
1068 msg_ids = self.completed.pop(eid)
1046 uid = self.engines[eid].queue
1069 uid = self.engines[eid].queue
1047 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1070 try:
1071 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1072 except Exception:
1073 reply = error.wrap_exception()
1074 break
1048
1075
1049 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1076 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1050
1077
@@ -1052,6 +1079,23 b' class Hub(LoggingFactory):'
1052 """Resubmit a task."""
1079 """Resubmit a task."""
1053 raise NotImplementedError
1080 raise NotImplementedError
1054
1081
1082 def _extract_record(self, rec):
1083 """decompose a TaskRecord dict into subsection of reply for get_result"""
1084 io_dict = {}
1085 for key in 'pyin pyout pyerr stdout stderr'.split():
1086 io_dict[key] = rec[key]
1087 content = { 'result_content': rec['result_content'],
1088 'header': rec['header'],
1089 'result_header' : rec['result_header'],
1090 'io' : io_dict,
1091 }
1092 if rec['result_buffers']:
1093 buffers = map(str, rec['result_buffers'])
1094 else:
1095 buffers = []
1096
1097 return content, buffers
1098
1055 def get_results(self, client_id, msg):
1099 def get_results(self, client_id, msg):
1056 """Get the result of 1 or more messages."""
1100 """Get the result of 1 or more messages."""
1057 content = msg['content']
1101 content = msg['content']
@@ -1064,25 +1108,28 b' class Hub(LoggingFactory):'
1064 content['completed'] = completed
1108 content['completed'] = completed
1065 buffers = []
1109 buffers = []
1066 if not statusonly:
1110 if not statusonly:
1067 content['results'] = {}
1111 try:
1068 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1112 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1113 # turn match list into dict, for faster lookup
1114 records = {}
1115 for rec in matches:
1116 records[rec['msg_id']] = rec
1117 except Exception:
1118 content = error.wrap_exception()
1119 self.session.send(self.query, "result_reply", content=content,
1120 parent=msg, ident=client_id)
1121 return
1122 else:
1123 records = {}
1069 for msg_id in msg_ids:
1124 for msg_id in msg_ids:
1070 if msg_id in self.pending:
1125 if msg_id in self.pending:
1071 pending.append(msg_id)
1126 pending.append(msg_id)
1072 elif msg_id in self.all_completed:
1127 elif msg_id in self.all_completed or msg_id in records:
1073 completed.append(msg_id)
1128 completed.append(msg_id)
1074 if not statusonly:
1129 if not statusonly:
1075 rec = records[msg_id]
1130 c,bufs = self._extract_record(records[msg_id])
1076 io_dict = {}
1131 content[msg_id] = c
1077 for key in 'pyin pyout pyerr stdout stderr'.split():
1132 buffers.extend(bufs)
1078 io_dict[key] = rec[key]
1079 content[msg_id] = { 'result_content': rec['result_content'],
1080 'header': rec['header'],
1081 'result_header' : rec['result_header'],
1082 'io' : io_dict,
1083 }
1084 if rec['result_buffers']:
1085 buffers.extend(map(str, rec['result_buffers']))
1086 else:
1133 else:
1087 try:
1134 try:
1088 raise KeyError('No such message: '+msg_id)
1135 raise KeyError('No such message: '+msg_id)
@@ -1093,3 +1140,54 b' class Hub(LoggingFactory):'
1093 parent=msg, ident=client_id,
1140 parent=msg, ident=client_id,
1094 buffers=buffers)
1141 buffers=buffers)
1095
1142
1143 def get_history(self, client_id, msg):
1144 """Get a list of all msg_ids in our DB records"""
1145 try:
1146 msg_ids = self.db.get_history()
1147 except Exception as e:
1148 content = error.wrap_exception()
1149 else:
1150 content = dict(status='ok', history=msg_ids)
1151
1152 self.session.send(self.query, "history_reply", content=content,
1153 parent=msg, ident=client_id)
1154
1155 def db_query(self, client_id, msg):
1156 """Perform a raw query on the task record database."""
1157 content = msg['content']
1158 query = content.get('query', {})
1159 keys = content.get('keys', None)
1160 query = util.extract_dates(query)
1161 buffers = []
1162 empty = list()
1163
1164 try:
1165 records = self.db.find_records(query, keys)
1166 except Exception as e:
1167 content = error.wrap_exception()
1168 else:
1169 # extract buffers from reply content:
1170 if keys is not None:
1171 buffer_lens = [] if 'buffers' in keys else None
1172 result_buffer_lens = [] if 'result_buffers' in keys else None
1173 else:
1174 buffer_lens = []
1175 result_buffer_lens = []
1176
1177 for rec in records:
1178 # buffers may be None, so double check
1179 if buffer_lens is not None:
1180 b = rec.pop('buffers', empty) or empty
1181 buffer_lens.append(len(b))
1182 buffers.extend(b)
1183 if result_buffer_lens is not None:
1184 rb = rec.pop('result_buffers', empty) or empty
1185 result_buffer_lens.append(len(rb))
1186 buffers.extend(rb)
1187 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1188 result_buffer_lens=result_buffer_lens)
1189
1190 self.session.send(self.query, "db_reply", content=content,
1191 parent=msg, ident=client_id,
1192 buffers=buffers)
1193
@@ -22,9 +22,9 b' from .dictdb import BaseDB'
22 class MongoDB(BaseDB):
22 class MongoDB(BaseDB):
23 """MongoDB TaskRecord backend."""
23 """MongoDB TaskRecord backend."""
24
24
25 connection_args = List(config=True)
25 connection_args = List(config=True) # args passed to pymongo.Connection
26 connection_kwargs = Dict(config=True)
26 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
27 database = CUnicode(config=True)
27 database = CUnicode(config=True) # name of the mongodb database
28 _table = Dict()
28 _table = Dict()
29
29
30 def __init__(self, **kwargs):
30 def __init__(self, **kwargs):
@@ -37,13 +37,14 b' class MongoDB(BaseDB):'
37
37
38 def _binary_buffers(self, rec):
38 def _binary_buffers(self, rec):
39 for key in ('buffers', 'result_buffers'):
39 for key in ('buffers', 'result_buffers'):
40 if key in rec:
40 if rec.get(key, None):
41 rec[key] = map(Binary, rec[key])
41 rec[key] = map(Binary, rec[key])
42 return rec
42
43
43 def add_record(self, msg_id, rec):
44 def add_record(self, msg_id, rec):
44 """Add a new Task Record, by msg_id."""
45 """Add a new Task Record, by msg_id."""
45 # print rec
46 # print rec
46 rec = _binary_buffers(rec)
47 rec = self._binary_buffers(rec)
47 obj_id = self._records.insert(rec)
48 obj_id = self._records.insert(rec)
48 self._table[msg_id] = obj_id
49 self._table[msg_id] = obj_id
49
50
@@ -53,7 +54,7 b' class MongoDB(BaseDB):'
53
54
54 def update_record(self, msg_id, rec):
55 def update_record(self, msg_id, rec):
55 """Update the data in an existing record."""
56 """Update the data in an existing record."""
56 rec = _binary_buffers(rec)
57 rec = self._binary_buffers(rec)
57 obj_id = self._table[msg_id]
58 obj_id = self._table[msg_id]
58 self._records.update({'_id':obj_id}, {'$set': rec})
59 self._records.update({'_id':obj_id}, {'$set': rec})
59
60
@@ -66,15 +67,30 b' class MongoDB(BaseDB):'
66 obj_id = self._table.pop(msg_id)
67 obj_id = self._table.pop(msg_id)
67 self._records.remove(obj_id)
68 self._records.remove(obj_id)
68
69
69 def find_records(self, check, id_only=False):
70 def find_records(self, check, keys=None):
70 """Find records matching a query dict."""
71 """Find records matching a query dict, optionally extracting subset of keys.
71 matches = list(self._records.find(check))
72
72 if id_only:
73 Returns list of matching records.
73 return [ rec['msg_id'] for rec in matches ]
74
74 else:
75 Parameters
75 data = {}
76 ----------
76 for rec in matches:
77
77 data[rec['msg_id']] = rec
78 check: dict
78 return data
79 mongodb-style query argument
80 keys: list of strs [optional]
81 if specified, the subset of keys to extract. msg_id will *always* be
82 included.
83 """
84 if keys and 'msg_id' not in keys:
85 keys.append('msg_id')
86 matches = list(self._records.find(check,keys))
87 for rec in matches:
88 rec.pop('_id')
89 return matches
90
91 def get_history(self):
92 """get all msg_ids, ordered by time submitted."""
93 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
94 return [ rec['msg_id'] for rec in cursor ]
79
95
80
96
@@ -24,7 +24,7 b' from IPython.parallel.util import ISO8601'
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25
25
26 operators = {
26 operators = {
27 '$lt' : lambda a,b: "%s < ?",
27 '$lt' : "<",
28 '$gt' : ">",
28 '$gt' : ">",
29 # null is handled weird with ==,!=
29 # null is handled weird with ==,!=
30 '$eq' : "IS",
30 '$eq' : "IS",
@@ -124,10 +124,11 b' class SQLiteDB(BaseDB):'
124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
125 pc.start()
125 pc.start()
126
126
127 def _defaults(self):
127 def _defaults(self, keys=None):
128 """create an empty record"""
128 """create an empty record"""
129 d = {}
129 d = {}
130 for key in self._keys:
130 keys = self._keys if keys is None else keys
131 for key in keys:
131 d[key] = None
132 d[key] = None
132 return d
133 return d
133
134
@@ -168,9 +169,6 b' class SQLiteDB(BaseDB):'
168 stdout text,
169 stdout text,
169 stderr text)
170 stderr text)
170 """%self.table)
171 """%self.table)
171 # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers
172 # (msg_id text, result integer, buffer blob)
173 # """%self.table)
174 self._db.commit()
172 self._db.commit()
175
173
176 def _dict_to_list(self, d):
174 def _dict_to_list(self, d):
@@ -178,10 +176,11 b' class SQLiteDB(BaseDB):'
178
176
179 return [ d[key] for key in self._keys ]
177 return [ d[key] for key in self._keys ]
180
178
181 def _list_to_dict(self, line):
179 def _list_to_dict(self, line, keys=None):
182 """Inverse of dict_to_list"""
180 """Inverse of dict_to_list"""
183 d = self._defaults()
181 keys = self._keys if keys is None else keys
184 for key,value in zip(self._keys, line):
182 d = self._defaults(keys)
183 for key,value in zip(keys, line):
185 d[key] = value
184 d[key] = value
186
185
187 return d
186 return d
@@ -249,13 +248,14 b' class SQLiteDB(BaseDB):'
249 sets.append('%s = ?'%key)
248 sets.append('%s = ?'%key)
250 values.append(rec[key])
249 values.append(rec[key])
251 query += ', '.join(sets)
250 query += ', '.join(sets)
252 query += ' WHERE msg_id == %r'%msg_id
251 query += ' WHERE msg_id == ?'
252 values.append(msg_id)
253 self._db.execute(query, values)
253 self._db.execute(query, values)
254 # self._db.commit()
254 # self._db.commit()
255
255
256 def drop_record(self, msg_id):
256 def drop_record(self, msg_id):
257 """Remove a record from the DB."""
257 """Remove a record from the DB."""
258 self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,))
258 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
259 # self._db.commit()
259 # self._db.commit()
260
260
261 def drop_matching_records(self, check):
261 def drop_matching_records(self, check):
@@ -265,20 +265,48 b' class SQLiteDB(BaseDB):'
265 self._db.execute(query,args)
265 self._db.execute(query,args)
266 # self._db.commit()
266 # self._db.commit()
267
267
268 def find_records(self, check, id_only=False):
268 def find_records(self, check, keys=None):
269 """Find records matching a query dict."""
269 """Find records matching a query dict, optionally extracting subset of keys.
270 req = 'msg_id' if id_only else '*'
270
271 Returns list of matching records.
272
273 Parameters
274 ----------
275
276 check: dict
277 mongodb-style query argument
278 keys: list of strs [optional]
279 if specified, the subset of keys to extract. msg_id will *always* be
280 included.
281 """
282 if keys:
283 bad_keys = [ key for key in keys if key not in self._keys ]
284 if bad_keys:
285 raise KeyError("Bad record key(s): %s"%bad_keys)
286
287 if keys:
288 # ensure msg_id is present and first:
289 if 'msg_id' in keys:
290 keys.remove('msg_id')
291 keys.insert(0, 'msg_id')
292 req = ', '.join(keys)
293 else:
294 req = '*'
271 expr,args = self._render_expression(check)
295 expr,args = self._render_expression(check)
272 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
296 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
273 cursor = self._db.execute(query, args)
297 cursor = self._db.execute(query, args)
274 matches = cursor.fetchall()
298 matches = cursor.fetchall()
275 if id_only:
299 records = []
276 return [ m[0] for m in matches ]
300 for line in matches:
277 else:
301 rec = self._list_to_dict(line, keys)
278 records = {}
302 records.append(rec)
279 for line in matches:
303 return records
280 rec = self._list_to_dict(line)
304
281 records[rec['msg_id']] = rec
305 def get_history(self):
282 return records
306 """get all msg_ids, ordered by time submitted."""
307 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
308 cursor = self._db.execute(query)
309 # will be a list of length 1 tuples
310 return [ tup[0] for tup in cursor.fetchall()]
283
311
284 __all__ = ['SQLiteDB'] No newline at end of file
312 __all__ = ['SQLiteDB']
@@ -28,6 +28,7 b' from zmq.eventloop.zmqstream import ZMQStream'
28 from .util import ISO8601
28 from .util import ISO8601
29
29
30 def squash_unicode(obj):
30 def squash_unicode(obj):
31 """coerce unicode back to bytestrings."""
31 if isinstance(obj,dict):
32 if isinstance(obj,dict):
32 for key in obj.keys():
33 for key in obj.keys():
33 obj[key] = squash_unicode(obj[key])
34 obj[key] = squash_unicode(obj[key])
@@ -40,7 +41,14 b' def squash_unicode(obj):'
40 obj = obj.encode('utf8')
41 obj = obj.encode('utf8')
41 return obj
42 return obj
42
43
43 json_packer = jsonapi.dumps
44 def _date_default(obj):
45 if isinstance(obj, datetime):
46 return obj.strftime(ISO8601)
47 else:
48 raise TypeError("%r is not JSON serializable"%obj)
49
50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
44 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
45
53
46 pickle_packer = lambda o: pickle.dumps(o,-1)
54 pickle_packer = lambda o: pickle.dumps(o,-1)
@@ -17,6 +17,7 b' import re'
17 import stat
17 import stat
18 import socket
18 import socket
19 import sys
19 import sys
20 from datetime import datetime
20 from signal import signal, SIGINT, SIGABRT, SIGTERM
21 from signal import signal, SIGINT, SIGABRT, SIGTERM
21 try:
22 try:
22 from signal import SIGKILL
23 from signal import SIGKILL
@@ -41,6 +42,7 b' from IPython.zmq.log import EnginePUBHandler'
41
42
42 # globals
43 # globals
43 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
44 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
45 ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
44
46
45 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
46 # Classes
48 # Classes
@@ -99,6 +101,18 b' class ReverseDict(dict):'
99 # Functions
101 # Functions
100 #-----------------------------------------------------------------------------
102 #-----------------------------------------------------------------------------
101
103
104 def extract_dates(obj):
105 """extract ISO8601 dates from unpacked JSON"""
106 if isinstance(obj, dict):
107 for k,v in obj.iteritems():
108 obj[k] = extract_dates(v)
109 elif isinstance(obj, list):
110 obj = [ extract_dates(o) for o in obj ]
111 elif isinstance(obj, basestring):
112 if ISO8601_RE.match(obj):
113 obj = datetime.strptime(obj, ISO8601)
114 return obj
115
102 def validate_url(url):
116 def validate_url(url):
103 """validate a url for zeromq"""
117 """validate a url for zeromq"""
104 if not isinstance(url, basestring):
118 if not isinstance(url, basestring):
@@ -460,3 +474,4 b' def local_logger(logname, loglevel=logging.DEBUG):'
460 handler.setLevel(loglevel)
474 handler.setLevel(loglevel)
461 logger.addHandler(handler)
475 logger.addHandler(handler)
462 logger.setLevel(loglevel)
476 logger.setLevel(loglevel)
477
General Comments 0
You need to be logged in to leave comments. Login now