Show More
@@ -0,0 +1,37 b'' | |||||
|
1 | """Tests for mongodb backend""" | |||
|
2 | ||||
|
3 | #------------------------------------------------------------------------------- | |||
|
4 | # Copyright (C) 2011 The IPython Development Team | |||
|
5 | # | |||
|
6 | # Distributed under the terms of the BSD License. The full license is in | |||
|
7 | # the file COPYING, distributed as part of this software. | |||
|
8 | #------------------------------------------------------------------------------- | |||
|
9 | ||||
|
10 | #------------------------------------------------------------------------------- | |||
|
11 | # Imports | |||
|
12 | #------------------------------------------------------------------------------- | |||
|
13 | ||||
|
14 | from nose import SkipTest | |||
|
15 | ||||
|
16 | from pymongo import Connection | |||
|
17 | from IPython.parallel.controller.mongodb import MongoDB | |||
|
18 | ||||
|
19 | from . import test_db | |||
|
20 | ||||
|
21 | try: | |||
|
22 | c = Connection() | |||
|
23 | except Exception: | |||
|
24 | c=None | |||
|
25 | ||||
|
26 | class TestMongoBackend(test_db.TestDictBackend): | |||
|
27 | """MongoDB backend tests""" | |||
|
28 | ||||
|
29 | def create_db(self): | |||
|
30 | try: | |||
|
31 | return MongoDB(database='iptestdb', _connection=c) | |||
|
32 | except Exception: | |||
|
33 | raise SkipTest("Couldn't connect to mongodb") | |||
|
34 | ||||
|
35 | def teardown(self): | |||
|
36 | if c is not None: | |||
|
37 | c.drop_database('iptestdb') |
@@ -146,7 +146,7 b' class DictDB(BaseDB):' | |||||
146 | """Remove a record from the DB.""" |
|
146 | """Remove a record from the DB.""" | |
147 | matches = self._match(check) |
|
147 | matches = self._match(check) | |
148 | for m in matches: |
|
148 | for m in matches: | |
149 | del self._records[m] |
|
149 | del self._records[m['msg_id']] | |
150 |
|
150 | |||
151 | def drop_record(self, msg_id): |
|
151 | def drop_record(self, msg_id): | |
152 | """Remove a record from the DB.""" |
|
152 | """Remove a record from the DB.""" |
@@ -1066,21 +1066,19 b' class Hub(LoggingFactory):' | |||||
1066 | except Exception: |
|
1066 | except Exception: | |
1067 | reply = error.wrap_exception() |
|
1067 | reply = error.wrap_exception() | |
1068 | else: |
|
1068 | else: | |
1069 | for msg_id in msg_ids: |
|
1069 | pending = filter(lambda m: m in self.pending, msg_ids) | |
1070 | if msg_id in self.all_completed: |
|
1070 | if pending: | |
1071 | self.db.drop_record(msg_id) |
|
|||
1072 | else: |
|
|||
1073 | if msg_id in self.pending: |
|
|||
1074 |
|
|
1071 | try: | |
1075 |
|
|
1072 | raise IndexError("msg pending: %r"%pending[0]) | |
1076 |
|
|
1073 | except: | |
1077 |
|
|
1074 | reply = error.wrap_exception() | |
1078 |
|
|
1075 | else: | |
1079 |
|
|
1076 | try: | |
1080 | raise IndexError("No such msg: %r"%msg_id) |
|
1077 | self.db.drop_matching_records(dict(msg_id={'$in':msg_ids})) | |
1081 |
|
|
1078 | except Exception: | |
1082 |
|
|
1079 | reply = error.wrap_exception() | |
1083 | break |
|
1080 | ||
|
1081 | if reply['status'] == 'ok': | |||
1084 | eids = content.get('engine_ids', []) |
|
1082 | eids = content.get('engine_ids', []) | |
1085 | for eid in eids: |
|
1083 | for eid in eids: | |
1086 | if eid not in self.engines: |
|
1084 | if eid not in self.engines: |
@@ -6,12 +6,10 b'' | |||||
6 | # the file COPYING, distributed as part of this software. |
|
6 | # the file COPYING, distributed as part of this software. | |
7 | #----------------------------------------------------------------------------- |
|
7 | #----------------------------------------------------------------------------- | |
8 |
|
8 | |||
9 | from datetime import datetime |
|
|||
10 |
|
||||
11 | from pymongo import Connection |
|
9 | from pymongo import Connection | |
12 | from pymongo.binary import Binary |
|
10 | from pymongo.binary import Binary | |
13 |
|
11 | |||
14 | from IPython.utils.traitlets import Dict, List, CUnicode |
|
12 | from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance | |
15 |
|
13 | |||
16 | from .dictdb import BaseDB |
|
14 | from .dictdb import BaseDB | |
17 |
|
15 | |||
@@ -25,15 +23,20 b' class MongoDB(BaseDB):' | |||||
25 | connection_args = List(config=True) # args passed to pymongo.Connection |
|
23 | connection_args = List(config=True) # args passed to pymongo.Connection | |
26 | connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection |
|
24 | connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection | |
27 | database = CUnicode(config=True) # name of the mongodb database |
|
25 | database = CUnicode(config=True) # name of the mongodb database | |
28 | _table = Dict() |
|
26 | ||
|
27 | _connection = Instance(Connection) # pymongo connection | |||
29 |
|
28 | |||
30 | def __init__(self, **kwargs): |
|
29 | def __init__(self, **kwargs): | |
31 | super(MongoDB, self).__init__(**kwargs) |
|
30 | super(MongoDB, self).__init__(**kwargs) | |
|
31 | if self._connection is None: | |||
32 | self._connection = Connection(*self.connection_args, **self.connection_kwargs) |
|
32 | self._connection = Connection(*self.connection_args, **self.connection_kwargs) | |
33 | if not self.database: |
|
33 | if not self.database: | |
34 | self.database = self.session |
|
34 | self.database = self.session | |
35 | self._db = self._connection[self.database] |
|
35 | self._db = self._connection[self.database] | |
36 | self._records = self._db['task_records'] |
|
36 | self._records = self._db['task_records'] | |
|
37 | self._records.ensure_index('msg_id', unique=True) | |||
|
38 | self._records.ensure_index('submitted') # for sorting history | |||
|
39 | # for rec in self._records.find | |||
37 |
|
40 | |||
38 | def _binary_buffers(self, rec): |
|
41 | def _binary_buffers(self, rec): | |
39 | for key in ('buffers', 'result_buffers'): |
|
42 | for key in ('buffers', 'result_buffers'): | |
@@ -45,18 +48,21 b' class MongoDB(BaseDB):' | |||||
45 | """Add a new Task Record, by msg_id.""" |
|
48 | """Add a new Task Record, by msg_id.""" | |
46 | # print rec |
|
49 | # print rec | |
47 | rec = self._binary_buffers(rec) |
|
50 | rec = self._binary_buffers(rec) | |
48 |
|
|
51 | self._records.insert(rec) | |
49 | self._table[msg_id] = obj_id |
|
|||
50 |
|
52 | |||
51 | def get_record(self, msg_id): |
|
53 | def get_record(self, msg_id): | |
52 | """Get a specific Task Record, by msg_id.""" |
|
54 | """Get a specific Task Record, by msg_id.""" | |
53 |
r |
|
55 | r = self._records.find_one({'msg_id': msg_id}) | |
|
56 | if not r: | |||
|
57 | # r will be '' if nothing is found | |||
|
58 | raise KeyError(msg_id) | |||
|
59 | return r | |||
54 |
|
60 | |||
55 | def update_record(self, msg_id, rec): |
|
61 | def update_record(self, msg_id, rec): | |
56 | """Update the data in an existing record.""" |
|
62 | """Update the data in an existing record.""" | |
57 | rec = self._binary_buffers(rec) |
|
63 | rec = self._binary_buffers(rec) | |
58 | obj_id = self._table[msg_id] |
|
64 | ||
59 |
self._records.update({'_id': |
|
65 | self._records.update({'msg_id':msg_id}, {'$set': rec}) | |
60 |
|
66 | |||
61 | def drop_matching_records(self, check): |
|
67 | def drop_matching_records(self, check): | |
62 | """Remove a record from the DB.""" |
|
68 | """Remove a record from the DB.""" | |
@@ -64,8 +70,7 b' class MongoDB(BaseDB):' | |||||
64 |
|
70 | |||
65 | def drop_record(self, msg_id): |
|
71 | def drop_record(self, msg_id): | |
66 | """Remove a record from the DB.""" |
|
72 | """Remove a record from the DB.""" | |
67 | obj_id = self._table.pop(msg_id) |
|
73 | self._records.remove({'msg_id':msg_id}) | |
68 | self._records.remove(obj_id) |
|
|||
69 |
|
74 | |||
70 | def find_records(self, check, keys=None): |
|
75 | def find_records(self, check, keys=None): | |
71 | """Find records matching a query dict, optionally extracting subset of keys. |
|
76 | """Find records matching a query dict, optionally extracting subset of keys. |
@@ -27,16 +27,20 b' operators = {' | |||||
27 | '$lt' : "<", |
|
27 | '$lt' : "<", | |
28 | '$gt' : ">", |
|
28 | '$gt' : ">", | |
29 | # null is handled weird with ==,!= |
|
29 | # null is handled weird with ==,!= | |
30 |
'$eq' : " |
|
30 | '$eq' : "=", | |
31 |
'$ne' : " |
|
31 | '$ne' : "!=", | |
32 | '$lte': "<=", |
|
32 | '$lte': "<=", | |
33 | '$gte': ">=", |
|
33 | '$gte': ">=", | |
34 |
'$in' : (' |
|
34 | '$in' : ('=', ' OR '), | |
35 |
'$nin': (' |
|
35 | '$nin': ('!=', ' AND '), | |
36 | # '$all': None, |
|
36 | # '$all': None, | |
37 | # '$mod': None, |
|
37 | # '$mod': None, | |
38 | # '$exists' : None |
|
38 | # '$exists' : None | |
39 | } |
|
39 | } | |
|
40 | null_operators = { | |||
|
41 | '=' : "IS NULL", | |||
|
42 | '!=' : "IS NOT NULL", | |||
|
43 | } | |||
40 |
|
44 | |||
41 | def _adapt_datetime(dt): |
|
45 | def _adapt_datetime(dt): | |
42 | return dt.strftime(ISO8601) |
|
46 | return dt.strftime(ISO8601) | |
@@ -205,8 +209,15 b' class SQLiteDB(BaseDB):' | |||||
205 | raise KeyError("Unsupported operator: %r"%test) |
|
209 | raise KeyError("Unsupported operator: %r"%test) | |
206 | if isinstance(op, tuple): |
|
210 | if isinstance(op, tuple): | |
207 | op, join = op |
|
211 | op, join = op | |
|
212 | ||||
|
213 | if value is None and op in null_operators: | |||
|
214 | expr = "%s %s"%null_operators[op] | |||
|
215 | else: | |||
208 | expr = "%s %s ?"%(name, op) |
|
216 | expr = "%s %s ?"%(name, op) | |
209 | if isinstance(value, (tuple,list)): |
|
217 | if isinstance(value, (tuple,list)): | |
|
218 | if op in null_operators and any([v is None for v in value]): | |||
|
219 | # equality tests don't work with NULL | |||
|
220 | raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test) | |||
210 | expr = '( %s )'%( join.join([expr]*len(value)) ) |
|
221 | expr = '( %s )'%( join.join([expr]*len(value)) ) | |
211 | args.extend(value) |
|
222 | args.extend(value) | |
212 | else: |
|
223 | else: | |
@@ -214,7 +225,10 b' class SQLiteDB(BaseDB):' | |||||
214 | expressions.append(expr) |
|
225 | expressions.append(expr) | |
215 | else: |
|
226 | else: | |
216 | # it's an equality check |
|
227 | # it's an equality check | |
217 | expressions.append("%s IS ?"%name) |
|
228 | if sub_check is None: | |
|
229 | expressions.append("%s IS NULL") | |||
|
230 | else: | |||
|
231 | expressions.append("%s = ?"%name) | |||
218 | args.append(sub_check) |
|
232 | args.append(sub_check) | |
219 |
|
233 | |||
220 | expr = " AND ".join(expressions) |
|
234 | expr = " AND ".join(expressions) |
@@ -235,3 +235,10 b' class TestClient(ClusterTestCase):' | |||||
235 | def test_resubmit_badkey(self): |
|
235 | def test_resubmit_badkey(self): | |
236 | """ensure KeyError on resubmit of nonexistant task""" |
|
236 | """ensure KeyError on resubmit of nonexistant task""" | |
237 | self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid']) |
|
237 | self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid']) | |
|
238 | ||||
|
239 | def test_purge_results(self): | |||
|
240 | hist = self.client.hub_history() | |||
|
241 | self.client.purge_results(hist) | |||
|
242 | newhist = self.client.hub_history() | |||
|
243 | self.assertTrue(len(newhist) == 0) | |||
|
244 |
@@ -15,10 +15,7 b'' | |||||
15 | import tempfile |
|
15 | import tempfile | |
16 | import time |
|
16 | import time | |
17 |
|
17 | |||
18 | import uuid |
|
|||
19 |
|
||||
20 | from datetime import datetime, timedelta |
|
18 | from datetime import datetime, timedelta | |
21 | from random import choice, randint |
|
|||
22 | from unittest import TestCase |
|
19 | from unittest import TestCase | |
23 |
|
20 | |||
24 | from nose import SkipTest |
|
21 | from nose import SkipTest | |
@@ -158,25 +155,16 b' class TestDictBackend(TestCase):' | |||||
158 | rec = self.db.get_record(msg_id) |
|
155 | rec = self.db.get_record(msg_id) | |
159 | self.assertTrue(isinstance(rec['completed'], datetime)) |
|
156 | self.assertTrue(isinstance(rec['completed'], datetime)) | |
160 |
|
157 | |||
|
158 | def test_drop_matching(self): | |||
|
159 | msg_ids = self.load_records(10) | |||
|
160 | query = {'msg_id' : {'$in':msg_ids}} | |||
|
161 | self.db.drop_matching_records(query) | |||
|
162 | recs = self.db.find_records(query) | |||
|
163 | self.assertTrue(len(recs)==0) | |||
|
164 | ||||
161 | class TestSQLiteBackend(TestDictBackend): |
|
165 | class TestSQLiteBackend(TestDictBackend): | |
162 | def create_db(self): |
|
166 | def create_db(self): | |
163 | return SQLiteDB(location=tempfile.gettempdir()) |
|
167 | return SQLiteDB(location=tempfile.gettempdir()) | |
164 |
|
168 | |||
165 | def tearDown(self): |
|
169 | def tearDown(self): | |
166 | self.db._db.close() |
|
170 | self.db._db.close() | |
167 |
|
||||
168 | # optional MongoDB test |
|
|||
169 | try: |
|
|||
170 | from IPython.parallel.controller.mongodb import MongoDB |
|
|||
171 | except ImportError: |
|
|||
172 | pass |
|
|||
173 | else: |
|
|||
174 | class TestMongoBackend(TestDictBackend): |
|
|||
175 | def create_db(self): |
|
|||
176 | try: |
|
|||
177 | return MongoDB(database='iptestdb') |
|
|||
178 | except Exception: |
|
|||
179 | raise SkipTest("Couldn't connect to mongodb instance") |
|
|||
180 |
|
||||
181 | def tearDown(self): |
|
|||
182 | self.db._connection.drop_database('iptestdb') |
|
@@ -199,6 +199,7 b' def make_exclude():' | |||||
199 |
|
199 | |||
200 | if not have['pymongo']: |
|
200 | if not have['pymongo']: | |
201 | exclusions.append(ipjoin('parallel', 'controller', 'mongodb')) |
|
201 | exclusions.append(ipjoin('parallel', 'controller', 'mongodb')) | |
|
202 | exclusions.append(ipjoin('parallel', 'tests', 'test_mongodb')) | |||
202 |
|
203 | |||
203 | if not have['matplotlib']: |
|
204 | if not have['matplotlib']: | |
204 | exclusions.extend([ipjoin('lib', 'pylabtools'), |
|
205 | exclusions.extend([ipjoin('lib', 'pylabtools'), |
General Comments 0
You need to be logged in to leave comments.
Login now