##// END OF EJS Templates
fix message when trying new table due to mismatch in SQLiteDB...
MinRK -
Show More
@@ -1,411 +1,412 b''
1 """A TaskRecord backend using sqlite3
1 """A TaskRecord backend using sqlite3
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 import json
14 import json
15 import os
15 import os
16 import cPickle as pickle
16 import cPickle as pickle
17 from datetime import datetime
17 from datetime import datetime
18
18
19 try:
19 try:
20 import sqlite3
20 import sqlite3
21 except ImportError:
21 except ImportError:
22 sqlite3 = None
22 sqlite3 = None
23
23
24 from zmq.eventloop import ioloop
24 from zmq.eventloop import ioloop
25
25
26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
27 from .dictdb import BaseDB
27 from .dictdb import BaseDB
28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # SQLite operators, adapters, and converters
31 # SQLite operators, adapters, and converters
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 try:
34 try:
35 buffer
35 buffer
36 except NameError:
36 except NameError:
37 # py3k
37 # py3k
38 buffer = memoryview
38 buffer = memoryview
39
39
40 operators = {
40 operators = {
41 '$lt' : "<",
41 '$lt' : "<",
42 '$gt' : ">",
42 '$gt' : ">",
43 # null is handled weird with ==,!=
43 # null is handled weird with ==,!=
44 '$eq' : "=",
44 '$eq' : "=",
45 '$ne' : "!=",
45 '$ne' : "!=",
46 '$lte': "<=",
46 '$lte': "<=",
47 '$gte': ">=",
47 '$gte': ">=",
48 '$in' : ('=', ' OR '),
48 '$in' : ('=', ' OR '),
49 '$nin': ('!=', ' AND '),
49 '$nin': ('!=', ' AND '),
50 # '$all': None,
50 # '$all': None,
51 # '$mod': None,
51 # '$mod': None,
52 # '$exists' : None
52 # '$exists' : None
53 }
53 }
54 null_operators = {
54 null_operators = {
55 '=' : "IS NULL",
55 '=' : "IS NULL",
56 '!=' : "IS NOT NULL",
56 '!=' : "IS NOT NULL",
57 }
57 }
58
58
59 def _adapt_dict(d):
59 def _adapt_dict(d):
60 return json.dumps(d, default=date_default)
60 return json.dumps(d, default=date_default)
61
61
62 def _convert_dict(ds):
62 def _convert_dict(ds):
63 if ds is None:
63 if ds is None:
64 return ds
64 return ds
65 else:
65 else:
66 if isinstance(ds, bytes):
66 if isinstance(ds, bytes):
67 # If I understand the sqlite doc correctly, this will always be utf8
67 # If I understand the sqlite doc correctly, this will always be utf8
68 ds = ds.decode('utf8')
68 ds = ds.decode('utf8')
69 return extract_dates(json.loads(ds))
69 return extract_dates(json.loads(ds))
70
70
71 def _adapt_bufs(bufs):
71 def _adapt_bufs(bufs):
72 # this is *horrible*
72 # this is *horrible*
73 # copy buffers into single list and pickle it:
73 # copy buffers into single list and pickle it:
74 if bufs and isinstance(bufs[0], (bytes, buffer)):
74 if bufs and isinstance(bufs[0], (bytes, buffer)):
75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
76 elif bufs:
76 elif bufs:
77 return bufs
77 return bufs
78 else:
78 else:
79 return None
79 return None
80
80
81 def _convert_bufs(bs):
81 def _convert_bufs(bs):
82 if bs is None:
82 if bs is None:
83 return []
83 return []
84 else:
84 else:
85 return pickle.loads(bytes(bs))
85 return pickle.loads(bytes(bs))
86
86
87 #-----------------------------------------------------------------------------
87 #-----------------------------------------------------------------------------
88 # SQLiteDB class
88 # SQLiteDB class
89 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
90
90
91 class SQLiteDB(BaseDB):
91 class SQLiteDB(BaseDB):
92 """SQLite3 TaskRecord backend."""
92 """SQLite3 TaskRecord backend."""
93
93
94 filename = Unicode('tasks.db', config=True,
94 filename = Unicode('tasks.db', config=True,
95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
96 location = Unicode('', config=True,
96 location = Unicode('', config=True,
97 help="""The directory containing the sqlite task database. The default
97 help="""The directory containing the sqlite task database. The default
98 is to use the cluster_dir location.""")
98 is to use the cluster_dir location.""")
99 table = Unicode("", config=True,
99 table = Unicode("", config=True,
100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
101 a new table will be created with the Hub's IDENT. Specifying the table will result
101 a new table will be created with the Hub's IDENT. Specifying the table will result
102 in tasks from previous sessions being available via Clients' db_query and
102 in tasks from previous sessions being available via Clients' db_query and
103 get_result methods.""")
103 get_result methods.""")
104
104
105 if sqlite3 is not None:
105 if sqlite3 is not None:
106 _db = Instance('sqlite3.Connection')
106 _db = Instance('sqlite3.Connection')
107 else:
107 else:
108 _db = None
108 _db = None
109 # the ordered list of column names
109 # the ordered list of column names
110 _keys = List(['msg_id' ,
110 _keys = List(['msg_id' ,
111 'header' ,
111 'header' ,
112 'content',
112 'content',
113 'buffers',
113 'buffers',
114 'submitted',
114 'submitted',
115 'client_uuid' ,
115 'client_uuid' ,
116 'engine_uuid' ,
116 'engine_uuid' ,
117 'started',
117 'started',
118 'completed',
118 'completed',
119 'resubmitted',
119 'resubmitted',
120 'received',
120 'received',
121 'result_header' ,
121 'result_header' ,
122 'result_content' ,
122 'result_content' ,
123 'result_buffers' ,
123 'result_buffers' ,
124 'queue' ,
124 'queue' ,
125 'pyin' ,
125 'pyin' ,
126 'pyout',
126 'pyout',
127 'pyerr',
127 'pyerr',
128 'stdout',
128 'stdout',
129 'stderr',
129 'stderr',
130 ])
130 ])
131 # sqlite datatypes for checking that db is current format
131 # sqlite datatypes for checking that db is current format
132 _types = Dict({'msg_id' : 'text' ,
132 _types = Dict({'msg_id' : 'text' ,
133 'header' : 'dict text',
133 'header' : 'dict text',
134 'content' : 'dict text',
134 'content' : 'dict text',
135 'buffers' : 'bufs blob',
135 'buffers' : 'bufs blob',
136 'submitted' : 'timestamp',
136 'submitted' : 'timestamp',
137 'client_uuid' : 'text',
137 'client_uuid' : 'text',
138 'engine_uuid' : 'text',
138 'engine_uuid' : 'text',
139 'started' : 'timestamp',
139 'started' : 'timestamp',
140 'completed' : 'timestamp',
140 'completed' : 'timestamp',
141 'resubmitted' : 'timestamp',
141 'resubmitted' : 'timestamp',
142 'received' : 'timestamp',
142 'received' : 'timestamp',
143 'result_header' : 'dict text',
143 'result_header' : 'dict text',
144 'result_content' : 'dict text',
144 'result_content' : 'dict text',
145 'result_buffers' : 'bufs blob',
145 'result_buffers' : 'bufs blob',
146 'queue' : 'text',
146 'queue' : 'text',
147 'pyin' : 'text',
147 'pyin' : 'text',
148 'pyout' : 'text',
148 'pyout' : 'text',
149 'pyerr' : 'text',
149 'pyerr' : 'text',
150 'stdout' : 'text',
150 'stdout' : 'text',
151 'stderr' : 'text',
151 'stderr' : 'text',
152 })
152 })
153
153
154 def __init__(self, **kwargs):
154 def __init__(self, **kwargs):
155 super(SQLiteDB, self).__init__(**kwargs)
155 super(SQLiteDB, self).__init__(**kwargs)
156 if sqlite3 is None:
156 if sqlite3 is None:
157 raise ImportError("SQLiteDB requires sqlite3")
157 raise ImportError("SQLiteDB requires sqlite3")
158 if not self.table:
158 if not self.table:
159 # use session, and prefix _, since starting with # is illegal
159 # use session, and prefix _, since starting with # is illegal
160 self.table = '_'+self.session.replace('-','_')
160 self.table = '_'+self.session.replace('-','_')
161 if not self.location:
161 if not self.location:
162 # get current profile
162 # get current profile
163 from IPython.core.application import BaseIPythonApplication
163 from IPython.core.application import BaseIPythonApplication
164 if BaseIPythonApplication.initialized():
164 if BaseIPythonApplication.initialized():
165 app = BaseIPythonApplication.instance()
165 app = BaseIPythonApplication.instance()
166 if app.profile_dir is not None:
166 if app.profile_dir is not None:
167 self.location = app.profile_dir.location
167 self.location = app.profile_dir.location
168 else:
168 else:
169 self.location = u'.'
169 self.location = u'.'
170 else:
170 else:
171 self.location = u'.'
171 self.location = u'.'
172 self._init_db()
172 self._init_db()
173
173
174 # register db commit as 2s periodic callback
174 # register db commit as 2s periodic callback
175 # to prevent clogging pipes
175 # to prevent clogging pipes
176 # assumes we are being run in a zmq ioloop app
176 # assumes we are being run in a zmq ioloop app
177 loop = ioloop.IOLoop.instance()
177 loop = ioloop.IOLoop.instance()
178 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
178 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
179 pc.start()
179 pc.start()
180
180
181 def _defaults(self, keys=None):
181 def _defaults(self, keys=None):
182 """create an empty record"""
182 """create an empty record"""
183 d = {}
183 d = {}
184 keys = self._keys if keys is None else keys
184 keys = self._keys if keys is None else keys
185 for key in keys:
185 for key in keys:
186 d[key] = None
186 d[key] = None
187 return d
187 return d
188
188
189 def _check_table(self):
189 def _check_table(self):
190 """Ensure that an incorrect table doesn't exist
190 """Ensure that an incorrect table doesn't exist
191
191
192 If a bad (old) table does exist, return False
192 If a bad (old) table does exist, return False
193 """
193 """
194 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
194 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
195 lines = cursor.fetchall()
195 lines = cursor.fetchall()
196 if not lines:
196 if not lines:
197 # table does not exist
197 # table does not exist
198 return True
198 return True
199 types = {}
199 types = {}
200 keys = []
200 keys = []
201 for line in lines:
201 for line in lines:
202 keys.append(line[1])
202 keys.append(line[1])
203 types[line[1]] = line[2]
203 types[line[1]] = line[2]
204 if self._keys != keys:
204 if self._keys != keys:
205 # key mismatch
205 # key mismatch
206 self.log.warn('keys mismatch')
206 self.log.warn('keys mismatch')
207 return False
207 return False
208 for key in self._keys:
208 for key in self._keys:
209 if types[key] != self._types[key]:
209 if types[key] != self._types[key]:
210 self.log.warn(
210 self.log.warn(
211 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
211 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
212 )
212 )
213 return False
213 return False
214 return True
214 return True
215
215
216 def _init_db(self):
216 def _init_db(self):
217 """Connect to the database and get new session number."""
217 """Connect to the database and get new session number."""
218 # register adapters
218 # register adapters
219 sqlite3.register_adapter(dict, _adapt_dict)
219 sqlite3.register_adapter(dict, _adapt_dict)
220 sqlite3.register_converter('dict', _convert_dict)
220 sqlite3.register_converter('dict', _convert_dict)
221 sqlite3.register_adapter(list, _adapt_bufs)
221 sqlite3.register_adapter(list, _adapt_bufs)
222 sqlite3.register_converter('bufs', _convert_bufs)
222 sqlite3.register_converter('bufs', _convert_bufs)
223 # connect to the db
223 # connect to the db
224 dbfile = os.path.join(self.location, self.filename)
224 dbfile = os.path.join(self.location, self.filename)
225 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
225 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
226 # isolation_level = None)#,
226 # isolation_level = None)#,
227 cached_statements=64)
227 cached_statements=64)
228 # print dir(self._db)
228 # print dir(self._db)
229 first_table = self.table
229 first_table = previous_table = self.table
230 i=0
230 i=0
231 while not self._check_table():
231 while not self._check_table():
232 i+=1
232 i+=1
233 self.table = first_table+'_%i'%i
233 self.table = first_table+'_%i'%i
234 self.log.warn(
234 self.log.warn(
235 "Table %s exists and doesn't match db format, trying %s"%
235 "Table %s exists and doesn't match db format, trying %s"%
236 (first_table,self.table)
236 (previous_table, self.table)
237 )
237 )
238 previous_table = self.table
238
239
239 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
240 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
240 (msg_id text PRIMARY KEY,
241 (msg_id text PRIMARY KEY,
241 header dict text,
242 header dict text,
242 content dict text,
243 content dict text,
243 buffers bufs blob,
244 buffers bufs blob,
244 submitted timestamp,
245 submitted timestamp,
245 client_uuid text,
246 client_uuid text,
246 engine_uuid text,
247 engine_uuid text,
247 started timestamp,
248 started timestamp,
248 completed timestamp,
249 completed timestamp,
249 resubmitted timestamp,
250 resubmitted timestamp,
250 received timestamp,
251 received timestamp,
251 result_header dict text,
252 result_header dict text,
252 result_content dict text,
253 result_content dict text,
253 result_buffers bufs blob,
254 result_buffers bufs blob,
254 queue text,
255 queue text,
255 pyin text,
256 pyin text,
256 pyout text,
257 pyout text,
257 pyerr text,
258 pyerr text,
258 stdout text,
259 stdout text,
259 stderr text)
260 stderr text)
260 """%self.table)
261 """%self.table)
261 self._db.commit()
262 self._db.commit()
262
263
263 def _dict_to_list(self, d):
264 def _dict_to_list(self, d):
264 """turn a mongodb-style record dict into a list."""
265 """turn a mongodb-style record dict into a list."""
265
266
266 return [ d[key] for key in self._keys ]
267 return [ d[key] for key in self._keys ]
267
268
268 def _list_to_dict(self, line, keys=None):
269 def _list_to_dict(self, line, keys=None):
269 """Inverse of dict_to_list"""
270 """Inverse of dict_to_list"""
270 keys = self._keys if keys is None else keys
271 keys = self._keys if keys is None else keys
271 d = self._defaults(keys)
272 d = self._defaults(keys)
272 for key,value in zip(keys, line):
273 for key,value in zip(keys, line):
273 d[key] = value
274 d[key] = value
274
275
275 return d
276 return d
276
277
277 def _render_expression(self, check):
278 def _render_expression(self, check):
278 """Turn a mongodb-style search dict into an SQL query."""
279 """Turn a mongodb-style search dict into an SQL query."""
279 expressions = []
280 expressions = []
280 args = []
281 args = []
281
282
282 skeys = set(check.keys())
283 skeys = set(check.keys())
283 skeys.difference_update(set(self._keys))
284 skeys.difference_update(set(self._keys))
284 skeys.difference_update(set(['buffers', 'result_buffers']))
285 skeys.difference_update(set(['buffers', 'result_buffers']))
285 if skeys:
286 if skeys:
286 raise KeyError("Illegal testing key(s): %s"%skeys)
287 raise KeyError("Illegal testing key(s): %s"%skeys)
287
288
288 for name,sub_check in check.iteritems():
289 for name,sub_check in check.iteritems():
289 if isinstance(sub_check, dict):
290 if isinstance(sub_check, dict):
290 for test,value in sub_check.iteritems():
291 for test,value in sub_check.iteritems():
291 try:
292 try:
292 op = operators[test]
293 op = operators[test]
293 except KeyError:
294 except KeyError:
294 raise KeyError("Unsupported operator: %r"%test)
295 raise KeyError("Unsupported operator: %r"%test)
295 if isinstance(op, tuple):
296 if isinstance(op, tuple):
296 op, join = op
297 op, join = op
297
298
298 if value is None and op in null_operators:
299 if value is None and op in null_operators:
299 expr = "%s %s" % (name, null_operators[op])
300 expr = "%s %s" % (name, null_operators[op])
300 else:
301 else:
301 expr = "%s %s ?"%(name, op)
302 expr = "%s %s ?"%(name, op)
302 if isinstance(value, (tuple,list)):
303 if isinstance(value, (tuple,list)):
303 if op in null_operators and any([v is None for v in value]):
304 if op in null_operators and any([v is None for v in value]):
304 # equality tests don't work with NULL
305 # equality tests don't work with NULL
305 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
306 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
306 expr = '( %s )'%( join.join([expr]*len(value)) )
307 expr = '( %s )'%( join.join([expr]*len(value)) )
307 args.extend(value)
308 args.extend(value)
308 else:
309 else:
309 args.append(value)
310 args.append(value)
310 expressions.append(expr)
311 expressions.append(expr)
311 else:
312 else:
312 # it's an equality check
313 # it's an equality check
313 if sub_check is None:
314 if sub_check is None:
314 expressions.append("%s IS NULL" % name)
315 expressions.append("%s IS NULL" % name)
315 else:
316 else:
316 expressions.append("%s = ?"%name)
317 expressions.append("%s = ?"%name)
317 args.append(sub_check)
318 args.append(sub_check)
318
319
319 expr = " AND ".join(expressions)
320 expr = " AND ".join(expressions)
320 return expr, args
321 return expr, args
321
322
322 def add_record(self, msg_id, rec):
323 def add_record(self, msg_id, rec):
323 """Add a new Task Record, by msg_id."""
324 """Add a new Task Record, by msg_id."""
324 d = self._defaults()
325 d = self._defaults()
325 d.update(rec)
326 d.update(rec)
326 d['msg_id'] = msg_id
327 d['msg_id'] = msg_id
327 line = self._dict_to_list(d)
328 line = self._dict_to_list(d)
328 tups = '(%s)'%(','.join(['?']*len(line)))
329 tups = '(%s)'%(','.join(['?']*len(line)))
329 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
330 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
330 # self._db.commit()
331 # self._db.commit()
331
332
332 def get_record(self, msg_id):
333 def get_record(self, msg_id):
333 """Get a specific Task Record, by msg_id."""
334 """Get a specific Task Record, by msg_id."""
334 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
335 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
335 line = cursor.fetchone()
336 line = cursor.fetchone()
336 if line is None:
337 if line is None:
337 raise KeyError("No such msg: %r"%msg_id)
338 raise KeyError("No such msg: %r"%msg_id)
338 return self._list_to_dict(line)
339 return self._list_to_dict(line)
339
340
340 def update_record(self, msg_id, rec):
341 def update_record(self, msg_id, rec):
341 """Update the data in an existing record."""
342 """Update the data in an existing record."""
342 query = "UPDATE %s SET "%self.table
343 query = "UPDATE %s SET "%self.table
343 sets = []
344 sets = []
344 keys = sorted(rec.keys())
345 keys = sorted(rec.keys())
345 values = []
346 values = []
346 for key in keys:
347 for key in keys:
347 sets.append('%s = ?'%key)
348 sets.append('%s = ?'%key)
348 values.append(rec[key])
349 values.append(rec[key])
349 query += ', '.join(sets)
350 query += ', '.join(sets)
350 query += ' WHERE msg_id == ?'
351 query += ' WHERE msg_id == ?'
351 values.append(msg_id)
352 values.append(msg_id)
352 self._db.execute(query, values)
353 self._db.execute(query, values)
353 # self._db.commit()
354 # self._db.commit()
354
355
355 def drop_record(self, msg_id):
356 def drop_record(self, msg_id):
356 """Remove a record from the DB."""
357 """Remove a record from the DB."""
357 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
358 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
358 # self._db.commit()
359 # self._db.commit()
359
360
360 def drop_matching_records(self, check):
361 def drop_matching_records(self, check):
361 """Remove a record from the DB."""
362 """Remove a record from the DB."""
362 expr,args = self._render_expression(check)
363 expr,args = self._render_expression(check)
363 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
364 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
364 self._db.execute(query,args)
365 self._db.execute(query,args)
365 # self._db.commit()
366 # self._db.commit()
366
367
367 def find_records(self, check, keys=None):
368 def find_records(self, check, keys=None):
368 """Find records matching a query dict, optionally extracting subset of keys.
369 """Find records matching a query dict, optionally extracting subset of keys.
369
370
370 Returns list of matching records.
371 Returns list of matching records.
371
372
372 Parameters
373 Parameters
373 ----------
374 ----------
374
375
375 check: dict
376 check: dict
376 mongodb-style query argument
377 mongodb-style query argument
377 keys: list of strs [optional]
378 keys: list of strs [optional]
378 if specified, the subset of keys to extract. msg_id will *always* be
379 if specified, the subset of keys to extract. msg_id will *always* be
379 included.
380 included.
380 """
381 """
381 if keys:
382 if keys:
382 bad_keys = [ key for key in keys if key not in self._keys ]
383 bad_keys = [ key for key in keys if key not in self._keys ]
383 if bad_keys:
384 if bad_keys:
384 raise KeyError("Bad record key(s): %s"%bad_keys)
385 raise KeyError("Bad record key(s): %s"%bad_keys)
385
386
386 if keys:
387 if keys:
387 # ensure msg_id is present and first:
388 # ensure msg_id is present and first:
388 if 'msg_id' in keys:
389 if 'msg_id' in keys:
389 keys.remove('msg_id')
390 keys.remove('msg_id')
390 keys.insert(0, 'msg_id')
391 keys.insert(0, 'msg_id')
391 req = ', '.join(keys)
392 req = ', '.join(keys)
392 else:
393 else:
393 req = '*'
394 req = '*'
394 expr,args = self._render_expression(check)
395 expr,args = self._render_expression(check)
395 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
396 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
396 cursor = self._db.execute(query, args)
397 cursor = self._db.execute(query, args)
397 matches = cursor.fetchall()
398 matches = cursor.fetchall()
398 records = []
399 records = []
399 for line in matches:
400 for line in matches:
400 rec = self._list_to_dict(line, keys)
401 rec = self._list_to_dict(line, keys)
401 records.append(rec)
402 records.append(rec)
402 return records
403 return records
403
404
404 def get_history(self):
405 def get_history(self):
405 """get all msg_ids, ordered by time submitted."""
406 """get all msg_ids, ordered by time submitted."""
406 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
407 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
407 cursor = self._db.execute(query)
408 cursor = self._db.execute(query)
408 # will be a list of length 1 tuples
409 # will be a list of length 1 tuples
409 return [ tup[0] for tup in cursor.fetchall()]
410 return [ tup[0] for tup in cursor.fetchall()]
410
411
411 __all__ = ['SQLiteDB'] No newline at end of file
412 __all__ = ['SQLiteDB']
@@ -1,240 +1,243 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import logging
21 import os
22 import os
22 import tempfile
23 import tempfile
23 import time
24 import time
24
25
25 from datetime import datetime, timedelta
26 from datetime import datetime, timedelta
26 from unittest import TestCase
27 from unittest import TestCase
27
28
28 from IPython.parallel import error
29 from IPython.parallel import error
29 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.hub import init_record, empty_record
32 from IPython.parallel.controller.hub import init_record, empty_record
32
33
33 from IPython.testing import decorators as dec
34 from IPython.testing import decorators as dec
34 from IPython.zmq.session import Session
35 from IPython.zmq.session import Session
35
36
36
37
37 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
38 # TestCases
39 # TestCases
39 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
40
41
41
42
42 def setup():
43 def setup():
43 global temp_db
44 global temp_db
44 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45
46
46
47
47 class TestDictBackend(TestCase):
48 class TestDictBackend(TestCase):
48 def setUp(self):
49 def setUp(self):
49 self.session = Session()
50 self.session = Session()
50 self.db = self.create_db()
51 self.db = self.create_db()
51 self.load_records(16)
52 self.load_records(16)
52
53
53 def create_db(self):
54 def create_db(self):
54 return DictDB()
55 return DictDB()
55
56
56 def load_records(self, n=1):
57 def load_records(self, n=1):
57 """load n records for testing"""
58 """load n records for testing"""
58 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 time.sleep(0.1)
60 time.sleep(0.1)
60 msg_ids = []
61 msg_ids = []
61 for i in range(n):
62 for i in range(n):
62 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg['buffers'] = []
64 msg['buffers'] = []
64 rec = init_record(msg)
65 rec = init_record(msg)
65 msg_id = msg['header']['msg_id']
66 msg_id = msg['header']['msg_id']
66 msg_ids.append(msg_id)
67 msg_ids.append(msg_id)
67 self.db.add_record(msg_id, rec)
68 self.db.add_record(msg_id, rec)
68 return msg_ids
69 return msg_ids
69
70
70 def test_add_record(self):
71 def test_add_record(self):
71 before = self.db.get_history()
72 before = self.db.get_history()
72 self.load_records(5)
73 self.load_records(5)
73 after = self.db.get_history()
74 after = self.db.get_history()
74 self.assertEquals(len(after), len(before)+5)
75 self.assertEquals(len(after), len(before)+5)
75 self.assertEquals(after[:-5],before)
76 self.assertEquals(after[:-5],before)
76
77
77 def test_drop_record(self):
78 def test_drop_record(self):
78 msg_id = self.load_records()[-1]
79 msg_id = self.load_records()[-1]
79 rec = self.db.get_record(msg_id)
80 rec = self.db.get_record(msg_id)
80 self.db.drop_record(msg_id)
81 self.db.drop_record(msg_id)
81 self.assertRaises(KeyError,self.db.get_record, msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
82
83
83 def _round_to_millisecond(self, dt):
84 def _round_to_millisecond(self, dt):
84 """necessary because mongodb rounds microseconds"""
85 """necessary because mongodb rounds microseconds"""
85 micro = dt.microsecond
86 micro = dt.microsecond
86 extra = int(str(micro)[-3:])
87 extra = int(str(micro)[-3:])
87 return dt - timedelta(microseconds=extra)
88 return dt - timedelta(microseconds=extra)
88
89
89 def test_update_record(self):
90 def test_update_record(self):
90 now = self._round_to_millisecond(datetime.now())
91 now = self._round_to_millisecond(datetime.now())
91 #
92 #
92 msg_id = self.db.get_history()[-1]
93 msg_id = self.db.get_history()[-1]
93 rec1 = self.db.get_record(msg_id)
94 rec1 = self.db.get_record(msg_id)
94 data = {'stdout': 'hello there', 'completed' : now}
95 data = {'stdout': 'hello there', 'completed' : now}
95 self.db.update_record(msg_id, data)
96 self.db.update_record(msg_id, data)
96 rec2 = self.db.get_record(msg_id)
97 rec2 = self.db.get_record(msg_id)
97 self.assertEquals(rec2['stdout'], 'hello there')
98 self.assertEquals(rec2['stdout'], 'hello there')
98 self.assertEquals(rec2['completed'], now)
99 self.assertEquals(rec2['completed'], now)
99 rec1.update(data)
100 rec1.update(data)
100 self.assertEquals(rec1, rec2)
101 self.assertEquals(rec1, rec2)
101
102
102 # def test_update_record_bad(self):
103 # def test_update_record_bad(self):
103 # """test updating nonexistant records"""
104 # """test updating nonexistant records"""
104 # msg_id = str(uuid.uuid4())
105 # msg_id = str(uuid.uuid4())
105 # data = {'stdout': 'hello there'}
106 # data = {'stdout': 'hello there'}
106 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107
108
108 def test_find_records_dt(self):
109 def test_find_records_dt(self):
109 """test finding records by date"""
110 """test finding records by date"""
110 hist = self.db.get_history()
111 hist = self.db.get_history()
111 middle = self.db.get_record(hist[len(hist)//2])
112 middle = self.db.get_record(hist[len(hist)//2])
112 tic = middle['submitted']
113 tic = middle['submitted']
113 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 self.assertEquals(len(before)+len(after),len(hist))
116 self.assertEquals(len(before)+len(after),len(hist))
116 for b in before:
117 for b in before:
117 self.assertTrue(b['submitted'] < tic)
118 self.assertTrue(b['submitted'] < tic)
118 for a in after:
119 for a in after:
119 self.assertTrue(a['submitted'] >= tic)
120 self.assertTrue(a['submitted'] >= tic)
120 same = self.db.find_records({'submitted' : tic})
121 same = self.db.find_records({'submitted' : tic})
121 for s in same:
122 for s in same:
122 self.assertTrue(s['submitted'] == tic)
123 self.assertTrue(s['submitted'] == tic)
123
124
124 def test_find_records_keys(self):
125 def test_find_records_keys(self):
125 """test extracting subset of record keys"""
126 """test extracting subset of record keys"""
126 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 for rec in found:
128 for rec in found:
128 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129
130
130 def test_find_records_msg_id(self):
131 def test_find_records_msg_id(self):
131 """ensure msg_id is always in found records"""
132 """ensure msg_id is always in found records"""
132 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 for rec in found:
134 for rec in found:
134 self.assertTrue('msg_id' in rec.keys())
135 self.assertTrue('msg_id' in rec.keys())
135 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 for rec in found:
137 for rec in found:
137 self.assertTrue('msg_id' in rec.keys())
138 self.assertTrue('msg_id' in rec.keys())
138 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 for rec in found:
140 for rec in found:
140 self.assertTrue('msg_id' in rec.keys())
141 self.assertTrue('msg_id' in rec.keys())
141
142
142 def test_find_records_in(self):
143 def test_find_records_in(self):
143 """test finding records with '$in','$nin' operators"""
144 """test finding records with '$in','$nin' operators"""
144 hist = self.db.get_history()
145 hist = self.db.get_history()
145 even = hist[::2]
146 even = hist[::2]
146 odd = hist[1::2]
147 odd = hist[1::2]
147 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 found = [ r['msg_id'] for r in recs ]
149 found = [ r['msg_id'] for r in recs ]
149 self.assertEquals(set(even), set(found))
150 self.assertEquals(set(even), set(found))
150 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 found = [ r['msg_id'] for r in recs ]
152 found = [ r['msg_id'] for r in recs ]
152 self.assertEquals(set(odd), set(found))
153 self.assertEquals(set(odd), set(found))
153
154
154 def test_get_history(self):
155 def test_get_history(self):
155 msg_ids = self.db.get_history()
156 msg_ids = self.db.get_history()
156 latest = datetime(1984,1,1)
157 latest = datetime(1984,1,1)
157 for msg_id in msg_ids:
158 for msg_id in msg_ids:
158 rec = self.db.get_record(msg_id)
159 rec = self.db.get_record(msg_id)
159 newt = rec['submitted']
160 newt = rec['submitted']
160 self.assertTrue(newt >= latest)
161 self.assertTrue(newt >= latest)
161 latest = newt
162 latest = newt
162 msg_id = self.load_records(1)[-1]
163 msg_id = self.load_records(1)[-1]
163 self.assertEquals(self.db.get_history()[-1],msg_id)
164 self.assertEquals(self.db.get_history()[-1],msg_id)
164
165
165 def test_datetime(self):
166 def test_datetime(self):
166 """get/set timestamps with datetime objects"""
167 """get/set timestamps with datetime objects"""
167 msg_id = self.db.get_history()[-1]
168 msg_id = self.db.get_history()[-1]
168 rec = self.db.get_record(msg_id)
169 rec = self.db.get_record(msg_id)
169 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 rec = self.db.get_record(msg_id)
172 rec = self.db.get_record(msg_id)
172 self.assertTrue(isinstance(rec['completed'], datetime))
173 self.assertTrue(isinstance(rec['completed'], datetime))
173
174
174 def test_drop_matching(self):
175 def test_drop_matching(self):
175 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
176 query = {'msg_id' : {'$in':msg_ids}}
177 query = {'msg_id' : {'$in':msg_ids}}
177 self.db.drop_matching_records(query)
178 self.db.drop_matching_records(query)
178 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
179 self.assertEquals(len(recs), 0)
180 self.assertEquals(len(recs), 0)
180
181
181 def test_null(self):
182 def test_null(self):
182 """test None comparison queries"""
183 """test None comparison queries"""
183 msg_ids = self.load_records(10)
184 msg_ids = self.load_records(10)
184
185
185 query = {'msg_id' : None}
186 query = {'msg_id' : None}
186 recs = self.db.find_records(query)
187 recs = self.db.find_records(query)
187 self.assertEquals(len(recs), 0)
188 self.assertEquals(len(recs), 0)
188
189
189 query = {'msg_id' : {'$ne' : None}}
190 query = {'msg_id' : {'$ne' : None}}
190 recs = self.db.find_records(query)
191 recs = self.db.find_records(query)
191 self.assertTrue(len(recs) >= 10)
192 self.assertTrue(len(recs) >= 10)
192
193
193 def test_pop_safe_get(self):
194 def test_pop_safe_get(self):
194 """editing query results shouldn't affect record [get]"""
195 """editing query results shouldn't affect record [get]"""
195 msg_id = self.db.get_history()[-1]
196 msg_id = self.db.get_history()[-1]
196 rec = self.db.get_record(msg_id)
197 rec = self.db.get_record(msg_id)
197 rec.pop('buffers')
198 rec.pop('buffers')
198 rec['garbage'] = 'hello'
199 rec['garbage'] = 'hello'
199 rec2 = self.db.get_record(msg_id)
200 rec2 = self.db.get_record(msg_id)
200 self.assertTrue('buffers' in rec2)
201 self.assertTrue('buffers' in rec2)
201 self.assertFalse('garbage' in rec2)
202 self.assertFalse('garbage' in rec2)
202
203
203 def test_pop_safe_find(self):
204 def test_pop_safe_find(self):
204 """editing query results shouldn't affect record [find]"""
205 """editing query results shouldn't affect record [find]"""
205 msg_id = self.db.get_history()[-1]
206 msg_id = self.db.get_history()[-1]
206 rec = self.db.find_records({'msg_id' : msg_id})[0]
207 rec = self.db.find_records({'msg_id' : msg_id})[0]
207 rec.pop('buffers')
208 rec.pop('buffers')
208 rec['garbage'] = 'hello'
209 rec['garbage'] = 'hello'
209 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
210 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
210 self.assertTrue('buffers' in rec2)
211 self.assertTrue('buffers' in rec2)
211 self.assertFalse('garbage' in rec2)
212 self.assertFalse('garbage' in rec2)
212
213
213 def test_pop_safe_find_keys(self):
214 def test_pop_safe_find_keys(self):
214 """editing query results shouldn't affect record [find+keys]"""
215 """editing query results shouldn't affect record [find+keys]"""
215 msg_id = self.db.get_history()[-1]
216 msg_id = self.db.get_history()[-1]
216 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers'])[0]
217 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers'])[0]
217 rec.pop('buffers')
218 rec.pop('buffers')
218 rec['garbage'] = 'hello'
219 rec['garbage'] = 'hello'
219 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
220 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
220 self.assertTrue('buffers' in rec2)
221 self.assertTrue('buffers' in rec2)
221 self.assertFalse('garbage' in rec2)
222 self.assertFalse('garbage' in rec2)
222
223
223
224
224 class TestSQLiteBackend(TestDictBackend):
225 class TestSQLiteBackend(TestDictBackend):
225
226
226 @dec.skip_without('sqlite3')
227 @dec.skip_without('sqlite3')
227 def create_db(self):
228 def create_db(self):
228 location, fname = os.path.split(temp_db)
229 location, fname = os.path.split(temp_db)
229 return SQLiteDB(location=location, fname=fname)
230 log = logging.getLogger('test')
231 log.setLevel(logging.CRITICAL)
232 return SQLiteDB(location=location, fname=fname, log=log)
230
233
231 def tearDown(self):
234 def tearDown(self):
232 self.db._db.close()
235 self.db._db.close()
233
236
234
237
235 def teardown():
238 def teardown():
236 """cleanup task db file after all tests have run"""
239 """cleanup task db file after all tests have run"""
237 try:
240 try:
238 os.remove(temp_db)
241 os.remove(temp_db)
239 except:
242 except:
240 pass
243 pass
General Comments 0
You need to be logged in to leave comments. Login now