##// END OF EJS Templates
remove some extraneous print statements from IPython.parallel...
MinRK -
Show More
@@ -1,283 +1,282 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import sys
22 22 import warnings
23 23
24 24 from IPython.external.decorator import decorator
25 25 from IPython.testing.skipdoctest import skip_doctest
26 26
27 27 from . import map as Map
28 28 from .asyncresult import AsyncMapResult
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Functions and Decorators
32 32 #-----------------------------------------------------------------------------
33 33
34 34 @skip_doctest
35 35 def remote(view, block=None, **flags):
36 36 """Turn a function into a remote function.
37 37
38 38 This method can be used for map:
39 39
40 40 In [1]: @remote(view,block=True)
41 41 ...: def func(a):
42 42 ...: pass
43 43 """
44 44
45 45 def remote_function(f):
46 46 return RemoteFunction(view, f, block=block, **flags)
47 47 return remote_function
48 48
49 49 @skip_doctest
50 50 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 51 """Turn a function into a parallel remote function.
52 52
53 53 This method can be used for map:
54 54
55 55 In [1]: @parallel(view, block=True)
56 56 ...: def func(a):
57 57 ...: pass
58 58 """
59 59
60 60 def parallel_function(f):
61 61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 62 return parallel_function
63 63
64 64 def getname(f):
65 65 """Get the name of an object.
66 66
67 67 For use in case of callables that are not functions, and
68 68 thus may not have __name__ defined.
69 69
70 70 Order: f.__name__ > f.name > str(f)
71 71 """
72 72 try:
73 73 return f.__name__
74 74 except:
75 75 pass
76 76 try:
77 77 return f.name
78 78 except:
79 79 pass
80 80
81 81 return str(f)
82 82
83 83 @decorator
84 84 def sync_view_results(f, self, *args, **kwargs):
85 85 """sync relevant results from self.client to our results attribute.
86 86
87 87 This is a clone of view.sync_results, but for remote functions
88 88 """
89 89 view = self.view
90 90 if view._in_sync_results:
91 91 return f(self, *args, **kwargs)
92 print 'in sync results', f
93 92 view._in_sync_results = True
94 93 try:
95 94 ret = f(self, *args, **kwargs)
96 95 finally:
97 96 view._in_sync_results = False
98 97 view._sync_results()
99 98 return ret
100 99
101 100 #--------------------------------------------------------------------------
102 101 # Classes
103 102 #--------------------------------------------------------------------------
104 103
105 104 class RemoteFunction(object):
106 105 """Turn an existing function into a remote function.
107 106
108 107 Parameters
109 108 ----------
110 109
111 110 view : View instance
112 111 The view to be used for execution
113 112 f : callable
114 113 The function to be wrapped into a remote function
115 114 block : bool [default: None]
116 115 Whether to wait for results or not. The default behavior is
117 116 to use the current `block` attribute of `view`
118 117
119 118 **flags : remaining kwargs are passed to View.temp_flags
120 119 """
121 120
122 121 view = None # the remote connection
123 122 func = None # the wrapped function
124 123 block = None # whether to block
125 124 flags = None # dict of extra kwargs for temp_flags
126 125
127 126 def __init__(self, view, f, block=None, **flags):
128 127 self.view = view
129 128 self.func = f
130 129 self.block=block
131 130 self.flags=flags
132 131
133 132 def __call__(self, *args, **kwargs):
134 133 block = self.view.block if self.block is None else self.block
135 134 with self.view.temp_flags(block=block, **self.flags):
136 135 return self.view.apply(self.func, *args, **kwargs)
137 136
138 137
139 138 class ParallelFunction(RemoteFunction):
140 139 """Class for mapping a function to sequences.
141 140
142 141 This will distribute the sequences according the a mapper, and call
143 142 the function on each sub-sequence. If called via map, then the function
144 143 will be called once on each element, rather that each sub-sequence.
145 144
146 145 Parameters
147 146 ----------
148 147
149 148 view : View instance
150 149 The view to be used for execution
151 150 f : callable
152 151 The function to be wrapped into a remote function
153 152 dist : str [default: 'b']
154 153 The key for which mapObject to use to distribute sequences
155 154 options are:
156 155 * 'b' : use contiguous chunks in order
157 156 * 'r' : use round-robin striping
158 157 block : bool [default: None]
159 158 Whether to wait for results or not. The default behavior is
160 159 to use the current `block` attribute of `view`
161 160 chunksize : int or None
162 161 The size of chunk to use when breaking up sequences in a load-balanced manner
163 162 ordered : bool [default: True]
164 163 Whether the result should be kept in order. If False,
165 164 results become available as they arrive, regardless of submission order.
166 165 **flags : remaining kwargs are passed to View.temp_flags
167 166 """
168 167
169 168 chunksize = None
170 169 ordered = None
171 170 mapObject = None
172 171 _mapping = False
173 172
174 173 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
175 174 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
176 175 self.chunksize = chunksize
177 176 self.ordered = ordered
178 177
179 178 mapClass = Map.dists[dist]
180 179 self.mapObject = mapClass()
181 180
182 181 @sync_view_results
183 182 def __call__(self, *sequences):
184 183 client = self.view.client
185 184
186 185 lens = []
187 186 maxlen = minlen = -1
188 187 for i, seq in enumerate(sequences):
189 188 try:
190 189 n = len(seq)
191 190 except Exception:
192 191 seq = list(seq)
193 192 if isinstance(sequences, tuple):
194 193 # can't alter a tuple
195 194 sequences = list(sequences)
196 195 sequences[i] = seq
197 196 n = len(seq)
198 197 if n > maxlen:
199 198 maxlen = n
200 199 if minlen == -1 or n < minlen:
201 200 minlen = n
202 201 lens.append(n)
203 202
204 203 # check that the length of sequences match
205 204 if not self._mapping and minlen != maxlen:
206 205 msg = 'all sequences must have equal length, but have %s' % lens
207 206 raise ValueError(msg)
208 207
209 208 balanced = 'Balanced' in self.view.__class__.__name__
210 209 if balanced:
211 210 if self.chunksize:
212 211 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
213 212 else:
214 213 nparts = maxlen
215 214 targets = [None]*nparts
216 215 else:
217 216 if self.chunksize:
218 217 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
219 218 # multiplexed:
220 219 targets = self.view.targets
221 220 # 'all' is lazily evaluated at execution time, which is now:
222 221 if targets == 'all':
223 222 targets = client._build_targets(targets)[1]
224 223 elif isinstance(targets, int):
225 224 # single-engine view, targets must be iterable
226 225 targets = [targets]
227 226 nparts = len(targets)
228 227
229 228 msg_ids = []
230 229 for index, t in enumerate(targets):
231 230 args = []
232 231 for seq in sequences:
233 232 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
234 233 args.append(part)
235 234
236 235 if sum([len(arg) for arg in args]) == 0:
237 236 continue
238 237
239 238 if self._mapping:
240 239 if sys.version_info[0] >= 3:
241 240 f = lambda f, *sequences: list(map(f, *sequences))
242 241 else:
243 242 f = map
244 243 args = [self.func] + args
245 244 else:
246 245 f=self.func
247 246
248 247 view = self.view if balanced else client[t]
249 248 with view.temp_flags(block=False, **self.flags):
250 249 ar = view.apply(f, *args)
251 250
252 251 msg_ids.extend(ar.msg_ids)
253 252
254 253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
255 254 fname=getname(self.func),
256 255 ordered=self.ordered
257 256 )
258 257
259 258 if self.block:
260 259 try:
261 260 return r.get()
262 261 except KeyboardInterrupt:
263 262 return r
264 263 else:
265 264 return r
266 265
267 266 def map(self, *sequences):
268 267 """call a function on each element of one or more sequence(s) remotely.
269 268 This should behave very much like the builtin map, but return an AsyncMapResult
270 269 if self.block is False.
271 270
272 271 That means it can take generators (will be cast to lists locally),
273 272 and mismatched sequence lengths will be padded with None.
274 273 """
275 274 # set _mapping as a flag for use inside self.__call__
276 275 self._mapping = True
277 276 try:
278 277 ret = self(*sequences)
279 278 finally:
280 279 self._mapping = False
281 280 return ret
282 281
283 282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,126 +1,125 b''
1 1 """toplevel setup/teardown for parallel tests."""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import tempfile
16 16 import time
17 17 from subprocess import Popen
18 18
19 19 from IPython.utils.path import get_ipython_dir
20 20 from IPython.parallel import Client
21 21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
22 22 ipengine_cmd_argv,
23 23 ipcontroller_cmd_argv,
24 24 SIGKILL,
25 25 ProcessStateError,
26 26 )
27 27
28 28 # globals
29 29 launchers = []
30 30 blackhole = open(os.devnull, 'w')
31 31
32 32 # Launcher class
33 33 class TestProcessLauncher(LocalProcessLauncher):
34 34 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
35 35 def start(self):
36 36 if self.state == 'before':
37 37 self.process = Popen(self.args,
38 38 stdout=blackhole, stderr=blackhole,
39 39 env=os.environ,
40 40 cwd=self.work_dir
41 41 )
42 42 self.notify_start(self.process.pid)
43 43 self.poll = self.process.poll
44 44 else:
45 45 s = 'The process was already started and has state: %r' % self.state
46 46 raise ProcessStateError(s)
47 47
48 48 # nose setup/teardown
49 49
50 50 def setup():
51 51 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
52 52 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
53 53 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
54 54 for json in (engine_json, client_json):
55 55 if os.path.exists(json):
56 56 os.remove(json)
57 57
58 58 cp = TestProcessLauncher()
59 59 cp.cmd_and_args = ipcontroller_cmd_argv + \
60 60 ['--profile=iptest', '--log-level=50', '--ping=250', '--dictdb']
61 61 cp.start()
62 62 launchers.append(cp)
63 63 tic = time.time()
64 64 while not os.path.exists(engine_json) or not os.path.exists(client_json):
65 65 if cp.poll() is not None:
66 print cp.poll()
67 raise RuntimeError("The test controller failed to start.")
66 raise RuntimeError("The test controller exited with status %s" % cp.poll())
68 67 elif time.time()-tic > 15:
69 68 raise RuntimeError("Timeout waiting for the test controller to start.")
70 69 time.sleep(0.1)
71 70 add_engines(1)
72 71
73 72 def add_engines(n=1, profile='iptest', total=False):
74 73 """add a number of engines to a given profile.
75 74
76 75 If total is True, then already running engines are counted, and only
77 76 the additional engines necessary (if any) are started.
78 77 """
79 78 rc = Client(profile=profile)
80 79 base = len(rc)
81 80
82 81 if total:
83 82 n = max(n - base, 0)
84 83
85 84 eps = []
86 85 for i in range(n):
87 86 ep = TestProcessLauncher()
88 87 ep.cmd_and_args = ipengine_cmd_argv + [
89 88 '--profile=%s' % profile,
90 89 '--log-level=50',
91 90 '--InteractiveShell.colors=nocolor'
92 91 ]
93 92 ep.start()
94 93 launchers.append(ep)
95 94 eps.append(ep)
96 95 tic = time.time()
97 96 while len(rc) < base+n:
98 97 if any([ ep.poll() is not None for ep in eps ]):
99 98 raise RuntimeError("A test engine failed to start.")
100 99 elif time.time()-tic > 15:
101 100 raise RuntimeError("Timeout waiting for engines to connect.")
102 101 time.sleep(.1)
103 102 rc.spin()
104 103 rc.close()
105 104 return eps
106 105
107 106 def teardown():
108 107 time.sleep(1)
109 108 while launchers:
110 109 p = launchers.pop()
111 110 if p.poll() is None:
112 111 try:
113 112 p.stop()
114 113 except Exception as e:
115 114 print e
116 115 pass
117 116 if p.poll() is None:
118 117 time.sleep(.25)
119 118 if p.poll() is None:
120 119 try:
121 120 print 'cleaning up test process...'
122 121 p.signal(SIGKILL)
123 122 except:
124 123 print "couldn't shutdown process: ", p
125 124 blackhole.close()
126 125
@@ -1,518 +1,517 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython import parallel
28 28 from IPython.parallel.client import client as clientmod
29 29 from IPython.parallel import error
30 30 from IPython.parallel import AsyncResult, AsyncHubResult
31 31 from IPython.parallel import LoadBalancedView, DirectView
32 32
33 33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34 34
35 35 def setup():
36 36 add_engines(4, total=True)
37 37
38 38 class TestClient(ClusterTestCase):
39 39
40 40 def test_ids(self):
41 41 n = len(self.client.ids)
42 42 self.add_engines(2)
43 43 self.assertEqual(len(self.client.ids), n+2)
44 44
45 45 def test_view_indexing(self):
46 46 """test index access for views"""
47 47 self.minimum_engines(4)
48 48 targets = self.client._build_targets('all')[-1]
49 49 v = self.client[:]
50 50 self.assertEqual(v.targets, targets)
51 51 t = self.client.ids[2]
52 52 v = self.client[t]
53 53 self.assertTrue(isinstance(v, DirectView))
54 54 self.assertEqual(v.targets, t)
55 55 t = self.client.ids[2:4]
56 56 v = self.client[t]
57 57 self.assertTrue(isinstance(v, DirectView))
58 58 self.assertEqual(v.targets, t)
59 59 v = self.client[::2]
60 60 self.assertTrue(isinstance(v, DirectView))
61 61 self.assertEqual(v.targets, targets[::2])
62 62 v = self.client[1::3]
63 63 self.assertTrue(isinstance(v, DirectView))
64 64 self.assertEqual(v.targets, targets[1::3])
65 65 v = self.client[:-3]
66 66 self.assertTrue(isinstance(v, DirectView))
67 67 self.assertEqual(v.targets, targets[:-3])
68 68 v = self.client[-1]
69 69 self.assertTrue(isinstance(v, DirectView))
70 70 self.assertEqual(v.targets, targets[-1])
71 71 self.assertRaises(TypeError, lambda : self.client[None])
72 72
73 73 def test_lbview_targets(self):
74 74 """test load_balanced_view targets"""
75 75 v = self.client.load_balanced_view()
76 76 self.assertEqual(v.targets, None)
77 77 v = self.client.load_balanced_view(-1)
78 78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 79 v = self.client.load_balanced_view('all')
80 80 self.assertEqual(v.targets, None)
81 81
82 82 def test_dview_targets(self):
83 83 """test direct_view targets"""
84 84 v = self.client.direct_view()
85 85 self.assertEqual(v.targets, 'all')
86 86 v = self.client.direct_view('all')
87 87 self.assertEqual(v.targets, 'all')
88 88 v = self.client.direct_view(-1)
89 89 self.assertEqual(v.targets, self.client.ids[-1])
90 90
91 91 def test_lazy_all_targets(self):
92 92 """test lazy evaluation of rc.direct_view('all')"""
93 93 v = self.client.direct_view()
94 94 self.assertEqual(v.targets, 'all')
95 95
96 96 def double(x):
97 97 return x*2
98 98 seq = range(100)
99 99 ref = [ double(x) for x in seq ]
100 100
101 101 # add some engines, which should be used
102 102 self.add_engines(1)
103 103 n1 = len(self.client.ids)
104 104
105 105 # simple apply
106 106 r = v.apply_sync(lambda : 1)
107 107 self.assertEqual(r, [1] * n1)
108 108
109 109 # map goes through remotefunction
110 110 r = v.map_sync(double, seq)
111 111 self.assertEqual(r, ref)
112 112
113 113 # add a couple more engines, and try again
114 114 self.add_engines(2)
115 115 n2 = len(self.client.ids)
116 116 self.assertNotEqual(n2, n1)
117 117
118 118 # apply
119 119 r = v.apply_sync(lambda : 1)
120 120 self.assertEqual(r, [1] * n2)
121 121
122 122 # map
123 123 r = v.map_sync(double, seq)
124 124 self.assertEqual(r, ref)
125 125
126 126 def test_targets(self):
127 127 """test various valid targets arguments"""
128 128 build = self.client._build_targets
129 129 ids = self.client.ids
130 130 idents,targets = build(None)
131 131 self.assertEqual(ids, targets)
132 132
133 133 def test_clear(self):
134 134 """test clear behavior"""
135 135 self.minimum_engines(2)
136 136 v = self.client[:]
137 137 v.block=True
138 138 v.push(dict(a=5))
139 139 v.pull('a')
140 140 id0 = self.client.ids[-1]
141 141 self.client.clear(targets=id0, block=True)
142 142 a = self.client[:-1].get('a')
143 143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 144 self.client.clear(block=True)
145 145 for i in self.client.ids:
146 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 147
148 148 def test_get_result(self):
149 149 """test getting results from the Hub."""
150 150 c = clientmod.Client(profile='iptest')
151 151 t = c.ids[-1]
152 152 ar = c[t].apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 155 ahr = self.client.get_result(ar.msg_ids[0])
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 157 self.assertEqual(ahr.get(), ar.get())
158 158 ar2 = self.client.get_result(ar.msg_ids[0])
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.close()
161 161
162 162 def test_get_execute_result(self):
163 163 """test getting execute results from the Hub."""
164 164 c = clientmod.Client(profile='iptest')
165 165 t = c.ids[-1]
166 166 cell = '\n'.join([
167 167 'import time',
168 168 'time.sleep(0.25)',
169 169 '5'
170 170 ])
171 171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 172 # give the monitor time to notice the message
173 173 time.sleep(.25)
174 174 ahr = self.client.get_result(ar.msg_ids[0])
175 print ar.get(), ahr.get(), ar._single_result, ahr._single_result
176 175 self.assertTrue(isinstance(ahr, AsyncHubResult))
177 176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
178 177 ar2 = self.client.get_result(ar.msg_ids[0])
179 178 self.assertFalse(isinstance(ar2, AsyncHubResult))
180 179 c.close()
181 180
182 181 def test_ids_list(self):
183 182 """test client.ids"""
184 183 ids = self.client.ids
185 184 self.assertEqual(ids, self.client._ids)
186 185 self.assertFalse(ids is self.client._ids)
187 186 ids.remove(ids[-1])
188 187 self.assertNotEqual(ids, self.client._ids)
189 188
190 189 def test_queue_status(self):
191 190 ids = self.client.ids
192 191 id0 = ids[0]
193 192 qs = self.client.queue_status(targets=id0)
194 193 self.assertTrue(isinstance(qs, dict))
195 194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
196 195 allqs = self.client.queue_status()
197 196 self.assertTrue(isinstance(allqs, dict))
198 197 intkeys = list(allqs.keys())
199 198 intkeys.remove('unassigned')
200 199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
201 200 unassigned = allqs.pop('unassigned')
202 201 for eid,qs in allqs.items():
203 202 self.assertTrue(isinstance(qs, dict))
204 203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
205 204
206 205 def test_shutdown(self):
207 206 ids = self.client.ids
208 207 id0 = ids[0]
209 208 self.client.shutdown(id0, block=True)
210 209 while id0 in self.client.ids:
211 210 time.sleep(0.1)
212 211 self.client.spin()
213 212
214 213 self.assertRaises(IndexError, lambda : self.client[id0])
215 214
216 215 def test_result_status(self):
217 216 pass
218 217 # to be written
219 218
220 219 def test_db_query_dt(self):
221 220 """test db query by date"""
222 221 hist = self.client.hub_history()
223 222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
224 223 tic = middle['submitted']
225 224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
226 225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
227 226 self.assertEqual(len(before)+len(after),len(hist))
228 227 for b in before:
229 228 self.assertTrue(b['submitted'] < tic)
230 229 for a in after:
231 230 self.assertTrue(a['submitted'] >= tic)
232 231 same = self.client.db_query({'submitted' : tic})
233 232 for s in same:
234 233 self.assertTrue(s['submitted'] == tic)
235 234
236 235 def test_db_query_keys(self):
237 236 """test extracting subset of record keys"""
238 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
239 238 for rec in found:
240 239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
241 240
242 241 def test_db_query_default_keys(self):
243 242 """default db_query excludes buffers"""
244 243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 244 for rec in found:
246 245 keys = set(rec.keys())
247 246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
248 247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
249 248
250 249 def test_db_query_msg_id(self):
251 250 """ensure msg_id is always in db queries"""
252 251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
253 252 for rec in found:
254 253 self.assertTrue('msg_id' in rec.keys())
255 254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
256 255 for rec in found:
257 256 self.assertTrue('msg_id' in rec.keys())
258 257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
259 258 for rec in found:
260 259 self.assertTrue('msg_id' in rec.keys())
261 260
262 261 def test_db_query_get_result(self):
263 262 """pop in db_query shouldn't pop from result itself"""
264 263 self.client[:].apply_sync(lambda : 1)
265 264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
266 265 rc2 = clientmod.Client(profile='iptest')
267 266 # If this bug is not fixed, this call will hang:
268 267 ar = rc2.get_result(self.client.history[-1])
269 268 ar.wait(2)
270 269 self.assertTrue(ar.ready())
271 270 ar.get()
272 271 rc2.close()
273 272
274 273 def test_db_query_in(self):
275 274 """test db query with '$in','$nin' operators"""
276 275 hist = self.client.hub_history()
277 276 even = hist[::2]
278 277 odd = hist[1::2]
279 278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
280 279 found = [ r['msg_id'] for r in recs ]
281 280 self.assertEqual(set(even), set(found))
282 281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
283 282 found = [ r['msg_id'] for r in recs ]
284 283 self.assertEqual(set(odd), set(found))
285 284
286 285 def test_hub_history(self):
287 286 hist = self.client.hub_history()
288 287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
289 288 recdict = {}
290 289 for rec in recs:
291 290 recdict[rec['msg_id']] = rec
292 291
293 292 latest = datetime(1984,1,1)
294 293 for msg_id in hist:
295 294 rec = recdict[msg_id]
296 295 newt = rec['submitted']
297 296 self.assertTrue(newt >= latest)
298 297 latest = newt
299 298 ar = self.client[-1].apply_async(lambda : 1)
300 299 ar.get()
301 300 time.sleep(0.25)
302 301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
303 302
304 303 def _wait_for_idle(self):
305 304 """wait for an engine to become idle, according to the Hub"""
306 305 rc = self.client
307 306
308 307 # step 1. wait for all requests to be noticed
309 308 # timeout 5s, polling every 100ms
310 309 msg_ids = set(rc.history)
311 310 hub_hist = rc.hub_history()
312 311 for i in range(50):
313 312 if msg_ids.difference(hub_hist):
314 313 time.sleep(0.1)
315 314 hub_hist = rc.hub_history()
316 315 else:
317 316 break
318 317
319 318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
320 319
321 320 # step 2. wait for all requests to be done
322 321 # timeout 5s, polling every 100ms
323 322 qs = rc.queue_status()
324 323 for i in range(50):
325 324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
326 325 time.sleep(0.1)
327 326 qs = rc.queue_status()
328 327 else:
329 328 break
330 329
331 330 # ensure Hub up to date:
332 331 self.assertEqual(qs['unassigned'], 0)
333 332 for eid in rc.ids:
334 333 self.assertEqual(qs[eid]['tasks'], 0)
335 334
336 335
337 336 def test_resubmit(self):
338 337 def f():
339 338 import random
340 339 return random.random()
341 340 v = self.client.load_balanced_view()
342 341 ar = v.apply_async(f)
343 342 r1 = ar.get(1)
344 343 # give the Hub a chance to notice:
345 344 self._wait_for_idle()
346 345 ahr = self.client.resubmit(ar.msg_ids)
347 346 r2 = ahr.get(1)
348 347 self.assertFalse(r1 == r2)
349 348
350 349 def test_resubmit_chain(self):
351 350 """resubmit resubmitted tasks"""
352 351 v = self.client.load_balanced_view()
353 352 ar = v.apply_async(lambda x: x, 'x'*1024)
354 353 ar.get()
355 354 self._wait_for_idle()
356 355 ars = [ar]
357 356
358 357 for i in range(10):
359 358 ar = ars[-1]
360 359 ar2 = self.client.resubmit(ar.msg_ids)
361 360
362 361 [ ar.get() for ar in ars ]
363 362
364 363 def test_resubmit_header(self):
365 364 """resubmit shouldn't clobber the whole header"""
366 365 def f():
367 366 import random
368 367 return random.random()
369 368 v = self.client.load_balanced_view()
370 369 v.retries = 1
371 370 ar = v.apply_async(f)
372 371 r1 = ar.get(1)
373 372 # give the Hub a chance to notice:
374 373 self._wait_for_idle()
375 374 ahr = self.client.resubmit(ar.msg_ids)
376 375 ahr.get(1)
377 376 time.sleep(0.5)
378 377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
379 378 h1,h2 = [ r['header'] for r in records ]
380 379 for key in set(h1.keys()).union(set(h2.keys())):
381 380 if key in ('msg_id', 'date'):
382 381 self.assertNotEqual(h1[key], h2[key])
383 382 else:
384 383 self.assertEqual(h1[key], h2[key])
385 384
386 385 def test_resubmit_aborted(self):
387 386 def f():
388 387 import random
389 388 return random.random()
390 389 v = self.client.load_balanced_view()
391 390 # restrict to one engine, so we can put a sleep
392 391 # ahead of the task, so it will get aborted
393 392 eid = self.client.ids[-1]
394 393 v.targets = [eid]
395 394 sleep = v.apply_async(time.sleep, 0.5)
396 395 ar = v.apply_async(f)
397 396 ar.abort()
398 397 self.assertRaises(error.TaskAborted, ar.get)
399 398 # Give the Hub a chance to get up to date:
400 399 self._wait_for_idle()
401 400 ahr = self.client.resubmit(ar.msg_ids)
402 401 r2 = ahr.get(1)
403 402
404 403 def test_resubmit_inflight(self):
405 404 """resubmit of inflight task"""
406 405 v = self.client.load_balanced_view()
407 406 ar = v.apply_async(time.sleep,1)
408 407 # give the message a chance to arrive
409 408 time.sleep(0.2)
410 409 ahr = self.client.resubmit(ar.msg_ids)
411 410 ar.get(2)
412 411 ahr.get(2)
413 412
414 413 def test_resubmit_badkey(self):
415 414 """ensure KeyError on resubmit of nonexistant task"""
416 415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
417 416
418 417 def test_purge_hub_results(self):
419 418 # ensure there are some tasks
420 419 for i in range(5):
421 420 self.client[:].apply_sync(lambda : 1)
422 421 # Wait for the Hub to realise the result is done:
423 422 # This prevents a race condition, where we
424 423 # might purge a result the Hub still thinks is pending.
425 424 self._wait_for_idle()
426 425 rc2 = clientmod.Client(profile='iptest')
427 426 hist = self.client.hub_history()
428 427 ahr = rc2.get_result([hist[-1]])
429 428 ahr.wait(10)
430 429 self.client.purge_hub_results(hist[-1])
431 430 newhist = self.client.hub_history()
432 431 self.assertEqual(len(newhist)+1,len(hist))
433 432 rc2.spin()
434 433 rc2.close()
435 434
436 435 def test_purge_local_results(self):
437 436 # ensure there are some tasks
438 437 res = []
439 438 for i in range(5):
440 439 res.append(self.client[:].apply_async(lambda : 1))
441 440 self._wait_for_idle()
442 441 self.client.wait(10) # wait for the results to come back
443 442 before = len(self.client.results)
444 443 self.assertEqual(len(self.client.metadata),before)
445 444 self.client.purge_local_results(res[-1])
446 445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
447 446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
448 447
449 448 def test_purge_all_hub_results(self):
450 449 self.client.purge_hub_results('all')
451 450 hist = self.client.hub_history()
452 451 self.assertEqual(len(hist), 0)
453 452
454 453 def test_purge_all_local_results(self):
455 454 self.client.purge_local_results('all')
456 455 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
457 456 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
458 457
459 458 def test_purge_all_results(self):
460 459 # ensure there are some tasks
461 460 for i in range(5):
462 461 self.client[:].apply_sync(lambda : 1)
463 462 self.client.wait(10)
464 463 self._wait_for_idle()
465 464 self.client.purge_results('all')
466 465 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
467 466 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
468 467 hist = self.client.hub_history()
469 468 self.assertEqual(len(hist), 0, msg="hub history not empty")
470 469
471 470 def test_purge_everything(self):
472 471 # ensure there are some tasks
473 472 for i in range(5):
474 473 self.client[:].apply_sync(lambda : 1)
475 474 self.client.wait(10)
476 475 self._wait_for_idle()
477 476 self.client.purge_everything()
478 477 # The client results
479 478 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
480 479 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
481 480 # The client "bookkeeping"
482 481 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
483 482 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
484 483 # the hub results
485 484 hist = self.client.hub_history()
486 485 self.assertEqual(len(hist), 0, msg="hub history not empty")
487 486
488 487
489 488 def test_spin_thread(self):
490 489 self.client.spin_thread(0.01)
491 490 ar = self.client[-1].apply_async(lambda : 1)
492 491 time.sleep(0.1)
493 492 self.assertTrue(ar.wall_time < 0.1,
494 493 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
495 494 )
496 495
497 496 def test_stop_spin_thread(self):
498 497 self.client.spin_thread(0.01)
499 498 self.client.stop_spin_thread()
500 499 ar = self.client[-1].apply_async(lambda : 1)
501 500 time.sleep(0.15)
502 501 self.assertTrue(ar.wall_time > 0.1,
503 502 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
504 503 )
505 504
506 505 def test_activate(self):
507 506 ip = get_ipython()
508 507 magics = ip.magics_manager.magics
509 508 self.assertTrue('px' in magics['line'])
510 509 self.assertTrue('px' in magics['cell'])
511 510 v0 = self.client.activate(-1, '0')
512 511 self.assertTrue('px0' in magics['line'])
513 512 self.assertTrue('px0' in magics['cell'])
514 513 self.assertEqual(v0.targets, self.client.ids[-1])
515 514 v0 = self.client.activate('all', 'all')
516 515 self.assertTrue('pxall' in magics['line'])
517 516 self.assertTrue('pxall' in magics['cell'])
518 517 self.assertEqual(v0.targets, 'all')
General Comments 0
You need to be logged in to leave comments. Login now