|
|
"""Tests for db backends"""
|
|
|
|
|
|
#-------------------------------------------------------------------------------
|
|
|
# Copyright (C) 2011 The IPython Development Team
|
|
|
#
|
|
|
# Distributed under the terms of the BSD License. The full license is in
|
|
|
# the file COPYING, distributed as part of this software.
|
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
#-------------------------------------------------------------------------------
|
|
|
# Imports
|
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
import tempfile
|
|
|
import time
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
from random import choice, randint
|
|
|
from unittest import TestCase
|
|
|
|
|
|
from nose import SkipTest
|
|
|
|
|
|
from IPython.parallel import error, streamsession as ss
|
|
|
from IPython.parallel.controller.dictdb import DictDB
|
|
|
from IPython.parallel.controller.sqlitedb import SQLiteDB
|
|
|
from IPython.parallel.controller.hub import init_record, empty_record
|
|
|
|
|
|
#-------------------------------------------------------------------------------
|
|
|
# TestCases
|
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
class TestDictBackend(TestCase):
|
|
|
def setUp(self):
|
|
|
self.session = ss.StreamSession()
|
|
|
self.db = self.create_db()
|
|
|
self.load_records(16)
|
|
|
|
|
|
def create_db(self):
|
|
|
return DictDB()
|
|
|
|
|
|
def load_records(self, n=1):
|
|
|
"""load n records for testing"""
|
|
|
#sleep 1/10 s, to ensure timestamp is different to previous calls
|
|
|
time.sleep(0.1)
|
|
|
msg_ids = []
|
|
|
for i in range(n):
|
|
|
msg = self.session.msg('apply_request', content=dict(a=5))
|
|
|
msg['buffers'] = []
|
|
|
rec = init_record(msg)
|
|
|
msg_ids.append(msg['msg_id'])
|
|
|
self.db.add_record(msg['msg_id'], rec)
|
|
|
return msg_ids
|
|
|
|
|
|
def test_add_record(self):
|
|
|
before = self.db.get_history()
|
|
|
self.load_records(5)
|
|
|
after = self.db.get_history()
|
|
|
self.assertEquals(len(after), len(before)+5)
|
|
|
self.assertEquals(after[:-5],before)
|
|
|
|
|
|
def test_drop_record(self):
|
|
|
msg_id = self.load_records()[-1]
|
|
|
rec = self.db.get_record(msg_id)
|
|
|
self.db.drop_record(msg_id)
|
|
|
self.assertRaises(KeyError,self.db.get_record, msg_id)
|
|
|
|
|
|
def _round_to_millisecond(self, dt):
|
|
|
"""necessary because mongodb rounds microseconds"""
|
|
|
micro = dt.microsecond
|
|
|
extra = int(str(micro)[-3:])
|
|
|
return dt - timedelta(microseconds=extra)
|
|
|
|
|
|
def test_update_record(self):
|
|
|
now = self._round_to_millisecond(datetime.now())
|
|
|
#
|
|
|
msg_id = self.db.get_history()[-1]
|
|
|
rec1 = self.db.get_record(msg_id)
|
|
|
data = {'stdout': 'hello there', 'completed' : now}
|
|
|
self.db.update_record(msg_id, data)
|
|
|
rec2 = self.db.get_record(msg_id)
|
|
|
self.assertEquals(rec2['stdout'], 'hello there')
|
|
|
self.assertEquals(rec2['completed'], now)
|
|
|
rec1.update(data)
|
|
|
self.assertEquals(rec1, rec2)
|
|
|
|
|
|
# def test_update_record_bad(self):
|
|
|
# """test updating nonexistant records"""
|
|
|
# msg_id = str(uuid.uuid4())
|
|
|
# data = {'stdout': 'hello there'}
|
|
|
# self.assertRaises(KeyError, self.db.update_record, msg_id, data)
|
|
|
|
|
|
def test_find_records_dt(self):
|
|
|
"""test finding records by date"""
|
|
|
hist = self.db.get_history()
|
|
|
middle = self.db.get_record(hist[len(hist)/2])
|
|
|
tic = middle['submitted']
|
|
|
before = self.db.find_records({'submitted' : {'$lt' : tic}})
|
|
|
after = self.db.find_records({'submitted' : {'$gte' : tic}})
|
|
|
self.assertEquals(len(before)+len(after),len(hist))
|
|
|
for b in before:
|
|
|
self.assertTrue(b['submitted'] < tic)
|
|
|
for a in after:
|
|
|
self.assertTrue(a['submitted'] >= tic)
|
|
|
same = self.db.find_records({'submitted' : tic})
|
|
|
for s in same:
|
|
|
self.assertTrue(s['submitted'] == tic)
|
|
|
|
|
|
def test_find_records_keys(self):
|
|
|
"""test extracting subset of record keys"""
|
|
|
found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
|
|
|
for rec in found:
|
|
|
self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
|
|
|
|
|
|
def test_find_records_msg_id(self):
|
|
|
"""ensure msg_id is always in found records"""
|
|
|
found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
|
|
|
for rec in found:
|
|
|
self.assertTrue('msg_id' in rec.keys())
|
|
|
found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
|
|
|
for rec in found:
|
|
|
self.assertTrue('msg_id' in rec.keys())
|
|
|
found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
|
|
|
for rec in found:
|
|
|
self.assertTrue('msg_id' in rec.keys())
|
|
|
|
|
|
def test_find_records_in(self):
|
|
|
"""test finding records with '$in','$nin' operators"""
|
|
|
hist = self.db.get_history()
|
|
|
even = hist[::2]
|
|
|
odd = hist[1::2]
|
|
|
recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
|
|
|
found = [ r['msg_id'] for r in recs ]
|
|
|
self.assertEquals(set(even), set(found))
|
|
|
recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
|
|
|
found = [ r['msg_id'] for r in recs ]
|
|
|
self.assertEquals(set(odd), set(found))
|
|
|
|
|
|
def test_get_history(self):
|
|
|
msg_ids = self.db.get_history()
|
|
|
latest = datetime(1984,1,1)
|
|
|
for msg_id in msg_ids:
|
|
|
rec = self.db.get_record(msg_id)
|
|
|
newt = rec['submitted']
|
|
|
self.assertTrue(newt >= latest)
|
|
|
latest = newt
|
|
|
msg_id = self.load_records(1)[-1]
|
|
|
self.assertEquals(self.db.get_history()[-1],msg_id)
|
|
|
|
|
|
def test_datetime(self):
|
|
|
"""get/set timestamps with datetime objects"""
|
|
|
msg_id = self.db.get_history()[-1]
|
|
|
rec = self.db.get_record(msg_id)
|
|
|
self.assertTrue(isinstance(rec['submitted'], datetime))
|
|
|
self.db.update_record(msg_id, dict(completed=datetime.now()))
|
|
|
rec = self.db.get_record(msg_id)
|
|
|
self.assertTrue(isinstance(rec['completed'], datetime))
|
|
|
|
|
|
class TestSQLiteBackend(TestDictBackend):
|
|
|
def create_db(self):
|
|
|
return SQLiteDB(location=tempfile.gettempdir())
|
|
|
|
|
|
def tearDown(self):
|
|
|
self.db._db.close()
|
|
|
|
|
|
# optional MongoDB test
|
|
|
try:
|
|
|
from IPython.parallel.controller.mongodb import MongoDB
|
|
|
except ImportError:
|
|
|
pass
|
|
|
else:
|
|
|
class TestMongoBackend(TestDictBackend):
|
|
|
def create_db(self):
|
|
|
try:
|
|
|
return MongoDB(database='iptestdb')
|
|
|
except Exception:
|
|
|
raise SkipTest("Couldn't connect to mongodb instance")
|
|
|
|
|
|
def tearDown(self):
|
|
|
self.db._connection.drop_database('iptestdb')
|
|
|
|