From 7a85cbc1aa55a811e8a8da541ba64e5579cf7969 2012-06-13 07:09:08
From: MinRK <benjaminrk@gmail.com>
Date: 2012-06-13 07:09:08
Subject: [PATCH] add size-limiting to the DictDB backend

---

diff --git a/IPython/parallel/controller/dictdb.py b/IPython/parallel/controller/dictdb.py
index fd8be3f..2899f80 100644
--- a/IPython/parallel/controller/dictdb.py
+++ b/IPython/parallel/controller/dictdb.py
@@ -51,7 +51,7 @@ from datetime import datetime
 
 from IPython.config.configurable import LoggingConfigurable
 
-from IPython.utils.traitlets import Dict, Unicode, Instance
+from IPython.utils.traitlets import Dict, Unicode, Integer, Float
 
 filters = {
  '$lt' : lambda a,b: a < b,
@@ -100,6 +100,33 @@ class DictDB(BaseDB):
     """
 
     _records = Dict()
+    _culled_ids = set() # set of ids which have been culled
+    _buffer_bytes = Integer(0) # running total of the bytes in the DB
+    
+    size_limit = Integer(1024*1024, config=True,
+        help="""The maximum total size (in bytes) of the buffers stored in the db
+        
+        When the db exceeds this size, the oldest records will be culled until
+        the total size is under size_limit * (1-cull_fraction).
+        """
+    )
+    record_limit = Integer(1024, config=True,
+        help="""The maximum number of records in the db
+        
+        When the history exceeds this size, the first record_limit * cull_fraction
+        records will be culled.
+        """
+    )
+    cull_fraction = Float(0.1, config=True,
+        help="""The fraction by which the db should culled when one of the limits is exceeded
+        
+        In general, the db size will spend most of its time with a size in the range:
+        
+        [limit * (1-cull_fraction), limit]
+        
+        for each of size_limit and record_limit.
+        """
+    )
 
     def _match_one(self, rec, tests):
         """Check if a specific record matches tests."""
@@ -130,34 +157,92 @@ class DictDB(BaseDB):
         for key in keys:
             d[key] = rec[key]
         return copy(d)
+    
+    # methods for monitoring size / culling history
+    
+    def _add_bytes(self, rec):
+        for key in ('buffers', 'result_buffers'):
+            for buf in rec.get(key) or []:
+                self._buffer_bytes += len(buf)
+        
+        self._maybe_cull()
+    
+    def _drop_bytes(self, rec):
+        for key in ('buffers', 'result_buffers'):
+            for buf in rec.get(key) or []:
+                self._buffer_bytes -= len(buf)
+    
+    def _cull_oldest(self, n=1):
+        """cull the oldest N records"""
+        for msg_id in self.get_history()[:n]:
+            self.log.debug("Culling record: %r", msg_id)
+            self._culled_ids.add(msg_id)
+            self.drop_record(msg_id)
+    
+    def _maybe_cull(self):
+        # cull by count:
+        if len(self._records) > self.record_limit:
+            to_cull = int(self.cull_fraction * self.record_limit)
+            self.log.info("%i records exceeds limit of %i, culling oldest %i",
+                len(self._records), self.record_limit, to_cull
+            )
+            self._cull_oldest(to_cull)
+        
+        # cull by size:
+        if self._buffer_bytes > self.size_limit:
+            limit = self.size_limit * (1 - self.cull_fraction)
+            
+            before = self._buffer_bytes
+            before_count = len(self._records)
+            culled = 0
+            while self._buffer_bytes > limit:
+                self._cull_oldest(1)
+                culled += 1
+        
+            self.log.info("%i records with total buffer size %i exceeds limit: %i. Culled oldest %i records.",
+                before_count, before, self.size_limit, culled
+            )
+    
+    # public API methods:
 
     def add_record(self, msg_id, rec):
         """Add a new Task Record, by msg_id."""
         if self._records.has_key(msg_id):
             raise KeyError("Already have msg_id %r"%(msg_id))
         self._records[msg_id] = rec
+        self._add_bytes(rec)
+        self._maybe_cull()
 
     def get_record(self, msg_id):
         """Get a specific Task Record, by msg_id."""
+        if msg_id in self._culled_ids:
+            raise KeyError("Record %r has been culled for size" % msg_id)
         if not msg_id in self._records:
             raise KeyError("No such msg_id %r"%(msg_id))
         return copy(self._records[msg_id])
 
     def update_record(self, msg_id, rec):
         """Update the data in an existing record."""
-        self._records[msg_id].update(rec)
+        if msg_id in self._culled_ids:
+            raise KeyError("Record %r has been culled for size" % msg_id)
+        _rec = self._records[msg_id]
+        self._drop_bytes(_rec)
+        _rec.update(rec)
+        self._add_bytes(_rec)
 
     def drop_matching_records(self, check):
         """Remove a record from the DB."""
         matches = self._match(check)
-        for m in matches:
-            del self._records[m['msg_id']]
+        for rec in matches:
+            self._drop_bytes(rec)
+            del self._records[rec['msg_id']]
 
     def drop_record(self, msg_id):
         """Remove a record from the DB."""
+        rec = self._records[msg_id]
+        self._drop_bytes(rec)
         del self._records[msg_id]
 
-
     def find_records(self, check, keys=None):
         """Find records matching a query dict, optionally extracting subset of keys.
 
@@ -178,17 +263,18 @@ class DictDB(BaseDB):
         else:
             return matches
 
-
     def get_history(self):
         """get all msg_ids, ordered by time submitted."""
         msg_ids = self._records.keys()
         return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
 
+
 NODATA = KeyError("NoDB backend doesn't store any data. "
 "Start the Controller with a DB backend to enable resubmission / result persistence."
 )
 
-class NoDB(DictDB):
+
+class NoDB(BaseDB):
     """A blackhole db backend that actually stores no information.
     
     Provides the full DB interface, but raises KeyErrors on any
diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py
index e1f2aed..85b3f14 100644
--- a/IPython/parallel/tests/test_db.py
+++ b/IPython/parallel/tests/test_db.py
@@ -45,23 +45,23 @@ def setup():
     temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
 
 
-class TestDictBackend(TestCase):
+class TaskDBTest:
     def setUp(self):
         self.session = Session()
         self.db = self.create_db()
         self.load_records(16)
     
     def create_db(self):
-        return DictDB()
+        raise NotImplementedError
     
-    def load_records(self, n=1):
+    def load_records(self, n=1, buffer_size=100):
         """load n records for testing"""
         #sleep 1/10 s, to ensure timestamp is different to previous calls
         time.sleep(0.1)
         msg_ids = []
         for i in range(n):
             msg = self.session.msg('apply_request', content=dict(a=5))
-            msg['buffers'] = []
+            msg['buffers'] = [os.urandom(buffer_size)]
             rec = init_record(msg)
             msg_id = msg['header']['msg_id']
             msg_ids.append(msg_id)
@@ -228,7 +228,72 @@ class TestDictBackend(TestCase):
         self.assertEquals(rec2['header']['msg_id'], msg_id)
 
 
-class TestSQLiteBackend(TestDictBackend):
+class TestDictBackend(TaskDBTest, TestCase):
+    
+    def create_db(self):
+        return DictDB()
+    
+    def test_cull_count(self):
+        self.db = self.create_db() # skip the load-records init from setUp
+        self.db.record_limit = 20
+        self.db.cull_fraction = 0.2
+        self.load_records(20)
+        self.assertEquals(len(self.db.get_history()), 20)
+        self.load_records(1)
+        # 0.2 * 20 = 4, 21 - 4 = 17
+        self.assertEquals(len(self.db.get_history()), 17)
+        self.load_records(3)
+        self.assertEquals(len(self.db.get_history()), 20)
+        self.load_records(1)
+        self.assertEquals(len(self.db.get_history()), 17)
+        
+        for i in range(100):
+            self.load_records(1)
+            self.assertTrue(len(self.db.get_history()) >= 17)
+            self.assertTrue(len(self.db.get_history()) <= 20)
+
+    def test_cull_size(self):
+        self.db = self.create_db() # skip the load-records init from setUp
+        self.db.size_limit = 1000
+        self.db.cull_fraction = 0.2
+        self.load_records(100, buffer_size=10)
+        self.assertEquals(len(self.db.get_history()), 100)
+        self.load_records(1, buffer_size=0)
+        self.assertEquals(len(self.db.get_history()), 101)
+        self.load_records(1, buffer_size=1)
+        # 0.2 * 100 = 20, 101 - 20 = 81
+        self.assertEquals(len(self.db.get_history()), 81)
+    
+    def test_cull_size_drop(self):
+        """dropping records updates tracked buffer size"""
+        self.db = self.create_db() # skip the load-records init from setUp
+        self.db.size_limit = 1000
+        self.db.cull_fraction = 0.2
+        self.load_records(100, buffer_size=10)
+        self.assertEquals(len(self.db.get_history()), 100)
+        self.db.drop_record(self.db.get_history()[-1])
+        self.assertEquals(len(self.db.get_history()), 99)
+        self.load_records(1, buffer_size=5)
+        self.assertEquals(len(self.db.get_history()), 100)
+        self.load_records(1, buffer_size=5)
+        self.assertEquals(len(self.db.get_history()), 101)
+        self.load_records(1, buffer_size=1)
+        self.assertEquals(len(self.db.get_history()), 81)
+
+    def test_cull_size_update(self):
+        """updating records updates tracked buffer size"""
+        self.db = self.create_db() # skip the load-records init from setUp
+        self.db.size_limit = 1000
+        self.db.cull_fraction = 0.2
+        self.load_records(100, buffer_size=10)
+        self.assertEquals(len(self.db.get_history()), 100)
+        msg_id = self.db.get_history()[-1]
+        self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
+        self.assertEquals(len(self.db.get_history()), 100)
+        self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
+        self.assertEquals(len(self.db.get_history()), 79)
+
+class TestSQLiteBackend(TaskDBTest, TestCase):
 
     @dec.skip_without('sqlite3')
     def create_db(self):
diff --git a/IPython/parallel/tests/test_mongodb.py b/IPython/parallel/tests/test_mongodb.py
index 389db17..4bf212e 100644
--- a/IPython/parallel/tests/test_mongodb.py
+++ b/IPython/parallel/tests/test_mongodb.py
@@ -16,6 +16,8 @@ Authors:
 # Imports
 #-------------------------------------------------------------------------------
 
+from unittest import TestCase
+
 from nose import SkipTest
 
 from pymongo import Connection
@@ -28,7 +30,7 @@ try:
 except Exception:
     c=None
 
-class TestMongoBackend(test_db.TestDictBackend):
+class TestMongoBackend(test_db.TaskDBTest, TestCase):
     """MongoDB backend tests"""
 
     def create_db(self):