diff --git a/IPython/parallel/controller/dictdb.py b/IPython/parallel/controller/dictdb.py index eaf106a..e0a00a6 100644 --- a/IPython/parallel/controller/dictdb.py +++ b/IPython/parallel/controller/dictdb.py @@ -51,7 +51,7 @@ from datetime import datetime from IPython.config.configurable import LoggingConfigurable -from IPython.utils.traitlets import Dict, Unicode, Instance +from IPython.utils.traitlets import Dict, Unicode, Integer, Float filters = { '$lt' : lambda a,b: a < b, @@ -100,6 +100,33 @@ class DictDB(BaseDB): """ _records = Dict() + _culled_ids = set() # set of ids which have been culled + _buffer_bytes = Integer(0) # running total of the bytes in the DB + + size_limit = Integer(1024*1024, config=True, + help="""The maximum total size (in bytes) of the buffers stored in the db + + When the db exceeds this size, the oldest records will be culled until + the total size is under size_limit * (1-cull_fraction). + """ + ) + record_limit = Integer(1024, config=True, + help="""The maximum number of records in the db + + When the history exceeds this size, the first record_limit * cull_fraction + records will be culled. + """ + ) + cull_fraction = Float(0.1, config=True, + help="""The fraction by which the db should culled when one of the limits is exceeded + + In general, the db size will spend most of its time with a size in the range: + + [limit * (1-cull_fraction), limit] + + for each of size_limit and record_limit. + """ + ) def _match_one(self, rec, tests): """Check if a specific record matches tests.""" @@ -130,34 +157,92 @@ class DictDB(BaseDB): for key in keys: d[key] = rec[key] return copy(d) + + # methods for monitoring size / culling history + + def _add_bytes(self, rec): + for key in ('buffers', 'result_buffers'): + for buf in rec.get(key) or []: + self._buffer_bytes += len(buf) + + self._maybe_cull() + + def _drop_bytes(self, rec): + for key in ('buffers', 'result_buffers'): + for buf in rec.get(key) or []: + self._buffer_bytes -= len(buf) + + def _cull_oldest(self, n=1): + """cull the oldest N records""" + for msg_id in self.get_history()[:n]: + self.log.debug("Culling record: %r", msg_id) + self._culled_ids.add(msg_id) + self.drop_record(msg_id) + + def _maybe_cull(self): + # cull by count: + if len(self._records) > self.record_limit: + to_cull = int(self.cull_fraction * self.record_limit) + self.log.info("%i records exceeds limit of %i, culling oldest %i", + len(self._records), self.record_limit, to_cull + ) + self._cull_oldest(to_cull) + + # cull by size: + if self._buffer_bytes > self.size_limit: + limit = self.size_limit * (1 - self.cull_fraction) + + before = self._buffer_bytes + before_count = len(self._records) + culled = 0 + while self._buffer_bytes > limit: + self._cull_oldest(1) + culled += 1 + + self.log.info("%i records with total buffer size %i exceeds limit: %i. Culled oldest %i records.", + before_count, before, self.size_limit, culled + ) + + # public API methods: def add_record(self, msg_id, rec): """Add a new Task Record, by msg_id.""" if msg_id in self._records: raise KeyError("Already have msg_id %r"%(msg_id)) self._records[msg_id] = rec + self._add_bytes(rec) + self._maybe_cull() def get_record(self, msg_id): """Get a specific Task Record, by msg_id.""" + if msg_id in self._culled_ids: + raise KeyError("Record %r has been culled for size" % msg_id) if not msg_id in self._records: raise KeyError("No such msg_id %r"%(msg_id)) return copy(self._records[msg_id]) def update_record(self, msg_id, rec): """Update the data in an existing record.""" - self._records[msg_id].update(rec) + if msg_id in self._culled_ids: + raise KeyError("Record %r has been culled for size" % msg_id) + _rec = self._records[msg_id] + self._drop_bytes(_rec) + _rec.update(rec) + self._add_bytes(_rec) def drop_matching_records(self, check): """Remove a record from the DB.""" matches = self._match(check) - for m in matches: - del self._records[m['msg_id']] + for rec in matches: + self._drop_bytes(rec) + del self._records[rec['msg_id']] def drop_record(self, msg_id): """Remove a record from the DB.""" + rec = self._records[msg_id] + self._drop_bytes(rec) del self._records[msg_id] - def find_records(self, check, keys=None): """Find records matching a query dict, optionally extracting subset of keys. @@ -178,17 +263,18 @@ class DictDB(BaseDB): else: return matches - def get_history(self): """get all msg_ids, ordered by time submitted.""" msg_ids = self._records.keys() return sorted(msg_ids, key=lambda m: self._records[m]['submitted']) + NODATA = KeyError("NoDB backend doesn't store any data. " "Start the Controller with a DB backend to enable resubmission / result persistence." ) -class NoDB(DictDB): + +class NoDB(BaseDB): """A blackhole db backend that actually stores no information. Provides the full DB interface, but raises KeyErrors on any diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py index a56dcb2..67635cb 100644 --- a/IPython/parallel/tests/test_db.py +++ b/IPython/parallel/tests/test_db.py @@ -45,23 +45,23 @@ def setup(): temp_db = tempfile.NamedTemporaryFile(suffix='.db').name -class TestDictBackend(TestCase): +class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): - return DictDB() + raise NotImplementedError - def load_records(self, n=1): + def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) - msg['buffers'] = [] + msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) @@ -228,7 +228,72 @@ class TestDictBackend(TestCase): self.assertEqual(rec2['header']['msg_id'], msg_id) -class TestSQLiteBackend(TestDictBackend): +class TestDictBackend(TaskDBTest, TestCase): + + def create_db(self): + return DictDB() + + def test_cull_count(self): + self.db = self.create_db() # skip the load-records init from setUp + self.db.record_limit = 20 + self.db.cull_fraction = 0.2 + self.load_records(20) + self.assertEquals(len(self.db.get_history()), 20) + self.load_records(1) + # 0.2 * 20 = 4, 21 - 4 = 17 + self.assertEquals(len(self.db.get_history()), 17) + self.load_records(3) + self.assertEquals(len(self.db.get_history()), 20) + self.load_records(1) + self.assertEquals(len(self.db.get_history()), 17) + + for i in range(100): + self.load_records(1) + self.assertTrue(len(self.db.get_history()) >= 17) + self.assertTrue(len(self.db.get_history()) <= 20) + + def test_cull_size(self): + self.db = self.create_db() # skip the load-records init from setUp + self.db.size_limit = 1000 + self.db.cull_fraction = 0.2 + self.load_records(100, buffer_size=10) + self.assertEquals(len(self.db.get_history()), 100) + self.load_records(1, buffer_size=0) + self.assertEquals(len(self.db.get_history()), 101) + self.load_records(1, buffer_size=1) + # 0.2 * 100 = 20, 101 - 20 = 81 + self.assertEquals(len(self.db.get_history()), 81) + + def test_cull_size_drop(self): + """dropping records updates tracked buffer size""" + self.db = self.create_db() # skip the load-records init from setUp + self.db.size_limit = 1000 + self.db.cull_fraction = 0.2 + self.load_records(100, buffer_size=10) + self.assertEquals(len(self.db.get_history()), 100) + self.db.drop_record(self.db.get_history()[-1]) + self.assertEquals(len(self.db.get_history()), 99) + self.load_records(1, buffer_size=5) + self.assertEquals(len(self.db.get_history()), 100) + self.load_records(1, buffer_size=5) + self.assertEquals(len(self.db.get_history()), 101) + self.load_records(1, buffer_size=1) + self.assertEquals(len(self.db.get_history()), 81) + + def test_cull_size_update(self): + """updating records updates tracked buffer size""" + self.db = self.create_db() # skip the load-records init from setUp + self.db.size_limit = 1000 + self.db.cull_fraction = 0.2 + self.load_records(100, buffer_size=10) + self.assertEquals(len(self.db.get_history()), 100) + msg_id = self.db.get_history()[-1] + self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[])) + self.assertEquals(len(self.db.get_history()), 100) + self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[])) + self.assertEquals(len(self.db.get_history()), 79) + +class TestSQLiteBackend(TaskDBTest, TestCase): @dec.skip_without('sqlite3') def create_db(self): diff --git a/IPython/parallel/tests/test_mongodb.py b/IPython/parallel/tests/test_mongodb.py index 389db17..4bf212e 100644 --- a/IPython/parallel/tests/test_mongodb.py +++ b/IPython/parallel/tests/test_mongodb.py @@ -16,6 +16,8 @@ Authors: # Imports #------------------------------------------------------------------------------- +from unittest import TestCase + from nose import SkipTest from pymongo import Connection @@ -28,7 +30,7 @@ try: except Exception: c=None -class TestMongoBackend(test_db.TestDictBackend): +class TestMongoBackend(test_db.TaskDBTest, TestCase): """MongoDB backend tests""" def create_db(self):