##// END OF EJS Templates
dictdb queries should [shallow] copy records...
MinRK -
Show More
@@ -1,216 +1,216 b''
1 """A Task logger that presents our DB interface,
1 """A Task logger that presents our DB interface,
2 but exists entirely in memory and implemented with dicts.
2 but exists entirely in memory and implemented with dicts.
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7
7
8
8
9 TaskRecords are dicts of the form:
9 TaskRecords are dicts of the form:
10 {
10 {
11 'msg_id' : str(uuid),
11 'msg_id' : str(uuid),
12 'client_uuid' : str(uuid),
12 'client_uuid' : str(uuid),
13 'engine_uuid' : str(uuid) or None,
13 'engine_uuid' : str(uuid) or None,
14 'header' : dict(header),
14 'header' : dict(header),
15 'content': dict(content),
15 'content': dict(content),
16 'buffers': list(buffers),
16 'buffers': list(buffers),
17 'submitted': datetime,
17 'submitted': datetime,
18 'started': datetime or None,
18 'started': datetime or None,
19 'completed': datetime or None,
19 'completed': datetime or None,
20 'resubmitted': datetime or None,
20 'resubmitted': datetime or None,
21 'result_header' : dict(header) or None,
21 'result_header' : dict(header) or None,
22 'result_content' : dict(content) or None,
22 'result_content' : dict(content) or None,
23 'result_buffers' : list(buffers) or None,
23 'result_buffers' : list(buffers) or None,
24 }
24 }
25 With this info, many of the special categories of tasks can be defined by query:
25 With this info, many of the special categories of tasks can be defined by query:
26
26
27 pending: completed is None
27 pending: completed is None
28 client's outstanding: client_uuid = uuid && completed is None
28 client's outstanding: client_uuid = uuid && completed is None
29 MIA: arrived is None (and completed is None)
29 MIA: arrived is None (and completed is None)
30 etc.
30 etc.
31
31
32 EngineRecords are dicts of the form:
32 EngineRecords are dicts of the form:
33 {
33 {
34 'eid' : int(id),
34 'eid' : int(id),
35 'uuid': str(uuid)
35 'uuid': str(uuid)
36 }
36 }
37 This may be extended, but is currently.
37 This may be extended, but is currently.
38
38
39 We support a subset of mongodb operators:
39 We support a subset of mongodb operators:
40 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
40 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
41 """
41 """
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43 # Copyright (C) 2010-2011 The IPython Development Team
43 # Copyright (C) 2010-2011 The IPython Development Team
44 #
44 #
45 # Distributed under the terms of the BSD License. The full license is in
45 # Distributed under the terms of the BSD License. The full license is in
46 # the file COPYING, distributed as part of this software.
46 # the file COPYING, distributed as part of this software.
47 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
48
48
49
49 from copy import copy
50 from datetime import datetime
50 from datetime import datetime
51
51
52 from IPython.config.configurable import LoggingConfigurable
52 from IPython.config.configurable import LoggingConfigurable
53
53
54 from IPython.utils.traitlets import Dict, Unicode, Instance
54 from IPython.utils.traitlets import Dict, Unicode, Instance
55
55
56 filters = {
56 filters = {
57 '$lt' : lambda a,b: a < b,
57 '$lt' : lambda a,b: a < b,
58 '$gt' : lambda a,b: b > a,
58 '$gt' : lambda a,b: b > a,
59 '$eq' : lambda a,b: a == b,
59 '$eq' : lambda a,b: a == b,
60 '$ne' : lambda a,b: a != b,
60 '$ne' : lambda a,b: a != b,
61 '$lte': lambda a,b: a <= b,
61 '$lte': lambda a,b: a <= b,
62 '$gte': lambda a,b: a >= b,
62 '$gte': lambda a,b: a >= b,
63 '$in' : lambda a,b: a in b,
63 '$in' : lambda a,b: a in b,
64 '$nin': lambda a,b: a not in b,
64 '$nin': lambda a,b: a not in b,
65 '$all': lambda a,b: all([ a in bb for bb in b ]),
65 '$all': lambda a,b: all([ a in bb for bb in b ]),
66 '$mod': lambda a,b: a%b[0] == b[1],
66 '$mod': lambda a,b: a%b[0] == b[1],
67 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
67 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
68 }
68 }
69
69
70
70
71 class CompositeFilter(object):
71 class CompositeFilter(object):
72 """Composite filter for matching multiple properties."""
72 """Composite filter for matching multiple properties."""
73
73
74 def __init__(self, dikt):
74 def __init__(self, dikt):
75 self.tests = []
75 self.tests = []
76 self.values = []
76 self.values = []
77 for key, value in dikt.iteritems():
77 for key, value in dikt.iteritems():
78 self.tests.append(filters[key])
78 self.tests.append(filters[key])
79 self.values.append(value)
79 self.values.append(value)
80
80
81 def __call__(self, value):
81 def __call__(self, value):
82 for test,check in zip(self.tests, self.values):
82 for test,check in zip(self.tests, self.values):
83 if not test(value, check):
83 if not test(value, check):
84 return False
84 return False
85 return True
85 return True
86
86
87 class BaseDB(LoggingConfigurable):
87 class BaseDB(LoggingConfigurable):
88 """Empty Parent class so traitlets work on DB."""
88 """Empty Parent class so traitlets work on DB."""
89 # base configurable traits:
89 # base configurable traits:
90 session = Unicode("")
90 session = Unicode("")
91
91
92 class DictDB(BaseDB):
92 class DictDB(BaseDB):
93 """Basic in-memory dict-based object for saving Task Records.
93 """Basic in-memory dict-based object for saving Task Records.
94
94
95 This is the first object to present the DB interface
95 This is the first object to present the DB interface
96 for logging tasks out of memory.
96 for logging tasks out of memory.
97
97
98 The interface is based on MongoDB, so adding a MongoDB
98 The interface is based on MongoDB, so adding a MongoDB
99 backend should be straightforward.
99 backend should be straightforward.
100 """
100 """
101
101
102 _records = Dict()
102 _records = Dict()
103
103
104 def _match_one(self, rec, tests):
104 def _match_one(self, rec, tests):
105 """Check if a specific record matches tests."""
105 """Check if a specific record matches tests."""
106 for key,test in tests.iteritems():
106 for key,test in tests.iteritems():
107 if not test(rec.get(key, None)):
107 if not test(rec.get(key, None)):
108 return False
108 return False
109 return True
109 return True
110
110
111 def _match(self, check):
111 def _match(self, check):
112 """Find all the matches for a check dict."""
112 """Find all the matches for a check dict."""
113 matches = []
113 matches = []
114 tests = {}
114 tests = {}
115 for k,v in check.iteritems():
115 for k,v in check.iteritems():
116 if isinstance(v, dict):
116 if isinstance(v, dict):
117 tests[k] = CompositeFilter(v)
117 tests[k] = CompositeFilter(v)
118 else:
118 else:
119 tests[k] = lambda o: o==v
119 tests[k] = lambda o: o==v
120
120
121 for rec in self._records.itervalues():
121 for rec in self._records.itervalues():
122 if self._match_one(rec, tests):
122 if self._match_one(rec, tests):
123 matches.append(rec)
123 matches.append(copy(rec))
124 return matches
124 return matches
125
125
126 def _extract_subdict(self, rec, keys):
126 def _extract_subdict(self, rec, keys):
127 """extract subdict of keys"""
127 """extract subdict of keys"""
128 d = {}
128 d = {}
129 d['msg_id'] = rec['msg_id']
129 d['msg_id'] = rec['msg_id']
130 for key in keys:
130 for key in keys:
131 d[key] = rec[key]
131 d[key] = rec[key]
132 return d
132 return d
133
133
134 def add_record(self, msg_id, rec):
134 def add_record(self, msg_id, rec):
135 """Add a new Task Record, by msg_id."""
135 """Add a new Task Record, by msg_id."""
136 if self._records.has_key(msg_id):
136 if self._records.has_key(msg_id):
137 raise KeyError("Already have msg_id %r"%(msg_id))
137 raise KeyError("Already have msg_id %r"%(msg_id))
138 self._records[msg_id] = rec
138 self._records[msg_id] = rec
139
139
140 def get_record(self, msg_id):
140 def get_record(self, msg_id):
141 """Get a specific Task Record, by msg_id."""
141 """Get a specific Task Record, by msg_id."""
142 if not self._records.has_key(msg_id):
142 if not msg_id in self._records:
143 raise KeyError("No such msg_id %r"%(msg_id))
143 raise KeyError("No such msg_id %r"%(msg_id))
144 return self._records[msg_id]
144 return copy(self._records[msg_id])
145
145
146 def update_record(self, msg_id, rec):
146 def update_record(self, msg_id, rec):
147 """Update the data in an existing record."""
147 """Update the data in an existing record."""
148 self._records[msg_id].update(rec)
148 self._records[msg_id].update(rec)
149
149
150 def drop_matching_records(self, check):
150 def drop_matching_records(self, check):
151 """Remove a record from the DB."""
151 """Remove a record from the DB."""
152 matches = self._match(check)
152 matches = self._match(check)
153 for m in matches:
153 for m in matches:
154 del self._records[m['msg_id']]
154 del self._records[m['msg_id']]
155
155
156 def drop_record(self, msg_id):
156 def drop_record(self, msg_id):
157 """Remove a record from the DB."""
157 """Remove a record from the DB."""
158 del self._records[msg_id]
158 del self._records[msg_id]
159
159
160
160
161 def find_records(self, check, keys=None):
161 def find_records(self, check, keys=None):
162 """Find records matching a query dict, optionally extracting subset of keys.
162 """Find records matching a query dict, optionally extracting subset of keys.
163
163
164 Returns dict keyed by msg_id of matching records.
164 Returns dict keyed by msg_id of matching records.
165
165
166 Parameters
166 Parameters
167 ----------
167 ----------
168
168
169 check: dict
169 check: dict
170 mongodb-style query argument
170 mongodb-style query argument
171 keys: list of strs [optional]
171 keys: list of strs [optional]
172 if specified, the subset of keys to extract. msg_id will *always* be
172 if specified, the subset of keys to extract. msg_id will *always* be
173 included.
173 included.
174 """
174 """
175 matches = self._match(check)
175 matches = self._match(check)
176 if keys:
176 if keys:
177 return [ self._extract_subdict(rec, keys) for rec in matches ]
177 return [ self._extract_subdict(rec, keys) for rec in matches ]
178 else:
178 else:
179 return matches
179 return matches
180
180
181
181
182 def get_history(self):
182 def get_history(self):
183 """get all msg_ids, ordered by time submitted."""
183 """get all msg_ids, ordered by time submitted."""
184 msg_ids = self._records.keys()
184 msg_ids = self._records.keys()
185 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
185 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
186
186
187 class NoDB(DictDB):
187 class NoDB(DictDB):
188 """A blackhole db backend that actually stores no information.
188 """A blackhole db backend that actually stores no information.
189
189
190 Provides the full DB interface, but raises KeyErrors on any
190 Provides the full DB interface, but raises KeyErrors on any
191 method that tries to access the records. This can be used to
191 method that tries to access the records. This can be used to
192 minimize the memory footprint of the Hub when its record-keeping
192 minimize the memory footprint of the Hub when its record-keeping
193 functionality is not required.
193 functionality is not required.
194 """
194 """
195
195
196 def add_record(self, msg_id, record):
196 def add_record(self, msg_id, record):
197 pass
197 pass
198
198
199 def get_record(self, msg_id):
199 def get_record(self, msg_id):
200 raise KeyError("NoDB does not support record access")
200 raise KeyError("NoDB does not support record access")
201
201
202 def update_record(self, msg_id, record):
202 def update_record(self, msg_id, record):
203 pass
203 pass
204
204
205 def drop_matching_records(self, check):
205 def drop_matching_records(self, check):
206 pass
206 pass
207
207
208 def drop_record(self, msg_id):
208 def drop_record(self, msg_id):
209 pass
209 pass
210
210
211 def find_records(self, check, keys=None):
211 def find_records(self, check, keys=None):
212 raise KeyError("NoDB does not store information")
212 raise KeyError("NoDB does not store information")
213
213
214 def get_history(self):
214 def get_history(self):
215 raise KeyError("NoDB does not store information")
215 raise KeyError("NoDB does not store information")
216
216
@@ -1,319 +1,331 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
31
31
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33
33
34 def setup():
34 def setup():
35 add_engines(4, total=True)
35 add_engines(4, total=True)
36
36
37 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
38
38
39 def test_ids(self):
39 def test_ids(self):
40 n = len(self.client.ids)
40 n = len(self.client.ids)
41 self.add_engines(2)
41 self.add_engines(2)
42 self.assertEquals(len(self.client.ids), n+2)
42 self.assertEquals(len(self.client.ids), n+2)
43
43
44 def test_view_indexing(self):
44 def test_view_indexing(self):
45 """test index access for views"""
45 """test index access for views"""
46 self.minimum_engines(4)
46 self.minimum_engines(4)
47 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
48 v = self.client[:]
48 v = self.client[:]
49 self.assertEquals(v.targets, targets)
49 self.assertEquals(v.targets, targets)
50 t = self.client.ids[2]
50 t = self.client.ids[2]
51 v = self.client[t]
51 v = self.client[t]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, t)
53 self.assertEquals(v.targets, t)
54 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
55 v = self.client[t]
55 v = self.client[t]
56 self.assert_(isinstance(v, DirectView))
56 self.assert_(isinstance(v, DirectView))
57 self.assertEquals(v.targets, t)
57 self.assertEquals(v.targets, t)
58 v = self.client[::2]
58 v = self.client[::2]
59 self.assert_(isinstance(v, DirectView))
59 self.assert_(isinstance(v, DirectView))
60 self.assertEquals(v.targets, targets[::2])
60 self.assertEquals(v.targets, targets[::2])
61 v = self.client[1::3]
61 v = self.client[1::3]
62 self.assert_(isinstance(v, DirectView))
62 self.assert_(isinstance(v, DirectView))
63 self.assertEquals(v.targets, targets[1::3])
63 self.assertEquals(v.targets, targets[1::3])
64 v = self.client[:-3]
64 v = self.client[:-3]
65 self.assert_(isinstance(v, DirectView))
65 self.assert_(isinstance(v, DirectView))
66 self.assertEquals(v.targets, targets[:-3])
66 self.assertEquals(v.targets, targets[:-3])
67 v = self.client[-1]
67 v = self.client[-1]
68 self.assert_(isinstance(v, DirectView))
68 self.assert_(isinstance(v, DirectView))
69 self.assertEquals(v.targets, targets[-1])
69 self.assertEquals(v.targets, targets[-1])
70 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
71
71
72 def test_lbview_targets(self):
72 def test_lbview_targets(self):
73 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
74 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
75 self.assertEquals(v.targets, None)
75 self.assertEquals(v.targets, None)
76 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
77 self.assertEquals(v.targets, [self.client.ids[-1]])
77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
79 self.assertEquals(v.targets, None)
79 self.assertEquals(v.targets, None)
80
80
81 def test_dview_targets(self):
81 def test_dview_targets(self):
82 """test direct_view targets"""
82 """test direct_view targets"""
83 v = self.client.direct_view()
83 v = self.client.direct_view()
84 self.assertEquals(v.targets, 'all')
84 self.assertEquals(v.targets, 'all')
85 v = self.client.direct_view('all')
85 v = self.client.direct_view('all')
86 self.assertEquals(v.targets, 'all')
86 self.assertEquals(v.targets, 'all')
87 v = self.client.direct_view(-1)
87 v = self.client.direct_view(-1)
88 self.assertEquals(v.targets, self.client.ids[-1])
88 self.assertEquals(v.targets, self.client.ids[-1])
89
89
90 def test_lazy_all_targets(self):
90 def test_lazy_all_targets(self):
91 """test lazy evaluation of rc.direct_view('all')"""
91 """test lazy evaluation of rc.direct_view('all')"""
92 v = self.client.direct_view()
92 v = self.client.direct_view()
93 self.assertEquals(v.targets, 'all')
93 self.assertEquals(v.targets, 'all')
94
94
95 def double(x):
95 def double(x):
96 return x*2
96 return x*2
97 seq = range(100)
97 seq = range(100)
98 ref = [ double(x) for x in seq ]
98 ref = [ double(x) for x in seq ]
99
99
100 # add some engines, which should be used
100 # add some engines, which should be used
101 self.add_engines(1)
101 self.add_engines(1)
102 n1 = len(self.client.ids)
102 n1 = len(self.client.ids)
103
103
104 # simple apply
104 # simple apply
105 r = v.apply_sync(lambda : 1)
105 r = v.apply_sync(lambda : 1)
106 self.assertEquals(r, [1] * n1)
106 self.assertEquals(r, [1] * n1)
107
107
108 # map goes through remotefunction
108 # map goes through remotefunction
109 r = v.map_sync(double, seq)
109 r = v.map_sync(double, seq)
110 self.assertEquals(r, ref)
110 self.assertEquals(r, ref)
111
111
112 # add a couple more engines, and try again
112 # add a couple more engines, and try again
113 self.add_engines(2)
113 self.add_engines(2)
114 n2 = len(self.client.ids)
114 n2 = len(self.client.ids)
115 self.assertNotEquals(n2, n1)
115 self.assertNotEquals(n2, n1)
116
116
117 # apply
117 # apply
118 r = v.apply_sync(lambda : 1)
118 r = v.apply_sync(lambda : 1)
119 self.assertEquals(r, [1] * n2)
119 self.assertEquals(r, [1] * n2)
120
120
121 # map
121 # map
122 r = v.map_sync(double, seq)
122 r = v.map_sync(double, seq)
123 self.assertEquals(r, ref)
123 self.assertEquals(r, ref)
124
124
125 def test_targets(self):
125 def test_targets(self):
126 """test various valid targets arguments"""
126 """test various valid targets arguments"""
127 build = self.client._build_targets
127 build = self.client._build_targets
128 ids = self.client.ids
128 ids = self.client.ids
129 idents,targets = build(None)
129 idents,targets = build(None)
130 self.assertEquals(ids, targets)
130 self.assertEquals(ids, targets)
131
131
132 def test_clear(self):
132 def test_clear(self):
133 """test clear behavior"""
133 """test clear behavior"""
134 self.minimum_engines(2)
134 self.minimum_engines(2)
135 v = self.client[:]
135 v = self.client[:]
136 v.block=True
136 v.block=True
137 v.push(dict(a=5))
137 v.push(dict(a=5))
138 v.pull('a')
138 v.pull('a')
139 id0 = self.client.ids[-1]
139 id0 = self.client.ids[-1]
140 self.client.clear(targets=id0, block=True)
140 self.client.clear(targets=id0, block=True)
141 a = self.client[:-1].get('a')
141 a = self.client[:-1].get('a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.client.clear(block=True)
143 self.client.clear(block=True)
144 for i in self.client.ids:
144 for i in self.client.ids:
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146
146
147 def test_get_result(self):
147 def test_get_result(self):
148 """test getting results from the Hub."""
148 """test getting results from the Hub."""
149 c = clientmod.Client(profile='iptest')
149 c = clientmod.Client(profile='iptest')
150 t = c.ids[-1]
150 t = c.ids[-1]
151 ar = c[t].apply_async(wait, 1)
151 ar = c[t].apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = self.client.get_result(ar.msg_ids)
154 ahr = self.client.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEquals(ahr.get(), ar.get())
156 self.assertEquals(ahr.get(), ar.get())
157 ar2 = self.client.get_result(ar.msg_ids)
157 ar2 = self.client.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.close()
159 c.close()
160
160
161 def test_ids_list(self):
161 def test_ids_list(self):
162 """test client.ids"""
162 """test client.ids"""
163 ids = self.client.ids
163 ids = self.client.ids
164 self.assertEquals(ids, self.client._ids)
164 self.assertEquals(ids, self.client._ids)
165 self.assertFalse(ids is self.client._ids)
165 self.assertFalse(ids is self.client._ids)
166 ids.remove(ids[-1])
166 ids.remove(ids[-1])
167 self.assertNotEquals(ids, self.client._ids)
167 self.assertNotEquals(ids, self.client._ids)
168
168
169 def test_queue_status(self):
169 def test_queue_status(self):
170 ids = self.client.ids
170 ids = self.client.ids
171 id0 = ids[0]
171 id0 = ids[0]
172 qs = self.client.queue_status(targets=id0)
172 qs = self.client.queue_status(targets=id0)
173 self.assertTrue(isinstance(qs, dict))
173 self.assertTrue(isinstance(qs, dict))
174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 allqs = self.client.queue_status()
175 allqs = self.client.queue_status()
176 self.assertTrue(isinstance(allqs, dict))
176 self.assertTrue(isinstance(allqs, dict))
177 intkeys = list(allqs.keys())
177 intkeys = list(allqs.keys())
178 intkeys.remove('unassigned')
178 intkeys.remove('unassigned')
179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 unassigned = allqs.pop('unassigned')
180 unassigned = allqs.pop('unassigned')
181 for eid,qs in allqs.items():
181 for eid,qs in allqs.items():
182 self.assertTrue(isinstance(qs, dict))
182 self.assertTrue(isinstance(qs, dict))
183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184
184
185 def test_shutdown(self):
185 def test_shutdown(self):
186 ids = self.client.ids
186 ids = self.client.ids
187 id0 = ids[0]
187 id0 = ids[0]
188 self.client.shutdown(id0, block=True)
188 self.client.shutdown(id0, block=True)
189 while id0 in self.client.ids:
189 while id0 in self.client.ids:
190 time.sleep(0.1)
190 time.sleep(0.1)
191 self.client.spin()
191 self.client.spin()
192
192
193 self.assertRaises(IndexError, lambda : self.client[id0])
193 self.assertRaises(IndexError, lambda : self.client[id0])
194
194
195 def test_result_status(self):
195 def test_result_status(self):
196 pass
196 pass
197 # to be written
197 # to be written
198
198
199 def test_db_query_dt(self):
199 def test_db_query_dt(self):
200 """test db query by date"""
200 """test db query by date"""
201 hist = self.client.hub_history()
201 hist = self.client.hub_history()
202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 tic = middle['submitted']
203 tic = middle['submitted']
204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 self.assertEquals(len(before)+len(after),len(hist))
206 self.assertEquals(len(before)+len(after),len(hist))
207 for b in before:
207 for b in before:
208 self.assertTrue(b['submitted'] < tic)
208 self.assertTrue(b['submitted'] < tic)
209 for a in after:
209 for a in after:
210 self.assertTrue(a['submitted'] >= tic)
210 self.assertTrue(a['submitted'] >= tic)
211 same = self.client.db_query({'submitted' : tic})
211 same = self.client.db_query({'submitted' : tic})
212 for s in same:
212 for s in same:
213 self.assertTrue(s['submitted'] == tic)
213 self.assertTrue(s['submitted'] == tic)
214
214
215 def test_db_query_keys(self):
215 def test_db_query_keys(self):
216 """test extracting subset of record keys"""
216 """test extracting subset of record keys"""
217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 for rec in found:
218 for rec in found:
219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220
220
221 def test_db_query_default_keys(self):
221 def test_db_query_default_keys(self):
222 """default db_query excludes buffers"""
222 """default db_query excludes buffers"""
223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 for rec in found:
224 for rec in found:
225 keys = set(rec.keys())
225 keys = set(rec.keys())
226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228
228
229 def test_db_query_msg_id(self):
229 def test_db_query_msg_id(self):
230 """ensure msg_id is always in db queries"""
230 """ensure msg_id is always in db queries"""
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 for rec in found:
232 for rec in found:
233 self.assertTrue('msg_id' in rec.keys())
233 self.assertTrue('msg_id' in rec.keys())
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 for rec in found:
235 for rec in found:
236 self.assertTrue('msg_id' in rec.keys())
236 self.assertTrue('msg_id' in rec.keys())
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 for rec in found:
238 for rec in found:
239 self.assertTrue('msg_id' in rec.keys())
239 self.assertTrue('msg_id' in rec.keys())
240
240
241 def test_db_query_get_result(self):
242 """pop in db_query shouldn't pop from result itself"""
243 self.client[:].apply_sync(lambda : 1)
244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 rc2 = clientmod.Client(profile='iptest')
246 # If this bug is not fixed, this call will hang:
247 ar = rc2.get_result(self.client.history[-1])
248 ar.wait(2)
249 self.assertTrue(ar.ready())
250 ar.get()
251 rc2.close()
252
241 def test_db_query_in(self):
253 def test_db_query_in(self):
242 """test db query with '$in','$nin' operators"""
254 """test db query with '$in','$nin' operators"""
243 hist = self.client.hub_history()
255 hist = self.client.hub_history()
244 even = hist[::2]
256 even = hist[::2]
245 odd = hist[1::2]
257 odd = hist[1::2]
246 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
258 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
247 found = [ r['msg_id'] for r in recs ]
259 found = [ r['msg_id'] for r in recs ]
248 self.assertEquals(set(even), set(found))
260 self.assertEquals(set(even), set(found))
249 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
261 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
250 found = [ r['msg_id'] for r in recs ]
262 found = [ r['msg_id'] for r in recs ]
251 self.assertEquals(set(odd), set(found))
263 self.assertEquals(set(odd), set(found))
252
264
253 def test_hub_history(self):
265 def test_hub_history(self):
254 hist = self.client.hub_history()
266 hist = self.client.hub_history()
255 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
267 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
256 recdict = {}
268 recdict = {}
257 for rec in recs:
269 for rec in recs:
258 recdict[rec['msg_id']] = rec
270 recdict[rec['msg_id']] = rec
259
271
260 latest = datetime(1984,1,1)
272 latest = datetime(1984,1,1)
261 for msg_id in hist:
273 for msg_id in hist:
262 rec = recdict[msg_id]
274 rec = recdict[msg_id]
263 newt = rec['submitted']
275 newt = rec['submitted']
264 self.assertTrue(newt >= latest)
276 self.assertTrue(newt >= latest)
265 latest = newt
277 latest = newt
266 ar = self.client[-1].apply_async(lambda : 1)
278 ar = self.client[-1].apply_async(lambda : 1)
267 ar.get()
279 ar.get()
268 time.sleep(0.25)
280 time.sleep(0.25)
269 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
281 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
270
282
271 def test_resubmit(self):
283 def test_resubmit(self):
272 def f():
284 def f():
273 import random
285 import random
274 return random.random()
286 return random.random()
275 v = self.client.load_balanced_view()
287 v = self.client.load_balanced_view()
276 ar = v.apply_async(f)
288 ar = v.apply_async(f)
277 r1 = ar.get(1)
289 r1 = ar.get(1)
278 # give the Hub a chance to notice:
290 # give the Hub a chance to notice:
279 time.sleep(0.5)
291 time.sleep(0.5)
280 ahr = self.client.resubmit(ar.msg_ids)
292 ahr = self.client.resubmit(ar.msg_ids)
281 r2 = ahr.get(1)
293 r2 = ahr.get(1)
282 self.assertFalse(r1 == r2)
294 self.assertFalse(r1 == r2)
283
295
284 def test_resubmit_inflight(self):
296 def test_resubmit_inflight(self):
285 """ensure ValueError on resubmit of inflight task"""
297 """ensure ValueError on resubmit of inflight task"""
286 v = self.client.load_balanced_view()
298 v = self.client.load_balanced_view()
287 ar = v.apply_async(time.sleep,1)
299 ar = v.apply_async(time.sleep,1)
288 # give the message a chance to arrive
300 # give the message a chance to arrive
289 time.sleep(0.2)
301 time.sleep(0.2)
290 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
302 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
291 ar.get(2)
303 ar.get(2)
292
304
293 def test_resubmit_badkey(self):
305 def test_resubmit_badkey(self):
294 """ensure KeyError on resubmit of nonexistant task"""
306 """ensure KeyError on resubmit of nonexistant task"""
295 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
307 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
296
308
297 def test_purge_results(self):
309 def test_purge_results(self):
298 # ensure there are some tasks
310 # ensure there are some tasks
299 for i in range(5):
311 for i in range(5):
300 self.client[:].apply_sync(lambda : 1)
312 self.client[:].apply_sync(lambda : 1)
301 # Wait for the Hub to realise the result is done:
313 # Wait for the Hub to realise the result is done:
302 # This prevents a race condition, where we
314 # This prevents a race condition, where we
303 # might purge a result the Hub still thinks is pending.
315 # might purge a result the Hub still thinks is pending.
304 time.sleep(0.1)
316 time.sleep(0.1)
305 rc2 = clientmod.Client(profile='iptest')
317 rc2 = clientmod.Client(profile='iptest')
306 hist = self.client.hub_history()
318 hist = self.client.hub_history()
307 ahr = rc2.get_result([hist[-1]])
319 ahr = rc2.get_result([hist[-1]])
308 ahr.wait(10)
320 ahr.wait(10)
309 self.client.purge_results(hist[-1])
321 self.client.purge_results(hist[-1])
310 newhist = self.client.hub_history()
322 newhist = self.client.hub_history()
311 self.assertEquals(len(newhist)+1,len(hist))
323 self.assertEquals(len(newhist)+1,len(hist))
312 rc2.spin()
324 rc2.spin()
313 rc2.close()
325 rc2.close()
314
326
315 def test_purge_all_results(self):
327 def test_purge_all_results(self):
316 self.client.purge_results('all')
328 self.client.purge_results('all')
317 hist = self.client.hub_history()
329 hist = self.client.hub_history()
318 self.assertEquals(len(hist), 0)
330 self.assertEquals(len(hist), 0)
319
331
@@ -1,194 +1,224 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import tempfile
21 import tempfile
22 import time
22 import time
23
23
24 from datetime import datetime, timedelta
24 from datetime import datetime, timedelta
25 from unittest import TestCase
25 from unittest import TestCase
26
26
27 from IPython.parallel import error
27 from IPython.parallel import error
28 from IPython.parallel.controller.dictdb import DictDB
28 from IPython.parallel.controller.dictdb import DictDB
29 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 from IPython.parallel.controller.sqlitedb import SQLiteDB
30 from IPython.parallel.controller.hub import init_record, empty_record
30 from IPython.parallel.controller.hub import init_record, empty_record
31
31
32 from IPython.testing import decorators as dec
32 from IPython.testing import decorators as dec
33 from IPython.zmq.session import Session
33 from IPython.zmq.session import Session
34
34
35
35
36 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
37 # TestCases
37 # TestCases
38 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
39
39
40 class TestDictBackend(TestCase):
40 class TestDictBackend(TestCase):
41 def setUp(self):
41 def setUp(self):
42 self.session = Session()
42 self.session = Session()
43 self.db = self.create_db()
43 self.db = self.create_db()
44 self.load_records(16)
44 self.load_records(16)
45
45
46 def create_db(self):
46 def create_db(self):
47 return DictDB()
47 return DictDB()
48
48
49 def load_records(self, n=1):
49 def load_records(self, n=1):
50 """load n records for testing"""
50 """load n records for testing"""
51 #sleep 1/10 s, to ensure timestamp is different to previous calls
51 #sleep 1/10 s, to ensure timestamp is different to previous calls
52 time.sleep(0.1)
52 time.sleep(0.1)
53 msg_ids = []
53 msg_ids = []
54 for i in range(n):
54 for i in range(n):
55 msg = self.session.msg('apply_request', content=dict(a=5))
55 msg = self.session.msg('apply_request', content=dict(a=5))
56 msg['buffers'] = []
56 msg['buffers'] = []
57 rec = init_record(msg)
57 rec = init_record(msg)
58 msg_id = msg['header']['msg_id']
58 msg_id = msg['header']['msg_id']
59 msg_ids.append(msg_id)
59 msg_ids.append(msg_id)
60 self.db.add_record(msg_id, rec)
60 self.db.add_record(msg_id, rec)
61 return msg_ids
61 return msg_ids
62
62
63 def test_add_record(self):
63 def test_add_record(self):
64 before = self.db.get_history()
64 before = self.db.get_history()
65 self.load_records(5)
65 self.load_records(5)
66 after = self.db.get_history()
66 after = self.db.get_history()
67 self.assertEquals(len(after), len(before)+5)
67 self.assertEquals(len(after), len(before)+5)
68 self.assertEquals(after[:-5],before)
68 self.assertEquals(after[:-5],before)
69
69
70 def test_drop_record(self):
70 def test_drop_record(self):
71 msg_id = self.load_records()[-1]
71 msg_id = self.load_records()[-1]
72 rec = self.db.get_record(msg_id)
72 rec = self.db.get_record(msg_id)
73 self.db.drop_record(msg_id)
73 self.db.drop_record(msg_id)
74 self.assertRaises(KeyError,self.db.get_record, msg_id)
74 self.assertRaises(KeyError,self.db.get_record, msg_id)
75
75
76 def _round_to_millisecond(self, dt):
76 def _round_to_millisecond(self, dt):
77 """necessary because mongodb rounds microseconds"""
77 """necessary because mongodb rounds microseconds"""
78 micro = dt.microsecond
78 micro = dt.microsecond
79 extra = int(str(micro)[-3:])
79 extra = int(str(micro)[-3:])
80 return dt - timedelta(microseconds=extra)
80 return dt - timedelta(microseconds=extra)
81
81
82 def test_update_record(self):
82 def test_update_record(self):
83 now = self._round_to_millisecond(datetime.now())
83 now = self._round_to_millisecond(datetime.now())
84 #
84 #
85 msg_id = self.db.get_history()[-1]
85 msg_id = self.db.get_history()[-1]
86 rec1 = self.db.get_record(msg_id)
86 rec1 = self.db.get_record(msg_id)
87 data = {'stdout': 'hello there', 'completed' : now}
87 data = {'stdout': 'hello there', 'completed' : now}
88 self.db.update_record(msg_id, data)
88 self.db.update_record(msg_id, data)
89 rec2 = self.db.get_record(msg_id)
89 rec2 = self.db.get_record(msg_id)
90 self.assertEquals(rec2['stdout'], 'hello there')
90 self.assertEquals(rec2['stdout'], 'hello there')
91 self.assertEquals(rec2['completed'], now)
91 self.assertEquals(rec2['completed'], now)
92 rec1.update(data)
92 rec1.update(data)
93 self.assertEquals(rec1, rec2)
93 self.assertEquals(rec1, rec2)
94
94
95 # def test_update_record_bad(self):
95 # def test_update_record_bad(self):
96 # """test updating nonexistant records"""
96 # """test updating nonexistant records"""
97 # msg_id = str(uuid.uuid4())
97 # msg_id = str(uuid.uuid4())
98 # data = {'stdout': 'hello there'}
98 # data = {'stdout': 'hello there'}
99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
100
100
101 def test_find_records_dt(self):
101 def test_find_records_dt(self):
102 """test finding records by date"""
102 """test finding records by date"""
103 hist = self.db.get_history()
103 hist = self.db.get_history()
104 middle = self.db.get_record(hist[len(hist)//2])
104 middle = self.db.get_record(hist[len(hist)//2])
105 tic = middle['submitted']
105 tic = middle['submitted']
106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
108 self.assertEquals(len(before)+len(after),len(hist))
108 self.assertEquals(len(before)+len(after),len(hist))
109 for b in before:
109 for b in before:
110 self.assertTrue(b['submitted'] < tic)
110 self.assertTrue(b['submitted'] < tic)
111 for a in after:
111 for a in after:
112 self.assertTrue(a['submitted'] >= tic)
112 self.assertTrue(a['submitted'] >= tic)
113 same = self.db.find_records({'submitted' : tic})
113 same = self.db.find_records({'submitted' : tic})
114 for s in same:
114 for s in same:
115 self.assertTrue(s['submitted'] == tic)
115 self.assertTrue(s['submitted'] == tic)
116
116
117 def test_find_records_keys(self):
117 def test_find_records_keys(self):
118 """test extracting subset of record keys"""
118 """test extracting subset of record keys"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
120 for rec in found:
121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
122
122
123 def test_find_records_msg_id(self):
123 def test_find_records_msg_id(self):
124 """ensure msg_id is always in found records"""
124 """ensure msg_id is always in found records"""
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
126 for rec in found:
126 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
127 self.assertTrue('msg_id' in rec.keys())
128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
129 for rec in found:
129 for rec in found:
130 self.assertTrue('msg_id' in rec.keys())
130 self.assertTrue('msg_id' in rec.keys())
131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
132 for rec in found:
132 for rec in found:
133 self.assertTrue('msg_id' in rec.keys())
133 self.assertTrue('msg_id' in rec.keys())
134
134
135 def test_find_records_in(self):
135 def test_find_records_in(self):
136 """test finding records with '$in','$nin' operators"""
136 """test finding records with '$in','$nin' operators"""
137 hist = self.db.get_history()
137 hist = self.db.get_history()
138 even = hist[::2]
138 even = hist[::2]
139 odd = hist[1::2]
139 odd = hist[1::2]
140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
141 found = [ r['msg_id'] for r in recs ]
141 found = [ r['msg_id'] for r in recs ]
142 self.assertEquals(set(even), set(found))
142 self.assertEquals(set(even), set(found))
143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
144 found = [ r['msg_id'] for r in recs ]
144 found = [ r['msg_id'] for r in recs ]
145 self.assertEquals(set(odd), set(found))
145 self.assertEquals(set(odd), set(found))
146
146
147 def test_get_history(self):
147 def test_get_history(self):
148 msg_ids = self.db.get_history()
148 msg_ids = self.db.get_history()
149 latest = datetime(1984,1,1)
149 latest = datetime(1984,1,1)
150 for msg_id in msg_ids:
150 for msg_id in msg_ids:
151 rec = self.db.get_record(msg_id)
151 rec = self.db.get_record(msg_id)
152 newt = rec['submitted']
152 newt = rec['submitted']
153 self.assertTrue(newt >= latest)
153 self.assertTrue(newt >= latest)
154 latest = newt
154 latest = newt
155 msg_id = self.load_records(1)[-1]
155 msg_id = self.load_records(1)[-1]
156 self.assertEquals(self.db.get_history()[-1],msg_id)
156 self.assertEquals(self.db.get_history()[-1],msg_id)
157
157
158 def test_datetime(self):
158 def test_datetime(self):
159 """get/set timestamps with datetime objects"""
159 """get/set timestamps with datetime objects"""
160 msg_id = self.db.get_history()[-1]
160 msg_id = self.db.get_history()[-1]
161 rec = self.db.get_record(msg_id)
161 rec = self.db.get_record(msg_id)
162 self.assertTrue(isinstance(rec['submitted'], datetime))
162 self.assertTrue(isinstance(rec['submitted'], datetime))
163 self.db.update_record(msg_id, dict(completed=datetime.now()))
163 self.db.update_record(msg_id, dict(completed=datetime.now()))
164 rec = self.db.get_record(msg_id)
164 rec = self.db.get_record(msg_id)
165 self.assertTrue(isinstance(rec['completed'], datetime))
165 self.assertTrue(isinstance(rec['completed'], datetime))
166
166
167 def test_drop_matching(self):
167 def test_drop_matching(self):
168 msg_ids = self.load_records(10)
168 msg_ids = self.load_records(10)
169 query = {'msg_id' : {'$in':msg_ids}}
169 query = {'msg_id' : {'$in':msg_ids}}
170 self.db.drop_matching_records(query)
170 self.db.drop_matching_records(query)
171 recs = self.db.find_records(query)
171 recs = self.db.find_records(query)
172 self.assertEquals(len(recs), 0)
172 self.assertEquals(len(recs), 0)
173
173
174 def test_null(self):
174 def test_null(self):
175 """test None comparison queries"""
175 """test None comparison queries"""
176 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
177
177
178 query = {'msg_id' : None}
178 query = {'msg_id' : None}
179 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
180 self.assertEquals(len(recs), 0)
180 self.assertEquals(len(recs), 0)
181
181
182 query = {'msg_id' : {'$ne' : None}}
182 query = {'msg_id' : {'$ne' : None}}
183 recs = self.db.find_records(query)
183 recs = self.db.find_records(query)
184 self.assertTrue(len(recs) >= 10)
184 self.assertTrue(len(recs) >= 10)
185
185
186 def test_pop_safe_get(self):
187 """editing query results shouldn't affect record [get]"""
188 msg_id = self.db.get_history()[-1]
189 rec = self.db.get_record(msg_id)
190 rec.pop('buffers')
191 rec['garbage'] = 'hello'
192 rec2 = self.db.get_record(msg_id)
193 self.assertTrue('buffers' in rec2)
194 self.assertFalse('garbage' in rec2)
195
196 def test_pop_safe_find(self):
197 """editing query results shouldn't affect record [find]"""
198 msg_id = self.db.get_history()[-1]
199 rec = self.db.find_records({'msg_id' : msg_id})[0]
200 rec.pop('buffers')
201 rec['garbage'] = 'hello'
202 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
203 self.assertTrue('buffers' in rec2)
204 self.assertFalse('garbage' in rec2)
205
206 def test_pop_safe_find_keys(self):
207 """editing query results shouldn't affect record [find+keys]"""
208 msg_id = self.db.get_history()[-1]
209 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers'])[0]
210 rec.pop('buffers')
211 rec['garbage'] = 'hello'
212 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
213 self.assertTrue('buffers' in rec2)
214 self.assertFalse('garbage' in rec2)
215
186
216
187 class TestSQLiteBackend(TestDictBackend):
217 class TestSQLiteBackend(TestDictBackend):
188
218
189 @dec.skip_without('sqlite3')
219 @dec.skip_without('sqlite3')
190 def create_db(self):
220 def create_db(self):
191 return SQLiteDB(location=tempfile.gettempdir())
221 return SQLiteDB(location=tempfile.gettempdir())
192
222
193 def tearDown(self):
223 def tearDown(self):
194 self.db._db.close()
224 self.db._db.close()
General Comments 0
You need to be logged in to leave comments. Login now