Show More
@@ -0,0 +1,272 b'' | |||||
|
1 | """A TaskRecord backend using sqlite3""" | |||
|
2 | #----------------------------------------------------------------------------- | |||
|
3 | # Copyright (C) 2011 The IPython Development Team | |||
|
4 | # | |||
|
5 | # Distributed under the terms of the BSD License. The full license is in | |||
|
6 | # the file COPYING, distributed as part of this software. | |||
|
7 | #----------------------------------------------------------------------------- | |||
|
8 | ||||
|
9 | import json | |||
|
10 | import os | |||
|
11 | import cPickle as pickle | |||
|
12 | from datetime import datetime | |||
|
13 | ||||
|
14 | import sqlite3 | |||
|
15 | ||||
|
16 | from IPython.utils.traitlets import CUnicode, CStr, Instance, List | |||
|
17 | from .dictdb import BaseDB | |||
|
18 | from .util import ISO8601 | |||
|
19 | ||||
|
20 | #----------------------------------------------------------------------------- | |||
|
21 | # SQLite operators, adapters, and converters | |||
|
22 | #----------------------------------------------------------------------------- | |||
|
23 | ||||
|
24 | operators = { | |||
|
25 | '$lt' : lambda a,b: "%s < ?", | |||
|
26 | '$gt' : ">", | |||
|
27 | # null is handled weird with ==,!= | |||
|
28 | '$eq' : "IS", | |||
|
29 | '$ne' : "IS NOT", | |||
|
30 | '$lte': "<=", | |||
|
31 | '$gte': ">=", | |||
|
32 | '$in' : ('IS', ' OR '), | |||
|
33 | '$nin': ('IS NOT', ' AND '), | |||
|
34 | # '$all': None, | |||
|
35 | # '$mod': None, | |||
|
36 | # '$exists' : None | |||
|
37 | } | |||
|
38 | ||||
|
39 | def _adapt_datetime(dt): | |||
|
40 | return dt.strftime(ISO8601) | |||
|
41 | ||||
|
42 | def _convert_datetime(ds): | |||
|
43 | if ds is None: | |||
|
44 | return ds | |||
|
45 | else: | |||
|
46 | return datetime.strptime(ds, ISO8601) | |||
|
47 | ||||
|
48 | def _adapt_dict(d): | |||
|
49 | return json.dumps(d) | |||
|
50 | ||||
|
51 | def _convert_dict(ds): | |||
|
52 | if ds is None: | |||
|
53 | return ds | |||
|
54 | else: | |||
|
55 | return json.loads(ds) | |||
|
56 | ||||
|
57 | def _adapt_bufs(bufs): | |||
|
58 | # this is *horrible* | |||
|
59 | # copy buffers into single list and pickle it: | |||
|
60 | if bufs and isinstance(bufs[0], (bytes, buffer)): | |||
|
61 | return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1)) | |||
|
62 | elif bufs: | |||
|
63 | return bufs | |||
|
64 | else: | |||
|
65 | return None | |||
|
66 | ||||
|
67 | def _convert_bufs(bs): | |||
|
68 | if bs is None: | |||
|
69 | return [] | |||
|
70 | else: | |||
|
71 | return pickle.loads(bytes(bs)) | |||
|
72 | ||||
|
73 | #----------------------------------------------------------------------------- | |||
|
74 | # SQLiteDB class | |||
|
75 | #----------------------------------------------------------------------------- | |||
|
76 | ||||
|
77 | class SQLiteDB(BaseDB): | |||
|
78 | """SQLite3 TaskRecord backend.""" | |||
|
79 | ||||
|
80 | filename = CUnicode('tasks.db', config=True) | |||
|
81 | location = CUnicode('', config=True) | |||
|
82 | table = CUnicode("", config=True) | |||
|
83 | ||||
|
84 | _db = Instance('sqlite3.Connection') | |||
|
85 | _keys = List(['msg_id' , | |||
|
86 | 'header' , | |||
|
87 | 'content', | |||
|
88 | 'buffers', | |||
|
89 | 'submitted', | |||
|
90 | 'client_uuid' , | |||
|
91 | 'engine_uuid' , | |||
|
92 | 'started', | |||
|
93 | 'completed', | |||
|
94 | 'resubmitted', | |||
|
95 | 'result_header' , | |||
|
96 | 'result_content' , | |||
|
97 | 'result_buffers' , | |||
|
98 | 'queue' , | |||
|
99 | 'pyin' , | |||
|
100 | 'pyout', | |||
|
101 | 'pyerr', | |||
|
102 | 'stdout', | |||
|
103 | 'stderr', | |||
|
104 | ]) | |||
|
105 | ||||
|
106 | def __init__(self, **kwargs): | |||
|
107 | super(SQLiteDB, self).__init__(**kwargs) | |||
|
108 | if not self.table: | |||
|
109 | # use session, and prefix _, since starting with # is illegal | |||
|
110 | self.table = '_'+self.session.replace('-','_') | |||
|
111 | if not self.location: | |||
|
112 | if hasattr(self.config.Global, 'cluster_dir'): | |||
|
113 | self.location = self.config.Global.cluster_dir | |||
|
114 | else: | |||
|
115 | self.location = '.' | |||
|
116 | self._init_db() | |||
|
117 | ||||
|
118 | def _defaults(self): | |||
|
119 | """create an empty record""" | |||
|
120 | d = {} | |||
|
121 | for key in self._keys: | |||
|
122 | d[key] = None | |||
|
123 | return d | |||
|
124 | ||||
|
125 | def _init_db(self): | |||
|
126 | """Connect to the database and get new session number.""" | |||
|
127 | # register adapters | |||
|
128 | sqlite3.register_adapter(datetime, _adapt_datetime) | |||
|
129 | sqlite3.register_converter('datetime', _convert_datetime) | |||
|
130 | sqlite3.register_adapter(dict, _adapt_dict) | |||
|
131 | sqlite3.register_converter('dict', _convert_dict) | |||
|
132 | sqlite3.register_adapter(list, _adapt_bufs) | |||
|
133 | sqlite3.register_converter('bufs', _convert_bufs) | |||
|
134 | # connect to the db | |||
|
135 | dbfile = os.path.join(self.location, self.filename) | |||
|
136 | self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES) | |||
|
137 | ||||
|
138 | self._db.execute("""CREATE TABLE IF NOT EXISTS %s | |||
|
139 | (msg_id text PRIMARY KEY, | |||
|
140 | header dict text, | |||
|
141 | content dict text, | |||
|
142 | buffers bufs blob, | |||
|
143 | submitted datetime text, | |||
|
144 | client_uuid text, | |||
|
145 | engine_uuid text, | |||
|
146 | started datetime text, | |||
|
147 | completed datetime text, | |||
|
148 | resubmitted datetime text, | |||
|
149 | result_header dict text, | |||
|
150 | result_content dict text, | |||
|
151 | result_buffers bufs blob, | |||
|
152 | queue text, | |||
|
153 | pyin text, | |||
|
154 | pyout text, | |||
|
155 | pyerr text, | |||
|
156 | stdout text, | |||
|
157 | stderr text) | |||
|
158 | """%self.table) | |||
|
159 | # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers | |||
|
160 | # (msg_id text, result integer, buffer blob) | |||
|
161 | # """%self.table) | |||
|
162 | self._db.commit() | |||
|
163 | ||||
|
164 | def _dict_to_list(self, d): | |||
|
165 | """turn a mongodb-style record dict into a list.""" | |||
|
166 | ||||
|
167 | return [ d[key] for key in self._keys ] | |||
|
168 | ||||
|
169 | def _list_to_dict(self, line): | |||
|
170 | """Inverse of dict_to_list""" | |||
|
171 | d = self._defaults() | |||
|
172 | for key,value in zip(self._keys, line): | |||
|
173 | d[key] = value | |||
|
174 | ||||
|
175 | return d | |||
|
176 | ||||
|
177 | def _render_expression(self, check): | |||
|
178 | """Turn a mongodb-style search dict into an SQL query.""" | |||
|
179 | expressions = [] | |||
|
180 | args = [] | |||
|
181 | ||||
|
182 | skeys = set(check.keys()) | |||
|
183 | skeys.difference_update(set(self._keys)) | |||
|
184 | skeys.difference_update(set(['buffers', 'result_buffers'])) | |||
|
185 | if skeys: | |||
|
186 | raise KeyError("Illegal testing key(s): %s"%skeys) | |||
|
187 | ||||
|
188 | for name,sub_check in check.iteritems(): | |||
|
189 | if isinstance(sub_check, dict): | |||
|
190 | for test,value in sub_check.iteritems(): | |||
|
191 | try: | |||
|
192 | op = operators[test] | |||
|
193 | except KeyError: | |||
|
194 | raise KeyError("Unsupported operator: %r"%test) | |||
|
195 | if isinstance(op, tuple): | |||
|
196 | op, join = op | |||
|
197 | expr = "%s %s ?"%(name, op) | |||
|
198 | if isinstance(value, (tuple,list)): | |||
|
199 | expr = '( %s )'%( join.join([expr]*len(value)) ) | |||
|
200 | args.extend(value) | |||
|
201 | else: | |||
|
202 | args.append(value) | |||
|
203 | expressions.append(expr) | |||
|
204 | else: | |||
|
205 | # it's an equality check | |||
|
206 | expressions.append("%s IS ?"%name) | |||
|
207 | args.append(sub_check) | |||
|
208 | ||||
|
209 | expr = " AND ".join(expressions) | |||
|
210 | return expr, args | |||
|
211 | ||||
|
212 | def add_record(self, msg_id, rec): | |||
|
213 | """Add a new Task Record, by msg_id.""" | |||
|
214 | d = self._defaults() | |||
|
215 | d.update(rec) | |||
|
216 | d['msg_id'] = msg_id | |||
|
217 | line = self._dict_to_list(d) | |||
|
218 | tups = '(%s)'%(','.join(['?']*len(line))) | |||
|
219 | self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line) | |||
|
220 | self._db.commit() | |||
|
221 | ||||
|
222 | def get_record(self, msg_id): | |||
|
223 | """Get a specific Task Record, by msg_id.""" | |||
|
224 | cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) | |||
|
225 | line = cursor.fetchone() | |||
|
226 | if line is None: | |||
|
227 | raise KeyError("No such msg: %r"%msg_id) | |||
|
228 | return self._list_to_dict(line) | |||
|
229 | ||||
|
230 | def update_record(self, msg_id, rec): | |||
|
231 | """Update the data in an existing record.""" | |||
|
232 | query = "UPDATE %s SET "%self.table | |||
|
233 | sets = [] | |||
|
234 | keys = sorted(rec.keys()) | |||
|
235 | values = [] | |||
|
236 | for key in keys: | |||
|
237 | sets.append('%s = ?'%key) | |||
|
238 | values.append(rec[key]) | |||
|
239 | query += ', '.join(sets) | |||
|
240 | query += ' WHERE msg_id == %r'%msg_id | |||
|
241 | self._db.execute(query, values) | |||
|
242 | self._db.commit() | |||
|
243 | ||||
|
244 | def drop_record(self, msg_id): | |||
|
245 | """Remove a record from the DB.""" | |||
|
246 | self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,)) | |||
|
247 | self._db.commit() | |||
|
248 | ||||
|
249 | def drop_matching_records(self, check): | |||
|
250 | """Remove a record from the DB.""" | |||
|
251 | expr,args = self._render_expression(check) | |||
|
252 | query = "DELETE FROM %s WHERE %s"%(self.table, expr) | |||
|
253 | self._db.execute(query,args) | |||
|
254 | self._db.commit() | |||
|
255 | ||||
|
256 | def find_records(self, check, id_only=False): | |||
|
257 | """Find records matching a query dict.""" | |||
|
258 | req = 'msg_id' if id_only else '*' | |||
|
259 | expr,args = self._render_expression(check) | |||
|
260 | query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) | |||
|
261 | cursor = self._db.execute(query, args) | |||
|
262 | matches = cursor.fetchall() | |||
|
263 | if id_only: | |||
|
264 | return [ m[0] for m in matches ] | |||
|
265 | else: | |||
|
266 | records = {} | |||
|
267 | for line in matches: | |||
|
268 | rec = self._list_to_dict(line) | |||
|
269 | records[rec['msg_id']] = rec | |||
|
270 | return records | |||
|
271 | ||||
|
272 | __all__ = ['SQLiteDB'] No newline at end of file |
General Comments 0
You need to be logged in to leave comments.
Login now