diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index a165c23..de810a2 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -1215,5 +1215,78 @@ class Client(HasTraits): if content['status'] != 'ok': raise self._unwrap_exception(content) + @spin_first + def hub_history(self): + """Get the Hub's history + + Just like the Client, the Hub has a history, which is a list of msg_ids. + This will contain the history of all clients, and, depending on configuration, + may contain history across multiple cluster sessions. + + Any msg_id returned here is a valid argument to `get_result`. + + Returns + ------- + + msg_ids : list of strs + list of all msg_ids, ordered by task submission time. + """ + + self.session.send(self._query_socket, "history_request", content={}) + idents, msg = self.session.recv(self._query_socket, 0) + + if self.debug: + pprint(msg) + content = msg['content'] + if content['status'] != 'ok': + raise self._unwrap_exception(content) + else: + return content['history'] + + @spin_first + def db_query(self, query, keys=None): + """Query the Hub's TaskRecord database + + This will return a list of task record dicts that match `query` + + Parameters + ---------- + + query : mongodb query dict + The search dict. See mongodb query docs for details. + keys : list of strs [optional] + THe subset of keys to be returned. The default is to fetch everything. + 'msg_id' will *always* be included. + """ + content = dict(query=query, keys=keys) + self.session.send(self._query_socket, "db_request", content=content) + idents, msg = self.session.recv(self._query_socket, 0) + if self.debug: + pprint(msg) + content = msg['content'] + if content['status'] != 'ok': + raise self._unwrap_exception(content) + + records = content['records'] + buffer_lens = content['buffer_lens'] + result_buffer_lens = content['result_buffer_lens'] + buffers = msg['buffers'] + has_bufs = buffer_lens is not None + has_rbufs = result_buffer_lens is not None + for i,rec in enumerate(records): + # relink buffers + if has_bufs: + blen = buffer_lens[i] + rec['buffers'], buffers = buffers[:blen],buffers[blen:] + if has_rbufs: + blen = result_buffer_lens[i] + rec['result_buffers'], buffers = buffers[:blen],buffers[blen:] + # turn timestamps back into times + for key in 'submitted started completed resubmitted'.split(): + maybedate = rec.get(key, None) + if maybedate and util.ISO8601_RE.match(maybedate): + rec[key] = datetime.strptime(maybedate, util.ISO8601) + + return records __all__ = [ 'Client' ] diff --git a/IPython/parallel/controller/dictdb.py b/IPython/parallel/controller/dictdb.py index 9b7e48d..6bf2ca8 100644 --- a/IPython/parallel/controller/dictdb.py +++ b/IPython/parallel/controller/dictdb.py @@ -103,9 +103,9 @@ class DictDB(BaseDB): return False return True - def _match(self, check, id_only=True): + def _match(self, check): """Find all the matches for a check dict.""" - matches = {} + matches = [] tests = {} for k,v in check.iteritems(): if isinstance(v, dict): @@ -113,14 +113,18 @@ class DictDB(BaseDB): else: tests[k] = lambda o: o==v - for msg_id, rec in self._records.iteritems(): + for rec in self._records.itervalues(): if self._match_one(rec, tests): - matches[msg_id] = rec - if id_only: - return matches.keys() - else: - return matches - + matches.append(rec) + return matches + + def _extract_subdict(self, rec, keys): + """extract subdict of keys""" + d = {} + d['msg_id'] = rec['msg_id'] + for key in keys: + d[key] = rec[key] + return d def add_record(self, msg_id, rec): """Add a new Task Record, by msg_id.""" @@ -140,7 +144,7 @@ class DictDB(BaseDB): def drop_matching_records(self, check): """Remove a record from the DB.""" - matches = self._match(check, id_only=True) + matches = self._match(check) for m in matches: del self._records[m] @@ -149,7 +153,28 @@ class DictDB(BaseDB): del self._records[msg_id] - def find_records(self, check, id_only=False): - """Find records matching a query dict.""" - matches = self._match(check, id_only) - return matches \ No newline at end of file + def find_records(self, check, keys=None): + """Find records matching a query dict, optionally extracting subset of keys. + + Returns dict keyed by msg_id of matching records. + + Parameters + ---------- + + check: dict + mongodb-style query argument + keys: list of strs [optional] + if specified, the subset of keys to extract. msg_id will *always* be + included. + """ + matches = self._match(check) + if keys: + return [ self._extract_subdict(rec, keys) for rec in matches ] + else: + return matches + + + def get_history(self): + """get all msg_ids, ordered by time submitted.""" + msg_ids = self._records.keys() + return sorted(msg_ids, key=lambda m: self._records[m]['submitted']) diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index a32dde7..9b2ffcb 100755 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -27,9 +27,8 @@ from zmq.eventloop.zmqstream import ZMQStream from IPython.utils.importstring import import_item from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool -from IPython.parallel import error +from IPython.parallel import error, util from IPython.parallel.factory import RegistrationFactory, LoggingFactory -from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601 from .heartmonitor import HeartMonitor @@ -76,7 +75,7 @@ def init_record(msg): 'header' : header, 'content': msg['content'], 'buffers': msg['buffers'], - 'submitted': datetime.strptime(header['date'], ISO8601), + 'submitted': datetime.strptime(header['date'], util.ISO8601), 'client_uuid' : None, 'engine_uuid' : None, 'started': None, @@ -119,32 +118,32 @@ class HubFactory(RegistrationFactory): # port-pairs for monitoredqueues: hb = Instance(list, config=True) def _hb_default(self): - return select_random_ports(2) + return util.select_random_ports(2) mux = Instance(list, config=True) def _mux_default(self): - return select_random_ports(2) + return util.select_random_ports(2) task = Instance(list, config=True) def _task_default(self): - return select_random_ports(2) + return util.select_random_ports(2) control = Instance(list, config=True) def _control_default(self): - return select_random_ports(2) + return util.select_random_ports(2) iopub = Instance(list, config=True) def _iopub_default(self): - return select_random_ports(2) + return util.select_random_ports(2) # single ports: mon_port = Instance(int, config=True) def _mon_port_default(self): - return select_random_ports(1)[0] + return util.select_random_ports(1)[0] notifier_port = Instance(int, config=True) def _notifier_port_default(self): - return select_random_ports(1)[0] + return util.select_random_ports(1)[0] ping = Int(1000, config=True) # ping frequency @@ -344,11 +343,11 @@ class Hub(LoggingFactory): # validate connection dicts: for k,v in self.client_info.iteritems(): if k == 'task': - validate_url_container(v[1]) + util.validate_url_container(v[1]) else: - validate_url_container(v) - # validate_url_container(self.client_info) - validate_url_container(self.engine_info) + util.validate_url_container(v) + # util.validate_url_container(self.client_info) + util.validate_url_container(self.engine_info) # register our callbacks self.query.on_recv(self.dispatch_query) @@ -369,6 +368,8 @@ class Hub(LoggingFactory): self.query_handlers = {'queue_request': self.queue_status, 'result_request': self.get_results, + 'history_request': self.get_history, + 'db_request': self.db_query, 'purge_request': self.purge_results, 'load_request': self.check_load, 'resubmit_request': self.resubmit_task, @@ -606,10 +607,10 @@ class Hub(LoggingFactory): return # update record anyway, because the unregistration could have been premature rheader = msg['header'] - completed = datetime.strptime(rheader['date'], ISO8601) + completed = datetime.strptime(rheader['date'], util.ISO8601) started = rheader.get('started', None) if started is not None: - started = datetime.strptime(started, ISO8601) + started = datetime.strptime(started, util.ISO8601) result = { 'result_header' : rheader, 'result_content': msg['content'], @@ -618,7 +619,10 @@ class Hub(LoggingFactory): } result['result_buffers'] = msg['buffers'] - self.db.update_record(msg_id, result) + try: + self.db.update_record(msg_id, result) + except Exception: + self.log.error("DB Error updating record %r"%msg_id, exc_info=True) #--------------------- Task Queue Traffic ------------------------------ @@ -653,6 +657,8 @@ class Hub(LoggingFactory): self.db.update_record(msg_id, record) except KeyError: self.db.add_record(msg_id, record) + except Exception: + self.log.error("DB Error saving task request %r"%msg_id, exc_info=True) def save_task_result(self, idents, msg): """save the result of a completed task.""" @@ -685,10 +691,10 @@ class Hub(LoggingFactory): self.completed[eid].append(msg_id) if msg_id in self.tasks[eid]: self.tasks[eid].remove(msg_id) - completed = datetime.strptime(header['date'], ISO8601) + completed = datetime.strptime(header['date'], util.ISO8601) started = header.get('started', None) if started is not None: - started = datetime.strptime(started, ISO8601) + started = datetime.strptime(started, util.ISO8601) result = { 'result_header' : header, 'result_content': msg['content'], @@ -698,7 +704,10 @@ class Hub(LoggingFactory): } result['result_buffers'] = msg['buffers'] - self.db.update_record(msg_id, result) + try: + self.db.update_record(msg_id, result) + except Exception: + self.log.error("DB Error saving task request %r"%msg_id, exc_info=True) else: self.log.debug("task::unknown task %s finished"%msg_id) @@ -723,7 +732,11 @@ class Hub(LoggingFactory): self.tasks[eid].append(msg_id) # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid)) - self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) + try: + self.db.update_record(msg_id, dict(engine_uuid=engine_uuid)) + except Exception: + self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True) + def mia_task_request(self, idents, msg): raise NotImplementedError @@ -772,7 +785,10 @@ class Hub(LoggingFactory): else: d[msg_type] = content.get('data', '') - self.db.update_record(msg_id, d) + try: + self.db.update_record(msg_id, d) + except Exception: + self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True) @@ -904,11 +920,15 @@ class Hub(LoggingFactory): # build a fake header: header = {} header['engine'] = uuid - header['date'] = datetime.now().strftime(ISO8601) + header['date'] = datetime.now() rec = dict(result_content=content, result_header=header, result_buffers=[]) rec['completed'] = header['date'] rec['engine_uuid'] = uuid - self.db.update_record(msg_id, rec) + try: + self.db.update_record(msg_id, rec) + except Exception: + self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True) + def finish_registration(self, heart): """Second half of engine registration, called after our HeartMonitor @@ -1017,7 +1037,10 @@ class Hub(LoggingFactory): msg_ids = content.get('msg_ids', []) reply = dict(status='ok') if msg_ids == 'all': - self.db.drop_matching_records(dict(completed={'$ne':None})) + try: + self.db.drop_matching_records(dict(completed={'$ne':None})) + except Exception: + reply = error.wrap_exception() else: for msg_id in msg_ids: if msg_id in self.all_completed: @@ -1044,7 +1067,11 @@ class Hub(LoggingFactory): break msg_ids = self.completed.pop(eid) uid = self.engines[eid].queue - self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) + try: + self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) + except Exception: + reply = error.wrap_exception() + break self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) @@ -1052,6 +1079,23 @@ class Hub(LoggingFactory): """Resubmit a task.""" raise NotImplementedError + def _extract_record(self, rec): + """decompose a TaskRecord dict into subsection of reply for get_result""" + io_dict = {} + for key in 'pyin pyout pyerr stdout stderr'.split(): + io_dict[key] = rec[key] + content = { 'result_content': rec['result_content'], + 'header': rec['header'], + 'result_header' : rec['result_header'], + 'io' : io_dict, + } + if rec['result_buffers']: + buffers = map(str, rec['result_buffers']) + else: + buffers = [] + + return content, buffers + def get_results(self, client_id, msg): """Get the result of 1 or more messages.""" content = msg['content'] @@ -1064,25 +1108,28 @@ class Hub(LoggingFactory): content['completed'] = completed buffers = [] if not statusonly: - content['results'] = {} - records = self.db.find_records(dict(msg_id={'$in':msg_ids})) + try: + matches = self.db.find_records(dict(msg_id={'$in':msg_ids})) + # turn match list into dict, for faster lookup + records = {} + for rec in matches: + records[rec['msg_id']] = rec + except Exception: + content = error.wrap_exception() + self.session.send(self.query, "result_reply", content=content, + parent=msg, ident=client_id) + return + else: + records = {} for msg_id in msg_ids: if msg_id in self.pending: pending.append(msg_id) - elif msg_id in self.all_completed: + elif msg_id in self.all_completed or msg_id in records: completed.append(msg_id) if not statusonly: - rec = records[msg_id] - io_dict = {} - for key in 'pyin pyout pyerr stdout stderr'.split(): - io_dict[key] = rec[key] - content[msg_id] = { 'result_content': rec['result_content'], - 'header': rec['header'], - 'result_header' : rec['result_header'], - 'io' : io_dict, - } - if rec['result_buffers']: - buffers.extend(map(str, rec['result_buffers'])) + c,bufs = self._extract_record(records[msg_id]) + content[msg_id] = c + buffers.extend(bufs) else: try: raise KeyError('No such message: '+msg_id) @@ -1093,3 +1140,54 @@ class Hub(LoggingFactory): parent=msg, ident=client_id, buffers=buffers) + def get_history(self, client_id, msg): + """Get a list of all msg_ids in our DB records""" + try: + msg_ids = self.db.get_history() + except Exception as e: + content = error.wrap_exception() + else: + content = dict(status='ok', history=msg_ids) + + self.session.send(self.query, "history_reply", content=content, + parent=msg, ident=client_id) + + def db_query(self, client_id, msg): + """Perform a raw query on the task record database.""" + content = msg['content'] + query = content.get('query', {}) + keys = content.get('keys', None) + query = util.extract_dates(query) + buffers = [] + empty = list() + + try: + records = self.db.find_records(query, keys) + except Exception as e: + content = error.wrap_exception() + else: + # extract buffers from reply content: + if keys is not None: + buffer_lens = [] if 'buffers' in keys else None + result_buffer_lens = [] if 'result_buffers' in keys else None + else: + buffer_lens = [] + result_buffer_lens = [] + + for rec in records: + # buffers may be None, so double check + if buffer_lens is not None: + b = rec.pop('buffers', empty) or empty + buffer_lens.append(len(b)) + buffers.extend(b) + if result_buffer_lens is not None: + rb = rec.pop('result_buffers', empty) or empty + result_buffer_lens.append(len(rb)) + buffers.extend(rb) + content = dict(status='ok', records=records, buffer_lens=buffer_lens, + result_buffer_lens=result_buffer_lens) + + self.session.send(self.query, "db_reply", content=content, + parent=msg, ident=client_id, + buffers=buffers) + diff --git a/IPython/parallel/controller/mongodb.py b/IPython/parallel/controller/mongodb.py index d2c4080..9100d48 100644 --- a/IPython/parallel/controller/mongodb.py +++ b/IPython/parallel/controller/mongodb.py @@ -22,9 +22,9 @@ from .dictdb import BaseDB class MongoDB(BaseDB): """MongoDB TaskRecord backend.""" - connection_args = List(config=True) - connection_kwargs = Dict(config=True) - database = CUnicode(config=True) + connection_args = List(config=True) # args passed to pymongo.Connection + connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection + database = CUnicode(config=True) # name of the mongodb database _table = Dict() def __init__(self, **kwargs): @@ -37,13 +37,14 @@ class MongoDB(BaseDB): def _binary_buffers(self, rec): for key in ('buffers', 'result_buffers'): - if key in rec: + if rec.get(key, None): rec[key] = map(Binary, rec[key]) + return rec def add_record(self, msg_id, rec): """Add a new Task Record, by msg_id.""" # print rec - rec = _binary_buffers(rec) + rec = self._binary_buffers(rec) obj_id = self._records.insert(rec) self._table[msg_id] = obj_id @@ -53,7 +54,7 @@ class MongoDB(BaseDB): def update_record(self, msg_id, rec): """Update the data in an existing record.""" - rec = _binary_buffers(rec) + rec = self._binary_buffers(rec) obj_id = self._table[msg_id] self._records.update({'_id':obj_id}, {'$set': rec}) @@ -66,15 +67,30 @@ class MongoDB(BaseDB): obj_id = self._table.pop(msg_id) self._records.remove(obj_id) - def find_records(self, check, id_only=False): - """Find records matching a query dict.""" - matches = list(self._records.find(check)) - if id_only: - return [ rec['msg_id'] for rec in matches ] - else: - data = {} - for rec in matches: - data[rec['msg_id']] = rec - return data + def find_records(self, check, keys=None): + """Find records matching a query dict, optionally extracting subset of keys. + + Returns list of matching records. + + Parameters + ---------- + + check: dict + mongodb-style query argument + keys: list of strs [optional] + if specified, the subset of keys to extract. msg_id will *always* be + included. + """ + if keys and 'msg_id' not in keys: + keys.append('msg_id') + matches = list(self._records.find(check,keys)) + for rec in matches: + rec.pop('_id') + return matches + + def get_history(self): + """get all msg_ids, ordered by time submitted.""" + cursor = self._records.find({},{'msg_id':1}).sort('submitted') + return [ rec['msg_id'] for rec in cursor ] diff --git a/IPython/parallel/controller/sqlitedb.py b/IPython/parallel/controller/sqlitedb.py index 8a6bd31..0b738ff 100644 --- a/IPython/parallel/controller/sqlitedb.py +++ b/IPython/parallel/controller/sqlitedb.py @@ -24,7 +24,7 @@ from IPython.parallel.util import ISO8601 #----------------------------------------------------------------------------- operators = { - '$lt' : lambda a,b: "%s < ?", + '$lt' : "<", '$gt' : ">", # null is handled weird with ==,!= '$eq' : "IS", @@ -124,10 +124,11 @@ class SQLiteDB(BaseDB): pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop) pc.start() - def _defaults(self): + def _defaults(self, keys=None): """create an empty record""" d = {} - for key in self._keys: + keys = self._keys if keys is None else keys + for key in keys: d[key] = None return d @@ -168,9 +169,6 @@ class SQLiteDB(BaseDB): stdout text, stderr text) """%self.table) - # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers - # (msg_id text, result integer, buffer blob) - # """%self.table) self._db.commit() def _dict_to_list(self, d): @@ -178,10 +176,11 @@ class SQLiteDB(BaseDB): return [ d[key] for key in self._keys ] - def _list_to_dict(self, line): + def _list_to_dict(self, line, keys=None): """Inverse of dict_to_list""" - d = self._defaults() - for key,value in zip(self._keys, line): + keys = self._keys if keys is None else keys + d = self._defaults(keys) + for key,value in zip(keys, line): d[key] = value return d @@ -249,13 +248,14 @@ class SQLiteDB(BaseDB): sets.append('%s = ?'%key) values.append(rec[key]) query += ', '.join(sets) - query += ' WHERE msg_id == %r'%msg_id + query += ' WHERE msg_id == ?' + values.append(msg_id) self._db.execute(query, values) # self._db.commit() def drop_record(self, msg_id): """Remove a record from the DB.""" - self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,)) + self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,)) # self._db.commit() def drop_matching_records(self, check): @@ -265,20 +265,48 @@ class SQLiteDB(BaseDB): self._db.execute(query,args) # self._db.commit() - def find_records(self, check, id_only=False): - """Find records matching a query dict.""" - req = 'msg_id' if id_only else '*' + def find_records(self, check, keys=None): + """Find records matching a query dict, optionally extracting subset of keys. + + Returns list of matching records. + + Parameters + ---------- + + check: dict + mongodb-style query argument + keys: list of strs [optional] + if specified, the subset of keys to extract. msg_id will *always* be + included. + """ + if keys: + bad_keys = [ key for key in keys if key not in self._keys ] + if bad_keys: + raise KeyError("Bad record key(s): %s"%bad_keys) + + if keys: + # ensure msg_id is present and first: + if 'msg_id' in keys: + keys.remove('msg_id') + keys.insert(0, 'msg_id') + req = ', '.join(keys) + else: + req = '*' expr,args = self._render_expression(check) query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr) cursor = self._db.execute(query, args) matches = cursor.fetchall() - if id_only: - return [ m[0] for m in matches ] - else: - records = {} - for line in matches: - rec = self._list_to_dict(line) - records[rec['msg_id']] = rec - return records + records = [] + for line in matches: + rec = self._list_to_dict(line, keys) + records.append(rec) + return records + + def get_history(self): + """get all msg_ids, ordered by time submitted.""" + query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table + cursor = self._db.execute(query) + # will be a list of length 1 tuples + return [ tup[0] for tup in cursor.fetchall()] __all__ = ['SQLiteDB'] \ No newline at end of file diff --git a/IPython/parallel/streamsession.py b/IPython/parallel/streamsession.py index 681083e..e5e8bd4 100644 --- a/IPython/parallel/streamsession.py +++ b/IPython/parallel/streamsession.py @@ -28,6 +28,7 @@ from zmq.eventloop.zmqstream import ZMQStream from .util import ISO8601 def squash_unicode(obj): + """coerce unicode back to bytestrings.""" if isinstance(obj,dict): for key in obj.keys(): obj[key] = squash_unicode(obj[key]) @@ -40,7 +41,14 @@ def squash_unicode(obj): obj = obj.encode('utf8') return obj -json_packer = jsonapi.dumps +def _date_default(obj): + if isinstance(obj, datetime): + return obj.strftime(ISO8601) + else: + raise TypeError("%r is not JSON serializable"%obj) + +_default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default' +json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default}) json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) pickle_packer = lambda o: pickle.dumps(o,-1) diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py new file mode 100644 index 0000000..e1dae1c --- /dev/null +++ b/IPython/parallel/tests/test_db.py @@ -0,0 +1,182 @@ +"""Tests for db backends""" + +#------------------------------------------------------------------------------- +# Copyright (C) 2011 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + + +import tempfile +import time + +import uuid + +from datetime import datetime, timedelta +from random import choice, randint +from unittest import TestCase + +from nose import SkipTest + +from IPython.parallel import error, streamsession as ss +from IPython.parallel.controller.dictdb import DictDB +from IPython.parallel.controller.sqlitedb import SQLiteDB +from IPython.parallel.controller.hub import init_record, empty_record + +#------------------------------------------------------------------------------- +# TestCases +#------------------------------------------------------------------------------- + +class TestDictBackend(TestCase): + def setUp(self): + self.session = ss.StreamSession() + self.db = self.create_db() + self.load_records(16) + + def create_db(self): + return DictDB() + + def load_records(self, n=1): + """load n records for testing""" + #sleep 1/10 s, to ensure timestamp is different to previous calls + time.sleep(0.1) + msg_ids = [] + for i in range(n): + msg = self.session.msg('apply_request', content=dict(a=5)) + msg['buffers'] = [] + rec = init_record(msg) + msg_ids.append(msg['msg_id']) + self.db.add_record(msg['msg_id'], rec) + return msg_ids + + def test_add_record(self): + before = self.db.get_history() + self.load_records(5) + after = self.db.get_history() + self.assertEquals(len(after), len(before)+5) + self.assertEquals(after[:-5],before) + + def test_drop_record(self): + msg_id = self.load_records()[-1] + rec = self.db.get_record(msg_id) + self.db.drop_record(msg_id) + self.assertRaises(KeyError,self.db.get_record, msg_id) + + def _round_to_millisecond(self, dt): + """necessary because mongodb rounds microseconds""" + micro = dt.microsecond + extra = int(str(micro)[-3:]) + return dt - timedelta(microseconds=extra) + + def test_update_record(self): + now = self._round_to_millisecond(datetime.now()) + # + msg_id = self.db.get_history()[-1] + rec1 = self.db.get_record(msg_id) + data = {'stdout': 'hello there', 'completed' : now} + self.db.update_record(msg_id, data) + rec2 = self.db.get_record(msg_id) + self.assertEquals(rec2['stdout'], 'hello there') + self.assertEquals(rec2['completed'], now) + rec1.update(data) + self.assertEquals(rec1, rec2) + + # def test_update_record_bad(self): + # """test updating nonexistant records""" + # msg_id = str(uuid.uuid4()) + # data = {'stdout': 'hello there'} + # self.assertRaises(KeyError, self.db.update_record, msg_id, data) + + def test_find_records_dt(self): + """test finding records by date""" + hist = self.db.get_history() + middle = self.db.get_record(hist[len(hist)/2]) + tic = middle['submitted'] + before = self.db.find_records({'submitted' : {'$lt' : tic}}) + after = self.db.find_records({'submitted' : {'$gte' : tic}}) + self.assertEquals(len(before)+len(after),len(hist)) + for b in before: + self.assertTrue(b['submitted'] < tic) + for a in after: + self.assertTrue(a['submitted'] >= tic) + same = self.db.find_records({'submitted' : tic}) + for s in same: + self.assertTrue(s['submitted'] == tic) + + def test_find_records_keys(self): + """test extracting subset of record keys""" + found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) + for rec in found: + self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) + + def test_find_records_msg_id(self): + """ensure msg_id is always in found records""" + found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) + for rec in found: + self.assertTrue('msg_id' in rec.keys()) + found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) + for rec in found: + self.assertTrue('msg_id' in rec.keys()) + found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) + for rec in found: + self.assertTrue('msg_id' in rec.keys()) + + def test_find_records_in(self): + """test finding records with '$in','$nin' operators""" + hist = self.db.get_history() + even = hist[::2] + odd = hist[1::2] + recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) + found = [ r['msg_id'] for r in recs ] + self.assertEquals(set(even), set(found)) + recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) + found = [ r['msg_id'] for r in recs ] + self.assertEquals(set(odd), set(found)) + + def test_get_history(self): + msg_ids = self.db.get_history() + latest = datetime(1984,1,1) + for msg_id in msg_ids: + rec = self.db.get_record(msg_id) + newt = rec['submitted'] + self.assertTrue(newt >= latest) + latest = newt + msg_id = self.load_records(1)[-1] + self.assertEquals(self.db.get_history()[-1],msg_id) + + def test_datetime(self): + """get/set timestamps with datetime objects""" + msg_id = self.db.get_history()[-1] + rec = self.db.get_record(msg_id) + self.assertTrue(isinstance(rec['submitted'], datetime)) + self.db.update_record(msg_id, dict(completed=datetime.now())) + rec = self.db.get_record(msg_id) + self.assertTrue(isinstance(rec['completed'], datetime)) + +class TestSQLiteBackend(TestDictBackend): + def create_db(self): + return SQLiteDB(location=tempfile.gettempdir()) + + def tearDown(self): + self.db._db.close() + +# optional MongoDB test +try: + from IPython.parallel.controller.mongodb import MongoDB +except ImportError: + pass +else: + class TestMongoBackend(TestDictBackend): + def create_db(self): + try: + return MongoDB(database='iptestdb') + except Exception: + raise SkipTest("Couldn't connect to mongodb instance") + + def tearDown(self): + self.db._connection.drop_database('iptestdb') diff --git a/IPython/parallel/util.py b/IPython/parallel/util.py index dc00ed4..9bac6d1 100644 --- a/IPython/parallel/util.py +++ b/IPython/parallel/util.py @@ -17,6 +17,7 @@ import re import stat import socket import sys +from datetime import datetime from signal import signal, SIGINT, SIGABRT, SIGTERM try: from signal import SIGKILL @@ -41,6 +42,7 @@ from IPython.zmq.log import EnginePUBHandler # globals ISO8601="%Y-%m-%dT%H:%M:%S.%f" +ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$") #----------------------------------------------------------------------------- # Classes @@ -99,6 +101,18 @@ class ReverseDict(dict): # Functions #----------------------------------------------------------------------------- +def extract_dates(obj): + """extract ISO8601 dates from unpacked JSON""" + if isinstance(obj, dict): + for k,v in obj.iteritems(): + obj[k] = extract_dates(v) + elif isinstance(obj, list): + obj = [ extract_dates(o) for o in obj ] + elif isinstance(obj, basestring): + if ISO8601_RE.match(obj): + obj = datetime.strptime(obj, ISO8601) + return obj + def validate_url(url): """validate a url for zeromq""" if not isinstance(url, basestring): @@ -460,3 +474,4 @@ def local_logger(logname, loglevel=logging.DEBUG): handler.setLevel(loglevel) logger.addHandler(handler) logger.setLevel(loglevel) +