Show More
@@ -0,0 +1,182 b'' | |||||
|
1 | """Tests for db backends""" | |||
|
2 | ||||
|
3 | #------------------------------------------------------------------------------- | |||
|
4 | # Copyright (C) 2011 The IPython Development Team | |||
|
5 | # | |||
|
6 | # Distributed under the terms of the BSD License. The full license is in | |||
|
7 | # the file COPYING, distributed as part of this software. | |||
|
8 | #------------------------------------------------------------------------------- | |||
|
9 | ||||
|
10 | #------------------------------------------------------------------------------- | |||
|
11 | # Imports | |||
|
12 | #------------------------------------------------------------------------------- | |||
|
13 | ||||
|
14 | ||||
|
15 | import tempfile | |||
|
16 | import time | |||
|
17 | ||||
|
18 | import uuid | |||
|
19 | ||||
|
20 | from datetime import datetime, timedelta | |||
|
21 | from random import choice, randint | |||
|
22 | from unittest import TestCase | |||
|
23 | ||||
|
24 | from nose import SkipTest | |||
|
25 | ||||
|
26 | from IPython.parallel import error, streamsession as ss | |||
|
27 | from IPython.parallel.controller.dictdb import DictDB | |||
|
28 | from IPython.parallel.controller.sqlitedb import SQLiteDB | |||
|
29 | from IPython.parallel.controller.hub import init_record, empty_record | |||
|
30 | ||||
|
31 | #------------------------------------------------------------------------------- | |||
|
32 | # TestCases | |||
|
33 | #------------------------------------------------------------------------------- | |||
|
34 | ||||
|
35 | class TestDictBackend(TestCase): | |||
|
36 | def setUp(self): | |||
|
37 | self.session = ss.StreamSession() | |||
|
38 | self.db = self.create_db() | |||
|
39 | self.load_records(16) | |||
|
40 | ||||
|
41 | def create_db(self): | |||
|
42 | return DictDB() | |||
|
43 | ||||
|
44 | def load_records(self, n=1): | |||
|
45 | """load n records for testing""" | |||
|
46 | #sleep 1/10 s, to ensure timestamp is different to previous calls | |||
|
47 | time.sleep(0.1) | |||
|
48 | msg_ids = [] | |||
|
49 | for i in range(n): | |||
|
50 | msg = self.session.msg('apply_request', content=dict(a=5)) | |||
|
51 | msg['buffers'] = [] | |||
|
52 | rec = init_record(msg) | |||
|
53 | msg_ids.append(msg['msg_id']) | |||
|
54 | self.db.add_record(msg['msg_id'], rec) | |||
|
55 | return msg_ids | |||
|
56 | ||||
|
57 | def test_add_record(self): | |||
|
58 | before = self.db.get_history() | |||
|
59 | self.load_records(5) | |||
|
60 | after = self.db.get_history() | |||
|
61 | self.assertEquals(len(after), len(before)+5) | |||
|
62 | self.assertEquals(after[:-5],before) | |||
|
63 | ||||
|
64 | def test_drop_record(self): | |||
|
65 | msg_id = self.load_records()[-1] | |||
|
66 | rec = self.db.get_record(msg_id) | |||
|
67 | self.db.drop_record(msg_id) | |||
|
68 | self.assertRaises(KeyError,self.db.get_record, msg_id) | |||
|
69 | ||||
|
70 | def _round_to_millisecond(self, dt): | |||
|
71 | """necessary because mongodb rounds microseconds""" | |||
|
72 | micro = dt.microsecond | |||
|
73 | extra = int(str(micro)[-3:]) | |||
|
74 | return dt - timedelta(microseconds=extra) | |||
|
75 | ||||
|
76 | def test_update_record(self): | |||
|
77 | now = self._round_to_millisecond(datetime.now()) | |||
|
78 | # | |||
|
79 | msg_id = self.db.get_history()[-1] | |||
|
80 | rec1 = self.db.get_record(msg_id) | |||
|
81 | data = {'stdout': 'hello there', 'completed' : now} | |||
|
82 | self.db.update_record(msg_id, data) | |||
|
83 | rec2 = self.db.get_record(msg_id) | |||
|
84 | self.assertEquals(rec2['stdout'], 'hello there') | |||
|
85 | self.assertEquals(rec2['completed'], now) | |||
|
86 | rec1.update(data) | |||
|
87 | self.assertEquals(rec1, rec2) | |||
|
88 | ||||
|
89 | # def test_update_record_bad(self): | |||
|
90 | # """test updating nonexistant records""" | |||
|
91 | # msg_id = str(uuid.uuid4()) | |||
|
92 | # data = {'stdout': 'hello there'} | |||
|
93 | # self.assertRaises(KeyError, self.db.update_record, msg_id, data) | |||
|
94 | ||||
|
95 | def test_find_records_dt(self): | |||
|
96 | """test finding records by date""" | |||
|
97 | hist = self.db.get_history() | |||
|
98 | middle = self.db.get_record(hist[len(hist)/2]) | |||
|
99 | tic = middle['submitted'] | |||
|
100 | before = self.db.find_records({'submitted' : {'$lt' : tic}}) | |||
|
101 | after = self.db.find_records({'submitted' : {'$gte' : tic}}) | |||
|
102 | self.assertEquals(len(before)+len(after),len(hist)) | |||
|
103 | for b in before: | |||
|
104 | self.assertTrue(b['submitted'] < tic) | |||
|
105 | for a in after: | |||
|
106 | self.assertTrue(a['submitted'] >= tic) | |||
|
107 | same = self.db.find_records({'submitted' : tic}) | |||
|
108 | for s in same: | |||
|
109 | self.assertTrue(s['submitted'] == tic) | |||
|
110 | ||||
|
111 | def test_find_records_keys(self): | |||
|
112 | """test extracting subset of record keys""" | |||
|
113 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) | |||
|
114 | for rec in found: | |||
|
115 | self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) | |||
|
116 | ||||
|
117 | def test_find_records_msg_id(self): | |||
|
118 | """ensure msg_id is always in found records""" | |||
|
119 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) | |||
|
120 | for rec in found: | |||
|
121 | self.assertTrue('msg_id' in rec.keys()) | |||
|
122 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) | |||
|
123 | for rec in found: | |||
|
124 | self.assertTrue('msg_id' in rec.keys()) | |||
|
125 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) | |||
|
126 | for rec in found: | |||
|
127 | self.assertTrue('msg_id' in rec.keys()) | |||
|
128 | ||||
|
129 | def test_find_records_in(self): | |||
|
130 | """test finding records with '$in','$nin' operators""" | |||
|
131 | hist = self.db.get_history() | |||
|
132 | even = hist[::2] | |||
|
133 | odd = hist[1::2] | |||
|
134 | recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) | |||
|
135 | found = [ r['msg_id'] for r in recs ] | |||
|
136 | self.assertEquals(set(even), set(found)) | |||
|
137 | recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) | |||
|
138 | found = [ r['msg_id'] for r in recs ] | |||
|
139 | self.assertEquals(set(odd), set(found)) | |||
|
140 | ||||
|
141 | def test_get_history(self): | |||
|
142 | msg_ids = self.db.get_history() | |||
|
143 | latest = datetime(1984,1,1) | |||
|
144 | for msg_id in msg_ids: | |||
|
145 | rec = self.db.get_record(msg_id) | |||
|
146 | newt = rec['submitted'] | |||
|
147 | self.assertTrue(newt >= latest) | |||
|
148 | latest = newt | |||
|
149 | msg_id = self.load_records(1)[-1] | |||
|
150 | self.assertEquals(self.db.get_history()[-1],msg_id) | |||
|
151 | ||||
|
152 | def test_datetime(self): | |||
|
153 | """get/set timestamps with datetime objects""" | |||
|
154 | msg_id = self.db.get_history()[-1] | |||
|
155 | rec = self.db.get_record(msg_id) | |||
|
156 | self.assertTrue(isinstance(rec['submitted'], datetime)) | |||
|
157 | self.db.update_record(msg_id, dict(completed=datetime.now())) | |||
|
158 | rec = self.db.get_record(msg_id) | |||
|
159 | self.assertTrue(isinstance(rec['completed'], datetime)) | |||
|
160 | ||||
|
161 | class TestSQLiteBackend(TestDictBackend): | |||
|
162 | def create_db(self): | |||
|
163 | return SQLiteDB(location=tempfile.gettempdir()) | |||
|
164 | ||||
|
165 | def tearDown(self): | |||
|
166 | self.db._db.close() | |||
|
167 | ||||
|
168 | # optional MongoDB test | |||
|
169 | try: | |||
|
170 | from IPython.parallel.controller.mongodb import MongoDB | |||
|
171 | except ImportError: | |||
|
172 | pass | |||
|
173 | else: | |||
|
174 | class TestMongoBackend(TestDictBackend): | |||
|
175 | def create_db(self): | |||
|
176 | try: | |||
|
177 | return MongoDB(database='iptestdb') | |||
|
178 | except Exception: | |||
|
179 | raise SkipTest("Couldn't connect to mongodb instance") | |||
|
180 | ||||
|
181 | def tearDown(self): | |||
|
182 | self.db._connection.drop_database('iptestdb') |
@@ -1215,5 +1215,78 b' class Client(HasTraits):' | |||||
1215 | if content['status'] != 'ok': |
|
1215 | if content['status'] != 'ok': | |
1216 | raise self._unwrap_exception(content) |
|
1216 | raise self._unwrap_exception(content) | |
1217 |
|
1217 | |||
|
1218 | @spin_first | |||
|
1219 | def hub_history(self): | |||
|
1220 | """Get the Hub's history | |||
|
1221 | ||||
|
1222 | Just like the Client, the Hub has a history, which is a list of msg_ids. | |||
|
1223 | This will contain the history of all clients, and, depending on configuration, | |||
|
1224 | may contain history across multiple cluster sessions. | |||
|
1225 | ||||
|
1226 | Any msg_id returned here is a valid argument to `get_result`. | |||
|
1227 | ||||
|
1228 | Returns | |||
|
1229 | ------- | |||
|
1230 | ||||
|
1231 | msg_ids : list of strs | |||
|
1232 | list of all msg_ids, ordered by task submission time. | |||
|
1233 | """ | |||
|
1234 | ||||
|
1235 | self.session.send(self._query_socket, "history_request", content={}) | |||
|
1236 | idents, msg = self.session.recv(self._query_socket, 0) | |||
|
1237 | ||||
|
1238 | if self.debug: | |||
|
1239 | pprint(msg) | |||
|
1240 | content = msg['content'] | |||
|
1241 | if content['status'] != 'ok': | |||
|
1242 | raise self._unwrap_exception(content) | |||
|
1243 | else: | |||
|
1244 | return content['history'] | |||
|
1245 | ||||
|
1246 | @spin_first | |||
|
1247 | def db_query(self, query, keys=None): | |||
|
1248 | """Query the Hub's TaskRecord database | |||
|
1249 | ||||
|
1250 | This will return a list of task record dicts that match `query` | |||
|
1251 | ||||
|
1252 | Parameters | |||
|
1253 | ---------- | |||
|
1254 | ||||
|
1255 | query : mongodb query dict | |||
|
1256 | The search dict. See mongodb query docs for details. | |||
|
1257 | keys : list of strs [optional] | |||
|
1258 | THe subset of keys to be returned. The default is to fetch everything. | |||
|
1259 | 'msg_id' will *always* be included. | |||
|
1260 | """ | |||
|
1261 | content = dict(query=query, keys=keys) | |||
|
1262 | self.session.send(self._query_socket, "db_request", content=content) | |||
|
1263 | idents, msg = self.session.recv(self._query_socket, 0) | |||
|
1264 | if self.debug: | |||
|
1265 | pprint(msg) | |||
|
1266 | content = msg['content'] | |||
|
1267 | if content['status'] != 'ok': | |||
|
1268 | raise self._unwrap_exception(content) | |||
|
1269 | ||||
|
1270 | records = content['records'] | |||
|
1271 | buffer_lens = content['buffer_lens'] | |||
|
1272 | result_buffer_lens = content['result_buffer_lens'] | |||
|
1273 | buffers = msg['buffers'] | |||
|
1274 | has_bufs = buffer_lens is not None | |||
|
1275 | has_rbufs = result_buffer_lens is not None | |||
|
1276 | for i,rec in enumerate(records): | |||
|
1277 | # relink buffers | |||
|
1278 | if has_bufs: | |||
|
1279 | blen = buffer_lens[i] | |||
|
1280 | rec['buffers'], buffers = buffers[:blen],buffers[blen:] | |||
|
1281 | if has_rbufs: | |||
|
1282 | blen = result_buffer_lens[i] | |||
|
1283 | rec['result_buffers'], buffers = buffers[:blen],buffers[blen:] | |||
|
1284 | # turn timestamps back into times | |||
|
1285 | for key in 'submitted started completed resubmitted'.split(): | |||
|
1286 | maybedate = rec.get(key, None) | |||
|
1287 | if maybedate and util.ISO8601_RE.match(maybedate): | |||
|
1288 | rec[key] = datetime.strptime(maybedate, util.ISO8601) | |||
|
1289 | ||||
|
1290 | return records | |||
1218 |
|
1291 | |||
1219 | __all__ = [ 'Client' ] |
|
1292 | __all__ = [ 'Client' ] |
@@ -103,9 +103,9 b' class DictDB(BaseDB):' | |||||
103 | return False |
|
103 | return False | |
104 | return True |
|
104 | return True | |
105 |
|
105 | |||
106 |
def _match(self, check |
|
106 | def _match(self, check): | |
107 | """Find all the matches for a check dict.""" |
|
107 | """Find all the matches for a check dict.""" | |
108 |
matches = |
|
108 | matches = [] | |
109 | tests = {} |
|
109 | tests = {} | |
110 | for k,v in check.iteritems(): |
|
110 | for k,v in check.iteritems(): | |
111 | if isinstance(v, dict): |
|
111 | if isinstance(v, dict): | |
@@ -113,14 +113,18 b' class DictDB(BaseDB):' | |||||
113 | else: |
|
113 | else: | |
114 | tests[k] = lambda o: o==v |
|
114 | tests[k] = lambda o: o==v | |
115 |
|
115 | |||
116 |
for |
|
116 | for rec in self._records.itervalues(): | |
117 | if self._match_one(rec, tests): |
|
117 | if self._match_one(rec, tests): | |
118 |
matches |
|
118 | matches.append(rec) | |
119 | if id_only: |
|
119 | return matches | |
120 | return matches.keys() |
|
120 | ||
121 | else: |
|
121 | def _extract_subdict(self, rec, keys): | |
122 | return matches |
|
122 | """extract subdict of keys""" | |
123 |
|
|
123 | d = {} | |
|
124 | d['msg_id'] = rec['msg_id'] | |||
|
125 | for key in keys: | |||
|
126 | d[key] = rec[key] | |||
|
127 | return d | |||
124 |
|
128 | |||
125 | def add_record(self, msg_id, rec): |
|
129 | def add_record(self, msg_id, rec): | |
126 | """Add a new Task Record, by msg_id.""" |
|
130 | """Add a new Task Record, by msg_id.""" | |
@@ -140,7 +144,7 b' class DictDB(BaseDB):' | |||||
140 |
|
144 | |||
141 | def drop_matching_records(self, check): |
|
145 | def drop_matching_records(self, check): | |
142 | """Remove a record from the DB.""" |
|
146 | """Remove a record from the DB.""" | |
143 |
matches = self._match(check |
|
147 | matches = self._match(check) | |
144 | for m in matches: |
|
148 | for m in matches: | |
145 | del self._records[m] |
|
149 | del self._records[m] | |
146 |
|
150 | |||
@@ -149,7 +153,28 b' class DictDB(BaseDB):' | |||||
149 | del self._records[msg_id] |
|
153 | del self._records[msg_id] | |
150 |
|
154 | |||
151 |
|
155 | |||
152 |
def find_records(self, check, |
|
156 | def find_records(self, check, keys=None): | |
153 |
"""Find records matching a query dict. |
|
157 | """Find records matching a query dict, optionally extracting subset of keys. | |
154 | matches = self._match(check, id_only) |
|
158 | ||
155 | return matches No newline at end of file |
|
159 | Returns dict keyed by msg_id of matching records. | |
|
160 | ||||
|
161 | Parameters | |||
|
162 | ---------- | |||
|
163 | ||||
|
164 | check: dict | |||
|
165 | mongodb-style query argument | |||
|
166 | keys: list of strs [optional] | |||
|
167 | if specified, the subset of keys to extract. msg_id will *always* be | |||
|
168 | included. | |||
|
169 | """ | |||
|
170 | matches = self._match(check) | |||
|
171 | if keys: | |||
|
172 | return [ self._extract_subdict(rec, keys) for rec in matches ] | |||
|
173 | else: | |||
|
174 | return matches | |||
|
175 | ||||
|
176 | ||||
|
177 | def get_history(self): | |||
|
178 | """get all msg_ids, ordered by time submitted.""" | |||
|
179 | msg_ids = self._records.keys() | |||
|
180 | return sorted(msg_ids, key=lambda m: self._records[m]['submitted']) |
@@ -27,9 +27,8 b' from zmq.eventloop.zmqstream import ZMQStream' | |||||
27 | from IPython.utils.importstring import import_item |
|
27 | from IPython.utils.importstring import import_item | |
28 | from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool |
|
28 | from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool | |
29 |
|
29 | |||
30 | from IPython.parallel import error |
|
30 | from IPython.parallel import error, util | |
31 | from IPython.parallel.factory import RegistrationFactory, LoggingFactory |
|
31 | from IPython.parallel.factory import RegistrationFactory, LoggingFactory | |
32 | from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601 |
|
|||
33 |
|
32 | |||
34 | from .heartmonitor import HeartMonitor |
|
33 | from .heartmonitor import HeartMonitor | |
35 |
|
34 | |||
@@ -76,7 +75,7 b' def init_record(msg):' | |||||
76 | 'header' : header, |
|
75 | 'header' : header, | |
77 | 'content': msg['content'], |
|
76 | 'content': msg['content'], | |
78 | 'buffers': msg['buffers'], |
|
77 | 'buffers': msg['buffers'], | |
79 | 'submitted': datetime.strptime(header['date'], ISO8601), |
|
78 | 'submitted': datetime.strptime(header['date'], util.ISO8601), | |
80 | 'client_uuid' : None, |
|
79 | 'client_uuid' : None, | |
81 | 'engine_uuid' : None, |
|
80 | 'engine_uuid' : None, | |
82 | 'started': None, |
|
81 | 'started': None, | |
@@ -119,32 +118,32 b' class HubFactory(RegistrationFactory):' | |||||
119 | # port-pairs for monitoredqueues: |
|
118 | # port-pairs for monitoredqueues: | |
120 | hb = Instance(list, config=True) |
|
119 | hb = Instance(list, config=True) | |
121 | def _hb_default(self): |
|
120 | def _hb_default(self): | |
122 | return select_random_ports(2) |
|
121 | return util.select_random_ports(2) | |
123 |
|
122 | |||
124 | mux = Instance(list, config=True) |
|
123 | mux = Instance(list, config=True) | |
125 | def _mux_default(self): |
|
124 | def _mux_default(self): | |
126 | return select_random_ports(2) |
|
125 | return util.select_random_ports(2) | |
127 |
|
126 | |||
128 | task = Instance(list, config=True) |
|
127 | task = Instance(list, config=True) | |
129 | def _task_default(self): |
|
128 | def _task_default(self): | |
130 | return select_random_ports(2) |
|
129 | return util.select_random_ports(2) | |
131 |
|
130 | |||
132 | control = Instance(list, config=True) |
|
131 | control = Instance(list, config=True) | |
133 | def _control_default(self): |
|
132 | def _control_default(self): | |
134 | return select_random_ports(2) |
|
133 | return util.select_random_ports(2) | |
135 |
|
134 | |||
136 | iopub = Instance(list, config=True) |
|
135 | iopub = Instance(list, config=True) | |
137 | def _iopub_default(self): |
|
136 | def _iopub_default(self): | |
138 | return select_random_ports(2) |
|
137 | return util.select_random_ports(2) | |
139 |
|
138 | |||
140 | # single ports: |
|
139 | # single ports: | |
141 | mon_port = Instance(int, config=True) |
|
140 | mon_port = Instance(int, config=True) | |
142 | def _mon_port_default(self): |
|
141 | def _mon_port_default(self): | |
143 | return select_random_ports(1)[0] |
|
142 | return util.select_random_ports(1)[0] | |
144 |
|
143 | |||
145 | notifier_port = Instance(int, config=True) |
|
144 | notifier_port = Instance(int, config=True) | |
146 | def _notifier_port_default(self): |
|
145 | def _notifier_port_default(self): | |
147 | return select_random_ports(1)[0] |
|
146 | return util.select_random_ports(1)[0] | |
148 |
|
147 | |||
149 | ping = Int(1000, config=True) # ping frequency |
|
148 | ping = Int(1000, config=True) # ping frequency | |
150 |
|
149 | |||
@@ -344,11 +343,11 b' class Hub(LoggingFactory):' | |||||
344 | # validate connection dicts: |
|
343 | # validate connection dicts: | |
345 | for k,v in self.client_info.iteritems(): |
|
344 | for k,v in self.client_info.iteritems(): | |
346 | if k == 'task': |
|
345 | if k == 'task': | |
347 | validate_url_container(v[1]) |
|
346 | util.validate_url_container(v[1]) | |
348 | else: |
|
347 | else: | |
349 | validate_url_container(v) |
|
348 | util.validate_url_container(v) | |
350 | # validate_url_container(self.client_info) |
|
349 | # util.validate_url_container(self.client_info) | |
351 | validate_url_container(self.engine_info) |
|
350 | util.validate_url_container(self.engine_info) | |
352 |
|
351 | |||
353 | # register our callbacks |
|
352 | # register our callbacks | |
354 | self.query.on_recv(self.dispatch_query) |
|
353 | self.query.on_recv(self.dispatch_query) | |
@@ -369,6 +368,8 b' class Hub(LoggingFactory):' | |||||
369 |
|
368 | |||
370 | self.query_handlers = {'queue_request': self.queue_status, |
|
369 | self.query_handlers = {'queue_request': self.queue_status, | |
371 | 'result_request': self.get_results, |
|
370 | 'result_request': self.get_results, | |
|
371 | 'history_request': self.get_history, | |||
|
372 | 'db_request': self.db_query, | |||
372 | 'purge_request': self.purge_results, |
|
373 | 'purge_request': self.purge_results, | |
373 | 'load_request': self.check_load, |
|
374 | 'load_request': self.check_load, | |
374 | 'resubmit_request': self.resubmit_task, |
|
375 | 'resubmit_request': self.resubmit_task, | |
@@ -606,10 +607,10 b' class Hub(LoggingFactory):' | |||||
606 | return |
|
607 | return | |
607 | # update record anyway, because the unregistration could have been premature |
|
608 | # update record anyway, because the unregistration could have been premature | |
608 | rheader = msg['header'] |
|
609 | rheader = msg['header'] | |
609 | completed = datetime.strptime(rheader['date'], ISO8601) |
|
610 | completed = datetime.strptime(rheader['date'], util.ISO8601) | |
610 | started = rheader.get('started', None) |
|
611 | started = rheader.get('started', None) | |
611 | if started is not None: |
|
612 | if started is not None: | |
612 | started = datetime.strptime(started, ISO8601) |
|
613 | started = datetime.strptime(started, util.ISO8601) | |
613 | result = { |
|
614 | result = { | |
614 | 'result_header' : rheader, |
|
615 | 'result_header' : rheader, | |
615 | 'result_content': msg['content'], |
|
616 | 'result_content': msg['content'], | |
@@ -618,7 +619,10 b' class Hub(LoggingFactory):' | |||||
618 | } |
|
619 | } | |
619 |
|
620 | |||
620 | result['result_buffers'] = msg['buffers'] |
|
621 | result['result_buffers'] = msg['buffers'] | |
621 | self.db.update_record(msg_id, result) |
|
622 | try: | |
|
623 | self.db.update_record(msg_id, result) | |||
|
624 | except Exception: | |||
|
625 | self.log.error("DB Error updating record %r"%msg_id, exc_info=True) | |||
622 |
|
626 | |||
623 |
|
627 | |||
624 | #--------------------- Task Queue Traffic ------------------------------ |
|
628 | #--------------------- Task Queue Traffic ------------------------------ | |
@@ -653,6 +657,8 b' class Hub(LoggingFactory):' | |||||
653 | self.db.update_record(msg_id, record) |
|
657 | self.db.update_record(msg_id, record) | |
654 | except KeyError: |
|
658 | except KeyError: | |
655 | self.db.add_record(msg_id, record) |
|
659 | self.db.add_record(msg_id, record) | |
|
660 | except Exception: | |||
|
661 | self.log.error("DB Error saving task request %r"%msg_id, exc_info=True) | |||
656 |
|
662 | |||
657 | def save_task_result(self, idents, msg): |
|
663 | def save_task_result(self, idents, msg): | |
658 | """save the result of a completed task.""" |
|
664 | """save the result of a completed task.""" | |
@@ -685,10 +691,10 b' class Hub(LoggingFactory):' | |||||
685 | self.completed[eid].append(msg_id) |
|
691 | self.completed[eid].append(msg_id) | |
686 | if msg_id in self.tasks[eid]: |
|
692 | if msg_id in self.tasks[eid]: | |
687 | self.tasks[eid].remove(msg_id) |
|
693 | self.tasks[eid].remove(msg_id) | |
688 | completed = datetime.strptime(header['date'], ISO8601) |
|
694 | completed = datetime.strptime(header['date'], util.ISO8601) | |
689 | started = header.get('started', None) |
|
695 | started = header.get('started', None) | |
690 | if started is not None: |
|
696 | if started is not None: | |
691 | started = datetime.strptime(started, ISO8601) |
|
697 | started = datetime.strptime(started, util.ISO8601) | |
692 | result = { |
|
698 | result = { | |
693 | 'result_header' : header, |
|
699 | 'result_header' : header, | |
694 | 'result_content': msg['content'], |
|
700 | 'result_content': msg['content'], | |
@@ -698,7 +704,10 b' class Hub(LoggingFactory):' | |||||
698 | } |
|
704 | } | |
699 |
|
705 | |||
700 | result['result_buffers'] = msg['buffers'] |
|
706 | result['result_buffers'] = msg['buffers'] | |
701 | self.db.update_record(msg_id, result) |
|
707 | try: | |
|
708 | self.db.update_record(msg_id, result) | |||
|
709 | except Exception: | |||
|
710 | self.log.error("DB Error saving task request %r"%msg_id, exc_info=True) | |||
702 |
|
711 | |||
703 | else: |
|
712 | else: | |
704 | self.log.debug("task::unknown task %s finished"%msg_id) |
|
713 | self.log.debug("task::unknown task %s finished"%msg_id) | |
@@ -723,7 +732,11 b' class Hub(LoggingFactory):' | |||||
723 |
|
732 | |||
724 | self.tasks[eid].append(msg_id) |
|
733 | self.tasks[eid].append(msg_id) | |
725 | # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid)) |
|
734 | # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid)) | |
726 | self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) |
|
735 | try: | |
|
736 | self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) | |||
|
737 | except Exception: | |||
|
738 | self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True) | |||
|
739 | ||||
727 |
|
740 | |||
728 | def mia_task_request(self, idents, msg): |
|
741 | def mia_task_request(self, idents, msg): | |
729 | raise NotImplementedError |
|
742 | raise NotImplementedError | |
@@ -772,7 +785,10 b' class Hub(LoggingFactory):' | |||||
772 | else: |
|
785 | else: | |
773 | d[msg_type] = content.get('data', '') |
|
786 | d[msg_type] = content.get('data', '') | |
774 |
|
787 | |||
775 | self.db.update_record(msg_id, d) |
|
788 | try: | |
|
789 | self.db.update_record(msg_id, d) | |||
|
790 | except Exception: | |||
|
791 | self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True) | |||
776 |
|
792 | |||
777 |
|
793 | |||
778 |
|
794 | |||
@@ -904,11 +920,15 b' class Hub(LoggingFactory):' | |||||
904 | # build a fake header: |
|
920 | # build a fake header: | |
905 | header = {} |
|
921 | header = {} | |
906 | header['engine'] = uuid |
|
922 | header['engine'] = uuid | |
907 |
header['date'] = datetime.now() |
|
923 | header['date'] = datetime.now() | |
908 | rec = dict(result_content=content, result_header=header, result_buffers=[]) |
|
924 | rec = dict(result_content=content, result_header=header, result_buffers=[]) | |
909 | rec['completed'] = header['date'] |
|
925 | rec['completed'] = header['date'] | |
910 | rec['engine_uuid'] = uuid |
|
926 | rec['engine_uuid'] = uuid | |
911 | self.db.update_record(msg_id, rec) |
|
927 | try: | |
|
928 | self.db.update_record(msg_id, rec) | |||
|
929 | except Exception: | |||
|
930 | self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True) | |||
|
931 | ||||
912 |
|
932 | |||
913 | def finish_registration(self, heart): |
|
933 | def finish_registration(self, heart): | |
914 | """Second half of engine registration, called after our HeartMonitor |
|
934 | """Second half of engine registration, called after our HeartMonitor | |
@@ -1017,7 +1037,10 b' class Hub(LoggingFactory):' | |||||
1017 | msg_ids = content.get('msg_ids', []) |
|
1037 | msg_ids = content.get('msg_ids', []) | |
1018 | reply = dict(status='ok') |
|
1038 | reply = dict(status='ok') | |
1019 | if msg_ids == 'all': |
|
1039 | if msg_ids == 'all': | |
1020 | self.db.drop_matching_records(dict(completed={'$ne':None})) |
|
1040 | try: | |
|
1041 | self.db.drop_matching_records(dict(completed={'$ne':None})) | |||
|
1042 | except Exception: | |||
|
1043 | reply = error.wrap_exception() | |||
1021 | else: |
|
1044 | else: | |
1022 | for msg_id in msg_ids: |
|
1045 | for msg_id in msg_ids: | |
1023 | if msg_id in self.all_completed: |
|
1046 | if msg_id in self.all_completed: | |
@@ -1044,7 +1067,11 b' class Hub(LoggingFactory):' | |||||
1044 | break |
|
1067 | break | |
1045 | msg_ids = self.completed.pop(eid) |
|
1068 | msg_ids = self.completed.pop(eid) | |
1046 | uid = self.engines[eid].queue |
|
1069 | uid = self.engines[eid].queue | |
1047 | self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) |
|
1070 | try: | |
|
1071 | self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) | |||
|
1072 | except Exception: | |||
|
1073 | reply = error.wrap_exception() | |||
|
1074 | break | |||
1048 |
|
1075 | |||
1049 | self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) |
|
1076 | self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) | |
1050 |
|
1077 | |||
@@ -1052,6 +1079,23 b' class Hub(LoggingFactory):' | |||||
1052 | """Resubmit a task.""" |
|
1079 | """Resubmit a task.""" | |
1053 | raise NotImplementedError |
|
1080 | raise NotImplementedError | |
1054 |
|
1081 | |||
|
1082 | def _extract_record(self, rec): | |||
|
1083 | """decompose a TaskRecord dict into subsection of reply for get_result""" | |||
|
1084 | io_dict = {} | |||
|
1085 | for key in 'pyin pyout pyerr stdout stderr'.split(): | |||
|
1086 | io_dict[key] = rec[key] | |||
|
1087 | content = { 'result_content': rec['result_content'], | |||
|
1088 | 'header': rec['header'], | |||
|
1089 | 'result_header' : rec['result_header'], | |||
|
1090 | 'io' : io_dict, | |||
|
1091 | } | |||
|
1092 | if rec['result_buffers']: | |||
|
1093 | buffers = map(str, rec['result_buffers']) | |||
|
1094 | else: | |||
|
1095 | buffers = [] | |||
|
1096 | ||||
|
1097 | return content, buffers | |||
|
1098 | ||||
1055 | def get_results(self, client_id, msg): |
|
1099 | def get_results(self, client_id, msg): | |
1056 | """Get the result of 1 or more messages.""" |
|
1100 | """Get the result of 1 or more messages.""" | |
1057 | content = msg['content'] |
|
1101 | content = msg['content'] | |
@@ -1064,25 +1108,28 b' class Hub(LoggingFactory):' | |||||
1064 | content['completed'] = completed |
|
1108 | content['completed'] = completed | |
1065 | buffers = [] |
|
1109 | buffers = [] | |
1066 | if not statusonly: |
|
1110 | if not statusonly: | |
1067 | content['results'] = {} |
|
1111 | try: | |
1068 |
|
|
1112 | matches = self.db.find_records(dict(msg_id={'$in':msg_ids})) | |
|
1113 | # turn match list into dict, for faster lookup | |||
|
1114 | records = {} | |||
|
1115 | for rec in matches: | |||
|
1116 | records[rec['msg_id']] = rec | |||
|
1117 | except Exception: | |||
|
1118 | content = error.wrap_exception() | |||
|
1119 | self.session.send(self.query, "result_reply", content=content, | |||
|
1120 | parent=msg, ident=client_id) | |||
|
1121 | return | |||
|
1122 | else: | |||
|
1123 | records = {} | |||
1069 | for msg_id in msg_ids: |
|
1124 | for msg_id in msg_ids: | |
1070 | if msg_id in self.pending: |
|
1125 | if msg_id in self.pending: | |
1071 | pending.append(msg_id) |
|
1126 | pending.append(msg_id) | |
1072 | elif msg_id in self.all_completed: |
|
1127 | elif msg_id in self.all_completed or msg_id in records: | |
1073 | completed.append(msg_id) |
|
1128 | completed.append(msg_id) | |
1074 | if not statusonly: |
|
1129 | if not statusonly: | |
1075 |
|
|
1130 | c,bufs = self._extract_record(records[msg_id]) | |
1076 |
|
|
1131 | content[msg_id] = c | |
1077 | for key in 'pyin pyout pyerr stdout stderr'.split(): |
|
1132 | buffers.extend(bufs) | |
1078 | io_dict[key] = rec[key] |
|
|||
1079 | content[msg_id] = { 'result_content': rec['result_content'], |
|
|||
1080 | 'header': rec['header'], |
|
|||
1081 | 'result_header' : rec['result_header'], |
|
|||
1082 | 'io' : io_dict, |
|
|||
1083 | } |
|
|||
1084 | if rec['result_buffers']: |
|
|||
1085 | buffers.extend(map(str, rec['result_buffers'])) |
|
|||
1086 | else: |
|
1133 | else: | |
1087 | try: |
|
1134 | try: | |
1088 | raise KeyError('No such message: '+msg_id) |
|
1135 | raise KeyError('No such message: '+msg_id) | |
@@ -1093,3 +1140,54 b' class Hub(LoggingFactory):' | |||||
1093 | parent=msg, ident=client_id, |
|
1140 | parent=msg, ident=client_id, | |
1094 | buffers=buffers) |
|
1141 | buffers=buffers) | |
1095 |
|
1142 | |||
|
1143 | def get_history(self, client_id, msg): | |||
|
1144 | """Get a list of all msg_ids in our DB records""" | |||
|
1145 | try: | |||
|
1146 | msg_ids = self.db.get_history() | |||
|
1147 | except Exception as e: | |||
|
1148 | content = error.wrap_exception() | |||
|
1149 | else: | |||
|
1150 | content = dict(status='ok', history=msg_ids) | |||
|
1151 | ||||
|
1152 | self.session.send(self.query, "history_reply", content=content, | |||
|
1153 | parent=msg, ident=client_id) | |||
|
1154 | ||||
|
1155 | def db_query(self, client_id, msg): | |||
|
1156 | """Perform a raw query on the task record database.""" | |||
|
1157 | content = msg['content'] | |||
|
1158 | query = content.get('query', {}) | |||
|
1159 | keys = content.get('keys', None) | |||
|
1160 | query = util.extract_dates(query) | |||
|
1161 | buffers = [] | |||
|
1162 | empty = list() | |||
|
1163 | ||||
|
1164 | try: | |||
|
1165 | records = self.db.find_records(query, keys) | |||
|
1166 | except Exception as e: | |||
|
1167 | content = error.wrap_exception() | |||
|
1168 | else: | |||
|
1169 | # extract buffers from reply content: | |||
|
1170 | if keys is not None: | |||
|
1171 | buffer_lens = [] if 'buffers' in keys else None | |||
|
1172 | result_buffer_lens = [] if 'result_buffers' in keys else None | |||
|
1173 | else: | |||
|
1174 | buffer_lens = [] | |||
|
1175 | result_buffer_lens = [] | |||
|
1176 | ||||
|
1177 | for rec in records: | |||
|
1178 | # buffers may be None, so double check | |||
|
1179 | if buffer_lens is not None: | |||
|
1180 | b = rec.pop('buffers', empty) or empty | |||
|
1181 | buffer_lens.append(len(b)) | |||
|
1182 | buffers.extend(b) | |||
|
1183 | if result_buffer_lens is not None: | |||
|
1184 | rb = rec.pop('result_buffers', empty) or empty | |||
|
1185 | result_buffer_lens.append(len(rb)) | |||
|
1186 | buffers.extend(rb) | |||
|
1187 | content = dict(status='ok', records=records, buffer_lens=buffer_lens, | |||
|
1188 | result_buffer_lens=result_buffer_lens) | |||
|
1189 | ||||
|
1190 | self.session.send(self.query, "db_reply", content=content, | |||
|
1191 | parent=msg, ident=client_id, | |||
|
1192 | buffers=buffers) | |||
|
1193 |
@@ -22,9 +22,9 b' from .dictdb import BaseDB' | |||||
22 | class MongoDB(BaseDB): |
|
22 | class MongoDB(BaseDB): | |
23 | """MongoDB TaskRecord backend.""" |
|
23 | """MongoDB TaskRecord backend.""" | |
24 |
|
24 | |||
25 | connection_args = List(config=True) |
|
25 | connection_args = List(config=True) # args passed to pymongo.Connection | |
26 | connection_kwargs = Dict(config=True) |
|
26 | connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection | |
27 | database = CUnicode(config=True) |
|
27 | database = CUnicode(config=True) # name of the mongodb database | |
28 | _table = Dict() |
|
28 | _table = Dict() | |
29 |
|
29 | |||
30 | def __init__(self, **kwargs): |
|
30 | def __init__(self, **kwargs): | |
@@ -37,13 +37,14 b' class MongoDB(BaseDB):' | |||||
37 |
|
37 | |||
38 | def _binary_buffers(self, rec): |
|
38 | def _binary_buffers(self, rec): | |
39 | for key in ('buffers', 'result_buffers'): |
|
39 | for key in ('buffers', 'result_buffers'): | |
40 |
if key |
|
40 | if rec.get(key, None): | |
41 | rec[key] = map(Binary, rec[key]) |
|
41 | rec[key] = map(Binary, rec[key]) | |
|
42 | return rec | |||
42 |
|
43 | |||
43 | def add_record(self, msg_id, rec): |
|
44 | def add_record(self, msg_id, rec): | |
44 | """Add a new Task Record, by msg_id.""" |
|
45 | """Add a new Task Record, by msg_id.""" | |
45 | # print rec |
|
46 | # print rec | |
46 | rec = _binary_buffers(rec) |
|
47 | rec = self._binary_buffers(rec) | |
47 | obj_id = self._records.insert(rec) |
|
48 | obj_id = self._records.insert(rec) | |
48 | self._table[msg_id] = obj_id |
|
49 | self._table[msg_id] = obj_id | |
49 |
|
50 | |||
@@ -53,7 +54,7 b' class MongoDB(BaseDB):' | |||||
53 |
|
54 | |||
54 | def update_record(self, msg_id, rec): |
|
55 | def update_record(self, msg_id, rec): | |
55 | """Update the data in an existing record.""" |
|
56 | """Update the data in an existing record.""" | |
56 | rec = _binary_buffers(rec) |
|
57 | rec = self._binary_buffers(rec) | |
57 | obj_id = self._table[msg_id] |
|
58 | obj_id = self._table[msg_id] | |
58 | self._records.update({'_id':obj_id}, {'$set': rec}) |
|
59 | self._records.update({'_id':obj_id}, {'$set': rec}) | |
59 |
|
60 | |||
@@ -66,15 +67,30 b' class MongoDB(BaseDB):' | |||||
66 | obj_id = self._table.pop(msg_id) |
|
67 | obj_id = self._table.pop(msg_id) | |
67 | self._records.remove(obj_id) |
|
68 | self._records.remove(obj_id) | |
68 |
|
69 | |||
69 |
def find_records(self, check, |
|
70 | def find_records(self, check, keys=None): | |
70 |
"""Find records matching a query dict. |
|
71 | """Find records matching a query dict, optionally extracting subset of keys. | |
71 | matches = list(self._records.find(check)) |
|
72 | ||
72 | if id_only: |
|
73 | Returns list of matching records. | |
73 | return [ rec['msg_id'] for rec in matches ] |
|
74 | ||
74 |
|
|
75 | Parameters | |
75 | data = {} |
|
76 | ---------- | |
76 | for rec in matches: |
|
77 | ||
77 | data[rec['msg_id']] = rec |
|
78 | check: dict | |
78 | return data |
|
79 | mongodb-style query argument | |
|
80 | keys: list of strs [optional] | |||
|
81 | if specified, the subset of keys to extract. msg_id will *always* be | |||
|
82 | included. | |||
|
83 | """ | |||
|
84 | if keys and 'msg_id' not in keys: | |||
|
85 | keys.append('msg_id') | |||
|
86 | matches = list(self._records.find(check,keys)) | |||
|
87 | for rec in matches: | |||
|
88 | rec.pop('_id') | |||
|
89 | return matches | |||
|
90 | ||||
|
91 | def get_history(self): | |||
|
92 | """get all msg_ids, ordered by time submitted.""" | |||
|
93 | cursor = self._records.find({},{'msg_id':1}).sort('submitted') | |||
|
94 | return [ rec['msg_id'] for rec in cursor ] | |||
79 |
|
95 | |||
80 |
|
96 |
@@ -24,7 +24,7 b' from IPython.parallel.util import ISO8601' | |||||
24 | #----------------------------------------------------------------------------- |
|
24 | #----------------------------------------------------------------------------- | |
25 |
|
25 | |||
26 | operators = { |
|
26 | operators = { | |
27 |
'$lt' : |
|
27 | '$lt' : "<", | |
28 | '$gt' : ">", |
|
28 | '$gt' : ">", | |
29 | # null is handled weird with ==,!= |
|
29 | # null is handled weird with ==,!= | |
30 | '$eq' : "IS", |
|
30 | '$eq' : "IS", | |
@@ -124,10 +124,11 b' class SQLiteDB(BaseDB):' | |||||
124 | pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop) |
|
124 | pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop) | |
125 | pc.start() |
|
125 | pc.start() | |
126 |
|
126 | |||
127 | def _defaults(self): |
|
127 | def _defaults(self, keys=None): | |
128 | """create an empty record""" |
|
128 | """create an empty record""" | |
129 | d = {} |
|
129 | d = {} | |
130 | for key in self._keys: |
|
130 | keys = self._keys if keys is None else keys | |
|
131 | for key in keys: | |||
131 | d[key] = None |
|
132 | d[key] = None | |
132 | return d |
|
133 | return d | |
133 |
|
134 | |||
@@ -168,9 +169,6 b' class SQLiteDB(BaseDB):' | |||||
168 | stdout text, |
|
169 | stdout text, | |
169 | stderr text) |
|
170 | stderr text) | |
170 | """%self.table) |
|
171 | """%self.table) | |
171 | # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers |
|
|||
172 | # (msg_id text, result integer, buffer blob) |
|
|||
173 | # """%self.table) |
|
|||
174 | self._db.commit() |
|
172 | self._db.commit() | |
175 |
|
173 | |||
176 | def _dict_to_list(self, d): |
|
174 | def _dict_to_list(self, d): | |
@@ -178,10 +176,11 b' class SQLiteDB(BaseDB):' | |||||
178 |
|
176 | |||
179 | return [ d[key] for key in self._keys ] |
|
177 | return [ d[key] for key in self._keys ] | |
180 |
|
178 | |||
181 | def _list_to_dict(self, line): |
|
179 | def _list_to_dict(self, line, keys=None): | |
182 | """Inverse of dict_to_list""" |
|
180 | """Inverse of dict_to_list""" | |
183 | d = self._defaults() |
|
181 | keys = self._keys if keys is None else keys | |
184 | for key,value in zip(self._keys, line): |
|
182 | d = self._defaults(keys) | |
|
183 | for key,value in zip(keys, line): | |||
185 | d[key] = value |
|
184 | d[key] = value | |
186 |
|
185 | |||
187 | return d |
|
186 | return d | |
@@ -249,13 +248,14 b' class SQLiteDB(BaseDB):' | |||||
249 | sets.append('%s = ?'%key) |
|
248 | sets.append('%s = ?'%key) | |
250 | values.append(rec[key]) |
|
249 | values.append(rec[key]) | |
251 | query += ', '.join(sets) |
|
250 | query += ', '.join(sets) | |
252 |
query += ' WHERE msg_id == |
|
251 | query += ' WHERE msg_id == ?' | |
|
252 | values.append(msg_id) | |||
253 | self._db.execute(query, values) |
|
253 | self._db.execute(query, values) | |
254 | # self._db.commit() |
|
254 | # self._db.commit() | |
255 |
|
255 | |||
256 | def drop_record(self, msg_id): |
|
256 | def drop_record(self, msg_id): | |
257 | """Remove a record from the DB.""" |
|
257 | """Remove a record from the DB.""" | |
258 |
self._db.execute("""DELETE FROM %s WHERE m |
|
258 | self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) | |
259 | # self._db.commit() |
|
259 | # self._db.commit() | |
260 |
|
260 | |||
261 | def drop_matching_records(self, check): |
|
261 | def drop_matching_records(self, check): | |
@@ -265,20 +265,48 b' class SQLiteDB(BaseDB):' | |||||
265 | self._db.execute(query,args) |
|
265 | self._db.execute(query,args) | |
266 | # self._db.commit() |
|
266 | # self._db.commit() | |
267 |
|
267 | |||
268 |
def find_records(self, check, |
|
268 | def find_records(self, check, keys=None): | |
269 |
"""Find records matching a query dict. |
|
269 | """Find records matching a query dict, optionally extracting subset of keys. | |
270 | req = 'msg_id' if id_only else '*' |
|
270 | ||
|
271 | Returns list of matching records. | |||
|
272 | ||||
|
273 | Parameters | |||
|
274 | ---------- | |||
|
275 | ||||
|
276 | check: dict | |||
|
277 | mongodb-style query argument | |||
|
278 | keys: list of strs [optional] | |||
|
279 | if specified, the subset of keys to extract. msg_id will *always* be | |||
|
280 | included. | |||
|
281 | """ | |||
|
282 | if keys: | |||
|
283 | bad_keys = [ key for key in keys if key not in self._keys ] | |||
|
284 | if bad_keys: | |||
|
285 | raise KeyError("Bad record key(s): %s"%bad_keys) | |||
|
286 | ||||
|
287 | if keys: | |||
|
288 | # ensure msg_id is present and first: | |||
|
289 | if 'msg_id' in keys: | |||
|
290 | keys.remove('msg_id') | |||
|
291 | keys.insert(0, 'msg_id') | |||
|
292 | req = ', '.join(keys) | |||
|
293 | else: | |||
|
294 | req = '*' | |||
271 | expr,args = self._render_expression(check) |
|
295 | expr,args = self._render_expression(check) | |
272 | query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) |
|
296 | query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) | |
273 | cursor = self._db.execute(query, args) |
|
297 | cursor = self._db.execute(query, args) | |
274 | matches = cursor.fetchall() |
|
298 | matches = cursor.fetchall() | |
275 | if id_only: |
|
299 | records = [] | |
276 |
|
|
300 | for line in matches: | |
277 | else: |
|
301 | rec = self._list_to_dict(line, keys) | |
278 |
records |
|
302 | records.append(rec) | |
279 | for line in matches: |
|
303 | return records | |
280 | rec = self._list_to_dict(line) |
|
304 | ||
281 | records[rec['msg_id']] = rec |
|
305 | def get_history(self): | |
282 | return records |
|
306 | """get all msg_ids, ordered by time submitted.""" | |
|
307 | query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table | |||
|
308 | cursor = self._db.execute(query) | |||
|
309 | # will be a list of length 1 tuples | |||
|
310 | return [ tup[0] for tup in cursor.fetchall()] | |||
283 |
|
311 | |||
284 | __all__ = ['SQLiteDB'] No newline at end of file |
|
312 | __all__ = ['SQLiteDB'] |
@@ -28,6 +28,7 b' from zmq.eventloop.zmqstream import ZMQStream' | |||||
28 | from .util import ISO8601 |
|
28 | from .util import ISO8601 | |
29 |
|
29 | |||
30 | def squash_unicode(obj): |
|
30 | def squash_unicode(obj): | |
|
31 | """coerce unicode back to bytestrings.""" | |||
31 | if isinstance(obj,dict): |
|
32 | if isinstance(obj,dict): | |
32 | for key in obj.keys(): |
|
33 | for key in obj.keys(): | |
33 | obj[key] = squash_unicode(obj[key]) |
|
34 | obj[key] = squash_unicode(obj[key]) | |
@@ -40,7 +41,14 b' def squash_unicode(obj):' | |||||
40 | obj = obj.encode('utf8') |
|
41 | obj = obj.encode('utf8') | |
41 | return obj |
|
42 | return obj | |
42 |
|
43 | |||
43 | json_packer = jsonapi.dumps |
|
44 | def _date_default(obj): | |
|
45 | if isinstance(obj, datetime): | |||
|
46 | return obj.strftime(ISO8601) | |||
|
47 | else: | |||
|
48 | raise TypeError("%r is not JSON serializable"%obj) | |||
|
49 | ||||
|
50 | _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default' | |||
|
51 | json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default}) | |||
44 | json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) |
|
52 | json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) | |
45 |
|
53 | |||
46 | pickle_packer = lambda o: pickle.dumps(o,-1) |
|
54 | pickle_packer = lambda o: pickle.dumps(o,-1) |
@@ -17,6 +17,7 b' import re' | |||||
17 | import stat |
|
17 | import stat | |
18 | import socket |
|
18 | import socket | |
19 | import sys |
|
19 | import sys | |
|
20 | from datetime import datetime | |||
20 | from signal import signal, SIGINT, SIGABRT, SIGTERM |
|
21 | from signal import signal, SIGINT, SIGABRT, SIGTERM | |
21 | try: |
|
22 | try: | |
22 | from signal import SIGKILL |
|
23 | from signal import SIGKILL | |
@@ -41,6 +42,7 b' from IPython.zmq.log import EnginePUBHandler' | |||||
41 |
|
42 | |||
42 | # globals |
|
43 | # globals | |
43 | ISO8601="%Y-%m-%dT%H:%M:%S.%f" |
|
44 | ISO8601="%Y-%m-%dT%H:%M:%S.%f" | |
|
45 | ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$") | |||
44 |
|
46 | |||
45 | #----------------------------------------------------------------------------- |
|
47 | #----------------------------------------------------------------------------- | |
46 | # Classes |
|
48 | # Classes | |
@@ -99,6 +101,18 b' class ReverseDict(dict):' | |||||
99 | # Functions |
|
101 | # Functions | |
100 | #----------------------------------------------------------------------------- |
|
102 | #----------------------------------------------------------------------------- | |
101 |
|
103 | |||
|
104 | def extract_dates(obj): | |||
|
105 | """extract ISO8601 dates from unpacked JSON""" | |||
|
106 | if isinstance(obj, dict): | |||
|
107 | for k,v in obj.iteritems(): | |||
|
108 | obj[k] = extract_dates(v) | |||
|
109 | elif isinstance(obj, list): | |||
|
110 | obj = [ extract_dates(o) for o in obj ] | |||
|
111 | elif isinstance(obj, basestring): | |||
|
112 | if ISO8601_RE.match(obj): | |||
|
113 | obj = datetime.strptime(obj, ISO8601) | |||
|
114 | return obj | |||
|
115 | ||||
102 | def validate_url(url): |
|
116 | def validate_url(url): | |
103 | """validate a url for zeromq""" |
|
117 | """validate a url for zeromq""" | |
104 | if not isinstance(url, basestring): |
|
118 | if not isinstance(url, basestring): | |
@@ -460,3 +474,4 b' def local_logger(logname, loglevel=logging.DEBUG):' | |||||
460 | handler.setLevel(loglevel) |
|
474 | handler.setLevel(loglevel) | |
461 | logger.addHandler(handler) |
|
475 | logger.addHandler(handler) | |
462 | logger.setLevel(loglevel) |
|
476 | logger.setLevel(loglevel) | |
|
477 |
General Comments 0
You need to be logged in to leave comments.
Login now