Show More
@@ -0,0 +1,182 b'' | |||
|
1 | """Tests for db backends""" | |
|
2 | ||
|
3 | #------------------------------------------------------------------------------- | |
|
4 | # Copyright (C) 2011 The IPython Development Team | |
|
5 | # | |
|
6 | # Distributed under the terms of the BSD License. The full license is in | |
|
7 | # the file COPYING, distributed as part of this software. | |
|
8 | #------------------------------------------------------------------------------- | |
|
9 | ||
|
10 | #------------------------------------------------------------------------------- | |
|
11 | # Imports | |
|
12 | #------------------------------------------------------------------------------- | |
|
13 | ||
|
14 | ||
|
15 | import tempfile | |
|
16 | import time | |
|
17 | ||
|
18 | import uuid | |
|
19 | ||
|
20 | from datetime import datetime, timedelta | |
|
21 | from random import choice, randint | |
|
22 | from unittest import TestCase | |
|
23 | ||
|
24 | from nose import SkipTest | |
|
25 | ||
|
26 | from IPython.parallel import error, streamsession as ss | |
|
27 | from IPython.parallel.controller.dictdb import DictDB | |
|
28 | from IPython.parallel.controller.sqlitedb import SQLiteDB | |
|
29 | from IPython.parallel.controller.hub import init_record, empty_record | |
|
30 | ||
|
31 | #------------------------------------------------------------------------------- | |
|
32 | # TestCases | |
|
33 | #------------------------------------------------------------------------------- | |
|
34 | ||
|
35 | class TestDictBackend(TestCase): | |
|
36 | def setUp(self): | |
|
37 | self.session = ss.StreamSession() | |
|
38 | self.db = self.create_db() | |
|
39 | self.load_records(16) | |
|
40 | ||
|
41 | def create_db(self): | |
|
42 | return DictDB() | |
|
43 | ||
|
44 | def load_records(self, n=1): | |
|
45 | """load n records for testing""" | |
|
46 | #sleep 1/10 s, to ensure timestamp is different to previous calls | |
|
47 | time.sleep(0.1) | |
|
48 | msg_ids = [] | |
|
49 | for i in range(n): | |
|
50 | msg = self.session.msg('apply_request', content=dict(a=5)) | |
|
51 | msg['buffers'] = [] | |
|
52 | rec = init_record(msg) | |
|
53 | msg_ids.append(msg['msg_id']) | |
|
54 | self.db.add_record(msg['msg_id'], rec) | |
|
55 | return msg_ids | |
|
56 | ||
|
57 | def test_add_record(self): | |
|
58 | before = self.db.get_history() | |
|
59 | self.load_records(5) | |
|
60 | after = self.db.get_history() | |
|
61 | self.assertEquals(len(after), len(before)+5) | |
|
62 | self.assertEquals(after[:-5],before) | |
|
63 | ||
|
64 | def test_drop_record(self): | |
|
65 | msg_id = self.load_records()[-1] | |
|
66 | rec = self.db.get_record(msg_id) | |
|
67 | self.db.drop_record(msg_id) | |
|
68 | self.assertRaises(KeyError,self.db.get_record, msg_id) | |
|
69 | ||
|
70 | def _round_to_millisecond(self, dt): | |
|
71 | """necessary because mongodb rounds microseconds""" | |
|
72 | micro = dt.microsecond | |
|
73 | extra = int(str(micro)[-3:]) | |
|
74 | return dt - timedelta(microseconds=extra) | |
|
75 | ||
|
76 | def test_update_record(self): | |
|
77 | now = self._round_to_millisecond(datetime.now()) | |
|
78 | # | |
|
79 | msg_id = self.db.get_history()[-1] | |
|
80 | rec1 = self.db.get_record(msg_id) | |
|
81 | data = {'stdout': 'hello there', 'completed' : now} | |
|
82 | self.db.update_record(msg_id, data) | |
|
83 | rec2 = self.db.get_record(msg_id) | |
|
84 | self.assertEquals(rec2['stdout'], 'hello there') | |
|
85 | self.assertEquals(rec2['completed'], now) | |
|
86 | rec1.update(data) | |
|
87 | self.assertEquals(rec1, rec2) | |
|
88 | ||
|
89 | # def test_update_record_bad(self): | |
|
90 | # """test updating nonexistant records""" | |
|
91 | # msg_id = str(uuid.uuid4()) | |
|
92 | # data = {'stdout': 'hello there'} | |
|
93 | # self.assertRaises(KeyError, self.db.update_record, msg_id, data) | |
|
94 | ||
|
95 | def test_find_records_dt(self): | |
|
96 | """test finding records by date""" | |
|
97 | hist = self.db.get_history() | |
|
98 | middle = self.db.get_record(hist[len(hist)/2]) | |
|
99 | tic = middle['submitted'] | |
|
100 | before = self.db.find_records({'submitted' : {'$lt' : tic}}) | |
|
101 | after = self.db.find_records({'submitted' : {'$gte' : tic}}) | |
|
102 | self.assertEquals(len(before)+len(after),len(hist)) | |
|
103 | for b in before: | |
|
104 | self.assertTrue(b['submitted'] < tic) | |
|
105 | for a in after: | |
|
106 | self.assertTrue(a['submitted'] >= tic) | |
|
107 | same = self.db.find_records({'submitted' : tic}) | |
|
108 | for s in same: | |
|
109 | self.assertTrue(s['submitted'] == tic) | |
|
110 | ||
|
111 | def test_find_records_keys(self): | |
|
112 | """test extracting subset of record keys""" | |
|
113 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) | |
|
114 | for rec in found: | |
|
115 | self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) | |
|
116 | ||
|
117 | def test_find_records_msg_id(self): | |
|
118 | """ensure msg_id is always in found records""" | |
|
119 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) | |
|
120 | for rec in found: | |
|
121 | self.assertTrue('msg_id' in rec.keys()) | |
|
122 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) | |
|
123 | for rec in found: | |
|
124 | self.assertTrue('msg_id' in rec.keys()) | |
|
125 | found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) | |
|
126 | for rec in found: | |
|
127 | self.assertTrue('msg_id' in rec.keys()) | |
|
128 | ||
|
129 | def test_find_records_in(self): | |
|
130 | """test finding records with '$in','$nin' operators""" | |
|
131 | hist = self.db.get_history() | |
|
132 | even = hist[::2] | |
|
133 | odd = hist[1::2] | |
|
134 | recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) | |
|
135 | found = [ r['msg_id'] for r in recs ] | |
|
136 | self.assertEquals(set(even), set(found)) | |
|
137 | recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) | |
|
138 | found = [ r['msg_id'] for r in recs ] | |
|
139 | self.assertEquals(set(odd), set(found)) | |
|
140 | ||
|
141 | def test_get_history(self): | |
|
142 | msg_ids = self.db.get_history() | |
|
143 | latest = datetime(1984,1,1) | |
|
144 | for msg_id in msg_ids: | |
|
145 | rec = self.db.get_record(msg_id) | |
|
146 | newt = rec['submitted'] | |
|
147 | self.assertTrue(newt >= latest) | |
|
148 | latest = newt | |
|
149 | msg_id = self.load_records(1)[-1] | |
|
150 | self.assertEquals(self.db.get_history()[-1],msg_id) | |
|
151 | ||
|
152 | def test_datetime(self): | |
|
153 | """get/set timestamps with datetime objects""" | |
|
154 | msg_id = self.db.get_history()[-1] | |
|
155 | rec = self.db.get_record(msg_id) | |
|
156 | self.assertTrue(isinstance(rec['submitted'], datetime)) | |
|
157 | self.db.update_record(msg_id, dict(completed=datetime.now())) | |
|
158 | rec = self.db.get_record(msg_id) | |
|
159 | self.assertTrue(isinstance(rec['completed'], datetime)) | |
|
160 | ||
|
161 | class TestSQLiteBackend(TestDictBackend): | |
|
162 | def create_db(self): | |
|
163 | return SQLiteDB(location=tempfile.gettempdir()) | |
|
164 | ||
|
165 | def tearDown(self): | |
|
166 | self.db._db.close() | |
|
167 | ||
|
168 | # optional MongoDB test | |
|
169 | try: | |
|
170 | from IPython.parallel.controller.mongodb import MongoDB | |
|
171 | except ImportError: | |
|
172 | pass | |
|
173 | else: | |
|
174 | class TestMongoBackend(TestDictBackend): | |
|
175 | def create_db(self): | |
|
176 | try: | |
|
177 | return MongoDB(database='iptestdb') | |
|
178 | except Exception: | |
|
179 | raise SkipTest("Couldn't connect to mongodb instance") | |
|
180 | ||
|
181 | def tearDown(self): | |
|
182 | self.db._connection.drop_database('iptestdb') |
@@ -1215,5 +1215,78 b' class Client(HasTraits):' | |||
|
1215 | 1215 | if content['status'] != 'ok': |
|
1216 | 1216 | raise self._unwrap_exception(content) |
|
1217 | 1217 | |
|
1218 | @spin_first | |
|
1219 | def hub_history(self): | |
|
1220 | """Get the Hub's history | |
|
1221 | ||
|
1222 | Just like the Client, the Hub has a history, which is a list of msg_ids. | |
|
1223 | This will contain the history of all clients, and, depending on configuration, | |
|
1224 | may contain history across multiple cluster sessions. | |
|
1225 | ||
|
1226 | Any msg_id returned here is a valid argument to `get_result`. | |
|
1227 | ||
|
1228 | Returns | |
|
1229 | ------- | |
|
1230 | ||
|
1231 | msg_ids : list of strs | |
|
1232 | list of all msg_ids, ordered by task submission time. | |
|
1233 | """ | |
|
1234 | ||
|
1235 | self.session.send(self._query_socket, "history_request", content={}) | |
|
1236 | idents, msg = self.session.recv(self._query_socket, 0) | |
|
1237 | ||
|
1238 | if self.debug: | |
|
1239 | pprint(msg) | |
|
1240 | content = msg['content'] | |
|
1241 | if content['status'] != 'ok': | |
|
1242 | raise self._unwrap_exception(content) | |
|
1243 | else: | |
|
1244 | return content['history'] | |
|
1245 | ||
|
1246 | @spin_first | |
|
1247 | def db_query(self, query, keys=None): | |
|
1248 | """Query the Hub's TaskRecord database | |
|
1249 | ||
|
1250 | This will return a list of task record dicts that match `query` | |
|
1251 | ||
|
1252 | Parameters | |
|
1253 | ---------- | |
|
1254 | ||
|
1255 | query : mongodb query dict | |
|
1256 | The search dict. See mongodb query docs for details. | |
|
1257 | keys : list of strs [optional] | |
|
1258 | THe subset of keys to be returned. The default is to fetch everything. | |
|
1259 | 'msg_id' will *always* be included. | |
|
1260 | """ | |
|
1261 | content = dict(query=query, keys=keys) | |
|
1262 | self.session.send(self._query_socket, "db_request", content=content) | |
|
1263 | idents, msg = self.session.recv(self._query_socket, 0) | |
|
1264 | if self.debug: | |
|
1265 | pprint(msg) | |
|
1266 | content = msg['content'] | |
|
1267 | if content['status'] != 'ok': | |
|
1268 | raise self._unwrap_exception(content) | |
|
1269 | ||
|
1270 | records = content['records'] | |
|
1271 | buffer_lens = content['buffer_lens'] | |
|
1272 | result_buffer_lens = content['result_buffer_lens'] | |
|
1273 | buffers = msg['buffers'] | |
|
1274 | has_bufs = buffer_lens is not None | |
|
1275 | has_rbufs = result_buffer_lens is not None | |
|
1276 | for i,rec in enumerate(records): | |
|
1277 | # relink buffers | |
|
1278 | if has_bufs: | |
|
1279 | blen = buffer_lens[i] | |
|
1280 | rec['buffers'], buffers = buffers[:blen],buffers[blen:] | |
|
1281 | if has_rbufs: | |
|
1282 | blen = result_buffer_lens[i] | |
|
1283 | rec['result_buffers'], buffers = buffers[:blen],buffers[blen:] | |
|
1284 | # turn timestamps back into times | |
|
1285 | for key in 'submitted started completed resubmitted'.split(): | |
|
1286 | maybedate = rec.get(key, None) | |
|
1287 | if maybedate and util.ISO8601_RE.match(maybedate): | |
|
1288 | rec[key] = datetime.strptime(maybedate, util.ISO8601) | |
|
1289 | ||
|
1290 | return records | |
|
1218 | 1291 | |
|
1219 | 1292 | __all__ = [ 'Client' ] |
@@ -103,9 +103,9 b' class DictDB(BaseDB):' | |||
|
103 | 103 | return False |
|
104 | 104 | return True |
|
105 | 105 | |
|
106 |
def _match(self, check |
|
|
106 | def _match(self, check): | |
|
107 | 107 | """Find all the matches for a check dict.""" |
|
108 |
matches = |
|
|
108 | matches = [] | |
|
109 | 109 | tests = {} |
|
110 | 110 | for k,v in check.iteritems(): |
|
111 | 111 | if isinstance(v, dict): |
@@ -113,14 +113,18 b' class DictDB(BaseDB):' | |||
|
113 | 113 | else: |
|
114 | 114 | tests[k] = lambda o: o==v |
|
115 | 115 | |
|
116 |
for |
|
|
116 | for rec in self._records.itervalues(): | |
|
117 | 117 | if self._match_one(rec, tests): |
|
118 |
matches |
|
|
119 | if id_only: | |
|
120 | return matches.keys() | |
|
121 | else: | |
|
122 | return matches | |
|
123 |
|
|
|
118 | matches.append(rec) | |
|
119 | return matches | |
|
120 | ||
|
121 | def _extract_subdict(self, rec, keys): | |
|
122 | """extract subdict of keys""" | |
|
123 | d = {} | |
|
124 | d['msg_id'] = rec['msg_id'] | |
|
125 | for key in keys: | |
|
126 | d[key] = rec[key] | |
|
127 | return d | |
|
124 | 128 | |
|
125 | 129 | def add_record(self, msg_id, rec): |
|
126 | 130 | """Add a new Task Record, by msg_id.""" |
@@ -140,7 +144,7 b' class DictDB(BaseDB):' | |||
|
140 | 144 | |
|
141 | 145 | def drop_matching_records(self, check): |
|
142 | 146 | """Remove a record from the DB.""" |
|
143 |
matches = self._match(check |
|
|
147 | matches = self._match(check) | |
|
144 | 148 | for m in matches: |
|
145 | 149 | del self._records[m] |
|
146 | 150 | |
@@ -149,7 +153,28 b' class DictDB(BaseDB):' | |||
|
149 | 153 | del self._records[msg_id] |
|
150 | 154 | |
|
151 | 155 | |
|
152 |
def find_records(self, check, |
|
|
153 |
"""Find records matching a query dict. |
|
|
154 | matches = self._match(check, id_only) | |
|
155 | return matches No newline at end of file | |
|
156 | def find_records(self, check, keys=None): | |
|
157 | """Find records matching a query dict, optionally extracting subset of keys. | |
|
158 | ||
|
159 | Returns dict keyed by msg_id of matching records. | |
|
160 | ||
|
161 | Parameters | |
|
162 | ---------- | |
|
163 | ||
|
164 | check: dict | |
|
165 | mongodb-style query argument | |
|
166 | keys: list of strs [optional] | |
|
167 | if specified, the subset of keys to extract. msg_id will *always* be | |
|
168 | included. | |
|
169 | """ | |
|
170 | matches = self._match(check) | |
|
171 | if keys: | |
|
172 | return [ self._extract_subdict(rec, keys) for rec in matches ] | |
|
173 | else: | |
|
174 | return matches | |
|
175 | ||
|
176 | ||
|
177 | def get_history(self): | |
|
178 | """get all msg_ids, ordered by time submitted.""" | |
|
179 | msg_ids = self._records.keys() | |
|
180 | return sorted(msg_ids, key=lambda m: self._records[m]['submitted']) |
@@ -27,9 +27,8 b' from zmq.eventloop.zmqstream import ZMQStream' | |||
|
27 | 27 | from IPython.utils.importstring import import_item |
|
28 | 28 | from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool |
|
29 | 29 | |
|
30 | from IPython.parallel import error | |
|
30 | from IPython.parallel import error, util | |
|
31 | 31 | from IPython.parallel.factory import RegistrationFactory, LoggingFactory |
|
32 | from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601 | |
|
33 | 32 | |
|
34 | 33 | from .heartmonitor import HeartMonitor |
|
35 | 34 | |
@@ -76,7 +75,7 b' def init_record(msg):' | |||
|
76 | 75 | 'header' : header, |
|
77 | 76 | 'content': msg['content'], |
|
78 | 77 | 'buffers': msg['buffers'], |
|
79 | 'submitted': datetime.strptime(header['date'], ISO8601), | |
|
78 | 'submitted': datetime.strptime(header['date'], util.ISO8601), | |
|
80 | 79 | 'client_uuid' : None, |
|
81 | 80 | 'engine_uuid' : None, |
|
82 | 81 | 'started': None, |
@@ -119,32 +118,32 b' class HubFactory(RegistrationFactory):' | |||
|
119 | 118 | # port-pairs for monitoredqueues: |
|
120 | 119 | hb = Instance(list, config=True) |
|
121 | 120 | def _hb_default(self): |
|
122 | return select_random_ports(2) | |
|
121 | return util.select_random_ports(2) | |
|
123 | 122 | |
|
124 | 123 | mux = Instance(list, config=True) |
|
125 | 124 | def _mux_default(self): |
|
126 | return select_random_ports(2) | |
|
125 | return util.select_random_ports(2) | |
|
127 | 126 | |
|
128 | 127 | task = Instance(list, config=True) |
|
129 | 128 | def _task_default(self): |
|
130 | return select_random_ports(2) | |
|
129 | return util.select_random_ports(2) | |
|
131 | 130 | |
|
132 | 131 | control = Instance(list, config=True) |
|
133 | 132 | def _control_default(self): |
|
134 | return select_random_ports(2) | |
|
133 | return util.select_random_ports(2) | |
|
135 | 134 | |
|
136 | 135 | iopub = Instance(list, config=True) |
|
137 | 136 | def _iopub_default(self): |
|
138 | return select_random_ports(2) | |
|
137 | return util.select_random_ports(2) | |
|
139 | 138 | |
|
140 | 139 | # single ports: |
|
141 | 140 | mon_port = Instance(int, config=True) |
|
142 | 141 | def _mon_port_default(self): |
|
143 | return select_random_ports(1)[0] | |
|
142 | return util.select_random_ports(1)[0] | |
|
144 | 143 | |
|
145 | 144 | notifier_port = Instance(int, config=True) |
|
146 | 145 | def _notifier_port_default(self): |
|
147 | return select_random_ports(1)[0] | |
|
146 | return util.select_random_ports(1)[0] | |
|
148 | 147 | |
|
149 | 148 | ping = Int(1000, config=True) # ping frequency |
|
150 | 149 | |
@@ -344,11 +343,11 b' class Hub(LoggingFactory):' | |||
|
344 | 343 | # validate connection dicts: |
|
345 | 344 | for k,v in self.client_info.iteritems(): |
|
346 | 345 | if k == 'task': |
|
347 | validate_url_container(v[1]) | |
|
346 | util.validate_url_container(v[1]) | |
|
348 | 347 | else: |
|
349 | validate_url_container(v) | |
|
350 | # validate_url_container(self.client_info) | |
|
351 | validate_url_container(self.engine_info) | |
|
348 | util.validate_url_container(v) | |
|
349 | # util.validate_url_container(self.client_info) | |
|
350 | util.validate_url_container(self.engine_info) | |
|
352 | 351 | |
|
353 | 352 | # register our callbacks |
|
354 | 353 | self.query.on_recv(self.dispatch_query) |
@@ -369,6 +368,8 b' class Hub(LoggingFactory):' | |||
|
369 | 368 | |
|
370 | 369 | self.query_handlers = {'queue_request': self.queue_status, |
|
371 | 370 | 'result_request': self.get_results, |
|
371 | 'history_request': self.get_history, | |
|
372 | 'db_request': self.db_query, | |
|
372 | 373 | 'purge_request': self.purge_results, |
|
373 | 374 | 'load_request': self.check_load, |
|
374 | 375 | 'resubmit_request': self.resubmit_task, |
@@ -606,10 +607,10 b' class Hub(LoggingFactory):' | |||
|
606 | 607 | return |
|
607 | 608 | # update record anyway, because the unregistration could have been premature |
|
608 | 609 | rheader = msg['header'] |
|
609 | completed = datetime.strptime(rheader['date'], ISO8601) | |
|
610 | completed = datetime.strptime(rheader['date'], util.ISO8601) | |
|
610 | 611 | started = rheader.get('started', None) |
|
611 | 612 | if started is not None: |
|
612 | started = datetime.strptime(started, ISO8601) | |
|
613 | started = datetime.strptime(started, util.ISO8601) | |
|
613 | 614 | result = { |
|
614 | 615 | 'result_header' : rheader, |
|
615 | 616 | 'result_content': msg['content'], |
@@ -618,7 +619,10 b' class Hub(LoggingFactory):' | |||
|
618 | 619 | } |
|
619 | 620 | |
|
620 | 621 | result['result_buffers'] = msg['buffers'] |
|
621 | self.db.update_record(msg_id, result) | |
|
622 | try: | |
|
623 | self.db.update_record(msg_id, result) | |
|
624 | except Exception: | |
|
625 | self.log.error("DB Error updating record %r"%msg_id, exc_info=True) | |
|
622 | 626 | |
|
623 | 627 | |
|
624 | 628 | #--------------------- Task Queue Traffic ------------------------------ |
@@ -653,6 +657,8 b' class Hub(LoggingFactory):' | |||
|
653 | 657 | self.db.update_record(msg_id, record) |
|
654 | 658 | except KeyError: |
|
655 | 659 | self.db.add_record(msg_id, record) |
|
660 | except Exception: | |
|
661 | self.log.error("DB Error saving task request %r"%msg_id, exc_info=True) | |
|
656 | 662 | |
|
657 | 663 | def save_task_result(self, idents, msg): |
|
658 | 664 | """save the result of a completed task.""" |
@@ -685,10 +691,10 b' class Hub(LoggingFactory):' | |||
|
685 | 691 | self.completed[eid].append(msg_id) |
|
686 | 692 | if msg_id in self.tasks[eid]: |
|
687 | 693 | self.tasks[eid].remove(msg_id) |
|
688 | completed = datetime.strptime(header['date'], ISO8601) | |
|
694 | completed = datetime.strptime(header['date'], util.ISO8601) | |
|
689 | 695 | started = header.get('started', None) |
|
690 | 696 | if started is not None: |
|
691 | started = datetime.strptime(started, ISO8601) | |
|
697 | started = datetime.strptime(started, util.ISO8601) | |
|
692 | 698 | result = { |
|
693 | 699 | 'result_header' : header, |
|
694 | 700 | 'result_content': msg['content'], |
@@ -698,7 +704,10 b' class Hub(LoggingFactory):' | |||
|
698 | 704 | } |
|
699 | 705 | |
|
700 | 706 | result['result_buffers'] = msg['buffers'] |
|
701 | self.db.update_record(msg_id, result) | |
|
707 | try: | |
|
708 | self.db.update_record(msg_id, result) | |
|
709 | except Exception: | |
|
710 | self.log.error("DB Error saving task request %r"%msg_id, exc_info=True) | |
|
702 | 711 | |
|
703 | 712 | else: |
|
704 | 713 | self.log.debug("task::unknown task %s finished"%msg_id) |
@@ -723,7 +732,11 b' class Hub(LoggingFactory):' | |||
|
723 | 732 | |
|
724 | 733 | self.tasks[eid].append(msg_id) |
|
725 | 734 | # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid)) |
|
726 | self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) | |
|
735 | try: | |
|
736 | self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) | |
|
737 | except Exception: | |
|
738 | self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True) | |
|
739 | ||
|
727 | 740 | |
|
728 | 741 | def mia_task_request(self, idents, msg): |
|
729 | 742 | raise NotImplementedError |
@@ -772,7 +785,10 b' class Hub(LoggingFactory):' | |||
|
772 | 785 | else: |
|
773 | 786 | d[msg_type] = content.get('data', '') |
|
774 | 787 | |
|
775 | self.db.update_record(msg_id, d) | |
|
788 | try: | |
|
789 | self.db.update_record(msg_id, d) | |
|
790 | except Exception: | |
|
791 | self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True) | |
|
776 | 792 | |
|
777 | 793 | |
|
778 | 794 | |
@@ -904,11 +920,15 b' class Hub(LoggingFactory):' | |||
|
904 | 920 | # build a fake header: |
|
905 | 921 | header = {} |
|
906 | 922 | header['engine'] = uuid |
|
907 |
header['date'] = datetime.now() |
|
|
923 | header['date'] = datetime.now() | |
|
908 | 924 | rec = dict(result_content=content, result_header=header, result_buffers=[]) |
|
909 | 925 | rec['completed'] = header['date'] |
|
910 | 926 | rec['engine_uuid'] = uuid |
|
911 | self.db.update_record(msg_id, rec) | |
|
927 | try: | |
|
928 | self.db.update_record(msg_id, rec) | |
|
929 | except Exception: | |
|
930 | self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True) | |
|
931 | ||
|
912 | 932 | |
|
913 | 933 | def finish_registration(self, heart): |
|
914 | 934 | """Second half of engine registration, called after our HeartMonitor |
@@ -1017,7 +1037,10 b' class Hub(LoggingFactory):' | |||
|
1017 | 1037 | msg_ids = content.get('msg_ids', []) |
|
1018 | 1038 | reply = dict(status='ok') |
|
1019 | 1039 | if msg_ids == 'all': |
|
1020 | self.db.drop_matching_records(dict(completed={'$ne':None})) | |
|
1040 | try: | |
|
1041 | self.db.drop_matching_records(dict(completed={'$ne':None})) | |
|
1042 | except Exception: | |
|
1043 | reply = error.wrap_exception() | |
|
1021 | 1044 | else: |
|
1022 | 1045 | for msg_id in msg_ids: |
|
1023 | 1046 | if msg_id in self.all_completed: |
@@ -1044,7 +1067,11 b' class Hub(LoggingFactory):' | |||
|
1044 | 1067 | break |
|
1045 | 1068 | msg_ids = self.completed.pop(eid) |
|
1046 | 1069 | uid = self.engines[eid].queue |
|
1047 | self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) | |
|
1070 | try: | |
|
1071 | self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) | |
|
1072 | except Exception: | |
|
1073 | reply = error.wrap_exception() | |
|
1074 | break | |
|
1048 | 1075 | |
|
1049 | 1076 | self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) |
|
1050 | 1077 | |
@@ -1052,6 +1079,23 b' class Hub(LoggingFactory):' | |||
|
1052 | 1079 | """Resubmit a task.""" |
|
1053 | 1080 | raise NotImplementedError |
|
1054 | 1081 | |
|
1082 | def _extract_record(self, rec): | |
|
1083 | """decompose a TaskRecord dict into subsection of reply for get_result""" | |
|
1084 | io_dict = {} | |
|
1085 | for key in 'pyin pyout pyerr stdout stderr'.split(): | |
|
1086 | io_dict[key] = rec[key] | |
|
1087 | content = { 'result_content': rec['result_content'], | |
|
1088 | 'header': rec['header'], | |
|
1089 | 'result_header' : rec['result_header'], | |
|
1090 | 'io' : io_dict, | |
|
1091 | } | |
|
1092 | if rec['result_buffers']: | |
|
1093 | buffers = map(str, rec['result_buffers']) | |
|
1094 | else: | |
|
1095 | buffers = [] | |
|
1096 | ||
|
1097 | return content, buffers | |
|
1098 | ||
|
1055 | 1099 | def get_results(self, client_id, msg): |
|
1056 | 1100 | """Get the result of 1 or more messages.""" |
|
1057 | 1101 | content = msg['content'] |
@@ -1064,25 +1108,28 b' class Hub(LoggingFactory):' | |||
|
1064 | 1108 | content['completed'] = completed |
|
1065 | 1109 | buffers = [] |
|
1066 | 1110 | if not statusonly: |
|
1067 | content['results'] = {} | |
|
1068 |
|
|
|
1111 | try: | |
|
1112 | matches = self.db.find_records(dict(msg_id={'$in':msg_ids})) | |
|
1113 | # turn match list into dict, for faster lookup | |
|
1114 | records = {} | |
|
1115 | for rec in matches: | |
|
1116 | records[rec['msg_id']] = rec | |
|
1117 | except Exception: | |
|
1118 | content = error.wrap_exception() | |
|
1119 | self.session.send(self.query, "result_reply", content=content, | |
|
1120 | parent=msg, ident=client_id) | |
|
1121 | return | |
|
1122 | else: | |
|
1123 | records = {} | |
|
1069 | 1124 | for msg_id in msg_ids: |
|
1070 | 1125 | if msg_id in self.pending: |
|
1071 | 1126 | pending.append(msg_id) |
|
1072 | elif msg_id in self.all_completed: | |
|
1127 | elif msg_id in self.all_completed or msg_id in records: | |
|
1073 | 1128 | completed.append(msg_id) |
|
1074 | 1129 | if not statusonly: |
|
1075 |
|
|
|
1076 |
|
|
|
1077 | for key in 'pyin pyout pyerr stdout stderr'.split(): | |
|
1078 | io_dict[key] = rec[key] | |
|
1079 | content[msg_id] = { 'result_content': rec['result_content'], | |
|
1080 | 'header': rec['header'], | |
|
1081 | 'result_header' : rec['result_header'], | |
|
1082 | 'io' : io_dict, | |
|
1083 | } | |
|
1084 | if rec['result_buffers']: | |
|
1085 | buffers.extend(map(str, rec['result_buffers'])) | |
|
1130 | c,bufs = self._extract_record(records[msg_id]) | |
|
1131 | content[msg_id] = c | |
|
1132 | buffers.extend(bufs) | |
|
1086 | 1133 | else: |
|
1087 | 1134 | try: |
|
1088 | 1135 | raise KeyError('No such message: '+msg_id) |
@@ -1093,3 +1140,54 b' class Hub(LoggingFactory):' | |||
|
1093 | 1140 | parent=msg, ident=client_id, |
|
1094 | 1141 | buffers=buffers) |
|
1095 | 1142 | |
|
1143 | def get_history(self, client_id, msg): | |
|
1144 | """Get a list of all msg_ids in our DB records""" | |
|
1145 | try: | |
|
1146 | msg_ids = self.db.get_history() | |
|
1147 | except Exception as e: | |
|
1148 | content = error.wrap_exception() | |
|
1149 | else: | |
|
1150 | content = dict(status='ok', history=msg_ids) | |
|
1151 | ||
|
1152 | self.session.send(self.query, "history_reply", content=content, | |
|
1153 | parent=msg, ident=client_id) | |
|
1154 | ||
|
1155 | def db_query(self, client_id, msg): | |
|
1156 | """Perform a raw query on the task record database.""" | |
|
1157 | content = msg['content'] | |
|
1158 | query = content.get('query', {}) | |
|
1159 | keys = content.get('keys', None) | |
|
1160 | query = util.extract_dates(query) | |
|
1161 | buffers = [] | |
|
1162 | empty = list() | |
|
1163 | ||
|
1164 | try: | |
|
1165 | records = self.db.find_records(query, keys) | |
|
1166 | except Exception as e: | |
|
1167 | content = error.wrap_exception() | |
|
1168 | else: | |
|
1169 | # extract buffers from reply content: | |
|
1170 | if keys is not None: | |
|
1171 | buffer_lens = [] if 'buffers' in keys else None | |
|
1172 | result_buffer_lens = [] if 'result_buffers' in keys else None | |
|
1173 | else: | |
|
1174 | buffer_lens = [] | |
|
1175 | result_buffer_lens = [] | |
|
1176 | ||
|
1177 | for rec in records: | |
|
1178 | # buffers may be None, so double check | |
|
1179 | if buffer_lens is not None: | |
|
1180 | b = rec.pop('buffers', empty) or empty | |
|
1181 | buffer_lens.append(len(b)) | |
|
1182 | buffers.extend(b) | |
|
1183 | if result_buffer_lens is not None: | |
|
1184 | rb = rec.pop('result_buffers', empty) or empty | |
|
1185 | result_buffer_lens.append(len(rb)) | |
|
1186 | buffers.extend(rb) | |
|
1187 | content = dict(status='ok', records=records, buffer_lens=buffer_lens, | |
|
1188 | result_buffer_lens=result_buffer_lens) | |
|
1189 | ||
|
1190 | self.session.send(self.query, "db_reply", content=content, | |
|
1191 | parent=msg, ident=client_id, | |
|
1192 | buffers=buffers) | |
|
1193 |
@@ -22,9 +22,9 b' from .dictdb import BaseDB' | |||
|
22 | 22 | class MongoDB(BaseDB): |
|
23 | 23 | """MongoDB TaskRecord backend.""" |
|
24 | 24 | |
|
25 | connection_args = List(config=True) | |
|
26 | connection_kwargs = Dict(config=True) | |
|
27 | database = CUnicode(config=True) | |
|
25 | connection_args = List(config=True) # args passed to pymongo.Connection | |
|
26 | connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection | |
|
27 | database = CUnicode(config=True) # name of the mongodb database | |
|
28 | 28 | _table = Dict() |
|
29 | 29 | |
|
30 | 30 | def __init__(self, **kwargs): |
@@ -37,13 +37,14 b' class MongoDB(BaseDB):' | |||
|
37 | 37 | |
|
38 | 38 | def _binary_buffers(self, rec): |
|
39 | 39 | for key in ('buffers', 'result_buffers'): |
|
40 |
if key |
|
|
40 | if rec.get(key, None): | |
|
41 | 41 | rec[key] = map(Binary, rec[key]) |
|
42 | return rec | |
|
42 | 43 | |
|
43 | 44 | def add_record(self, msg_id, rec): |
|
44 | 45 | """Add a new Task Record, by msg_id.""" |
|
45 | 46 | # print rec |
|
46 | rec = _binary_buffers(rec) | |
|
47 | rec = self._binary_buffers(rec) | |
|
47 | 48 | obj_id = self._records.insert(rec) |
|
48 | 49 | self._table[msg_id] = obj_id |
|
49 | 50 | |
@@ -53,7 +54,7 b' class MongoDB(BaseDB):' | |||
|
53 | 54 | |
|
54 | 55 | def update_record(self, msg_id, rec): |
|
55 | 56 | """Update the data in an existing record.""" |
|
56 | rec = _binary_buffers(rec) | |
|
57 | rec = self._binary_buffers(rec) | |
|
57 | 58 | obj_id = self._table[msg_id] |
|
58 | 59 | self._records.update({'_id':obj_id}, {'$set': rec}) |
|
59 | 60 | |
@@ -66,15 +67,30 b' class MongoDB(BaseDB):' | |||
|
66 | 67 | obj_id = self._table.pop(msg_id) |
|
67 | 68 | self._records.remove(obj_id) |
|
68 | 69 | |
|
69 |
def find_records(self, check, |
|
|
70 |
"""Find records matching a query dict. |
|
|
71 | matches = list(self._records.find(check)) | |
|
72 | if id_only: | |
|
73 | return [ rec['msg_id'] for rec in matches ] | |
|
74 |
|
|
|
75 | data = {} | |
|
76 | for rec in matches: | |
|
77 | data[rec['msg_id']] = rec | |
|
78 | return data | |
|
70 | def find_records(self, check, keys=None): | |
|
71 | """Find records matching a query dict, optionally extracting subset of keys. | |
|
72 | ||
|
73 | Returns list of matching records. | |
|
74 | ||
|
75 | Parameters | |
|
76 | ---------- | |
|
77 | ||
|
78 | check: dict | |
|
79 | mongodb-style query argument | |
|
80 | keys: list of strs [optional] | |
|
81 | if specified, the subset of keys to extract. msg_id will *always* be | |
|
82 | included. | |
|
83 | """ | |
|
84 | if keys and 'msg_id' not in keys: | |
|
85 | keys.append('msg_id') | |
|
86 | matches = list(self._records.find(check,keys)) | |
|
87 | for rec in matches: | |
|
88 | rec.pop('_id') | |
|
89 | return matches | |
|
90 | ||
|
91 | def get_history(self): | |
|
92 | """get all msg_ids, ordered by time submitted.""" | |
|
93 | cursor = self._records.find({},{'msg_id':1}).sort('submitted') | |
|
94 | return [ rec['msg_id'] for rec in cursor ] | |
|
79 | 95 | |
|
80 | 96 |
@@ -24,7 +24,7 b' from IPython.parallel.util import ISO8601' | |||
|
24 | 24 | #----------------------------------------------------------------------------- |
|
25 | 25 | |
|
26 | 26 | operators = { |
|
27 |
'$lt' : |
|
|
27 | '$lt' : "<", | |
|
28 | 28 | '$gt' : ">", |
|
29 | 29 | # null is handled weird with ==,!= |
|
30 | 30 | '$eq' : "IS", |
@@ -124,10 +124,11 b' class SQLiteDB(BaseDB):' | |||
|
124 | 124 | pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop) |
|
125 | 125 | pc.start() |
|
126 | 126 | |
|
127 | def _defaults(self): | |
|
127 | def _defaults(self, keys=None): | |
|
128 | 128 | """create an empty record""" |
|
129 | 129 | d = {} |
|
130 | for key in self._keys: | |
|
130 | keys = self._keys if keys is None else keys | |
|
131 | for key in keys: | |
|
131 | 132 | d[key] = None |
|
132 | 133 | return d |
|
133 | 134 | |
@@ -168,9 +169,6 b' class SQLiteDB(BaseDB):' | |||
|
168 | 169 | stdout text, |
|
169 | 170 | stderr text) |
|
170 | 171 | """%self.table) |
|
171 | # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers | |
|
172 | # (msg_id text, result integer, buffer blob) | |
|
173 | # """%self.table) | |
|
174 | 172 | self._db.commit() |
|
175 | 173 | |
|
176 | 174 | def _dict_to_list(self, d): |
@@ -178,10 +176,11 b' class SQLiteDB(BaseDB):' | |||
|
178 | 176 | |
|
179 | 177 | return [ d[key] for key in self._keys ] |
|
180 | 178 | |
|
181 | def _list_to_dict(self, line): | |
|
179 | def _list_to_dict(self, line, keys=None): | |
|
182 | 180 | """Inverse of dict_to_list""" |
|
183 | d = self._defaults() | |
|
184 | for key,value in zip(self._keys, line): | |
|
181 | keys = self._keys if keys is None else keys | |
|
182 | d = self._defaults(keys) | |
|
183 | for key,value in zip(keys, line): | |
|
185 | 184 | d[key] = value |
|
186 | 185 | |
|
187 | 186 | return d |
@@ -249,13 +248,14 b' class SQLiteDB(BaseDB):' | |||
|
249 | 248 | sets.append('%s = ?'%key) |
|
250 | 249 | values.append(rec[key]) |
|
251 | 250 | query += ', '.join(sets) |
|
252 |
query += ' WHERE msg_id == |
|
|
251 | query += ' WHERE msg_id == ?' | |
|
252 | values.append(msg_id) | |
|
253 | 253 | self._db.execute(query, values) |
|
254 | 254 | # self._db.commit() |
|
255 | 255 | |
|
256 | 256 | def drop_record(self, msg_id): |
|
257 | 257 | """Remove a record from the DB.""" |
|
258 |
self._db.execute("""DELETE FROM %s WHERE m |
|
|
258 | self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) | |
|
259 | 259 | # self._db.commit() |
|
260 | 260 | |
|
261 | 261 | def drop_matching_records(self, check): |
@@ -265,20 +265,48 b' class SQLiteDB(BaseDB):' | |||
|
265 | 265 | self._db.execute(query,args) |
|
266 | 266 | # self._db.commit() |
|
267 | 267 | |
|
268 |
def find_records(self, check, |
|
|
269 |
"""Find records matching a query dict. |
|
|
270 | req = 'msg_id' if id_only else '*' | |
|
268 | def find_records(self, check, keys=None): | |
|
269 | """Find records matching a query dict, optionally extracting subset of keys. | |
|
270 | ||
|
271 | Returns list of matching records. | |
|
272 | ||
|
273 | Parameters | |
|
274 | ---------- | |
|
275 | ||
|
276 | check: dict | |
|
277 | mongodb-style query argument | |
|
278 | keys: list of strs [optional] | |
|
279 | if specified, the subset of keys to extract. msg_id will *always* be | |
|
280 | included. | |
|
281 | """ | |
|
282 | if keys: | |
|
283 | bad_keys = [ key for key in keys if key not in self._keys ] | |
|
284 | if bad_keys: | |
|
285 | raise KeyError("Bad record key(s): %s"%bad_keys) | |
|
286 | ||
|
287 | if keys: | |
|
288 | # ensure msg_id is present and first: | |
|
289 | if 'msg_id' in keys: | |
|
290 | keys.remove('msg_id') | |
|
291 | keys.insert(0, 'msg_id') | |
|
292 | req = ', '.join(keys) | |
|
293 | else: | |
|
294 | req = '*' | |
|
271 | 295 | expr,args = self._render_expression(check) |
|
272 | 296 | query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) |
|
273 | 297 | cursor = self._db.execute(query, args) |
|
274 | 298 | matches = cursor.fetchall() |
|
275 | if id_only: | |
|
276 |
|
|
|
277 | else: | |
|
278 |
records |
|
|
279 | for line in matches: | |
|
280 | rec = self._list_to_dict(line) | |
|
281 | records[rec['msg_id']] = rec | |
|
282 | return records | |
|
299 | records = [] | |
|
300 | for line in matches: | |
|
301 | rec = self._list_to_dict(line, keys) | |
|
302 | records.append(rec) | |
|
303 | return records | |
|
304 | ||
|
305 | def get_history(self): | |
|
306 | """get all msg_ids, ordered by time submitted.""" | |
|
307 | query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table | |
|
308 | cursor = self._db.execute(query) | |
|
309 | # will be a list of length 1 tuples | |
|
310 | return [ tup[0] for tup in cursor.fetchall()] | |
|
283 | 311 | |
|
284 | 312 | __all__ = ['SQLiteDB'] No newline at end of file |
@@ -28,6 +28,7 b' from zmq.eventloop.zmqstream import ZMQStream' | |||
|
28 | 28 | from .util import ISO8601 |
|
29 | 29 | |
|
30 | 30 | def squash_unicode(obj): |
|
31 | """coerce unicode back to bytestrings.""" | |
|
31 | 32 | if isinstance(obj,dict): |
|
32 | 33 | for key in obj.keys(): |
|
33 | 34 | obj[key] = squash_unicode(obj[key]) |
@@ -40,7 +41,14 b' def squash_unicode(obj):' | |||
|
40 | 41 | obj = obj.encode('utf8') |
|
41 | 42 | return obj |
|
42 | 43 | |
|
43 | json_packer = jsonapi.dumps | |
|
44 | def _date_default(obj): | |
|
45 | if isinstance(obj, datetime): | |
|
46 | return obj.strftime(ISO8601) | |
|
47 | else: | |
|
48 | raise TypeError("%r is not JSON serializable"%obj) | |
|
49 | ||
|
50 | _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default' | |
|
51 | json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default}) | |
|
44 | 52 | json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) |
|
45 | 53 | |
|
46 | 54 | pickle_packer = lambda o: pickle.dumps(o,-1) |
@@ -17,6 +17,7 b' import re' | |||
|
17 | 17 | import stat |
|
18 | 18 | import socket |
|
19 | 19 | import sys |
|
20 | from datetime import datetime | |
|
20 | 21 | from signal import signal, SIGINT, SIGABRT, SIGTERM |
|
21 | 22 | try: |
|
22 | 23 | from signal import SIGKILL |
@@ -41,6 +42,7 b' from IPython.zmq.log import EnginePUBHandler' | |||
|
41 | 42 | |
|
42 | 43 | # globals |
|
43 | 44 | ISO8601="%Y-%m-%dT%H:%M:%S.%f" |
|
45 | ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$") | |
|
44 | 46 | |
|
45 | 47 | #----------------------------------------------------------------------------- |
|
46 | 48 | # Classes |
@@ -99,6 +101,18 b' class ReverseDict(dict):' | |||
|
99 | 101 | # Functions |
|
100 | 102 | #----------------------------------------------------------------------------- |
|
101 | 103 | |
|
104 | def extract_dates(obj): | |
|
105 | """extract ISO8601 dates from unpacked JSON""" | |
|
106 | if isinstance(obj, dict): | |
|
107 | for k,v in obj.iteritems(): | |
|
108 | obj[k] = extract_dates(v) | |
|
109 | elif isinstance(obj, list): | |
|
110 | obj = [ extract_dates(o) for o in obj ] | |
|
111 | elif isinstance(obj, basestring): | |
|
112 | if ISO8601_RE.match(obj): | |
|
113 | obj = datetime.strptime(obj, ISO8601) | |
|
114 | return obj | |
|
115 | ||
|
102 | 116 | def validate_url(url): |
|
103 | 117 | """validate a url for zeromq""" |
|
104 | 118 | if not isinstance(url, basestring): |
@@ -460,3 +474,4 b' def local_logger(logname, loglevel=logging.DEBUG):' | |||
|
460 | 474 | handler.setLevel(loglevel) |
|
461 | 475 | logger.addHandler(handler) |
|
462 | 476 | logger.setLevel(loglevel) |
|
477 |
General Comments 0
You need to be logged in to leave comments.
Login now