sqlitedb.py
271 lines
| 8.7 KiB
| text/x-python
|
PythonLexer
MinRK
|
r3652 | """A TaskRecord backend using sqlite3""" | |
#----------------------------------------------------------------------------- | |||
# Copyright (C) 2011 The IPython Development Team | |||
# | |||
# Distributed under the terms of the BSD License. The full license is in | |||
# the file COPYING, distributed as part of this software. | |||
#----------------------------------------------------------------------------- | |||
import json | |||
import os | |||
import cPickle as pickle | |||
from datetime import datetime | |||
import sqlite3 | |||
from IPython.utils.traitlets import CUnicode, CStr, Instance, List | |||
from .dictdb import BaseDB | |||
from .util import ISO8601 | |||
#----------------------------------------------------------------------------- | |||
# SQLite operators, adapters, and converters | |||
#----------------------------------------------------------------------------- | |||
operators = { | |||
'$lt' : lambda a,b: "%s < ?", | |||
'$gt' : ">", | |||
# null is handled weird with ==,!= | |||
'$eq' : "IS", | |||
'$ne' : "IS NOT", | |||
'$lte': "<=", | |||
'$gte': ">=", | |||
'$in' : ('IS', ' OR '), | |||
'$nin': ('IS NOT', ' AND '), | |||
# '$all': None, | |||
# '$mod': None, | |||
# '$exists' : None | |||
} | |||
def _adapt_datetime(dt): | |||
return dt.strftime(ISO8601) | |||
def _convert_datetime(ds): | |||
if ds is None: | |||
return ds | |||
else: | |||
return datetime.strptime(ds, ISO8601) | |||
def _adapt_dict(d): | |||
return json.dumps(d) | |||
def _convert_dict(ds): | |||
if ds is None: | |||
return ds | |||
else: | |||
return json.loads(ds) | |||
def _adapt_bufs(bufs): | |||
# this is *horrible* | |||
# copy buffers into single list and pickle it: | |||
if bufs and isinstance(bufs[0], (bytes, buffer)): | |||
return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1)) | |||
elif bufs: | |||
return bufs | |||
else: | |||
return None | |||
def _convert_bufs(bs): | |||
if bs is None: | |||
return [] | |||
else: | |||
return pickle.loads(bytes(bs)) | |||
#----------------------------------------------------------------------------- | |||
# SQLiteDB class | |||
#----------------------------------------------------------------------------- | |||
class SQLiteDB(BaseDB): | |||
"""SQLite3 TaskRecord backend.""" | |||
filename = CUnicode('tasks.db', config=True) | |||
location = CUnicode('', config=True) | |||
table = CUnicode("", config=True) | |||
_db = Instance('sqlite3.Connection') | |||
_keys = List(['msg_id' , | |||
'header' , | |||
'content', | |||
'buffers', | |||
'submitted', | |||
'client_uuid' , | |||
'engine_uuid' , | |||
'started', | |||
'completed', | |||
'resubmitted', | |||
'result_header' , | |||
'result_content' , | |||
'result_buffers' , | |||
'queue' , | |||
'pyin' , | |||
'pyout', | |||
'pyerr', | |||
'stdout', | |||
'stderr', | |||
]) | |||
def __init__(self, **kwargs): | |||
super(SQLiteDB, self).__init__(**kwargs) | |||
if not self.table: | |||
# use session, and prefix _, since starting with # is illegal | |||
self.table = '_'+self.session.replace('-','_') | |||
if not self.location: | |||
if hasattr(self.config.Global, 'cluster_dir'): | |||
self.location = self.config.Global.cluster_dir | |||
else: | |||
self.location = '.' | |||
self._init_db() | |||
def _defaults(self): | |||
"""create an empty record""" | |||
d = {} | |||
for key in self._keys: | |||
d[key] = None | |||
return d | |||
def _init_db(self): | |||
"""Connect to the database and get new session number.""" | |||
# register adapters | |||
sqlite3.register_adapter(datetime, _adapt_datetime) | |||
sqlite3.register_converter('datetime', _convert_datetime) | |||
sqlite3.register_adapter(dict, _adapt_dict) | |||
sqlite3.register_converter('dict', _convert_dict) | |||
sqlite3.register_adapter(list, _adapt_bufs) | |||
sqlite3.register_converter('bufs', _convert_bufs) | |||
# connect to the db | |||
dbfile = os.path.join(self.location, self.filename) | |||
self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES) | |||
self._db.execute("""CREATE TABLE IF NOT EXISTS %s | |||
(msg_id text PRIMARY KEY, | |||
header dict text, | |||
content dict text, | |||
buffers bufs blob, | |||
submitted datetime text, | |||
client_uuid text, | |||
engine_uuid text, | |||
started datetime text, | |||
completed datetime text, | |||
resubmitted datetime text, | |||
result_header dict text, | |||
result_content dict text, | |||
result_buffers bufs blob, | |||
queue text, | |||
pyin text, | |||
pyout text, | |||
pyerr text, | |||
stdout text, | |||
stderr text) | |||
"""%self.table) | |||
# self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers | |||
# (msg_id text, result integer, buffer blob) | |||
# """%self.table) | |||
self._db.commit() | |||
def _dict_to_list(self, d): | |||
"""turn a mongodb-style record dict into a list.""" | |||
return [ d[key] for key in self._keys ] | |||
def _list_to_dict(self, line): | |||
"""Inverse of dict_to_list""" | |||
d = self._defaults() | |||
for key,value in zip(self._keys, line): | |||
d[key] = value | |||
return d | |||
def _render_expression(self, check): | |||
"""Turn a mongodb-style search dict into an SQL query.""" | |||
expressions = [] | |||
args = [] | |||
skeys = set(check.keys()) | |||
skeys.difference_update(set(self._keys)) | |||
skeys.difference_update(set(['buffers', 'result_buffers'])) | |||
if skeys: | |||
raise KeyError("Illegal testing key(s): %s"%skeys) | |||
for name,sub_check in check.iteritems(): | |||
if isinstance(sub_check, dict): | |||
for test,value in sub_check.iteritems(): | |||
try: | |||
op = operators[test] | |||
except KeyError: | |||
raise KeyError("Unsupported operator: %r"%test) | |||
if isinstance(op, tuple): | |||
op, join = op | |||
expr = "%s %s ?"%(name, op) | |||
if isinstance(value, (tuple,list)): | |||
expr = '( %s )'%( join.join([expr]*len(value)) ) | |||
args.extend(value) | |||
else: | |||
args.append(value) | |||
expressions.append(expr) | |||
else: | |||
# it's an equality check | |||
expressions.append("%s IS ?"%name) | |||
args.append(sub_check) | |||
expr = " AND ".join(expressions) | |||
return expr, args | |||
def add_record(self, msg_id, rec): | |||
"""Add a new Task Record, by msg_id.""" | |||
d = self._defaults() | |||
d.update(rec) | |||
d['msg_id'] = msg_id | |||
line = self._dict_to_list(d) | |||
tups = '(%s)'%(','.join(['?']*len(line))) | |||
self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line) | |||
self._db.commit() | |||
def get_record(self, msg_id): | |||
"""Get a specific Task Record, by msg_id.""" | |||
cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) | |||
line = cursor.fetchone() | |||
if line is None: | |||
raise KeyError("No such msg: %r"%msg_id) | |||
return self._list_to_dict(line) | |||
def update_record(self, msg_id, rec): | |||
"""Update the data in an existing record.""" | |||
query = "UPDATE %s SET "%self.table | |||
sets = [] | |||
keys = sorted(rec.keys()) | |||
values = [] | |||
for key in keys: | |||
sets.append('%s = ?'%key) | |||
values.append(rec[key]) | |||
query += ', '.join(sets) | |||
query += ' WHERE msg_id == %r'%msg_id | |||
self._db.execute(query, values) | |||
self._db.commit() | |||
def drop_record(self, msg_id): | |||
"""Remove a record from the DB.""" | |||
self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,)) | |||
self._db.commit() | |||
def drop_matching_records(self, check): | |||
"""Remove a record from the DB.""" | |||
expr,args = self._render_expression(check) | |||
query = "DELETE FROM %s WHERE %s"%(self.table, expr) | |||
self._db.execute(query,args) | |||
self._db.commit() | |||
def find_records(self, check, id_only=False): | |||
"""Find records matching a query dict.""" | |||
req = 'msg_id' if id_only else '*' | |||
expr,args = self._render_expression(check) | |||
query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) | |||
cursor = self._db.execute(query, args) | |||
matches = cursor.fetchall() | |||
if id_only: | |||
return [ m[0] for m in matches ] | |||
else: | |||
records = {} | |||
for line in matches: | |||
rec = self._list_to_dict(line) | |||
records[rec['msg_id']] = rec | |||
return records | |||
__all__ = ['SQLiteDB'] |