From ffe043d4d415461f21547b59685bf74b262e6b5c 2011-05-17 21:27:33
From: MinRK <benjaminrk@gmail.com>
Date: 2011-05-17 21:27:33
Subject: [PATCH] various db backend fixes

* use index on msg_id in mongodb backend (_table prevented some methods from working outside the session)
* purge_request improved to use fewer db calls
* mongodb testcase split into its own file
* Fix equality testing, NULL handling, in SQLiteDB backend

---

diff --git a/IPython/parallel/controller/dictdb.py b/IPython/parallel/controller/dictdb.py
index 6bf2ca8..0526f2d 100644
--- a/IPython/parallel/controller/dictdb.py
+++ b/IPython/parallel/controller/dictdb.py
@@ -146,7 +146,7 @@ class DictDB(BaseDB):
         """Remove a record from the DB."""
         matches = self._match(check)
         for m in matches:
-            del self._records[m]
+            del self._records[m['msg_id']]
         
     def drop_record(self, msg_id):
         """Remove a record from the DB."""
diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py
index 9136ef2..fc5ec77 100755
--- a/IPython/parallel/controller/hub.py
+++ b/IPython/parallel/controller/hub.py
@@ -1066,36 +1066,34 @@ class Hub(LoggingFactory):
             except Exception:
                 reply = error.wrap_exception()
         else:
-            for msg_id in msg_ids:
-                if msg_id in self.all_completed:
-                    self.db.drop_record(msg_id)
-                else:
-                    if msg_id in self.pending:
-                        try:
-                            raise IndexError("msg pending: %r"%msg_id)
-                        except:
-                            reply = error.wrap_exception()
-                    else:
+            pending = filter(lambda m: m in self.pending, msg_ids)
+            if pending:
+                try:
+                    raise IndexError("msg pending: %r"%pending[0])
+                except:
+                    reply = error.wrap_exception()
+            else:
+                try:
+                    self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
+                except Exception:
+                    reply = error.wrap_exception()
+
+            if reply['status'] == 'ok':
+                eids = content.get('engine_ids', [])
+                for eid in eids:
+                    if eid not in self.engines:
                         try:
-                            raise IndexError("No such msg: %r"%msg_id)
+                            raise IndexError("No such engine: %i"%eid)
                         except:
                             reply = error.wrap_exception()
-                    break
-            eids = content.get('engine_ids', [])
-            for eid in eids:
-                if eid not in self.engines:
+                        break
+                    msg_ids = self.completed.pop(eid)
+                    uid = self.engines[eid].queue
                     try:
-                        raise IndexError("No such engine: %i"%eid)
-                    except:
+                        self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
+                    except Exception:
                         reply = error.wrap_exception()
-                    break
-                msg_ids = self.completed.pop(eid)
-                uid = self.engines[eid].queue
-                try:
-                    self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
-                except Exception:
-                    reply = error.wrap_exception()
-                    break
+                        break
         
         self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
     
diff --git a/IPython/parallel/controller/mongodb.py b/IPython/parallel/controller/mongodb.py
index 9100d48..71cf6b9 100644
--- a/IPython/parallel/controller/mongodb.py
+++ b/IPython/parallel/controller/mongodb.py
@@ -6,12 +6,10 @@
 #  the file COPYING, distributed as part of this software.
 #-----------------------------------------------------------------------------
 
-from datetime import datetime
-
 from pymongo import Connection
 from pymongo.binary import Binary
 
-from IPython.utils.traitlets import Dict, List, CUnicode
+from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance
 
 from .dictdb import BaseDB
 
@@ -25,15 +23,20 @@ class MongoDB(BaseDB):
     connection_args = List(config=True) # args passed to pymongo.Connection
     connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
     database = CUnicode(config=True) # name of the mongodb database
-    _table = Dict()
+
+    _connection = Instance(Connection) # pymongo connection
     
     def __init__(self, **kwargs):
         super(MongoDB, self).__init__(**kwargs)
-        self._connection = Connection(*self.connection_args, **self.connection_kwargs)
+        if self._connection is None:
+            self._connection = Connection(*self.connection_args, **self.connection_kwargs)
         if not self.database:
             self.database = self.session
         self._db = self._connection[self.database]
         self._records = self._db['task_records']
