##// END OF EJS Templates
test view.map on a generator
MinRK -
Show More
@@ -1,203 +1,211
1 1 # -*- coding: utf-8 -*-
2 2 """test LoadBalancedView objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from nose import SkipTest
24 24 from nose.plugins.attrib import attr
25 25
26 26 from IPython import parallel as pmod
27 27 from IPython.parallel import error
28 28
29 29 from IPython.parallel.tests import add_engines
30 30
31 31 from .clienttest import ClusterTestCase, crash, wait, skip_without
32 32
33 33 def setup():
34 34 add_engines(3, total=True)
35 35
36 36 class TestLoadBalancedView(ClusterTestCase):
37 37
38 38 def setUp(self):
39 39 ClusterTestCase.setUp(self)
40 40 self.view = self.client.load_balanced_view()
41 41
42 42 @attr('crash')
43 43 def test_z_crash_task(self):
44 44 """test graceful handling of engine death (balanced)"""
45 45 # self.add_engines(1)
46 46 ar = self.view.apply_async(crash)
47 47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 48 eid = ar.engine_id
49 49 tic = time.time()
50 50 while eid in self.client.ids and time.time()-tic < 5:
51 51 time.sleep(.01)
52 52 self.client.spin()
53 53 self.assertFalse(eid in self.client.ids, "Engine should have died")
54 54
55 55 def test_map(self):
56 56 def f(x):
57 57 return x**2
58 58 data = range(16)
59 59 r = self.view.map_sync(f, data)
60 60 self.assertEqual(r, map(f, data))
61 61
62 def test_map_generator(self):
63 def f(x):
64 return x**2
65
66 data = range(16)
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
69
62 70 def test_map_short_first(self):
63 71 def f(x,y):
64 72 if y is None:
65 73 return y
66 74 if x is None:
67 75 return x
68 76 return x*y
69 77 data = range(10)
70 78 data2 = range(4)
71 79
72 80 r = self.view.map_sync(f, data, data2)
73 81 self.assertEqual(r, map(f, data, data2))
74 82
75 83 def test_map_short_last(self):
76 84 def f(x,y):
77 85 if y is None:
78 86 return y
79 87 if x is None:
80 88 return x
81 89 return x*y
82 90 data = range(4)
83 91 data2 = range(10)
84 92
85 93 r = self.view.map_sync(f, data, data2)
86 94 self.assertEqual(r, map(f, data, data2))
87 95
88 96 def test_map_unordered(self):
89 97 def f(x):
90 98 return x**2
91 99 def slow_f(x):
92 100 import time
93 101 time.sleep(0.05*x)
94 102 return x**2
95 103 data = range(16,0,-1)
96 104 reference = map(f, data)
97 105
98 106 amr = self.view.map_async(slow_f, data, ordered=False)
99 107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
100 108 # check individual elements, retrieved as they come
101 109 # list comprehension uses __iter__
102 110 astheycame = [ r for r in amr ]
103 111 # Ensure that at least one result came out of order:
104 112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
105 113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
106 114
107 115 def test_map_ordered(self):
108 116 def f(x):
109 117 return x**2
110 118 def slow_f(x):
111 119 import time
112 120 time.sleep(0.05*x)
113 121 return x**2
114 122 data = range(16,0,-1)
115 123 reference = map(f, data)
116 124
117 125 amr = self.view.map_async(slow_f, data)
118 126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
119 127 # check individual elements, retrieved as they come
120 128 # list(amr) uses __iter__
121 129 astheycame = list(amr)
122 130 # Ensure that results came in order
123 131 self.assertEqual(astheycame, reference)
124 132 self.assertEqual(amr.result, reference)
125 133
126 134 def test_map_iterable(self):
127 135 """test map on iterables (balanced)"""
128 136 view = self.view
129 137 # 101 is prime, so it won't be evenly distributed
130 138 arr = range(101)
131 139 # so that it will be an iterator, even in Python 3
132 140 it = iter(arr)
133 141 r = view.map_sync(lambda x:x, arr)
134 142 self.assertEqual(r, list(arr))
135 143
136 144
137 145 def test_abort(self):
138 146 view = self.view
139 147 ar = self.client[:].apply_async(time.sleep, .5)
140 148 ar = self.client[:].apply_async(time.sleep, .5)
141 149 time.sleep(0.2)
142 150 ar2 = view.apply_async(lambda : 2)
143 151 ar3 = view.apply_async(lambda : 3)
144 152 view.abort(ar2)
145 153 view.abort(ar3.msg_ids)
146 154 self.assertRaises(error.TaskAborted, ar2.get)
147 155 self.assertRaises(error.TaskAborted, ar3.get)
148 156
149 157 def test_retries(self):
150 158 view = self.view
151 159 view.timeout = 1 # prevent hang if this doesn't behave
152 160 def fail():
153 161 assert False
154 162 for r in range(len(self.client)-1):
155 163 with view.temp_flags(retries=r):
156 164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
157 165
158 166 with view.temp_flags(retries=len(self.client), timeout=0.25):
159 167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
160 168
161 169 def test_invalid_dependency(self):
162 170 view = self.view
163 171 with view.temp_flags(after='12345'):
164 172 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
165 173
166 174 def test_impossible_dependency(self):
167 175 self.minimum_engines(2)
168 176 view = self.client.load_balanced_view()
169 177 ar1 = view.apply_async(lambda : 1)
170 178 ar1.get()
171 179 e1 = ar1.engine_id
172 180 e2 = e1
173 181 while e2 == e1:
174 182 ar2 = view.apply_async(lambda : 1)
175 183 ar2.get()
176 184 e2 = ar2.engine_id
177 185
178 186 with view.temp_flags(follow=[ar1, ar2]):
179 187 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
180 188
181 189
182 190 def test_follow(self):
183 191 ar = self.view.apply_async(lambda : 1)
184 192 ar.get()
185 193 ars = []
186 194 first_id = ar.engine_id
187 195
188 196 self.view.follow = ar
189 197 for i in range(5):
190 198 ars.append(self.view.apply_async(lambda : 1))
191 199 self.view.wait(ars)
192 200 for ar in ars:
193 201 self.assertEqual(ar.engine_id, first_id)
194 202
195 203 def test_after(self):
196 204 view = self.view
197 205 ar = view.apply_async(time.sleep, 0.5)
198 206 with view.temp_flags(after=ar):
199 207 ar2 = view.apply_async(lambda : 1)
200 208
201 209 ar.wait()
202 210 ar2.wait()
203 211 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
General Comments 0
You need to be logged in to leave comments. Login now