##// END OF EJS Templates
fix null comparisons in sqlitedb backend...
MinRK -
Show More
@@ -1,408 +1,408 b''
1 1 """A TaskRecord backend using sqlite3
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 import json
15 15 import os
16 16 import cPickle as pickle
17 17 from datetime import datetime
18 18
19 19 try:
20 20 import sqlite3
21 21 except ImportError:
22 22 sqlite3 = None
23 23
24 24 from zmq.eventloop import ioloop
25 25
26 26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
27 27 from .dictdb import BaseDB
28 28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # SQLite operators, adapters, and converters
32 32 #-----------------------------------------------------------------------------
33 33
34 34 try:
35 35 buffer
36 36 except NameError:
37 37 # py3k
38 38 buffer = memoryview
39 39
40 40 operators = {
41 41 '$lt' : "<",
42 42 '$gt' : ">",
43 43 # null is handled weird with ==,!=
44 44 '$eq' : "=",
45 45 '$ne' : "!=",
46 46 '$lte': "<=",
47 47 '$gte': ">=",
48 48 '$in' : ('=', ' OR '),
49 49 '$nin': ('!=', ' AND '),
50 50 # '$all': None,
51 51 # '$mod': None,
52 52 # '$exists' : None
53 53 }
54 54 null_operators = {
55 55 '=' : "IS NULL",
56 56 '!=' : "IS NOT NULL",
57 57 }
58 58
59 59 def _adapt_dict(d):
60 60 return json.dumps(d, default=date_default)
61 61
62 62 def _convert_dict(ds):
63 63 if ds is None:
64 64 return ds
65 65 else:
66 66 if isinstance(ds, bytes):
67 67 # If I understand the sqlite doc correctly, this will always be utf8
68 68 ds = ds.decode('utf8')
69 69 return extract_dates(json.loads(ds))
70 70
71 71 def _adapt_bufs(bufs):
72 72 # this is *horrible*
73 73 # copy buffers into single list and pickle it:
74 74 if bufs and isinstance(bufs[0], (bytes, buffer)):
75 75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
76 76 elif bufs:
77 77 return bufs
78 78 else:
79 79 return None
80 80
81 81 def _convert_bufs(bs):
82 82 if bs is None:
83 83 return []
84 84 else:
85 85 return pickle.loads(bytes(bs))
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # SQLiteDB class
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SQLiteDB(BaseDB):
92 92 """SQLite3 TaskRecord backend."""
93 93
94 94 filename = Unicode('tasks.db', config=True,
95 95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
96 96 location = Unicode('', config=True,
97 97 help="""The directory containing the sqlite task database. The default
98 98 is to use the cluster_dir location.""")
99 99 table = Unicode("", config=True,
100 100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
101 101 a new table will be created with the Hub's IDENT. Specifying the table will result
102 102 in tasks from previous sessions being available via Clients' db_query and
103 103 get_result methods.""")
104 104
105 105 if sqlite3 is not None:
106 106 _db = Instance('sqlite3.Connection')
107 107 else:
108 108 _db = None
109 109 # the ordered list of column names
110 110 _keys = List(['msg_id' ,
111 111 'header' ,
112 112 'content',
113 113 'buffers',
114 114 'submitted',
115 115 'client_uuid' ,
116 116 'engine_uuid' ,
117 117 'started',
118 118 'completed',
119 119 'resubmitted',
120 120 'result_header' ,
121 121 'result_content' ,
122 122 'result_buffers' ,
123 123 'queue' ,
124 124 'pyin' ,
125 125 'pyout',
126 126 'pyerr',
127 127 'stdout',
128 128 'stderr',
129 129 ])
130 130 # sqlite datatypes for checking that db is current format
131 131 _types = Dict({'msg_id' : 'text' ,
132 132 'header' : 'dict text',
133 133 'content' : 'dict text',
134 134 'buffers' : 'bufs blob',
135 135 'submitted' : 'timestamp',
136 136 'client_uuid' : 'text',
137 137 'engine_uuid' : 'text',
138 138 'started' : 'timestamp',
139 139 'completed' : 'timestamp',
140 140 'resubmitted' : 'timestamp',
141 141 'result_header' : 'dict text',
142 142 'result_content' : 'dict text',
143 143 'result_buffers' : 'bufs blob',
144 144 'queue' : 'text',
145 145 'pyin' : 'text',
146 146 'pyout' : 'text',
147 147 'pyerr' : 'text',
148 148 'stdout' : 'text',
149 149 'stderr' : 'text',
150 150 })
151 151
152 152 def __init__(self, **kwargs):
153 153 super(SQLiteDB, self).__init__(**kwargs)
154 154 if sqlite3 is None:
155 155 raise ImportError("SQLiteDB requires sqlite3")
156 156 if not self.table:
157 157 # use session, and prefix _, since starting with # is illegal
158 158 self.table = '_'+self.session.replace('-','_')
159 159 if not self.location:
160 160 # get current profile
161 161 from IPython.core.application import BaseIPythonApplication
162 162 if BaseIPythonApplication.initialized():
163 163 app = BaseIPythonApplication.instance()
164 164 if app.profile_dir is not None:
165 165 self.location = app.profile_dir.location
166 166 else:
167 167 self.location = u'.'
168 168 else:
169 169 self.location = u'.'
170 170 self._init_db()
171 171
172 172 # register db commit as 2s periodic callback
173 173 # to prevent clogging pipes
174 174 # assumes we are being run in a zmq ioloop app
175 175 loop = ioloop.IOLoop.instance()
176 176 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
177 177 pc.start()
178 178
179 179 def _defaults(self, keys=None):
180 180 """create an empty record"""
181 181 d = {}
182 182 keys = self._keys if keys is None else keys
183 183 for key in keys:
184 184 d[key] = None
185 185 return d
186 186
187 187 def _check_table(self):
188 188 """Ensure that an incorrect table doesn't exist
189 189
190 190 If a bad (old) table does exist, return False
191 191 """
192 192 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
193 193 lines = cursor.fetchall()
194 194 if not lines:
195 195 # table does not exist
196 196 return True
197 197 types = {}
198 198 keys = []
199 199 for line in lines:
200 200 keys.append(line[1])
201 201 types[line[1]] = line[2]
202 202 if self._keys != keys:
203 203 # key mismatch
204 204 self.log.warn('keys mismatch')
205 205 return False
206 206 for key in self._keys:
207 207 if types[key] != self._types[key]:
208 208 self.log.warn(
209 209 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
210 210 )
211 211 return False
212 212 return True
213 213
214 214 def _init_db(self):
215 215 """Connect to the database and get new session number."""
216 216 # register adapters
217 217 sqlite3.register_adapter(dict, _adapt_dict)
218 218 sqlite3.register_converter('dict', _convert_dict)
219 219 sqlite3.register_adapter(list, _adapt_bufs)
220 220 sqlite3.register_converter('bufs', _convert_bufs)
221 221 # connect to the db
222 222 dbfile = os.path.join(self.location, self.filename)
223 223 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
224 224 # isolation_level = None)#,
225 225 cached_statements=64)
226 226 # print dir(self._db)
227 227 first_table = self.table
228 228 i=0
229 229 while not self._check_table():
230 230 i+=1
231 231 self.table = first_table+'_%i'%i
232 232 self.log.warn(
233 233 "Table %s exists and doesn't match db format, trying %s"%
234 234 (first_table,self.table)
235 235 )
236 236
237 237 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
238 238 (msg_id text PRIMARY KEY,
239 239 header dict text,
240 240 content dict text,
241 241 buffers bufs blob,
242 242 submitted timestamp,
243 243 client_uuid text,
244 244 engine_uuid text,
245 245 started timestamp,
246 246 completed timestamp,
247 247 resubmitted timestamp,
248 248 result_header dict text,
249 249 result_content dict text,
250 250 result_buffers bufs blob,
251 251 queue text,
252 252 pyin text,
253 253 pyout text,
254 254 pyerr text,
255 255 stdout text,
256 256 stderr text)
257 257 """%self.table)
258 258 self._db.commit()
259 259
260 260 def _dict_to_list(self, d):
261 261 """turn a mongodb-style record dict into a list."""
262 262
263 263 return [ d[key] for key in self._keys ]
264 264
265 265 def _list_to_dict(self, line, keys=None):
266 266 """Inverse of dict_to_list"""
267 267 keys = self._keys if keys is None else keys
268 268 d = self._defaults(keys)
269 269 for key,value in zip(keys, line):
270 270 d[key] = value
271 271
272 272 return d
273 273
274 274 def _render_expression(self, check):
275 275 """Turn a mongodb-style search dict into an SQL query."""
276 276 expressions = []
277 277 args = []
278 278
279 279 skeys = set(check.keys())
280 280 skeys.difference_update(set(self._keys))
281 281 skeys.difference_update(set(['buffers', 'result_buffers']))
282 282 if skeys:
283 283 raise KeyError("Illegal testing key(s): %s"%skeys)
284 284
285 285 for name,sub_check in check.iteritems():
286 286 if isinstance(sub_check, dict):
287 287 for test,value in sub_check.iteritems():
288 288 try:
289 289 op = operators[test]
290 290 except KeyError:
291 291 raise KeyError("Unsupported operator: %r"%test)
292 292 if isinstance(op, tuple):
293 293 op, join = op
294 294
295 295 if value is None and op in null_operators:
296 expr = "%s %s"%null_operators[op]
296 expr = "%s %s" % (name, null_operators[op])
297 297 else:
298 298 expr = "%s %s ?"%(name, op)
299 299 if isinstance(value, (tuple,list)):
300 300 if op in null_operators and any([v is None for v in value]):
301 301 # equality tests don't work with NULL
302 302 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
303 303 expr = '( %s )'%( join.join([expr]*len(value)) )
304 304 args.extend(value)
305 305 else:
306 306 args.append(value)
307 307 expressions.append(expr)
308 308 else:
309 309 # it's an equality check
310 310 if sub_check is None:
311 expressions.append("%s IS NULL")
311 expressions.append("%s IS NULL" % name)
312 312 else:
313 313 expressions.append("%s = ?"%name)
314 314 args.append(sub_check)
315 315
316 316 expr = " AND ".join(expressions)
317 317 return expr, args
318 318
319 319 def add_record(self, msg_id, rec):
320 320 """Add a new Task Record, by msg_id."""
321 321 d = self._defaults()
322 322 d.update(rec)
323 323 d['msg_id'] = msg_id
324 324 line = self._dict_to_list(d)
325 325 tups = '(%s)'%(','.join(['?']*len(line)))
326 326 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
327 327 # self._db.commit()
328 328
329 329 def get_record(self, msg_id):
330 330 """Get a specific Task Record, by msg_id."""
331 331 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
332 332 line = cursor.fetchone()
333 333 if line is None:
334 334 raise KeyError("No such msg: %r"%msg_id)
335 335 return self._list_to_dict(line)
336 336
337 337 def update_record(self, msg_id, rec):
338 338 """Update the data in an existing record."""
339 339 query = "UPDATE %s SET "%self.table
340 340 sets = []
341 341 keys = sorted(rec.keys())
342 342 values = []
343 343 for key in keys:
344 344 sets.append('%s = ?'%key)
345 345 values.append(rec[key])
346 346 query += ', '.join(sets)
347 347 query += ' WHERE msg_id == ?'
348 348 values.append(msg_id)
349 349 self._db.execute(query, values)
350 350 # self._db.commit()
351 351
352 352 def drop_record(self, msg_id):
353 353 """Remove a record from the DB."""
354 354 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
355 355 # self._db.commit()
356 356
357 357 def drop_matching_records(self, check):
358 358 """Remove a record from the DB."""
359 359 expr,args = self._render_expression(check)
360 360 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
361 361 self._db.execute(query,args)
362 362 # self._db.commit()
363 363
364 364 def find_records(self, check, keys=None):
365 365 """Find records matching a query dict, optionally extracting subset of keys.
366 366
367 367 Returns list of matching records.
368 368
369 369 Parameters
370 370 ----------
371 371
372 372 check: dict
373 373 mongodb-style query argument
374 374 keys: list of strs [optional]
375 375 if specified, the subset of keys to extract. msg_id will *always* be
376 376 included.
377 377 """
378 378 if keys:
379 379 bad_keys = [ key for key in keys if key not in self._keys ]
380 380 if bad_keys:
381 381 raise KeyError("Bad record key(s): %s"%bad_keys)
382 382
383 383 if keys:
384 384 # ensure msg_id is present and first:
385 385 if 'msg_id' in keys:
386 386 keys.remove('msg_id')
387 387 keys.insert(0, 'msg_id')
388 388 req = ', '.join(keys)
389 389 else:
390 390 req = '*'
391 391 expr,args = self._render_expression(check)
392 392 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
393 393 cursor = self._db.execute(query, args)
394 394 matches = cursor.fetchall()
395 395 records = []
396 396 for line in matches:
397 397 rec = self._list_to_dict(line, keys)
398 398 records.append(rec)
399 399 return records
400 400
401 401 def get_history(self):
402 402 """get all msg_ids, ordered by time submitted."""
403 403 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
404 404 cursor = self._db.execute(query)
405 405 # will be a list of length 1 tuples
406 406 return [ tup[0] for tup in cursor.fetchall()]
407 407
408 408 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,182 +1,194 b''
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import tempfile
22 22 import time
23 23
24 24 from datetime import datetime, timedelta
25 25 from unittest import TestCase
26 26
27 27 from IPython.parallel import error
28 28 from IPython.parallel.controller.dictdb import DictDB
29 29 from IPython.parallel.controller.sqlitedb import SQLiteDB
30 30 from IPython.parallel.controller.hub import init_record, empty_record
31 31
32 32 from IPython.testing import decorators as dec
33 33 from IPython.zmq.session import Session
34 34
35 35
36 36 #-------------------------------------------------------------------------------
37 37 # TestCases
38 38 #-------------------------------------------------------------------------------
39 39
40 40 class TestDictBackend(TestCase):
41 41 def setUp(self):
42 42 self.session = Session()
43 43 self.db = self.create_db()
44 44 self.load_records(16)
45 45
46 46 def create_db(self):
47 47 return DictDB()
48 48
49 49 def load_records(self, n=1):
50 50 """load n records for testing"""
51 51 #sleep 1/10 s, to ensure timestamp is different to previous calls
52 52 time.sleep(0.1)
53 53 msg_ids = []
54 54 for i in range(n):
55 55 msg = self.session.msg('apply_request', content=dict(a=5))
56 56 msg['buffers'] = []
57 57 rec = init_record(msg)
58 58 msg_id = msg['header']['msg_id']
59 59 msg_ids.append(msg_id)
60 60 self.db.add_record(msg_id, rec)
61 61 return msg_ids
62 62
63 63 def test_add_record(self):
64 64 before = self.db.get_history()
65 65 self.load_records(5)
66 66 after = self.db.get_history()
67 67 self.assertEquals(len(after), len(before)+5)
68 68 self.assertEquals(after[:-5],before)
69 69
70 70 def test_drop_record(self):
71 71 msg_id = self.load_records()[-1]
72 72 rec = self.db.get_record(msg_id)
73 73 self.db.drop_record(msg_id)
74 74 self.assertRaises(KeyError,self.db.get_record, msg_id)
75 75
76 76 def _round_to_millisecond(self, dt):
77 77 """necessary because mongodb rounds microseconds"""
78 78 micro = dt.microsecond
79 79 extra = int(str(micro)[-3:])
80 80 return dt - timedelta(microseconds=extra)
81 81
82 82 def test_update_record(self):
83 83 now = self._round_to_millisecond(datetime.now())
84 84 #
85 85 msg_id = self.db.get_history()[-1]
86 86 rec1 = self.db.get_record(msg_id)
87 87 data = {'stdout': 'hello there', 'completed' : now}
88 88 self.db.update_record(msg_id, data)
89 89 rec2 = self.db.get_record(msg_id)
90 90 self.assertEquals(rec2['stdout'], 'hello there')
91 91 self.assertEquals(rec2['completed'], now)
92 92 rec1.update(data)
93 93 self.assertEquals(rec1, rec2)
94 94
95 95 # def test_update_record_bad(self):
96 96 # """test updating nonexistant records"""
97 97 # msg_id = str(uuid.uuid4())
98 98 # data = {'stdout': 'hello there'}
99 99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
100 100
101 101 def test_find_records_dt(self):
102 102 """test finding records by date"""
103 103 hist = self.db.get_history()
104 104 middle = self.db.get_record(hist[len(hist)//2])
105 105 tic = middle['submitted']
106 106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
107 107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
108 108 self.assertEquals(len(before)+len(after),len(hist))
109 109 for b in before:
110 110 self.assertTrue(b['submitted'] < tic)
111 111 for a in after:
112 112 self.assertTrue(a['submitted'] >= tic)
113 113 same = self.db.find_records({'submitted' : tic})
114 114 for s in same:
115 115 self.assertTrue(s['submitted'] == tic)
116 116
117 117 def test_find_records_keys(self):
118 118 """test extracting subset of record keys"""
119 119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 120 for rec in found:
121 121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
122 122
123 123 def test_find_records_msg_id(self):
124 124 """ensure msg_id is always in found records"""
125 125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
126 126 for rec in found:
127 127 self.assertTrue('msg_id' in rec.keys())
128 128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
129 129 for rec in found:
130 130 self.assertTrue('msg_id' in rec.keys())
131 131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
132 132 for rec in found:
133 133 self.assertTrue('msg_id' in rec.keys())
134 134
135 135 def test_find_records_in(self):
136 136 """test finding records with '$in','$nin' operators"""
137 137 hist = self.db.get_history()
138 138 even = hist[::2]
139 139 odd = hist[1::2]
140 140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
141 141 found = [ r['msg_id'] for r in recs ]
142 142 self.assertEquals(set(even), set(found))
143 143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
144 144 found = [ r['msg_id'] for r in recs ]
145 145 self.assertEquals(set(odd), set(found))
146 146
147 147 def test_get_history(self):
148 148 msg_ids = self.db.get_history()
149 149 latest = datetime(1984,1,1)
150 150 for msg_id in msg_ids:
151 151 rec = self.db.get_record(msg_id)
152 152 newt = rec['submitted']
153 153 self.assertTrue(newt >= latest)
154 154 latest = newt
155 155 msg_id = self.load_records(1)[-1]
156 156 self.assertEquals(self.db.get_history()[-1],msg_id)
157 157
158 158 def test_datetime(self):
159 159 """get/set timestamps with datetime objects"""
160 160 msg_id = self.db.get_history()[-1]
161 161 rec = self.db.get_record(msg_id)
162 162 self.assertTrue(isinstance(rec['submitted'], datetime))
163 163 self.db.update_record(msg_id, dict(completed=datetime.now()))
164 164 rec = self.db.get_record(msg_id)
165 165 self.assertTrue(isinstance(rec['completed'], datetime))
166 166
167 167 def test_drop_matching(self):
168 168 msg_ids = self.load_records(10)
169 169 query = {'msg_id' : {'$in':msg_ids}}
170 170 self.db.drop_matching_records(query)
171 171 recs = self.db.find_records(query)
172 172 self.assertEquals(len(recs), 0)
173
174 def test_null(self):
175 """test None comparison queries"""
176 msg_ids = self.load_records(10)
177
178 query = {'msg_id' : None}
179 recs = self.db.find_records(query)
180 self.assertEquals(len(recs), 0)
181
182 query = {'msg_id' : {'$ne' : None}}
183 recs = self.db.find_records(query)
184 self.assertTrue(len(recs) >= 10)
173 185
174 186
175 187 class TestSQLiteBackend(TestDictBackend):
176 188
177 189 @dec.skip_without('sqlite3')
178 190 def create_db(self):
179 191 return SQLiteDB(location=tempfile.gettempdir())
180 192
181 193 def tearDown(self):
182 194 self.db._db.close()
General Comments 0
You need to be logged in to leave comments. Login now