##// END OF EJS Templates
dictdb queries should [shallow] copy records...
MinRK -
Show More
@@ -1,216 +1,216 b''
1 1 """A Task logger that presents our DB interface,
2 2 but exists entirely in memory and implemented with dicts.
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7
8 8
9 9 TaskRecords are dicts of the form:
10 10 {
11 11 'msg_id' : str(uuid),
12 12 'client_uuid' : str(uuid),
13 13 'engine_uuid' : str(uuid) or None,
14 14 'header' : dict(header),
15 15 'content': dict(content),
16 16 'buffers': list(buffers),
17 17 'submitted': datetime,
18 18 'started': datetime or None,
19 19 'completed': datetime or None,
20 20 'resubmitted': datetime or None,
21 21 'result_header' : dict(header) or None,
22 22 'result_content' : dict(content) or None,
23 23 'result_buffers' : list(buffers) or None,
24 24 }
25 25 With this info, many of the special categories of tasks can be defined by query:
26 26
27 27 pending: completed is None
28 28 client's outstanding: client_uuid = uuid && completed is None
29 29 MIA: arrived is None (and completed is None)
30 30 etc.
31 31
32 32 EngineRecords are dicts of the form:
33 33 {
34 34 'eid' : int(id),
35 35 'uuid': str(uuid)
36 36 }
37 37 This may be extended, but is currently.
38 38
39 39 We support a subset of mongodb operators:
40 40 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
41 41 """
42 42 #-----------------------------------------------------------------------------
43 43 # Copyright (C) 2010-2011 The IPython Development Team
44 44 #
45 45 # Distributed under the terms of the BSD License. The full license is in
46 46 # the file COPYING, distributed as part of this software.
47 47 #-----------------------------------------------------------------------------
48 48
49
49 from copy import copy
50 50 from datetime import datetime
51 51
52 52 from IPython.config.configurable import LoggingConfigurable
53 53
54 54 from IPython.utils.traitlets import Dict, Unicode, Instance
55 55
56 56 filters = {
57 57 '$lt' : lambda a,b: a < b,
58 58 '$gt' : lambda a,b: b > a,
59 59 '$eq' : lambda a,b: a == b,
60 60 '$ne' : lambda a,b: a != b,
61 61 '$lte': lambda a,b: a <= b,
62 62 '$gte': lambda a,b: a >= b,
63 63 '$in' : lambda a,b: a in b,
64 64 '$nin': lambda a,b: a not in b,
65 65 '$all': lambda a,b: all([ a in bb for bb in b ]),
66 66 '$mod': lambda a,b: a%b[0] == b[1],
67 67 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
68 68 }
69 69
70 70
71 71 class CompositeFilter(object):
72 72 """Composite filter for matching multiple properties."""
73 73
74 74 def __init__(self, dikt):
75 75 self.tests = []
76 76 self.values = []
77 77 for key, value in dikt.iteritems():
78 78 self.tests.append(filters[key])
79 79 self.values.append(value)
80 80
81 81 def __call__(self, value):
82 82 for test,check in zip(self.tests, self.values):
83 83 if not test(value, check):
84 84 return False
85 85 return True
86 86
87 87 class BaseDB(LoggingConfigurable):
88 88 """Empty Parent class so traitlets work on DB."""
89 89 # base configurable traits:
90 90 session = Unicode("")
91 91
92 92 class DictDB(BaseDB):
93 93 """Basic in-memory dict-based object for saving Task Records.
94 94
95 95 This is the first object to present the DB interface
96 96 for logging tasks out of memory.
97 97
98 98 The interface is based on MongoDB, so adding a MongoDB
99 99 backend should be straightforward.
100 100 """
101 101
102 102 _records = Dict()
103 103
104 104 def _match_one(self, rec, tests):
105 105 """Check if a specific record matches tests."""
106 106 for key,test in tests.iteritems():
107 107 if not test(rec.get(key, None)):
108 108 return False
109 109 return True
110 110
111 111 def _match(self, check):
112 112 """Find all the matches for a check dict."""
113 113 matches = []
114 114 tests = {}
115 115 for k,v in check.iteritems():
116 116 if isinstance(v, dict):
117 117 tests[k] = CompositeFilter(v)
118 118 else:
119 119 tests[k] = lambda o: o==v
120 120
121 121 for rec in self._records.itervalues():
122 122 if self._match_one(rec, tests):
123 matches.append(rec)
123 matches.append(copy(rec))
124 124 return matches
125 125
126 126 def _extract_subdict(self, rec, keys):
127 127 """extract subdict of keys"""
128 128 d = {}
129 129 d['msg_id'] = rec['msg_id']
130 130 for key in keys:
131 131 d[key] = rec[key]
132 132 return d
133 133
134 134 def add_record(self, msg_id, rec):
135 135 """Add a new Task Record, by msg_id."""
136 136 if self._records.has_key(msg_id):
137 137 raise KeyError("Already have msg_id %r"%(msg_id))
138 138 self._records[msg_id] = rec
139 139
140 140 def get_record(self, msg_id):
141 141 """Get a specific Task Record, by msg_id."""
142 if not self._records.has_key(msg_id):
142 if not msg_id in self._records:
143 143 raise KeyError("No such msg_id %r"%(msg_id))
144 return self._records[msg_id]
144 return copy(self._records[msg_id])
145 145
146 146 def update_record(self, msg_id, rec):
147 147 """Update the data in an existing record."""
148 148 self._records[msg_id].update(rec)
149 149
150 150 def drop_matching_records(self, check):
151 151 """Remove a record from the DB."""
152 152 matches = self._match(check)
153 153 for m in matches:
154 154 del self._records[m['msg_id']]
155 155
156 156 def drop_record(self, msg_id):
157 157 """Remove a record from the DB."""
158 158 del self._records[msg_id]
159 159
160 160
161 161 def find_records(self, check, keys=None):
162 162 """Find records matching a query dict, optionally extracting subset of keys.
163 163
164 164 Returns dict keyed by msg_id of matching records.
165 165
166 166 Parameters
167 167 ----------
168 168
169 169 check: dict
170 170 mongodb-style query argument
171 171 keys: list of strs [optional]
172 172 if specified, the subset of keys to extract. msg_id will *always* be
173 173 included.
174 174 """
175 175 matches = self._match(check)
176 176 if keys:
177 177 return [ self._extract_subdict(rec, keys) for rec in matches ]
178 178 else:
179 179 return matches
180 180
181 181
182 182 def get_history(self):
183 183 """get all msg_ids, ordered by time submitted."""
184 184 msg_ids = self._records.keys()
185 185 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
186 186
187 187 class NoDB(DictDB):
188 188 """A blackhole db backend that actually stores no information.
189 189
190 190 Provides the full DB interface, but raises KeyErrors on any
191 191 method that tries to access the records. This can be used to
192 192 minimize the memory footprint of the Hub when its record-keeping
193 193 functionality is not required.
194 194 """
195 195
196 196 def add_record(self, msg_id, record):
197 197 pass
198 198
199 199 def get_record(self, msg_id):
200 200 raise KeyError("NoDB does not support record access")
201 201
202 202 def update_record(self, msg_id, record):
203 203 pass
204 204
205 205 def drop_matching_records(self, check):
206 206 pass
207 207
208 208 def drop_record(self, msg_id):
209 209 pass
210 210
211 211 def find_records(self, check, keys=None):
212 212 raise KeyError("NoDB does not store information")
213 213
214 214 def get_history(self):
215 215 raise KeyError("NoDB does not store information")
216 216
@@ -1,319 +1,331 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython.parallel.client import client as clientmod
28 28 from IPython.parallel import error
29 29 from IPython.parallel import AsyncResult, AsyncHubResult
30 30 from IPython.parallel import LoadBalancedView, DirectView
31 31
32 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 33
34 34 def setup():
35 35 add_engines(4, total=True)
36 36
37 37 class TestClient(ClusterTestCase):
38 38
39 39 def test_ids(self):
40 40 n = len(self.client.ids)
41 41 self.add_engines(2)
42 42 self.assertEquals(len(self.client.ids), n+2)
43 43
44 44 def test_view_indexing(self):
45 45 """test index access for views"""
46 46 self.minimum_engines(4)
47 47 targets = self.client._build_targets('all')[-1]
48 48 v = self.client[:]
49 49 self.assertEquals(v.targets, targets)
50 50 t = self.client.ids[2]
51 51 v = self.client[t]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, t)
54 54 t = self.client.ids[2:4]
55 55 v = self.client[t]
56 56 self.assert_(isinstance(v, DirectView))
57 57 self.assertEquals(v.targets, t)
58 58 v = self.client[::2]
59 59 self.assert_(isinstance(v, DirectView))
60 60 self.assertEquals(v.targets, targets[::2])
61 61 v = self.client[1::3]
62 62 self.assert_(isinstance(v, DirectView))
63 63 self.assertEquals(v.targets, targets[1::3])
64 64 v = self.client[:-3]
65 65 self.assert_(isinstance(v, DirectView))
66 66 self.assertEquals(v.targets, targets[:-3])
67 67 v = self.client[-1]
68 68 self.assert_(isinstance(v, DirectView))
69 69 self.assertEquals(v.targets, targets[-1])
70 70 self.assertRaises(TypeError, lambda : self.client[None])
71 71
72 72 def test_lbview_targets(self):
73 73 """test load_balanced_view targets"""
74 74 v = self.client.load_balanced_view()
75 75 self.assertEquals(v.targets, None)
76 76 v = self.client.load_balanced_view(-1)
77 77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 78 v = self.client.load_balanced_view('all')
79 79 self.assertEquals(v.targets, None)
80 80
81 81 def test_dview_targets(self):
82 82 """test direct_view targets"""
83 83 v = self.client.direct_view()
84 84 self.assertEquals(v.targets, 'all')
85 85 v = self.client.direct_view('all')
86 86 self.assertEquals(v.targets, 'all')
87 87 v = self.client.direct_view(-1)
88 88 self.assertEquals(v.targets, self.client.ids[-1])
89 89
90 90 def test_lazy_all_targets(self):
91 91 """test lazy evaluation of rc.direct_view('all')"""
92 92 v = self.client.direct_view()
93 93 self.assertEquals(v.targets, 'all')
94 94
95 95 def double(x):
96 96 return x*2
97 97 seq = range(100)
98 98 ref = [ double(x) for x in seq ]
99 99
100 100 # add some engines, which should be used
101 101 self.add_engines(1)
102 102 n1 = len(self.client.ids)
103 103
104 104 # simple apply
105 105 r = v.apply_sync(lambda : 1)
106 106 self.assertEquals(r, [1] * n1)
107 107
108 108 # map goes through remotefunction
109 109 r = v.map_sync(double, seq)
110 110 self.assertEquals(r, ref)
111 111
112 112 # add a couple more engines, and try again
113 113 self.add_engines(2)
114 114 n2 = len(self.client.ids)
115 115 self.assertNotEquals(n2, n1)
116 116
117 117 # apply
118 118 r = v.apply_sync(lambda : 1)
119 119 self.assertEquals(r, [1] * n2)
120 120
121 121 # map
122 122 r = v.map_sync(double, seq)
123 123 self.assertEquals(r, ref)
124 124
125 125 def test_targets(self):
126 126 """test various valid targets arguments"""
127 127 build = self.client._build_targets
128 128 ids = self.client.ids
129 129 idents,targets = build(None)
130 130 self.assertEquals(ids, targets)
131 131
132 132 def test_clear(self):
133 133 """test clear behavior"""
134 134 self.minimum_engines(2)
135 135 v = self.client[:]
136 136 v.block=True
137 137 v.push(dict(a=5))
138 138 v.pull('a')
139 139 id0 = self.client.ids[-1]
140 140 self.client.clear(targets=id0, block=True)
141 141 a = self.client[:-1].get('a')
142 142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 143 self.client.clear(block=True)
144 144 for i in self.client.ids:
145 145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 146
147 147 def test_get_result(self):
148 148 """test getting results from the Hub."""
149 149 c = clientmod.Client(profile='iptest')
150 150 t = c.ids[-1]
151 151 ar = c[t].apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = self.client.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 156 self.assertEquals(ahr.get(), ar.get())
157 157 ar2 = self.client.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.close()
160 160
161 161 def test_ids_list(self):
162 162 """test client.ids"""
163 163 ids = self.client.ids
164 164 self.assertEquals(ids, self.client._ids)
165 165 self.assertFalse(ids is self.client._ids)
166 166 ids.remove(ids[-1])
167 167 self.assertNotEquals(ids, self.client._ids)
168 168
169 169 def test_queue_status(self):
170 170 ids = self.client.ids
171 171 id0 = ids[0]
172 172 qs = self.client.queue_status(targets=id0)
173 173 self.assertTrue(isinstance(qs, dict))
174 174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 175 allqs = self.client.queue_status()
176 176 self.assertTrue(isinstance(allqs, dict))
177 177 intkeys = list(allqs.keys())
178 178 intkeys.remove('unassigned')
179 179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 180 unassigned = allqs.pop('unassigned')
181 181 for eid,qs in allqs.items():
182 182 self.assertTrue(isinstance(qs, dict))
183 183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184 184
185 185 def test_shutdown(self):
186 186 ids = self.client.ids
187 187 id0 = ids[0]
188 188 self.client.shutdown(id0, block=True)
189 189 while id0 in self.client.ids:
190 190 time.sleep(0.1)
191 191 self.client.spin()
192 192
193 193 self.assertRaises(IndexError, lambda : self.client[id0])
194 194
195 195 def test_result_status(self):
196 196 pass
197 197 # to be written
198 198
199 199 def test_db_query_dt(self):
200 200 """test db query by date"""
201 201 hist = self.client.hub_history()
202 202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 203 tic = middle['submitted']
204 204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 206 self.assertEquals(len(before)+len(after),len(hist))
207 207 for b in before:
208 208 self.assertTrue(b['submitted'] < tic)
209 209 for a in after:
210 210 self.assertTrue(a['submitted'] >= tic)
211 211 same = self.client.db_query({'submitted' : tic})
212 212 for s in same:
213 213 self.assertTrue(s['submitted'] == tic)
214 214
215 215 def test_db_query_keys(self):
216 216 """test extracting subset of record keys"""
217 217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 218 for rec in found:
219 219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220 220
221 221 def test_db_query_default_keys(self):
222 222 """default db_query excludes buffers"""
223 223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 224 for rec in found:
225 225 keys = set(rec.keys())
226 226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228 228
229 229 def test_db_query_msg_id(self):
230 230 """ensure msg_id is always in db queries"""
231 231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 232 for rec in found:
233 233 self.assertTrue('msg_id' in rec.keys())
234 234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 235 for rec in found:
236 236 self.assertTrue('msg_id' in rec.keys())
237 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 238 for rec in found:
239 239 self.assertTrue('msg_id' in rec.keys())
240 240
241 def test_db_query_get_result(self):
242 """pop in db_query shouldn't pop from result itself"""
243 self.client[:].apply_sync(lambda : 1)
244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 rc2 = clientmod.Client(profile='iptest')
246 # If this bug is not fixed, this call will hang:
247 ar = rc2.get_result(self.client.history[-1])
248 ar.wait(2)
249 self.assertTrue(ar.ready())
250 ar.get()
251 rc2.close()
252
241 253 def test_db_query_in(self):
242 254 """test db query with '$in','$nin' operators"""
243 255 hist = self.client.hub_history()
244 256 even = hist[::2]
245 257 odd = hist[1::2]
246 258 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
247 259 found = [ r['msg_id'] for r in recs ]
248 260 self.assertEquals(set(even), set(found))
249 261 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
250 262 found = [ r['msg_id'] for r in recs ]
251 263 self.assertEquals(set(odd), set(found))
252 264
253 265 def test_hub_history(self):
254 266 hist = self.client.hub_history()
255 267 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
256 268 recdict = {}
257 269 for rec in recs:
258 270 recdict[rec['msg_id']] = rec
259 271
260 272 latest = datetime(1984,1,1)
261 273 for msg_id in hist:
262 274 rec = recdict[msg_id]
263 275 newt = rec['submitted']
264 276 self.assertTrue(newt >= latest)
265 277 latest = newt
266 278 ar = self.client[-1].apply_async(lambda : 1)
267 279 ar.get()
268 280 time.sleep(0.25)
269 281 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
270 282
271 283 def test_resubmit(self):
272 284 def f():
273 285 import random
274 286 return random.random()
275 287 v = self.client.load_balanced_view()
276 288 ar = v.apply_async(f)
277 289 r1 = ar.get(1)
278 290 # give the Hub a chance to notice:
279 291 time.sleep(0.5)
280 292 ahr = self.client.resubmit(ar.msg_ids)
281 293 r2 = ahr.get(1)
282 294 self.assertFalse(r1 == r2)
283 295
284 296 def test_resubmit_inflight(self):
285 297 """ensure ValueError on resubmit of inflight task"""
286 298 v = self.client.load_balanced_view()
287 299 ar = v.apply_async(time.sleep,1)
288 300 # give the message a chance to arrive
289 301 time.sleep(0.2)
290 302 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
291 303 ar.get(2)
292 304
293 305 def test_resubmit_badkey(self):
294 306 """ensure KeyError on resubmit of nonexistant task"""
295 307 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
296 308
297 309 def test_purge_results(self):
298 310 # ensure there are some tasks
299 311 for i in range(5):
300 312 self.client[:].apply_sync(lambda : 1)
301 313 # Wait for the Hub to realise the result is done:
302 314 # This prevents a race condition, where we
303 315 # might purge a result the Hub still thinks is pending.
304 316 time.sleep(0.1)
305 317 rc2 = clientmod.Client(profile='iptest')
306 318 hist = self.client.hub_history()
307 319 ahr = rc2.get_result([hist[-1]])
308 320 ahr.wait(10)
309 321 self.client.purge_results(hist[-1])
310 322 newhist = self.client.hub_history()
311 323 self.assertEquals(len(newhist)+1,len(hist))
312 324 rc2.spin()
313 325 rc2.close()
314 326
315 327 def test_purge_all_results(self):
316 328 self.client.purge_results('all')
317 329 hist = self.client.hub_history()
318 330 self.assertEquals(len(hist), 0)
319 331
@@ -1,194 +1,224 b''
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import tempfile
22 22 import time
23 23
24 24 from datetime import datetime, timedelta
25 25 from unittest import TestCase
26 26
27 27 from IPython.parallel import error
28 28 from IPython.parallel.controller.dictdb import DictDB
29 29 from IPython.parallel.controller.sqlitedb import SQLiteDB
30 30 from IPython.parallel.controller.hub import init_record, empty_record
31 31
32 32 from IPython.testing import decorators as dec
33 33 from IPython.zmq.session import Session
34 34
35 35
36 36 #-------------------------------------------------------------------------------
37 37 # TestCases
38 38 #-------------------------------------------------------------------------------
39 39
40 40 class TestDictBackend(TestCase):
41 41 def setUp(self):
42 42 self.session = Session()
43 43 self.db = self.create_db()
44 44 self.load_records(16)
45 45
46 46 def create_db(self):
47 47 return DictDB()
48 48
49 49 def load_records(self, n=1):
50 50 """load n records for testing"""
51 51 #sleep 1/10 s, to ensure timestamp is different to previous calls
52 52 time.sleep(0.1)
53 53 msg_ids = []
54 54 for i in range(n):
55 55 msg = self.session.msg('apply_request', content=dict(a=5))
56 56 msg['buffers'] = []
57 57 rec = init_record(msg)
58 58 msg_id = msg['header']['msg_id']
59 59 msg_ids.append(msg_id)
60 60 self.db.add_record(msg_id, rec)
61 61 return msg_ids
62 62
63 63 def test_add_record(self):
64 64 before = self.db.get_history()
65 65 self.load_records(5)
66 66 after = self.db.get_history()
67 67 self.assertEquals(len(after), len(before)+5)
68 68 self.assertEquals(after[:-5],before)
69 69
70 70 def test_drop_record(self):
71 71 msg_id = self.load_records()[-1]
72 72 rec = self.db.get_record(msg_id)
73 73 self.db.drop_record(msg_id)
74 74 self.assertRaises(KeyError,self.db.get_record, msg_id)
75 75
76 76 def _round_to_millisecond(self, dt):
77 77 """necessary because mongodb rounds microseconds"""
78 78 micro = dt.microsecond
79 79 extra = int(str(micro)[-3:])
80 80 return dt - timedelta(microseconds=extra)
81 81
82 82 def test_update_record(self):
83 83 now = self._round_to_millisecond(datetime.now())
84 84 #
85 85 msg_id = self.db.get_history()[-1]
86 86 rec1 = self.db.get_record(msg_id)
87 87 data = {'stdout': 'hello there', 'completed' : now}
88 88 self.db.update_record(msg_id, data)
89 89 rec2 = self.db.get_record(msg_id)
90 90 self.assertEquals(rec2['stdout'], 'hello there')
91 91 self.assertEquals(rec2['completed'], now)
92 92 rec1.update(data)
93 93 self.assertEquals(rec1, rec2)
94 94
95 95 # def test_update_record_bad(self):
96 96 # """test updating nonexistant records"""
97 97 # msg_id = str(uuid.uuid4())
98 98 # data = {'stdout': 'hello there'}
99 99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
100 100
101 101 def test_find_records_dt(self):
102 102 """test finding records by date"""
103 103 hist = self.db.get_history()
104 104 middle = self.db.get_record(hist[len(hist)//2])
105 105 tic = middle['submitted']
106 106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
107 107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
108 108 self.assertEquals(len(before)+len(after),len(hist))
109 109 for b in before:
110 110 self.assertTrue(b['submitted'] < tic)
111 111 for a in after:
112 112 self.assertTrue(a['submitted'] >= tic)
113 113 same = self.db.find_records({'submitted' : tic})
114 114 for s in same:
115 115 self.assertTrue(s['submitted'] == tic)
116 116
117 117 def test_find_records_keys(self):
118 118 """test extracting subset of record keys"""
119 119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 120 for rec in found:
121 121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
122 122
123 123 def test_find_records_msg_id(self):
124 124 """ensure msg_id is always in found records"""
125 125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
126 126 for rec in found:
127 127 self.assertTrue('msg_id' in rec.keys())
128 128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
129 129 for rec in found:
130 130 self.assertTrue('msg_id' in rec.keys())
131 131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
132 132 for rec in found:
133 133 self.assertTrue('msg_id' in rec.keys())
134 134
135 135 def test_find_records_in(self):
136 136 """test finding records with '$in','$nin' operators"""
137 137 hist = self.db.get_history()
138 138 even = hist[::2]
139 139 odd = hist[1::2]
140 140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
141 141 found = [ r['msg_id'] for r in recs ]
142 142 self.assertEquals(set(even), set(found))
143 143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
144 144 found = [ r['msg_id'] for r in recs ]
145 145 self.assertEquals(set(odd), set(found))
146 146
147 147 def test_get_history(self):
148 148 msg_ids = self.db.get_history()
149 149 latest = datetime(1984,1,1)
150 150 for msg_id in msg_ids:
151 151 rec = self.db.get_record(msg_id)
152 152 newt = rec['submitted']
153 153 self.assertTrue(newt >= latest)
154 154 latest = newt
155 155 msg_id = self.load_records(1)[-1]
156 156 self.assertEquals(self.db.get_history()[-1],msg_id)
157 157
158 158 def test_datetime(self):
159 159 """get/set timestamps with datetime objects"""
160 160 msg_id = self.db.get_history()[-1]
161 161 rec = self.db.get_record(msg_id)
162 162 self.assertTrue(isinstance(rec['submitted'], datetime))
163 163 self.db.update_record(msg_id, dict(completed=datetime.now()))
164 164 rec = self.db.get_record(msg_id)
165 165 self.assertTrue(isinstance(rec['completed'], datetime))
166 166
167 167 def test_drop_matching(self):
168 168 msg_ids = self.load_records(10)
169 169 query = {'msg_id' : {'$in':msg_ids}}
170 170 self.db.drop_matching_records(query)
171 171 recs = self.db.find_records(query)
172 172 self.assertEquals(len(recs), 0)
173 173
174 174 def test_null(self):
175 175 """test None comparison queries"""
176 176 msg_ids = self.load_records(10)
177 177
178 178 query = {'msg_id' : None}
179 179 recs = self.db.find_records(query)
180 180 self.assertEquals(len(recs), 0)
181 181
182 182 query = {'msg_id' : {'$ne' : None}}
183 183 recs = self.db.find_records(query)
184 184 self.assertTrue(len(recs) >= 10)
185
186 def test_pop_safe_get(self):
187 """editing query results shouldn't affect record [get]"""
188 msg_id = self.db.get_history()[-1]
189 rec = self.db.get_record(msg_id)
190 rec.pop('buffers')
191 rec['garbage'] = 'hello'
192 rec2 = self.db.get_record(msg_id)
193 self.assertTrue('buffers' in rec2)
194 self.assertFalse('garbage' in rec2)
195
196 def test_pop_safe_find(self):
197 """editing query results shouldn't affect record [find]"""
198 msg_id = self.db.get_history()[-1]
199 rec = self.db.find_records({'msg_id' : msg_id})[0]
200 rec.pop('buffers')
201 rec['garbage'] = 'hello'
202 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
203 self.assertTrue('buffers' in rec2)
204 self.assertFalse('garbage' in rec2)
205
206 def test_pop_safe_find_keys(self):
207 """editing query results shouldn't affect record [find+keys]"""
208 msg_id = self.db.get_history()[-1]
209 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers'])[0]
210 rec.pop('buffers')
211 rec['garbage'] = 'hello'
212 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
213 self.assertTrue('buffers' in rec2)
214 self.assertFalse('garbage' in rec2)
185 215
186 216
187 217 class TestSQLiteBackend(TestDictBackend):
188 218
189 219 @dec.skip_without('sqlite3')
190 220 def create_db(self):
191 221 return SQLiteDB(location=tempfile.gettempdir())
192 222
193 223 def tearDown(self):
194 224 self.db._db.close()
General Comments 0
You need to be logged in to leave comments. Login now