##// END OF EJS Templates
s/assertEquals/assertEqual (again)...
MinRK -
Show More
@@ -1,34 +1,34 b''
1 1 """Tests for the notebook manager."""
2 2
3 3 import os
4 4 from unittest import TestCase
5 5 from tempfile import NamedTemporaryFile
6 6
7 7 from IPython.utils.tempdir import TemporaryDirectory
8 8 from IPython.utils.traitlets import TraitError
9 9
10 10 from IPython.frontend.html.notebook.filenbmanager import FileNotebookManager
11 11
12 12 class TestNotebookManager(TestCase):
13 13
14 14 def test_nb_dir(self):
15 15 with TemporaryDirectory() as td:
16 16 km = FileNotebookManager(notebook_dir=td)
17 self.assertEquals(km.notebook_dir, td)
17 self.assertEqual(km.notebook_dir, td)
18 18
19 19 def test_create_nb_dir(self):
20 20 with TemporaryDirectory() as td:
21 21 nbdir = os.path.join(td, 'notebooks')
22 22 km = FileNotebookManager(notebook_dir=nbdir)
23 self.assertEquals(km.notebook_dir, nbdir)
23 self.assertEqual(km.notebook_dir, nbdir)
24 24
25 25 def test_missing_nb_dir(self):
26 26 with TemporaryDirectory() as td:
27 27 nbdir = os.path.join(td, 'notebook', 'dir', 'is', 'missing')
28 28 self.assertRaises(TraitError, FileNotebookManager, notebook_dir=nbdir)
29 29
30 30 def test_invalid_nb_dir(self):
31 31 with NamedTemporaryFile() as tf:
32 32 self.assertRaises(TraitError, FileNotebookManager, notebook_dir=tf.name)
33 33
34 34
@@ -1,314 +1,314 b''
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import logging
22 22 import os
23 23 import tempfile
24 24 import time
25 25
26 26 from datetime import datetime, timedelta
27 27 from unittest import TestCase
28 28
29 29 from IPython.parallel import error
30 30 from IPython.parallel.controller.dictdb import DictDB
31 31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 32 from IPython.parallel.controller.hub import init_record, empty_record
33 33
34 34 from IPython.testing import decorators as dec
35 35 from IPython.zmq.session import Session
36 36
37 37
38 38 #-------------------------------------------------------------------------------
39 39 # TestCases
40 40 #-------------------------------------------------------------------------------
41 41
42 42
43 43 def setup():
44 44 global temp_db
45 45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46 46
47 47
48 48 class TaskDBTest:
49 49 def setUp(self):
50 50 self.session = Session()
51 51 self.db = self.create_db()
52 52 self.load_records(16)
53 53
54 54 def create_db(self):
55 55 raise NotImplementedError
56 56
57 57 def load_records(self, n=1, buffer_size=100):
58 58 """load n records for testing"""
59 59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 60 time.sleep(0.1)
61 61 msg_ids = []
62 62 for i in range(n):
63 63 msg = self.session.msg('apply_request', content=dict(a=5))
64 64 msg['buffers'] = [os.urandom(buffer_size)]
65 65 rec = init_record(msg)
66 66 msg_id = msg['header']['msg_id']
67 67 msg_ids.append(msg_id)
68 68 self.db.add_record(msg_id, rec)
69 69 return msg_ids
70 70
71 71 def test_add_record(self):
72 72 before = self.db.get_history()
73 73 self.load_records(5)
74 74 after = self.db.get_history()
75 75 self.assertEqual(len(after), len(before)+5)
76 76 self.assertEqual(after[:-5],before)
77 77
78 78 def test_drop_record(self):
79 79 msg_id = self.load_records()[-1]
80 80 rec = self.db.get_record(msg_id)
81 81 self.db.drop_record(msg_id)
82 82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83 83
84 84 def _round_to_millisecond(self, dt):
85 85 """necessary because mongodb rounds microseconds"""
86 86 micro = dt.microsecond
87 87 extra = int(str(micro)[-3:])
88 88 return dt - timedelta(microseconds=extra)
89 89
90 90 def test_update_record(self):
91 91 now = self._round_to_millisecond(datetime.now())
92 92 #
93 93 msg_id = self.db.get_history()[-1]
94 94 rec1 = self.db.get_record(msg_id)
95 95 data = {'stdout': 'hello there', 'completed' : now}
96 96 self.db.update_record(msg_id, data)
97 97 rec2 = self.db.get_record(msg_id)
98 98 self.assertEqual(rec2['stdout'], 'hello there')
99 99 self.assertEqual(rec2['completed'], now)
100 100 rec1.update(data)
101 101 self.assertEqual(rec1, rec2)
102 102
103 103 # def test_update_record_bad(self):
104 104 # """test updating nonexistant records"""
105 105 # msg_id = str(uuid.uuid4())
106 106 # data = {'stdout': 'hello there'}
107 107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108 108
109 109 def test_find_records_dt(self):
110 110 """test finding records by date"""
111 111 hist = self.db.get_history()
112 112 middle = self.db.get_record(hist[len(hist)//2])
113 113 tic = middle['submitted']
114 114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 116 self.assertEqual(len(before)+len(after),len(hist))
117 117 for b in before:
118 118 self.assertTrue(b['submitted'] < tic)
119 119 for a in after:
120 120 self.assertTrue(a['submitted'] >= tic)
121 121 same = self.db.find_records({'submitted' : tic})
122 122 for s in same:
123 123 self.assertTrue(s['submitted'] == tic)
124 124
125 125 def test_find_records_keys(self):
126 126 """test extracting subset of record keys"""
127 127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 128 for rec in found:
129 129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130 130
131 131 def test_find_records_msg_id(self):
132 132 """ensure msg_id is always in found records"""
133 133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 134 for rec in found:
135 135 self.assertTrue('msg_id' in rec.keys())
136 136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 137 for rec in found:
138 138 self.assertTrue('msg_id' in rec.keys())
139 139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 140 for rec in found:
141 141 self.assertTrue('msg_id' in rec.keys())
142 142
143 143 def test_find_records_in(self):
144 144 """test finding records with '$in','$nin' operators"""
145 145 hist = self.db.get_history()
146 146 even = hist[::2]
147 147 odd = hist[1::2]
148 148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 149 found = [ r['msg_id'] for r in recs ]
150 150 self.assertEqual(set(even), set(found))
151 151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 152 found = [ r['msg_id'] for r in recs ]
153 153 self.assertEqual(set(odd), set(found))
154 154
155 155 def test_get_history(self):
156 156 msg_ids = self.db.get_history()
157 157 latest = datetime(1984,1,1)
158 158 for msg_id in msg_ids:
159 159 rec = self.db.get_record(msg_id)
160 160 newt = rec['submitted']
161 161 self.assertTrue(newt >= latest)
162 162 latest = newt
163 163 msg_id = self.load_records(1)[-1]
164 164 self.assertEqual(self.db.get_history()[-1],msg_id)
165 165
166 166 def test_datetime(self):
167 167 """get/set timestamps with datetime objects"""
168 168 msg_id = self.db.get_history()[-1]
169 169 rec = self.db.get_record(msg_id)
170 170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 172 rec = self.db.get_record(msg_id)
173 173 self.assertTrue(isinstance(rec['completed'], datetime))
174 174
175 175 def test_drop_matching(self):
176 176 msg_ids = self.load_records(10)
177 177 query = {'msg_id' : {'$in':msg_ids}}
178 178 self.db.drop_matching_records(query)
179 179 recs = self.db.find_records(query)
180 180 self.assertEqual(len(recs), 0)
181 181
182 182 def test_null(self):
183 183 """test None comparison queries"""
184 184 msg_ids = self.load_records(10)
185 185
186 186 query = {'msg_id' : None}
187 187 recs = self.db.find_records(query)
188 188 self.assertEqual(len(recs), 0)
189 189
190 190 query = {'msg_id' : {'$ne' : None}}
191 191 recs = self.db.find_records(query)
192 192 self.assertTrue(len(recs) >= 10)
193 193
194 194 def test_pop_safe_get(self):
195 195 """editing query results shouldn't affect record [get]"""
196 196 msg_id = self.db.get_history()[-1]
197 197 rec = self.db.get_record(msg_id)
198 198 rec.pop('buffers')
199 199 rec['garbage'] = 'hello'
200 200 rec['header']['msg_id'] = 'fubar'
201 201 rec2 = self.db.get_record(msg_id)
202 202 self.assertTrue('buffers' in rec2)
203 203 self.assertFalse('garbage' in rec2)
204 204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205 205
206 206 def test_pop_safe_find(self):
207 207 """editing query results shouldn't affect record [find]"""
208 208 msg_id = self.db.get_history()[-1]
209 209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 210 rec.pop('buffers')
211 211 rec['garbage'] = 'hello'
212 212 rec['header']['msg_id'] = 'fubar'
213 213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 214 self.assertTrue('buffers' in rec2)
215 215 self.assertFalse('garbage' in rec2)
216 216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217 217
218 218 def test_pop_safe_find_keys(self):
219 219 """editing query results shouldn't affect record [find+keys]"""
220 220 msg_id = self.db.get_history()[-1]
221 221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 222 rec.pop('buffers')
223 223 rec['garbage'] = 'hello'
224 224 rec['header']['msg_id'] = 'fubar'
225 225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 226 self.assertTrue('buffers' in rec2)
227 227 self.assertFalse('garbage' in rec2)
228 228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229 229
230 230
231 231 class TestDictBackend(TaskDBTest, TestCase):
232 232
233 233 def create_db(self):
234 234 return DictDB()
235 235
236 236 def test_cull_count(self):
237 237 self.db = self.create_db() # skip the load-records init from setUp
238 238 self.db.record_limit = 20
239 239 self.db.cull_fraction = 0.2
240 240 self.load_records(20)
241 self.assertEquals(len(self.db.get_history()), 20)
241 self.assertEqual(len(self.db.get_history()), 20)
242 242 self.load_records(1)
243 243 # 0.2 * 20 = 4, 21 - 4 = 17
244 self.assertEquals(len(self.db.get_history()), 17)
244 self.assertEqual(len(self.db.get_history()), 17)
245 245 self.load_records(3)
246 self.assertEquals(len(self.db.get_history()), 20)
246 self.assertEqual(len(self.db.get_history()), 20)
247 247 self.load_records(1)
248 self.assertEquals(len(self.db.get_history()), 17)
248 self.assertEqual(len(self.db.get_history()), 17)
249 249
250 250 for i in range(100):
251 251 self.load_records(1)
252 252 self.assertTrue(len(self.db.get_history()) >= 17)
253 253 self.assertTrue(len(self.db.get_history()) <= 20)
254 254
255 255 def test_cull_size(self):
256 256 self.db = self.create_db() # skip the load-records init from setUp
257 257 self.db.size_limit = 1000
258 258 self.db.cull_fraction = 0.2
259 259 self.load_records(100, buffer_size=10)
260 self.assertEquals(len(self.db.get_history()), 100)
260 self.assertEqual(len(self.db.get_history()), 100)
261 261 self.load_records(1, buffer_size=0)
262 self.assertEquals(len(self.db.get_history()), 101)
262 self.assertEqual(len(self.db.get_history()), 101)
263 263 self.load_records(1, buffer_size=1)
264 264 # 0.2 * 100 = 20, 101 - 20 = 81
265 self.assertEquals(len(self.db.get_history()), 81)
265 self.assertEqual(len(self.db.get_history()), 81)
266 266
267 267 def test_cull_size_drop(self):
268 268 """dropping records updates tracked buffer size"""
269 269 self.db = self.create_db() # skip the load-records init from setUp
270 270 self.db.size_limit = 1000
271 271 self.db.cull_fraction = 0.2
272 272 self.load_records(100, buffer_size=10)
273 self.assertEquals(len(self.db.get_history()), 100)
273 self.assertEqual(len(self.db.get_history()), 100)
274 274 self.db.drop_record(self.db.get_history()[-1])
275 self.assertEquals(len(self.db.get_history()), 99)
275 self.assertEqual(len(self.db.get_history()), 99)
276 276 self.load_records(1, buffer_size=5)
277 self.assertEquals(len(self.db.get_history()), 100)
277 self.assertEqual(len(self.db.get_history()), 100)
278 278 self.load_records(1, buffer_size=5)
279 self.assertEquals(len(self.db.get_history()), 101)
279 self.assertEqual(len(self.db.get_history()), 101)
280 280 self.load_records(1, buffer_size=1)
281 self.assertEquals(len(self.db.get_history()), 81)
281 self.assertEqual(len(self.db.get_history()), 81)
282 282
283 283 def test_cull_size_update(self):
284 284 """updating records updates tracked buffer size"""
285 285 self.db = self.create_db() # skip the load-records init from setUp
286 286 self.db.size_limit = 1000
287 287 self.db.cull_fraction = 0.2
288 288 self.load_records(100, buffer_size=10)
289 self.assertEquals(len(self.db.get_history()), 100)
289 self.assertEqual(len(self.db.get_history()), 100)
290 290 msg_id = self.db.get_history()[-1]
291 291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
292 self.assertEquals(len(self.db.get_history()), 100)
292 self.assertEqual(len(self.db.get_history()), 100)
293 293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
294 self.assertEquals(len(self.db.get_history()), 79)
294 self.assertEqual(len(self.db.get_history()), 79)
295 295
296 296 class TestSQLiteBackend(TaskDBTest, TestCase):
297 297
298 298 @dec.skip_without('sqlite3')
299 299 def create_db(self):
300 300 location, fname = os.path.split(temp_db)
301 301 log = logging.getLogger('test')
302 302 log.setLevel(logging.CRITICAL)
303 303 return SQLiteDB(location=location, fname=fname, log=log)
304 304
305 305 def tearDown(self):
306 306 self.db._db.close()
307 307
308 308
309 309 def teardown():
310 310 """cleanup task db file after all tests have run"""
311 311 try:
312 312 os.remove(temp_db)
313 313 except:
314 314 pass
@@ -1,678 +1,678 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from tempfile import mktemp
23 23 from StringIO import StringIO
24 24
25 25 import zmq
26 26 from nose import SkipTest
27 27
28 28 from IPython.testing import decorators as dec
29 29 from IPython.testing.ipunittest import ParametricTestCase
30 30
31 31 from IPython import parallel as pmod
32 32 from IPython.parallel import error
33 33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 34 from IPython.parallel import DirectView
35 35 from IPython.parallel.util import interactive
36 36
37 37 from IPython.parallel.tests import add_engines
38 38
39 39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 40
41 41 def setup():
42 42 add_engines(3, total=True)
43 43
44 44 class TestView(ClusterTestCase, ParametricTestCase):
45 45
46 46 def setUp(self):
47 47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 50 time.sleep(2)
51 51 super(TestView, self).setUp()
52 52
53 53 def test_z_crash_mux(self):
54 54 """test graceful handling of engine death (direct)"""
55 55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 56 # self.add_engines(1)
57 57 eid = self.client.ids[-1]
58 58 ar = self.client[eid].apply_async(crash)
59 59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 60 eid = ar.engine_id
61 61 tic = time.time()
62 62 while eid in self.client.ids and time.time()-tic < 5:
63 63 time.sleep(.01)
64 64 self.client.spin()
65 65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66 66
67 67 def test_push_pull(self):
68 68 """test pushing and pulling"""
69 69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 70 t = self.client.ids[-1]
71 71 v = self.client[t]
72 72 push = v.push
73 73 pull = v.pull
74 74 v.block=True
75 75 nengines = len(self.client)
76 76 push({'data':data})
77 77 d = pull('data')
78 78 self.assertEqual(d, data)
79 79 self.client[:].push({'data':data})
80 80 d = self.client[:].pull('data', block=True)
81 81 self.assertEqual(d, nengines*[data])
82 82 ar = push({'data':data}, block=False)
83 83 self.assertTrue(isinstance(ar, AsyncResult))
84 84 r = ar.get()
85 85 ar = self.client[:].pull('data', block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 88 self.assertEqual(r, nengines*[data])
89 89 self.client[:].push(dict(a=10,b=20))
90 90 r = self.client[:].pull(('a','b'), block=True)
91 91 self.assertEqual(r, nengines*[[10,20]])
92 92
93 93 def test_push_pull_function(self):
94 94 "test pushing and pulling functions"
95 95 def testf(x):
96 96 return 2.0*x
97 97
98 98 t = self.client.ids[-1]
99 99 v = self.client[t]
100 100 v.block=True
101 101 push = v.push
102 102 pull = v.pull
103 103 execute = v.execute
104 104 push({'testf':testf})
105 105 r = pull('testf')
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute('r = testf(10)')
108 108 r = pull('r')
109 109 self.assertEqual(r, testf(10))
110 110 ar = self.client[:].push({'testf':testf}, block=False)
111 111 ar.get()
112 112 ar = self.client[:].pull('testf', block=False)
113 113 rlist = ar.get()
114 114 for r in rlist:
115 115 self.assertEqual(r(1.0), testf(1.0))
116 116 execute("def g(x): return x*x")
117 117 r = pull(('testf','g'))
118 118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119 119
120 120 def test_push_function_globals(self):
121 121 """test that pushed functions have access to globals"""
122 122 @interactive
123 123 def geta():
124 124 return a
125 125 # self.add_engines(1)
126 126 v = self.client[-1]
127 127 v.block=True
128 128 v['f'] = geta
129 129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 130 v.execute('a=5')
131 131 v.execute('b=f()')
132 132 self.assertEqual(v['b'], 5)
133 133
134 134 def test_push_function_defaults(self):
135 135 """test that pushed functions preserve default args"""
136 136 def echo(a=10):
137 137 return a
138 138 v = self.client[-1]
139 139 v.block=True
140 140 v['f'] = echo
141 141 v.execute('b=f()')
142 142 self.assertEqual(v['b'], 10)
143 143
144 144 def test_get_result(self):
145 145 """test getting results from the Hub."""
146 146 c = pmod.Client(profile='iptest')
147 147 # self.add_engines(1)
148 148 t = c.ids[-1]
149 149 v = c[t]
150 150 v2 = self.client[t]
151 151 ar = v.apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = v2.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 156 self.assertEqual(ahr.get(), ar.get())
157 157 ar2 = v2.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.spin()
160 160 c.close()
161 161
162 162 def test_run_newline(self):
163 163 """test that run appends newline to files"""
164 164 tmpfile = mktemp()
165 165 with open(tmpfile, 'w') as f:
166 166 f.write("""def g():
167 167 return 5
168 168 """)
169 169 v = self.client[-1]
170 170 v.run(tmpfile, block=True)
171 171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172 172
173 173 def test_apply_tracked(self):
174 174 """test tracking for apply"""
175 175 # self.add_engines(1)
176 176 t = self.client.ids[-1]
177 177 v = self.client[t]
178 178 v.block=False
179 179 def echo(n=1024*1024, **kwargs):
180 180 with v.temp_flags(**kwargs):
181 181 return v.apply(lambda x: x, 'x'*n)
182 182 ar = echo(1, track=False)
183 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 184 self.assertTrue(ar.sent)
185 185 ar = echo(track=True)
186 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 187 self.assertEqual(ar.sent, ar._tracker.done)
188 188 ar._tracker.wait()
189 189 self.assertTrue(ar.sent)
190 190
191 191 def test_push_tracked(self):
192 192 t = self.client.ids[-1]
193 193 ns = dict(x='x'*1024*1024)
194 194 v = self.client[t]
195 195 ar = v.push(ns, block=False, track=False)
196 196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 197 self.assertTrue(ar.sent)
198 198
199 199 ar = v.push(ns, block=False, track=True)
200 200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 201 ar._tracker.wait()
202 202 self.assertEqual(ar.sent, ar._tracker.done)
203 203 self.assertTrue(ar.sent)
204 204 ar.get()
205 205
206 206 def test_scatter_tracked(self):
207 207 t = self.client.ids
208 208 x='x'*1024*1024
209 209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 211 self.assertTrue(ar.sent)
212 212
213 213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 215 self.assertEqual(ar.sent, ar._tracker.done)
216 216 ar._tracker.wait()
217 217 self.assertTrue(ar.sent)
218 218 ar.get()
219 219
220 220 def test_remote_reference(self):
221 221 v = self.client[-1]
222 222 v['a'] = 123
223 223 ra = pmod.Reference('a')
224 224 b = v.apply_sync(lambda x: x, ra)
225 225 self.assertEqual(b, 123)
226 226
227 227
228 228 def test_scatter_gather(self):
229 229 view = self.client[:]
230 230 seq1 = range(16)
231 231 view.scatter('a', seq1)
232 232 seq2 = view.gather('a', block=True)
233 233 self.assertEqual(seq2, seq1)
234 234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235 235
236 236 @skip_without('numpy')
237 237 def test_scatter_gather_numpy(self):
238 238 import numpy
239 239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 240 view = self.client[:]
241 241 a = numpy.arange(64)
242 242 view.scatter('a', a, block=True)
243 243 b = view.gather('a', block=True)
244 244 assert_array_equal(b, a)
245 245
246 246 def test_scatter_gather_lazy(self):
247 247 """scatter/gather with targets='all'"""
248 248 view = self.client.direct_view(targets='all')
249 249 x = range(64)
250 250 view.scatter('x', x)
251 251 gathered = view.gather('x', block=True)
252 252 self.assertEqual(gathered, x)
253 253
254 254
255 255 @dec.known_failure_py3
256 256 @skip_without('numpy')
257 257 def test_push_numpy_nocopy(self):
258 258 import numpy
259 259 view = self.client[:]
260 260 a = numpy.arange(64)
261 261 view['A'] = a
262 262 @interactive
263 263 def check_writeable(x):
264 264 return x.flags.writeable
265 265
266 266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268 268
269 269 view.push(dict(B=a))
270 270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 272
273 273 @skip_without('numpy')
274 274 def test_apply_numpy(self):
275 275 """view.apply(f, ndarray)"""
276 276 import numpy
277 277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 278
279 279 A = numpy.random.random((100,100))
280 280 view = self.client[-1]
281 281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 282 B = A.astype(dt)
283 283 C = view.apply_sync(lambda x:x, B)
284 284 assert_array_equal(B,C)
285 285
286 286 @skip_without('numpy')
287 287 def test_push_pull_recarray(self):
288 288 """push/pull recarrays"""
289 289 import numpy
290 290 from numpy.testing.utils import assert_array_equal
291 291
292 292 view = self.client[-1]
293 293
294 294 R = numpy.array([
295 295 (1, 'hi', 0.),
296 296 (2**30, 'there', 2.5),
297 297 (-99999, 'world', -12345.6789),
298 298 ], [('n', int), ('s', '|S10'), ('f', float)])
299 299
300 300 view['RR'] = R
301 301 R2 = view['RR']
302 302
303 303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 304 self.assertEqual(r_dtype, R.dtype)
305 305 self.assertEqual(r_shape, R.shape)
306 306 self.assertEqual(R2.dtype, R.dtype)
307 307 self.assertEqual(R2.shape, R.shape)
308 308 assert_array_equal(R2, R)
309 309
310 310 def test_map(self):
311 311 view = self.client[:]
312 312 def f(x):
313 313 return x**2
314 314 data = range(16)
315 315 r = view.map_sync(f, data)
316 316 self.assertEqual(r, map(f, data))
317 317
318 318 def test_map_iterable(self):
319 319 """test map on iterables (direct)"""
320 320 view = self.client[:]
321 321 # 101 is prime, so it won't be evenly distributed
322 322 arr = range(101)
323 323 # ensure it will be an iterator, even in Python 3
324 324 it = iter(arr)
325 325 r = view.map_sync(lambda x:x, arr)
326 326 self.assertEqual(r, list(arr))
327 327
328 328 def test_scatter_gather_nonblocking(self):
329 329 data = range(16)
330 330 view = self.client[:]
331 331 view.scatter('a', data, block=False)
332 332 ar = view.gather('a', block=False)
333 333 self.assertEqual(ar.get(), data)
334 334
335 335 @skip_without('numpy')
336 336 def test_scatter_gather_numpy_nonblocking(self):
337 337 import numpy
338 338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 339 a = numpy.arange(64)
340 340 view = self.client[:]
341 341 ar = view.scatter('a', a, block=False)
342 342 self.assertTrue(isinstance(ar, AsyncResult))
343 343 amr = view.gather('a', block=False)
344 344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 345 assert_array_equal(amr.get(), a)
346 346
347 347 def test_execute(self):
348 348 view = self.client[:]
349 349 # self.client.debug=True
350 350 execute = view.execute
351 351 ar = execute('c=30', block=False)
352 352 self.assertTrue(isinstance(ar, AsyncResult))
353 353 ar = execute('d=[0,1,2]', block=False)
354 354 self.client.wait(ar, 1)
355 355 self.assertEqual(len(ar.get()), len(self.client))
356 356 for c in view['c']:
357 357 self.assertEqual(c, 30)
358 358
359 359 def test_abort(self):
360 360 view = self.client[-1]
361 361 ar = view.execute('import time; time.sleep(1)', block=False)
362 362 ar2 = view.apply_async(lambda : 2)
363 363 ar3 = view.apply_async(lambda : 3)
364 364 view.abort(ar2)
365 365 view.abort(ar3.msg_ids)
366 366 self.assertRaises(error.TaskAborted, ar2.get)
367 367 self.assertRaises(error.TaskAborted, ar3.get)
368 368
369 369 def test_abort_all(self):
370 370 """view.abort() aborts all outstanding tasks"""
371 371 view = self.client[-1]
372 372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 373 view.abort()
374 374 view.wait(timeout=5)
375 375 for ar in ars[5:]:
376 376 self.assertRaises(error.TaskAborted, ar.get)
377 377
378 378 def test_temp_flags(self):
379 379 view = self.client[-1]
380 380 view.block=True
381 381 with view.temp_flags(block=False):
382 382 self.assertFalse(view.block)
383 383 self.assertTrue(view.block)
384 384
385 385 @dec.known_failure_py3
386 386 def test_importer(self):
387 387 view = self.client[-1]
388 388 view.clear(block=True)
389 389 with view.importer:
390 390 import re
391 391
392 392 @interactive
393 393 def findall(pat, s):
394 394 # this globals() step isn't necessary in real code
395 395 # only to prevent a closure in the test
396 396 re = globals()['re']
397 397 return re.findall(pat, s)
398 398
399 399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400 400
401 401 def test_unicode_execute(self):
402 402 """test executing unicode strings"""
403 403 v = self.client[-1]
404 404 v.block=True
405 405 if sys.version_info[0] >= 3:
406 406 code="a='é'"
407 407 else:
408 408 code=u"a=u'é'"
409 409 v.execute(code)
410 410 self.assertEqual(v['a'], u'é')
411 411
412 412 def test_unicode_apply_result(self):
413 413 """test unicode apply results"""
414 414 v = self.client[-1]
415 415 r = v.apply_sync(lambda : u'é')
416 416 self.assertEqual(r, u'é')
417 417
418 418 def test_unicode_apply_arg(self):
419 419 """test passing unicode arguments to apply"""
420 420 v = self.client[-1]
421 421
422 422 @interactive
423 423 def check_unicode(a, check):
424 424 assert isinstance(a, unicode), "%r is not unicode"%a
425 425 assert isinstance(check, bytes), "%r is not bytes"%check
426 426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427 427
428 428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 429 try:
430 430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 431 except error.RemoteError as e:
432 432 if e.ename == 'AssertionError':
433 433 self.fail(e.evalue)
434 434 else:
435 435 raise e
436 436
437 437 def test_map_reference(self):
438 438 """view.map(<Reference>, *seqs) should work"""
439 439 v = self.client[:]
440 440 v.scatter('n', self.client.ids, flatten=True)
441 441 v.execute("f = lambda x,y: x*y")
442 442 rf = pmod.Reference('f')
443 443 nlist = list(range(10))
444 444 mlist = nlist[::-1]
445 445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 446 result = v.map_sync(rf, mlist, nlist)
447 447 self.assertEqual(result, expected)
448 448
449 449 def test_apply_reference(self):
450 450 """view.apply(<Reference>, *args) should work"""
451 451 v = self.client[:]
452 452 v.scatter('n', self.client.ids, flatten=True)
453 453 v.execute("f = lambda x: n*x")
454 454 rf = pmod.Reference('f')
455 455 result = v.apply_sync(rf, 5)
456 456 expected = [ 5*id for id in self.client.ids ]
457 457 self.assertEqual(result, expected)
458 458
459 459 def test_eval_reference(self):
460 460 v = self.client[self.client.ids[0]]
461 461 v['g'] = range(5)
462 462 rg = pmod.Reference('g[0]')
463 463 echo = lambda x:x
464 464 self.assertEqual(v.apply_sync(echo, rg), 0)
465 465
466 466 def test_reference_nameerror(self):
467 467 v = self.client[self.client.ids[0]]
468 468 r = pmod.Reference('elvis_has_left')
469 469 echo = lambda x:x
470 470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471 471
472 472 def test_single_engine_map(self):
473 473 e0 = self.client[self.client.ids[0]]
474 474 r = range(5)
475 475 check = [ -1*i for i in r ]
476 476 result = e0.map_sync(lambda x: -1*x, r)
477 477 self.assertEqual(result, check)
478 478
479 479 def test_len(self):
480 480 """len(view) makes sense"""
481 481 e0 = self.client[self.client.ids[0]]
482 482 yield self.assertEqual(len(e0), 1)
483 483 v = self.client[:]
484 484 yield self.assertEqual(len(v), len(self.client.ids))
485 485 v = self.client.direct_view('all')
486 486 yield self.assertEqual(len(v), len(self.client.ids))
487 487 v = self.client[:2]
488 488 yield self.assertEqual(len(v), 2)
489 489 v = self.client[:1]
490 490 yield self.assertEqual(len(v), 1)
491 491 v = self.client.load_balanced_view()
492 492 yield self.assertEqual(len(v), len(self.client.ids))
493 493 # parametric tests seem to require manual closing?
494 494 self.client.close()
495 495
496 496
497 497 # begin execute tests
498 498
499 499 def test_execute_reply(self):
500 500 e0 = self.client[self.client.ids[0]]
501 501 e0.block = True
502 502 ar = e0.execute("5", silent=False)
503 503 er = ar.get()
504 504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506 506
507 507 def test_execute_reply_stdout(self):
508 508 e0 = self.client[self.client.ids[0]]
509 509 e0.block = True
510 510 ar = e0.execute("print (5)", silent=False)
511 511 er = ar.get()
512 512 self.assertEqual(er.stdout.strip(), '5')
513 513
514 514 def test_execute_pyout(self):
515 515 """execute triggers pyout with silent=False"""
516 516 view = self.client[:]
517 517 ar = view.execute("5", silent=False, block=True)
518 518
519 519 expected = [{'text/plain' : '5'}] * len(view)
520 520 mimes = [ out['data'] for out in ar.pyout ]
521 521 self.assertEqual(mimes, expected)
522 522
523 523 def test_execute_silent(self):
524 524 """execute does not trigger pyout with silent=True"""
525 525 view = self.client[:]
526 526 ar = view.execute("5", block=True)
527 527 expected = [None] * len(view)
528 528 self.assertEqual(ar.pyout, expected)
529 529
530 530 def test_execute_magic(self):
531 531 """execute accepts IPython commands"""
532 532 view = self.client[:]
533 533 view.execute("a = 5")
534 534 ar = view.execute("%whos", block=True)
535 535 # this will raise, if that failed
536 536 ar.get(5)
537 537 for stdout in ar.stdout:
538 538 lines = stdout.splitlines()
539 539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 540 found = False
541 541 for line in lines[2:]:
542 542 split = line.split()
543 543 if split == ['a', 'int', '5']:
544 544 found = True
545 545 break
546 546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547 547
548 548 def test_execute_displaypub(self):
549 549 """execute tracks display_pub output"""
550 550 view = self.client[:]
551 551 view.execute("from IPython.core.display import *")
552 552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553 553
554 554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 555 for outputs in ar.outputs:
556 556 mimes = [ out['data'] for out in outputs ]
557 557 self.assertEqual(mimes, expected)
558 558
559 559 def test_apply_displaypub(self):
560 560 """apply tracks display_pub output"""
561 561 view = self.client[:]
562 562 view.execute("from IPython.core.display import *")
563 563
564 564 @interactive
565 565 def publish():
566 566 [ display(i) for i in range(5) ]
567 567
568 568 ar = view.apply_async(publish)
569 569 ar.get(5)
570 570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 571 for outputs in ar.outputs:
572 572 mimes = [ out['data'] for out in outputs ]
573 573 self.assertEqual(mimes, expected)
574 574
575 575 def test_execute_raises(self):
576 576 """exceptions in execute requests raise appropriately"""
577 577 view = self.client[-1]
578 578 ar = view.execute("1/0")
579 579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580 580
581 581 @dec.skipif_not_matplotlib
582 582 def test_magic_pylab(self):
583 583 """%pylab works on engines"""
584 584 view = self.client[-1]
585 585 ar = view.execute("%pylab inline")
586 586 # at least check if this raised:
587 587 reply = ar.get(5)
588 588 # include imports, in case user config
589 589 ar = view.execute("plot(rand(100))", silent=False)
590 590 reply = ar.get(5)
591 591 self.assertEqual(len(reply.outputs), 1)
592 592 output = reply.outputs[0]
593 593 self.assertTrue("data" in output)
594 594 data = output['data']
595 595 self.assertTrue("image/png" in data)
596 596
597 597 def test_func_default_func(self):
598 598 """interactively defined function as apply func default"""
599 599 def foo():
600 600 return 'foo'
601 601
602 602 def bar(f=foo):
603 603 return f()
604 604
605 605 view = self.client[-1]
606 606 ar = view.apply_async(bar)
607 607 r = ar.get(10)
608 self.assertEquals(r, 'foo')
608 self.assertEqual(r, 'foo')
609 609 def test_data_pub_single(self):
610 610 view = self.client[-1]
611 611 ar = view.execute('\n'.join([
612 612 'from IPython.zmq.datapub import publish_data',
613 613 'for i in range(5):',
614 614 ' publish_data(dict(i=i))'
615 615 ]), block=False)
616 616 self.assertTrue(isinstance(ar.data, dict))
617 617 ar.get(5)
618 618 self.assertEqual(ar.data, dict(i=4))
619 619
620 620 def test_data_pub(self):
621 621 view = self.client[:]
622 622 ar = view.execute('\n'.join([
623 623 'from IPython.zmq.datapub import publish_data',
624 624 'for i in range(5):',
625 625 ' publish_data(dict(i=i))'
626 626 ]), block=False)
627 627 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
628 628 ar.get(5)
629 629 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
630 630
631 631 def test_can_list_arg(self):
632 632 """args in lists are canned"""
633 633 view = self.client[-1]
634 634 view['a'] = 128
635 635 rA = pmod.Reference('a')
636 636 ar = view.apply_async(lambda x: x, [rA])
637 637 r = ar.get(5)
638 638 self.assertEqual(r, [128])
639 639
640 640 def test_can_dict_arg(self):
641 641 """args in dicts are canned"""
642 642 view = self.client[-1]
643 643 view['a'] = 128
644 644 rA = pmod.Reference('a')
645 645 ar = view.apply_async(lambda x: x, dict(foo=rA))
646 646 r = ar.get(5)
647 647 self.assertEqual(r, dict(foo=128))
648 648
649 649 def test_can_list_kwarg(self):
650 650 """kwargs in lists are canned"""
651 651 view = self.client[-1]
652 652 view['a'] = 128
653 653 rA = pmod.Reference('a')
654 654 ar = view.apply_async(lambda x=5: x, x=[rA])
655 655 r = ar.get(5)
656 656 self.assertEqual(r, [128])
657 657
658 658 def test_can_dict_kwarg(self):
659 659 """kwargs in dicts are canned"""
660 660 view = self.client[-1]
661 661 view['a'] = 128
662 662 rA = pmod.Reference('a')
663 663 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
664 664 r = ar.get(5)
665 665 self.assertEqual(r, dict(foo=128))
666 666
667 667 def test_map_ref(self):
668 668 """view.map works with references"""
669 669 view = self.client[:]
670 670 ranks = sorted(self.client.ids)
671 671 view.scatter('rank', ranks, flatten=True)
672 672 rrank = pmod.Reference('rank')
673 673
674 674 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
675 675 drank = amr.get(5)
676 676 self.assertEqual(drank, [ r*2 for r in ranks ])
677 677
678 678
General Comments 0
You need to be logged in to leave comments. Login now