Show More
@@ -0,0 +1,37 b'' | |||
|
1 | """Tests for mongodb backend""" | |
|
2 | ||
|
3 | #------------------------------------------------------------------------------- | |
|
4 | # Copyright (C) 2011 The IPython Development Team | |
|
5 | # | |
|
6 | # Distributed under the terms of the BSD License. The full license is in | |
|
7 | # the file COPYING, distributed as part of this software. | |
|
8 | #------------------------------------------------------------------------------- | |
|
9 | ||
|
10 | #------------------------------------------------------------------------------- | |
|
11 | # Imports | |
|
12 | #------------------------------------------------------------------------------- | |
|
13 | ||
|
14 | from nose import SkipTest | |
|
15 | ||
|
16 | from pymongo import Connection | |
|
17 | from IPython.parallel.controller.mongodb import MongoDB | |
|
18 | ||
|
19 | from . import test_db | |
|
20 | ||
|
21 | try: | |
|
22 | c = Connection() | |
|
23 | except Exception: | |
|
24 | c=None | |
|
25 | ||
|
26 | class TestMongoBackend(test_db.TestDictBackend): | |
|
27 | """MongoDB backend tests""" | |
|
28 | ||
|
29 | def create_db(self): | |
|
30 | try: | |
|
31 | return MongoDB(database='iptestdb', _connection=c) | |
|
32 | except Exception: | |
|
33 | raise SkipTest("Couldn't connect to mongodb") | |
|
34 | ||
|
35 | def teardown(self): | |
|
36 | if c is not None: | |
|
37 | c.drop_database('iptestdb') |
@@ -146,7 +146,7 b' class DictDB(BaseDB):' | |||
|
146 | 146 | """Remove a record from the DB.""" |
|
147 | 147 | matches = self._match(check) |
|
148 | 148 | for m in matches: |
|
149 | del self._records[m] | |
|
149 | del self._records[m['msg_id']] | |
|
150 | 150 | |
|
151 | 151 | def drop_record(self, msg_id): |
|
152 | 152 | """Remove a record from the DB.""" |
@@ -1066,21 +1066,19 b' class Hub(LoggingFactory):' | |||
|
1066 | 1066 | except Exception: |
|
1067 | 1067 | reply = error.wrap_exception() |
|
1068 | 1068 | else: |
|
1069 | for msg_id in msg_ids: | |
|
1070 | if msg_id in self.all_completed: | |
|
1071 | self.db.drop_record(msg_id) | |
|
1072 | else: | |
|
1073 | if msg_id in self.pending: | |
|
1069 | pending = filter(lambda m: m in self.pending, msg_ids) | |
|
1070 | if pending: | |
|
1074 | 1071 |
|
|
1075 |
|
|
|
1072 | raise IndexError("msg pending: %r"%pending[0]) | |
|
1076 | 1073 |
|
|
1077 | 1074 |
|
|
1078 | 1075 |
|
|
1079 | 1076 |
|
|
1080 | raise IndexError("No such msg: %r"%msg_id) | |
|
1081 |
|
|
|
1077 | self.db.drop_matching_records(dict(msg_id={'$in':msg_ids})) | |
|
1078 | except Exception: | |
|
1082 | 1079 |
|
|
1083 | break | |
|
1080 | ||
|
1081 | if reply['status'] == 'ok': | |
|
1084 | 1082 | eids = content.get('engine_ids', []) |
|
1085 | 1083 | for eid in eids: |
|
1086 | 1084 | if eid not in self.engines: |
@@ -6,12 +6,10 b'' | |||
|
6 | 6 | # the file COPYING, distributed as part of this software. |
|
7 | 7 | #----------------------------------------------------------------------------- |
|
8 | 8 | |
|
9 | from datetime import datetime | |
|
10 | ||
|
11 | 9 | from pymongo import Connection |
|
12 | 10 | from pymongo.binary import Binary |
|
13 | 11 | |
|
14 | from IPython.utils.traitlets import Dict, List, CUnicode | |
|
12 | from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance | |
|
15 | 13 | |
|
16 | 14 | from .dictdb import BaseDB |
|
17 | 15 | |
@@ -25,15 +23,20 b' class MongoDB(BaseDB):' | |||
|
25 | 23 | connection_args = List(config=True) # args passed to pymongo.Connection |
|
26 | 24 | connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection |
|
27 | 25 | database = CUnicode(config=True) # name of the mongodb database |
|
28 | _table = Dict() | |
|
26 | ||
|
27 | _connection = Instance(Connection) # pymongo connection | |
|
29 | 28 | |
|
30 | 29 | def __init__(self, **kwargs): |
|
31 | 30 | super(MongoDB, self).__init__(**kwargs) |
|
31 | if self._connection is None: | |
|
32 | 32 | self._connection = Connection(*self.connection_args, **self.connection_kwargs) |
|
33 | 33 | if not self.database: |
|
34 | 34 | self.database = self.session |
|
35 | 35 | self._db = self._connection[self.database] |
|
36 | 36 | self._records = self._db['task_records'] |
|
37 | self._records.ensure_index('msg_id', unique=True) | |
|
38 | self._records.ensure_index('submitted') # for sorting history | |
|
39 | # for rec in self._records.find | |
|
37 | 40 | |
|
38 | 41 | def _binary_buffers(self, rec): |
|
39 | 42 | for key in ('buffers', 'result_buffers'): |
@@ -45,18 +48,21 b' class MongoDB(BaseDB):' | |||
|
45 | 48 | """Add a new Task Record, by msg_id.""" |
|
46 | 49 | # print rec |
|
47 | 50 | rec = self._binary_buffers(rec) |
|
48 |
|
|
|
49 | self._table[msg_id] = obj_id | |
|
51 | self._records.insert(rec) | |
|
50 | 52 | |
|
51 | 53 | def get_record(self, msg_id): |
|
52 | 54 | """Get a specific Task Record, by msg_id.""" |
|
53 |
r |
|
|
55 | r = self._records.find_one({'msg_id': msg_id}) | |
|
56 | if not r: | |
|
57 | # r will be '' if nothing is found | |
|
58 | raise KeyError(msg_id) | |
|
59 | return r | |
|
54 | 60 | |
|
55 | 61 | def update_record(self, msg_id, rec): |
|
56 | 62 | """Update the data in an existing record.""" |
|
57 | 63 | rec = self._binary_buffers(rec) |
|
58 | obj_id = self._table[msg_id] | |
|
59 |
self._records.update({'_id': |
|
|
64 | ||
|
65 | self._records.update({'msg_id':msg_id}, {'$set': rec}) | |
|
60 | 66 | |
|
61 | 67 | def drop_matching_records(self, check): |
|
62 | 68 | """Remove a record from the DB.""" |
@@ -64,8 +70,7 b' class MongoDB(BaseDB):' | |||
|
64 | 70 | |
|
65 | 71 | def drop_record(self, msg_id): |
|
66 | 72 | """Remove a record from the DB.""" |
|
67 | obj_id = self._table.pop(msg_id) | |
|
68 | self._records.remove(obj_id) | |
|
73 | self._records.remove({'msg_id':msg_id}) | |
|
69 | 74 | |
|
70 | 75 | def find_records(self, check, keys=None): |
|
71 | 76 | """Find records matching a query dict, optionally extracting subset of keys. |
@@ -27,16 +27,20 b' operators = {' | |||
|
27 | 27 | '$lt' : "<", |
|
28 | 28 | '$gt' : ">", |
|
29 | 29 | # null is handled weird with ==,!= |
|
30 |
'$eq' : " |
|
|
31 |
'$ne' : " |
|
|
30 | '$eq' : "=", | |
|
31 | '$ne' : "!=", | |
|
32 | 32 | '$lte': "<=", |
|
33 | 33 | '$gte': ">=", |
|
34 |
'$in' : (' |
|
|
35 |
'$nin': (' |
|
|
34 | '$in' : ('=', ' OR '), | |
|
35 | '$nin': ('!=', ' AND '), | |
|
36 | 36 | # '$all': None, |
|
37 | 37 | # '$mod': None, |
|
38 | 38 | # '$exists' : None |
|
39 | 39 | } |
|
40 | null_operators = { | |
|
41 | '=' : "IS NULL", | |
|
42 | '!=' : "IS NOT NULL", | |
|
43 | } | |
|
40 | 44 | |
|
41 | 45 | def _adapt_datetime(dt): |
|
42 | 46 | return dt.strftime(ISO8601) |
@@ -205,8 +209,15 b' class SQLiteDB(BaseDB):' | |||
|
205 | 209 | raise KeyError("Unsupported operator: %r"%test) |
|
206 | 210 | if isinstance(op, tuple): |
|
207 | 211 | op, join = op |
|
212 | ||
|
213 | if value is None and op in null_operators: | |
|
214 | expr = "%s %s"%null_operators[op] | |
|
215 | else: | |
|
208 | 216 | expr = "%s %s ?"%(name, op) |
|
209 | 217 | if isinstance(value, (tuple,list)): |
|
218 | if op in null_operators and any([v is None for v in value]): | |
|
219 | # equality tests don't work with NULL | |
|
220 | raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test) | |
|
210 | 221 | expr = '( %s )'%( join.join([expr]*len(value)) ) |
|
211 | 222 | args.extend(value) |
|
212 | 223 | else: |
@@ -214,7 +225,10 b' class SQLiteDB(BaseDB):' | |||
|
214 | 225 | expressions.append(expr) |
|
215 | 226 | else: |
|
216 | 227 | # it's an equality check |
|
217 | expressions.append("%s IS ?"%name) | |
|
228 | if sub_check is None: | |
|
229 | expressions.append("%s IS NULL") | |
|
230 | else: | |
|
231 | expressions.append("%s = ?"%name) | |
|
218 | 232 | args.append(sub_check) |
|
219 | 233 | |
|
220 | 234 | expr = " AND ".join(expressions) |
@@ -235,3 +235,10 b' class TestClient(ClusterTestCase):' | |||
|
235 | 235 | def test_resubmit_badkey(self): |
|
236 | 236 | """ensure KeyError on resubmit of nonexistant task""" |
|
237 | 237 | self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid']) |
|
238 | ||
|
239 | def test_purge_results(self): | |
|
240 | hist = self.client.hub_history() | |
|
241 | self.client.purge_results(hist) | |
|
242 | newhist = self.client.hub_history() | |
|
243 | self.assertTrue(len(newhist) == 0) | |
|
244 |
@@ -15,10 +15,7 b'' | |||
|
15 | 15 | import tempfile |
|
16 | 16 | import time |
|
17 | 17 | |
|
18 | import uuid | |
|
19 | ||
|
20 | 18 | from datetime import datetime, timedelta |
|
21 | from random import choice, randint | |
|
22 | 19 | from unittest import TestCase |
|
23 | 20 | |
|
24 | 21 | from nose import SkipTest |
@@ -158,25 +155,16 b' class TestDictBackend(TestCase):' | |||
|
158 | 155 | rec = self.db.get_record(msg_id) |
|
159 | 156 | self.assertTrue(isinstance(rec['completed'], datetime)) |
|
160 | 157 | |
|
158 | def test_drop_matching(self): | |
|
159 | msg_ids = self.load_records(10) | |
|
160 | query = {'msg_id' : {'$in':msg_ids}} | |
|
161 | self.db.drop_matching_records(query) | |
|
162 | recs = self.db.find_records(query) | |
|
163 | self.assertTrue(len(recs)==0) | |
|
164 | ||
|
161 | 165 | class TestSQLiteBackend(TestDictBackend): |
|
162 | 166 | def create_db(self): |
|
163 | 167 | return SQLiteDB(location=tempfile.gettempdir()) |
|
164 | 168 | |
|
165 | 169 | def tearDown(self): |
|
166 | 170 | self.db._db.close() |
|
167 | ||
|
168 | # optional MongoDB test | |
|
169 | try: | |
|
170 | from IPython.parallel.controller.mongodb import MongoDB | |
|
171 | except ImportError: | |
|
172 | pass | |
|
173 | else: | |
|
174 | class TestMongoBackend(TestDictBackend): | |
|
175 | def create_db(self): | |
|
176 | try: | |
|
177 | return MongoDB(database='iptestdb') | |
|
178 | except Exception: | |
|
179 | raise SkipTest("Couldn't connect to mongodb instance") | |
|
180 | ||
|
181 | def tearDown(self): | |
|
182 | self.db._connection.drop_database('iptestdb') |
@@ -199,6 +199,7 b' def make_exclude():' | |||
|
199 | 199 | |
|
200 | 200 | if not have['pymongo']: |
|
201 | 201 | exclusions.append(ipjoin('parallel', 'controller', 'mongodb')) |
|
202 | exclusions.append(ipjoin('parallel', 'tests', 'test_mongodb')) | |
|
202 | 203 | |
|
203 | 204 | if not have['matplotlib']: |
|
204 | 205 | exclusions.extend([ipjoin('lib', 'pylabtools'), |
General Comments 0
You need to be logged in to leave comments.
Login now