##// END OF EJS Templates
General improvements to database backend...
MinRK -
Show More
@@ -0,0 +1,182 b''
1 """Tests for db backends"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14
15 import tempfile
16 import time
17
18 import uuid
19
20 from datetime import datetime, timedelta
21 from random import choice, randint
22 from unittest import TestCase
23
24 from nose import SkipTest
25
26 from IPython.parallel import error, streamsession as ss
27 from IPython.parallel.controller.dictdb import DictDB
28 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 from IPython.parallel.controller.hub import init_record, empty_record
30
31 #-------------------------------------------------------------------------------
32 # TestCases
33 #-------------------------------------------------------------------------------
34
35 class TestDictBackend(TestCase):
36 def setUp(self):
37 self.session = ss.StreamSession()
38 self.db = self.create_db()
39 self.load_records(16)
40
41 def create_db(self):
42 return DictDB()
43
44 def load_records(self, n=1):
45 """load n records for testing"""
46 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 time.sleep(0.1)
48 msg_ids = []
49 for i in range(n):
50 msg = self.session.msg('apply_request', content=dict(a=5))
51 msg['buffers'] = []
52 rec = init_record(msg)
53 msg_ids.append(msg['msg_id'])
54 self.db.add_record(msg['msg_id'], rec)
55 return msg_ids
56
57 def test_add_record(self):
58 before = self.db.get_history()
59 self.load_records(5)
60 after = self.db.get_history()
61 self.assertEquals(len(after), len(before)+5)
62 self.assertEquals(after[:-5],before)
63
64 def test_drop_record(self):
65 msg_id = self.load_records()[-1]
66 rec = self.db.get_record(msg_id)
67 self.db.drop_record(msg_id)
68 self.assertRaises(KeyError,self.db.get_record, msg_id)
69
70 def _round_to_millisecond(self, dt):
71 """necessary because mongodb rounds microseconds"""
72 micro = dt.microsecond
73 extra = int(str(micro)[-3:])
74 return dt - timedelta(microseconds=extra)
75
76 def test_update_record(self):
77 now = self._round_to_millisecond(datetime.now())
78 #
79 msg_id = self.db.get_history()[-1]
80 rec1 = self.db.get_record(msg_id)
81 data = {'stdout': 'hello there', 'completed' : now}
82 self.db.update_record(msg_id, data)
83 rec2 = self.db.get_record(msg_id)
84 self.assertEquals(rec2['stdout'], 'hello there')
85 self.assertEquals(rec2['completed'], now)
86 rec1.update(data)
87 self.assertEquals(rec1, rec2)
88
89 # def test_update_record_bad(self):
90 # """test updating nonexistant records"""
91 # msg_id = str(uuid.uuid4())
92 # data = {'stdout': 'hello there'}
93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94
95 def test_find_records_dt(self):
96 """test finding records by date"""
97 hist = self.db.get_history()
98 middle = self.db.get_record(hist[len(hist)/2])
99 tic = middle['submitted']
100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 self.assertEquals(len(before)+len(after),len(hist))
103 for b in before:
104 self.assertTrue(b['submitted'] < tic)
105 for a in after:
106 self.assertTrue(a['submitted'] >= tic)
107 same = self.db.find_records({'submitted' : tic})
108 for s in same:
109 self.assertTrue(s['submitted'] == tic)
110
111 def test_find_records_keys(self):
112 """test extracting subset of record keys"""
113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 for rec in found:
115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116
117 def test_find_records_msg_id(self):
118 """ensure msg_id is always in found records"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
121 self.assertTrue('msg_id' in rec.keys())
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 for rec in found:
124 self.assertTrue('msg_id' in rec.keys())
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
128
129 def test_find_records_in(self):
130 """test finding records with '$in','$nin' operators"""
131 hist = self.db.get_history()
132 even = hist[::2]
133 odd = hist[1::2]
134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 found = [ r['msg_id'] for r in recs ]
136 self.assertEquals(set(even), set(found))
137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 found = [ r['msg_id'] for r in recs ]
139 self.assertEquals(set(odd), set(found))
140
141 def test_get_history(self):
142 msg_ids = self.db.get_history()
143 latest = datetime(1984,1,1)
144 for msg_id in msg_ids:
145 rec = self.db.get_record(msg_id)
146 newt = rec['submitted']
147 self.assertTrue(newt >= latest)
148 latest = newt
149 msg_id = self.load_records(1)[-1]
150 self.assertEquals(self.db.get_history()[-1],msg_id)
151
152 def test_datetime(self):
153 """get/set timestamps with datetime objects"""
154 msg_id = self.db.get_history()[-1]
155 rec = self.db.get_record(msg_id)
156 self.assertTrue(isinstance(rec['submitted'], datetime))
157 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 rec = self.db.get_record(msg_id)
159 self.assertTrue(isinstance(rec['completed'], datetime))
160
161 class TestSQLiteBackend(TestDictBackend):
162 def create_db(self):
163 return SQLiteDB(location=tempfile.gettempdir())
164
165 def tearDown(self):
166 self.db._db.close()
167
168 # optional MongoDB test
169 try:
170 from IPython.parallel.controller.mongodb import MongoDB
171 except ImportError:
172 pass
173 else:
174 class TestMongoBackend(TestDictBackend):
175 def create_db(self):
176 try:
177 return MongoDB(database='iptestdb')
178 except Exception:
179 raise SkipTest("Couldn't connect to mongodb instance")
180
181 def tearDown(self):
182 self.db._connection.drop_database('iptestdb')
@@ -1215,5 +1215,78 b' class Client(HasTraits):'
1215 1215 if content['status'] != 'ok':
1216 1216 raise self._unwrap_exception(content)
1217 1217
1218 @spin_first
1219 def hub_history(self):
1220 """Get the Hub's history
1221
1222 Just like the Client, the Hub has a history, which is a list of msg_ids.
1223 This will contain the history of all clients, and, depending on configuration,
1224 may contain history across multiple cluster sessions.
1225
1226 Any msg_id returned here is a valid argument to `get_result`.
1227
1228 Returns
1229 -------
1230
1231 msg_ids : list of strs
1232 list of all msg_ids, ordered by task submission time.
1233 """
1234
1235 self.session.send(self._query_socket, "history_request", content={})
1236 idents, msg = self.session.recv(self._query_socket, 0)
1237
1238 if self.debug:
1239 pprint(msg)
1240 content = msg['content']
1241 if content['status'] != 'ok':
1242 raise self._unwrap_exception(content)
1243 else:
1244 return content['history']
1245
1246 @spin_first
1247 def db_query(self, query, keys=None):
1248 """Query the Hub's TaskRecord database
1249
1250 This will return a list of task record dicts that match `query`
1251
1252 Parameters
1253 ----------
1254
1255 query : mongodb query dict
1256 The search dict. See mongodb query docs for details.
1257 keys : list of strs [optional]
1258 THe subset of keys to be returned. The default is to fetch everything.
1259 'msg_id' will *always* be included.
1260 """
1261 content = dict(query=query, keys=keys)
1262 self.session.send(self._query_socket, "db_request", content=content)
1263 idents, msg = self.session.recv(self._query_socket, 0)
1264 if self.debug:
1265 pprint(msg)
1266 content = msg['content']
1267 if content['status'] != 'ok':
1268 raise self._unwrap_exception(content)
1269
1270 records = content['records']
1271 buffer_lens = content['buffer_lens']
1272 result_buffer_lens = content['result_buffer_lens']
1273 buffers = msg['buffers']
1274 has_bufs = buffer_lens is not None
1275 has_rbufs = result_buffer_lens is not None
1276 for i,rec in enumerate(records):
1277 # relink buffers
1278 if has_bufs:
1279 blen = buffer_lens[i]
1280 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1281 if has_rbufs:
1282 blen = result_buffer_lens[i]
1283 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1284 # turn timestamps back into times
1285 for key in 'submitted started completed resubmitted'.split():
1286 maybedate = rec.get(key, None)
1287 if maybedate and util.ISO8601_RE.match(maybedate):
1288 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1289
1290 return records
1218 1291
1219 1292 __all__ = [ 'Client' ]
@@ -103,9 +103,9 b' class DictDB(BaseDB):'
103 103 return False
104 104 return True
105 105
106 def _match(self, check, id_only=True):
106 def _match(self, check):
107 107 """Find all the matches for a check dict."""
108 matches = {}
108 matches = []
109 109 tests = {}
110 110 for k,v in check.iteritems():
111 111 if isinstance(v, dict):
@@ -113,14 +113,18 b' class DictDB(BaseDB):'
113 113 else:
114 114 tests[k] = lambda o: o==v
115 115
116 for msg_id, rec in self._records.iteritems():
116 for rec in self._records.itervalues():
117 117 if self._match_one(rec, tests):
118 matches[msg_id] = rec
119 if id_only:
120 return matches.keys()
121 else:
122 return matches
123
118 matches.append(rec)
119 return matches
120
121 def _extract_subdict(self, rec, keys):
122 """extract subdict of keys"""
123 d = {}
124 d['msg_id'] = rec['msg_id']
125 for key in keys:
126 d[key] = rec[key]
127 return d
124 128
125 129 def add_record(self, msg_id, rec):
126 130 """Add a new Task Record, by msg_id."""
@@ -140,7 +144,7 b' class DictDB(BaseDB):'
140 144
141 145 def drop_matching_records(self, check):
142 146 """Remove a record from the DB."""
143 matches = self._match(check, id_only=True)
147 matches = self._match(check)
144 148 for m in matches:
145 149 del self._records[m]
146 150
@@ -149,7 +153,28 b' class DictDB(BaseDB):'
149 153 del self._records[msg_id]
150 154
151 155
152 def find_records(self, check, id_only=False):
153 """Find records matching a query dict."""
154 matches = self._match(check, id_only)
155 return matches No newline at end of file
156 def find_records(self, check, keys=None):
157 """Find records matching a query dict, optionally extracting subset of keys.
158
159 Returns dict keyed by msg_id of matching records.
160
161 Parameters
162 ----------
163
164 check: dict
165 mongodb-style query argument
166 keys: list of strs [optional]
167 if specified, the subset of keys to extract. msg_id will *always* be
168 included.
169 """
170 matches = self._match(check)
171 if keys:
172 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 else:
174 return matches
175
176
177 def get_history(self):
178 """get all msg_ids, ordered by time submitted."""
179 msg_ids = self._records.keys()
180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -27,9 +27,8 b' from zmq.eventloop.zmqstream import ZMQStream'
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29 29
30 from IPython.parallel import error
30 from IPython.parallel import error, util
31 31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601
33 32
34 33 from .heartmonitor import HeartMonitor
35 34
@@ -76,7 +75,7 b' def init_record(msg):'
76 75 'header' : header,
77 76 'content': msg['content'],
78 77 'buffers': msg['buffers'],
79 'submitted': datetime.strptime(header['date'], ISO8601),
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
80 79 'client_uuid' : None,
81 80 'engine_uuid' : None,
82 81 'started': None,
@@ -119,32 +118,32 b' class HubFactory(RegistrationFactory):'
119 118 # port-pairs for monitoredqueues:
120 119 hb = Instance(list, config=True)
121 120 def _hb_default(self):
122 return select_random_ports(2)
121 return util.select_random_ports(2)
123 122
124 123 mux = Instance(list, config=True)
125 124 def _mux_default(self):
126 return select_random_ports(2)
125 return util.select_random_ports(2)
127 126
128 127 task = Instance(list, config=True)
129 128 def _task_default(self):
130 return select_random_ports(2)
129 return util.select_random_ports(2)
131 130
132 131 control = Instance(list, config=True)
133 132 def _control_default(self):
134 return select_random_ports(2)
133 return util.select_random_ports(2)
135 134
136 135 iopub = Instance(list, config=True)
137 136 def _iopub_default(self):
138 return select_random_ports(2)
137 return util.select_random_ports(2)
139 138
140 139 # single ports:
141 140 mon_port = Instance(int, config=True)
142 141 def _mon_port_default(self):
143 return select_random_ports(1)[0]
142 return util.select_random_ports(1)[0]
144 143
145 144 notifier_port = Instance(int, config=True)
146 145 def _notifier_port_default(self):
147 return select_random_ports(1)[0]
146 return util.select_random_ports(1)[0]
148 147
149 148 ping = Int(1000, config=True) # ping frequency
150 149
@@ -344,11 +343,11 b' class Hub(LoggingFactory):'
344 343 # validate connection dicts:
345 344 for k,v in self.client_info.iteritems():
346 345 if k == 'task':
347 validate_url_container(v[1])
346 util.validate_url_container(v[1])
348 347 else:
349 validate_url_container(v)
350 # validate_url_container(self.client_info)
351 validate_url_container(self.engine_info)
348 util.validate_url_container(v)
349 # util.validate_url_container(self.client_info)
350 util.validate_url_container(self.engine_info)
352 351
353 352 # register our callbacks
354 353 self.query.on_recv(self.dispatch_query)
@@ -369,6 +368,8 b' class Hub(LoggingFactory):'
369 368
370 369 self.query_handlers = {'queue_request': self.queue_status,
371 370 'result_request': self.get_results,
371 'history_request': self.get_history,
372 'db_request': self.db_query,
372 373 'purge_request': self.purge_results,
373 374 'load_request': self.check_load,
374 375 'resubmit_request': self.resubmit_task,
@@ -606,10 +607,10 b' class Hub(LoggingFactory):'
606 607 return
607 608 # update record anyway, because the unregistration could have been premature
608 609 rheader = msg['header']
609 completed = datetime.strptime(rheader['date'], ISO8601)
610 completed = datetime.strptime(rheader['date'], util.ISO8601)
610 611 started = rheader.get('started', None)
611 612 if started is not None:
612 started = datetime.strptime(started, ISO8601)
613 started = datetime.strptime(started, util.ISO8601)
613 614 result = {
614 615 'result_header' : rheader,
615 616 'result_content': msg['content'],
@@ -618,7 +619,10 b' class Hub(LoggingFactory):'
618 619 }
619 620
620 621 result['result_buffers'] = msg['buffers']
621 self.db.update_record(msg_id, result)
622 try:
623 self.db.update_record(msg_id, result)
624 except Exception:
625 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
622 626
623 627
624 628 #--------------------- Task Queue Traffic ------------------------------
@@ -653,6 +657,8 b' class Hub(LoggingFactory):'
653 657 self.db.update_record(msg_id, record)
654 658 except KeyError:
655 659 self.db.add_record(msg_id, record)
660 except Exception:
661 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
656 662
657 663 def save_task_result(self, idents, msg):
658 664 """save the result of a completed task."""
@@ -685,10 +691,10 b' class Hub(LoggingFactory):'
685 691 self.completed[eid].append(msg_id)
686 692 if msg_id in self.tasks[eid]:
687 693 self.tasks[eid].remove(msg_id)
688 completed = datetime.strptime(header['date'], ISO8601)
694 completed = datetime.strptime(header['date'], util.ISO8601)
689 695 started = header.get('started', None)
690 696 if started is not None:
691 started = datetime.strptime(started, ISO8601)
697 started = datetime.strptime(started, util.ISO8601)
692 698 result = {
693 699 'result_header' : header,
694 700 'result_content': msg['content'],
@@ -698,7 +704,10 b' class Hub(LoggingFactory):'
698 704 }
699 705
700 706 result['result_buffers'] = msg['buffers']
701 self.db.update_record(msg_id, result)
707 try:
708 self.db.update_record(msg_id, result)
709 except Exception:
710 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
702 711
703 712 else:
704 713 self.log.debug("task::unknown task %s finished"%msg_id)
@@ -723,7 +732,11 b' class Hub(LoggingFactory):'
723 732
724 733 self.tasks[eid].append(msg_id)
725 734 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
726 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
735 try:
736 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
737 except Exception:
738 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
739
727 740
728 741 def mia_task_request(self, idents, msg):
729 742 raise NotImplementedError
@@ -772,7 +785,10 b' class Hub(LoggingFactory):'
772 785 else:
773 786 d[msg_type] = content.get('data', '')
774 787
775 self.db.update_record(msg_id, d)
788 try:
789 self.db.update_record(msg_id, d)
790 except Exception:
791 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
776 792
777 793
778 794
@@ -904,11 +920,15 b' class Hub(LoggingFactory):'
904 920 # build a fake header:
905 921 header = {}
906 922 header['engine'] = uuid
907 header['date'] = datetime.now().strftime(ISO8601)
923 header['date'] = datetime.now()
908 924 rec = dict(result_content=content, result_header=header, result_buffers=[])
909 925 rec['completed'] = header['date']
910 926 rec['engine_uuid'] = uuid
911 self.db.update_record(msg_id, rec)
927 try:
928 self.db.update_record(msg_id, rec)
929 except Exception:
930 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
931
912 932
913 933 def finish_registration(self, heart):
914 934 """Second half of engine registration, called after our HeartMonitor
@@ -1017,7 +1037,10 b' class Hub(LoggingFactory):'
1017 1037 msg_ids = content.get('msg_ids', [])
1018 1038 reply = dict(status='ok')
1019 1039 if msg_ids == 'all':
1020 self.db.drop_matching_records(dict(completed={'$ne':None}))
1040 try:
1041 self.db.drop_matching_records(dict(completed={'$ne':None}))
1042 except Exception:
1043 reply = error.wrap_exception()
1021 1044 else:
1022 1045 for msg_id in msg_ids:
1023 1046 if msg_id in self.all_completed:
@@ -1044,7 +1067,11 b' class Hub(LoggingFactory):'
1044 1067 break
1045 1068 msg_ids = self.completed.pop(eid)
1046 1069 uid = self.engines[eid].queue
1047 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1070 try:
1071 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1072 except Exception:
1073 reply = error.wrap_exception()
1074 break
1048 1075
1049 1076 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1050 1077
@@ -1052,6 +1079,23 b' class Hub(LoggingFactory):'
1052 1079 """Resubmit a task."""
1053 1080 raise NotImplementedError
1054 1081
1082 def _extract_record(self, rec):
1083 """decompose a TaskRecord dict into subsection of reply for get_result"""
1084 io_dict = {}
1085 for key in 'pyin pyout pyerr stdout stderr'.split():
1086 io_dict[key] = rec[key]
1087 content = { 'result_content': rec['result_content'],
1088 'header': rec['header'],
1089 'result_header' : rec['result_header'],
1090 'io' : io_dict,
1091 }
1092 if rec['result_buffers']:
1093 buffers = map(str, rec['result_buffers'])
1094 else:
1095 buffers = []
1096
1097 return content, buffers
1098
1055 1099 def get_results(self, client_id, msg):
1056 1100 """Get the result of 1 or more messages."""
1057 1101 content = msg['content']
@@ -1064,25 +1108,28 b' class Hub(LoggingFactory):'
1064 1108 content['completed'] = completed
1065 1109 buffers = []
1066 1110 if not statusonly:
1067 content['results'] = {}
1068 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1111 try:
1112 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1113 # turn match list into dict, for faster lookup
1114 records = {}
1115 for rec in matches:
1116 records[rec['msg_id']] = rec
1117 except Exception:
1118 content = error.wrap_exception()
1119 self.session.send(self.query, "result_reply", content=content,
1120 parent=msg, ident=client_id)
1121 return
1122 else:
1123 records = {}
1069 1124 for msg_id in msg_ids:
1070 1125 if msg_id in self.pending:
1071 1126 pending.append(msg_id)
1072 elif msg_id in self.all_completed:
1127 elif msg_id in self.all_completed or msg_id in records:
1073 1128 completed.append(msg_id)
1074 1129 if not statusonly:
1075 rec = records[msg_id]
1076 io_dict = {}
1077 for key in 'pyin pyout pyerr stdout stderr'.split():
1078 io_dict[key] = rec[key]
1079 content[msg_id] = { 'result_content': rec['result_content'],
1080 'header': rec['header'],
1081 'result_header' : rec['result_header'],
1082 'io' : io_dict,
1083 }
1084 if rec['result_buffers']:
1085 buffers.extend(map(str, rec['result_buffers']))
1130 c,bufs = self._extract_record(records[msg_id])
1131 content[msg_id] = c
1132 buffers.extend(bufs)
1086 1133 else:
1087 1134 try:
1088 1135 raise KeyError('No such message: '+msg_id)
@@ -1093,3 +1140,54 b' class Hub(LoggingFactory):'
1093 1140 parent=msg, ident=client_id,
1094 1141 buffers=buffers)
1095 1142
1143 def get_history(self, client_id, msg):
1144 """Get a list of all msg_ids in our DB records"""
1145 try:
1146 msg_ids = self.db.get_history()
1147 except Exception as e:
1148 content = error.wrap_exception()
1149 else:
1150 content = dict(status='ok', history=msg_ids)
1151
1152 self.session.send(self.query, "history_reply", content=content,
1153 parent=msg, ident=client_id)
1154
1155 def db_query(self, client_id, msg):
1156 """Perform a raw query on the task record database."""
1157 content = msg['content']
1158 query = content.get('query', {})
1159 keys = content.get('keys', None)
1160 query = util.extract_dates(query)
1161 buffers = []
1162 empty = list()
1163
1164 try:
1165 records = self.db.find_records(query, keys)
1166 except Exception as e:
1167 content = error.wrap_exception()
1168 else:
1169 # extract buffers from reply content:
1170 if keys is not None:
1171 buffer_lens = [] if 'buffers' in keys else None
1172 result_buffer_lens = [] if 'result_buffers' in keys else None
1173 else:
1174 buffer_lens = []
1175 result_buffer_lens = []
1176
1177 for rec in records:
1178 # buffers may be None, so double check
1179 if buffer_lens is not None:
1180 b = rec.pop('buffers', empty) or empty
1181 buffer_lens.append(len(b))
1182 buffers.extend(b)
1183 if result_buffer_lens is not None:
1184 rb = rec.pop('result_buffers', empty) or empty
1185 result_buffer_lens.append(len(rb))
1186 buffers.extend(rb)
1187 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1188 result_buffer_lens=result_buffer_lens)
1189
1190 self.session.send(self.query, "db_reply", content=content,
1191 parent=msg, ident=client_id,
1192 buffers=buffers)
1193
@@ -22,9 +22,9 b' from .dictdb import BaseDB'
22 22 class MongoDB(BaseDB):
23 23 """MongoDB TaskRecord backend."""
24 24
25 connection_args = List(config=True)
26 connection_kwargs = Dict(config=True)
27 database = CUnicode(config=True)
25 connection_args = List(config=True) # args passed to pymongo.Connection
26 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
27 database = CUnicode(config=True) # name of the mongodb database
28 28 _table = Dict()
29 29
30 30 def __init__(self, **kwargs):
@@ -37,13 +37,14 b' class MongoDB(BaseDB):'
37 37
38 38 def _binary_buffers(self, rec):
39 39 for key in ('buffers', 'result_buffers'):
40 if key in rec:
40 if rec.get(key, None):
41 41 rec[key] = map(Binary, rec[key])
42 return rec
42 43
43 44 def add_record(self, msg_id, rec):
44 45 """Add a new Task Record, by msg_id."""
45 46 # print rec
46 rec = _binary_buffers(rec)
47 rec = self._binary_buffers(rec)
47 48 obj_id = self._records.insert(rec)
48 49 self._table[msg_id] = obj_id
49 50
@@ -53,7 +54,7 b' class MongoDB(BaseDB):'
53 54
54 55 def update_record(self, msg_id, rec):
55 56 """Update the data in an existing record."""
56 rec = _binary_buffers(rec)
57 rec = self._binary_buffers(rec)
57 58 obj_id = self._table[msg_id]
58 59 self._records.update({'_id':obj_id}, {'$set': rec})
59 60
@@ -66,15 +67,30 b' class MongoDB(BaseDB):'
66 67 obj_id = self._table.pop(msg_id)
67 68 self._records.remove(obj_id)
68 69
69 def find_records(self, check, id_only=False):
70 """Find records matching a query dict."""
71 matches = list(self._records.find(check))
72 if id_only:
73 return [ rec['msg_id'] for rec in matches ]
74 else:
75 data = {}
76 for rec in matches:
77 data[rec['msg_id']] = rec
78 return data
70 def find_records(self, check, keys=None):
71 """Find records matching a query dict, optionally extracting subset of keys.
72
73 Returns list of matching records.
74
75 Parameters
76 ----------
77
78 check: dict
79 mongodb-style query argument
80 keys: list of strs [optional]
81 if specified, the subset of keys to extract. msg_id will *always* be
82 included.
83 """
84 if keys and 'msg_id' not in keys:
85 keys.append('msg_id')
86 matches = list(self._records.find(check,keys))
87 for rec in matches:
88 rec.pop('_id')
89 return matches
90
91 def get_history(self):
92 """get all msg_ids, ordered by time submitted."""
93 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
94 return [ rec['msg_id'] for rec in cursor ]
79 95
80 96
@@ -24,7 +24,7 b' from IPython.parallel.util import ISO8601'
24 24 #-----------------------------------------------------------------------------
25 25
26 26 operators = {
27 '$lt' : lambda a,b: "%s < ?",
27 '$lt' : "<",
28 28 '$gt' : ">",
29 29 # null is handled weird with ==,!=
30 30 '$eq' : "IS",
@@ -124,10 +124,11 b' class SQLiteDB(BaseDB):'
124 124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
125 125 pc.start()
126 126
127 def _defaults(self):
127 def _defaults(self, keys=None):
128 128 """create an empty record"""
129 129 d = {}
130 for key in self._keys:
130 keys = self._keys if keys is None else keys
131 for key in keys:
131 132 d[key] = None
132 133 return d
133 134
@@ -168,9 +169,6 b' class SQLiteDB(BaseDB):'
168 169 stdout text,
169 170 stderr text)
170 171 """%self.table)
171 # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers
172 # (msg_id text, result integer, buffer blob)
173 # """%self.table)
174 172 self._db.commit()
175 173
176 174 def _dict_to_list(self, d):
@@ -178,10 +176,11 b' class SQLiteDB(BaseDB):'
178 176
179 177 return [ d[key] for key in self._keys ]
180 178
181 def _list_to_dict(self, line):
179 def _list_to_dict(self, line, keys=None):
182 180 """Inverse of dict_to_list"""
183 d = self._defaults()
184 for key,value in zip(self._keys, line):
181 keys = self._keys if keys is None else keys
182 d = self._defaults(keys)
183 for key,value in zip(keys, line):
185 184 d[key] = value
186 185
187 186 return d
@@ -249,13 +248,14 b' class SQLiteDB(BaseDB):'
249 248 sets.append('%s = ?'%key)
250 249 values.append(rec[key])
251 250 query += ', '.join(sets)
252 query += ' WHERE msg_id == %r'%msg_id
251 query += ' WHERE msg_id == ?'
252 values.append(msg_id)
253 253 self._db.execute(query, values)
254 254 # self._db.commit()
255 255
256 256 def drop_record(self, msg_id):
257 257 """Remove a record from the DB."""
258 self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,))
258 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
259 259 # self._db.commit()
260 260
261 261 def drop_matching_records(self, check):
@@ -265,20 +265,48 b' class SQLiteDB(BaseDB):'
265 265 self._db.execute(query,args)
266 266 # self._db.commit()
267 267
268 def find_records(self, check, id_only=False):
269 """Find records matching a query dict."""
270 req = 'msg_id' if id_only else '*'
268 def find_records(self, check, keys=None):
269 """Find records matching a query dict, optionally extracting subset of keys.
270
271 Returns list of matching records.
272
273 Parameters
274 ----------
275
276 check: dict
277 mongodb-style query argument
278 keys: list of strs [optional]
279 if specified, the subset of keys to extract. msg_id will *always* be
280 included.
281 """
282 if keys:
283 bad_keys = [ key for key in keys if key not in self._keys ]
284 if bad_keys:
285 raise KeyError("Bad record key(s): %s"%bad_keys)
286
287 if keys:
288 # ensure msg_id is present and first:
289 if 'msg_id' in keys:
290 keys.remove('msg_id')
291 keys.insert(0, 'msg_id')
292 req = ', '.join(keys)
293 else:
294 req = '*'
271 295 expr,args = self._render_expression(check)
272 296 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
273 297 cursor = self._db.execute(query, args)
274 298 matches = cursor.fetchall()
275 if id_only:
276 return [ m[0] for m in matches ]
277 else:
278 records = {}
279 for line in matches:
280 rec = self._list_to_dict(line)
281 records[rec['msg_id']] = rec
282 return records
299 records = []
300 for line in matches:
301 rec = self._list_to_dict(line, keys)
302 records.append(rec)
303 return records
304
305 def get_history(self):
306 """get all msg_ids, ordered by time submitted."""
307 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
308 cursor = self._db.execute(query)
309 # will be a list of length 1 tuples
310 return [ tup[0] for tup in cursor.fetchall()]
283 311
284 312 __all__ = ['SQLiteDB'] No newline at end of file
@@ -28,6 +28,7 b' from zmq.eventloop.zmqstream import ZMQStream'
28 28 from .util import ISO8601
29 29
30 30 def squash_unicode(obj):
31 """coerce unicode back to bytestrings."""
31 32 if isinstance(obj,dict):
32 33 for key in obj.keys():
33 34 obj[key] = squash_unicode(obj[key])
@@ -40,7 +41,14 b' def squash_unicode(obj):'
40 41 obj = obj.encode('utf8')
41 42 return obj
42 43
43 json_packer = jsonapi.dumps
44 def _date_default(obj):
45 if isinstance(obj, datetime):
46 return obj.strftime(ISO8601)
47 else:
48 raise TypeError("%r is not JSON serializable"%obj)
49
50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
44 52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
45 53
46 54 pickle_packer = lambda o: pickle.dumps(o,-1)
@@ -17,6 +17,7 b' import re'
17 17 import stat
18 18 import socket
19 19 import sys
20 from datetime import datetime
20 21 from signal import signal, SIGINT, SIGABRT, SIGTERM
21 22 try:
22 23 from signal import SIGKILL
@@ -41,6 +42,7 b' from IPython.zmq.log import EnginePUBHandler'
41 42
42 43 # globals
43 44 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
45 ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
44 46
45 47 #-----------------------------------------------------------------------------
46 48 # Classes
@@ -99,6 +101,18 b' class ReverseDict(dict):'
99 101 # Functions
100 102 #-----------------------------------------------------------------------------
101 103
104 def extract_dates(obj):
105 """extract ISO8601 dates from unpacked JSON"""
106 if isinstance(obj, dict):
107 for k,v in obj.iteritems():
108 obj[k] = extract_dates(v)
109 elif isinstance(obj, list):
110 obj = [ extract_dates(o) for o in obj ]
111 elif isinstance(obj, basestring):
112 if ISO8601_RE.match(obj):
113 obj = datetime.strptime(obj, ISO8601)
114 return obj
115
102 116 def validate_url(url):
103 117 """validate a url for zeromq"""
104 118 if not isinstance(url, basestring):
@@ -460,3 +474,4 b' def local_logger(logname, loglevel=logging.DEBUG):'
460 474 handler.setLevel(loglevel)
461 475 logger.addHandler(handler)
462 476 logger.setLevel(loglevel)
477
General Comments 0
You need to be logged in to leave comments. Login now