##// END OF EJS Templates
fix null comparisons in sqlitedb backend...
MinRK -
Show More
@@ -1,408 +1,408 b''
1 """A TaskRecord backend using sqlite3
1 """A TaskRecord backend using sqlite3
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 import json
14 import json
15 import os
15 import os
16 import cPickle as pickle
16 import cPickle as pickle
17 from datetime import datetime
17 from datetime import datetime
18
18
19 try:
19 try:
20 import sqlite3
20 import sqlite3
21 except ImportError:
21 except ImportError:
22 sqlite3 = None
22 sqlite3 = None
23
23
24 from zmq.eventloop import ioloop
24 from zmq.eventloop import ioloop
25
25
26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
27 from .dictdb import BaseDB
27 from .dictdb import BaseDB
28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # SQLite operators, adapters, and converters
31 # SQLite operators, adapters, and converters
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 try:
34 try:
35 buffer
35 buffer
36 except NameError:
36 except NameError:
37 # py3k
37 # py3k
38 buffer = memoryview
38 buffer = memoryview
39
39
40 operators = {
40 operators = {
41 '$lt' : "<",
41 '$lt' : "<",
42 '$gt' : ">",
42 '$gt' : ">",
43 # null is handled weird with ==,!=
43 # null is handled weird with ==,!=
44 '$eq' : "=",
44 '$eq' : "=",
45 '$ne' : "!=",
45 '$ne' : "!=",
46 '$lte': "<=",
46 '$lte': "<=",
47 '$gte': ">=",
47 '$gte': ">=",
48 '$in' : ('=', ' OR '),
48 '$in' : ('=', ' OR '),
49 '$nin': ('!=', ' AND '),
49 '$nin': ('!=', ' AND '),
50 # '$all': None,
50 # '$all': None,
51 # '$mod': None,
51 # '$mod': None,
52 # '$exists' : None
52 # '$exists' : None
53 }
53 }
54 null_operators = {
54 null_operators = {
55 '=' : "IS NULL",
55 '=' : "IS NULL",
56 '!=' : "IS NOT NULL",
56 '!=' : "IS NOT NULL",
57 }
57 }
58
58
59 def _adapt_dict(d):
59 def _adapt_dict(d):
60 return json.dumps(d, default=date_default)
60 return json.dumps(d, default=date_default)
61
61
62 def _convert_dict(ds):
62 def _convert_dict(ds):
63 if ds is None:
63 if ds is None:
64 return ds
64 return ds
65 else:
65 else:
66 if isinstance(ds, bytes):
66 if isinstance(ds, bytes):
67 # If I understand the sqlite doc correctly, this will always be utf8
67 # If I understand the sqlite doc correctly, this will always be utf8
68 ds = ds.decode('utf8')
68 ds = ds.decode('utf8')
69 return extract_dates(json.loads(ds))
69 return extract_dates(json.loads(ds))
70
70
71 def _adapt_bufs(bufs):
71 def _adapt_bufs(bufs):
72 # this is *horrible*
72 # this is *horrible*
73 # copy buffers into single list and pickle it:
73 # copy buffers into single list and pickle it:
74 if bufs and isinstance(bufs[0], (bytes, buffer)):
74 if bufs and isinstance(bufs[0], (bytes, buffer)):
75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
76 elif bufs:
76 elif bufs:
77 return bufs
77 return bufs
78 else:
78 else:
79 return None
79 return None
80
80
81 def _convert_bufs(bs):
81 def _convert_bufs(bs):
82 if bs is None:
82 if bs is None:
83 return []
83 return []
84 else:
84 else:
85 return pickle.loads(bytes(bs))
85 return pickle.loads(bytes(bs))
86
86
87 #-----------------------------------------------------------------------------
87 #-----------------------------------------------------------------------------
88 # SQLiteDB class
88 # SQLiteDB class
89 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
90
90
91 class SQLiteDB(BaseDB):
91 class SQLiteDB(BaseDB):
92 """SQLite3 TaskRecord backend."""
92 """SQLite3 TaskRecord backend."""
93
93
94 filename = Unicode('tasks.db', config=True,
94 filename = Unicode('tasks.db', config=True,
95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
96 location = Unicode('', config=True,
96 location = Unicode('', config=True,
97 help="""The directory containing the sqlite task database. The default
97 help="""The directory containing the sqlite task database. The default
98 is to use the cluster_dir location.""")
98 is to use the cluster_dir location.""")
99 table = Unicode("", config=True,
99 table = Unicode("", config=True,
100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
101 a new table will be created with the Hub's IDENT. Specifying the table will result
101 a new table will be created with the Hub's IDENT. Specifying the table will result
102 in tasks from previous sessions being available via Clients' db_query and
102 in tasks from previous sessions being available via Clients' db_query and
103 get_result methods.""")
103 get_result methods.""")
104
104
105 if sqlite3 is not None:
105 if sqlite3 is not None:
106 _db = Instance('sqlite3.Connection')
106 _db = Instance('sqlite3.Connection')
107 else:
107 else:
108 _db = None
108 _db = None
109 # the ordered list of column names
109 # the ordered list of column names
110 _keys = List(['msg_id' ,
110 _keys = List(['msg_id' ,
111 'header' ,
111 'header' ,
112 'content',
112 'content',
113 'buffers',
113 'buffers',
114 'submitted',
114 'submitted',
115 'client_uuid' ,
115 'client_uuid' ,
116 'engine_uuid' ,
116 'engine_uuid' ,
117 'started',
117 'started',
118 'completed',
118 'completed',
119 'resubmitted',
119 'resubmitted',
120 'result_header' ,
120 'result_header' ,
121 'result_content' ,
121 'result_content' ,
122 'result_buffers' ,
122 'result_buffers' ,
123 'queue' ,
123 'queue' ,
124 'pyin' ,
124 'pyin' ,
125 'pyout',
125 'pyout',
126 'pyerr',
126 'pyerr',
127 'stdout',
127 'stdout',
128 'stderr',
128 'stderr',
129 ])
129 ])
130 # sqlite datatypes for checking that db is current format
130 # sqlite datatypes for checking that db is current format
131 _types = Dict({'msg_id' : 'text' ,
131 _types = Dict({'msg_id' : 'text' ,
132 'header' : 'dict text',
132 'header' : 'dict text',
133 'content' : 'dict text',
133 'content' : 'dict text',
134 'buffers' : 'bufs blob',
134 'buffers' : 'bufs blob',
135 'submitted' : 'timestamp',
135 'submitted' : 'timestamp',
136 'client_uuid' : 'text',
136 'client_uuid' : 'text',
137 'engine_uuid' : 'text',
137 'engine_uuid' : 'text',
138 'started' : 'timestamp',
138 'started' : 'timestamp',
139 'completed' : 'timestamp',
139 'completed' : 'timestamp',
140 'resubmitted' : 'timestamp',
140 'resubmitted' : 'timestamp',
141 'result_header' : 'dict text',
141 'result_header' : 'dict text',
142 'result_content' : 'dict text',
142 'result_content' : 'dict text',
143 'result_buffers' : 'bufs blob',
143 'result_buffers' : 'bufs blob',
144 'queue' : 'text',
144 'queue' : 'text',
145 'pyin' : 'text',
145 'pyin' : 'text',
146 'pyout' : 'text',
146 'pyout' : 'text',
147 'pyerr' : 'text',
147 'pyerr' : 'text',
148 'stdout' : 'text',
148 'stdout' : 'text',
149 'stderr' : 'text',
149 'stderr' : 'text',
150 })
150 })
151
151
152 def __init__(self, **kwargs):
152 def __init__(self, **kwargs):
153 super(SQLiteDB, self).__init__(**kwargs)
153 super(SQLiteDB, self).__init__(**kwargs)
154 if sqlite3 is None:
154 if sqlite3 is None:
155 raise ImportError("SQLiteDB requires sqlite3")
155 raise ImportError("SQLiteDB requires sqlite3")
156 if not self.table:
156 if not self.table:
157 # use session, and prefix _, since starting with # is illegal
157 # use session, and prefix _, since starting with # is illegal
158 self.table = '_'+self.session.replace('-','_')
158 self.table = '_'+self.session.replace('-','_')
159 if not self.location:
159 if not self.location:
160 # get current profile
160 # get current profile
161 from IPython.core.application import BaseIPythonApplication
161 from IPython.core.application import BaseIPythonApplication
162 if BaseIPythonApplication.initialized():
162 if BaseIPythonApplication.initialized():
163 app = BaseIPythonApplication.instance()
163 app = BaseIPythonApplication.instance()
164 if app.profile_dir is not None:
164 if app.profile_dir is not None:
165 self.location = app.profile_dir.location
165 self.location = app.profile_dir.location
166 else:
166 else:
167 self.location = u'.'
167 self.location = u'.'
168 else:
168 else:
169 self.location = u'.'
169 self.location = u'.'
170 self._init_db()
170 self._init_db()
171
171
172 # register db commit as 2s periodic callback
172 # register db commit as 2s periodic callback
173 # to prevent clogging pipes
173 # to prevent clogging pipes
174 # assumes we are being run in a zmq ioloop app
174 # assumes we are being run in a zmq ioloop app
175 loop = ioloop.IOLoop.instance()
175 loop = ioloop.IOLoop.instance()
176 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
176 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
177 pc.start()
177 pc.start()
178
178
179 def _defaults(self, keys=None):
179 def _defaults(self, keys=None):
180 """create an empty record"""
180 """create an empty record"""
181 d = {}
181 d = {}
182 keys = self._keys if keys is None else keys
182 keys = self._keys if keys is None else keys
183 for key in keys:
183 for key in keys:
184 d[key] = None
184 d[key] = None
185 return d
185 return d
186
186
187 def _check_table(self):
187 def _check_table(self):
188 """Ensure that an incorrect table doesn't exist
188 """Ensure that an incorrect table doesn't exist
189
189
190 If a bad (old) table does exist, return False
190 If a bad (old) table does exist, return False
191 """
191 """
192 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
192 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
193 lines = cursor.fetchall()
193 lines = cursor.fetchall()
194 if not lines:
194 if not lines:
195 # table does not exist
195 # table does not exist
196 return True
196 return True
197 types = {}
197 types = {}
198 keys = []
198 keys = []
199 for line in lines:
199 for line in lines:
200 keys.append(line[1])
200 keys.append(line[1])
201 types[line[1]] = line[2]
201 types[line[1]] = line[2]
202 if self._keys != keys:
202 if self._keys != keys:
203 # key mismatch
203 # key mismatch
204 self.log.warn('keys mismatch')
204 self.log.warn('keys mismatch')
205 return False
205 return False
206 for key in self._keys:
206 for key in self._keys:
207 if types[key] != self._types[key]:
207 if types[key] != self._types[key]:
208 self.log.warn(
208 self.log.warn(
209 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
209 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
210 )
210 )
211 return False
211 return False
212 return True
212 return True
213
213
214 def _init_db(self):
214 def _init_db(self):
215 """Connect to the database and get new session number."""
215 """Connect to the database and get new session number."""
216 # register adapters
216 # register adapters
217 sqlite3.register_adapter(dict, _adapt_dict)
217 sqlite3.register_adapter(dict, _adapt_dict)
218 sqlite3.register_converter('dict', _convert_dict)
218 sqlite3.register_converter('dict', _convert_dict)
219 sqlite3.register_adapter(list, _adapt_bufs)
219 sqlite3.register_adapter(list, _adapt_bufs)
220 sqlite3.register_converter('bufs', _convert_bufs)
220 sqlite3.register_converter('bufs', _convert_bufs)
221 # connect to the db
221 # connect to the db
222 dbfile = os.path.join(self.location, self.filename)
222 dbfile = os.path.join(self.location, self.filename)
223 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
223 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
224 # isolation_level = None)#,
224 # isolation_level = None)#,
225 cached_statements=64)
225 cached_statements=64)
226 # print dir(self._db)
226 # print dir(self._db)
227 first_table = self.table
227 first_table = self.table
228 i=0
228 i=0
229 while not self._check_table():
229 while not self._check_table():
230 i+=1
230 i+=1
231 self.table = first_table+'_%i'%i
231 self.table = first_table+'_%i'%i
232 self.log.warn(
232 self.log.warn(
233 "Table %s exists and doesn't match db format, trying %s"%
233 "Table %s exists and doesn't match db format, trying %s"%
234 (first_table,self.table)
234 (first_table,self.table)
235 )
235 )
236
236
237 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
237 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
238 (msg_id text PRIMARY KEY,
238 (msg_id text PRIMARY KEY,
239 header dict text,
239 header dict text,
240 content dict text,
240 content dict text,
241 buffers bufs blob,
241 buffers bufs blob,
242 submitted timestamp,
242 submitted timestamp,
243 client_uuid text,
243 client_uuid text,
244 engine_uuid text,
244 engine_uuid text,
245 started timestamp,
245 started timestamp,
246 completed timestamp,
246 completed timestamp,
247 resubmitted timestamp,
247 resubmitted timestamp,
248 result_header dict text,
248 result_header dict text,
249 result_content dict text,
249 result_content dict text,
250 result_buffers bufs blob,
250 result_buffers bufs blob,
251 queue text,
251 queue text,
252 pyin text,
252 pyin text,
253 pyout text,
253 pyout text,
254 pyerr text,
254 pyerr text,
255 stdout text,
255 stdout text,
256 stderr text)
256 stderr text)
257 """%self.table)
257 """%self.table)
258 self._db.commit()
258 self._db.commit()
259
259
260 def _dict_to_list(self, d):
260 def _dict_to_list(self, d):
261 """turn a mongodb-style record dict into a list."""
261 """turn a mongodb-style record dict into a list."""
262
262
263 return [ d[key] for key in self._keys ]
263 return [ d[key] for key in self._keys ]
264
264
265 def _list_to_dict(self, line, keys=None):
265 def _list_to_dict(self, line, keys=None):
266 """Inverse of dict_to_list"""
266 """Inverse of dict_to_list"""
267 keys = self._keys if keys is None else keys
267 keys = self._keys if keys is None else keys
268 d = self._defaults(keys)
268 d = self._defaults(keys)
269 for key,value in zip(keys, line):
269 for key,value in zip(keys, line):
270 d[key] = value
270 d[key] = value
271
271
272 return d
272 return d
273
273
274 def _render_expression(self, check):
274 def _render_expression(self, check):
275 """Turn a mongodb-style search dict into an SQL query."""
275 """Turn a mongodb-style search dict into an SQL query."""
276 expressions = []
276 expressions = []
277 args = []
277 args = []
278
278
279 skeys = set(check.keys())
279 skeys = set(check.keys())
280 skeys.difference_update(set(self._keys))
280 skeys.difference_update(set(self._keys))
281 skeys.difference_update(set(['buffers', 'result_buffers']))
281 skeys.difference_update(set(['buffers', 'result_buffers']))
282 if skeys:
282 if skeys:
283 raise KeyError("Illegal testing key(s): %s"%skeys)
283 raise KeyError("Illegal testing key(s): %s"%skeys)
284
284
285 for name,sub_check in check.iteritems():
285 for name,sub_check in check.iteritems():
286 if isinstance(sub_check, dict):
286 if isinstance(sub_check, dict):
287 for test,value in sub_check.iteritems():
287 for test,value in sub_check.iteritems():
288 try:
288 try:
289 op = operators[test]
289 op = operators[test]
290 except KeyError:
290 except KeyError:
291 raise KeyError("Unsupported operator: %r"%test)
291 raise KeyError("Unsupported operator: %r"%test)
292 if isinstance(op, tuple):
292 if isinstance(op, tuple):
293 op, join = op
293 op, join = op
294
294
295 if value is None and op in null_operators:
295 if value is None and op in null_operators:
296 expr = "%s %s"%null_operators[op]
296 expr = "%s %s" % (name, null_operators[op])
297 else:
297 else:
298 expr = "%s %s ?"%(name, op)
298 expr = "%s %s ?"%(name, op)
299 if isinstance(value, (tuple,list)):
299 if isinstance(value, (tuple,list)):
300 if op in null_operators and any([v is None for v in value]):
300 if op in null_operators and any([v is None for v in value]):
301 # equality tests don't work with NULL
301 # equality tests don't work with NULL
302 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
302 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
303 expr = '( %s )'%( join.join([expr]*len(value)) )
303 expr = '( %s )'%( join.join([expr]*len(value)) )
304 args.extend(value)
304 args.extend(value)
305 else:
305 else:
306 args.append(value)
306 args.append(value)
307 expressions.append(expr)
307 expressions.append(expr)
308 else:
308 else:
309 # it's an equality check
309 # it's an equality check
310 if sub_check is None:
310 if sub_check is None:
311 expressions.append("%s IS NULL")
311 expressions.append("%s IS NULL" % name)
312 else:
312 else:
313 expressions.append("%s = ?"%name)
313 expressions.append("%s = ?"%name)
314 args.append(sub_check)
314 args.append(sub_check)
315
315
316 expr = " AND ".join(expressions)
316 expr = " AND ".join(expressions)
317 return expr, args
317 return expr, args
318
318
319 def add_record(self, msg_id, rec):
319 def add_record(self, msg_id, rec):
320 """Add a new Task Record, by msg_id."""
320 """Add a new Task Record, by msg_id."""
321 d = self._defaults()
321 d = self._defaults()
322 d.update(rec)
322 d.update(rec)
323 d['msg_id'] = msg_id
323 d['msg_id'] = msg_id
324 line = self._dict_to_list(d)
324 line = self._dict_to_list(d)
325 tups = '(%s)'%(','.join(['?']*len(line)))
325 tups = '(%s)'%(','.join(['?']*len(line)))
326 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
326 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
327 # self._db.commit()
327 # self._db.commit()
328
328
329 def get_record(self, msg_id):
329 def get_record(self, msg_id):
330 """Get a specific Task Record, by msg_id."""
330 """Get a specific Task Record, by msg_id."""
331 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
331 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
332 line = cursor.fetchone()
332 line = cursor.fetchone()
333 if line is None:
333 if line is None:
334 raise KeyError("No such msg: %r"%msg_id)
334 raise KeyError("No such msg: %r"%msg_id)
335 return self._list_to_dict(line)
335 return self._list_to_dict(line)
336
336
337 def update_record(self, msg_id, rec):
337 def update_record(self, msg_id, rec):
338 """Update the data in an existing record."""
338 """Update the data in an existing record."""
339 query = "UPDATE %s SET "%self.table
339 query = "UPDATE %s SET "%self.table
340 sets = []
340 sets = []
341 keys = sorted(rec.keys())
341 keys = sorted(rec.keys())
342 values = []
342 values = []
343 for key in keys:
343 for key in keys:
344 sets.append('%s = ?'%key)
344 sets.append('%s = ?'%key)
345 values.append(rec[key])
345 values.append(rec[key])
346 query += ', '.join(sets)
346 query += ', '.join(sets)
347 query += ' WHERE msg_id == ?'
347 query += ' WHERE msg_id == ?'
348 values.append(msg_id)
348 values.append(msg_id)
349 self._db.execute(query, values)
349 self._db.execute(query, values)
350 # self._db.commit()
350 # self._db.commit()
351
351
352 def drop_record(self, msg_id):
352 def drop_record(self, msg_id):
353 """Remove a record from the DB."""
353 """Remove a record from the DB."""
354 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
354 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
355 # self._db.commit()
355 # self._db.commit()
356
356
357 def drop_matching_records(self, check):
357 def drop_matching_records(self, check):
358 """Remove a record from the DB."""
358 """Remove a record from the DB."""
359 expr,args = self._render_expression(check)
359 expr,args = self._render_expression(check)
360 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
360 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
361 self._db.execute(query,args)
361 self._db.execute(query,args)
362 # self._db.commit()
362 # self._db.commit()
363
363
364 def find_records(self, check, keys=None):
364 def find_records(self, check, keys=None):
365 """Find records matching a query dict, optionally extracting subset of keys.
365 """Find records matching a query dict, optionally extracting subset of keys.
366
366
367 Returns list of matching records.
367 Returns list of matching records.
368
368
369 Parameters
369 Parameters
370 ----------
370 ----------
371
371
372 check: dict
372 check: dict
373 mongodb-style query argument
373 mongodb-style query argument
374 keys: list of strs [optional]
374 keys: list of strs [optional]
375 if specified, the subset of keys to extract. msg_id will *always* be
375 if specified, the subset of keys to extract. msg_id will *always* be
376 included.
376 included.
377 """
377 """
378 if keys:
378 if keys:
379 bad_keys = [ key for key in keys if key not in self._keys ]
379 bad_keys = [ key for key in keys if key not in self._keys ]
380 if bad_keys:
380 if bad_keys:
381 raise KeyError("Bad record key(s): %s"%bad_keys)
381 raise KeyError("Bad record key(s): %s"%bad_keys)
382
382
383 if keys:
383 if keys:
384 # ensure msg_id is present and first:
384 # ensure msg_id is present and first:
385 if 'msg_id' in keys:
385 if 'msg_id' in keys:
386 keys.remove('msg_id')
386 keys.remove('msg_id')
387 keys.insert(0, 'msg_id')
387 keys.insert(0, 'msg_id')
388 req = ', '.join(keys)
388 req = ', '.join(keys)
389 else:
389 else:
390 req = '*'
390 req = '*'
391 expr,args = self._render_expression(check)
391 expr,args = self._render_expression(check)
392 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
392 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
393 cursor = self._db.execute(query, args)
393 cursor = self._db.execute(query, args)
394 matches = cursor.fetchall()
394 matches = cursor.fetchall()
395 records = []
395 records = []
396 for line in matches:
396 for line in matches:
397 rec = self._list_to_dict(line, keys)
397 rec = self._list_to_dict(line, keys)
398 records.append(rec)
398 records.append(rec)
399 return records
399 return records
400
400
401 def get_history(self):
401 def get_history(self):
402 """get all msg_ids, ordered by time submitted."""
402 """get all msg_ids, ordered by time submitted."""
403 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
403 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
404 cursor = self._db.execute(query)
404 cursor = self._db.execute(query)
405 # will be a list of length 1 tuples
405 # will be a list of length 1 tuples
406 return [ tup[0] for tup in cursor.fetchall()]
406 return [ tup[0] for tup in cursor.fetchall()]
407
407
408 __all__ = ['SQLiteDB'] No newline at end of file
408 __all__ = ['SQLiteDB']
@@ -1,182 +1,194 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import tempfile
21 import tempfile
22 import time
22 import time
23
23
24 from datetime import datetime, timedelta
24 from datetime import datetime, timedelta
25 from unittest import TestCase
25 from unittest import TestCase
26
26
27 from IPython.parallel import error
27 from IPython.parallel import error
28 from IPython.parallel.controller.dictdb import DictDB
28 from IPython.parallel.controller.dictdb import DictDB
29 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 from IPython.parallel.controller.sqlitedb import SQLiteDB
30 from IPython.parallel.controller.hub import init_record, empty_record
30 from IPython.parallel.controller.hub import init_record, empty_record
31
31
32 from IPython.testing import decorators as dec
32 from IPython.testing import decorators as dec
33 from IPython.zmq.session import Session
33 from IPython.zmq.session import Session
34
34
35
35
36 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
37 # TestCases
37 # TestCases
38 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
39
39
40 class TestDictBackend(TestCase):
40 class TestDictBackend(TestCase):
41 def setUp(self):
41 def setUp(self):
42 self.session = Session()
42 self.session = Session()
43 self.db = self.create_db()
43 self.db = self.create_db()
44 self.load_records(16)
44 self.load_records(16)
45
45
46 def create_db(self):
46 def create_db(self):
47 return DictDB()
47 return DictDB()
48
48
49 def load_records(self, n=1):
49 def load_records(self, n=1):
50 """load n records for testing"""
50 """load n records for testing"""
51 #sleep 1/10 s, to ensure timestamp is different to previous calls
51 #sleep 1/10 s, to ensure timestamp is different to previous calls
52 time.sleep(0.1)
52 time.sleep(0.1)
53 msg_ids = []
53 msg_ids = []
54 for i in range(n):
54 for i in range(n):
55 msg = self.session.msg('apply_request', content=dict(a=5))
55 msg = self.session.msg('apply_request', content=dict(a=5))
56 msg['buffers'] = []
56 msg['buffers'] = []
57 rec = init_record(msg)
57 rec = init_record(msg)
58 msg_id = msg['header']['msg_id']
58 msg_id = msg['header']['msg_id']
59 msg_ids.append(msg_id)
59 msg_ids.append(msg_id)
60 self.db.add_record(msg_id, rec)
60 self.db.add_record(msg_id, rec)
61 return msg_ids
61 return msg_ids
62
62
63 def test_add_record(self):
63 def test_add_record(self):
64 before = self.db.get_history()
64 before = self.db.get_history()
65 self.load_records(5)
65 self.load_records(5)
66 after = self.db.get_history()
66 after = self.db.get_history()
67 self.assertEquals(len(after), len(before)+5)
67 self.assertEquals(len(after), len(before)+5)
68 self.assertEquals(after[:-5],before)
68 self.assertEquals(after[:-5],before)
69
69
70 def test_drop_record(self):
70 def test_drop_record(self):
71 msg_id = self.load_records()[-1]
71 msg_id = self.load_records()[-1]
72 rec = self.db.get_record(msg_id)
72 rec = self.db.get_record(msg_id)
73 self.db.drop_record(msg_id)
73 self.db.drop_record(msg_id)
74 self.assertRaises(KeyError,self.db.get_record, msg_id)
74 self.assertRaises(KeyError,self.db.get_record, msg_id)
75
75
76 def _round_to_millisecond(self, dt):
76 def _round_to_millisecond(self, dt):
77 """necessary because mongodb rounds microseconds"""
77 """necessary because mongodb rounds microseconds"""
78 micro = dt.microsecond
78 micro = dt.microsecond
79 extra = int(str(micro)[-3:])
79 extra = int(str(micro)[-3:])
80 return dt - timedelta(microseconds=extra)
80 return dt - timedelta(microseconds=extra)
81
81
82 def test_update_record(self):
82 def test_update_record(self):
83 now = self._round_to_millisecond(datetime.now())
83 now = self._round_to_millisecond(datetime.now())
84 #
84 #
85 msg_id = self.db.get_history()[-1]
85 msg_id = self.db.get_history()[-1]
86 rec1 = self.db.get_record(msg_id)
86 rec1 = self.db.get_record(msg_id)
87 data = {'stdout': 'hello there', 'completed' : now}
87 data = {'stdout': 'hello there', 'completed' : now}
88 self.db.update_record(msg_id, data)
88 self.db.update_record(msg_id, data)
89 rec2 = self.db.get_record(msg_id)
89 rec2 = self.db.get_record(msg_id)
90 self.assertEquals(rec2['stdout'], 'hello there')
90 self.assertEquals(rec2['stdout'], 'hello there')
91 self.assertEquals(rec2['completed'], now)
91 self.assertEquals(rec2['completed'], now)
92 rec1.update(data)
92 rec1.update(data)
93 self.assertEquals(rec1, rec2)
93 self.assertEquals(rec1, rec2)
94
94
95 # def test_update_record_bad(self):
95 # def test_update_record_bad(self):
96 # """test updating nonexistant records"""
96 # """test updating nonexistant records"""
97 # msg_id = str(uuid.uuid4())
97 # msg_id = str(uuid.uuid4())
98 # data = {'stdout': 'hello there'}
98 # data = {'stdout': 'hello there'}
99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
100
100
101 def test_find_records_dt(self):
101 def test_find_records_dt(self):
102 """test finding records by date"""
102 """test finding records by date"""
103 hist = self.db.get_history()
103 hist = self.db.get_history()
104 middle = self.db.get_record(hist[len(hist)//2])
104 middle = self.db.get_record(hist[len(hist)//2])
105 tic = middle['submitted']
105 tic = middle['submitted']
106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
108 self.assertEquals(len(before)+len(after),len(hist))
108 self.assertEquals(len(before)+len(after),len(hist))
109 for b in before:
109 for b in before:
110 self.assertTrue(b['submitted'] < tic)
110 self.assertTrue(b['submitted'] < tic)
111 for a in after:
111 for a in after:
112 self.assertTrue(a['submitted'] >= tic)
112 self.assertTrue(a['submitted'] >= tic)
113 same = self.db.find_records({'submitted' : tic})
113 same = self.db.find_records({'submitted' : tic})
114 for s in same:
114 for s in same:
115 self.assertTrue(s['submitted'] == tic)
115 self.assertTrue(s['submitted'] == tic)
116
116
117 def test_find_records_keys(self):
117 def test_find_records_keys(self):
118 """test extracting subset of record keys"""
118 """test extracting subset of record keys"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
120 for rec in found:
121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
122
122
123 def test_find_records_msg_id(self):
123 def test_find_records_msg_id(self):
124 """ensure msg_id is always in found records"""
124 """ensure msg_id is always in found records"""
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
126 for rec in found:
126 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
127 self.assertTrue('msg_id' in rec.keys())
128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
129 for rec in found:
129 for rec in found:
130 self.assertTrue('msg_id' in rec.keys())
130 self.assertTrue('msg_id' in rec.keys())
131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
132 for rec in found:
132 for rec in found:
133 self.assertTrue('msg_id' in rec.keys())
133 self.assertTrue('msg_id' in rec.keys())
134
134
135 def test_find_records_in(self):
135 def test_find_records_in(self):
136 """test finding records with '$in','$nin' operators"""
136 """test finding records with '$in','$nin' operators"""
137 hist = self.db.get_history()
137 hist = self.db.get_history()
138 even = hist[::2]
138 even = hist[::2]
139 odd = hist[1::2]
139 odd = hist[1::2]
140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
141 found = [ r['msg_id'] for r in recs ]
141 found = [ r['msg_id'] for r in recs ]
142 self.assertEquals(set(even), set(found))
142 self.assertEquals(set(even), set(found))
143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
144 found = [ r['msg_id'] for r in recs ]
144 found = [ r['msg_id'] for r in recs ]
145 self.assertEquals(set(odd), set(found))
145 self.assertEquals(set(odd), set(found))
146
146
147 def test_get_history(self):
147 def test_get_history(self):
148 msg_ids = self.db.get_history()
148 msg_ids = self.db.get_history()
149 latest = datetime(1984,1,1)
149 latest = datetime(1984,1,1)
150 for msg_id in msg_ids:
150 for msg_id in msg_ids:
151 rec = self.db.get_record(msg_id)
151 rec = self.db.get_record(msg_id)
152 newt = rec['submitted']
152 newt = rec['submitted']
153 self.assertTrue(newt >= latest)
153 self.assertTrue(newt >= latest)
154 latest = newt
154 latest = newt
155 msg_id = self.load_records(1)[-1]
155 msg_id = self.load_records(1)[-1]
156 self.assertEquals(self.db.get_history()[-1],msg_id)
156 self.assertEquals(self.db.get_history()[-1],msg_id)
157
157
158 def test_datetime(self):
158 def test_datetime(self):
159 """get/set timestamps with datetime objects"""
159 """get/set timestamps with datetime objects"""
160 msg_id = self.db.get_history()[-1]
160 msg_id = self.db.get_history()[-1]
161 rec = self.db.get_record(msg_id)
161 rec = self.db.get_record(msg_id)
162 self.assertTrue(isinstance(rec['submitted'], datetime))
162 self.assertTrue(isinstance(rec['submitted'], datetime))
163 self.db.update_record(msg_id, dict(completed=datetime.now()))
163 self.db.update_record(msg_id, dict(completed=datetime.now()))
164 rec = self.db.get_record(msg_id)
164 rec = self.db.get_record(msg_id)
165 self.assertTrue(isinstance(rec['completed'], datetime))
165 self.assertTrue(isinstance(rec['completed'], datetime))
166
166
167 def test_drop_matching(self):
167 def test_drop_matching(self):
168 msg_ids = self.load_records(10)
168 msg_ids = self.load_records(10)
169 query = {'msg_id' : {'$in':msg_ids}}
169 query = {'msg_id' : {'$in':msg_ids}}
170 self.db.drop_matching_records(query)
170 self.db.drop_matching_records(query)
171 recs = self.db.find_records(query)
171 recs = self.db.find_records(query)
172 self.assertEquals(len(recs), 0)
172 self.assertEquals(len(recs), 0)
173
173
174 def test_null(self):
175 """test None comparison queries"""
176 msg_ids = self.load_records(10)
177
178 query = {'msg_id' : None}
179 recs = self.db.find_records(query)
180 self.assertEquals(len(recs), 0)
181
182 query = {'msg_id' : {'$ne' : None}}
183 recs = self.db.find_records(query)
184 self.assertTrue(len(recs) >= 10)
185
174
186
175 class TestSQLiteBackend(TestDictBackend):
187 class TestSQLiteBackend(TestDictBackend):
176
188
177 @dec.skip_without('sqlite3')
189 @dec.skip_without('sqlite3')
178 def create_db(self):
190 def create_db(self):
179 return SQLiteDB(location=tempfile.gettempdir())
191 return SQLiteDB(location=tempfile.gettempdir())
180
192
181 def tearDown(self):
193 def tearDown(self):
182 self.db._db.close()
194 self.db._db.close()
General Comments 0
You need to be logged in to leave comments. Login now