+        self._records.ensure_index('msg_id', unique=True)
+        self._records.ensure_index('submitted') # for sorting history
+        # for rec in self._records.find
     
     def _binary_buffers(self, rec):
         for key in ('buffers', 'result_buffers'):
@@ -45,18 +48,21 @@ class MongoDB(BaseDB):
         """Add a new Task Record, by msg_id."""
         # print rec
         rec = self._binary_buffers(rec)
-        obj_id = self._records.insert(rec)
-        self._table[msg_id] = obj_id
+        self._records.insert(rec)
     
     def get_record(self, msg_id):
         """Get a specific Task Record, by msg_id."""
-        return self._records.find_one(self._table[msg_id])
+        r = self._records.find_one({'msg_id': msg_id})
+        if not r:
+            # r will be '' if nothing is found
+            raise KeyError(msg_id)
+        return r
     
     def update_record(self, msg_id, rec):
         """Update the data in an existing record."""
         rec = self._binary_buffers(rec)
-        obj_id = self._table[msg_id]
-        self._records.update({'_id':obj_id}, {'$set': rec})
+
+        self._records.update({'msg_id':msg_id}, {'$set': rec})
     
     def drop_matching_records(self, check):
         """Remove a record from the DB."""
@@ -64,8 +70,7 @@ class MongoDB(BaseDB):
         
     def drop_record(self, msg_id):
         """Remove a record from the DB."""
