sqlitedb.py
399 lines
| 13.3 KiB
| text/x-python
|
PythonLexer
MinRK
|
r4018 | """A TaskRecord backend using sqlite3 | |
Authors: | |||
* Min RK | |||
""" | |||
MinRK
|
r3652 | #----------------------------------------------------------------------------- | |
# Copyright (C) 2011 The IPython Development Team | |||
# | |||
# Distributed under the terms of the BSD License. The full license is in | |||
# the file COPYING, distributed as part of this software. | |||
#----------------------------------------------------------------------------- | |||
import json | |||
import os | |||
import cPickle as pickle | |||
from datetime import datetime | |||
import sqlite3 | |||
MinRK
|
r3668 | from zmq.eventloop import ioloop | |
MinRK
|
r4009 | from IPython.utils.traitlets import Unicode, Instance, List, Dict | |
MinRK
|
r3652 | from .dictdb import BaseDB | |
MinRK
|
r4009 | from IPython.utils.jsonutil import date_default, extract_dates, squash_dates | |
MinRK
|
r3652 | ||
#----------------------------------------------------------------------------- | |||
# SQLite operators, adapters, and converters | |||
#----------------------------------------------------------------------------- | |||
MinRK
|
r4155 | try: | |
buffer | |||
except NameError: | |||
# py3k | |||
buffer = memoryview | |||
MinRK
|
r3652 | operators = { | |
MinRK
|
r3780 | '$lt' : "<", | |
MinRK
|
r3652 | '$gt' : ">", | |
# null is handled weird with ==,!= | |||
MinRK
|
r3875 | '$eq' : "=", | |
'$ne' : "!=", | |||
MinRK
|
r3652 | '$lte': "<=", | |
'$gte': ">=", | |||
MinRK
|
r3875 | '$in' : ('=', ' OR '), | |
'$nin': ('!=', ' AND '), | |||
MinRK
|
r3652 | # '$all': None, | |
# '$mod': None, | |||
# '$exists' : None | |||
} | |||
MinRK
|
r3875 | null_operators = { | |
'=' : "IS NULL", | |||
'!=' : "IS NOT NULL", | |||
} | |||
MinRK
|
r3652 | ||
def _adapt_dict(d): | |||
MinRK
|
r4006 | return json.dumps(d, default=date_default) | |
MinRK
|
r3652 | ||
def _convert_dict(ds): | |||
if ds is None: | |||
return ds | |||
else: | |||
MinRK
|
r4155 | if isinstance(ds, bytes): | |
# If I understand the sqlite doc correctly, this will always be utf8 | |||
ds = ds.decode('utf8') | |||
MinRK
|
r4161 | return extract_dates(json.loads(ds)) | |
MinRK
|
r3652 | ||
def _adapt_bufs(bufs): | |||
# this is *horrible* | |||
# copy buffers into single list and pickle it: | |||
if bufs and isinstance(bufs[0], (bytes, buffer)): | |||
return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1)) | |||
elif bufs: | |||
return bufs | |||
else: | |||
return None | |||
def _convert_bufs(bs): | |||
if bs is None: | |||
return [] | |||
else: | |||
return pickle.loads(bytes(bs)) | |||
#----------------------------------------------------------------------------- | |||
# SQLiteDB class | |||
#----------------------------------------------------------------------------- | |||
class SQLiteDB(BaseDB): | |||
"""SQLite3 TaskRecord backend.""" | |||
MinRK
|
r3988 | filename = Unicode('tasks.db', config=True, | |
MinRK
|
r3985 | help="""The filename of the sqlite task database. [default: 'tasks.db']""") | |
MinRK
|
r3988 | location = Unicode('', config=True, | |
MinRK
|
r3985 | help="""The directory containing the sqlite task database. The default | |
is to use the cluster_dir location.""") | |||
MinRK
|
r3988 | table = Unicode("", config=True, | |
MinRK
|
r3985 | help="""The SQLite Table to use for storing tasks for this session. If unspecified, | |
a new table will be created with the Hub's IDENT. Specifying the table will result | |||
in tasks from previous sessions being available via Clients' db_query and | |||
get_result methods.""") | |||
MinRK
|
r3652 | ||
_db = Instance('sqlite3.Connection') | |||
MinRK
|
r4009 | # the ordered list of column names | |
MinRK
|
r3652 | _keys = List(['msg_id' , | |
'header' , | |||
'content', | |||
'buffers', | |||
'submitted', | |||
'client_uuid' , | |||
'engine_uuid' , | |||
'started', | |||
'completed', | |||
'resubmitted', | |||
'result_header' , | |||
'result_content' , | |||
'result_buffers' , | |||
'queue' , | |||
'pyin' , | |||
'pyout', | |||
'pyerr', | |||
'stdout', | |||
'stderr', | |||
]) | |||
MinRK
|
r4009 | # sqlite datatypes for checking that db is current format | |
_types = Dict({'msg_id' : 'text' , | |||
'header' : 'dict text', | |||
'content' : 'dict text', | |||
'buffers' : 'bufs blob', | |||
'submitted' : 'timestamp', | |||
'client_uuid' : 'text', | |||
'engine_uuid' : 'text', | |||
'started' : 'timestamp', | |||
'completed' : 'timestamp', | |||
'resubmitted' : 'timestamp', | |||
'result_header' : 'dict text', | |||
'result_content' : 'dict text', | |||
'result_buffers' : 'bufs blob', | |||
'queue' : 'text', | |||
'pyin' : 'text', | |||
'pyout' : 'text', | |||
'pyerr' : 'text', | |||
'stdout' : 'text', | |||
'stderr' : 'text', | |||
}) | |||
MinRK
|
r3652 | ||
def __init__(self, **kwargs): | |||
super(SQLiteDB, self).__init__(**kwargs) | |||
if not self.table: | |||
# use session, and prefix _, since starting with # is illegal | |||
self.table = '_'+self.session.replace('-','_') | |||
if not self.location: | |||
MinRK
|
r3992 | # get current profile | |
MinRK
|
r4023 | from IPython.core.application import BaseIPythonApplication | |
MinRK
|
r3992 | if BaseIPythonApplication.initialized(): | |
app = BaseIPythonApplication.instance() | |||
if app.profile_dir is not None: | |||
self.location = app.profile_dir.location | |||
else: | |||
self.location = u'.' | |||
MinRK
|
r3652 | else: | |
MinRK
|
r3992 | self.location = u'.' | |
MinRK
|
r3652 | self._init_db() | |
MinRK
|
r3668 | ||
# register db commit as 2s periodic callback | |||
# to prevent clogging pipes | |||
# assumes we are being run in a zmq ioloop app | |||
loop = ioloop.IOLoop.instance() | |||
pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop) | |||
pc.start() | |||
MinRK
|
r3652 | ||
MinRK
|
r3780 | def _defaults(self, keys=None): | |
MinRK
|
r3652 | """create an empty record""" | |
d = {} | |||
MinRK
|
r3780 | keys = self._keys if keys is None else keys | |
for key in keys: | |||
MinRK
|
r3652 | d[key] = None | |
return d | |||
MinRK
|
r4009 | def _check_table(self): | |
"""Ensure that an incorrect table doesn't exist | |||
If a bad (old) table does exist, return False | |||
""" | |||
cursor = self._db.execute("PRAGMA table_info(%s)"%self.table) | |||
lines = cursor.fetchall() | |||
if not lines: | |||
# table does not exist | |||
return True | |||
types = {} | |||
keys = [] | |||
for line in lines: | |||
keys.append(line[1]) | |||
types[line[1]] = line[2] | |||
if self._keys != keys: | |||
# key mismatch | |||
self.log.warn('keys mismatch') | |||
return False | |||
for key in self._keys: | |||
if types[key] != self._types[key]: | |||
self.log.warn( | |||
'type mismatch: %s: %s != %s'%(key,types[key],self._types[key]) | |||
) | |||
return False | |||
return True | |||
MinRK
|
r3652 | def _init_db(self): | |
"""Connect to the database and get new session number.""" | |||
# register adapters | |||
sqlite3.register_adapter(dict, _adapt_dict) | |||
sqlite3.register_converter('dict', _convert_dict) | |||
sqlite3.register_adapter(list, _adapt_bufs) | |||
sqlite3.register_converter('bufs', _convert_bufs) | |||
# connect to the db | |||
dbfile = os.path.join(self.location, self.filename) | |||
MinRK
|
r3668 | self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES, | |
# isolation_level = None)#, | |||
cached_statements=64) | |||
MinRK
|
r3666 | # print dir(self._db) | |
MinRK
|
r4009 | first_table = self.table | |
i=0 | |||
while not self._check_table(): | |||
i+=1 | |||
self.table = first_table+'_%i'%i | |||
self.log.warn( | |||
"Table %s exists and doesn't match db format, trying %s"% | |||
(first_table,self.table) | |||
) | |||
MinRK
|
r3652 | ||
self._db.execute("""CREATE TABLE IF NOT EXISTS %s | |||
(msg_id text PRIMARY KEY, | |||
header dict text, | |||
content dict text, | |||
buffers bufs blob, | |||
MinRK
|
r4009 | submitted timestamp, | |
MinRK
|
r3652 | client_uuid text, | |
engine_uuid text, | |||
MinRK
|
r4009 | started timestamp, | |
completed timestamp, | |||
resubmitted timestamp, | |||
MinRK
|
r3652 | result_header dict text, | |
result_content dict text, | |||
result_buffers bufs blob, | |||
queue text, | |||
pyin text, | |||
pyout text, | |||
pyerr text, | |||
stdout text, | |||
stderr text) | |||
"""%self.table) | |||
self._db.commit() | |||
def _dict_to_list(self, d): | |||
"""turn a mongodb-style record dict into a list.""" | |||
return [ d[key] for key in self._keys ] | |||
MinRK
|
r3780 | def _list_to_dict(self, line, keys=None): | |
MinRK
|
r3652 | """Inverse of dict_to_list""" | |
MinRK
|
r3780 | keys = self._keys if keys is None else keys | |
d = self._defaults(keys) | |||
for key,value in zip(keys, line): | |||
MinRK
|
r3652 | d[key] = value | |
return d | |||
def _render_expression(self, check): | |||
"""Turn a mongodb-style search dict into an SQL query.""" | |||
expressions = [] | |||
args = [] | |||
skeys = set(check.keys()) | |||
skeys.difference_update(set(self._keys)) | |||
skeys.difference_update(set(['buffers', 'result_buffers'])) | |||
if skeys: | |||
raise KeyError("Illegal testing key(s): %s"%skeys) | |||
for name,sub_check in check.iteritems(): | |||
if isinstance(sub_check, dict): | |||
for test,value in sub_check.iteritems(): | |||
try: | |||
op = operators[test] | |||
except KeyError: | |||
raise KeyError("Unsupported operator: %r"%test) | |||
if isinstance(op, tuple): | |||
op, join = op | |||
MinRK
|
r3875 | ||
if value is None and op in null_operators: | |||
expr = "%s %s"%null_operators[op] | |||
MinRK
|
r3652 | else: | |
MinRK
|
r3875 | expr = "%s %s ?"%(name, op) | |
if isinstance(value, (tuple,list)): | |||
if op in null_operators and any([v is None for v in value]): | |||
# equality tests don't work with NULL | |||
raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test) | |||
expr = '( %s )'%( join.join([expr]*len(value)) ) | |||
args.extend(value) | |||
else: | |||
args.append(value) | |||
MinRK
|
r3652 | expressions.append(expr) | |
else: | |||
# it's an equality check | |||
MinRK
|
r3875 | if sub_check is None: | |
expressions.append("%s IS NULL") | |||
else: | |||
expressions.append("%s = ?"%name) | |||
args.append(sub_check) | |||
MinRK
|
r3652 | ||
expr = " AND ".join(expressions) | |||
return expr, args | |||
def add_record(self, msg_id, rec): | |||
"""Add a new Task Record, by msg_id.""" | |||
d = self._defaults() | |||
d.update(rec) | |||
d['msg_id'] = msg_id | |||
line = self._dict_to_list(d) | |||
tups = '(%s)'%(','.join(['?']*len(line))) | |||
self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line) | |||
MinRK
|
r3668 | # self._db.commit() | |
MinRK
|
r3652 | ||
def get_record(self, msg_id): | |||
"""Get a specific Task Record, by msg_id.""" | |||
cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) | |||
line = cursor.fetchone() | |||
if line is None: | |||
raise KeyError("No such msg: %r"%msg_id) | |||
return self._list_to_dict(line) | |||
def update_record(self, msg_id, rec): | |||
"""Update the data in an existing record.""" | |||
query = "UPDATE %s SET "%self.table | |||
sets = [] | |||
keys = sorted(rec.keys()) | |||
values = [] | |||
for key in keys: | |||
sets.append('%s = ?'%key) | |||
values.append(rec[key]) | |||
query += ', '.join(sets) | |||
MinRK
|
r3780 | query += ' WHERE msg_id == ?' | |
values.append(msg_id) | |||
MinRK
|
r3652 | self._db.execute(query, values) | |
MinRK
|
r3668 | # self._db.commit() | |
MinRK
|
r3652 | ||
def drop_record(self, msg_id): | |||
"""Remove a record from the DB.""" | |||
MinRK
|
r3780 | self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) | |
MinRK
|
r3668 | # self._db.commit() | |
MinRK
|
r3652 | ||
def drop_matching_records(self, check): | |||
"""Remove a record from the DB.""" | |||
expr,args = self._render_expression(check) | |||
query = "DELETE FROM %s WHERE %s"%(self.table, expr) | |||
self._db.execute(query,args) | |||
MinRK
|
r3668 | # self._db.commit() | |
MinRK
|
r3652 | ||
MinRK
|
r3780 | def find_records(self, check, keys=None): | |
"""Find records matching a query dict, optionally extracting subset of keys. | |||
Returns list of matching records. | |||
Parameters | |||
---------- | |||
check: dict | |||
mongodb-style query argument | |||
keys: list of strs [optional] | |||
if specified, the subset of keys to extract. msg_id will *always* be | |||
included. | |||
""" | |||
if keys: | |||
bad_keys = [ key for key in keys if key not in self._keys ] | |||
if bad_keys: | |||
raise KeyError("Bad record key(s): %s"%bad_keys) | |||
if keys: | |||
# ensure msg_id is present and first: | |||
if 'msg_id' in keys: | |||
keys.remove('msg_id') | |||
keys.insert(0, 'msg_id') | |||
req = ', '.join(keys) | |||
else: | |||
req = '*' | |||
MinRK
|
r3652 | expr,args = self._render_expression(check) | |
query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) | |||
cursor = self._db.execute(query, args) | |||
matches = cursor.fetchall() | |||
MinRK
|
r3780 | records = [] | |
for line in matches: | |||
rec = self._list_to_dict(line, keys) | |||
records.append(rec) | |||
return records | |||
def get_history(self): | |||
"""get all msg_ids, ordered by time submitted.""" | |||
query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table | |||
cursor = self._db.execute(query) | |||
# will be a list of length 1 tuples | |||
return [ tup[0] for tup in cursor.fetchall()] | |||
MinRK
|
r3652 | ||
__all__ = ['SQLiteDB'] |