##// END OF EJS Templates
remove some extraneous print statements from IPython.parallel...
MinRK -
Show More
@@ -1,283 +1,282 b''
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 * Min RK
6 * Min RK
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import sys
21 import sys
22 import warnings
22 import warnings
23
23
24 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
25 from IPython.testing.skipdoctest import skip_doctest
25 from IPython.testing.skipdoctest import skip_doctest
26
26
27 from . import map as Map
27 from . import map as Map
28 from .asyncresult import AsyncMapResult
28 from .asyncresult import AsyncMapResult
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Functions and Decorators
31 # Functions and Decorators
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 @skip_doctest
34 @skip_doctest
35 def remote(view, block=None, **flags):
35 def remote(view, block=None, **flags):
36 """Turn a function into a remote function.
36 """Turn a function into a remote function.
37
37
38 This method can be used for map:
38 This method can be used for map:
39
39
40 In [1]: @remote(view,block=True)
40 In [1]: @remote(view,block=True)
41 ...: def func(a):
41 ...: def func(a):
42 ...: pass
42 ...: pass
43 """
43 """
44
44
45 def remote_function(f):
45 def remote_function(f):
46 return RemoteFunction(view, f, block=block, **flags)
46 return RemoteFunction(view, f, block=block, **flags)
47 return remote_function
47 return remote_function
48
48
49 @skip_doctest
49 @skip_doctest
50 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 """Turn a function into a parallel remote function.
51 """Turn a function into a parallel remote function.
52
52
53 This method can be used for map:
53 This method can be used for map:
54
54
55 In [1]: @parallel(view, block=True)
55 In [1]: @parallel(view, block=True)
56 ...: def func(a):
56 ...: def func(a):
57 ...: pass
57 ...: pass
58 """
58 """
59
59
60 def parallel_function(f):
60 def parallel_function(f):
61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 return parallel_function
62 return parallel_function
63
63
64 def getname(f):
64 def getname(f):
65 """Get the name of an object.
65 """Get the name of an object.
66
66
67 For use in case of callables that are not functions, and
67 For use in case of callables that are not functions, and
68 thus may not have __name__ defined.
68 thus may not have __name__ defined.
69
69
70 Order: f.__name__ > f.name > str(f)
70 Order: f.__name__ > f.name > str(f)
71 """
71 """
72 try:
72 try:
73 return f.__name__
73 return f.__name__
74 except:
74 except:
75 pass
75 pass
76 try:
76 try:
77 return f.name
77 return f.name
78 except:
78 except:
79 pass
79 pass
80
80
81 return str(f)
81 return str(f)
82
82
83 @decorator
83 @decorator
84 def sync_view_results(f, self, *args, **kwargs):
84 def sync_view_results(f, self, *args, **kwargs):
85 """sync relevant results from self.client to our results attribute.
85 """sync relevant results from self.client to our results attribute.
86
86
87 This is a clone of view.sync_results, but for remote functions
87 This is a clone of view.sync_results, but for remote functions
88 """
88 """
89 view = self.view
89 view = self.view
90 if view._in_sync_results:
90 if view._in_sync_results:
91 return f(self, *args, **kwargs)
91 return f(self, *args, **kwargs)
92 print 'in sync results', f
93 view._in_sync_results = True
92 view._in_sync_results = True
94 try:
93 try:
95 ret = f(self, *args, **kwargs)
94 ret = f(self, *args, **kwargs)
96 finally:
95 finally:
97 view._in_sync_results = False
96 view._in_sync_results = False
98 view._sync_results()
97 view._sync_results()
99 return ret
98 return ret
100
99
101 #--------------------------------------------------------------------------
100 #--------------------------------------------------------------------------
102 # Classes
101 # Classes
103 #--------------------------------------------------------------------------
102 #--------------------------------------------------------------------------
104
103
105 class RemoteFunction(object):
104 class RemoteFunction(object):
106 """Turn an existing function into a remote function.
105 """Turn an existing function into a remote function.
107
106
108 Parameters
107 Parameters
109 ----------
108 ----------
110
109
111 view : View instance
110 view : View instance
112 The view to be used for execution
111 The view to be used for execution
113 f : callable
112 f : callable
114 The function to be wrapped into a remote function
113 The function to be wrapped into a remote function
115 block : bool [default: None]
114 block : bool [default: None]
116 Whether to wait for results or not. The default behavior is
115 Whether to wait for results or not. The default behavior is
117 to use the current `block` attribute of `view`
116 to use the current `block` attribute of `view`
118
117
119 **flags : remaining kwargs are passed to View.temp_flags
118 **flags : remaining kwargs are passed to View.temp_flags
120 """
119 """
121
120
122 view = None # the remote connection
121 view = None # the remote connection
123 func = None # the wrapped function
122 func = None # the wrapped function
124 block = None # whether to block
123 block = None # whether to block
125 flags = None # dict of extra kwargs for temp_flags
124 flags = None # dict of extra kwargs for temp_flags
126
125
127 def __init__(self, view, f, block=None, **flags):
126 def __init__(self, view, f, block=None, **flags):
128 self.view = view
127 self.view = view
129 self.func = f
128 self.func = f
130 self.block=block
129 self.block=block
131 self.flags=flags
130 self.flags=flags
132
131
133 def __call__(self, *args, **kwargs):
132 def __call__(self, *args, **kwargs):
134 block = self.view.block if self.block is None else self.block
133 block = self.view.block if self.block is None else self.block
135 with self.view.temp_flags(block=block, **self.flags):
134 with self.view.temp_flags(block=block, **self.flags):
136 return self.view.apply(self.func, *args, **kwargs)
135 return self.view.apply(self.func, *args, **kwargs)
137
136
138
137
139 class ParallelFunction(RemoteFunction):
138 class ParallelFunction(RemoteFunction):
140 """Class for mapping a function to sequences.
139 """Class for mapping a function to sequences.
141
140
142 This will distribute the sequences according the a mapper, and call
141 This will distribute the sequences according the a mapper, and call
143 the function on each sub-sequence. If called via map, then the function
142 the function on each sub-sequence. If called via map, then the function
144 will be called once on each element, rather that each sub-sequence.
143 will be called once on each element, rather that each sub-sequence.
145
144
146 Parameters
145 Parameters
147 ----------
146 ----------
148
147
149 view : View instance
148 view : View instance
150 The view to be used for execution
149 The view to be used for execution
151 f : callable
150 f : callable
152 The function to be wrapped into a remote function
151 The function to be wrapped into a remote function
153 dist : str [default: 'b']
152 dist : str [default: 'b']
154 The key for which mapObject to use to distribute sequences
153 The key for which mapObject to use to distribute sequences
155 options are:
154 options are:
156 * 'b' : use contiguous chunks in order
155 * 'b' : use contiguous chunks in order
157 * 'r' : use round-robin striping
156 * 'r' : use round-robin striping
158 block : bool [default: None]
157 block : bool [default: None]
159 Whether to wait for results or not. The default behavior is
158 Whether to wait for results or not. The default behavior is
160 to use the current `block` attribute of `view`
159 to use the current `block` attribute of `view`
161 chunksize : int or None
160 chunksize : int or None
162 The size of chunk to use when breaking up sequences in a load-balanced manner
161 The size of chunk to use when breaking up sequences in a load-balanced manner
163 ordered : bool [default: True]
162 ordered : bool [default: True]
164 Whether the result should be kept in order. If False,
163 Whether the result should be kept in order. If False,
165 results become available as they arrive, regardless of submission order.
164 results become available as they arrive, regardless of submission order.
166 **flags : remaining kwargs are passed to View.temp_flags
165 **flags : remaining kwargs are passed to View.temp_flags
167 """
166 """
168
167
169 chunksize = None
168 chunksize = None
170 ordered = None
169 ordered = None
171 mapObject = None
170 mapObject = None
172 _mapping = False
171 _mapping = False
173
172
174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
173 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
174 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
176 self.chunksize = chunksize
175 self.chunksize = chunksize
177 self.ordered = ordered
176 self.ordered = ordered
178
177
179 mapClass = Map.dists[dist]
178 mapClass = Map.dists[dist]
180 self.mapObject = mapClass()
179 self.mapObject = mapClass()
181
180
182 @sync_view_results
181 @sync_view_results
183 def __call__(self, *sequences):
182 def __call__(self, *sequences):
184 client = self.view.client
183 client = self.view.client
185
184
186 lens = []
185 lens = []
187 maxlen = minlen = -1
186 maxlen = minlen = -1
188 for i, seq in enumerate(sequences):
187 for i, seq in enumerate(sequences):
189 try:
188 try:
190 n = len(seq)
189 n = len(seq)
191 except Exception:
190 except Exception:
192 seq = list(seq)
191 seq = list(seq)
193 if isinstance(sequences, tuple):
192 if isinstance(sequences, tuple):
194 # can't alter a tuple
193 # can't alter a tuple
195 sequences = list(sequences)
194 sequences = list(sequences)
196 sequences[i] = seq
195 sequences[i] = seq
197 n = len(seq)
196 n = len(seq)
198 if n > maxlen:
197 if n > maxlen:
199 maxlen = n
198 maxlen = n
200 if minlen == -1 or n < minlen:
199 if minlen == -1 or n < minlen:
201 minlen = n
200 minlen = n
202 lens.append(n)
201 lens.append(n)
203
202
204 # check that the length of sequences match
203 # check that the length of sequences match
205 if not self._mapping and minlen != maxlen:
204 if not self._mapping and minlen != maxlen:
206 msg = 'all sequences must have equal length, but have %s' % lens
205 msg = 'all sequences must have equal length, but have %s' % lens
207 raise ValueError(msg)
206 raise ValueError(msg)
208
207
209 balanced = 'Balanced' in self.view.__class__.__name__
208 balanced = 'Balanced' in self.view.__class__.__name__
210 if balanced:
209 if balanced:
211 if self.chunksize:
210 if self.chunksize:
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
211 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
213 else:
212 else:
214 nparts = maxlen
213 nparts = maxlen
215 targets = [None]*nparts
214 targets = [None]*nparts
216 else:
215 else:
217 if self.chunksize:
216 if self.chunksize:
218 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
217 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
219 # multiplexed:
218 # multiplexed:
220 targets = self.view.targets
219 targets = self.view.targets
221 # 'all' is lazily evaluated at execution time, which is now:
220 # 'all' is lazily evaluated at execution time, which is now:
222 if targets == 'all':
221 if targets == 'all':
223 targets = client._build_targets(targets)[1]
222 targets = client._build_targets(targets)[1]
224 elif isinstance(targets, int):
223 elif isinstance(targets, int):
225 # single-engine view, targets must be iterable
224 # single-engine view, targets must be iterable
226 targets = [targets]
225 targets = [targets]
227 nparts = len(targets)
226 nparts = len(targets)
228
227
229 msg_ids = []
228 msg_ids = []
230 for index, t in enumerate(targets):
229 for index, t in enumerate(targets):
231 args = []
230 args = []
232 for seq in sequences:
231 for seq in sequences:
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
232 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
234 args.append(part)
233 args.append(part)
235
234
236 if sum([len(arg) for arg in args]) == 0:
235 if sum([len(arg) for arg in args]) == 0:
237 continue
236 continue
238
237
239 if self._mapping:
238 if self._mapping:
240 if sys.version_info[0] >= 3:
239 if sys.version_info[0] >= 3:
241 f = lambda f, *sequences: list(map(f, *sequences))
240 f = lambda f, *sequences: list(map(f, *sequences))
242 else:
241 else:
243 f = map
242 f = map
244 args = [self.func] + args
243 args = [self.func] + args
245 else:
244 else:
246 f=self.func
245 f=self.func
247
246
248 view = self.view if balanced else client[t]
247 view = self.view if balanced else client[t]
249 with view.temp_flags(block=False, **self.flags):
248 with view.temp_flags(block=False, **self.flags):
250 ar = view.apply(f, *args)
249 ar = view.apply(f, *args)
251
250
252 msg_ids.extend(ar.msg_ids)
251 msg_ids.extend(ar.msg_ids)
253
252
254 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
255 fname=getname(self.func),
254 fname=getname(self.func),
256 ordered=self.ordered
255 ordered=self.ordered
257 )
256 )
258
257
259 if self.block:
258 if self.block:
260 try:
259 try:
261 return r.get()
260 return r.get()
262 except KeyboardInterrupt:
261 except KeyboardInterrupt:
263 return r
262 return r
264 else:
263 else:
265 return r
264 return r
266
265
267 def map(self, *sequences):
266 def map(self, *sequences):
268 """call a function on each element of one or more sequence(s) remotely.
267 """call a function on each element of one or more sequence(s) remotely.
269 This should behave very much like the builtin map, but return an AsyncMapResult
268 This should behave very much like the builtin map, but return an AsyncMapResult
270 if self.block is False.
269 if self.block is False.
271
270
272 That means it can take generators (will be cast to lists locally),
271 That means it can take generators (will be cast to lists locally),
273 and mismatched sequence lengths will be padded with None.
272 and mismatched sequence lengths will be padded with None.
274 """
273 """
275 # set _mapping as a flag for use inside self.__call__
274 # set _mapping as a flag for use inside self.__call__
276 self._mapping = True
275 self._mapping = True
277 try:
276 try:
278 ret = self(*sequences)
277 ret = self(*sequences)
279 finally:
278 finally:
280 self._mapping = False
279 self._mapping = False
281 return ret
280 return ret
282
281
283 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,126 +1,125 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import tempfile
15 import tempfile
16 import time
16 import time
17 from subprocess import Popen
17 from subprocess import Popen
18
18
19 from IPython.utils.path import get_ipython_dir
19 from IPython.utils.path import get_ipython_dir
20 from IPython.parallel import Client
20 from IPython.parallel import Client
21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
22 ipengine_cmd_argv,
22 ipengine_cmd_argv,
23 ipcontroller_cmd_argv,
23 ipcontroller_cmd_argv,
24 SIGKILL,
24 SIGKILL,
25 ProcessStateError,
25 ProcessStateError,
26 )
26 )
27
27
28 # globals
28 # globals
29 launchers = []
29 launchers = []
30 blackhole = open(os.devnull, 'w')
30 blackhole = open(os.devnull, 'w')
31
31
32 # Launcher class
32 # Launcher class
33 class TestProcessLauncher(LocalProcessLauncher):
33 class TestProcessLauncher(LocalProcessLauncher):
34 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
34 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
35 def start(self):
35 def start(self):
36 if self.state == 'before':
36 if self.state == 'before':
37 self.process = Popen(self.args,
37 self.process = Popen(self.args,
38 stdout=blackhole, stderr=blackhole,
38 stdout=blackhole, stderr=blackhole,
39 env=os.environ,
39 env=os.environ,
40 cwd=self.work_dir
40 cwd=self.work_dir
41 )
41 )
42 self.notify_start(self.process.pid)
42 self.notify_start(self.process.pid)
43 self.poll = self.process.poll
43 self.poll = self.process.poll
44 else:
44 else:
45 s = 'The process was already started and has state: %r' % self.state
45 s = 'The process was already started and has state: %r' % self.state
46 raise ProcessStateError(s)
46 raise ProcessStateError(s)
47
47
48 # nose setup/teardown
48 # nose setup/teardown
49
49
50 def setup():
50 def setup():
51 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
51 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
52 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
52 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
53 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
53 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
54 for json in (engine_json, client_json):
54 for json in (engine_json, client_json):
55 if os.path.exists(json):
55 if os.path.exists(json):
56 os.remove(json)
56 os.remove(json)
57
57
58 cp = TestProcessLauncher()
58 cp = TestProcessLauncher()
59 cp.cmd_and_args = ipcontroller_cmd_argv + \
59 cp.cmd_and_args = ipcontroller_cmd_argv + \
60 ['--profile=iptest', '--log-level=50', '--ping=250', '--dictdb']
60 ['--profile=iptest', '--log-level=50', '--ping=250', '--dictdb']
61 cp.start()
61 cp.start()
62 launchers.append(cp)
62 launchers.append(cp)
63 tic = time.time()
63 tic = time.time()
64 while not os.path.exists(engine_json) or not os.path.exists(client_json):
64 while not os.path.exists(engine_json) or not os.path.exists(client_json):
65 if cp.poll() is not None:
65 if cp.poll() is not None:
66 print cp.poll()
66 raise RuntimeError("The test controller exited with status %s" % cp.poll())
67 raise RuntimeError("The test controller failed to start.")
68 elif time.time()-tic > 15:
67 elif time.time()-tic > 15:
69 raise RuntimeError("Timeout waiting for the test controller to start.")
68 raise RuntimeError("Timeout waiting for the test controller to start.")
70 time.sleep(0.1)
69 time.sleep(0.1)
71 add_engines(1)
70 add_engines(1)
72
71
73 def add_engines(n=1, profile='iptest', total=False):
72 def add_engines(n=1, profile='iptest', total=False):
74 """add a number of engines to a given profile.
73 """add a number of engines to a given profile.
75
74
76 If total is True, then already running engines are counted, and only
75 If total is True, then already running engines are counted, and only
77 the additional engines necessary (if any) are started.
76 the additional engines necessary (if any) are started.
78 """
77 """
79 rc = Client(profile=profile)
78 rc = Client(profile=profile)
80 base = len(rc)
79 base = len(rc)
81
80
82 if total:
81 if total:
83 n = max(n - base, 0)
82 n = max(n - base, 0)
84
83
85 eps = []
84 eps = []
86 for i in range(n):
85 for i in range(n):
87 ep = TestProcessLauncher()
86 ep = TestProcessLauncher()
88 ep.cmd_and_args = ipengine_cmd_argv + [
87 ep.cmd_and_args = ipengine_cmd_argv + [
89 '--profile=%s' % profile,
88 '--profile=%s' % profile,
90 '--log-level=50',
89 '--log-level=50',
91 '--InteractiveShell.colors=nocolor'
90 '--InteractiveShell.colors=nocolor'
92 ]
91 ]
93 ep.start()
92 ep.start()
94 launchers.append(ep)
93 launchers.append(ep)
95 eps.append(ep)
94 eps.append(ep)
96 tic = time.time()
95 tic = time.time()
97 while len(rc) < base+n:
96 while len(rc) < base+n:
98 if any([ ep.poll() is not None for ep in eps ]):
97 if any([ ep.poll() is not None for ep in eps ]):
99 raise RuntimeError("A test engine failed to start.")
98 raise RuntimeError("A test engine failed to start.")
100 elif time.time()-tic > 15:
99 elif time.time()-tic > 15:
101 raise RuntimeError("Timeout waiting for engines to connect.")
100 raise RuntimeError("Timeout waiting for engines to connect.")
102 time.sleep(.1)
101 time.sleep(.1)
103 rc.spin()
102 rc.spin()
104 rc.close()
103 rc.close()
105 return eps
104 return eps
106
105
107 def teardown():
106 def teardown():
108 time.sleep(1)
107 time.sleep(1)
109 while launchers:
108 while launchers:
110 p = launchers.pop()
109 p = launchers.pop()
111 if p.poll() is None:
110 if p.poll() is None:
112 try:
111 try:
113 p.stop()
112 p.stop()
114 except Exception as e:
113 except Exception as e:
115 print e
114 print e
116 pass
115 pass
117 if p.poll() is None:
116 if p.poll() is None:
118 time.sleep(.25)
117 time.sleep(.25)
119 if p.poll() is None:
118 if p.poll() is None:
120 try:
119 try:
121 print 'cleaning up test process...'
120 print 'cleaning up test process...'
122 p.signal(SIGKILL)
121 p.signal(SIGKILL)
123 except:
122 except:
124 print "couldn't shutdown process: ", p
123 print "couldn't shutdown process: ", p
125 blackhole.close()
124 blackhole.close()
126
125
@@ -1,518 +1,517 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython import parallel
27 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
28 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
31 from IPython.parallel import LoadBalancedView, DirectView
32
32
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34
34
35 def setup():
35 def setup():
36 add_engines(4, total=True)
36 add_engines(4, total=True)
37
37
38 class TestClient(ClusterTestCase):
38 class TestClient(ClusterTestCase):
39
39
40 def test_ids(self):
40 def test_ids(self):
41 n = len(self.client.ids)
41 n = len(self.client.ids)
42 self.add_engines(2)
42 self.add_engines(2)
43 self.assertEqual(len(self.client.ids), n+2)
43 self.assertEqual(len(self.client.ids), n+2)
44
44
45 def test_view_indexing(self):
45 def test_view_indexing(self):
46 """test index access for views"""
46 """test index access for views"""
47 self.minimum_engines(4)
47 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
48 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
49 v = self.client[:]
50 self.assertEqual(v.targets, targets)
50 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
51 t = self.client.ids[2]
52 v = self.client[t]
52 v = self.client[t]
53 self.assertTrue(isinstance(v, DirectView))
53 self.assertTrue(isinstance(v, DirectView))
54 self.assertEqual(v.targets, t)
54 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
55 t = self.client.ids[2:4]
56 v = self.client[t]
56 v = self.client[t]
57 self.assertTrue(isinstance(v, DirectView))
57 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, t)
58 self.assertEqual(v.targets, t)
59 v = self.client[::2]
59 v = self.client[::2]
60 self.assertTrue(isinstance(v, DirectView))
60 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[::2])
61 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
62 v = self.client[1::3]
63 self.assertTrue(isinstance(v, DirectView))
63 self.assertTrue(isinstance(v, DirectView))
64 self.assertEqual(v.targets, targets[1::3])
64 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
65 v = self.client[:-3]
66 self.assertTrue(isinstance(v, DirectView))
66 self.assertTrue(isinstance(v, DirectView))
67 self.assertEqual(v.targets, targets[:-3])
67 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
68 v = self.client[-1]
69 self.assertTrue(isinstance(v, DirectView))
69 self.assertTrue(isinstance(v, DirectView))
70 self.assertEqual(v.targets, targets[-1])
70 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
71 self.assertRaises(TypeError, lambda : self.client[None])
72
72
73 def test_lbview_targets(self):
73 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
74 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
75 v = self.client.load_balanced_view()
76 self.assertEqual(v.targets, None)
76 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
77 v = self.client.load_balanced_view(-1)
78 self.assertEqual(v.targets, [self.client.ids[-1]])
78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
79 v = self.client.load_balanced_view('all')
80 self.assertEqual(v.targets, None)
80 self.assertEqual(v.targets, None)
81
81
82 def test_dview_targets(self):
82 def test_dview_targets(self):
83 """test direct_view targets"""
83 """test direct_view targets"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
86 v = self.client.direct_view('all')
87 self.assertEqual(v.targets, 'all')
87 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
88 v = self.client.direct_view(-1)
89 self.assertEqual(v.targets, self.client.ids[-1])
89 self.assertEqual(v.targets, self.client.ids[-1])
90
90
91 def test_lazy_all_targets(self):
91 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
92 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
93 v = self.client.direct_view()
94 self.assertEqual(v.targets, 'all')
94 self.assertEqual(v.targets, 'all')
95
95
96 def double(x):
96 def double(x):
97 return x*2
97 return x*2
98 seq = range(100)
98 seq = range(100)
99 ref = [ double(x) for x in seq ]
99 ref = [ double(x) for x in seq ]
100
100
101 # add some engines, which should be used
101 # add some engines, which should be used
102 self.add_engines(1)
102 self.add_engines(1)
103 n1 = len(self.client.ids)
103 n1 = len(self.client.ids)
104
104
105 # simple apply
105 # simple apply
106 r = v.apply_sync(lambda : 1)
106 r = v.apply_sync(lambda : 1)
107 self.assertEqual(r, [1] * n1)
107 self.assertEqual(r, [1] * n1)
108
108
109 # map goes through remotefunction
109 # map goes through remotefunction
110 r = v.map_sync(double, seq)
110 r = v.map_sync(double, seq)
111 self.assertEqual(r, ref)
111 self.assertEqual(r, ref)
112
112
113 # add a couple more engines, and try again
113 # add a couple more engines, and try again
114 self.add_engines(2)
114 self.add_engines(2)
115 n2 = len(self.client.ids)
115 n2 = len(self.client.ids)
116 self.assertNotEqual(n2, n1)
116 self.assertNotEqual(n2, n1)
117
117
118 # apply
118 # apply
119 r = v.apply_sync(lambda : 1)
119 r = v.apply_sync(lambda : 1)
120 self.assertEqual(r, [1] * n2)
120 self.assertEqual(r, [1] * n2)
121
121
122 # map
122 # map
123 r = v.map_sync(double, seq)
123 r = v.map_sync(double, seq)
124 self.assertEqual(r, ref)
124 self.assertEqual(r, ref)
125
125
126 def test_targets(self):
126 def test_targets(self):
127 """test various valid targets arguments"""
127 """test various valid targets arguments"""
128 build = self.client._build_targets
128 build = self.client._build_targets
129 ids = self.client.ids
129 ids = self.client.ids
130 idents,targets = build(None)
130 idents,targets = build(None)
131 self.assertEqual(ids, targets)
131 self.assertEqual(ids, targets)
132
132
133 def test_clear(self):
133 def test_clear(self):
134 """test clear behavior"""
134 """test clear behavior"""
135 self.minimum_engines(2)
135 self.minimum_engines(2)
136 v = self.client[:]
136 v = self.client[:]
137 v.block=True
137 v.block=True
138 v.push(dict(a=5))
138 v.push(dict(a=5))
139 v.pull('a')
139 v.pull('a')
140 id0 = self.client.ids[-1]
140 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
141 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
142 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
144 self.client.clear(block=True)
145 for i in self.client.ids:
145 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
147
148 def test_get_result(self):
148 def test_get_result(self):
149 """test getting results from the Hub."""
149 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
150 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
151 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
152 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
153 # give the monitor time to notice the message
154 time.sleep(.25)
154 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids[0])
155 ahr = self.client.get_result(ar.msg_ids[0])
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids[0])
158 ar2 = self.client.get_result(ar.msg_ids[0])
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
160 c.close()
161
161
162 def test_get_execute_result(self):
162 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
163 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
164 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
165 t = c.ids[-1]
166 cell = '\n'.join([
166 cell = '\n'.join([
167 'import time',
167 'import time',
168 'time.sleep(0.25)',
168 'time.sleep(0.25)',
169 '5'
169 '5'
170 ])
170 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
172 # give the monitor time to notice the message
173 time.sleep(.25)
173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids[0])
174 ahr = self.client.get_result(ar.msg_ids[0])
175 print ar.get(), ahr.get(), ar._single_result, ahr._single_result
176 self.assertTrue(isinstance(ahr, AsyncHubResult))
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
177 self.assertEqual(ahr.get().pyout, ar.get().pyout)
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
178 ar2 = self.client.get_result(ar.msg_ids[0])
177 ar2 = self.client.get_result(ar.msg_ids[0])
179 self.assertFalse(isinstance(ar2, AsyncHubResult))
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
180 c.close()
179 c.close()
181
180
182 def test_ids_list(self):
181 def test_ids_list(self):
183 """test client.ids"""
182 """test client.ids"""
184 ids = self.client.ids
183 ids = self.client.ids
185 self.assertEqual(ids, self.client._ids)
184 self.assertEqual(ids, self.client._ids)
186 self.assertFalse(ids is self.client._ids)
185 self.assertFalse(ids is self.client._ids)
187 ids.remove(ids[-1])
186 ids.remove(ids[-1])
188 self.assertNotEqual(ids, self.client._ids)
187 self.assertNotEqual(ids, self.client._ids)
189
188
190 def test_queue_status(self):
189 def test_queue_status(self):
191 ids = self.client.ids
190 ids = self.client.ids
192 id0 = ids[0]
191 id0 = ids[0]
193 qs = self.client.queue_status(targets=id0)
192 qs = self.client.queue_status(targets=id0)
194 self.assertTrue(isinstance(qs, dict))
193 self.assertTrue(isinstance(qs, dict))
195 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
196 allqs = self.client.queue_status()
195 allqs = self.client.queue_status()
197 self.assertTrue(isinstance(allqs, dict))
196 self.assertTrue(isinstance(allqs, dict))
198 intkeys = list(allqs.keys())
197 intkeys = list(allqs.keys())
199 intkeys.remove('unassigned')
198 intkeys.remove('unassigned')
200 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
201 unassigned = allqs.pop('unassigned')
200 unassigned = allqs.pop('unassigned')
202 for eid,qs in allqs.items():
201 for eid,qs in allqs.items():
203 self.assertTrue(isinstance(qs, dict))
202 self.assertTrue(isinstance(qs, dict))
204 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
205
204
206 def test_shutdown(self):
205 def test_shutdown(self):
207 ids = self.client.ids
206 ids = self.client.ids
208 id0 = ids[0]
207 id0 = ids[0]
209 self.client.shutdown(id0, block=True)
208 self.client.shutdown(id0, block=True)
210 while id0 in self.client.ids:
209 while id0 in self.client.ids:
211 time.sleep(0.1)
210 time.sleep(0.1)
212 self.client.spin()
211 self.client.spin()
213
212
214 self.assertRaises(IndexError, lambda : self.client[id0])
213 self.assertRaises(IndexError, lambda : self.client[id0])
215
214
216 def test_result_status(self):
215 def test_result_status(self):
217 pass
216 pass
218 # to be written
217 # to be written
219
218
220 def test_db_query_dt(self):
219 def test_db_query_dt(self):
221 """test db query by date"""
220 """test db query by date"""
222 hist = self.client.hub_history()
221 hist = self.client.hub_history()
223 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
224 tic = middle['submitted']
223 tic = middle['submitted']
225 before = self.client.db_query({'submitted' : {'$lt' : tic}})
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
226 after = self.client.db_query({'submitted' : {'$gte' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
227 self.assertEqual(len(before)+len(after),len(hist))
226 self.assertEqual(len(before)+len(after),len(hist))
228 for b in before:
227 for b in before:
229 self.assertTrue(b['submitted'] < tic)
228 self.assertTrue(b['submitted'] < tic)
230 for a in after:
229 for a in after:
231 self.assertTrue(a['submitted'] >= tic)
230 self.assertTrue(a['submitted'] >= tic)
232 same = self.client.db_query({'submitted' : tic})
231 same = self.client.db_query({'submitted' : tic})
233 for s in same:
232 for s in same:
234 self.assertTrue(s['submitted'] == tic)
233 self.assertTrue(s['submitted'] == tic)
235
234
236 def test_db_query_keys(self):
235 def test_db_query_keys(self):
237 """test extracting subset of record keys"""
236 """test extracting subset of record keys"""
238 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
239 for rec in found:
238 for rec in found:
240 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
241
240
242 def test_db_query_default_keys(self):
241 def test_db_query_default_keys(self):
243 """default db_query excludes buffers"""
242 """default db_query excludes buffers"""
244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 for rec in found:
244 for rec in found:
246 keys = set(rec.keys())
245 keys = set(rec.keys())
247 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
248 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
249
248
250 def test_db_query_msg_id(self):
249 def test_db_query_msg_id(self):
251 """ensure msg_id is always in db queries"""
250 """ensure msg_id is always in db queries"""
252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
253 for rec in found:
252 for rec in found:
254 self.assertTrue('msg_id' in rec.keys())
253 self.assertTrue('msg_id' in rec.keys())
255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
256 for rec in found:
255 for rec in found:
257 self.assertTrue('msg_id' in rec.keys())
256 self.assertTrue('msg_id' in rec.keys())
258 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
259 for rec in found:
258 for rec in found:
260 self.assertTrue('msg_id' in rec.keys())
259 self.assertTrue('msg_id' in rec.keys())
261
260
262 def test_db_query_get_result(self):
261 def test_db_query_get_result(self):
263 """pop in db_query shouldn't pop from result itself"""
262 """pop in db_query shouldn't pop from result itself"""
264 self.client[:].apply_sync(lambda : 1)
263 self.client[:].apply_sync(lambda : 1)
265 found = self.client.db_query({'msg_id': {'$ne' : ''}})
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
266 rc2 = clientmod.Client(profile='iptest')
265 rc2 = clientmod.Client(profile='iptest')
267 # If this bug is not fixed, this call will hang:
266 # If this bug is not fixed, this call will hang:
268 ar = rc2.get_result(self.client.history[-1])
267 ar = rc2.get_result(self.client.history[-1])
269 ar.wait(2)
268 ar.wait(2)
270 self.assertTrue(ar.ready())
269 self.assertTrue(ar.ready())
271 ar.get()
270 ar.get()
272 rc2.close()
271 rc2.close()
273
272
274 def test_db_query_in(self):
273 def test_db_query_in(self):
275 """test db query with '$in','$nin' operators"""
274 """test db query with '$in','$nin' operators"""
276 hist = self.client.hub_history()
275 hist = self.client.hub_history()
277 even = hist[::2]
276 even = hist[::2]
278 odd = hist[1::2]
277 odd = hist[1::2]
279 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
280 found = [ r['msg_id'] for r in recs ]
279 found = [ r['msg_id'] for r in recs ]
281 self.assertEqual(set(even), set(found))
280 self.assertEqual(set(even), set(found))
282 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
283 found = [ r['msg_id'] for r in recs ]
282 found = [ r['msg_id'] for r in recs ]
284 self.assertEqual(set(odd), set(found))
283 self.assertEqual(set(odd), set(found))
285
284
286 def test_hub_history(self):
285 def test_hub_history(self):
287 hist = self.client.hub_history()
286 hist = self.client.hub_history()
288 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
289 recdict = {}
288 recdict = {}
290 for rec in recs:
289 for rec in recs:
291 recdict[rec['msg_id']] = rec
290 recdict[rec['msg_id']] = rec
292
291
293 latest = datetime(1984,1,1)
292 latest = datetime(1984,1,1)
294 for msg_id in hist:
293 for msg_id in hist:
295 rec = recdict[msg_id]
294 rec = recdict[msg_id]
296 newt = rec['submitted']
295 newt = rec['submitted']
297 self.assertTrue(newt >= latest)
296 self.assertTrue(newt >= latest)
298 latest = newt
297 latest = newt
299 ar = self.client[-1].apply_async(lambda : 1)
298 ar = self.client[-1].apply_async(lambda : 1)
300 ar.get()
299 ar.get()
301 time.sleep(0.25)
300 time.sleep(0.25)
302 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
303
302
304 def _wait_for_idle(self):
303 def _wait_for_idle(self):
305 """wait for an engine to become idle, according to the Hub"""
304 """wait for an engine to become idle, according to the Hub"""
306 rc = self.client
305 rc = self.client
307
306
308 # step 1. wait for all requests to be noticed
307 # step 1. wait for all requests to be noticed
309 # timeout 5s, polling every 100ms
308 # timeout 5s, polling every 100ms
310 msg_ids = set(rc.history)
309 msg_ids = set(rc.history)
311 hub_hist = rc.hub_history()
310 hub_hist = rc.hub_history()
312 for i in range(50):
311 for i in range(50):
313 if msg_ids.difference(hub_hist):
312 if msg_ids.difference(hub_hist):
314 time.sleep(0.1)
313 time.sleep(0.1)
315 hub_hist = rc.hub_history()
314 hub_hist = rc.hub_history()
316 else:
315 else:
317 break
316 break
318
317
319 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
320
319
321 # step 2. wait for all requests to be done
320 # step 2. wait for all requests to be done
322 # timeout 5s, polling every 100ms
321 # timeout 5s, polling every 100ms
323 qs = rc.queue_status()
322 qs = rc.queue_status()
324 for i in range(50):
323 for i in range(50):
325 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
326 time.sleep(0.1)
325 time.sleep(0.1)
327 qs = rc.queue_status()
326 qs = rc.queue_status()
328 else:
327 else:
329 break
328 break
330
329
331 # ensure Hub up to date:
330 # ensure Hub up to date:
332 self.assertEqual(qs['unassigned'], 0)
331 self.assertEqual(qs['unassigned'], 0)
333 for eid in rc.ids:
332 for eid in rc.ids:
334 self.assertEqual(qs[eid]['tasks'], 0)
333 self.assertEqual(qs[eid]['tasks'], 0)
335
334
336
335
337 def test_resubmit(self):
336 def test_resubmit(self):
338 def f():
337 def f():
339 import random
338 import random
340 return random.random()
339 return random.random()
341 v = self.client.load_balanced_view()
340 v = self.client.load_balanced_view()
342 ar = v.apply_async(f)
341 ar = v.apply_async(f)
343 r1 = ar.get(1)
342 r1 = ar.get(1)
344 # give the Hub a chance to notice:
343 # give the Hub a chance to notice:
345 self._wait_for_idle()
344 self._wait_for_idle()
346 ahr = self.client.resubmit(ar.msg_ids)
345 ahr = self.client.resubmit(ar.msg_ids)
347 r2 = ahr.get(1)
346 r2 = ahr.get(1)
348 self.assertFalse(r1 == r2)
347 self.assertFalse(r1 == r2)
349
348
350 def test_resubmit_chain(self):
349 def test_resubmit_chain(self):
351 """resubmit resubmitted tasks"""
350 """resubmit resubmitted tasks"""
352 v = self.client.load_balanced_view()
351 v = self.client.load_balanced_view()
353 ar = v.apply_async(lambda x: x, 'x'*1024)
352 ar = v.apply_async(lambda x: x, 'x'*1024)
354 ar.get()
353 ar.get()
355 self._wait_for_idle()
354 self._wait_for_idle()
356 ars = [ar]
355 ars = [ar]
357
356
358 for i in range(10):
357 for i in range(10):
359 ar = ars[-1]
358 ar = ars[-1]
360 ar2 = self.client.resubmit(ar.msg_ids)
359 ar2 = self.client.resubmit(ar.msg_ids)
361
360
362 [ ar.get() for ar in ars ]
361 [ ar.get() for ar in ars ]
363
362
364 def test_resubmit_header(self):
363 def test_resubmit_header(self):
365 """resubmit shouldn't clobber the whole header"""
364 """resubmit shouldn't clobber the whole header"""
366 def f():
365 def f():
367 import random
366 import random
368 return random.random()
367 return random.random()
369 v = self.client.load_balanced_view()
368 v = self.client.load_balanced_view()
370 v.retries = 1
369 v.retries = 1
371 ar = v.apply_async(f)
370 ar = v.apply_async(f)
372 r1 = ar.get(1)
371 r1 = ar.get(1)
373 # give the Hub a chance to notice:
372 # give the Hub a chance to notice:
374 self._wait_for_idle()
373 self._wait_for_idle()
375 ahr = self.client.resubmit(ar.msg_ids)
374 ahr = self.client.resubmit(ar.msg_ids)
376 ahr.get(1)
375 ahr.get(1)
377 time.sleep(0.5)
376 time.sleep(0.5)
378 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
379 h1,h2 = [ r['header'] for r in records ]
378 h1,h2 = [ r['header'] for r in records ]
380 for key in set(h1.keys()).union(set(h2.keys())):
379 for key in set(h1.keys()).union(set(h2.keys())):
381 if key in ('msg_id', 'date'):
380 if key in ('msg_id', 'date'):
382 self.assertNotEqual(h1[key], h2[key])
381 self.assertNotEqual(h1[key], h2[key])
383 else:
382 else:
384 self.assertEqual(h1[key], h2[key])
383 self.assertEqual(h1[key], h2[key])
385
384
386 def test_resubmit_aborted(self):
385 def test_resubmit_aborted(self):
387 def f():
386 def f():
388 import random
387 import random
389 return random.random()
388 return random.random()
390 v = self.client.load_balanced_view()
389 v = self.client.load_balanced_view()
391 # restrict to one engine, so we can put a sleep
390 # restrict to one engine, so we can put a sleep
392 # ahead of the task, so it will get aborted
391 # ahead of the task, so it will get aborted
393 eid = self.client.ids[-1]
392 eid = self.client.ids[-1]
394 v.targets = [eid]
393 v.targets = [eid]
395 sleep = v.apply_async(time.sleep, 0.5)
394 sleep = v.apply_async(time.sleep, 0.5)
396 ar = v.apply_async(f)
395 ar = v.apply_async(f)
397 ar.abort()
396 ar.abort()
398 self.assertRaises(error.TaskAborted, ar.get)
397 self.assertRaises(error.TaskAborted, ar.get)
399 # Give the Hub a chance to get up to date:
398 # Give the Hub a chance to get up to date:
400 self._wait_for_idle()
399 self._wait_for_idle()
401 ahr = self.client.resubmit(ar.msg_ids)
400 ahr = self.client.resubmit(ar.msg_ids)
402 r2 = ahr.get(1)
401 r2 = ahr.get(1)
403
402
404 def test_resubmit_inflight(self):
403 def test_resubmit_inflight(self):
405 """resubmit of inflight task"""
404 """resubmit of inflight task"""
406 v = self.client.load_balanced_view()
405 v = self.client.load_balanced_view()
407 ar = v.apply_async(time.sleep,1)
406 ar = v.apply_async(time.sleep,1)
408 # give the message a chance to arrive
407 # give the message a chance to arrive
409 time.sleep(0.2)
408 time.sleep(0.2)
410 ahr = self.client.resubmit(ar.msg_ids)
409 ahr = self.client.resubmit(ar.msg_ids)
411 ar.get(2)
410 ar.get(2)
412 ahr.get(2)
411 ahr.get(2)
413
412
414 def test_resubmit_badkey(self):
413 def test_resubmit_badkey(self):
415 """ensure KeyError on resubmit of nonexistant task"""
414 """ensure KeyError on resubmit of nonexistant task"""
416 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
417
416
418 def test_purge_hub_results(self):
417 def test_purge_hub_results(self):
419 # ensure there are some tasks
418 # ensure there are some tasks
420 for i in range(5):
419 for i in range(5):
421 self.client[:].apply_sync(lambda : 1)
420 self.client[:].apply_sync(lambda : 1)
422 # Wait for the Hub to realise the result is done:
421 # Wait for the Hub to realise the result is done:
423 # This prevents a race condition, where we
422 # This prevents a race condition, where we
424 # might purge a result the Hub still thinks is pending.
423 # might purge a result the Hub still thinks is pending.
425 self._wait_for_idle()
424 self._wait_for_idle()
426 rc2 = clientmod.Client(profile='iptest')
425 rc2 = clientmod.Client(profile='iptest')
427 hist = self.client.hub_history()
426 hist = self.client.hub_history()
428 ahr = rc2.get_result([hist[-1]])
427 ahr = rc2.get_result([hist[-1]])
429 ahr.wait(10)
428 ahr.wait(10)
430 self.client.purge_hub_results(hist[-1])
429 self.client.purge_hub_results(hist[-1])
431 newhist = self.client.hub_history()
430 newhist = self.client.hub_history()
432 self.assertEqual(len(newhist)+1,len(hist))
431 self.assertEqual(len(newhist)+1,len(hist))
433 rc2.spin()
432 rc2.spin()
434 rc2.close()
433 rc2.close()
435
434
436 def test_purge_local_results(self):
435 def test_purge_local_results(self):
437 # ensure there are some tasks
436 # ensure there are some tasks
438 res = []
437 res = []
439 for i in range(5):
438 for i in range(5):
440 res.append(self.client[:].apply_async(lambda : 1))
439 res.append(self.client[:].apply_async(lambda : 1))
441 self._wait_for_idle()
440 self._wait_for_idle()
442 self.client.wait(10) # wait for the results to come back
441 self.client.wait(10) # wait for the results to come back
443 before = len(self.client.results)
442 before = len(self.client.results)
444 self.assertEqual(len(self.client.metadata),before)
443 self.assertEqual(len(self.client.metadata),before)
445 self.client.purge_local_results(res[-1])
444 self.client.purge_local_results(res[-1])
446 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
447 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
448
447
449 def test_purge_all_hub_results(self):
448 def test_purge_all_hub_results(self):
450 self.client.purge_hub_results('all')
449 self.client.purge_hub_results('all')
451 hist = self.client.hub_history()
450 hist = self.client.hub_history()
452 self.assertEqual(len(hist), 0)
451 self.assertEqual(len(hist), 0)
453
452
454 def test_purge_all_local_results(self):
453 def test_purge_all_local_results(self):
455 self.client.purge_local_results('all')
454 self.client.purge_local_results('all')
456 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
455 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
457 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
456 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
458
457
459 def test_purge_all_results(self):
458 def test_purge_all_results(self):
460 # ensure there are some tasks
459 # ensure there are some tasks
461 for i in range(5):
460 for i in range(5):
462 self.client[:].apply_sync(lambda : 1)
461 self.client[:].apply_sync(lambda : 1)
463 self.client.wait(10)
462 self.client.wait(10)
464 self._wait_for_idle()
463 self._wait_for_idle()
465 self.client.purge_results('all')
464 self.client.purge_results('all')
466 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
465 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
467 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
466 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
468 hist = self.client.hub_history()
467 hist = self.client.hub_history()
469 self.assertEqual(len(hist), 0, msg="hub history not empty")
468 self.assertEqual(len(hist), 0, msg="hub history not empty")
470
469
471 def test_purge_everything(self):
470 def test_purge_everything(self):
472 # ensure there are some tasks
471 # ensure there are some tasks
473 for i in range(5):
472 for i in range(5):
474 self.client[:].apply_sync(lambda : 1)
473 self.client[:].apply_sync(lambda : 1)
475 self.client.wait(10)
474 self.client.wait(10)
476 self._wait_for_idle()
475 self._wait_for_idle()
477 self.client.purge_everything()
476 self.client.purge_everything()
478 # The client results
477 # The client results
479 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
478 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
480 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
479 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
481 # The client "bookkeeping"
480 # The client "bookkeeping"
482 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
481 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
483 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
482 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
484 # the hub results
483 # the hub results
485 hist = self.client.hub_history()
484 hist = self.client.hub_history()
486 self.assertEqual(len(hist), 0, msg="hub history not empty")
485 self.assertEqual(len(hist), 0, msg="hub history not empty")
487
486
488
487
489 def test_spin_thread(self):
488 def test_spin_thread(self):
490 self.client.spin_thread(0.01)
489 self.client.spin_thread(0.01)
491 ar = self.client[-1].apply_async(lambda : 1)
490 ar = self.client[-1].apply_async(lambda : 1)
492 time.sleep(0.1)
491 time.sleep(0.1)
493 self.assertTrue(ar.wall_time < 0.1,
492 self.assertTrue(ar.wall_time < 0.1,
494 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
493 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
495 )
494 )
496
495
497 def test_stop_spin_thread(self):
496 def test_stop_spin_thread(self):
498 self.client.spin_thread(0.01)
497 self.client.spin_thread(0.01)
499 self.client.stop_spin_thread()
498 self.client.stop_spin_thread()
500 ar = self.client[-1].apply_async(lambda : 1)
499 ar = self.client[-1].apply_async(lambda : 1)
501 time.sleep(0.15)
500 time.sleep(0.15)
502 self.assertTrue(ar.wall_time > 0.1,
501 self.assertTrue(ar.wall_time > 0.1,
503 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
502 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
504 )
503 )
505
504
506 def test_activate(self):
505 def test_activate(self):
507 ip = get_ipython()
506 ip = get_ipython()
508 magics = ip.magics_manager.magics
507 magics = ip.magics_manager.magics
509 self.assertTrue('px' in magics['line'])
508 self.assertTrue('px' in magics['line'])
510 self.assertTrue('px' in magics['cell'])
509 self.assertTrue('px' in magics['cell'])
511 v0 = self.client.activate(-1, '0')
510 v0 = self.client.activate(-1, '0')
512 self.assertTrue('px0' in magics['line'])
511 self.assertTrue('px0' in magics['line'])
513 self.assertTrue('px0' in magics['cell'])
512 self.assertTrue('px0' in magics['cell'])
514 self.assertEqual(v0.targets, self.client.ids[-1])
513 self.assertEqual(v0.targets, self.client.ids[-1])
515 v0 = self.client.activate('all', 'all')
514 v0 = self.client.activate('all', 'all')
516 self.assertTrue('pxall' in magics['line'])
515 self.assertTrue('pxall' in magics['line'])
517 self.assertTrue('pxall' in magics['cell'])
516 self.assertTrue('pxall' in magics['cell'])
518 self.assertEqual(v0.targets, 'all')
517 self.assertEqual(v0.targets, 'all')
General Comments 0
You need to be logged in to leave comments. Login now