"""A TaskRecord backend using sqlite3""" #----------------------------------------------------------------------------- # Copyright (C) 2011 The IPython Development Team # # Distributed under the terms of the BSD License. The full license is in # the file COPYING, distributed as part of this software. #----------------------------------------------------------------------------- import json import os import cPickle as pickle from datetime import datetime import sqlite3 from IPython.utils.traitlets import CUnicode, CStr, Instance, List from .dictdb import BaseDB from .util import ISO8601 #----------------------------------------------------------------------------- # SQLite operators, adapters, and converters #----------------------------------------------------------------------------- operators = { '$lt' : lambda a,b: "%s < ?", '$gt' : ">", # null is handled weird with ==,!= '$eq' : "IS", '$ne' : "IS NOT", '$lte': "<=", '$gte': ">=", '$in' : ('IS', ' OR '), '$nin': ('IS NOT', ' AND '), # '$all': None, # '$mod': None, # '$exists' : None } def _adapt_datetime(dt): return dt.strftime(ISO8601) def _convert_datetime(ds): if ds is None: return ds else: return datetime.strptime(ds, ISO8601) def _adapt_dict(d): return json.dumps(d) def _convert_dict(ds): if ds is None: return ds else: return json.loads(ds) def _adapt_bufs(bufs): # this is *horrible* # copy buffers into single list and pickle it: if bufs and isinstance(bufs[0], (bytes, buffer)): return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1)) elif bufs: return bufs else: return None def _convert_bufs(bs): if bs is None: return [] else: return pickle.loads(bytes(bs)) #----------------------------------------------------------------------------- # SQLiteDB class #----------------------------------------------------------------------------- class SQLiteDB(BaseDB): """SQLite3 TaskRecord backend.""" filename = CUnicode('tasks.db', config=True) location = CUnicode('', config=True) table = CUnicode("", config=True) _db = Instance('sqlite3.Connection') _keys = List(['msg_id' , 'header' , 'content', 'buffers', 'submitted', 'client_uuid' , 'engine_uuid' , 'started', 'completed', 'resubmitted', 'result_header' , 'result_content' , 'result_buffers' , 'queue' , 'pyin' , 'pyout', 'pyerr', 'stdout', 'stderr', ]) def __init__(self, **kwargs): super(SQLiteDB, self).__init__(**kwargs) if not self.table: # use session, and prefix _, since starting with # is illegal self.table = '_'+self.session.replace('-','_') if not self.location: if hasattr(self.config.Global, 'cluster_dir'): self.location = self.config.Global.cluster_dir else: self.location = '.' self._init_db() def _defaults(self): """create an empty record""" d = {} for key in self._keys: d[key] = None return d def _init_db(self): """Connect to the database and get new session number.""" # register adapters sqlite3.register_adapter(datetime, _adapt_datetime) sqlite3.register_converter('datetime', _convert_datetime) sqlite3.register_adapter(dict, _adapt_dict) sqlite3.register_converter('dict', _convert_dict) sqlite3.register_adapter(list, _adapt_bufs) sqlite3.register_converter('bufs', _convert_bufs) # connect to the db dbfile = os.path.join(self.location, self.filename) self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES) self._db.execute("""CREATE TABLE IF NOT EXISTS %s (msg_id text PRIMARY KEY, header dict text, content dict text, buffers bufs blob, submitted datetime text, client_uuid text, engine_uuid text, started datetime text, completed datetime text, resubmitted datetime text, result_header dict text, result_content dict text, result_buffers bufs blob, queue text, pyin text, pyout text, pyerr text, stdout text, stderr text) """%self.table) # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers # (msg_id text, result integer, buffer blob) # """%self.table) self._db.commit() def _dict_to_list(self, d): """turn a mongodb-style record dict into a list.""" return [ d[key] for key in self._keys ] def _list_to_dict(self, line): """Inverse of dict_to_list""" d = self._defaults() for key,value in zip(self._keys, line): d[key] = value return d def _render_expression(self, check): """Turn a mongodb-style search dict into an SQL query.""" expressions = [] args = [] skeys = set(check.keys()) skeys.difference_update(set(self._keys)) skeys.difference_update(set(['buffers', 'result_buffers'])) if skeys: raise KeyError("Illegal testing key(s): %s"%skeys) for name,sub_check in check.iteritems(): if isinstance(sub_check, dict): for test,value in sub_check.iteritems(): try: op = operators[test] except KeyError: raise KeyError("Unsupported operator: %r"%test) if isinstance(op, tuple): op, join = op expr = "%s %s ?"%(name, op) if isinstance(value, (tuple,list)): expr = '( %s )'%( join.join([expr]*len(value)) ) args.extend(value) else: args.append(value) expressions.append(expr) else: # it's an equality check expressions.append("%s IS ?"%name) args.append(sub_check) expr = " AND ".join(expressions) return expr, args def add_record(self, msg_id, rec): """Add a new Task Record, by msg_id.""" d = self._defaults() d.update(rec) d['msg_id'] = msg_id line = self._dict_to_list(d) tups = '(%s)'%(','.join(['?']*len(line))) self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line) self._db.commit() def get_record(self, msg_id): """Get a specific Task Record, by msg_id.""" cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) line = cursor.fetchone() if line is None: raise KeyError("No such msg: %r"%msg_id) return self._list_to_dict(line) def update_record(self, msg_id, rec): """Update the data in an existing record.""" query = "UPDATE %s SET "%self.table sets = [] keys = sorted(rec.keys()) values = [] for key in keys: sets.append('%s = ?'%key) values.append(rec[key]) query += ', '.join(sets) query += ' WHERE msg_id == %r'%msg_id self._db.execute(query, values) self._db.commit() def drop_record(self, msg_id): """Remove a record from the DB.""" self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,)) self._db.commit() def drop_matching_records(self, check): """Remove a record from the DB.""" expr,args = self._render_expression(check) query = "DELETE FROM %s WHERE %s"%(self.table, expr) self._db.execute(query,args) self._db.commit() def find_records(self, check, id_only=False): """Find records matching a query dict.""" req = 'msg_id' if id_only else '*' expr,args = self._render_expression(check) query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) cursor = self._db.execute(query, args) matches = cursor.fetchall() if id_only: return [ m[0] for m in matches ] else: records = {} for line in matches: rec = self._list_to_dict(line) records[rec['msg_id']] = rec return records __all__ = ['SQLiteDB']