##// END OF EJS Templates
fix out of sync parallel tests...
MinRK -
Show More
@@ -1,147 +1,148 b''
1 1 """Tests for parallel client.py"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import time
15 15 from tempfile import mktemp
16 16
17 17 import zmq
18 18
19 19 from IPython.parallel.client import client as clientmod
20 20 from IPython.parallel import error
21 21 from IPython.parallel import AsyncResult, AsyncHubResult
22 22 from IPython.parallel import LoadBalancedView, DirectView
23 23
24 24 from clienttest import ClusterTestCase, segfault, wait, add_engines
25 25
26 26 def setup():
27 27 add_engines(4)
28 28
29 29 class TestClient(ClusterTestCase):
30 30
31 31 def test_ids(self):
32 32 n = len(self.client.ids)
33 33 self.add_engines(3)
34 34 self.assertEquals(len(self.client.ids), n+3)
35 35
36 36 def test_view_indexing(self):
37 37 """test index access for views"""
38 38 self.add_engines(2)
39 39 targets = self.client._build_targets('all')[-1]
40 40 v = self.client[:]
41 41 self.assertEquals(v.targets, targets)
42 42 t = self.client.ids[2]
43 43 v = self.client[t]
44 44 self.assert_(isinstance(v, DirectView))
45 45 self.assertEquals(v.targets, t)
46 46 t = self.client.ids[2:4]
47 47 v = self.client[t]
48 48 self.assert_(isinstance(v, DirectView))
49 49 self.assertEquals(v.targets, t)
50 50 v = self.client[::2]
51 51 self.assert_(isinstance(v, DirectView))
52 52 self.assertEquals(v.targets, targets[::2])
53 53 v = self.client[1::3]
54 54 self.assert_(isinstance(v, DirectView))
55 55 self.assertEquals(v.targets, targets[1::3])
56 56 v = self.client[:-3]
57 57 self.assert_(isinstance(v, DirectView))
58 58 self.assertEquals(v.targets, targets[:-3])
59 59 v = self.client[-1]
60 60 self.assert_(isinstance(v, DirectView))
61 61 self.assertEquals(v.targets, targets[-1])
62 62 self.assertRaises(TypeError, lambda : self.client[None])
63 63
64 64 def test_lbview_targets(self):
65 65 """test load_balanced_view targets"""
66 66 v = self.client.load_balanced_view()
67 67 self.assertEquals(v.targets, None)
68 68 v = self.client.load_balanced_view(-1)
69 69 self.assertEquals(v.targets, [self.client.ids[-1]])
70 70 v = self.client.load_balanced_view('all')
71 71 self.assertEquals(v.targets, self.client.ids)
72 72
73 73 def test_targets(self):
74 74 """test various valid targets arguments"""
75 75 build = self.client._build_targets
76 76 ids = self.client.ids
77 77 idents,targets = build(None)
78 78 self.assertEquals(ids, targets)
79 79
80 80 def test_clear(self):
81 81 """test clear behavior"""
82 82 # self.add_engines(2)
83 83 v = self.client[:]
84 84 v.block=True
85 85 v.push(dict(a=5))
86 86 v.pull('a')
87 87 id0 = self.client.ids[-1]
88 88 self.client.clear(targets=id0)
89 89 self.client[:-1].pull('a')
90 90 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
91 91 self.client.clear(block=True)
92 92 for i in self.client.ids:
93 93 # print i
94 94 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
95 95
96 96 def test_get_result(self):
97 97 """test getting results from the Hub."""
98 98 c = clientmod.Client(profile='iptest')
99 99 # self.add_engines(1)
100 100 t = c.ids[-1]
101 101 ar = c[t].apply_async(wait, 1)
102 102 # give the monitor time to notice the message
103 103 time.sleep(.25)
104 104 ahr = self.client.get_result(ar.msg_ids)
105 105 self.assertTrue(isinstance(ahr, AsyncHubResult))
106 106 self.assertEquals(ahr.get(), ar.get())
107 107 ar2 = self.client.get_result(ar.msg_ids)
108 108 self.assertFalse(isinstance(ar2, AsyncHubResult))
109 109 c.close()
110 110
111 111 def test_ids_list(self):
112 112 """test client.ids"""
113 113 # self.add_engines(2)
114 114 ids = self.client.ids
115 115 self.assertEquals(ids, self.client._ids)
116 116 self.assertFalse(ids is self.client._ids)
117 117 ids.remove(ids[-1])
118 118 self.assertNotEquals(ids, self.client._ids)
119 119
120 120 def test_queue_status(self):
121 121 # self.addEngine(4)
122 122 ids = self.client.ids
123 123 id0 = ids[0]
124 124 qs = self.client.queue_status(targets=id0)
125 125 self.assertTrue(isinstance(qs, dict))
126 126 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
127 127 allqs = self.client.queue_status()
128 128 self.assertTrue(isinstance(allqs, dict))
129 self.assertEquals(sorted(allqs.keys()), self.client.ids)
129 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
130 unassigned = allqs.pop('unassigned')
130 131 for eid,qs in allqs.items():
131 132 self.assertTrue(isinstance(qs, dict))
132 133 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
133 134
134 135 def test_shutdown(self):
135 136 # self.addEngine(4)
136 137 ids = self.client.ids
137 138 id0 = ids[0]
138 139 self.client.shutdown(id0, block=True)
139 140 while id0 in self.client.ids:
140 141 time.sleep(0.1)
141 142 self.client.spin()
142 143
143 144 self.assertRaises(IndexError, lambda : self.client[id0])
144 145
145 146 def test_result_status(self):
146 147 pass
147 148 # to be written
@@ -1,301 +1,302 b''
1 1 """test View objects"""
2 2 #-------------------------------------------------------------------------------
3 3 # Copyright (C) 2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-------------------------------------------------------------------------------
8 8
9 9 #-------------------------------------------------------------------------------
10 10 # Imports
11 11 #-------------------------------------------------------------------------------
12 12
13 13 import time
14 14 from tempfile import mktemp
15 15
16 16 import zmq
17 17
18 18 from IPython import parallel as pmod
19 19 from IPython.parallel import error
20 20 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
21 21 from IPython.parallel import LoadBalancedView, DirectView
22 22 from IPython.parallel.util import interactive
23 23
24 24 from IPython.parallel.tests import add_engines
25 25
26 26 from .clienttest import ClusterTestCase, segfault, wait, skip_without
27 27
28 28 def setup():
29 29 add_engines(3)
30 30
31 31 class TestView(ClusterTestCase):
32 32
33 33 def test_segfault_task(self):
34 34 """test graceful handling of engine death (balanced)"""
35 35 # self.add_engines(1)
36 36 ar = self.client[-1].apply_async(segfault)
37 37 self.assertRaisesRemote(error.EngineError, ar.get)
38 38 eid = ar.engine_id
39 39 while eid in self.client.ids:
40 40 time.sleep(.01)
41 41 self.client.spin()
42 42
43 43 def test_segfault_mux(self):
44 44 """test graceful handling of engine death (direct)"""
45 45 # self.add_engines(1)
46 46 eid = self.client.ids[-1]
47 47 ar = self.client[eid].apply_async(segfault)
48 48 self.assertRaisesRemote(error.EngineError, ar.get)
49 49 eid = ar.engine_id
50 50 while eid in self.client.ids:
51 51 time.sleep(.01)
52 52 self.client.spin()
53 53
54 54 def test_push_pull(self):
55 55 """test pushing and pulling"""
56 56 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
57 57 t = self.client.ids[-1]
58 58 v = self.client[t]
59 59 push = v.push
60 60 pull = v.pull
61 61 v.block=True
62 62 nengines = len(self.client)
63 63 push({'data':data})
64 64 d = pull('data')
65 65 self.assertEquals(d, data)
66 66 self.client[:].push({'data':data})
67 67 d = self.client[:].pull('data', block=True)
68 68 self.assertEquals(d, nengines*[data])
69 69 ar = push({'data':data}, block=False)
70 70 self.assertTrue(isinstance(ar, AsyncResult))
71 71 r = ar.get()
72 72 ar = self.client[:].pull('data', block=False)
73 73 self.assertTrue(isinstance(ar, AsyncResult))
74 74 r = ar.get()
75 75 self.assertEquals(r, nengines*[data])
76 76 self.client[:].push(dict(a=10,b=20))
77 r = self.client[:].pull(('a','b'))
77 r = self.client[:].pull(('a','b'), block=True)
78 78 self.assertEquals(r, nengines*[[10,20]])
79 79
80 80 def test_push_pull_function(self):
81 81 "test pushing and pulling functions"
82 82 def testf(x):
83 83 return 2.0*x
84 84
85 85 t = self.client.ids[-1]
86 self.client[t].block=True
87 push = self.client[t].push
88 pull = self.client[t].pull
89 execute = self.client[t].execute
86 v = self.client[t]
87 v.block=True
88 push = v.push
89 pull = v.pull
90 execute = v.execute
90 91 push({'testf':testf})
91 92 r = pull('testf')
92 93 self.assertEqual(r(1.0), testf(1.0))
93 94 execute('r = testf(10)')
94 95 r = pull('r')
95 96 self.assertEquals(r, testf(10))
96 97 ar = self.client[:].push({'testf':testf}, block=False)
97 98 ar.get()
98 99 ar = self.client[:].pull('testf', block=False)
99 100 rlist = ar.get()
100 101 for r in rlist:
101 102 self.assertEqual(r(1.0), testf(1.0))
102 103 execute("def g(x): return x*x")
103 104 r = pull(('testf','g'))
104 105 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
105 106
106 107 def test_push_function_globals(self):
107 108 """test that pushed functions have access to globals"""
108 109 @interactive
109 110 def geta():
110 111 return a
111 112 # self.add_engines(1)
112 113 v = self.client[-1]
113 114 v.block=True
114 115 v['f'] = geta
115 116 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
116 117 v.execute('a=5')
117 118 v.execute('b=f()')
118 119 self.assertEquals(v['b'], 5)
119 120
120 121 def test_push_function_defaults(self):
121 122 """test that pushed functions preserve default args"""
122 123 def echo(a=10):
123 124 return a
124 125 v = self.client[-1]
125 126 v.block=True
126 127 v['f'] = echo
127 128 v.execute('b=f()')
128 129 self.assertEquals(v['b'], 10)
129 130
130 131 def test_get_result(self):
131 132 """test getting results from the Hub."""
132 133 c = pmod.Client(profile='iptest')
133 134 # self.add_engines(1)
134 135 t = c.ids[-1]
135 136 v = c[t]
136 137 v2 = self.client[t]
137 138 ar = v.apply_async(wait, 1)
138 139 # give the monitor time to notice the message
139 140 time.sleep(.25)
140 141 ahr = v2.get_result(ar.msg_ids)
141 142 self.assertTrue(isinstance(ahr, AsyncHubResult))
142 143 self.assertEquals(ahr.get(), ar.get())
143 144 ar2 = v2.get_result(ar.msg_ids)
144 145 self.assertFalse(isinstance(ar2, AsyncHubResult))
145 146 c.spin()
146 147 c.close()
147 148
148 149 def test_run_newline(self):
149 150 """test that run appends newline to files"""
150 151 tmpfile = mktemp()
151 152 with open(tmpfile, 'w') as f:
152 153 f.write("""def g():
153 154 return 5
154 155 """)
155 156 v = self.client[-1]
156 157 v.run(tmpfile, block=True)
157 158 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
158 159
159 160 def test_apply_tracked(self):
160 161 """test tracking for apply"""
161 162 # self.add_engines(1)
162 163 t = self.client.ids[-1]
163 164 v = self.client[t]
164 165 v.block=False
165 166 def echo(n=1024*1024, **kwargs):
166 167 with v.temp_flags(**kwargs):
167 168 return v.apply(lambda x: x, 'x'*n)
168 169 ar = echo(1, track=False)
169 170 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
170 171 self.assertTrue(ar.sent)
171 172 ar = echo(track=True)
172 173 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
173 174 self.assertEquals(ar.sent, ar._tracker.done)
174 175 ar._tracker.wait()
175 176 self.assertTrue(ar.sent)
176 177
177 178 def test_push_tracked(self):
178 179 t = self.client.ids[-1]
179 180 ns = dict(x='x'*1024*1024)
180 181 v = self.client[t]
181 182 ar = v.push(ns, block=False, track=False)
182 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 184 self.assertTrue(ar.sent)
184 185
185 186 ar = v.push(ns, block=False, track=True)
186 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 188 self.assertEquals(ar.sent, ar._tracker.done)
188 189 ar._tracker.wait()
189 190 self.assertTrue(ar.sent)
190 191 ar.get()
191 192
192 193 def test_scatter_tracked(self):
193 194 t = self.client.ids
194 195 x='x'*1024*1024
195 196 ar = self.client[t].scatter('x', x, block=False, track=False)
196 197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 198 self.assertTrue(ar.sent)
198 199
199 200 ar = self.client[t].scatter('x', x, block=False, track=True)
200 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 202 self.assertEquals(ar.sent, ar._tracker.done)
202 203 ar._tracker.wait()
203 204 self.assertTrue(ar.sent)
204 205 ar.get()
205 206
206 207 def test_remote_reference(self):
207 208 v = self.client[-1]
208 209 v['a'] = 123
209 210 ra = pmod.Reference('a')
210 211 b = v.apply_sync(lambda x: x, ra)
211 212 self.assertEquals(b, 123)
212 213
213 214
214 215 def test_scatter_gather(self):
215 216 view = self.client[:]
216 217 seq1 = range(16)
217 218 view.scatter('a', seq1)
218 219 seq2 = view.gather('a', block=True)
219 220 self.assertEquals(seq2, seq1)
220 221 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
221 222
222 223 @skip_without('numpy')
223 224 def test_scatter_gather_numpy(self):
224 225 import numpy
225 226 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
226 227 view = self.client[:]
227 228 a = numpy.arange(64)
228 229 view.scatter('a', a)
229 230 b = view.gather('a', block=True)
230 231 assert_array_equal(b, a)
231 232
232 233 def test_map(self):
233 234 view = self.client[:]
234 235 def f(x):
235 236 return x**2
236 237 data = range(16)
237 238 r = view.map_sync(f, data)
238 239 self.assertEquals(r, map(f, data))
239 240
240 241 def test_scatterGatherNonblocking(self):
241 242 data = range(16)
242 243 view = self.client[:]
243 244 view.scatter('a', data, block=False)
244 245 ar = view.gather('a', block=False)
245 246 self.assertEquals(ar.get(), data)
246 247
247 248 @skip_without('numpy')
248 249 def test_scatter_gather_numpy_nonblocking(self):
249 250 import numpy
250 251 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
251 252 a = numpy.arange(64)
252 253 view = self.client[:]
253 254 ar = view.scatter('a', a, block=False)
254 255 self.assertTrue(isinstance(ar, AsyncResult))
255 256 amr = view.gather('a', block=False)
256 257 self.assertTrue(isinstance(amr, AsyncMapResult))
257 258 assert_array_equal(amr.get(), a)
258 259
259 260 def test_execute(self):
260 261 view = self.client[:]
261 262 # self.client.debug=True
262 263 execute = view.execute
263 264 ar = execute('c=30', block=False)
264 265 self.assertTrue(isinstance(ar, AsyncResult))
265 266 ar = execute('d=[0,1,2]', block=False)
266 267 self.client.wait(ar, 1)
267 268 self.assertEquals(len(ar.get()), len(self.client))
268 269 for c in view['c']:
269 270 self.assertEquals(c, 30)
270 271
271 272 def test_abort(self):
272 273 view = self.client[-1]
273 274 ar = view.execute('import time; time.sleep(0.25)', block=False)
274 275 ar2 = view.apply_async(lambda : 2)
275 276 ar3 = view.apply_async(lambda : 3)
276 277 view.abort(ar2)
277 278 view.abort(ar3.msg_ids)
278 279 self.assertRaises(error.TaskAborted, ar2.get)
279 280 self.assertRaises(error.TaskAborted, ar3.get)
280 281
281 282 def test_temp_flags(self):
282 283 view = self.client[-1]
283 284 view.block=True
284 285 with view.temp_flags(block=False):
285 286 self.assertFalse(view.block)
286 287 self.assertTrue(view.block)
287 288
288 289 def test_importer(self):
289 290 view = self.client[-1]
290 291 view.clear(block=True)
291 292 with view.importer:
292 293 import re
293 294
294 295 @interactive
295 296 def findall(pat, s):
296 297 # this globals() step isn't necessary in real code
297 298 # only to prevent a closure in the test
298 299 return globals()['re'].findall(pat, s)
299 300
300 301 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
301 302
General Comments 0
You need to be logged in to leave comments. Login now