##// END OF EJS Templates
quote table name...
MinRK -
Show More
@@ -1,418 +1,418 b''
1 1 """A TaskRecord backend using sqlite3
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 import json
15 15 import os
16 16 import cPickle as pickle
17 17 from datetime import datetime
18 18
19 19 try:
20 20 import sqlite3
21 21 except ImportError:
22 22 sqlite3 = None
23 23
24 24 from zmq.eventloop import ioloop
25 25
26 26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
27 27 from .dictdb import BaseDB
28 28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # SQLite operators, adapters, and converters
32 32 #-----------------------------------------------------------------------------
33 33
34 34 try:
35 35 buffer
36 36 except NameError:
37 37 # py3k
38 38 buffer = memoryview
39 39
40 40 operators = {
41 41 '$lt' : "<",
42 42 '$gt' : ">",
43 43 # null is handled weird with ==,!=
44 44 '$eq' : "=",
45 45 '$ne' : "!=",
46 46 '$lte': "<=",
47 47 '$gte': ">=",
48 48 '$in' : ('=', ' OR '),
49 49 '$nin': ('!=', ' AND '),
50 50 # '$all': None,
51 51 # '$mod': None,
52 52 # '$exists' : None
53 53 }
54 54 null_operators = {
55 55 '=' : "IS NULL",
56 56 '!=' : "IS NOT NULL",
57 57 }
58 58
59 59 def _adapt_dict(d):
60 60 return json.dumps(d, default=date_default)
61 61
62 62 def _convert_dict(ds):
63 63 if ds is None:
64 64 return ds
65 65 else:
66 66 if isinstance(ds, bytes):
67 67 # If I understand the sqlite doc correctly, this will always be utf8
68 68 ds = ds.decode('utf8')
69 69 return extract_dates(json.loads(ds))
70 70
71 71 def _adapt_bufs(bufs):
72 72 # this is *horrible*
73 73 # copy buffers into single list and pickle it:
74 74 if bufs and isinstance(bufs[0], (bytes, buffer)):
75 75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
76 76 elif bufs:
77 77 return bufs
78 78 else:
79 79 return None
80 80
81 81 def _convert_bufs(bs):
82 82 if bs is None:
83 83 return []
84 84 else:
85 85 return pickle.loads(bytes(bs))
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # SQLiteDB class
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SQLiteDB(BaseDB):
92 92 """SQLite3 TaskRecord backend."""
93 93
94 94 filename = Unicode('tasks.db', config=True,
95 95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
96 96 location = Unicode('', config=True,
97 97 help="""The directory containing the sqlite task database. The default
98 98 is to use the cluster_dir location.""")
99 99 table = Unicode("ipython-tasks", config=True,
100 100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
101 101 a new table will be created with the Hub's IDENT. Specifying the table will result
102 102 in tasks from previous sessions being available via Clients' db_query and
103 103 get_result methods.""")
104 104
105 105 if sqlite3 is not None:
106 106 _db = Instance('sqlite3.Connection')
107 107 else:
108 108 _db = None
109 109 # the ordered list of column names
110 110 _keys = List(['msg_id' ,
111 111 'header' ,
112 112 'metadata',
113 113 'content',
114 114 'buffers',
115 115 'submitted',
116 116 'client_uuid' ,
117 117 'engine_uuid' ,
118 118 'started',
119 119 'completed',
120 120 'resubmitted',
121 121 'received',
122 122 'result_header' ,
123 123 'result_metadata',
124 124 'result_content' ,
125 125 'result_buffers' ,
126 126 'queue' ,
127 127 'pyin' ,
128 128 'pyout',
129 129 'pyerr',
130 130 'stdout',
131 131 'stderr',
132 132 ])
133 133 # sqlite datatypes for checking that db is current format
134 134 _types = Dict({'msg_id' : 'text' ,
135 135 'header' : 'dict text',
136 136 'metadata' : 'dict text',
137 137 'content' : 'dict text',
138 138 'buffers' : 'bufs blob',
139 139 'submitted' : 'timestamp',
140 140 'client_uuid' : 'text',
141 141 'engine_uuid' : 'text',
142 142 'started' : 'timestamp',
143 143 'completed' : 'timestamp',
144 144 'resubmitted' : 'text',
145 145 'received' : 'timestamp',
146 146 'result_header' : 'dict text',
147 147 'result_metadata' : 'dict text',
148 148 'result_content' : 'dict text',
149 149 'result_buffers' : 'bufs blob',
150 150 'queue' : 'text',
151 151 'pyin' : 'text',
152 152 'pyout' : 'text',
153 153 'pyerr' : 'text',
154 154 'stdout' : 'text',
155 155 'stderr' : 'text',
156 156 })
157 157
158 158 def __init__(self, **kwargs):
159 159 super(SQLiteDB, self).__init__(**kwargs)
160 160 if sqlite3 is None:
161 161 raise ImportError("SQLiteDB requires sqlite3")
162 162 if not self.table:
163 163 # use session, and prefix _, since starting with # is illegal
164 164 self.table = '_'+self.session.replace('-','_')
165 165 if not self.location:
166 166 # get current profile
167 167 from IPython.core.application import BaseIPythonApplication
168 168 if BaseIPythonApplication.initialized():
169 169 app = BaseIPythonApplication.instance()
170 170 if app.profile_dir is not None:
171 171 self.location = app.profile_dir.location
172 172 else:
173 173 self.location = u'.'
174 174 else:
175 175 self.location = u'.'
176 176 self._init_db()
177 177
178 178 # register db commit as 2s periodic callback
179 179 # to prevent clogging pipes
180 180 # assumes we are being run in a zmq ioloop app
181 181 loop = ioloop.IOLoop.instance()
182 182 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
183 183 pc.start()
184 184
185 185 def _defaults(self, keys=None):
186 186 """create an empty record"""
187 187 d = {}
188 188 keys = self._keys if keys is None else keys
189 189 for key in keys:
190 190 d[key] = None
191 191 return d
192 192
193 193 def _check_table(self):
194 194 """Ensure that an incorrect table doesn't exist
195 195
196 196 If a bad (old) table does exist, return False
197 197 """
198 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
198 cursor = self._db.execute("PRAGMA table_info('%s')"%self.table)
199 199 lines = cursor.fetchall()
200 200 if not lines:
201 201 # table does not exist
202 202 return True
203 203 types = {}
204 204 keys = []
205 205 for line in lines:
206 206 keys.append(line[1])
207 207 types[line[1]] = line[2]
208 208 if self._keys != keys:
209 209 # key mismatch
210 210 self.log.warn('keys mismatch')
211 211 return False
212 212 for key in self._keys:
213 213 if types[key] != self._types[key]:
214 214 self.log.warn(
215 215 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
216 216 )
217 217 return False
218 218 return True
219 219
220 220 def _init_db(self):
221 221 """Connect to the database and get new session number."""
222 222 # register adapters
223 223 sqlite3.register_adapter(dict, _adapt_dict)
224 224 sqlite3.register_converter('dict', _convert_dict)
225 225 sqlite3.register_adapter(list, _adapt_bufs)
226 226 sqlite3.register_converter('bufs', _convert_bufs)
227 227 # connect to the db
228 228 dbfile = os.path.join(self.location, self.filename)
229 229 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
230 230 # isolation_level = None)#,
231 231 cached_statements=64)
232 232 # print dir(self._db)
233 233 first_table = previous_table = self.table
234 234 i=0
235 235 while not self._check_table():
236 236 i+=1
237 237 self.table = first_table+'_%i'%i
238 238 self.log.warn(
239 239 "Table %s exists and doesn't match db format, trying %s"%
240 240 (previous_table, self.table)
241 241 )
242 242 previous_table = self.table
243 243
244 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
244 self._db.execute("""CREATE TABLE IF NOT EXISTS '%s'
245 245 (msg_id text PRIMARY KEY,
246 246 header dict text,
247 247 metadata dict text,
248 248 content dict text,
249 249 buffers bufs blob,
250 250 submitted timestamp,
251 251 client_uuid text,
252 252 engine_uuid text,
253 253 started timestamp,
254 254 completed timestamp,
255 255 resubmitted text,
256 256 received timestamp,
257 257 result_header dict text,
258 258 result_metadata dict text,
259 259 result_content dict text,
260 260 result_buffers bufs blob,
261 261 queue text,
262 262 pyin text,
263 263 pyout text,
264 264 pyerr text,
265 265 stdout text,
266 266 stderr text)
267 267 """%self.table)
268 268 self._db.commit()
269 269
270 270 def _dict_to_list(self, d):
271 271 """turn a mongodb-style record dict into a list."""
272 272
273 273 return [ d[key] for key in self._keys ]
274 274
275 275 def _list_to_dict(self, line, keys=None):
276 276 """Inverse of dict_to_list"""
277 277 keys = self._keys if keys is None else keys
278 278 d = self._defaults(keys)
279 279 for key,value in zip(keys, line):
280 280 d[key] = value
281 281
282 282 return d
283 283
284 284 def _render_expression(self, check):
285 285 """Turn a mongodb-style search dict into an SQL query."""
286 286 expressions = []
287 287 args = []
288 288
289 289 skeys = set(check.keys())
290 290 skeys.difference_update(set(self._keys))
291 291 skeys.difference_update(set(['buffers', 'result_buffers']))
292 292 if skeys:
293 293 raise KeyError("Illegal testing key(s): %s"%skeys)
294 294
295 295 for name,sub_check in check.iteritems():
296 296 if isinstance(sub_check, dict):
297 297 for test,value in sub_check.iteritems():
298 298 try:
299 299 op = operators[test]
300 300 except KeyError:
301 301 raise KeyError("Unsupported operator: %r"%test)
302 302 if isinstance(op, tuple):
303 303 op, join = op
304 304
305 305 if value is None and op in null_operators:
306 306 expr = "%s %s" % (name, null_operators[op])
307 307 else:
308 308 expr = "%s %s ?"%(name, op)
309 309 if isinstance(value, (tuple,list)):
310 310 if op in null_operators and any([v is None for v in value]):
311 311 # equality tests don't work with NULL
312 312 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
313 313 expr = '( %s )'%( join.join([expr]*len(value)) )
314 314 args.extend(value)
315 315 else:
316 316 args.append(value)
317 317 expressions.append(expr)
318 318 else:
319 319 # it's an equality check
320 320 if sub_check is None:
321 321 expressions.append("%s IS NULL" % name)
322 322 else:
323 323 expressions.append("%s = ?"%name)
324 324 args.append(sub_check)
325 325
326 326 expr = " AND ".join(expressions)
327 327 return expr, args
328 328
329 329 def add_record(self, msg_id, rec):
330 330 """Add a new Task Record, by msg_id."""
331 331 d = self._defaults()
332 332 d.update(rec)
333 333 d['msg_id'] = msg_id
334 334 line = self._dict_to_list(d)
335 335 tups = '(%s)'%(','.join(['?']*len(line)))
336 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
336 self._db.execute("INSERT INTO '%s' VALUES %s"%(self.table, tups), line)
337 337 # self._db.commit()
338 338
339 339 def get_record(self, msg_id):
340 340 """Get a specific Task Record, by msg_id."""
341 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
341 cursor = self._db.execute("""SELECT * FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
342 342 line = cursor.fetchone()
343 343 if line is None:
344 344 raise KeyError("No such msg: %r"%msg_id)
345 345 return self._list_to_dict(line)
346 346
347 347 def update_record(self, msg_id, rec):
348 348 """Update the data in an existing record."""
349 query = "UPDATE %s SET "%self.table
349 query = "UPDATE '%s' SET "%self.table
350 350 sets = []
351 351 keys = sorted(rec.keys())
352 352 values = []
353 353 for key in keys:
354 354 sets.append('%s = ?'%key)
355 355 values.append(rec[key])
356 356 query += ', '.join(sets)
357 357 query += ' WHERE msg_id == ?'
358 358 values.append(msg_id)
359 359 self._db.execute(query, values)
360 360 # self._db.commit()
361 361
362 362 def drop_record(self, msg_id):
363 363 """Remove a record from the DB."""
364 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
364 self._db.execute("""DELETE FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
365 365 # self._db.commit()
366 366
367 367 def drop_matching_records(self, check):
368 368 """Remove a record from the DB."""
369 369 expr,args = self._render_expression(check)
370 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
370 query = "DELETE FROM '%s' WHERE %s"%(self.table, expr)
371 371 self._db.execute(query,args)
372 372 # self._db.commit()
373 373
374 374 def find_records(self, check, keys=None):
375 375 """Find records matching a query dict, optionally extracting subset of keys.
376 376
377 377 Returns list of matching records.
378 378
379 379 Parameters
380 380 ----------
381 381
382 382 check: dict
383 383 mongodb-style query argument
384 384 keys: list of strs [optional]
385 385 if specified, the subset of keys to extract. msg_id will *always* be
386 386 included.
387 387 """
388 388 if keys:
389 389 bad_keys = [ key for key in keys if key not in self._keys ]
390 390 if bad_keys:
391 391 raise KeyError("Bad record key(s): %s"%bad_keys)
392 392
393 393 if keys:
394 394 # ensure msg_id is present and first:
395 395 if 'msg_id' in keys:
396 396 keys.remove('msg_id')
397 397 keys.insert(0, 'msg_id')
398 398 req = ', '.join(keys)
399 399 else:
400 400 req = '*'
401 401 expr,args = self._render_expression(check)
402 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
402 query = """SELECT %s FROM '%s' WHERE %s"""%(req, self.table, expr)
403 403 cursor = self._db.execute(query, args)
404 404 matches = cursor.fetchall()
405 405 records = []
406 406 for line in matches:
407 407 rec = self._list_to_dict(line, keys)
408 408 records.append(rec)
409 409 return records
410 410
411 411 def get_history(self):
412 412 """get all msg_ids, ordered by time submitted."""
413 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
413 query = """SELECT msg_id FROM '%s' ORDER by submitted ASC"""%self.table
414 414 cursor = self._db.execute(query)
415 415 # will be a list of length 1 tuples
416 416 return [ tup[0] for tup in cursor.fetchall()]
417 417
418 418 __all__ = ['SQLiteDB'] No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now