##// END OF EJS Templates
s/assertEquals/assertEqual (again)...
MinRK -
Show More
@@ -1,34 +1,34 b''
1 """Tests for the notebook manager."""
1 """Tests for the notebook manager."""
2
2
3 import os
3 import os
4 from unittest import TestCase
4 from unittest import TestCase
5 from tempfile import NamedTemporaryFile
5 from tempfile import NamedTemporaryFile
6
6
7 from IPython.utils.tempdir import TemporaryDirectory
7 from IPython.utils.tempdir import TemporaryDirectory
8 from IPython.utils.traitlets import TraitError
8 from IPython.utils.traitlets import TraitError
9
9
10 from IPython.frontend.html.notebook.filenbmanager import FileNotebookManager
10 from IPython.frontend.html.notebook.filenbmanager import FileNotebookManager
11
11
12 class TestNotebookManager(TestCase):
12 class TestNotebookManager(TestCase):
13
13
14 def test_nb_dir(self):
14 def test_nb_dir(self):
15 with TemporaryDirectory() as td:
15 with TemporaryDirectory() as td:
16 km = FileNotebookManager(notebook_dir=td)
16 km = FileNotebookManager(notebook_dir=td)
17 self.assertEquals(km.notebook_dir, td)
17 self.assertEqual(km.notebook_dir, td)
18
18
19 def test_create_nb_dir(self):
19 def test_create_nb_dir(self):
20 with TemporaryDirectory() as td:
20 with TemporaryDirectory() as td:
21 nbdir = os.path.join(td, 'notebooks')
21 nbdir = os.path.join(td, 'notebooks')
22 km = FileNotebookManager(notebook_dir=nbdir)
22 km = FileNotebookManager(notebook_dir=nbdir)
23 self.assertEquals(km.notebook_dir, nbdir)
23 self.assertEqual(km.notebook_dir, nbdir)
24
24
25 def test_missing_nb_dir(self):
25 def test_missing_nb_dir(self):
26 with TemporaryDirectory() as td:
26 with TemporaryDirectory() as td:
27 nbdir = os.path.join(td, 'notebook', 'dir', 'is', 'missing')
27 nbdir = os.path.join(td, 'notebook', 'dir', 'is', 'missing')
28 self.assertRaises(TraitError, FileNotebookManager, notebook_dir=nbdir)
28 self.assertRaises(TraitError, FileNotebookManager, notebook_dir=nbdir)
29
29
30 def test_invalid_nb_dir(self):
30 def test_invalid_nb_dir(self):
31 with NamedTemporaryFile() as tf:
31 with NamedTemporaryFile() as tf:
32 self.assertRaises(TraitError, FileNotebookManager, notebook_dir=tf.name)
32 self.assertRaises(TraitError, FileNotebookManager, notebook_dir=tf.name)
33
33
34
34
@@ -1,314 +1,314 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import logging
21 import logging
22 import os
22 import os
23 import tempfile
23 import tempfile
24 import time
24 import time
25
25
26 from datetime import datetime, timedelta
26 from datetime import datetime, timedelta
27 from unittest import TestCase
27 from unittest import TestCase
28
28
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.dictdb import DictDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 from IPython.parallel.controller.hub import init_record, empty_record
32 from IPython.parallel.controller.hub import init_record, empty_record
33
33
34 from IPython.testing import decorators as dec
34 from IPython.testing import decorators as dec
35 from IPython.zmq.session import Session
35 from IPython.zmq.session import Session
36
36
37
37
38 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
39 # TestCases
39 # TestCases
40 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
41
41
42
42
43 def setup():
43 def setup():
44 global temp_db
44 global temp_db
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46
46
47
47
48 class TaskDBTest:
48 class TaskDBTest:
49 def setUp(self):
49 def setUp(self):
50 self.session = Session()
50 self.session = Session()
51 self.db = self.create_db()
51 self.db = self.create_db()
52 self.load_records(16)
52 self.load_records(16)
53
53
54 def create_db(self):
54 def create_db(self):
55 raise NotImplementedError
55 raise NotImplementedError
56
56
57 def load_records(self, n=1, buffer_size=100):
57 def load_records(self, n=1, buffer_size=100):
58 """load n records for testing"""
58 """load n records for testing"""
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 time.sleep(0.1)
60 time.sleep(0.1)
61 msg_ids = []
61 msg_ids = []
62 for i in range(n):
62 for i in range(n):
63 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg = self.session.msg('apply_request', content=dict(a=5))
64 msg['buffers'] = [os.urandom(buffer_size)]
64 msg['buffers'] = [os.urandom(buffer_size)]
65 rec = init_record(msg)
65 rec = init_record(msg)
66 msg_id = msg['header']['msg_id']
66 msg_id = msg['header']['msg_id']
67 msg_ids.append(msg_id)
67 msg_ids.append(msg_id)
68 self.db.add_record(msg_id, rec)
68 self.db.add_record(msg_id, rec)
69 return msg_ids
69 return msg_ids
70
70
71 def test_add_record(self):
71 def test_add_record(self):
72 before = self.db.get_history()
72 before = self.db.get_history()
73 self.load_records(5)
73 self.load_records(5)
74 after = self.db.get_history()
74 after = self.db.get_history()
75 self.assertEqual(len(after), len(before)+5)
75 self.assertEqual(len(after), len(before)+5)
76 self.assertEqual(after[:-5],before)
76 self.assertEqual(after[:-5],before)
77
77
78 def test_drop_record(self):
78 def test_drop_record(self):
79 msg_id = self.load_records()[-1]
79 msg_id = self.load_records()[-1]
80 rec = self.db.get_record(msg_id)
80 rec = self.db.get_record(msg_id)
81 self.db.drop_record(msg_id)
81 self.db.drop_record(msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83
83
84 def _round_to_millisecond(self, dt):
84 def _round_to_millisecond(self, dt):
85 """necessary because mongodb rounds microseconds"""
85 """necessary because mongodb rounds microseconds"""
86 micro = dt.microsecond
86 micro = dt.microsecond
87 extra = int(str(micro)[-3:])
87 extra = int(str(micro)[-3:])
88 return dt - timedelta(microseconds=extra)
88 return dt - timedelta(microseconds=extra)
89
89
90 def test_update_record(self):
90 def test_update_record(self):
91 now = self._round_to_millisecond(datetime.now())
91 now = self._round_to_millisecond(datetime.now())
92 #
92 #
93 msg_id = self.db.get_history()[-1]
93 msg_id = self.db.get_history()[-1]
94 rec1 = self.db.get_record(msg_id)
94 rec1 = self.db.get_record(msg_id)
95 data = {'stdout': 'hello there', 'completed' : now}
95 data = {'stdout': 'hello there', 'completed' : now}
96 self.db.update_record(msg_id, data)
96 self.db.update_record(msg_id, data)
97 rec2 = self.db.get_record(msg_id)
97 rec2 = self.db.get_record(msg_id)
98 self.assertEqual(rec2['stdout'], 'hello there')
98 self.assertEqual(rec2['stdout'], 'hello there')
99 self.assertEqual(rec2['completed'], now)
99 self.assertEqual(rec2['completed'], now)
100 rec1.update(data)
100 rec1.update(data)
101 self.assertEqual(rec1, rec2)
101 self.assertEqual(rec1, rec2)
102
102
103 # def test_update_record_bad(self):
103 # def test_update_record_bad(self):
104 # """test updating nonexistant records"""
104 # """test updating nonexistant records"""
105 # msg_id = str(uuid.uuid4())
105 # msg_id = str(uuid.uuid4())
106 # data = {'stdout': 'hello there'}
106 # data = {'stdout': 'hello there'}
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108
108
109 def test_find_records_dt(self):
109 def test_find_records_dt(self):
110 """test finding records by date"""
110 """test finding records by date"""
111 hist = self.db.get_history()
111 hist = self.db.get_history()
112 middle = self.db.get_record(hist[len(hist)//2])
112 middle = self.db.get_record(hist[len(hist)//2])
113 tic = middle['submitted']
113 tic = middle['submitted']
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 self.assertEqual(len(before)+len(after),len(hist))
116 self.assertEqual(len(before)+len(after),len(hist))
117 for b in before:
117 for b in before:
118 self.assertTrue(b['submitted'] < tic)
118 self.assertTrue(b['submitted'] < tic)
119 for a in after:
119 for a in after:
120 self.assertTrue(a['submitted'] >= tic)
120 self.assertTrue(a['submitted'] >= tic)
121 same = self.db.find_records({'submitted' : tic})
121 same = self.db.find_records({'submitted' : tic})
122 for s in same:
122 for s in same:
123 self.assertTrue(s['submitted'] == tic)
123 self.assertTrue(s['submitted'] == tic)
124
124
125 def test_find_records_keys(self):
125 def test_find_records_keys(self):
126 """test extracting subset of record keys"""
126 """test extracting subset of record keys"""
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 for rec in found:
128 for rec in found:
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130
130
131 def test_find_records_msg_id(self):
131 def test_find_records_msg_id(self):
132 """ensure msg_id is always in found records"""
132 """ensure msg_id is always in found records"""
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 for rec in found:
134 for rec in found:
135 self.assertTrue('msg_id' in rec.keys())
135 self.assertTrue('msg_id' in rec.keys())
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 for rec in found:
137 for rec in found:
138 self.assertTrue('msg_id' in rec.keys())
138 self.assertTrue('msg_id' in rec.keys())
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 for rec in found:
140 for rec in found:
141 self.assertTrue('msg_id' in rec.keys())
141 self.assertTrue('msg_id' in rec.keys())
142
142
143 def test_find_records_in(self):
143 def test_find_records_in(self):
144 """test finding records with '$in','$nin' operators"""
144 """test finding records with '$in','$nin' operators"""
145 hist = self.db.get_history()
145 hist = self.db.get_history()
146 even = hist[::2]
146 even = hist[::2]
147 odd = hist[1::2]
147 odd = hist[1::2]
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 found = [ r['msg_id'] for r in recs ]
149 found = [ r['msg_id'] for r in recs ]
150 self.assertEqual(set(even), set(found))
150 self.assertEqual(set(even), set(found))
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 found = [ r['msg_id'] for r in recs ]
152 found = [ r['msg_id'] for r in recs ]
153 self.assertEqual(set(odd), set(found))
153 self.assertEqual(set(odd), set(found))
154
154
155 def test_get_history(self):
155 def test_get_history(self):
156 msg_ids = self.db.get_history()
156 msg_ids = self.db.get_history()
157 latest = datetime(1984,1,1)
157 latest = datetime(1984,1,1)
158 for msg_id in msg_ids:
158 for msg_id in msg_ids:
159 rec = self.db.get_record(msg_id)
159 rec = self.db.get_record(msg_id)
160 newt = rec['submitted']
160 newt = rec['submitted']
161 self.assertTrue(newt >= latest)
161 self.assertTrue(newt >= latest)
162 latest = newt
162 latest = newt
163 msg_id = self.load_records(1)[-1]
163 msg_id = self.load_records(1)[-1]
164 self.assertEqual(self.db.get_history()[-1],msg_id)
164 self.assertEqual(self.db.get_history()[-1],msg_id)
165
165
166 def test_datetime(self):
166 def test_datetime(self):
167 """get/set timestamps with datetime objects"""
167 """get/set timestamps with datetime objects"""
168 msg_id = self.db.get_history()[-1]
168 msg_id = self.db.get_history()[-1]
169 rec = self.db.get_record(msg_id)
169 rec = self.db.get_record(msg_id)
170 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 rec = self.db.get_record(msg_id)
172 rec = self.db.get_record(msg_id)
173 self.assertTrue(isinstance(rec['completed'], datetime))
173 self.assertTrue(isinstance(rec['completed'], datetime))
174
174
175 def test_drop_matching(self):
175 def test_drop_matching(self):
176 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
177 query = {'msg_id' : {'$in':msg_ids}}
177 query = {'msg_id' : {'$in':msg_ids}}
178 self.db.drop_matching_records(query)
178 self.db.drop_matching_records(query)
179 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
180 self.assertEqual(len(recs), 0)
180 self.assertEqual(len(recs), 0)
181
181
182 def test_null(self):
182 def test_null(self):
183 """test None comparison queries"""
183 """test None comparison queries"""
184 msg_ids = self.load_records(10)
184 msg_ids = self.load_records(10)
185
185
186 query = {'msg_id' : None}
186 query = {'msg_id' : None}
187 recs = self.db.find_records(query)
187 recs = self.db.find_records(query)
188 self.assertEqual(len(recs), 0)
188 self.assertEqual(len(recs), 0)
189
189
190 query = {'msg_id' : {'$ne' : None}}
190 query = {'msg_id' : {'$ne' : None}}
191 recs = self.db.find_records(query)
191 recs = self.db.find_records(query)
192 self.assertTrue(len(recs) >= 10)
192 self.assertTrue(len(recs) >= 10)
193
193
194 def test_pop_safe_get(self):
194 def test_pop_safe_get(self):
195 """editing query results shouldn't affect record [get]"""
195 """editing query results shouldn't affect record [get]"""
196 msg_id = self.db.get_history()[-1]
196 msg_id = self.db.get_history()[-1]
197 rec = self.db.get_record(msg_id)
197 rec = self.db.get_record(msg_id)
198 rec.pop('buffers')
198 rec.pop('buffers')
199 rec['garbage'] = 'hello'
199 rec['garbage'] = 'hello'
200 rec['header']['msg_id'] = 'fubar'
200 rec['header']['msg_id'] = 'fubar'
201 rec2 = self.db.get_record(msg_id)
201 rec2 = self.db.get_record(msg_id)
202 self.assertTrue('buffers' in rec2)
202 self.assertTrue('buffers' in rec2)
203 self.assertFalse('garbage' in rec2)
203 self.assertFalse('garbage' in rec2)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205
205
206 def test_pop_safe_find(self):
206 def test_pop_safe_find(self):
207 """editing query results shouldn't affect record [find]"""
207 """editing query results shouldn't affect record [find]"""
208 msg_id = self.db.get_history()[-1]
208 msg_id = self.db.get_history()[-1]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 rec.pop('buffers')
210 rec.pop('buffers')
211 rec['garbage'] = 'hello'
211 rec['garbage'] = 'hello'
212 rec['header']['msg_id'] = 'fubar'
212 rec['header']['msg_id'] = 'fubar'
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 self.assertTrue('buffers' in rec2)
214 self.assertTrue('buffers' in rec2)
215 self.assertFalse('garbage' in rec2)
215 self.assertFalse('garbage' in rec2)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217
217
218 def test_pop_safe_find_keys(self):
218 def test_pop_safe_find_keys(self):
219 """editing query results shouldn't affect record [find+keys]"""
219 """editing query results shouldn't affect record [find+keys]"""
220 msg_id = self.db.get_history()[-1]
220 msg_id = self.db.get_history()[-1]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 rec.pop('buffers')
222 rec.pop('buffers')
223 rec['garbage'] = 'hello'
223 rec['garbage'] = 'hello'
224 rec['header']['msg_id'] = 'fubar'
224 rec['header']['msg_id'] = 'fubar'
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 self.assertTrue('buffers' in rec2)
226 self.assertTrue('buffers' in rec2)
227 self.assertFalse('garbage' in rec2)
227 self.assertFalse('garbage' in rec2)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229
229
230
230
231 class TestDictBackend(TaskDBTest, TestCase):
231 class TestDictBackend(TaskDBTest, TestCase):
232
232
233 def create_db(self):
233 def create_db(self):
234 return DictDB()
234 return DictDB()
235
235
236 def test_cull_count(self):
236 def test_cull_count(self):
237 self.db = self.create_db() # skip the load-records init from setUp
237 self.db = self.create_db() # skip the load-records init from setUp
238 self.db.record_limit = 20
238 self.db.record_limit = 20
239 self.db.cull_fraction = 0.2
239 self.db.cull_fraction = 0.2
240 self.load_records(20)
240 self.load_records(20)
241 self.assertEquals(len(self.db.get_history()), 20)
241 self.assertEqual(len(self.db.get_history()), 20)
242 self.load_records(1)
242 self.load_records(1)
243 # 0.2 * 20 = 4, 21 - 4 = 17
243 # 0.2 * 20 = 4, 21 - 4 = 17
244 self.assertEquals(len(self.db.get_history()), 17)
244 self.assertEqual(len(self.db.get_history()), 17)
245 self.load_records(3)
245 self.load_records(3)
246 self.assertEquals(len(self.db.get_history()), 20)
246 self.assertEqual(len(self.db.get_history()), 20)
247 self.load_records(1)
247 self.load_records(1)
248 self.assertEquals(len(self.db.get_history()), 17)
248 self.assertEqual(len(self.db.get_history()), 17)
249
249
250 for i in range(100):
250 for i in range(100):
251 self.load_records(1)
251 self.load_records(1)
252 self.assertTrue(len(self.db.get_history()) >= 17)
252 self.assertTrue(len(self.db.get_history()) >= 17)
253 self.assertTrue(len(self.db.get_history()) <= 20)
253 self.assertTrue(len(self.db.get_history()) <= 20)
254
254
255 def test_cull_size(self):
255 def test_cull_size(self):
256 self.db = self.create_db() # skip the load-records init from setUp
256 self.db = self.create_db() # skip the load-records init from setUp
257 self.db.size_limit = 1000
257 self.db.size_limit = 1000
258 self.db.cull_fraction = 0.2
258 self.db.cull_fraction = 0.2
259 self.load_records(100, buffer_size=10)
259 self.load_records(100, buffer_size=10)
260 self.assertEquals(len(self.db.get_history()), 100)
260 self.assertEqual(len(self.db.get_history()), 100)
261 self.load_records(1, buffer_size=0)
261 self.load_records(1, buffer_size=0)
262 self.assertEquals(len(self.db.get_history()), 101)
262 self.assertEqual(len(self.db.get_history()), 101)
263 self.load_records(1, buffer_size=1)
263 self.load_records(1, buffer_size=1)
264 # 0.2 * 100 = 20, 101 - 20 = 81
264 # 0.2 * 100 = 20, 101 - 20 = 81
265 self.assertEquals(len(self.db.get_history()), 81)
265 self.assertEqual(len(self.db.get_history()), 81)
266
266
267 def test_cull_size_drop(self):
267 def test_cull_size_drop(self):
268 """dropping records updates tracked buffer size"""
268 """dropping records updates tracked buffer size"""
269 self.db = self.create_db() # skip the load-records init from setUp
269 self.db = self.create_db() # skip the load-records init from setUp
270 self.db.size_limit = 1000
270 self.db.size_limit = 1000
271 self.db.cull_fraction = 0.2
271 self.db.cull_fraction = 0.2
272 self.load_records(100, buffer_size=10)
272 self.load_records(100, buffer_size=10)
273 self.assertEquals(len(self.db.get_history()), 100)
273 self.assertEqual(len(self.db.get_history()), 100)
274 self.db.drop_record(self.db.get_history()[-1])
274 self.db.drop_record(self.db.get_history()[-1])
275 self.assertEquals(len(self.db.get_history()), 99)
275 self.assertEqual(len(self.db.get_history()), 99)
276 self.load_records(1, buffer_size=5)
276 self.load_records(1, buffer_size=5)
277 self.assertEquals(len(self.db.get_history()), 100)
277 self.assertEqual(len(self.db.get_history()), 100)
278 self.load_records(1, buffer_size=5)
278 self.load_records(1, buffer_size=5)
279 self.assertEquals(len(self.db.get_history()), 101)
279 self.assertEqual(len(self.db.get_history()), 101)
280 self.load_records(1, buffer_size=1)
280 self.load_records(1, buffer_size=1)
281 self.assertEquals(len(self.db.get_history()), 81)
281 self.assertEqual(len(self.db.get_history()), 81)
282
282
283 def test_cull_size_update(self):
283 def test_cull_size_update(self):
284 """updating records updates tracked buffer size"""
284 """updating records updates tracked buffer size"""
285 self.db = self.create_db() # skip the load-records init from setUp
285 self.db = self.create_db() # skip the load-records init from setUp
286 self.db.size_limit = 1000
286 self.db.size_limit = 1000
287 self.db.cull_fraction = 0.2
287 self.db.cull_fraction = 0.2
288 self.load_records(100, buffer_size=10)
288 self.load_records(100, buffer_size=10)
289 self.assertEquals(len(self.db.get_history()), 100)
289 self.assertEqual(len(self.db.get_history()), 100)
290 msg_id = self.db.get_history()[-1]
290 msg_id = self.db.get_history()[-1]
291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
292 self.assertEquals(len(self.db.get_history()), 100)
292 self.assertEqual(len(self.db.get_history()), 100)
293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
294 self.assertEquals(len(self.db.get_history()), 79)
294 self.assertEqual(len(self.db.get_history()), 79)
295
295
296 class TestSQLiteBackend(TaskDBTest, TestCase):
296 class TestSQLiteBackend(TaskDBTest, TestCase):
297
297
298 @dec.skip_without('sqlite3')
298 @dec.skip_without('sqlite3')
299 def create_db(self):
299 def create_db(self):
300 location, fname = os.path.split(temp_db)
300 location, fname = os.path.split(temp_db)
301 log = logging.getLogger('test')
301 log = logging.getLogger('test')
302 log.setLevel(logging.CRITICAL)
302 log.setLevel(logging.CRITICAL)
303 return SQLiteDB(location=location, fname=fname, log=log)
303 return SQLiteDB(location=location, fname=fname, log=log)
304
304
305 def tearDown(self):
305 def tearDown(self):
306 self.db._db.close()
306 self.db._db.close()
307
307
308
308
309 def teardown():
309 def teardown():
310 """cleanup task db file after all tests have run"""
310 """cleanup task db file after all tests have run"""
311 try:
311 try:
312 os.remove(temp_db)
312 os.remove(temp_db)
313 except:
313 except:
314 pass
314 pass
@@ -1,678 +1,678 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30
30
31 from IPython import parallel as pmod
31 from IPython import parallel as pmod
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import DirectView
34 from IPython.parallel import DirectView
35 from IPython.parallel.util import interactive
35 from IPython.parallel.util import interactive
36
36
37 from IPython.parallel.tests import add_engines
37 from IPython.parallel.tests import add_engines
38
38
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40
40
41 def setup():
41 def setup():
42 add_engines(3, total=True)
42 add_engines(3, total=True)
43
43
44 class TestView(ClusterTestCase, ParametricTestCase):
44 class TestView(ClusterTestCase, ParametricTestCase):
45
45
46 def setUp(self):
46 def setUp(self):
47 # On Win XP, wait for resource cleanup, else parallel test group fails
47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 time.sleep(2)
50 time.sleep(2)
51 super(TestView, self).setUp()
51 super(TestView, self).setUp()
52
52
53 def test_z_crash_mux(self):
53 def test_z_crash_mux(self):
54 """test graceful handling of engine death (direct)"""
54 """test graceful handling of engine death (direct)"""
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 # self.add_engines(1)
56 # self.add_engines(1)
57 eid = self.client.ids[-1]
57 eid = self.client.ids[-1]
58 ar = self.client[eid].apply_async(crash)
58 ar = self.client[eid].apply_async(crash)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 eid = ar.engine_id
60 eid = ar.engine_id
61 tic = time.time()
61 tic = time.time()
62 while eid in self.client.ids and time.time()-tic < 5:
62 while eid in self.client.ids and time.time()-tic < 5:
63 time.sleep(.01)
63 time.sleep(.01)
64 self.client.spin()
64 self.client.spin()
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66
66
67 def test_push_pull(self):
67 def test_push_pull(self):
68 """test pushing and pulling"""
68 """test pushing and pulling"""
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 t = self.client.ids[-1]
70 t = self.client.ids[-1]
71 v = self.client[t]
71 v = self.client[t]
72 push = v.push
72 push = v.push
73 pull = v.pull
73 pull = v.pull
74 v.block=True
74 v.block=True
75 nengines = len(self.client)
75 nengines = len(self.client)
76 push({'data':data})
76 push({'data':data})
77 d = pull('data')
77 d = pull('data')
78 self.assertEqual(d, data)
78 self.assertEqual(d, data)
79 self.client[:].push({'data':data})
79 self.client[:].push({'data':data})
80 d = self.client[:].pull('data', block=True)
80 d = self.client[:].pull('data', block=True)
81 self.assertEqual(d, nengines*[data])
81 self.assertEqual(d, nengines*[data])
82 ar = push({'data':data}, block=False)
82 ar = push({'data':data}, block=False)
83 self.assertTrue(isinstance(ar, AsyncResult))
83 self.assertTrue(isinstance(ar, AsyncResult))
84 r = ar.get()
84 r = ar.get()
85 ar = self.client[:].pull('data', block=False)
85 ar = self.client[:].pull('data', block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 self.assertEqual(r, nengines*[data])
88 self.assertEqual(r, nengines*[data])
89 self.client[:].push(dict(a=10,b=20))
89 self.client[:].push(dict(a=10,b=20))
90 r = self.client[:].pull(('a','b'), block=True)
90 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEqual(r, nengines*[[10,20]])
91 self.assertEqual(r, nengines*[[10,20]])
92
92
93 def test_push_pull_function(self):
93 def test_push_pull_function(self):
94 "test pushing and pulling functions"
94 "test pushing and pulling functions"
95 def testf(x):
95 def testf(x):
96 return 2.0*x
96 return 2.0*x
97
97
98 t = self.client.ids[-1]
98 t = self.client.ids[-1]
99 v = self.client[t]
99 v = self.client[t]
100 v.block=True
100 v.block=True
101 push = v.push
101 push = v.push
102 pull = v.pull
102 pull = v.pull
103 execute = v.execute
103 execute = v.execute
104 push({'testf':testf})
104 push({'testf':testf})
105 r = pull('testf')
105 r = pull('testf')
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute('r = testf(10)')
107 execute('r = testf(10)')
108 r = pull('r')
108 r = pull('r')
109 self.assertEqual(r, testf(10))
109 self.assertEqual(r, testf(10))
110 ar = self.client[:].push({'testf':testf}, block=False)
110 ar = self.client[:].push({'testf':testf}, block=False)
111 ar.get()
111 ar.get()
112 ar = self.client[:].pull('testf', block=False)
112 ar = self.client[:].pull('testf', block=False)
113 rlist = ar.get()
113 rlist = ar.get()
114 for r in rlist:
114 for r in rlist:
115 self.assertEqual(r(1.0), testf(1.0))
115 self.assertEqual(r(1.0), testf(1.0))
116 execute("def g(x): return x*x")
116 execute("def g(x): return x*x")
117 r = pull(('testf','g'))
117 r = pull(('testf','g'))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119
119
120 def test_push_function_globals(self):
120 def test_push_function_globals(self):
121 """test that pushed functions have access to globals"""
121 """test that pushed functions have access to globals"""
122 @interactive
122 @interactive
123 def geta():
123 def geta():
124 return a
124 return a
125 # self.add_engines(1)
125 # self.add_engines(1)
126 v = self.client[-1]
126 v = self.client[-1]
127 v.block=True
127 v.block=True
128 v['f'] = geta
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
130 v.execute('a=5')
131 v.execute('b=f()')
131 v.execute('b=f()')
132 self.assertEqual(v['b'], 5)
132 self.assertEqual(v['b'], 5)
133
133
134 def test_push_function_defaults(self):
134 def test_push_function_defaults(self):
135 """test that pushed functions preserve default args"""
135 """test that pushed functions preserve default args"""
136 def echo(a=10):
136 def echo(a=10):
137 return a
137 return a
138 v = self.client[-1]
138 v = self.client[-1]
139 v.block=True
139 v.block=True
140 v['f'] = echo
140 v['f'] = echo
141 v.execute('b=f()')
141 v.execute('b=f()')
142 self.assertEqual(v['b'], 10)
142 self.assertEqual(v['b'], 10)
143
143
144 def test_get_result(self):
144 def test_get_result(self):
145 """test getting results from the Hub."""
145 """test getting results from the Hub."""
146 c = pmod.Client(profile='iptest')
146 c = pmod.Client(profile='iptest')
147 # self.add_engines(1)
147 # self.add_engines(1)
148 t = c.ids[-1]
148 t = c.ids[-1]
149 v = c[t]
149 v = c[t]
150 v2 = self.client[t]
150 v2 = self.client[t]
151 ar = v.apply_async(wait, 1)
151 ar = v.apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = v2.get_result(ar.msg_ids)
154 ahr = v2.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEqual(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
157 ar2 = v2.get_result(ar.msg_ids)
157 ar2 = v2.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.spin()
159 c.spin()
160 c.close()
160 c.close()
161
161
162 def test_run_newline(self):
162 def test_run_newline(self):
163 """test that run appends newline to files"""
163 """test that run appends newline to files"""
164 tmpfile = mktemp()
164 tmpfile = mktemp()
165 with open(tmpfile, 'w') as f:
165 with open(tmpfile, 'w') as f:
166 f.write("""def g():
166 f.write("""def g():
167 return 5
167 return 5
168 """)
168 """)
169 v = self.client[-1]
169 v = self.client[-1]
170 v.run(tmpfile, block=True)
170 v.run(tmpfile, block=True)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172
172
173 def test_apply_tracked(self):
173 def test_apply_tracked(self):
174 """test tracking for apply"""
174 """test tracking for apply"""
175 # self.add_engines(1)
175 # self.add_engines(1)
176 t = self.client.ids[-1]
176 t = self.client.ids[-1]
177 v = self.client[t]
177 v = self.client[t]
178 v.block=False
178 v.block=False
179 def echo(n=1024*1024, **kwargs):
179 def echo(n=1024*1024, **kwargs):
180 with v.temp_flags(**kwargs):
180 with v.temp_flags(**kwargs):
181 return v.apply(lambda x: x, 'x'*n)
181 return v.apply(lambda x: x, 'x'*n)
182 ar = echo(1, track=False)
182 ar = echo(1, track=False)
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(ar.sent)
184 self.assertTrue(ar.sent)
185 ar = echo(track=True)
185 ar = echo(track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEqual(ar.sent, ar._tracker.done)
187 self.assertEqual(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
188 ar._tracker.wait()
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190
190
191 def test_push_tracked(self):
191 def test_push_tracked(self):
192 t = self.client.ids[-1]
192 t = self.client.ids[-1]
193 ns = dict(x='x'*1024*1024)
193 ns = dict(x='x'*1024*1024)
194 v = self.client[t]
194 v = self.client[t]
195 ar = v.push(ns, block=False, track=False)
195 ar = v.push(ns, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
197 self.assertTrue(ar.sent)
198
198
199 ar = v.push(ns, block=False, track=True)
199 ar = v.push(ns, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 ar._tracker.wait()
201 ar._tracker.wait()
202 self.assertEqual(ar.sent, ar._tracker.done)
202 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204 ar.get()
204 ar.get()
205
205
206 def test_scatter_tracked(self):
206 def test_scatter_tracked(self):
207 t = self.client.ids
207 t = self.client.ids
208 x='x'*1024*1024
208 x='x'*1024*1024
209 ar = self.client[t].scatter('x', x, block=False, track=False)
209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(ar.sent)
211 self.assertTrue(ar.sent)
212
212
213 ar = self.client[t].scatter('x', x, block=False, track=True)
213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEqual(ar.sent, ar._tracker.done)
215 self.assertEqual(ar.sent, ar._tracker.done)
216 ar._tracker.wait()
216 ar._tracker.wait()
217 self.assertTrue(ar.sent)
217 self.assertTrue(ar.sent)
218 ar.get()
218 ar.get()
219
219
220 def test_remote_reference(self):
220 def test_remote_reference(self):
221 v = self.client[-1]
221 v = self.client[-1]
222 v['a'] = 123
222 v['a'] = 123
223 ra = pmod.Reference('a')
223 ra = pmod.Reference('a')
224 b = v.apply_sync(lambda x: x, ra)
224 b = v.apply_sync(lambda x: x, ra)
225 self.assertEqual(b, 123)
225 self.assertEqual(b, 123)
226
226
227
227
228 def test_scatter_gather(self):
228 def test_scatter_gather(self):
229 view = self.client[:]
229 view = self.client[:]
230 seq1 = range(16)
230 seq1 = range(16)
231 view.scatter('a', seq1)
231 view.scatter('a', seq1)
232 seq2 = view.gather('a', block=True)
232 seq2 = view.gather('a', block=True)
233 self.assertEqual(seq2, seq1)
233 self.assertEqual(seq2, seq1)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235
235
236 @skip_without('numpy')
236 @skip_without('numpy')
237 def test_scatter_gather_numpy(self):
237 def test_scatter_gather_numpy(self):
238 import numpy
238 import numpy
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view.scatter('a', a, block=True)
242 view.scatter('a', a, block=True)
243 b = view.gather('a', block=True)
243 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
244 assert_array_equal(b, a)
245
245
246 def test_scatter_gather_lazy(self):
246 def test_scatter_gather_lazy(self):
247 """scatter/gather with targets='all'"""
247 """scatter/gather with targets='all'"""
248 view = self.client.direct_view(targets='all')
248 view = self.client.direct_view(targets='all')
249 x = range(64)
249 x = range(64)
250 view.scatter('x', x)
250 view.scatter('x', x)
251 gathered = view.gather('x', block=True)
251 gathered = view.gather('x', block=True)
252 self.assertEqual(gathered, x)
252 self.assertEqual(gathered, x)
253
253
254
254
255 @dec.known_failure_py3
255 @dec.known_failure_py3
256 @skip_without('numpy')
256 @skip_without('numpy')
257 def test_push_numpy_nocopy(self):
257 def test_push_numpy_nocopy(self):
258 import numpy
258 import numpy
259 view = self.client[:]
259 view = self.client[:]
260 a = numpy.arange(64)
260 a = numpy.arange(64)
261 view['A'] = a
261 view['A'] = a
262 @interactive
262 @interactive
263 def check_writeable(x):
263 def check_writeable(x):
264 return x.flags.writeable
264 return x.flags.writeable
265
265
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268
268
269 view.push(dict(B=a))
269 view.push(dict(B=a))
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272
272
273 @skip_without('numpy')
273 @skip_without('numpy')
274 def test_apply_numpy(self):
274 def test_apply_numpy(self):
275 """view.apply(f, ndarray)"""
275 """view.apply(f, ndarray)"""
276 import numpy
276 import numpy
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278
278
279 A = numpy.random.random((100,100))
279 A = numpy.random.random((100,100))
280 view = self.client[-1]
280 view = self.client[-1]
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 B = A.astype(dt)
282 B = A.astype(dt)
283 C = view.apply_sync(lambda x:x, B)
283 C = view.apply_sync(lambda x:x, B)
284 assert_array_equal(B,C)
284 assert_array_equal(B,C)
285
285
286 @skip_without('numpy')
286 @skip_without('numpy')
287 def test_push_pull_recarray(self):
287 def test_push_pull_recarray(self):
288 """push/pull recarrays"""
288 """push/pull recarrays"""
289 import numpy
289 import numpy
290 from numpy.testing.utils import assert_array_equal
290 from numpy.testing.utils import assert_array_equal
291
291
292 view = self.client[-1]
292 view = self.client[-1]
293
293
294 R = numpy.array([
294 R = numpy.array([
295 (1, 'hi', 0.),
295 (1, 'hi', 0.),
296 (2**30, 'there', 2.5),
296 (2**30, 'there', 2.5),
297 (-99999, 'world', -12345.6789),
297 (-99999, 'world', -12345.6789),
298 ], [('n', int), ('s', '|S10'), ('f', float)])
298 ], [('n', int), ('s', '|S10'), ('f', float)])
299
299
300 view['RR'] = R
300 view['RR'] = R
301 R2 = view['RR']
301 R2 = view['RR']
302
302
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEqual(r_dtype, R.dtype)
304 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_shape, R.shape)
305 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(R2.dtype, R.dtype)
306 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.shape, R.shape)
307 self.assertEqual(R2.shape, R.shape)
308 assert_array_equal(R2, R)
308 assert_array_equal(R2, R)
309
309
310 def test_map(self):
310 def test_map(self):
311 view = self.client[:]
311 view = self.client[:]
312 def f(x):
312 def f(x):
313 return x**2
313 return x**2
314 data = range(16)
314 data = range(16)
315 r = view.map_sync(f, data)
315 r = view.map_sync(f, data)
316 self.assertEqual(r, map(f, data))
316 self.assertEqual(r, map(f, data))
317
317
318 def test_map_iterable(self):
318 def test_map_iterable(self):
319 """test map on iterables (direct)"""
319 """test map on iterables (direct)"""
320 view = self.client[:]
320 view = self.client[:]
321 # 101 is prime, so it won't be evenly distributed
321 # 101 is prime, so it won't be evenly distributed
322 arr = range(101)
322 arr = range(101)
323 # ensure it will be an iterator, even in Python 3
323 # ensure it will be an iterator, even in Python 3
324 it = iter(arr)
324 it = iter(arr)
325 r = view.map_sync(lambda x:x, arr)
325 r = view.map_sync(lambda x:x, arr)
326 self.assertEqual(r, list(arr))
326 self.assertEqual(r, list(arr))
327
327
328 def test_scatter_gather_nonblocking(self):
328 def test_scatter_gather_nonblocking(self):
329 data = range(16)
329 data = range(16)
330 view = self.client[:]
330 view = self.client[:]
331 view.scatter('a', data, block=False)
331 view.scatter('a', data, block=False)
332 ar = view.gather('a', block=False)
332 ar = view.gather('a', block=False)
333 self.assertEqual(ar.get(), data)
333 self.assertEqual(ar.get(), data)
334
334
335 @skip_without('numpy')
335 @skip_without('numpy')
336 def test_scatter_gather_numpy_nonblocking(self):
336 def test_scatter_gather_numpy_nonblocking(self):
337 import numpy
337 import numpy
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 a = numpy.arange(64)
339 a = numpy.arange(64)
340 view = self.client[:]
340 view = self.client[:]
341 ar = view.scatter('a', a, block=False)
341 ar = view.scatter('a', a, block=False)
342 self.assertTrue(isinstance(ar, AsyncResult))
342 self.assertTrue(isinstance(ar, AsyncResult))
343 amr = view.gather('a', block=False)
343 amr = view.gather('a', block=False)
344 self.assertTrue(isinstance(amr, AsyncMapResult))
344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 assert_array_equal(amr.get(), a)
345 assert_array_equal(amr.get(), a)
346
346
347 def test_execute(self):
347 def test_execute(self):
348 view = self.client[:]
348 view = self.client[:]
349 # self.client.debug=True
349 # self.client.debug=True
350 execute = view.execute
350 execute = view.execute
351 ar = execute('c=30', block=False)
351 ar = execute('c=30', block=False)
352 self.assertTrue(isinstance(ar, AsyncResult))
352 self.assertTrue(isinstance(ar, AsyncResult))
353 ar = execute('d=[0,1,2]', block=False)
353 ar = execute('d=[0,1,2]', block=False)
354 self.client.wait(ar, 1)
354 self.client.wait(ar, 1)
355 self.assertEqual(len(ar.get()), len(self.client))
355 self.assertEqual(len(ar.get()), len(self.client))
356 for c in view['c']:
356 for c in view['c']:
357 self.assertEqual(c, 30)
357 self.assertEqual(c, 30)
358
358
359 def test_abort(self):
359 def test_abort(self):
360 view = self.client[-1]
360 view = self.client[-1]
361 ar = view.execute('import time; time.sleep(1)', block=False)
361 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar2 = view.apply_async(lambda : 2)
362 ar2 = view.apply_async(lambda : 2)
363 ar3 = view.apply_async(lambda : 3)
363 ar3 = view.apply_async(lambda : 3)
364 view.abort(ar2)
364 view.abort(ar2)
365 view.abort(ar3.msg_ids)
365 view.abort(ar3.msg_ids)
366 self.assertRaises(error.TaskAborted, ar2.get)
366 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
368
368
369 def test_abort_all(self):
369 def test_abort_all(self):
370 """view.abort() aborts all outstanding tasks"""
370 """view.abort() aborts all outstanding tasks"""
371 view = self.client[-1]
371 view = self.client[-1]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 view.abort()
373 view.abort()
374 view.wait(timeout=5)
374 view.wait(timeout=5)
375 for ar in ars[5:]:
375 for ar in ars[5:]:
376 self.assertRaises(error.TaskAborted, ar.get)
376 self.assertRaises(error.TaskAborted, ar.get)
377
377
378 def test_temp_flags(self):
378 def test_temp_flags(self):
379 view = self.client[-1]
379 view = self.client[-1]
380 view.block=True
380 view.block=True
381 with view.temp_flags(block=False):
381 with view.temp_flags(block=False):
382 self.assertFalse(view.block)
382 self.assertFalse(view.block)
383 self.assertTrue(view.block)
383 self.assertTrue(view.block)
384
384
385 @dec.known_failure_py3
385 @dec.known_failure_py3
386 def test_importer(self):
386 def test_importer(self):
387 view = self.client[-1]
387 view = self.client[-1]
388 view.clear(block=True)
388 view.clear(block=True)
389 with view.importer:
389 with view.importer:
390 import re
390 import re
391
391
392 @interactive
392 @interactive
393 def findall(pat, s):
393 def findall(pat, s):
394 # this globals() step isn't necessary in real code
394 # this globals() step isn't necessary in real code
395 # only to prevent a closure in the test
395 # only to prevent a closure in the test
396 re = globals()['re']
396 re = globals()['re']
397 return re.findall(pat, s)
397 return re.findall(pat, s)
398
398
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400
400
401 def test_unicode_execute(self):
401 def test_unicode_execute(self):
402 """test executing unicode strings"""
402 """test executing unicode strings"""
403 v = self.client[-1]
403 v = self.client[-1]
404 v.block=True
404 v.block=True
405 if sys.version_info[0] >= 3:
405 if sys.version_info[0] >= 3:
406 code="a='é'"
406 code="a='é'"
407 else:
407 else:
408 code=u"a=u'é'"
408 code=u"a=u'é'"
409 v.execute(code)
409 v.execute(code)
410 self.assertEqual(v['a'], u'é')
410 self.assertEqual(v['a'], u'é')
411
411
412 def test_unicode_apply_result(self):
412 def test_unicode_apply_result(self):
413 """test unicode apply results"""
413 """test unicode apply results"""
414 v = self.client[-1]
414 v = self.client[-1]
415 r = v.apply_sync(lambda : u'é')
415 r = v.apply_sync(lambda : u'é')
416 self.assertEqual(r, u'é')
416 self.assertEqual(r, u'é')
417
417
418 def test_unicode_apply_arg(self):
418 def test_unicode_apply_arg(self):
419 """test passing unicode arguments to apply"""
419 """test passing unicode arguments to apply"""
420 v = self.client[-1]
420 v = self.client[-1]
421
421
422 @interactive
422 @interactive
423 def check_unicode(a, check):
423 def check_unicode(a, check):
424 assert isinstance(a, unicode), "%r is not unicode"%a
424 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(check, bytes), "%r is not bytes"%check
425 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427
427
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 try:
429 try:
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 except error.RemoteError as e:
431 except error.RemoteError as e:
432 if e.ename == 'AssertionError':
432 if e.ename == 'AssertionError':
433 self.fail(e.evalue)
433 self.fail(e.evalue)
434 else:
434 else:
435 raise e
435 raise e
436
436
437 def test_map_reference(self):
437 def test_map_reference(self):
438 """view.map(<Reference>, *seqs) should work"""
438 """view.map(<Reference>, *seqs) should work"""
439 v = self.client[:]
439 v = self.client[:]
440 v.scatter('n', self.client.ids, flatten=True)
440 v.scatter('n', self.client.ids, flatten=True)
441 v.execute("f = lambda x,y: x*y")
441 v.execute("f = lambda x,y: x*y")
442 rf = pmod.Reference('f')
442 rf = pmod.Reference('f')
443 nlist = list(range(10))
443 nlist = list(range(10))
444 mlist = nlist[::-1]
444 mlist = nlist[::-1]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 result = v.map_sync(rf, mlist, nlist)
446 result = v.map_sync(rf, mlist, nlist)
447 self.assertEqual(result, expected)
447 self.assertEqual(result, expected)
448
448
449 def test_apply_reference(self):
449 def test_apply_reference(self):
450 """view.apply(<Reference>, *args) should work"""
450 """view.apply(<Reference>, *args) should work"""
451 v = self.client[:]
451 v = self.client[:]
452 v.scatter('n', self.client.ids, flatten=True)
452 v.scatter('n', self.client.ids, flatten=True)
453 v.execute("f = lambda x: n*x")
453 v.execute("f = lambda x: n*x")
454 rf = pmod.Reference('f')
454 rf = pmod.Reference('f')
455 result = v.apply_sync(rf, 5)
455 result = v.apply_sync(rf, 5)
456 expected = [ 5*id for id in self.client.ids ]
456 expected = [ 5*id for id in self.client.ids ]
457 self.assertEqual(result, expected)
457 self.assertEqual(result, expected)
458
458
459 def test_eval_reference(self):
459 def test_eval_reference(self):
460 v = self.client[self.client.ids[0]]
460 v = self.client[self.client.ids[0]]
461 v['g'] = range(5)
461 v['g'] = range(5)
462 rg = pmod.Reference('g[0]')
462 rg = pmod.Reference('g[0]')
463 echo = lambda x:x
463 echo = lambda x:x
464 self.assertEqual(v.apply_sync(echo, rg), 0)
464 self.assertEqual(v.apply_sync(echo, rg), 0)
465
465
466 def test_reference_nameerror(self):
466 def test_reference_nameerror(self):
467 v = self.client[self.client.ids[0]]
467 v = self.client[self.client.ids[0]]
468 r = pmod.Reference('elvis_has_left')
468 r = pmod.Reference('elvis_has_left')
469 echo = lambda x:x
469 echo = lambda x:x
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471
471
472 def test_single_engine_map(self):
472 def test_single_engine_map(self):
473 e0 = self.client[self.client.ids[0]]
473 e0 = self.client[self.client.ids[0]]
474 r = range(5)
474 r = range(5)
475 check = [ -1*i for i in r ]
475 check = [ -1*i for i in r ]
476 result = e0.map_sync(lambda x: -1*x, r)
476 result = e0.map_sync(lambda x: -1*x, r)
477 self.assertEqual(result, check)
477 self.assertEqual(result, check)
478
478
479 def test_len(self):
479 def test_len(self):
480 """len(view) makes sense"""
480 """len(view) makes sense"""
481 e0 = self.client[self.client.ids[0]]
481 e0 = self.client[self.client.ids[0]]
482 yield self.assertEqual(len(e0), 1)
482 yield self.assertEqual(len(e0), 1)
483 v = self.client[:]
483 v = self.client[:]
484 yield self.assertEqual(len(v), len(self.client.ids))
484 yield self.assertEqual(len(v), len(self.client.ids))
485 v = self.client.direct_view('all')
485 v = self.client.direct_view('all')
486 yield self.assertEqual(len(v), len(self.client.ids))
486 yield self.assertEqual(len(v), len(self.client.ids))
487 v = self.client[:2]
487 v = self.client[:2]
488 yield self.assertEqual(len(v), 2)
488 yield self.assertEqual(len(v), 2)
489 v = self.client[:1]
489 v = self.client[:1]
490 yield self.assertEqual(len(v), 1)
490 yield self.assertEqual(len(v), 1)
491 v = self.client.load_balanced_view()
491 v = self.client.load_balanced_view()
492 yield self.assertEqual(len(v), len(self.client.ids))
492 yield self.assertEqual(len(v), len(self.client.ids))
493 # parametric tests seem to require manual closing?
493 # parametric tests seem to require manual closing?
494 self.client.close()
494 self.client.close()
495
495
496
496
497 # begin execute tests
497 # begin execute tests
498
498
499 def test_execute_reply(self):
499 def test_execute_reply(self):
500 e0 = self.client[self.client.ids[0]]
500 e0 = self.client[self.client.ids[0]]
501 e0.block = True
501 e0.block = True
502 ar = e0.execute("5", silent=False)
502 ar = e0.execute("5", silent=False)
503 er = ar.get()
503 er = ar.get()
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506
506
507 def test_execute_reply_stdout(self):
507 def test_execute_reply_stdout(self):
508 e0 = self.client[self.client.ids[0]]
508 e0 = self.client[self.client.ids[0]]
509 e0.block = True
509 e0.block = True
510 ar = e0.execute("print (5)", silent=False)
510 ar = e0.execute("print (5)", silent=False)
511 er = ar.get()
511 er = ar.get()
512 self.assertEqual(er.stdout.strip(), '5')
512 self.assertEqual(er.stdout.strip(), '5')
513
513
514 def test_execute_pyout(self):
514 def test_execute_pyout(self):
515 """execute triggers pyout with silent=False"""
515 """execute triggers pyout with silent=False"""
516 view = self.client[:]
516 view = self.client[:]
517 ar = view.execute("5", silent=False, block=True)
517 ar = view.execute("5", silent=False, block=True)
518
518
519 expected = [{'text/plain' : '5'}] * len(view)
519 expected = [{'text/plain' : '5'}] * len(view)
520 mimes = [ out['data'] for out in ar.pyout ]
520 mimes = [ out['data'] for out in ar.pyout ]
521 self.assertEqual(mimes, expected)
521 self.assertEqual(mimes, expected)
522
522
523 def test_execute_silent(self):
523 def test_execute_silent(self):
524 """execute does not trigger pyout with silent=True"""
524 """execute does not trigger pyout with silent=True"""
525 view = self.client[:]
525 view = self.client[:]
526 ar = view.execute("5", block=True)
526 ar = view.execute("5", block=True)
527 expected = [None] * len(view)
527 expected = [None] * len(view)
528 self.assertEqual(ar.pyout, expected)
528 self.assertEqual(ar.pyout, expected)
529
529
530 def test_execute_magic(self):
530 def test_execute_magic(self):
531 """execute accepts IPython commands"""
531 """execute accepts IPython commands"""
532 view = self.client[:]
532 view = self.client[:]
533 view.execute("a = 5")
533 view.execute("a = 5")
534 ar = view.execute("%whos", block=True)
534 ar = view.execute("%whos", block=True)
535 # this will raise, if that failed
535 # this will raise, if that failed
536 ar.get(5)
536 ar.get(5)
537 for stdout in ar.stdout:
537 for stdout in ar.stdout:
538 lines = stdout.splitlines()
538 lines = stdout.splitlines()
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 found = False
540 found = False
541 for line in lines[2:]:
541 for line in lines[2:]:
542 split = line.split()
542 split = line.split()
543 if split == ['a', 'int', '5']:
543 if split == ['a', 'int', '5']:
544 found = True
544 found = True
545 break
545 break
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547
547
548 def test_execute_displaypub(self):
548 def test_execute_displaypub(self):
549 """execute tracks display_pub output"""
549 """execute tracks display_pub output"""
550 view = self.client[:]
550 view = self.client[:]
551 view.execute("from IPython.core.display import *")
551 view.execute("from IPython.core.display import *")
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553
553
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 for outputs in ar.outputs:
555 for outputs in ar.outputs:
556 mimes = [ out['data'] for out in outputs ]
556 mimes = [ out['data'] for out in outputs ]
557 self.assertEqual(mimes, expected)
557 self.assertEqual(mimes, expected)
558
558
559 def test_apply_displaypub(self):
559 def test_apply_displaypub(self):
560 """apply tracks display_pub output"""
560 """apply tracks display_pub output"""
561 view = self.client[:]
561 view = self.client[:]
562 view.execute("from IPython.core.display import *")
562 view.execute("from IPython.core.display import *")
563
563
564 @interactive
564 @interactive
565 def publish():
565 def publish():
566 [ display(i) for i in range(5) ]
566 [ display(i) for i in range(5) ]
567
567
568 ar = view.apply_async(publish)
568 ar = view.apply_async(publish)
569 ar.get(5)
569 ar.get(5)
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 for outputs in ar.outputs:
571 for outputs in ar.outputs:
572 mimes = [ out['data'] for out in outputs ]
572 mimes = [ out['data'] for out in outputs ]
573 self.assertEqual(mimes, expected)
573 self.assertEqual(mimes, expected)
574
574
575 def test_execute_raises(self):
575 def test_execute_raises(self):
576 """exceptions in execute requests raise appropriately"""
576 """exceptions in execute requests raise appropriately"""
577 view = self.client[-1]
577 view = self.client[-1]
578 ar = view.execute("1/0")
578 ar = view.execute("1/0")
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580
580
581 @dec.skipif_not_matplotlib
581 @dec.skipif_not_matplotlib
582 def test_magic_pylab(self):
582 def test_magic_pylab(self):
583 """%pylab works on engines"""
583 """%pylab works on engines"""
584 view = self.client[-1]
584 view = self.client[-1]
585 ar = view.execute("%pylab inline")
585 ar = view.execute("%pylab inline")
586 # at least check if this raised:
586 # at least check if this raised:
587 reply = ar.get(5)
587 reply = ar.get(5)
588 # include imports, in case user config
588 # include imports, in case user config
589 ar = view.execute("plot(rand(100))", silent=False)
589 ar = view.execute("plot(rand(100))", silent=False)
590 reply = ar.get(5)
590 reply = ar.get(5)
591 self.assertEqual(len(reply.outputs), 1)
591 self.assertEqual(len(reply.outputs), 1)
592 output = reply.outputs[0]
592 output = reply.outputs[0]
593 self.assertTrue("data" in output)
593 self.assertTrue("data" in output)
594 data = output['data']
594 data = output['data']
595 self.assertTrue("image/png" in data)
595 self.assertTrue("image/png" in data)
596
596
597 def test_func_default_func(self):
597 def test_func_default_func(self):
598 """interactively defined function as apply func default"""
598 """interactively defined function as apply func default"""
599 def foo():
599 def foo():
600 return 'foo'
600 return 'foo'
601
601
602 def bar(f=foo):
602 def bar(f=foo):
603 return f()
603 return f()
604
604
605 view = self.client[-1]
605 view = self.client[-1]
606 ar = view.apply_async(bar)
606 ar = view.apply_async(bar)
607 r = ar.get(10)
607 r = ar.get(10)
608 self.assertEquals(r, 'foo')
608 self.assertEqual(r, 'foo')
609 def test_data_pub_single(self):
609 def test_data_pub_single(self):
610 view = self.client[-1]
610 view = self.client[-1]
611 ar = view.execute('\n'.join([
611 ar = view.execute('\n'.join([
612 'from IPython.zmq.datapub import publish_data',
612 'from IPython.zmq.datapub import publish_data',
613 'for i in range(5):',
613 'for i in range(5):',
614 ' publish_data(dict(i=i))'
614 ' publish_data(dict(i=i))'
615 ]), block=False)
615 ]), block=False)
616 self.assertTrue(isinstance(ar.data, dict))
616 self.assertTrue(isinstance(ar.data, dict))
617 ar.get(5)
617 ar.get(5)
618 self.assertEqual(ar.data, dict(i=4))
618 self.assertEqual(ar.data, dict(i=4))
619
619
620 def test_data_pub(self):
620 def test_data_pub(self):
621 view = self.client[:]
621 view = self.client[:]
622 ar = view.execute('\n'.join([
622 ar = view.execute('\n'.join([
623 'from IPython.zmq.datapub import publish_data',
623 'from IPython.zmq.datapub import publish_data',
624 'for i in range(5):',
624 'for i in range(5):',
625 ' publish_data(dict(i=i))'
625 ' publish_data(dict(i=i))'
626 ]), block=False)
626 ]), block=False)
627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
628 ar.get(5)
628 ar.get(5)
629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
630
630
631 def test_can_list_arg(self):
631 def test_can_list_arg(self):
632 """args in lists are canned"""
632 """args in lists are canned"""
633 view = self.client[-1]
633 view = self.client[-1]
634 view['a'] = 128
634 view['a'] = 128
635 rA = pmod.Reference('a')
635 rA = pmod.Reference('a')
636 ar = view.apply_async(lambda x: x, [rA])
636 ar = view.apply_async(lambda x: x, [rA])
637 r = ar.get(5)
637 r = ar.get(5)
638 self.assertEqual(r, [128])
638 self.assertEqual(r, [128])
639
639
640 def test_can_dict_arg(self):
640 def test_can_dict_arg(self):
641 """args in dicts are canned"""
641 """args in dicts are canned"""
642 view = self.client[-1]
642 view = self.client[-1]
643 view['a'] = 128
643 view['a'] = 128
644 rA = pmod.Reference('a')
644 rA = pmod.Reference('a')
645 ar = view.apply_async(lambda x: x, dict(foo=rA))
645 ar = view.apply_async(lambda x: x, dict(foo=rA))
646 r = ar.get(5)
646 r = ar.get(5)
647 self.assertEqual(r, dict(foo=128))
647 self.assertEqual(r, dict(foo=128))
648
648
649 def test_can_list_kwarg(self):
649 def test_can_list_kwarg(self):
650 """kwargs in lists are canned"""
650 """kwargs in lists are canned"""
651 view = self.client[-1]
651 view = self.client[-1]
652 view['a'] = 128
652 view['a'] = 128
653 rA = pmod.Reference('a')
653 rA = pmod.Reference('a')
654 ar = view.apply_async(lambda x=5: x, x=[rA])
654 ar = view.apply_async(lambda x=5: x, x=[rA])
655 r = ar.get(5)
655 r = ar.get(5)
656 self.assertEqual(r, [128])
656 self.assertEqual(r, [128])
657
657
658 def test_can_dict_kwarg(self):
658 def test_can_dict_kwarg(self):
659 """kwargs in dicts are canned"""
659 """kwargs in dicts are canned"""
660 view = self.client[-1]
660 view = self.client[-1]
661 view['a'] = 128
661 view['a'] = 128
662 rA = pmod.Reference('a')
662 rA = pmod.Reference('a')
663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
664 r = ar.get(5)
664 r = ar.get(5)
665 self.assertEqual(r, dict(foo=128))
665 self.assertEqual(r, dict(foo=128))
666
666
667 def test_map_ref(self):
667 def test_map_ref(self):
668 """view.map works with references"""
668 """view.map works with references"""
669 view = self.client[:]
669 view = self.client[:]
670 ranks = sorted(self.client.ids)
670 ranks = sorted(self.client.ids)
671 view.scatter('rank', ranks, flatten=True)
671 view.scatter('rank', ranks, flatten=True)
672 rrank = pmod.Reference('rank')
672 rrank = pmod.Reference('rank')
673
673
674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
675 drank = amr.get(5)
675 drank = amr.get(5)
676 self.assertEqual(drank, [ r*2 for r in ranks ])
676 self.assertEqual(drank, [ r*2 for r in ranks ])
677
677
678
678
General Comments 0
You need to be logged in to leave comments. Login now