diff --git a/IPython/parallel/controller/dictdb.py b/IPython/parallel/controller/dictdb.py index 6bf2ca8..0526f2d 100644 --- a/IPython/parallel/controller/dictdb.py +++ b/IPython/parallel/controller/dictdb.py @@ -146,7 +146,7 @@ class DictDB(BaseDB): """Remove a record from the DB.""" matches = self._match(check) for m in matches: - del self._records[m] + del self._records[m['msg_id']] def drop_record(self, msg_id): """Remove a record from the DB.""" diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index 9136ef2..fc5ec77 100755 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -1066,36 +1066,34 @@ class Hub(LoggingFactory): except Exception: reply = error.wrap_exception() else: - for msg_id in msg_ids: - if msg_id in self.all_completed: - self.db.drop_record(msg_id) - else: - if msg_id in self.pending: - try: - raise IndexError("msg pending: %r"%msg_id) - except: - reply = error.wrap_exception() - else: + pending = filter(lambda m: m in self.pending, msg_ids) + if pending: + try: + raise IndexError("msg pending: %r"%pending[0]) + except: + reply = error.wrap_exception() + else: + try: + self.db.drop_matching_records(dict(msg_id={'$in':msg_ids})) + except Exception: + reply = error.wrap_exception() + + if reply['status'] == 'ok': + eids = content.get('engine_ids', []) + for eid in eids: + if eid not in self.engines: try: - raise IndexError("No such msg: %r"%msg_id) + raise IndexError("No such engine: %i"%eid) except: reply = error.wrap_exception() - break - eids = content.get('engine_ids', []) - for eid in eids: - if eid not in self.engines: + break + msg_ids = self.completed.pop(eid) + uid = self.engines[eid].queue try: - raise IndexError("No such engine: %i"%eid) - except: + self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) + except Exception: reply = error.wrap_exception() - break - msg_ids = self.completed.pop(eid) - uid = self.engines[eid].queue - try: - self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) - except Exception: - reply = error.wrap_exception() - break + break self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) diff --git a/IPython/parallel/controller/mongodb.py b/IPython/parallel/controller/mongodb.py index 9100d48..71cf6b9 100644 --- a/IPython/parallel/controller/mongodb.py +++ b/IPython/parallel/controller/mongodb.py @@ -6,12 +6,10 @@ # the file COPYING, distributed as part of this software. #----------------------------------------------------------------------------- -from datetime import datetime - from pymongo import Connection from pymongo.binary import Binary -from IPython.utils.traitlets import Dict, List, CUnicode +from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance from .dictdb import BaseDB @@ -25,15 +23,20 @@ class MongoDB(BaseDB): connection_args = List(config=True) # args passed to pymongo.Connection connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection database = CUnicode(config=True) # name of the mongodb database - _table = Dict() + + _connection = Instance(Connection) # pymongo connection def __init__(self, **kwargs): super(MongoDB, self).__init__(**kwargs) - self._connection = Connection(*self.connection_args, **self.connection_kwargs) + if self._connection is None: + self._connection = Connection(*self.connection_args, **self.connection_kwargs) if not self.database: self.database = self.session self._db = self._connection[self.database] self._records = self._db['task_records'] + self._records.ensure_index('msg_id', unique=True) + self._records.ensure_index('submitted') # for sorting history + # for rec in self._records.find def _binary_buffers(self, rec): for key in ('buffers', 'result_buffers'): @@ -45,18 +48,21 @@ class MongoDB(BaseDB): """Add a new Task Record, by msg_id.""" # print rec rec = self._binary_buffers(rec) - obj_id = self._records.insert(rec) - self._table[msg_id] = obj_id + self._records.insert(rec) def get_record(self, msg_id): """Get a specific Task Record, by msg_id.""" - return self._records.find_one(self._table[msg_id]) + r = self._records.find_one({'msg_id': msg_id}) + if not r: + # r will be '' if nothing is found + raise KeyError(msg_id) + return r def update_record(self, msg_id, rec): """Update the data in an existing record.""" rec = self._binary_buffers(rec) - obj_id = self._table[msg_id] - self._records.update({'_id':obj_id}, {'$set': rec}) + + self._records.update({'msg_id':msg_id}, {'$set': rec}) def drop_matching_records(self, check): """Remove a record from the DB.""" @@ -64,8 +70,7 @@ class MongoDB(BaseDB): def drop_record(self, msg_id): """Remove a record from the DB.""" - obj_id = self._table.pop(msg_id) - self._records.remove(obj_id) + self._records.remove({'msg_id':msg_id}) def find_records(self, check, keys=None): """Find records matching a query dict, optionally extracting subset of keys. diff --git a/IPython/parallel/controller/sqlitedb.py b/IPython/parallel/controller/sqlitedb.py index 0b738ff..e542d14 100644 --- a/IPython/parallel/controller/sqlitedb.py +++ b/IPython/parallel/controller/sqlitedb.py @@ -27,16 +27,20 @@ operators = { '$lt' : "<", '$gt' : ">", # null is handled weird with ==,!= - '$eq' : "IS", - '$ne' : "IS NOT", + '$eq' : "=", + '$ne' : "!=", '$lte': "<=", '$gte': ">=", - '$in' : ('IS', ' OR '), - '$nin': ('IS NOT', ' AND '), + '$in' : ('=', ' OR '), + '$nin': ('!=', ' AND '), # '$all': None, # '$mod': None, # '$exists' : None } +null_operators = { +'=' : "IS NULL", +'!=' : "IS NOT NULL", +} def _adapt_datetime(dt): return dt.strftime(ISO8601) @@ -205,17 +209,27 @@ class SQLiteDB(BaseDB): raise KeyError("Unsupported operator: %r"%test) if isinstance(op, tuple): op, join = op - expr = "%s %s ?"%(name, op) - if isinstance(value, (tuple,list)): - expr = '( %s )'%( join.join([expr]*len(value)) ) - args.extend(value) + + if value is None and op in null_operators: + expr = "%s %s"%null_operators[op] else: - args.append(value) + expr = "%s %s ?"%(name, op) + if isinstance(value, (tuple,list)): + if op in null_operators and any([v is None for v in value]): + # equality tests don't work with NULL + raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test) + expr = '( %s )'%( join.join([expr]*len(value)) ) + args.extend(value) + else: + args.append(value) expressions.append(expr) else: # it's an equality check - expressions.append("%s IS ?"%name) - args.append(sub_check) + if sub_check is None: + expressions.append("%s IS NULL") + else: + expressions.append("%s = ?"%name) + args.append(sub_check) expr = " AND ".join(expressions) return expr, args diff --git a/IPython/parallel/tests/test_client.py b/IPython/parallel/tests/test_client.py index 4d7d42d..d4e89a4 100644 --- a/IPython/parallel/tests/test_client.py +++ b/IPython/parallel/tests/test_client.py @@ -235,3 +235,10 @@ class TestClient(ClusterTestCase): def test_resubmit_badkey(self): """ensure KeyError on resubmit of nonexistant task""" self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid']) + + def test_purge_results(self): + hist = self.client.hub_history() + self.client.purge_results(hist) + newhist = self.client.hub_history() + self.assertTrue(len(newhist) == 0) + diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py index e1dae1c..fc47b61 100644 --- a/IPython/parallel/tests/test_db.py +++ b/IPython/parallel/tests/test_db.py @@ -15,10 +15,7 @@ import tempfile import time -import uuid - from datetime import datetime, timedelta -from random import choice, randint from unittest import TestCase from nose import SkipTest @@ -157,6 +154,13 @@ class TestDictBackend(TestCase): self.db.update_record(msg_id, dict(completed=datetime.now())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) + + def test_drop_matching(self): + msg_ids = self.load_records(10) + query = {'msg_id' : {'$in':msg_ids}} + self.db.drop_matching_records(query) + recs = self.db.find_records(query) + self.assertTrue(len(recs)==0) class TestSQLiteBackend(TestDictBackend): def create_db(self): @@ -164,19 +168,3 @@ class TestSQLiteBackend(TestDictBackend): def tearDown(self): self.db._db.close() - -# optional MongoDB test -try: - from IPython.parallel.controller.mongodb import MongoDB -except ImportError: - pass -else: - class TestMongoBackend(TestDictBackend): - def create_db(self): - try: - return MongoDB(database='iptestdb') - except Exception: - raise SkipTest("Couldn't connect to mongodb instance") - - def tearDown(self): - self.db._connection.drop_database('iptestdb') diff --git a/IPython/parallel/tests/test_mongodb.py b/IPython/parallel/tests/test_mongodb.py new file mode 100644 index 0000000..f36e1fb --- /dev/null +++ b/IPython/parallel/tests/test_mongodb.py @@ -0,0 +1,37 @@ +"""Tests for mongodb backend""" + +#------------------------------------------------------------------------------- +# Copyright (C) 2011 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from nose import SkipTest + +from pymongo import Connection +from IPython.parallel.controller.mongodb import MongoDB + +from . import test_db + +try: + c = Connection() +except Exception: + c=None + +class TestMongoBackend(test_db.TestDictBackend): + """MongoDB backend tests""" + + def create_db(self): + try: + return MongoDB(database='iptestdb', _connection=c) + except Exception: + raise SkipTest("Couldn't connect to mongodb") + +def teardown(self): + if c is not None: + c.drop_database('iptestdb') diff --git a/IPython/testing/iptest.py b/IPython/testing/iptest.py index 506e650..be593dd 100644 --- a/IPython/testing/iptest.py +++ b/IPython/testing/iptest.py @@ -199,6 +199,7 @@ def make_exclude(): if not have['pymongo']: exclusions.append(ipjoin('parallel', 'controller', 'mongodb')) + exclusions.append(ipjoin('parallel', 'tests', 'test_mongodb')) if not have['matplotlib']: exclusions.extend([ipjoin('lib', 'pylabtools'),