"""A TaskRecord backend using sqlite3""" # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. import json import os try: import cPickle as pickle except ImportError: import pickle from datetime import datetime try: import sqlite3 except ImportError: sqlite3 = None from zmq.eventloop import ioloop from IPython.utils.traitlets import Unicode, Instance, List, Dict from .dictdb import BaseDB from IPython.utils.jsonutil import date_default, extract_dates, squash_dates from IPython.utils.py3compat import iteritems #----------------------------------------------------------------------------- # SQLite operators, adapters, and converters #----------------------------------------------------------------------------- try: buffer except NameError: # py3k buffer = memoryview operators = { '$lt' : "<", '$gt' : ">", # null is handled weird with ==,!= '$eq' : "=", '$ne' : "!=", '$lte': "<=", '$gte': ">=", '$in' : ('=', ' OR '), '$nin': ('!=', ' AND '), # '$all': None, # '$mod': None, # '$exists' : None } null_operators = { '=' : "IS NULL", '!=' : "IS NOT NULL", } def _adapt_dict(d): return json.dumps(d, default=date_default) def _convert_dict(ds): if ds is None: return ds else: if isinstance(ds, bytes): # If I understand the sqlite doc correctly, this will always be utf8 ds = ds.decode('utf8') return extract_dates(json.loads(ds)) def _adapt_bufs(bufs): # this is *horrible* # copy buffers into single list and pickle it: if bufs and isinstance(bufs[0], (bytes, buffer)): return sqlite3.Binary(pickle.dumps(list(map(bytes, bufs)),-1)) elif bufs: return bufs else: return None def _convert_bufs(bs): if bs is None: return [] else: return pickle.loads(bytes(bs)) #----------------------------------------------------------------------------- # SQLiteDB class #----------------------------------------------------------------------------- class SQLiteDB(BaseDB): """SQLite3 TaskRecord backend.""" filename = Unicode('tasks.db', config=True, help="""The filename of the sqlite task database. [default: 'tasks.db']""") location = Unicode('', config=True, help="""The directory containing the sqlite task database. The default is to use the cluster_dir location.""") table = Unicode("ipython-tasks", config=True, help="""The SQLite Table to use for storing tasks for this session. If unspecified, a new table will be created with the Hub's IDENT. Specifying the table will result in tasks from previous sessions being available via Clients' db_query and get_result methods.""") if sqlite3 is not None: _db = Instance('sqlite3.Connection', allow_none=True) else: _db = None # the ordered list of column names _keys = List(['msg_id' , 'header' , 'metadata', 'content', 'buffers', 'submitted', 'client_uuid' , 'engine_uuid' , 'started', 'completed', 'resubmitted', 'received', 'result_header' , 'result_metadata', 'result_content' , 'result_buffers' , 'queue' , 'execute_input' , 'execute_result', 'error', 'stdout', 'stderr', ]) # sqlite datatypes for checking that db is current format _types = Dict({'msg_id' : 'text' , 'header' : 'dict text', 'metadata' : 'dict text', 'content' : 'dict text', 'buffers' : 'bufs blob', 'submitted' : 'timestamp', 'client_uuid' : 'text', 'engine_uuid' : 'text', 'started' : 'timestamp', 'completed' : 'timestamp', 'resubmitted' : 'text', 'received' : 'timestamp', 'result_header' : 'dict text', 'result_metadata' : 'dict text', 'result_content' : 'dict text', 'result_buffers' : 'bufs blob', 'queue' : 'text', 'execute_input' : 'text', 'execute_result' : 'text', 'error' : 'text', 'stdout' : 'text', 'stderr' : 'text', }) def __init__(self, **kwargs): super(SQLiteDB, self).__init__(**kwargs) if sqlite3 is None: raise ImportError("SQLiteDB requires sqlite3") if not self.table: # use session, and prefix _, since starting with # is illegal self.table = '_'+self.session.replace('-','_') if not self.location: # get current profile from IPython.core.application import BaseIPythonApplication if BaseIPythonApplication.initialized(): app = BaseIPythonApplication.instance() if app.profile_dir is not None: self.location = app.profile_dir.location else: self.location = u'.' else: self.location = u'.' self._init_db() # register db commit as 2s periodic callback # to prevent clogging pipes # assumes we are being run in a zmq ioloop app loop = ioloop.IOLoop.instance() pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop) pc.start() def _defaults(self, keys=None): """create an empty record""" d = {} keys = self._keys if keys is None else keys for key in keys: d[key] = None return d def _check_table(self): """Ensure that an incorrect table doesn't exist If a bad (old) table does exist, return False """ cursor = self._db.execute("PRAGMA table_info('%s')"%self.table) lines = cursor.fetchall() if not lines: # table does not exist return True types = {} keys = [] for line in lines: keys.append(line[1]) types[line[1]] = line[2] if self._keys != keys: # key mismatch self.log.warn('keys mismatch') return False for key in self._keys: if types[key] != self._types[key]: self.log.warn( 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key]) ) return False return True def _init_db(self): """Connect to the database and get new session number.""" # register adapters sqlite3.register_adapter(dict, _adapt_dict) sqlite3.register_converter('dict', _convert_dict) sqlite3.register_adapter(list, _adapt_bufs) sqlite3.register_converter('bufs', _convert_bufs) # connect to the db dbfile = os.path.join(self.location, self.filename) self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES, # isolation_level = None)#, cached_statements=64) # print dir(self._db) first_table = previous_table = self.table i=0 while not self._check_table(): i+=1 self.table = first_table+'_%i'%i self.log.warn( "Table %s exists and doesn't match db format, trying %s"% (previous_table, self.table) ) previous_table = self.table self._db.execute("""CREATE TABLE IF NOT EXISTS '%s' (msg_id text PRIMARY KEY, header dict text, metadata dict text, content dict text, buffers bufs blob, submitted timestamp, client_uuid text, engine_uuid text, started timestamp, completed timestamp, resubmitted text, received timestamp, result_header dict text, result_metadata dict text, result_content dict text, result_buffers bufs blob, queue text, execute_input text, execute_result text, error text, stdout text, stderr text) """%self.table) self._db.commit() def _dict_to_list(self, d): """turn a mongodb-style record dict into a list.""" return [ d[key] for key in self._keys ] def _list_to_dict(self, line, keys=None): """Inverse of dict_to_list""" keys = self._keys if keys is None else keys d = self._defaults(keys) for key,value in zip(keys, line): d[key] = value return d def _render_expression(self, check): """Turn a mongodb-style search dict into an SQL query.""" expressions = [] args = [] skeys = set(check.keys()) skeys.difference_update(set(self._keys)) skeys.difference_update(set(['buffers', 'result_buffers'])) if skeys: raise KeyError("Illegal testing key(s): %s"%skeys) for name,sub_check in iteritems(check): if isinstance(sub_check, dict): for test,value in iteritems(sub_check): try: op = operators[test] except KeyError: raise KeyError("Unsupported operator: %r"%test) if isinstance(op, tuple): op, join = op if value is None and op in null_operators: expr = "%s %s" % (name, null_operators[op]) else: expr = "%s %s ?"%(name, op) if isinstance(value, (tuple,list)): if op in null_operators and any([v is None for v in value]): # equality tests don't work with NULL raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test) expr = '( %s )'%( join.join([expr]*len(value)) ) args.extend(value) else: args.append(value) expressions.append(expr) else: # it's an equality check if sub_check is None: expressions.append("%s IS NULL" % name) else: expressions.append("%s = ?"%name) args.append(sub_check) expr = " AND ".join(expressions) return expr, args def add_record(self, msg_id, rec): """Add a new Task Record, by msg_id.""" d = self._defaults() d.update(rec) d['msg_id'] = msg_id line = self._dict_to_list(d) tups = '(%s)'%(','.join(['?']*len(line))) self._db.execute("INSERT INTO '%s' VALUES %s"%(self.table, tups), line) # self._db.commit() def get_record(self, msg_id): """Get a specific Task Record, by msg_id.""" cursor = self._db.execute("""SELECT * FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,)) line = cursor.fetchone() if line is None: raise KeyError("No such msg: %r"%msg_id) return self._list_to_dict(line) def update_record(self, msg_id, rec): """Update the data in an existing record.""" query = "UPDATE '%s' SET "%self.table sets = [] keys = sorted(rec.keys()) values = [] for key in keys: sets.append('%s = ?'%key) values.append(rec[key]) query += ', '.join(sets) query += ' WHERE msg_id == ?' values.append(msg_id) self._db.execute(query, values) # self._db.commit() def drop_record(self, msg_id): """Remove a record from the DB.""" self._db.execute("""DELETE FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,)) # self._db.commit() def drop_matching_records(self, check): """Remove a record from the DB.""" expr,args = self._render_expression(check) query = "DELETE FROM '%s' WHERE %s"%(self.table, expr) self._db.execute(query,args) # self._db.commit() def find_records(self, check, keys=None): """Find records matching a query dict, optionally extracting subset of keys. Returns list of matching records. Parameters ---------- check: dict mongodb-style query argument keys: list of strs [optional] if specified, the subset of keys to extract. msg_id will *always* be included. """ if keys: bad_keys = [ key for key in keys if key not in self._keys ] if bad_keys: raise KeyError("Bad record key(s): %s"%bad_keys) if keys: # ensure msg_id is present and first: if 'msg_id' in keys: keys.remove('msg_id') keys.insert(0, 'msg_id') req = ', '.join(keys) else: req = '*' expr,args = self._render_expression(check) query = """SELECT %s FROM '%s' WHERE %s"""%(req, self.table, expr) cursor = self._db.execute(query, args) matches = cursor.fetchall() records = [] for line in matches: rec = self._list_to_dict(line, keys) records.append(rec) return records def get_history(self): """get all msg_ids, ordered by time submitted.""" query = """SELECT msg_id FROM '%s' ORDER by submitted ASC"""%self.table cursor = self._db.execute(query) # will be a list of length 1 tuples return [ tup[0] for tup in cursor.fetchall()] __all__ = ['SQLiteDB']