-        obj_id = self._table.pop(msg_id)
-        self._records.remove(obj_id)
+        self._records.remove({'msg_id':msg_id})
     
     def find_records(self, check, keys=None):
         """Find records matching a query dict, optionally extracting subset of keys.
diff --git a/IPython/parallel/controller/sqlitedb.py b/IPython/parallel/controller/sqlitedb.py
index 0b738ff..e542d14 100644
--- a/IPython/parallel/controller/sqlitedb.py
+++ b/IPython/parallel/controller/sqlitedb.py
@@ -27,16 +27,20 @@ operators = {
  '$lt' : "<",
  '$gt' : ">",
  # null is handled weird with ==,!=
- '$eq' : "IS",
- '$ne' : "IS NOT",
+ '$eq' : "=",
+ '$ne' : "!=",
  '$lte': "<=",
  '$gte': ">=",
- '$in' : ('IS', ' OR '),
- '$nin': ('IS NOT', ' AND '),
+ '$in' : ('=', ' OR '),
+ '$nin': ('!=', ' AND '),
  # '$all': None,
  # '$mod': None,
  # '$exists' : None
 }
+null_operators = {
+'=' : "IS NULL",
+'!=' : "IS NOT NULL",
+}
 
 def _adapt_datetime(dt):
     return dt.strftime(ISO8601)
@@ -205,17 +209,27 @@ class SQLiteDB(BaseDB):
                         raise KeyError("Unsupported operator: %r"%test)
                     if isinstance(op, tuple):
                         op, join = op
-                    expr = "%s %s ?"%(name, op)
-                    if isinstance(value, (tuple,list)):
-                        expr = '( %s )'%( join.join([expr]*len(value)) )
-                        args.extend(value)
+                    
+                    if value is None and op in null_operators:
+                            expr = "%s %s"%null_operators[op]
                     else:
-                        args.append(value)
+                        expr = "%s %s ?"%(name, op)
+                        if isinstance(value, (tuple,list)):
+                            if op in null_operators and any([v is None for v in value]):
+                                # equality tests don't work with NULL
+                                raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
+                            expr = '( %s )'%( join.join([expr]*len(value)) )
+                            args.extend(value)
+                        else:
+                            args.append(value)
                     expressions.append(expr)
             else:
                 # it's an equality check
-                expressions.append("%s IS ?"%name)
-                args.append(sub_check)
+                if sub_check is None:
+                    expressions.append("%s IS NULL")
+                else:
+                    expressions.append("%s = ?"%name)
+                    args.append(sub_check)
         
         expr = " AND ".join(expressions)
         return expr, args
diff --git a/IPython/parallel/tests/test_client.py b/IPython/parallel/tests/test_client.py
index 4d7d42d..d4e89a4 100644
--- a/IPython/parallel/tests/test_client.py
+++ b/IPython/parallel/tests/test_client.py
@@ -235,3 +235,10 @@ class TestClient(ClusterTestCase):
     def test_resubmit_badkey(self):
         """ensure KeyError on resubmit of nonexistant task"""
         self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
+
+    def test_purge_results(self):
+        hist = self.client.hub_history()
+        self.client.purge_results(hist)
+        newhist = self.client.hub_history()
+        self.assertTrue(len(newhist) == 0)
+
diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py
index e1dae1c..fc47b61 100644
--- a/IPython/parallel/tests/test_db.py
+++ b/IPython/parallel/tests/test_db.py
@@ -15,10 +15,7 @@
 import tempfile
 import time
 
-import uuid
-
 from datetime import datetime, timedelta
-from random import choice, randint
 from unittest import TestCase
 
 from nose import SkipTest
@@ -157,6 +154,13 @@ class TestDictBackend(TestCase):
         self.db.update_record(msg_id, dict(completed=datetime.now()))
         rec = self.db.get_record(msg_id)
         self.assertTrue(isinstance(rec['completed'], datetime))
+
+    def test_drop_matching(self):
+        msg_ids = self.load_records(10)
+        query = {'msg_id' : {'$in':msg_ids}}
+        self.db.drop_matching_records(query)
+        recs = self.db.find_records(query)
+        self.assertTrue(len(recs)==0)
             
 class TestSQLiteBackend(TestDictBackend):
     def create_db(self):
@@ -164,19 +168,3 @@ class TestSQLiteBackend(TestDictBackend):
     
     def tearDown(self):
         self.db._db.close()
-
-# optional MongoDB test
-try:
-    from IPython.parallel.controller.mongodb import MongoDB
-except ImportError:
-    pass
-else:
-    class TestMongoBackend(TestDictBackend):
-        def create_db(self):
-            try:
-                return MongoDB(database='iptestdb')
-            except Exception:
-                raise SkipTest("Couldn't connect to mongodb instance")
-            
-        def tearDown(self):
-            self.db._connection.drop_database('iptestdb')
diff --git a/IPython/parallel/tests/test_mongodb.py b/IPython/parallel/tests/test_mongodb.py
new file mode 100644
index 0000000..f36e1fb
--- /dev/null
+++ b/IPython/parallel/tests/test_mongodb.py
@@ -0,0 +1,37 @@
+"""Tests for mongodb backend"""
+
+#-------------------------------------------------------------------------------
+#  Copyright (C) 2011  The IPython Development Team
+#
+#  Distributed under the terms of the BSD License.  The full license is in
+#  the file COPYING, distributed as part of this software.
+#-------------------------------------------------------------------------------
+
+#-------------------------------------------------------------------------------
+# Imports
+#-------------------------------------------------------------------------------
+
+from nose import SkipTest
+
+from pymongo import Connection
+from IPython.parallel.controller.mongodb import MongoDB
+
+from . import test_db
+
+try:
+    c = Connection()
+except Exception:
+    c=None
+
+class TestMongoBackend(test_db.TestDictBackend):
+    """MongoDB backend tests"""
+
+    def create_db(self):
+        try:
+            return MongoDB(database='iptestdb', _connection=c)
+        except Exception:
+            raise SkipTest("Couldn't connect to mongodb")
+
+def teardown(self):
+    if c is not None:
+        c.drop_database('iptestdb')
diff --git a/IPython/testing/iptest.py b/IPython/testing/iptest.py
index 506e650..be593dd 100644
--- a/IPython/testing/iptest.py
+++ b/IPython/testing/iptest.py
@@ -199,6 +199,7 @@ def make_exclude():
     
     if not have['pymongo']:
         exclusions.append(ipjoin('parallel', 'controller', 'mongodb'))
+        exclusions.append(ipjoin('parallel', 'tests', 'test_mongodb'))
 
     if not have['matplotlib']:
         exclusions.extend([ipjoin('lib', 'pylabtools'),