##// END OF EJS Templates
add retries flag to LoadBalancedView...
MinRK -
Show More
@@ -0,0 +1,120 b''
1 """test LoadBalancedView objects"""
2 # -*- coding: utf-8 -*-
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 import sys
15 import time
16
17 import zmq
18
19 from IPython import parallel as pmod
20 from IPython.parallel import error
21
22 from IPython.parallel.tests import add_engines
23
24 from .clienttest import ClusterTestCase, crash, wait, skip_without
25
26 def setup():
27 add_engines(3)
28
29 class TestLoadBalancedView(ClusterTestCase):
30
31 def setUp(self):
32 ClusterTestCase.setUp(self)
33 self.view = self.client.load_balanced_view()
34
35 def test_z_crash_task(self):
36 """test graceful handling of engine death (balanced)"""
37 # self.add_engines(1)
38 ar = self.view.apply_async(crash)
39 self.assertRaisesRemote(error.EngineError, ar.get)
40 eid = ar.engine_id
41 tic = time.time()
42 while eid in self.client.ids and time.time()-tic < 5:
43 time.sleep(.01)
44 self.client.spin()
45 self.assertFalse(eid in self.client.ids, "Engine should have died")
46
47 def test_map(self):
48 def f(x):
49 return x**2
50 data = range(16)
51 r = self.view.map_sync(f, data)
52 self.assertEquals(r, map(f, data))
53
54 def test_abort(self):
55 view = self.view
56 ar = self.client[:].apply_async(time.sleep, .5)
57 ar2 = view.apply_async(lambda : 2)
58 ar3 = view.apply_async(lambda : 3)
59 view.abort(ar2)
60 view.abort(ar3.msg_ids)
61 self.assertRaises(error.TaskAborted, ar2.get)
62 self.assertRaises(error.TaskAborted, ar3.get)
63
64 def test_retries(self):
65 add_engines(3)
66 view = self.view
67 view.timeout = 1 # prevent hang if this doesn't behave
68 def fail():
69 assert False
70 for r in range(len(self.client)-1):
71 with view.temp_flags(retries=r):
72 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
73
74 with view.temp_flags(retries=len(self.client), timeout=0.25):
75 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
76
77 def test_invalid_dependency(self):
78 view = self.view
79 with view.temp_flags(after='12345'):
80 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
81
82 def test_impossible_dependency(self):
83 if len(self.client) < 2:
84 add_engines(2)
85 view = self.client.load_balanced_view()
86 ar1 = view.apply_async(lambda : 1)
87 ar1.get()
88 e1 = ar1.engine_id
89 e2 = e1
90 while e2 == e1:
91 ar2 = view.apply_async(lambda : 1)
92 ar2.get()
93 e2 = ar2.engine_id
94
95 with view.temp_flags(follow=[ar1, ar2]):
96 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
97
98
99 def test_follow(self):
100 ar = self.view.apply_async(lambda : 1)
101 ar.get()
102 ars = []
103 first_id = ar.engine_id
104
105 self.view.follow = ar
106 for i in range(5):
107 ars.append(self.view.apply_async(lambda : 1))
108 self.view.wait(ars)
109 for ar in ars:
110 self.assertEquals(ar.engine_id, first_id)
111
112 def test_after(self):
113 view = self.view
114 ar = view.apply_async(time.sleep, 0.5)
115 with view.temp_flags(after=ar):
116 ar2 = view.apply_async(lambda : 1)
117
118 ar.wait()
119 ar2.wait()
120 self.assertTrue(ar2.started > ar.completed)
@@ -1,1033 +1,1042 b''
1 """Views of remote engines."""
1 """Views of remote engines."""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import imp
13 import imp
14 import sys
14 import sys
15 import warnings
15 import warnings
16 from contextlib import contextmanager
16 from contextlib import contextmanager
17 from types import ModuleType
17 from types import ModuleType
18
18
19 import zmq
19 import zmq
20
20
21 from IPython.testing import decorators as testdec
21 from IPython.testing import decorators as testdec
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
23
23
24 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
25
25
26 from IPython.parallel import util
26 from IPython.parallel import util
27 from IPython.parallel.controller.dependency import Dependency, dependent
27 from IPython.parallel.controller.dependency import Dependency, dependent
28
28
29 from . import map as Map
29 from . import map as Map
30 from .asyncresult import AsyncResult, AsyncMapResult
30 from .asyncresult import AsyncResult, AsyncMapResult
31 from .remotefunction import ParallelFunction, parallel, remote
31 from .remotefunction import ParallelFunction, parallel, remote
32
32
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34 # Decorators
34 # Decorators
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36
36
37 @decorator
37 @decorator
38 def save_ids(f, self, *args, **kwargs):
38 def save_ids(f, self, *args, **kwargs):
39 """Keep our history and outstanding attributes up to date after a method call."""
39 """Keep our history and outstanding attributes up to date after a method call."""
40 n_previous = len(self.client.history)
40 n_previous = len(self.client.history)
41 try:
41 try:
42 ret = f(self, *args, **kwargs)
42 ret = f(self, *args, **kwargs)
43 finally:
43 finally:
44 nmsgs = len(self.client.history) - n_previous
44 nmsgs = len(self.client.history) - n_previous
45 msg_ids = self.client.history[-nmsgs:]
45 msg_ids = self.client.history[-nmsgs:]
46 self.history.extend(msg_ids)
46 self.history.extend(msg_ids)
47 map(self.outstanding.add, msg_ids)
47 map(self.outstanding.add, msg_ids)
48 return ret
48 return ret
49
49
50 @decorator
50 @decorator
51 def sync_results(f, self, *args, **kwargs):
51 def sync_results(f, self, *args, **kwargs):
52 """sync relevant results from self.client to our results attribute."""
52 """sync relevant results from self.client to our results attribute."""
53 ret = f(self, *args, **kwargs)
53 ret = f(self, *args, **kwargs)
54 delta = self.outstanding.difference(self.client.outstanding)
54 delta = self.outstanding.difference(self.client.outstanding)
55 completed = self.outstanding.intersection(delta)
55 completed = self.outstanding.intersection(delta)
56 self.outstanding = self.outstanding.difference(completed)
56 self.outstanding = self.outstanding.difference(completed)
57 for msg_id in completed:
57 for msg_id in completed:
58 self.results[msg_id] = self.client.results[msg_id]
58 self.results[msg_id] = self.client.results[msg_id]
59 return ret
59 return ret
60
60
61 @decorator
61 @decorator
62 def spin_after(f, self, *args, **kwargs):
62 def spin_after(f, self, *args, **kwargs):
63 """call spin after the method."""
63 """call spin after the method."""
64 ret = f(self, *args, **kwargs)
64 ret = f(self, *args, **kwargs)
65 self.spin()
65 self.spin()
66 return ret
66 return ret
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72 @testdec.skip_doctest
72 @testdec.skip_doctest
73 class View(HasTraits):
73 class View(HasTraits):
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75
75
76 Don't use this class, use subclasses.
76 Don't use this class, use subclasses.
77
77
78 Methods
78 Methods
79 -------
79 -------
80
80
81 spin
81 spin
82 flushes incoming results and registration state changes
82 flushes incoming results and registration state changes
83 control methods spin, and requesting `ids` also ensures up to date
83 control methods spin, and requesting `ids` also ensures up to date
84
84
85 wait
85 wait
86 wait on one or more msg_ids
86 wait on one or more msg_ids
87
87
88 execution methods
88 execution methods
89 apply
89 apply
90 legacy: execute, run
90 legacy: execute, run
91
91
92 data movement
92 data movement
93 push, pull, scatter, gather
93 push, pull, scatter, gather
94
94
95 query methods
95 query methods
96 get_result, queue_status, purge_results, result_status
96 get_result, queue_status, purge_results, result_status
97
97
98 control methods
98 control methods
99 abort, shutdown
99 abort, shutdown
100
100
101 """
101 """
102 # flags
102 # flags
103 block=Bool(False)
103 block=Bool(False)
104 track=Bool(True)
104 track=Bool(True)
105 targets = Any()
105 targets = Any()
106
106
107 history=List()
107 history=List()
108 outstanding = Set()
108 outstanding = Set()
109 results = Dict()
109 results = Dict()
110 client = Instance('IPython.parallel.Client')
110 client = Instance('IPython.parallel.Client')
111
111
112 _socket = Instance('zmq.Socket')
112 _socket = Instance('zmq.Socket')
113 _flag_names = List(['targets', 'block', 'track'])
113 _flag_names = List(['targets', 'block', 'track'])
114 _targets = Any()
114 _targets = Any()
115 _idents = Any()
115 _idents = Any()
116
116
117 def __init__(self, client=None, socket=None, **flags):
117 def __init__(self, client=None, socket=None, **flags):
118 super(View, self).__init__(client=client, _socket=socket)
118 super(View, self).__init__(client=client, _socket=socket)
119 self.block = client.block
119 self.block = client.block
120
120
121 self.set_flags(**flags)
121 self.set_flags(**flags)
122
122
123 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
123 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
124
124
125
125
126 def __repr__(self):
126 def __repr__(self):
127 strtargets = str(self.targets)
127 strtargets = str(self.targets)
128 if len(strtargets) > 16:
128 if len(strtargets) > 16:
129 strtargets = strtargets[:12]+'...]'
129 strtargets = strtargets[:12]+'...]'
130 return "<%s %s>"%(self.__class__.__name__, strtargets)
130 return "<%s %s>"%(self.__class__.__name__, strtargets)
131
131
132 def set_flags(self, **kwargs):
132 def set_flags(self, **kwargs):
133 """set my attribute flags by keyword.
133 """set my attribute flags by keyword.
134
134
135 Views determine behavior with a few attributes (`block`, `track`, etc.).
135 Views determine behavior with a few attributes (`block`, `track`, etc.).
136 These attributes can be set all at once by name with this method.
136 These attributes can be set all at once by name with this method.
137
137
138 Parameters
138 Parameters
139 ----------
139 ----------
140
140
141 block : bool
141 block : bool
142 whether to wait for results
142 whether to wait for results
143 track : bool
143 track : bool
144 whether to create a MessageTracker to allow the user to
144 whether to create a MessageTracker to allow the user to
145 safely edit after arrays and buffers during non-copying
145 safely edit after arrays and buffers during non-copying
146 sends.
146 sends.
147 """
147 """
148 for name, value in kwargs.iteritems():
148 for name, value in kwargs.iteritems():
149 if name not in self._flag_names:
149 if name not in self._flag_names:
150 raise KeyError("Invalid name: %r"%name)
150 raise KeyError("Invalid name: %r"%name)
151 else:
151 else:
152 setattr(self, name, value)
152 setattr(self, name, value)
153
153
154 @contextmanager
154 @contextmanager
155 def temp_flags(self, **kwargs):
155 def temp_flags(self, **kwargs):
156 """temporarily set flags, for use in `with` statements.
156 """temporarily set flags, for use in `with` statements.
157
157
158 See set_flags for permanent setting of flags
158 See set_flags for permanent setting of flags
159
159
160 Examples
160 Examples
161 --------
161 --------
162
162
163 >>> view.track=False
163 >>> view.track=False
164 ...
164 ...
165 >>> with view.temp_flags(track=True):
165 >>> with view.temp_flags(track=True):
166 ... ar = view.apply(dostuff, my_big_array)
166 ... ar = view.apply(dostuff, my_big_array)
167 ... ar.tracker.wait() # wait for send to finish
167 ... ar.tracker.wait() # wait for send to finish
168 >>> view.track
168 >>> view.track
169 False
169 False
170
170
171 """
171 """
172 # preflight: save flags, and set temporaries
172 # preflight: save flags, and set temporaries
173 saved_flags = {}
173 saved_flags = {}
174 for f in self._flag_names:
174 for f in self._flag_names:
175 saved_flags[f] = getattr(self, f)
175 saved_flags[f] = getattr(self, f)
176 self.set_flags(**kwargs)
176 self.set_flags(**kwargs)
177 # yield to the with-statement block
177 # yield to the with-statement block
178 try:
178 try:
179 yield
179 yield
180 finally:
180 finally:
181 # postflight: restore saved flags
181 # postflight: restore saved flags
182 self.set_flags(**saved_flags)
182 self.set_flags(**saved_flags)
183
183
184
184
185 #----------------------------------------------------------------
185 #----------------------------------------------------------------
186 # apply
186 # apply
187 #----------------------------------------------------------------
187 #----------------------------------------------------------------
188
188
189 @sync_results
189 @sync_results
190 @save_ids
190 @save_ids
191 def _really_apply(self, f, args, kwargs, block=None, **options):
191 def _really_apply(self, f, args, kwargs, block=None, **options):
192 """wrapper for client.send_apply_message"""
192 """wrapper for client.send_apply_message"""
193 raise NotImplementedError("Implement in subclasses")
193 raise NotImplementedError("Implement in subclasses")
194
194
195 def apply(self, f, *args, **kwargs):
195 def apply(self, f, *args, **kwargs):
196 """calls f(*args, **kwargs) on remote engines, returning the result.
196 """calls f(*args, **kwargs) on remote engines, returning the result.
197
197
198 This method sets all apply flags via this View's attributes.
198 This method sets all apply flags via this View's attributes.
199
199
200 if self.block is False:
200 if self.block is False:
201 returns AsyncResult
201 returns AsyncResult
202 else:
202 else:
203 returns actual result of f(*args, **kwargs)
203 returns actual result of f(*args, **kwargs)
204 """
204 """
205 return self._really_apply(f, args, kwargs)
205 return self._really_apply(f, args, kwargs)
206
206
207 def apply_async(self, f, *args, **kwargs):
207 def apply_async(self, f, *args, **kwargs):
208 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
208 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
209
209
210 returns AsyncResult
210 returns AsyncResult
211 """
211 """
212 return self._really_apply(f, args, kwargs, block=False)
212 return self._really_apply(f, args, kwargs, block=False)
213
213
214 @spin_after
214 @spin_after
215 def apply_sync(self, f, *args, **kwargs):
215 def apply_sync(self, f, *args, **kwargs):
216 """calls f(*args, **kwargs) on remote engines in a blocking manner,
216 """calls f(*args, **kwargs) on remote engines in a blocking manner,
217 returning the result.
217 returning the result.
218
218
219 returns: actual result of f(*args, **kwargs)
219 returns: actual result of f(*args, **kwargs)
220 """
220 """
221 return self._really_apply(f, args, kwargs, block=True)
221 return self._really_apply(f, args, kwargs, block=True)
222
222
223 #----------------------------------------------------------------
223 #----------------------------------------------------------------
224 # wrappers for client and control methods
224 # wrappers for client and control methods
225 #----------------------------------------------------------------
225 #----------------------------------------------------------------
226 @sync_results
226 @sync_results
227 def spin(self):
227 def spin(self):
228 """spin the client, and sync"""
228 """spin the client, and sync"""
229 self.client.spin()
229 self.client.spin()
230
230
231 @sync_results
231 @sync_results
232 def wait(self, jobs=None, timeout=-1):
232 def wait(self, jobs=None, timeout=-1):
233 """waits on one or more `jobs`, for up to `timeout` seconds.
233 """waits on one or more `jobs`, for up to `timeout` seconds.
234
234
235 Parameters
235 Parameters
236 ----------
236 ----------
237
237
238 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
238 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
239 ints are indices to self.history
239 ints are indices to self.history
240 strs are msg_ids
240 strs are msg_ids
241 default: wait on all outstanding messages
241 default: wait on all outstanding messages
242 timeout : float
242 timeout : float
243 a time in seconds, after which to give up.
243 a time in seconds, after which to give up.
244 default is -1, which means no timeout
244 default is -1, which means no timeout
245
245
246 Returns
246 Returns
247 -------
247 -------
248
248
249 True : when all msg_ids are done
249 True : when all msg_ids are done
250 False : timeout reached, some msg_ids still outstanding
250 False : timeout reached, some msg_ids still outstanding
251 """
251 """
252 if jobs is None:
252 if jobs is None:
253 jobs = self.history
253 jobs = self.history
254 return self.client.wait(jobs, timeout)
254 return self.client.wait(jobs, timeout)
255
255
256 def abort(self, jobs=None, targets=None, block=None):
256 def abort(self, jobs=None, targets=None, block=None):
257 """Abort jobs on my engines.
257 """Abort jobs on my engines.
258
258
259 Parameters
259 Parameters
260 ----------
260 ----------
261
261
262 jobs : None, str, list of strs, optional
262 jobs : None, str, list of strs, optional
263 if None: abort all jobs.
263 if None: abort all jobs.
264 else: abort specific msg_id(s).
264 else: abort specific msg_id(s).
265 """
265 """
266 block = block if block is not None else self.block
266 block = block if block is not None else self.block
267 targets = targets if targets is not None else self.targets
267 targets = targets if targets is not None else self.targets
268 return self.client.abort(jobs=jobs, targets=targets, block=block)
268 return self.client.abort(jobs=jobs, targets=targets, block=block)
269
269
270 def queue_status(self, targets=None, verbose=False):
270 def queue_status(self, targets=None, verbose=False):
271 """Fetch the Queue status of my engines"""
271 """Fetch the Queue status of my engines"""
272 targets = targets if targets is not None else self.targets
272 targets = targets if targets is not None else self.targets
273 return self.client.queue_status(targets=targets, verbose=verbose)
273 return self.client.queue_status(targets=targets, verbose=verbose)
274
274
275 def purge_results(self, jobs=[], targets=[]):
275 def purge_results(self, jobs=[], targets=[]):
276 """Instruct the controller to forget specific results."""
276 """Instruct the controller to forget specific results."""
277 if targets is None or targets == 'all':
277 if targets is None or targets == 'all':
278 targets = self.targets
278 targets = self.targets
279 return self.client.purge_results(jobs=jobs, targets=targets)
279 return self.client.purge_results(jobs=jobs, targets=targets)
280
280
281 def shutdown(self, targets=None, restart=False, hub=False, block=None):
281 def shutdown(self, targets=None, restart=False, hub=False, block=None):
282 """Terminates one or more engine processes, optionally including the hub.
282 """Terminates one or more engine processes, optionally including the hub.
283 """
283 """
284 block = self.block if block is None else block
284 block = self.block if block is None else block
285 if targets is None or targets == 'all':
285 if targets is None or targets == 'all':
286 targets = self.targets
286 targets = self.targets
287 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
287 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
288
288
289 @spin_after
289 @spin_after
290 def get_result(self, indices_or_msg_ids=None):
290 def get_result(self, indices_or_msg_ids=None):
291 """return one or more results, specified by history index or msg_id.
291 """return one or more results, specified by history index or msg_id.
292
292
293 See client.get_result for details.
293 See client.get_result for details.
294
294
295 """
295 """
296
296
297 if indices_or_msg_ids is None:
297 if indices_or_msg_ids is None:
298 indices_or_msg_ids = -1
298 indices_or_msg_ids = -1
299 if isinstance(indices_or_msg_ids, int):
299 if isinstance(indices_or_msg_ids, int):
300 indices_or_msg_ids = self.history[indices_or_msg_ids]
300 indices_or_msg_ids = self.history[indices_or_msg_ids]
301 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
301 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
302 indices_or_msg_ids = list(indices_or_msg_ids)
302 indices_or_msg_ids = list(indices_or_msg_ids)
303 for i,index in enumerate(indices_or_msg_ids):
303 for i,index in enumerate(indices_or_msg_ids):
304 if isinstance(index, int):
304 if isinstance(index, int):
305 indices_or_msg_ids[i] = self.history[index]
305 indices_or_msg_ids[i] = self.history[index]
306 return self.client.get_result(indices_or_msg_ids)
306 return self.client.get_result(indices_or_msg_ids)
307
307
308 #-------------------------------------------------------------------
308 #-------------------------------------------------------------------
309 # Map
309 # Map
310 #-------------------------------------------------------------------
310 #-------------------------------------------------------------------
311
311
312 def map(self, f, *sequences, **kwargs):
312 def map(self, f, *sequences, **kwargs):
313 """override in subclasses"""
313 """override in subclasses"""
314 raise NotImplementedError
314 raise NotImplementedError
315
315
316 def map_async(self, f, *sequences, **kwargs):
316 def map_async(self, f, *sequences, **kwargs):
317 """Parallel version of builtin `map`, using this view's engines.
317 """Parallel version of builtin `map`, using this view's engines.
318
318
319 This is equivalent to map(...block=False)
319 This is equivalent to map(...block=False)
320
320
321 See `self.map` for details.
321 See `self.map` for details.
322 """
322 """
323 if 'block' in kwargs:
323 if 'block' in kwargs:
324 raise TypeError("map_async doesn't take a `block` keyword argument.")
324 raise TypeError("map_async doesn't take a `block` keyword argument.")
325 kwargs['block'] = False
325 kwargs['block'] = False
326 return self.map(f,*sequences,**kwargs)
326 return self.map(f,*sequences,**kwargs)
327
327
328 def map_sync(self, f, *sequences, **kwargs):
328 def map_sync(self, f, *sequences, **kwargs):
329 """Parallel version of builtin `map`, using this view's engines.
329 """Parallel version of builtin `map`, using this view's engines.
330
330
331 This is equivalent to map(...block=True)
331 This is equivalent to map(...block=True)
332
332
333 See `self.map` for details.
333 See `self.map` for details.
334 """
334 """
335 if 'block' in kwargs:
335 if 'block' in kwargs:
336 raise TypeError("map_sync doesn't take a `block` keyword argument.")
336 raise TypeError("map_sync doesn't take a `block` keyword argument.")
337 kwargs['block'] = True
337 kwargs['block'] = True
338 return self.map(f,*sequences,**kwargs)
338 return self.map(f,*sequences,**kwargs)
339
339
340 def imap(self, f, *sequences, **kwargs):
340 def imap(self, f, *sequences, **kwargs):
341 """Parallel version of `itertools.imap`.
341 """Parallel version of `itertools.imap`.
342
342
343 See `self.map` for details.
343 See `self.map` for details.
344
344
345 """
345 """
346
346
347 return iter(self.map_async(f,*sequences, **kwargs))
347 return iter(self.map_async(f,*sequences, **kwargs))
348
348
349 #-------------------------------------------------------------------
349 #-------------------------------------------------------------------
350 # Decorators
350 # Decorators
351 #-------------------------------------------------------------------
351 #-------------------------------------------------------------------
352
352
353 def remote(self, block=True, **flags):
353 def remote(self, block=True, **flags):
354 """Decorator for making a RemoteFunction"""
354 """Decorator for making a RemoteFunction"""
355 block = self.block if block is None else block
355 block = self.block if block is None else block
356 return remote(self, block=block, **flags)
356 return remote(self, block=block, **flags)
357
357
358 def parallel(self, dist='b', block=None, **flags):
358 def parallel(self, dist='b', block=None, **flags):
359 """Decorator for making a ParallelFunction"""
359 """Decorator for making a ParallelFunction"""
360 block = self.block if block is None else block
360 block = self.block if block is None else block
361 return parallel(self, dist=dist, block=block, **flags)
361 return parallel(self, dist=dist, block=block, **flags)
362
362
363 @testdec.skip_doctest
363 @testdec.skip_doctest
364 class DirectView(View):
364 class DirectView(View):
365 """Direct Multiplexer View of one or more engines.
365 """Direct Multiplexer View of one or more engines.
366
366
367 These are created via indexed access to a client:
367 These are created via indexed access to a client:
368
368
369 >>> dv_1 = client[1]
369 >>> dv_1 = client[1]
370 >>> dv_all = client[:]
370 >>> dv_all = client[:]
371 >>> dv_even = client[::2]
371 >>> dv_even = client[::2]
372 >>> dv_some = client[1:3]
372 >>> dv_some = client[1:3]
373
373
374 This object provides dictionary access to engine namespaces:
374 This object provides dictionary access to engine namespaces:
375
375
376 # push a=5:
376 # push a=5:
377 >>> dv['a'] = 5
377 >>> dv['a'] = 5
378 # pull 'foo':
378 # pull 'foo':
379 >>> db['foo']
379 >>> db['foo']
380
380
381 """
381 """
382
382
383 def __init__(self, client=None, socket=None, targets=None):
383 def __init__(self, client=None, socket=None, targets=None):
384 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
384 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
385
385
386 @property
386 @property
387 def importer(self):
387 def importer(self):
388 """sync_imports(local=True) as a property.
388 """sync_imports(local=True) as a property.
389
389
390 See sync_imports for details.
390 See sync_imports for details.
391
391
392 """
392 """
393 return self.sync_imports(True)
393 return self.sync_imports(True)
394
394
395 @contextmanager
395 @contextmanager
396 def sync_imports(self, local=True):
396 def sync_imports(self, local=True):
397 """Context Manager for performing simultaneous local and remote imports.
397 """Context Manager for performing simultaneous local and remote imports.
398
398
399 'import x as y' will *not* work. The 'as y' part will simply be ignored.
399 'import x as y' will *not* work. The 'as y' part will simply be ignored.
400
400
401 >>> with view.sync_imports():
401 >>> with view.sync_imports():
402 ... from numpy import recarray
402 ... from numpy import recarray
403 importing recarray from numpy on engine(s)
403 importing recarray from numpy on engine(s)
404
404
405 """
405 """
406 import __builtin__
406 import __builtin__
407 local_import = __builtin__.__import__
407 local_import = __builtin__.__import__
408 modules = set()
408 modules = set()
409 results = []
409 results = []
410 @util.interactive
410 @util.interactive
411 def remote_import(name, fromlist, level):
411 def remote_import(name, fromlist, level):
412 """the function to be passed to apply, that actually performs the import
412 """the function to be passed to apply, that actually performs the import
413 on the engine, and loads up the user namespace.
413 on the engine, and loads up the user namespace.
414 """
414 """
415 import sys
415 import sys
416 user_ns = globals()
416 user_ns = globals()
417 mod = __import__(name, fromlist=fromlist, level=level)
417 mod = __import__(name, fromlist=fromlist, level=level)
418 if fromlist:
418 if fromlist:
419 for key in fromlist:
419 for key in fromlist:
420 user_ns[key] = getattr(mod, key)
420 user_ns[key] = getattr(mod, key)
421 else:
421 else:
422 user_ns[name] = sys.modules[name]
422 user_ns[name] = sys.modules[name]
423
423
424 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
424 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
425 """the drop-in replacement for __import__, that optionally imports
425 """the drop-in replacement for __import__, that optionally imports
426 locally as well.
426 locally as well.
427 """
427 """
428 # don't override nested imports
428 # don't override nested imports
429 save_import = __builtin__.__import__
429 save_import = __builtin__.__import__
430 __builtin__.__import__ = local_import
430 __builtin__.__import__ = local_import
431
431
432 if imp.lock_held():
432 if imp.lock_held():
433 # this is a side-effect import, don't do it remotely, or even
433 # this is a side-effect import, don't do it remotely, or even
434 # ignore the local effects
434 # ignore the local effects
435 return local_import(name, globals, locals, fromlist, level)
435 return local_import(name, globals, locals, fromlist, level)
436
436
437 imp.acquire_lock()
437 imp.acquire_lock()
438 if local:
438 if local:
439 mod = local_import(name, globals, locals, fromlist, level)
439 mod = local_import(name, globals, locals, fromlist, level)
440 else:
440 else:
441 raise NotImplementedError("remote-only imports not yet implemented")
441 raise NotImplementedError("remote-only imports not yet implemented")
442 imp.release_lock()
442 imp.release_lock()
443
443
444 key = name+':'+','.join(fromlist or [])
444 key = name+':'+','.join(fromlist or [])
445 if level == -1 and key not in modules:
445 if level == -1 and key not in modules:
446 modules.add(key)
446 modules.add(key)
447 if fromlist:
447 if fromlist:
448 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
448 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
449 else:
449 else:
450 print "importing %s on engine(s)"%name
450 print "importing %s on engine(s)"%name
451 results.append(self.apply_async(remote_import, name, fromlist, level))
451 results.append(self.apply_async(remote_import, name, fromlist, level))
452 # restore override
452 # restore override
453 __builtin__.__import__ = save_import
453 __builtin__.__import__ = save_import
454
454
455 return mod
455 return mod
456
456
457 # override __import__
457 # override __import__
458 __builtin__.__import__ = view_import
458 __builtin__.__import__ = view_import
459 try:
459 try:
460 # enter the block
460 # enter the block
461 yield
461 yield
462 except ImportError:
462 except ImportError:
463 if not local:
463 if not local:
464 # ignore import errors if not doing local imports
464 # ignore import errors if not doing local imports
465 pass
465 pass
466 finally:
466 finally:
467 # always restore __import__
467 # always restore __import__
468 __builtin__.__import__ = local_import
468 __builtin__.__import__ = local_import
469
469
470 for r in results:
470 for r in results:
471 # raise possible remote ImportErrors here
471 # raise possible remote ImportErrors here
472 r.get()
472 r.get()
473
473
474
474
475 @sync_results
475 @sync_results
476 @save_ids
476 @save_ids
477 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
477 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
478 """calls f(*args, **kwargs) on remote engines, returning the result.
478 """calls f(*args, **kwargs) on remote engines, returning the result.
479
479
480 This method sets all of `apply`'s flags via this View's attributes.
480 This method sets all of `apply`'s flags via this View's attributes.
481
481
482 Parameters
482 Parameters
483 ----------
483 ----------
484
484
485 f : callable
485 f : callable
486
486
487 args : list [default: empty]
487 args : list [default: empty]
488
488
489 kwargs : dict [default: empty]
489 kwargs : dict [default: empty]
490
490
491 targets : target list [default: self.targets]
491 targets : target list [default: self.targets]
492 where to run
492 where to run
493 block : bool [default: self.block]
493 block : bool [default: self.block]
494 whether to block
494 whether to block
495 track : bool [default: self.track]
495 track : bool [default: self.track]
496 whether to ask zmq to track the message, for safe non-copying sends
496 whether to ask zmq to track the message, for safe non-copying sends
497
497
498 Returns
498 Returns
499 -------
499 -------
500
500
501 if self.block is False:
501 if self.block is False:
502 returns AsyncResult
502 returns AsyncResult
503 else:
503 else:
504 returns actual result of f(*args, **kwargs) on the engine(s)
504 returns actual result of f(*args, **kwargs) on the engine(s)
505 This will be a list of self.targets is also a list (even length 1), or
505 This will be a list of self.targets is also a list (even length 1), or
506 the single result if self.targets is an integer engine id
506 the single result if self.targets is an integer engine id
507 """
507 """
508 args = [] if args is None else args
508 args = [] if args is None else args
509 kwargs = {} if kwargs is None else kwargs
509 kwargs = {} if kwargs is None else kwargs
510 block = self.block if block is None else block
510 block = self.block if block is None else block
511 track = self.track if track is None else track
511 track = self.track if track is None else track
512 targets = self.targets if targets is None else targets
512 targets = self.targets if targets is None else targets
513
513
514 _idents = self.client._build_targets(targets)[0]
514 _idents = self.client._build_targets(targets)[0]
515 msg_ids = []
515 msg_ids = []
516 trackers = []
516 trackers = []
517 for ident in _idents:
517 for ident in _idents:
518 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
518 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
519 ident=ident)
519 ident=ident)
520 if track:
520 if track:
521 trackers.append(msg['tracker'])
521 trackers.append(msg['tracker'])
522 msg_ids.append(msg['msg_id'])
522 msg_ids.append(msg['msg_id'])
523 tracker = None if track is False else zmq.MessageTracker(*trackers)
523 tracker = None if track is False else zmq.MessageTracker(*trackers)
524 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
524 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
525 if block:
525 if block:
526 try:
526 try:
527 return ar.get()
527 return ar.get()
528 except KeyboardInterrupt:
528 except KeyboardInterrupt:
529 pass
529 pass
530 return ar
530 return ar
531
531
532 @spin_after
532 @spin_after
533 def map(self, f, *sequences, **kwargs):
533 def map(self, f, *sequences, **kwargs):
534 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
534 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
535
535
536 Parallel version of builtin `map`, using this View's `targets`.
536 Parallel version of builtin `map`, using this View's `targets`.
537
537
538 There will be one task per target, so work will be chunked
538 There will be one task per target, so work will be chunked
539 if the sequences are longer than `targets`.
539 if the sequences are longer than `targets`.
540
540
541 Results can be iterated as they are ready, but will become available in chunks.
541 Results can be iterated as they are ready, but will become available in chunks.
542
542
543 Parameters
543 Parameters
544 ----------
544 ----------
545
545
546 f : callable
546 f : callable
547 function to be mapped
547 function to be mapped
548 *sequences: one or more sequences of matching length
548 *sequences: one or more sequences of matching length
549 the sequences to be distributed and passed to `f`
549 the sequences to be distributed and passed to `f`
550 block : bool
550 block : bool
551 whether to wait for the result or not [default self.block]
551 whether to wait for the result or not [default self.block]
552
552
553 Returns
553 Returns
554 -------
554 -------
555
555
556 if block=False:
556 if block=False:
557 AsyncMapResult
557 AsyncMapResult
558 An object like AsyncResult, but which reassembles the sequence of results
558 An object like AsyncResult, but which reassembles the sequence of results
559 into a single list. AsyncMapResults can be iterated through before all
559 into a single list. AsyncMapResults can be iterated through before all
560 results are complete.
560 results are complete.
561 else:
561 else:
562 list
562 list
563 the result of map(f,*sequences)
563 the result of map(f,*sequences)
564 """
564 """
565
565
566 block = kwargs.pop('block', self.block)
566 block = kwargs.pop('block', self.block)
567 for k in kwargs.keys():
567 for k in kwargs.keys():
568 if k not in ['block', 'track']:
568 if k not in ['block', 'track']:
569 raise TypeError("invalid keyword arg, %r"%k)
569 raise TypeError("invalid keyword arg, %r"%k)
570
570
571 assert len(sequences) > 0, "must have some sequences to map onto!"
571 assert len(sequences) > 0, "must have some sequences to map onto!"
572 pf = ParallelFunction(self, f, block=block, **kwargs)
572 pf = ParallelFunction(self, f, block=block, **kwargs)
573 return pf.map(*sequences)
573 return pf.map(*sequences)
574
574
575 def execute(self, code, targets=None, block=None):
575 def execute(self, code, targets=None, block=None):
576 """Executes `code` on `targets` in blocking or nonblocking manner.
576 """Executes `code` on `targets` in blocking or nonblocking manner.
577
577
578 ``execute`` is always `bound` (affects engine namespace)
578 ``execute`` is always `bound` (affects engine namespace)
579
579
580 Parameters
580 Parameters
581 ----------
581 ----------
582
582
583 code : str
583 code : str
584 the code string to be executed
584 the code string to be executed
585 block : bool
585 block : bool
586 whether or not to wait until done to return
586 whether or not to wait until done to return
587 default: self.block
587 default: self.block
588 """
588 """
589 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
589 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
590
590
591 def run(self, filename, targets=None, block=None):
591 def run(self, filename, targets=None, block=None):
592 """Execute contents of `filename` on my engine(s).
592 """Execute contents of `filename` on my engine(s).
593
593
594 This simply reads the contents of the file and calls `execute`.
594 This simply reads the contents of the file and calls `execute`.
595
595
596 Parameters
596 Parameters
597 ----------
597 ----------
598
598
599 filename : str
599 filename : str
600 The path to the file
600 The path to the file
601 targets : int/str/list of ints/strs
601 targets : int/str/list of ints/strs
602 the engines on which to execute
602 the engines on which to execute
603 default : all
603 default : all
604 block : bool
604 block : bool
605 whether or not to wait until done
605 whether or not to wait until done
606 default: self.block
606 default: self.block
607
607
608 """
608 """
609 with open(filename, 'r') as f:
609 with open(filename, 'r') as f:
610 # add newline in case of trailing indented whitespace
610 # add newline in case of trailing indented whitespace
611 # which will cause SyntaxError
611 # which will cause SyntaxError
612 code = f.read()+'\n'
612 code = f.read()+'\n'
613 return self.execute(code, block=block, targets=targets)
613 return self.execute(code, block=block, targets=targets)
614
614
615 def update(self, ns):
615 def update(self, ns):
616 """update remote namespace with dict `ns`
616 """update remote namespace with dict `ns`
617
617
618 See `push` for details.
618 See `push` for details.
619 """
619 """
620 return self.push(ns, block=self.block, track=self.track)
620 return self.push(ns, block=self.block, track=self.track)
621
621
622 def push(self, ns, targets=None, block=None, track=None):
622 def push(self, ns, targets=None, block=None, track=None):
623 """update remote namespace with dict `ns`
623 """update remote namespace with dict `ns`
624
624
625 Parameters
625 Parameters
626 ----------
626 ----------
627
627
628 ns : dict
628 ns : dict
629 dict of keys with which to update engine namespace(s)
629 dict of keys with which to update engine namespace(s)
630 block : bool [default : self.block]
630 block : bool [default : self.block]
631 whether to wait to be notified of engine receipt
631 whether to wait to be notified of engine receipt
632
632
633 """
633 """
634
634
635 block = block if block is not None else self.block
635 block = block if block is not None else self.block
636 track = track if track is not None else self.track
636 track = track if track is not None else self.track
637 targets = targets if targets is not None else self.targets
637 targets = targets if targets is not None else self.targets
638 # applier = self.apply_sync if block else self.apply_async
638 # applier = self.apply_sync if block else self.apply_async
639 if not isinstance(ns, dict):
639 if not isinstance(ns, dict):
640 raise TypeError("Must be a dict, not %s"%type(ns))
640 raise TypeError("Must be a dict, not %s"%type(ns))
641 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
641 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
642
642
643 def get(self, key_s):
643 def get(self, key_s):
644 """get object(s) by `key_s` from remote namespace
644 """get object(s) by `key_s` from remote namespace
645
645
646 see `pull` for details.
646 see `pull` for details.
647 """
647 """
648 # block = block if block is not None else self.block
648 # block = block if block is not None else self.block
649 return self.pull(key_s, block=True)
649 return self.pull(key_s, block=True)
650
650
651 def pull(self, names, targets=None, block=None):
651 def pull(self, names, targets=None, block=None):
652 """get object(s) by `name` from remote namespace
652 """get object(s) by `name` from remote namespace
653
653
654 will return one object if it is a key.
654 will return one object if it is a key.
655 can also take a list of keys, in which case it will return a list of objects.
655 can also take a list of keys, in which case it will return a list of objects.
656 """
656 """
657 block = block if block is not None else self.block
657 block = block if block is not None else self.block
658 targets = targets if targets is not None else self.targets
658 targets = targets if targets is not None else self.targets
659 applier = self.apply_sync if block else self.apply_async
659 applier = self.apply_sync if block else self.apply_async
660 if isinstance(names, basestring):
660 if isinstance(names, basestring):
661 pass
661 pass
662 elif isinstance(names, (list,tuple,set)):
662 elif isinstance(names, (list,tuple,set)):
663 for key in names:
663 for key in names:
664 if not isinstance(key, basestring):
664 if not isinstance(key, basestring):
665 raise TypeError("keys must be str, not type %r"%type(key))
665 raise TypeError("keys must be str, not type %r"%type(key))
666 else:
666 else:
667 raise TypeError("names must be strs, not %r"%names)
667 raise TypeError("names must be strs, not %r"%names)
668 return self._really_apply(util._pull, (names,), block=block, targets=targets)
668 return self._really_apply(util._pull, (names,), block=block, targets=targets)
669
669
670 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
670 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
671 """
671 """
672 Partition a Python sequence and send the partitions to a set of engines.
672 Partition a Python sequence and send the partitions to a set of engines.
673 """
673 """
674 block = block if block is not None else self.block
674 block = block if block is not None else self.block
675 track = track if track is not None else self.track
675 track = track if track is not None else self.track
676 targets = targets if targets is not None else self.targets
676 targets = targets if targets is not None else self.targets
677
677
678 mapObject = Map.dists[dist]()
678 mapObject = Map.dists[dist]()
679 nparts = len(targets)
679 nparts = len(targets)
680 msg_ids = []
680 msg_ids = []
681 trackers = []
681 trackers = []
682 for index, engineid in enumerate(targets):
682 for index, engineid in enumerate(targets):
683 partition = mapObject.getPartition(seq, index, nparts)
683 partition = mapObject.getPartition(seq, index, nparts)
684 if flatten and len(partition) == 1:
684 if flatten and len(partition) == 1:
685 ns = {key: partition[0]}
685 ns = {key: partition[0]}
686 else:
686 else:
687 ns = {key: partition}
687 ns = {key: partition}
688 r = self.push(ns, block=False, track=track, targets=engineid)
688 r = self.push(ns, block=False, track=track, targets=engineid)
689 msg_ids.extend(r.msg_ids)
689 msg_ids.extend(r.msg_ids)
690 if track:
690 if track:
691 trackers.append(r._tracker)
691 trackers.append(r._tracker)
692
692
693 if track:
693 if track:
694 tracker = zmq.MessageTracker(*trackers)
694 tracker = zmq.MessageTracker(*trackers)
695 else:
695 else:
696 tracker = None
696 tracker = None
697
697
698 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
698 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
699 if block:
699 if block:
700 r.wait()
700 r.wait()
701 else:
701 else:
702 return r
702 return r
703
703
704 @sync_results
704 @sync_results
705 @save_ids
705 @save_ids
706 def gather(self, key, dist='b', targets=None, block=None):
706 def gather(self, key, dist='b', targets=None, block=None):
707 """
707 """
708 Gather a partitioned sequence on a set of engines as a single local seq.
708 Gather a partitioned sequence on a set of engines as a single local seq.
709 """
709 """
710 block = block if block is not None else self.block
710 block = block if block is not None else self.block
711 targets = targets if targets is not None else self.targets
711 targets = targets if targets is not None else self.targets
712 mapObject = Map.dists[dist]()
712 mapObject = Map.dists[dist]()
713 msg_ids = []
713 msg_ids = []
714
714
715 for index, engineid in enumerate(targets):
715 for index, engineid in enumerate(targets):
716 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
716 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
717
717
718 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
718 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
719
719
720 if block:
720 if block:
721 try:
721 try:
722 return r.get()
722 return r.get()
723 except KeyboardInterrupt:
723 except KeyboardInterrupt:
724 pass
724 pass
725 return r
725 return r
726
726
727 def __getitem__(self, key):
727 def __getitem__(self, key):
728 return self.get(key)
728 return self.get(key)
729
729
730 def __setitem__(self,key, value):
730 def __setitem__(self,key, value):
731 self.update({key:value})
731 self.update({key:value})
732
732
733 def clear(self, targets=None, block=False):
733 def clear(self, targets=None, block=False):
734 """Clear the remote namespaces on my engines."""
734 """Clear the remote namespaces on my engines."""
735 block = block if block is not None else self.block
735 block = block if block is not None else self.block
736 targets = targets if targets is not None else self.targets
736 targets = targets if targets is not None else self.targets
737 return self.client.clear(targets=targets, block=block)
737 return self.client.clear(targets=targets, block=block)
738
738
739 def kill(self, targets=None, block=True):
739 def kill(self, targets=None, block=True):
740 """Kill my engines."""
740 """Kill my engines."""
741 block = block if block is not None else self.block
741 block = block if block is not None else self.block
742 targets = targets if targets is not None else self.targets
742 targets = targets if targets is not None else self.targets
743 return self.client.kill(targets=targets, block=block)
743 return self.client.kill(targets=targets, block=block)
744
744
745 #----------------------------------------
745 #----------------------------------------
746 # activate for %px,%autopx magics
746 # activate for %px,%autopx magics
747 #----------------------------------------
747 #----------------------------------------
748 def activate(self):
748 def activate(self):
749 """Make this `View` active for parallel magic commands.
749 """Make this `View` active for parallel magic commands.
750
750
751 IPython has a magic command syntax to work with `MultiEngineClient` objects.
751 IPython has a magic command syntax to work with `MultiEngineClient` objects.
752 In a given IPython session there is a single active one. While
752 In a given IPython session there is a single active one. While
753 there can be many `Views` created and used by the user,
753 there can be many `Views` created and used by the user,
754 there is only one active one. The active `View` is used whenever
754 there is only one active one. The active `View` is used whenever
755 the magic commands %px and %autopx are used.
755 the magic commands %px and %autopx are used.
756
756
757 The activate() method is called on a given `View` to make it
757 The activate() method is called on a given `View` to make it
758 active. Once this has been done, the magic commands can be used.
758 active. Once this has been done, the magic commands can be used.
759 """
759 """
760
760
761 try:
761 try:
762 # This is injected into __builtins__.
762 # This is injected into __builtins__.
763 ip = get_ipython()
763 ip = get_ipython()
764 except NameError:
764 except NameError:
765 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
765 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
766 else:
766 else:
767 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
767 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
768 if pmagic is None:
768 if pmagic is None:
769 ip.magic_load_ext('parallelmagic')
769 ip.magic_load_ext('parallelmagic')
770 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
770 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
771
771
772 pmagic.active_view = self
772 pmagic.active_view = self
773
773
774
774
775 @testdec.skip_doctest
775 @testdec.skip_doctest
776 class LoadBalancedView(View):
776 class LoadBalancedView(View):
777 """An load-balancing View that only executes via the Task scheduler.
777 """An load-balancing View that only executes via the Task scheduler.
778
778
779 Load-balanced views can be created with the client's `view` method:
779 Load-balanced views can be created with the client's `view` method:
780
780
781 >>> v = client.load_balanced_view()
781 >>> v = client.load_balanced_view()
782
782
783 or targets can be specified, to restrict the potential destinations:
783 or targets can be specified, to restrict the potential destinations:
784
784
785 >>> v = client.client.load_balanced_view(([1,3])
785 >>> v = client.client.load_balanced_view(([1,3])
786
786
787 which would restrict loadbalancing to between engines 1 and 3.
787 which would restrict loadbalancing to between engines 1 and 3.
788
788
789 """
789 """
790
790
791 follow=Any()
791 follow=Any()
792 after=Any()
792 after=Any()
793 timeout=CFloat()
793 timeout=CFloat()
794 retries = CInt(0)
794
795
795 _task_scheme = Any()
796 _task_scheme = Any()
796 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout'])
797 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
797
798
798 def __init__(self, client=None, socket=None, **flags):
799 def __init__(self, client=None, socket=None, **flags):
799 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
800 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
800 self._task_scheme=client._task_scheme
801 self._task_scheme=client._task_scheme
801
802
802 def _validate_dependency(self, dep):
803 def _validate_dependency(self, dep):
803 """validate a dependency.
804 """validate a dependency.
804
805
805 For use in `set_flags`.
806 For use in `set_flags`.
806 """
807 """
807 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
808 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
808 return True
809 return True
809 elif isinstance(dep, (list,set, tuple)):
810 elif isinstance(dep, (list,set, tuple)):
810 for d in dep:
811 for d in dep:
811 if not isinstance(d, (str, AsyncResult)):
812 if not isinstance(d, (str, AsyncResult)):
812 return False
813 return False
813 elif isinstance(dep, dict):
814 elif isinstance(dep, dict):
814 if set(dep.keys()) != set(Dependency().as_dict().keys()):
815 if set(dep.keys()) != set(Dependency().as_dict().keys()):
815 return False
816 return False
816 if not isinstance(dep['msg_ids'], list):
817 if not isinstance(dep['msg_ids'], list):
817 return False
818 return False
818 for d in dep['msg_ids']:
819 for d in dep['msg_ids']:
819 if not isinstance(d, str):
820 if not isinstance(d, str):
820 return False
821 return False
821 else:
822 else:
822 return False
823 return False
823
824
824 return True
825 return True
825
826
826 def _render_dependency(self, dep):
827 def _render_dependency(self, dep):
827 """helper for building jsonable dependencies from various input forms."""
828 """helper for building jsonable dependencies from various input forms."""
828 if isinstance(dep, Dependency):
829 if isinstance(dep, Dependency):
829 return dep.as_dict()
830 return dep.as_dict()
830 elif isinstance(dep, AsyncResult):
831 elif isinstance(dep, AsyncResult):
831 return dep.msg_ids
832 return dep.msg_ids
832 elif dep is None:
833 elif dep is None:
833 return []
834 return []
834 else:
835 else:
835 # pass to Dependency constructor
836 # pass to Dependency constructor
836 return list(Dependency(dep))
837 return list(Dependency(dep))
837
838
838 def set_flags(self, **kwargs):
839 def set_flags(self, **kwargs):
839 """set my attribute flags by keyword.
840 """set my attribute flags by keyword.
840
841
841 A View is a wrapper for the Client's apply method, but with attributes
842 A View is a wrapper for the Client's apply method, but with attributes
842 that specify keyword arguments, those attributes can be set by keyword
843 that specify keyword arguments, those attributes can be set by keyword
843 argument with this method.
844 argument with this method.
844
845
845 Parameters
846 Parameters
846 ----------
847 ----------
847
848
848 block : bool
849 block : bool
849 whether to wait for results
850 whether to wait for results
850 track : bool
851 track : bool
851 whether to create a MessageTracker to allow the user to
852 whether to create a MessageTracker to allow the user to
852 safely edit after arrays and buffers during non-copying
853 safely edit after arrays and buffers during non-copying
853 sends.
854 sends.
854 #
855
855 after : Dependency or collection of msg_ids
856 after : Dependency or collection of msg_ids
856 Only for load-balanced execution (targets=None)
857 Only for load-balanced execution (targets=None)
857 Specify a list of msg_ids as a time-based dependency.
858 Specify a list of msg_ids as a time-based dependency.
858 This job will only be run *after* the dependencies
859 This job will only be run *after* the dependencies
859 have been met.
860 have been met.
860
861
861 follow : Dependency or collection of msg_ids
862 follow : Dependency or collection of msg_ids
862 Only for load-balanced execution (targets=None)
863 Only for load-balanced execution (targets=None)
863 Specify a list of msg_ids as a location-based dependency.
864 Specify a list of msg_ids as a location-based dependency.
864 This job will only be run on an engine where this dependency
865 This job will only be run on an engine where this dependency
865 is met.
866 is met.
866
867
867 timeout : float/int or None
868 timeout : float/int or None
868 Only for load-balanced execution (targets=None)
869 Only for load-balanced execution (targets=None)
869 Specify an amount of time (in seconds) for the scheduler to
870 Specify an amount of time (in seconds) for the scheduler to
870 wait for dependencies to be met before failing with a
871 wait for dependencies to be met before failing with a
871 DependencyTimeout.
872 DependencyTimeout.
873
874 retries : int
875 Number of times a task will be retried on failure.
872 """
876 """
873
877
874 super(LoadBalancedView, self).set_flags(**kwargs)
878 super(LoadBalancedView, self).set_flags(**kwargs)
875 for name in ('follow', 'after'):
879 for name in ('follow', 'after'):
876 if name in kwargs:
880 if name in kwargs:
877 value = kwargs[name]
881 value = kwargs[name]
878 if self._validate_dependency(value):
882 if self._validate_dependency(value):
879 setattr(self, name, value)
883 setattr(self, name, value)
880 else:
884 else:
881 raise ValueError("Invalid dependency: %r"%value)
885 raise ValueError("Invalid dependency: %r"%value)
882 if 'timeout' in kwargs:
886 if 'timeout' in kwargs:
883 t = kwargs['timeout']
887 t = kwargs['timeout']
884 if not isinstance(t, (int, long, float, type(None))):
888 if not isinstance(t, (int, long, float, type(None))):
885 raise TypeError("Invalid type for timeout: %r"%type(t))
889 raise TypeError("Invalid type for timeout: %r"%type(t))
886 if t is not None:
890 if t is not None:
887 if t < 0:
891 if t < 0:
888 raise ValueError("Invalid timeout: %s"%t)
892 raise ValueError("Invalid timeout: %s"%t)
889 self.timeout = t
893 self.timeout = t
890
894
891 @sync_results
895 @sync_results
892 @save_ids
896 @save_ids
893 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
897 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
894 after=None, follow=None, timeout=None,
898 after=None, follow=None, timeout=None,
895 targets=None):
899 targets=None, retries=None):
896 """calls f(*args, **kwargs) on a remote engine, returning the result.
900 """calls f(*args, **kwargs) on a remote engine, returning the result.
897
901
898 This method temporarily sets all of `apply`'s flags for a single call.
902 This method temporarily sets all of `apply`'s flags for a single call.
899
903
900 Parameters
904 Parameters
901 ----------
905 ----------
902
906
903 f : callable
907 f : callable
904
908
905 args : list [default: empty]
909 args : list [default: empty]
906
910
907 kwargs : dict [default: empty]
911 kwargs : dict [default: empty]
908
912
909 block : bool [default: self.block]
913 block : bool [default: self.block]
910 whether to block
914 whether to block
911 track : bool [default: self.track]
915 track : bool [default: self.track]
912 whether to ask zmq to track the message, for safe non-copying sends
916 whether to ask zmq to track the message, for safe non-copying sends
913
917
914 !!!!!! TODO: THE REST HERE !!!!
918 !!!!!! TODO: THE REST HERE !!!!
915
919
916 Returns
920 Returns
917 -------
921 -------
918
922
919 if self.block is False:
923 if self.block is False:
920 returns AsyncResult
924 returns AsyncResult
921 else:
925 else:
922 returns actual result of f(*args, **kwargs) on the engine(s)
926 returns actual result of f(*args, **kwargs) on the engine(s)
923 This will be a list of self.targets is also a list (even length 1), or
927 This will be a list of self.targets is also a list (even length 1), or
924 the single result if self.targets is an integer engine id
928 the single result if self.targets is an integer engine id
925 """
929 """
926
930
927 # validate whether we can run
931 # validate whether we can run
928 if self._socket.closed:
932 if self._socket.closed:
929 msg = "Task farming is disabled"
933 msg = "Task farming is disabled"
930 if self._task_scheme == 'pure':
934 if self._task_scheme == 'pure':
931 msg += " because the pure ZMQ scheduler cannot handle"
935 msg += " because the pure ZMQ scheduler cannot handle"
932 msg += " disappearing engines."
936 msg += " disappearing engines."
933 raise RuntimeError(msg)
937 raise RuntimeError(msg)
934
938
935 if self._task_scheme == 'pure':
939 if self._task_scheme == 'pure':
936 # pure zmq scheme doesn't support dependencies
940 # pure zmq scheme doesn't support extra features
937 msg = "Pure ZMQ scheduler doesn't support dependencies"
941 msg = "Pure ZMQ scheduler doesn't support the following flags:"
938 if (follow or after):
942 "follow, after, retries, targets, timeout"
939 # hard fail on DAG dependencies
943 if (follow or after or retries or targets or timeout):
944 # hard fail on Scheduler flags
940 raise RuntimeError(msg)
945 raise RuntimeError(msg)
941 if isinstance(f, dependent):
946 if isinstance(f, dependent):
942 # soft warn on functional dependencies
947 # soft warn on functional dependencies
943 warnings.warn(msg, RuntimeWarning)
948 warnings.warn(msg, RuntimeWarning)
944
949
945 # build args
950 # build args
946 args = [] if args is None else args
951 args = [] if args is None else args
947 kwargs = {} if kwargs is None else kwargs
952 kwargs = {} if kwargs is None else kwargs
948 block = self.block if block is None else block
953 block = self.block if block is None else block
949 track = self.track if track is None else track
954 track = self.track if track is None else track
950 after = self.after if after is None else after
955 after = self.after if after is None else after
956 retries = self.retries if retries is None else retries
951 follow = self.follow if follow is None else follow
957 follow = self.follow if follow is None else follow
952 timeout = self.timeout if timeout is None else timeout
958 timeout = self.timeout if timeout is None else timeout
953 targets = self.targets if targets is None else targets
959 targets = self.targets if targets is None else targets
954
960
961 if not isinstance(retries, int):
962 raise TypeError('retries must be int, not %r'%type(retries))
963
955 if targets is None:
964 if targets is None:
956 idents = []
965 idents = []
957 else:
966 else:
958 idents = self.client._build_targets(targets)[0]
967 idents = self.client._build_targets(targets)[0]
959
968
960 after = self._render_dependency(after)
969 after = self._render_dependency(after)
961 follow = self._render_dependency(follow)
970 follow = self._render_dependency(follow)
962 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
971 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
963
972
964 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
973 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
965 subheader=subheader)
974 subheader=subheader)
966 tracker = None if track is False else msg['tracker']
975 tracker = None if track is False else msg['tracker']
967
976
968 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
977 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
969
978
970 if block:
979 if block:
971 try:
980 try:
972 return ar.get()
981 return ar.get()
973 except KeyboardInterrupt:
982 except KeyboardInterrupt:
974 pass
983 pass
975 return ar
984 return ar
976
985
977 @spin_after
986 @spin_after
978 @save_ids
987 @save_ids
979 def map(self, f, *sequences, **kwargs):
988 def map(self, f, *sequences, **kwargs):
980 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
989 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
981
990
982 Parallel version of builtin `map`, load-balanced by this View.
991 Parallel version of builtin `map`, load-balanced by this View.
983
992
984 `block`, and `chunksize` can be specified by keyword only.
993 `block`, and `chunksize` can be specified by keyword only.
985
994
986 Each `chunksize` elements will be a separate task, and will be
995 Each `chunksize` elements will be a separate task, and will be
987 load-balanced. This lets individual elements be available for iteration
996 load-balanced. This lets individual elements be available for iteration
988 as soon as they arrive.
997 as soon as they arrive.
989
998
990 Parameters
999 Parameters
991 ----------
1000 ----------
992
1001
993 f : callable
1002 f : callable
994 function to be mapped
1003 function to be mapped
995 *sequences: one or more sequences of matching length
1004 *sequences: one or more sequences of matching length
996 the sequences to be distributed and passed to `f`
1005 the sequences to be distributed and passed to `f`
997 block : bool
1006 block : bool
998 whether to wait for the result or not [default self.block]
1007 whether to wait for the result or not [default self.block]
999 track : bool
1008 track : bool
1000 whether to create a MessageTracker to allow the user to
1009 whether to create a MessageTracker to allow the user to
1001 safely edit after arrays and buffers during non-copying
1010 safely edit after arrays and buffers during non-copying
1002 sends.
1011 sends.
1003 chunksize : int
1012 chunksize : int
1004 how many elements should be in each task [default 1]
1013 how many elements should be in each task [default 1]
1005
1014
1006 Returns
1015 Returns
1007 -------
1016 -------
1008
1017
1009 if block=False:
1018 if block=False:
1010 AsyncMapResult
1019 AsyncMapResult
1011 An object like AsyncResult, but which reassembles the sequence of results
1020 An object like AsyncResult, but which reassembles the sequence of results
1012 into a single list. AsyncMapResults can be iterated through before all
1021 into a single list. AsyncMapResults can be iterated through before all
1013 results are complete.
1022 results are complete.
1014 else:
1023 else:
1015 the result of map(f,*sequences)
1024 the result of map(f,*sequences)
1016
1025
1017 """
1026 """
1018
1027
1019 # default
1028 # default
1020 block = kwargs.get('block', self.block)
1029 block = kwargs.get('block', self.block)
1021 chunksize = kwargs.get('chunksize', 1)
1030 chunksize = kwargs.get('chunksize', 1)
1022
1031
1023 keyset = set(kwargs.keys())
1032 keyset = set(kwargs.keys())
1024 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1033 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1025 if extra_keys:
1034 if extra_keys:
1026 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1035 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1027
1036
1028 assert len(sequences) > 0, "must have some sequences to map onto!"
1037 assert len(sequences) > 0, "must have some sequences to map onto!"
1029
1038
1030 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1039 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1031 return pf.map(*sequences)
1040 return pf.map(*sequences)
1032
1041
1033 __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file
1042 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,621 +1,665 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #----------------------------------------------------------------------
14 #----------------------------------------------------------------------
15 # Imports
15 # Imports
16 #----------------------------------------------------------------------
16 #----------------------------------------------------------------------
17
17
18 from __future__ import print_function
18 from __future__ import print_function
19
19
20 import logging
20 import logging
21 import sys
21 import sys
22
22
23 from datetime import datetime, timedelta
23 from datetime import datetime, timedelta
24 from random import randint, random
24 from random import randint, random
25 from types import FunctionType
25 from types import FunctionType
26
26
27 try:
27 try:
28 import numpy
28 import numpy
29 except ImportError:
29 except ImportError:
30 numpy = None
30 numpy = None
31
31
32 import zmq
32 import zmq
33 from zmq.eventloop import ioloop, zmqstream
33 from zmq.eventloop import ioloop, zmqstream
34
34
35 # local imports
35 # local imports
36 from IPython.external.decorator import decorator
36 from IPython.external.decorator import decorator
37 from IPython.config.loader import Config
37 from IPython.config.loader import Config
38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int
38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int
39
39
40 from IPython.parallel import error
40 from IPython.parallel import error
41 from IPython.parallel.factory import SessionFactory
41 from IPython.parallel.factory import SessionFactory
42 from IPython.parallel.util import connect_logger, local_logger
42 from IPython.parallel.util import connect_logger, local_logger
43
43
44 from .dependency import Dependency
44 from .dependency import Dependency
45
45
46 @decorator
46 @decorator
47 def logged(f,self,*args,**kwargs):
47 def logged(f,self,*args,**kwargs):
48 # print ("#--------------------")
48 # print ("#--------------------")
49 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
49 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
50 # print ("#--")
50 # print ("#--")
51 return f(self,*args, **kwargs)
51 return f(self,*args, **kwargs)
52
52
53 #----------------------------------------------------------------------
53 #----------------------------------------------------------------------
54 # Chooser functions
54 # Chooser functions
55 #----------------------------------------------------------------------
55 #----------------------------------------------------------------------
56
56
57 def plainrandom(loads):
57 def plainrandom(loads):
58 """Plain random pick."""
58 """Plain random pick."""
59 n = len(loads)
59 n = len(loads)
60 return randint(0,n-1)
60 return randint(0,n-1)
61
61
62 def lru(loads):
62 def lru(loads):
63 """Always pick the front of the line.
63 """Always pick the front of the line.
64
64
65 The content of `loads` is ignored.
65 The content of `loads` is ignored.
66
66
67 Assumes LRU ordering of loads, with oldest first.
67 Assumes LRU ordering of loads, with oldest first.
68 """
68 """
69 return 0
69 return 0
70
70
71 def twobin(loads):
71 def twobin(loads):
72 """Pick two at random, use the LRU of the two.
72 """Pick two at random, use the LRU of the two.
73
73
74 The content of loads is ignored.
74 The content of loads is ignored.
75
75
76 Assumes LRU ordering of loads, with oldest first.
76 Assumes LRU ordering of loads, with oldest first.
77 """
77 """
78 n = len(loads)
78 n = len(loads)
79 a = randint(0,n-1)
79 a = randint(0,n-1)
80 b = randint(0,n-1)
80 b = randint(0,n-1)
81 return min(a,b)
81 return min(a,b)
82
82
83 def weighted(loads):
83 def weighted(loads):
84 """Pick two at random using inverse load as weight.
84 """Pick two at random using inverse load as weight.
85
85
86 Return the less loaded of the two.
86 Return the less loaded of the two.
87 """
87 """
88 # weight 0 a million times more than 1:
88 # weight 0 a million times more than 1:
89 weights = 1./(1e-6+numpy.array(loads))
89 weights = 1./(1e-6+numpy.array(loads))
90 sums = weights.cumsum()
90 sums = weights.cumsum()
91 t = sums[-1]
91 t = sums[-1]
92 x = random()*t
92 x = random()*t
93 y = random()*t
93 y = random()*t
94 idx = 0
94 idx = 0
95 idy = 0
95 idy = 0
96 while sums[idx] < x:
96 while sums[idx] < x:
97 idx += 1
97 idx += 1
98 while sums[idy] < y:
98 while sums[idy] < y:
99 idy += 1
99 idy += 1
100 if weights[idy] > weights[idx]:
100 if weights[idy] > weights[idx]:
101 return idy
101 return idy
102 else:
102 else:
103 return idx
103 return idx
104
104
105 def leastload(loads):
105 def leastload(loads):
106 """Always choose the lowest load.
106 """Always choose the lowest load.
107
107
108 If the lowest load occurs more than once, the first
108 If the lowest load occurs more than once, the first
109 occurance will be used. If loads has LRU ordering, this means
109 occurance will be used. If loads has LRU ordering, this means
110 the LRU of those with the lowest load is chosen.
110 the LRU of those with the lowest load is chosen.
111 """
111 """
112 return loads.index(min(loads))
112 return loads.index(min(loads))
113
113
114 #---------------------------------------------------------------------
114 #---------------------------------------------------------------------
115 # Classes
115 # Classes
116 #---------------------------------------------------------------------
116 #---------------------------------------------------------------------
117 # store empty default dependency:
117 # store empty default dependency:
118 MET = Dependency([])
118 MET = Dependency([])
119
119
120 class TaskScheduler(SessionFactory):
120 class TaskScheduler(SessionFactory):
121 """Python TaskScheduler object.
121 """Python TaskScheduler object.
122
122
123 This is the simplest object that supports msg_id based
123 This is the simplest object that supports msg_id based
124 DAG dependencies. *Only* task msg_ids are checked, not
124 DAG dependencies. *Only* task msg_ids are checked, not
125 msg_ids of jobs submitted via the MUX queue.
125 msg_ids of jobs submitted via the MUX queue.
126
126
127 """
127 """
128
128
129 hwm = Int(0, config=True) # limit number of outstanding tasks
129 hwm = Int(0, config=True) # limit number of outstanding tasks
130
130
131 # input arguments:
131 # input arguments:
132 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
132 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
133 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
133 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
134 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
134 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
135 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
135 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
136 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
136 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
137
137
138 # internals:
138 # internals:
139 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
139 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
140 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
140 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
141 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
141 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
142 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
142 pending = Dict() # dict by engine_uuid of submitted tasks
143 pending = Dict() # dict by engine_uuid of submitted tasks
143 completed = Dict() # dict by engine_uuid of completed tasks
144 completed = Dict() # dict by engine_uuid of completed tasks
144 failed = Dict() # dict by engine_uuid of failed tasks
145 failed = Dict() # dict by engine_uuid of failed tasks
145 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
146 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
146 clients = Dict() # dict by msg_id for who submitted the task
147 clients = Dict() # dict by msg_id for who submitted the task
147 targets = List() # list of target IDENTs
148 targets = List() # list of target IDENTs
148 loads = List() # list of engine loads
149 loads = List() # list of engine loads
149 # full = Set() # set of IDENTs that have HWM outstanding tasks
150 # full = Set() # set of IDENTs that have HWM outstanding tasks
150 all_completed = Set() # set of all completed tasks
151 all_completed = Set() # set of all completed tasks
151 all_failed = Set() # set of all failed tasks
152 all_failed = Set() # set of all failed tasks
152 all_done = Set() # set of all finished tasks=union(completed,failed)
153 all_done = Set() # set of all finished tasks=union(completed,failed)
153 all_ids = Set() # set of all submitted task IDs
154 all_ids = Set() # set of all submitted task IDs
154 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
155 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
155 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
156 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
156
157
157
158
158 def start(self):
159 def start(self):
159 self.engine_stream.on_recv(self.dispatch_result, copy=False)
160 self.engine_stream.on_recv(self.dispatch_result, copy=False)
160 self._notification_handlers = dict(
161 self._notification_handlers = dict(
161 registration_notification = self._register_engine,
162 registration_notification = self._register_engine,
162 unregistration_notification = self._unregister_engine
163 unregistration_notification = self._unregister_engine
163 )
164 )
164 self.notifier_stream.on_recv(self.dispatch_notification)
165 self.notifier_stream.on_recv(self.dispatch_notification)
165 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
166 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
166 self.auditor.start()
167 self.auditor.start()
167 self.log.info("Scheduler started...%r"%self)
168 self.log.info("Scheduler started...%r"%self)
168
169
169 def resume_receiving(self):
170 def resume_receiving(self):
170 """Resume accepting jobs."""
171 """Resume accepting jobs."""
171 self.client_stream.on_recv(self.dispatch_submission, copy=False)
172 self.client_stream.on_recv(self.dispatch_submission, copy=False)
172
173
173 def stop_receiving(self):
174 def stop_receiving(self):
174 """Stop accepting jobs while there are no engines.
175 """Stop accepting jobs while there are no engines.
175 Leave them in the ZMQ queue."""
176 Leave them in the ZMQ queue."""
176 self.client_stream.on_recv(None)
177 self.client_stream.on_recv(None)
177
178
178 #-----------------------------------------------------------------------
179 #-----------------------------------------------------------------------
179 # [Un]Registration Handling
180 # [Un]Registration Handling
180 #-----------------------------------------------------------------------
181 #-----------------------------------------------------------------------
181
182
182 def dispatch_notification(self, msg):
183 def dispatch_notification(self, msg):
183 """dispatch register/unregister events."""
184 """dispatch register/unregister events."""
184 idents,msg = self.session.feed_identities(msg)
185 idents,msg = self.session.feed_identities(msg)
185 msg = self.session.unpack_message(msg)
186 msg = self.session.unpack_message(msg)
186 msg_type = msg['msg_type']
187 msg_type = msg['msg_type']
187 handler = self._notification_handlers.get(msg_type, None)
188 handler = self._notification_handlers.get(msg_type, None)
188 if handler is None:
189 if handler is None:
189 raise Exception("Unhandled message type: %s"%msg_type)
190 raise Exception("Unhandled message type: %s"%msg_type)
190 else:
191 else:
191 try:
192 try:
192 handler(str(msg['content']['queue']))
193 handler(str(msg['content']['queue']))
193 except KeyError:
194 except KeyError:
194 self.log.error("task::Invalid notification msg: %s"%msg)
195 self.log.error("task::Invalid notification msg: %s"%msg)
195
196
196 @logged
197 @logged
197 def _register_engine(self, uid):
198 def _register_engine(self, uid):
198 """New engine with ident `uid` became available."""
199 """New engine with ident `uid` became available."""
199 # head of the line:
200 # head of the line:
200 self.targets.insert(0,uid)
201 self.targets.insert(0,uid)
201 self.loads.insert(0,0)
202 self.loads.insert(0,0)
202 # initialize sets
203 # initialize sets
203 self.completed[uid] = set()
204 self.completed[uid] = set()
204 self.failed[uid] = set()
205 self.failed[uid] = set()
205 self.pending[uid] = {}
206 self.pending[uid] = {}
206 if len(self.targets) == 1:
207 if len(self.targets) == 1:
207 self.resume_receiving()
208 self.resume_receiving()
209 # rescan the graph:
210 self.update_graph(None)
208
211
209 def _unregister_engine(self, uid):
212 def _unregister_engine(self, uid):
210 """Existing engine with ident `uid` became unavailable."""
213 """Existing engine with ident `uid` became unavailable."""
211 if len(self.targets) == 1:
214 if len(self.targets) == 1:
212 # this was our only engine
215 # this was our only engine
213 self.stop_receiving()
216 self.stop_receiving()
214
217
215 # handle any potentially finished tasks:
218 # handle any potentially finished tasks:
216 self.engine_stream.flush()
219 self.engine_stream.flush()
217
220
218 self.completed.pop(uid)
221 # don't pop destinations, because they might be used later
219 self.failed.pop(uid)
220 # don't pop destinations, because it might be used later
221 # map(self.destinations.pop, self.completed.pop(uid))
222 # map(self.destinations.pop, self.completed.pop(uid))
222 # map(self.destinations.pop, self.failed.pop(uid))
223 # map(self.destinations.pop, self.failed.pop(uid))
224
225 # prevent this engine from receiving work
223 idx = self.targets.index(uid)
226 idx = self.targets.index(uid)
224 self.targets.pop(idx)
227 self.targets.pop(idx)
225 self.loads.pop(idx)
228 self.loads.pop(idx)
226
229
227 # wait 5 seconds before cleaning up pending jobs, since the results might
230 # wait 5 seconds before cleaning up pending jobs, since the results might
228 # still be incoming
231 # still be incoming
229 if self.pending[uid]:
232 if self.pending[uid]:
230 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
233 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
231 dc.start()
234 dc.start()
235 else:
236 self.completed.pop(uid)
237 self.failed.pop(uid)
238
232
239
233 @logged
240 @logged
234 def handle_stranded_tasks(self, engine):
241 def handle_stranded_tasks(self, engine):
235 """Deal with jobs resident in an engine that died."""
242 """Deal with jobs resident in an engine that died."""
236 lost = self.pending.pop(engine)
243 lost = self.pending[engine]
237
244 for msg_id in lost.keys():
238 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
245 if msg_id not in self.pending[engine]:
239 self.all_failed.add(msg_id)
246 # prevent double-handling of messages
240 self.all_done.add(msg_id)
247 continue
248
249 raw_msg = lost[msg_id][0]
250
241 idents,msg = self.session.feed_identities(raw_msg, copy=False)
251 idents,msg = self.session.feed_identities(raw_msg, copy=False)
242 msg = self.session.unpack_message(msg, copy=False, content=False)
252 msg = self.session.unpack_message(msg, copy=False, content=False)
243 parent = msg['header']
253 parent = msg['header']
244 idents = [idents[0],engine]+idents[1:]
254 idents = [engine, idents[0]]
245 # print (idents)
255
256 # build fake error reply
246 try:
257 try:
247 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
258 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
248 except:
259 except:
249 content = error.wrap_exception()
260 content = error.wrap_exception()
250 msg = self.session.send(self.client_stream, 'apply_reply', content,
261 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
251 parent=parent, ident=idents)
262 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
252 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
263 # and dispatch it
253 self.update_graph(msg_id)
264 self.dispatch_result(raw_reply)
265
266 # finally scrub completed/failed lists
267 self.completed.pop(engine)
268 self.failed.pop(engine)
254
269
255
270
256 #-----------------------------------------------------------------------
271 #-----------------------------------------------------------------------
257 # Job Submission
272 # Job Submission
258 #-----------------------------------------------------------------------
273 #-----------------------------------------------------------------------
259 @logged
274 @logged
260 def dispatch_submission(self, raw_msg):
275 def dispatch_submission(self, raw_msg):
261 """Dispatch job submission to appropriate handlers."""
276 """Dispatch job submission to appropriate handlers."""
262 # ensure targets up to date:
277 # ensure targets up to date:
263 self.notifier_stream.flush()
278 self.notifier_stream.flush()
264 try:
279 try:
265 idents, msg = self.session.feed_identities(raw_msg, copy=False)
280 idents, msg = self.session.feed_identities(raw_msg, copy=False)
266 msg = self.session.unpack_message(msg, content=False, copy=False)
281 msg = self.session.unpack_message(msg, content=False, copy=False)
267 except Exception:
282 except Exception:
268 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
283 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
269 return
284 return
270
285
271 # send to monitor
286 # send to monitor
272 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
287 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
273
288
274 header = msg['header']
289 header = msg['header']
275 msg_id = header['msg_id']
290 msg_id = header['msg_id']
276 self.all_ids.add(msg_id)
291 self.all_ids.add(msg_id)
277
292
278 # targets
293 # targets
279 targets = set(header.get('targets', []))
294 targets = set(header.get('targets', []))
295 retries = header.get('retries', 0)
296 self.retries[msg_id] = retries
280
297
281 # time dependencies
298 # time dependencies
282 after = Dependency(header.get('after', []))
299 after = Dependency(header.get('after', []))
283 if after.all:
300 if after.all:
284 if after.success:
301 if after.success:
285 after.difference_update(self.all_completed)
302 after.difference_update(self.all_completed)
286 if after.failure:
303 if after.failure:
287 after.difference_update(self.all_failed)
304 after.difference_update(self.all_failed)
288 if after.check(self.all_completed, self.all_failed):
305 if after.check(self.all_completed, self.all_failed):
289 # recast as empty set, if `after` already met,
306 # recast as empty set, if `after` already met,
290 # to prevent unnecessary set comparisons
307 # to prevent unnecessary set comparisons
291 after = MET
308 after = MET
292
309
293 # location dependencies
310 # location dependencies
294 follow = Dependency(header.get('follow', []))
311 follow = Dependency(header.get('follow', []))
295
312
296 # turn timeouts into datetime objects:
313 # turn timeouts into datetime objects:
297 timeout = header.get('timeout', None)
314 timeout = header.get('timeout', None)
298 if timeout:
315 if timeout:
299 timeout = datetime.now() + timedelta(0,timeout,0)
316 timeout = datetime.now() + timedelta(0,timeout,0)
300
317
301 args = [raw_msg, targets, after, follow, timeout]
318 args = [raw_msg, targets, after, follow, timeout]
302
319
303 # validate and reduce dependencies:
320 # validate and reduce dependencies:
304 for dep in after,follow:
321 for dep in after,follow:
305 # check valid:
322 # check valid:
306 if msg_id in dep or dep.difference(self.all_ids):
323 if msg_id in dep or dep.difference(self.all_ids):
307 self.depending[msg_id] = args
324 self.depending[msg_id] = args
308 return self.fail_unreachable(msg_id, error.InvalidDependency)
325 return self.fail_unreachable(msg_id, error.InvalidDependency)
309 # check if unreachable:
326 # check if unreachable:
310 if dep.unreachable(self.all_completed, self.all_failed):
327 if dep.unreachable(self.all_completed, self.all_failed):
311 self.depending[msg_id] = args
328 self.depending[msg_id] = args
312 return self.fail_unreachable(msg_id)
329 return self.fail_unreachable(msg_id)
313
330
314 if after.check(self.all_completed, self.all_failed):
331 if after.check(self.all_completed, self.all_failed):
315 # time deps already met, try to run
332 # time deps already met, try to run
316 if not self.maybe_run(msg_id, *args):
333 if not self.maybe_run(msg_id, *args):
317 # can't run yet
334 # can't run yet
318 self.save_unmet(msg_id, *args)
335 if msg_id not in self.all_failed:
336 # could have failed as unreachable
337 self.save_unmet(msg_id, *args)
319 else:
338 else:
320 self.save_unmet(msg_id, *args)
339 self.save_unmet(msg_id, *args)
321
340
322 # @logged
341 # @logged
323 def audit_timeouts(self):
342 def audit_timeouts(self):
324 """Audit all waiting tasks for expired timeouts."""
343 """Audit all waiting tasks for expired timeouts."""
325 now = datetime.now()
344 now = datetime.now()
326 for msg_id in self.depending.keys():
345 for msg_id in self.depending.keys():
327 # must recheck, in case one failure cascaded to another:
346 # must recheck, in case one failure cascaded to another:
328 if msg_id in self.depending:
347 if msg_id in self.depending:
329 raw,after,targets,follow,timeout = self.depending[msg_id]
348 raw,after,targets,follow,timeout = self.depending[msg_id]
330 if timeout and timeout < now:
349 if timeout and timeout < now:
331 self.fail_unreachable(msg_id, timeout=True)
350 self.fail_unreachable(msg_id, error.TaskTimeout)
332
351
333 @logged
352 @logged
334 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
353 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
335 """a task has become unreachable, send a reply with an ImpossibleDependency
354 """a task has become unreachable, send a reply with an ImpossibleDependency
336 error."""
355 error."""
337 if msg_id not in self.depending:
356 if msg_id not in self.depending:
338 self.log.error("msg %r already failed!"%msg_id)
357 self.log.error("msg %r already failed!"%msg_id)
339 return
358 return
340 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
359 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
341 for mid in follow.union(after):
360 for mid in follow.union(after):
342 if mid in self.graph:
361 if mid in self.graph:
343 self.graph[mid].remove(msg_id)
362 self.graph[mid].remove(msg_id)
344
363
345 # FIXME: unpacking a message I've already unpacked, but didn't save:
364 # FIXME: unpacking a message I've already unpacked, but didn't save:
346 idents,msg = self.session.feed_identities(raw_msg, copy=False)
365 idents,msg = self.session.feed_identities(raw_msg, copy=False)
347 msg = self.session.unpack_message(msg, copy=False, content=False)
366 msg = self.session.unpack_message(msg, copy=False, content=False)
348 header = msg['header']
367 header = msg['header']
349
368
350 try:
369 try:
351 raise why()
370 raise why()
352 except:
371 except:
353 content = error.wrap_exception()
372 content = error.wrap_exception()
354
373
355 self.all_done.add(msg_id)
374 self.all_done.add(msg_id)
356 self.all_failed.add(msg_id)
375 self.all_failed.add(msg_id)
357
376
358 msg = self.session.send(self.client_stream, 'apply_reply', content,
377 msg = self.session.send(self.client_stream, 'apply_reply', content,
359 parent=header, ident=idents)
378 parent=header, ident=idents)
360 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
379 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
361
380
362 self.update_graph(msg_id, success=False)
381 self.update_graph(msg_id, success=False)
363
382
364 @logged
383 @logged
365 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
384 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
366 """check location dependencies, and run if they are met."""
385 """check location dependencies, and run if they are met."""
367 blacklist = self.blacklist.setdefault(msg_id, set())
386 blacklist = self.blacklist.setdefault(msg_id, set())
368 if follow or targets or blacklist or self.hwm:
387 if follow or targets or blacklist or self.hwm:
369 # we need a can_run filter
388 # we need a can_run filter
370 def can_run(idx):
389 def can_run(idx):
371 # check hwm
390 # check hwm
372 if self.loads[idx] == self.hwm:
391 if self.hwm and self.loads[idx] == self.hwm:
373 return False
392 return False
374 target = self.targets[idx]
393 target = self.targets[idx]
375 # check blacklist
394 # check blacklist
376 if target in blacklist:
395 if target in blacklist:
377 return False
396 return False
378 # check targets
397 # check targets
379 if targets and target not in targets:
398 if targets and target not in targets:
380 return False
399 return False
381 # check follow
400 # check follow
382 return follow.check(self.completed[target], self.failed[target])
401 return follow.check(self.completed[target], self.failed[target])
383
402
384 indices = filter(can_run, range(len(self.targets)))
403 indices = filter(can_run, range(len(self.targets)))
404
385 if not indices:
405 if not indices:
386 # couldn't run
406 # couldn't run
387 if follow.all:
407 if follow.all:
388 # check follow for impossibility
408 # check follow for impossibility
389 dests = set()
409 dests = set()
390 relevant = set()
410 relevant = set()
391 if follow.success:
411 if follow.success:
392 relevant = self.all_completed
412 relevant = self.all_completed
393 if follow.failure:
413 if follow.failure:
394 relevant = relevant.union(self.all_failed)
414 relevant = relevant.union(self.all_failed)
395 for m in follow.intersection(relevant):
415 for m in follow.intersection(relevant):
396 dests.add(self.destinations[m])
416 dests.add(self.destinations[m])
397 if len(dests) > 1:
417 if len(dests) > 1:
418 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
398 self.fail_unreachable(msg_id)
419 self.fail_unreachable(msg_id)
399 return False
420 return False
400 if targets:
421 if targets:
401 # check blacklist+targets for impossibility
422 # check blacklist+targets for impossibility
402 targets.difference_update(blacklist)
423 targets.difference_update(blacklist)
403 if not targets or not targets.intersection(self.targets):
424 if not targets or not targets.intersection(self.targets):
425 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
404 self.fail_unreachable(msg_id)
426 self.fail_unreachable(msg_id)
405 return False
427 return False
406 return False
428 return False
407 else:
429 else:
408 indices = None
430 indices = None
409
431
410 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
432 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
411 return True
433 return True
412
434
413 @logged
435 @logged
414 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
436 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
415 """Save a message for later submission when its dependencies are met."""
437 """Save a message for later submission when its dependencies are met."""
416 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
438 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
417 # track the ids in follow or after, but not those already finished
439 # track the ids in follow or after, but not those already finished
418 for dep_id in after.union(follow).difference(self.all_done):
440 for dep_id in after.union(follow).difference(self.all_done):
419 if dep_id not in self.graph:
441 if dep_id not in self.graph:
420 self.graph[dep_id] = set()
442 self.graph[dep_id] = set()
421 self.graph[dep_id].add(msg_id)
443 self.graph[dep_id].add(msg_id)
422
444
423 @logged
445 @logged
424 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
446 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
425 """Submit a task to any of a subset of our targets."""
447 """Submit a task to any of a subset of our targets."""
426 if indices:
448 if indices:
427 loads = [self.loads[i] for i in indices]
449 loads = [self.loads[i] for i in indices]
428 else:
450 else:
429 loads = self.loads
451 loads = self.loads
430 idx = self.scheme(loads)
452 idx = self.scheme(loads)
431 if indices:
453 if indices:
432 idx = indices[idx]
454 idx = indices[idx]
433 target = self.targets[idx]
455 target = self.targets[idx]
434 # print (target, map(str, msg[:3]))
456 # print (target, map(str, msg[:3]))
435 # send job to the engine
457 # send job to the engine
436 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
458 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
437 self.engine_stream.send_multipart(raw_msg, copy=False)
459 self.engine_stream.send_multipart(raw_msg, copy=False)
438 # update load
460 # update load
439 self.add_job(idx)
461 self.add_job(idx)
440 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
462 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
441 # notify Hub
463 # notify Hub
442 content = dict(msg_id=msg_id, engine_id=target)
464 content = dict(msg_id=msg_id, engine_id=target)
443 self.session.send(self.mon_stream, 'task_destination', content=content,
465 self.session.send(self.mon_stream, 'task_destination', content=content,
444 ident=['tracktask',self.session.session])
466 ident=['tracktask',self.session.session])
445
467
446
468
447 #-----------------------------------------------------------------------
469 #-----------------------------------------------------------------------
448 # Result Handling
470 # Result Handling
449 #-----------------------------------------------------------------------
471 #-----------------------------------------------------------------------
450 @logged
472 @logged
451 def dispatch_result(self, raw_msg):
473 def dispatch_result(self, raw_msg):
452 """dispatch method for result replies"""
474 """dispatch method for result replies"""
453 try:
475 try:
454 idents,msg = self.session.feed_identities(raw_msg, copy=False)
476 idents,msg = self.session.feed_identities(raw_msg, copy=False)
455 msg = self.session.unpack_message(msg, content=False, copy=False)
477 msg = self.session.unpack_message(msg, content=False, copy=False)
456 engine = idents[0]
478 engine = idents[0]
457 idx = self.targets.index(engine)
479 try:
458 self.finish_job(idx)
480 idx = self.targets.index(engine)
481 except ValueError:
482 pass # skip load-update for dead engines
483 else:
484 self.finish_job(idx)
459 except Exception:
485 except Exception:
460 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
486 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
461 return
487 return
462
488
463 header = msg['header']
489 header = msg['header']
490 parent = msg['parent_header']
464 if header.get('dependencies_met', True):
491 if header.get('dependencies_met', True):
465 success = (header['status'] == 'ok')
492 success = (header['status'] == 'ok')
466 self.handle_result(idents, msg['parent_header'], raw_msg, success)
493 msg_id = parent['msg_id']
467 # send to Hub monitor
494 retries = self.retries[msg_id]
468 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
495 if not success and retries > 0:
496 # failed
497 self.retries[msg_id] = retries - 1
498 self.handle_unmet_dependency(idents, parent)
499 else:
500 del self.retries[msg_id]
501 # relay to client and update graph
502 self.handle_result(idents, parent, raw_msg, success)
503 # send to Hub monitor
504 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
469 else:
505 else:
470 self.handle_unmet_dependency(idents, msg['parent_header'])
506 self.handle_unmet_dependency(idents, parent)
471
507
472 @logged
508 @logged
473 def handle_result(self, idents, parent, raw_msg, success=True):
509 def handle_result(self, idents, parent, raw_msg, success=True):
474 """handle a real task result, either success or failure"""
510 """handle a real task result, either success or failure"""
475 # first, relay result to client
511 # first, relay result to client
476 engine = idents[0]
512 engine = idents[0]
477 client = idents[1]
513 client = idents[1]
478 # swap_ids for XREP-XREP mirror
514 # swap_ids for XREP-XREP mirror
479 raw_msg[:2] = [client,engine]
515 raw_msg[:2] = [client,engine]
480 # print (map(str, raw_msg[:4]))
516 # print (map(str, raw_msg[:4]))
481 self.client_stream.send_multipart(raw_msg, copy=False)
517 self.client_stream.send_multipart(raw_msg, copy=False)
482 # now, update our data structures
518 # now, update our data structures
483 msg_id = parent['msg_id']
519 msg_id = parent['msg_id']
484 self.blacklist.pop(msg_id, None)
520 self.blacklist.pop(msg_id, None)
485 self.pending[engine].pop(msg_id)
521 self.pending[engine].pop(msg_id)
486 if success:
522 if success:
487 self.completed[engine].add(msg_id)
523 self.completed[engine].add(msg_id)
488 self.all_completed.add(msg_id)
524 self.all_completed.add(msg_id)
489 else:
525 else:
490 self.failed[engine].add(msg_id)
526 self.failed[engine].add(msg_id)
491 self.all_failed.add(msg_id)
527 self.all_failed.add(msg_id)
492 self.all_done.add(msg_id)
528 self.all_done.add(msg_id)
493 self.destinations[msg_id] = engine
529 self.destinations[msg_id] = engine
494
530
495 self.update_graph(msg_id, success)
531 self.update_graph(msg_id, success)
496
532
497 @logged
533 @logged
498 def handle_unmet_dependency(self, idents, parent):
534 def handle_unmet_dependency(self, idents, parent):
499 """handle an unmet dependency"""
535 """handle an unmet dependency"""
500 engine = idents[0]
536 engine = idents[0]
501 msg_id = parent['msg_id']
537 msg_id = parent['msg_id']
502
538
503 if msg_id not in self.blacklist:
539 if msg_id not in self.blacklist:
504 self.blacklist[msg_id] = set()
540 self.blacklist[msg_id] = set()
505 self.blacklist[msg_id].add(engine)
541 self.blacklist[msg_id].add(engine)
506
542
507 args = self.pending[engine].pop(msg_id)
543 args = self.pending[engine].pop(msg_id)
508 raw,targets,after,follow,timeout = args
544 raw,targets,after,follow,timeout = args
509
545
510 if self.blacklist[msg_id] == targets:
546 if self.blacklist[msg_id] == targets:
511 self.depending[msg_id] = args
547 self.depending[msg_id] = args
512 self.fail_unreachable(msg_id)
548 self.fail_unreachable(msg_id)
513 elif not self.maybe_run(msg_id, *args):
549 elif not self.maybe_run(msg_id, *args):
514 # resubmit failed, put it back in our dependency tree
550 # resubmit failed
515 self.save_unmet(msg_id, *args)
551 if msg_id not in self.all_failed:
552 # put it back in our dependency tree
553 self.save_unmet(msg_id, *args)
516
554
517 if self.hwm:
555 if self.hwm:
518 idx = self.targets.index(engine)
556 try:
519 if self.loads[idx] == self.hwm-1:
557 idx = self.targets.index(engine)
520 self.update_graph(None)
558 except ValueError:
559 pass # skip load-update for dead engines
560 else:
561 if self.loads[idx] == self.hwm-1:
562 self.update_graph(None)
521
563
522
564
523
565
524 @logged
566 @logged
525 def update_graph(self, dep_id=None, success=True):
567 def update_graph(self, dep_id=None, success=True):
526 """dep_id just finished. Update our dependency
568 """dep_id just finished. Update our dependency
527 graph and submit any jobs that just became runable.
569 graph and submit any jobs that just became runable.
528
570
529 Called with dep_id=None to update graph for hwm, but without finishing
571 Called with dep_id=None to update entire graph for hwm, but without finishing
530 a task.
572 a task.
531 """
573 """
532 # print ("\n\n***********")
574 # print ("\n\n***********")
533 # pprint (dep_id)
575 # pprint (dep_id)
534 # pprint (self.graph)
576 # pprint (self.graph)
535 # pprint (self.depending)
577 # pprint (self.depending)
536 # pprint (self.all_completed)
578 # pprint (self.all_completed)
537 # pprint (self.all_failed)
579 # pprint (self.all_failed)
538 # print ("\n\n***********\n\n")
580 # print ("\n\n***********\n\n")
539 # update any jobs that depended on the dependency
581 # update any jobs that depended on the dependency
540 jobs = self.graph.pop(dep_id, [])
582 jobs = self.graph.pop(dep_id, [])
541 # if we have HWM and an engine just become no longer full
583
542 # recheck *all* jobs:
584 # recheck *all* jobs if
543 if self.hwm and any( [ load==self.hwm-1 for load in self.loads]):
585 # a) we have HWM and an engine just become no longer full
586 # or b) dep_id was given as None
587 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
544 jobs = self.depending.keys()
588 jobs = self.depending.keys()
545
589
546 for msg_id in jobs:
590 for msg_id in jobs:
547 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
591 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
548
592
549 if after.unreachable(self.all_completed, self.all_failed) or follow.unreachable(self.all_completed, self.all_failed):
593 if after.unreachable(self.all_completed, self.all_failed) or follow.unreachable(self.all_completed, self.all_failed):
550 self.fail_unreachable(msg_id)
594 self.fail_unreachable(msg_id)
551
595
552 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
596 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
553 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
597 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
554
598
555 self.depending.pop(msg_id)
599 self.depending.pop(msg_id)
556 for mid in follow.union(after):
600 for mid in follow.union(after):
557 if mid in self.graph:
601 if mid in self.graph:
558 self.graph[mid].remove(msg_id)
602 self.graph[mid].remove(msg_id)
559
603
560 #----------------------------------------------------------------------
604 #----------------------------------------------------------------------
561 # methods to be overridden by subclasses
605 # methods to be overridden by subclasses
562 #----------------------------------------------------------------------
606 #----------------------------------------------------------------------
563
607
564 def add_job(self, idx):
608 def add_job(self, idx):
565 """Called after self.targets[idx] just got the job with header.
609 """Called after self.targets[idx] just got the job with header.
566 Override with subclasses. The default ordering is simple LRU.
610 Override with subclasses. The default ordering is simple LRU.
567 The default loads are the number of outstanding jobs."""
611 The default loads are the number of outstanding jobs."""
568 self.loads[idx] += 1
612 self.loads[idx] += 1
569 for lis in (self.targets, self.loads):
613 for lis in (self.targets, self.loads):
570 lis.append(lis.pop(idx))
614 lis.append(lis.pop(idx))
571
615
572
616
573 def finish_job(self, idx):
617 def finish_job(self, idx):
574 """Called after self.targets[idx] just finished a job.
618 """Called after self.targets[idx] just finished a job.
575 Override with subclasses."""
619 Override with subclasses."""
576 self.loads[idx] -= 1
620 self.loads[idx] -= 1
577
621
578
622
579
623
580 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
624 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
581 log_addr=None, loglevel=logging.DEBUG, scheme='lru',
625 log_addr=None, loglevel=logging.DEBUG, scheme='lru',
582 identity=b'task'):
626 identity=b'task'):
583 from zmq.eventloop import ioloop
627 from zmq.eventloop import ioloop
584 from zmq.eventloop.zmqstream import ZMQStream
628 from zmq.eventloop.zmqstream import ZMQStream
585
629
586 if config:
630 if config:
587 # unwrap dict back into Config
631 # unwrap dict back into Config
588 config = Config(config)
632 config = Config(config)
589
633
590 ctx = zmq.Context()
634 ctx = zmq.Context()
591 loop = ioloop.IOLoop()
635 loop = ioloop.IOLoop()
592 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
636 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
593 ins.setsockopt(zmq.IDENTITY, identity)
637 ins.setsockopt(zmq.IDENTITY, identity)
594 ins.bind(in_addr)
638 ins.bind(in_addr)
595
639
596 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
640 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
597 outs.setsockopt(zmq.IDENTITY, identity)
641 outs.setsockopt(zmq.IDENTITY, identity)
598 outs.bind(out_addr)
642 outs.bind(out_addr)
599 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
643 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
600 mons.connect(mon_addr)
644 mons.connect(mon_addr)
601 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
645 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
602 nots.setsockopt(zmq.SUBSCRIBE, '')
646 nots.setsockopt(zmq.SUBSCRIBE, '')
603 nots.connect(not_addr)
647 nots.connect(not_addr)
604
648
605 scheme = globals().get(scheme, None)
649 scheme = globals().get(scheme, None)
606 # setup logging
650 # setup logging
607 if log_addr:
651 if log_addr:
608 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
652 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
609 else:
653 else:
610 local_logger(logname, loglevel)
654 local_logger(logname, loglevel)
611
655
612 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
656 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
613 mon_stream=mons, notifier_stream=nots,
657 mon_stream=mons, notifier_stream=nots,
614 scheme=scheme, loop=loop, logname=logname,
658 scheme=scheme, loop=loop, logname=logname,
615 config=config)
659 config=config)
616 scheduler.start()
660 scheduler.start()
617 try:
661 try:
618 loop.start()
662 loop.start()
619 except KeyboardInterrupt:
663 except KeyboardInterrupt:
620 print ("interrupted, exiting...", file=sys.__stderr__)
664 print ("interrupted, exiting...", file=sys.__stderr__)
621
665
@@ -1,107 +1,107 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import tempfile
15 import tempfile
16 import time
16 import time
17 from subprocess import Popen
17 from subprocess import Popen
18
18
19 from IPython.utils.path import get_ipython_dir
19 from IPython.utils.path import get_ipython_dir
20 from IPython.parallel import Client
20 from IPython.parallel import Client
21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
22 ipengine_cmd_argv,
22 ipengine_cmd_argv,
23 ipcontroller_cmd_argv,
23 ipcontroller_cmd_argv,
24 SIGKILL)
24 SIGKILL)
25
25
26 # globals
26 # globals
27 launchers = []
27 launchers = []
28 blackhole = open(os.devnull, 'w')
28 blackhole = open(os.devnull, 'w')
29
29
30 # Launcher class
30 # Launcher class
31 class TestProcessLauncher(LocalProcessLauncher):
31 class TestProcessLauncher(LocalProcessLauncher):
32 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
32 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
33 def start(self):
33 def start(self):
34 if self.state == 'before':
34 if self.state == 'before':
35 self.process = Popen(self.args,
35 self.process = Popen(self.args,
36 stdout=blackhole, stderr=blackhole,
36 stdout=blackhole, stderr=blackhole,
37 env=os.environ,
37 env=os.environ,
38 cwd=self.work_dir
38 cwd=self.work_dir
39 )
39 )
40 self.notify_start(self.process.pid)
40 self.notify_start(self.process.pid)
41 self.poll = self.process.poll
41 self.poll = self.process.poll
42 else:
42 else:
43 s = 'The process was already started and has state: %r' % self.state
43 s = 'The process was already started and has state: %r' % self.state
44 raise ProcessStateError(s)
44 raise ProcessStateError(s)
45
45
46 # nose setup/teardown
46 # nose setup/teardown
47
47
48 def setup():
48 def setup():
49 cp = TestProcessLauncher()
49 cp = TestProcessLauncher()
50 cp.cmd_and_args = ipcontroller_cmd_argv + \
50 cp.cmd_and_args = ipcontroller_cmd_argv + \
51 ['--profile', 'iptest', '--log-level', '99', '-r', '--usethreads']
51 ['--profile', 'iptest', '--log-level', '99', '-r']
52 cp.start()
52 cp.start()
53 launchers.append(cp)
53 launchers.append(cp)
54 cluster_dir = os.path.join(get_ipython_dir(), 'cluster_iptest')
54 cluster_dir = os.path.join(get_ipython_dir(), 'cluster_iptest')
55 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
55 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
56 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
56 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
57 tic = time.time()
57 tic = time.time()
58 while not os.path.exists(engine_json) or not os.path.exists(client_json):
58 while not os.path.exists(engine_json) or not os.path.exists(client_json):
59 if cp.poll() is not None:
59 if cp.poll() is not None:
60 print cp.poll()
60 print cp.poll()
61 raise RuntimeError("The test controller failed to start.")
61 raise RuntimeError("The test controller failed to start.")
62 elif time.time()-tic > 10:
62 elif time.time()-tic > 10:
63 raise RuntimeError("Timeout waiting for the test controller to start.")
63 raise RuntimeError("Timeout waiting for the test controller to start.")
64 time.sleep(0.1)
64 time.sleep(0.1)
65 add_engines(1)
65 add_engines(1)
66
66
67 def add_engines(n=1, profile='iptest'):
67 def add_engines(n=1, profile='iptest'):
68 rc = Client(profile=profile)
68 rc = Client(profile=profile)
69 base = len(rc)
69 base = len(rc)
70 eps = []
70 eps = []
71 for i in range(n):
71 for i in range(n):
72 ep = TestProcessLauncher()
72 ep = TestProcessLauncher()
73 ep.cmd_and_args = ipengine_cmd_argv + ['--profile', profile, '--log-level', '99']
73 ep.cmd_and_args = ipengine_cmd_argv + ['--profile', profile, '--log-level', '99']
74 ep.start()
74 ep.start()
75 launchers.append(ep)
75 launchers.append(ep)
76 eps.append(ep)
76 eps.append(ep)
77 tic = time.time()
77 tic = time.time()
78 while len(rc) < base+n:
78 while len(rc) < base+n:
79 if any([ ep.poll() is not None for ep in eps ]):
79 if any([ ep.poll() is not None for ep in eps ]):
80 raise RuntimeError("A test engine failed to start.")
80 raise RuntimeError("A test engine failed to start.")
81 elif time.time()-tic > 10:
81 elif time.time()-tic > 10:
82 raise RuntimeError("Timeout waiting for engines to connect.")
82 raise RuntimeError("Timeout waiting for engines to connect.")
83 time.sleep(.1)
83 time.sleep(.1)
84 rc.spin()
84 rc.spin()
85 rc.close()
85 rc.close()
86 return eps
86 return eps
87
87
88 def teardown():
88 def teardown():
89 time.sleep(1)
89 time.sleep(1)
90 while launchers:
90 while launchers:
91 p = launchers.pop()
91 p = launchers.pop()
92 if p.poll() is None:
92 if p.poll() is None:
93 try:
93 try:
94 p.stop()
94 p.stop()
95 except Exception, e:
95 except Exception, e:
96 print e
96 print e
97 pass
97 pass
98 if p.poll() is None:
98 if p.poll() is None:
99 time.sleep(.25)
99 time.sleep(.25)
100 if p.poll() is None:
100 if p.poll() is None:
101 try:
101 try:
102 print 'cleaning up test process...'
102 print 'cleaning up test process...'
103 p.signal(SIGKILL)
103 p.signal(SIGKILL)
104 except:
104 except:
105 print "couldn't shutdown process: ", p
105 print "couldn't shutdown process: ", p
106 blackhole.close()
106 blackhole.close()
107
107
@@ -1,452 +1,440 b''
1 """test View objects"""
1 """test View objects"""
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import sys
14 import sys
15 import time
15 import time
16 from tempfile import mktemp
16 from tempfile import mktemp
17 from StringIO import StringIO
17 from StringIO import StringIO
18
18
19 import zmq
19 import zmq
20
20
21 from IPython import parallel as pmod
21 from IPython import parallel as pmod
22 from IPython.parallel import error
22 from IPython.parallel import error
23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 from IPython.parallel import LoadBalancedView, DirectView
24 from IPython.parallel import DirectView
25 from IPython.parallel.util import interactive
25 from IPython.parallel.util import interactive
26
26
27 from IPython.parallel.tests import add_engines
27 from IPython.parallel.tests import add_engines
28
28
29 from .clienttest import ClusterTestCase, crash, wait, skip_without
29 from .clienttest import ClusterTestCase, crash, wait, skip_without
30
30
31 def setup():
31 def setup():
32 add_engines(3)
32 add_engines(3)
33
33
34 class TestView(ClusterTestCase):
34 class TestView(ClusterTestCase):
35
35
36 def test_z_crash_task(self):
37 """test graceful handling of engine death (balanced)"""
38 # self.add_engines(1)
39 ar = self.client[-1].apply_async(crash)
40 self.assertRaisesRemote(error.EngineError, ar.get)
41 eid = ar.engine_id
42 tic = time.time()
43 while eid in self.client.ids and time.time()-tic < 5:
44 time.sleep(.01)
45 self.client.spin()
46 self.assertFalse(eid in self.client.ids, "Engine should have died")
47
48 def test_z_crash_mux(self):
36 def test_z_crash_mux(self):
49 """test graceful handling of engine death (direct)"""
37 """test graceful handling of engine death (direct)"""
50 # self.add_engines(1)
38 # self.add_engines(1)
51 eid = self.client.ids[-1]
39 eid = self.client.ids[-1]
52 ar = self.client[eid].apply_async(crash)
40 ar = self.client[eid].apply_async(crash)
53 self.assertRaisesRemote(error.EngineError, ar.get)
41 self.assertRaisesRemote(error.EngineError, ar.get)
54 eid = ar.engine_id
42 eid = ar.engine_id
55 tic = time.time()
43 tic = time.time()
56 while eid in self.client.ids and time.time()-tic < 5:
44 while eid in self.client.ids and time.time()-tic < 5:
57 time.sleep(.01)
45 time.sleep(.01)
58 self.client.spin()
46 self.client.spin()
59 self.assertFalse(eid in self.client.ids, "Engine should have died")
47 self.assertFalse(eid in self.client.ids, "Engine should have died")
60
48
61 def test_push_pull(self):
49 def test_push_pull(self):
62 """test pushing and pulling"""
50 """test pushing and pulling"""
63 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
51 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
64 t = self.client.ids[-1]
52 t = self.client.ids[-1]
65 v = self.client[t]
53 v = self.client[t]
66 push = v.push
54 push = v.push
67 pull = v.pull
55 pull = v.pull
68 v.block=True
56 v.block=True
69 nengines = len(self.client)
57 nengines = len(self.client)
70 push({'data':data})
58 push({'data':data})
71 d = pull('data')
59 d = pull('data')
72 self.assertEquals(d, data)
60 self.assertEquals(d, data)
73 self.client[:].push({'data':data})
61 self.client[:].push({'data':data})
74 d = self.client[:].pull('data', block=True)
62 d = self.client[:].pull('data', block=True)
75 self.assertEquals(d, nengines*[data])
63 self.assertEquals(d, nengines*[data])
76 ar = push({'data':data}, block=False)
64 ar = push({'data':data}, block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
65 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
66 r = ar.get()
79 ar = self.client[:].pull('data', block=False)
67 ar = self.client[:].pull('data', block=False)
80 self.assertTrue(isinstance(ar, AsyncResult))
68 self.assertTrue(isinstance(ar, AsyncResult))
81 r = ar.get()
69 r = ar.get()
82 self.assertEquals(r, nengines*[data])
70 self.assertEquals(r, nengines*[data])
83 self.client[:].push(dict(a=10,b=20))
71 self.client[:].push(dict(a=10,b=20))
84 r = self.client[:].pull(('a','b'), block=True)
72 r = self.client[:].pull(('a','b'), block=True)
85 self.assertEquals(r, nengines*[[10,20]])
73 self.assertEquals(r, nengines*[[10,20]])
86
74
87 def test_push_pull_function(self):
75 def test_push_pull_function(self):
88 "test pushing and pulling functions"
76 "test pushing and pulling functions"
89 def testf(x):
77 def testf(x):
90 return 2.0*x
78 return 2.0*x
91
79
92 t = self.client.ids[-1]
80 t = self.client.ids[-1]
93 v = self.client[t]
81 v = self.client[t]
94 v.block=True
82 v.block=True
95 push = v.push
83 push = v.push
96 pull = v.pull
84 pull = v.pull
97 execute = v.execute
85 execute = v.execute
98 push({'testf':testf})
86 push({'testf':testf})
99 r = pull('testf')
87 r = pull('testf')
100 self.assertEqual(r(1.0), testf(1.0))
88 self.assertEqual(r(1.0), testf(1.0))
101 execute('r = testf(10)')
89 execute('r = testf(10)')
102 r = pull('r')
90 r = pull('r')
103 self.assertEquals(r, testf(10))
91 self.assertEquals(r, testf(10))
104 ar = self.client[:].push({'testf':testf}, block=False)
92 ar = self.client[:].push({'testf':testf}, block=False)
105 ar.get()
93 ar.get()
106 ar = self.client[:].pull('testf', block=False)
94 ar = self.client[:].pull('testf', block=False)
107 rlist = ar.get()
95 rlist = ar.get()
108 for r in rlist:
96 for r in rlist:
109 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
110 execute("def g(x): return x*x")
98 execute("def g(x): return x*x")
111 r = pull(('testf','g'))
99 r = pull(('testf','g'))
112 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
100 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
113
101
114 def test_push_function_globals(self):
102 def test_push_function_globals(self):
115 """test that pushed functions have access to globals"""
103 """test that pushed functions have access to globals"""
116 @interactive
104 @interactive
117 def geta():
105 def geta():
118 return a
106 return a
119 # self.add_engines(1)
107 # self.add_engines(1)
120 v = self.client[-1]
108 v = self.client[-1]
121 v.block=True
109 v.block=True
122 v['f'] = geta
110 v['f'] = geta
123 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
111 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
124 v.execute('a=5')
112 v.execute('a=5')
125 v.execute('b=f()')
113 v.execute('b=f()')
126 self.assertEquals(v['b'], 5)
114 self.assertEquals(v['b'], 5)
127
115
128 def test_push_function_defaults(self):
116 def test_push_function_defaults(self):
129 """test that pushed functions preserve default args"""
117 """test that pushed functions preserve default args"""
130 def echo(a=10):
118 def echo(a=10):
131 return a
119 return a
132 v = self.client[-1]
120 v = self.client[-1]
133 v.block=True
121 v.block=True
134 v['f'] = echo
122 v['f'] = echo
135 v.execute('b=f()')
123 v.execute('b=f()')
136 self.assertEquals(v['b'], 10)
124 self.assertEquals(v['b'], 10)
137
125
138 def test_get_result(self):
126 def test_get_result(self):
139 """test getting results from the Hub."""
127 """test getting results from the Hub."""
140 c = pmod.Client(profile='iptest')
128 c = pmod.Client(profile='iptest')
141 # self.add_engines(1)
129 # self.add_engines(1)
142 t = c.ids[-1]
130 t = c.ids[-1]
143 v = c[t]
131 v = c[t]
144 v2 = self.client[t]
132 v2 = self.client[t]
145 ar = v.apply_async(wait, 1)
133 ar = v.apply_async(wait, 1)
146 # give the monitor time to notice the message
134 # give the monitor time to notice the message
147 time.sleep(.25)
135 time.sleep(.25)
148 ahr = v2.get_result(ar.msg_ids)
136 ahr = v2.get_result(ar.msg_ids)
149 self.assertTrue(isinstance(ahr, AsyncHubResult))
137 self.assertTrue(isinstance(ahr, AsyncHubResult))
150 self.assertEquals(ahr.get(), ar.get())
138 self.assertEquals(ahr.get(), ar.get())
151 ar2 = v2.get_result(ar.msg_ids)
139 ar2 = v2.get_result(ar.msg_ids)
152 self.assertFalse(isinstance(ar2, AsyncHubResult))
140 self.assertFalse(isinstance(ar2, AsyncHubResult))
153 c.spin()
141 c.spin()
154 c.close()
142 c.close()
155
143
156 def test_run_newline(self):
144 def test_run_newline(self):
157 """test that run appends newline to files"""
145 """test that run appends newline to files"""
158 tmpfile = mktemp()
146 tmpfile = mktemp()
159 with open(tmpfile, 'w') as f:
147 with open(tmpfile, 'w') as f:
160 f.write("""def g():
148 f.write("""def g():
161 return 5
149 return 5
162 """)
150 """)
163 v = self.client[-1]
151 v = self.client[-1]
164 v.run(tmpfile, block=True)
152 v.run(tmpfile, block=True)
165 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
153 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
166
154
167 def test_apply_tracked(self):
155 def test_apply_tracked(self):
168 """test tracking for apply"""
156 """test tracking for apply"""
169 # self.add_engines(1)
157 # self.add_engines(1)
170 t = self.client.ids[-1]
158 t = self.client.ids[-1]
171 v = self.client[t]
159 v = self.client[t]
172 v.block=False
160 v.block=False
173 def echo(n=1024*1024, **kwargs):
161 def echo(n=1024*1024, **kwargs):
174 with v.temp_flags(**kwargs):
162 with v.temp_flags(**kwargs):
175 return v.apply(lambda x: x, 'x'*n)
163 return v.apply(lambda x: x, 'x'*n)
176 ar = echo(1, track=False)
164 ar = echo(1, track=False)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
165 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertTrue(ar.sent)
166 self.assertTrue(ar.sent)
179 ar = echo(track=True)
167 ar = echo(track=True)
180 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
168 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
181 self.assertEquals(ar.sent, ar._tracker.done)
169 self.assertEquals(ar.sent, ar._tracker.done)
182 ar._tracker.wait()
170 ar._tracker.wait()
183 self.assertTrue(ar.sent)
171 self.assertTrue(ar.sent)
184
172
185 def test_push_tracked(self):
173 def test_push_tracked(self):
186 t = self.client.ids[-1]
174 t = self.client.ids[-1]
187 ns = dict(x='x'*1024*1024)
175 ns = dict(x='x'*1024*1024)
188 v = self.client[t]
176 v = self.client[t]
189 ar = v.push(ns, block=False, track=False)
177 ar = v.push(ns, block=False, track=False)
190 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(ar.sent)
179 self.assertTrue(ar.sent)
192
180
193 ar = v.push(ns, block=False, track=True)
181 ar = v.push(ns, block=False, track=True)
194 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
182 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
195 self.assertEquals(ar.sent, ar._tracker.done)
183 self.assertEquals(ar.sent, ar._tracker.done)
196 ar._tracker.wait()
184 ar._tracker.wait()
197 self.assertTrue(ar.sent)
185 self.assertTrue(ar.sent)
198 ar.get()
186 ar.get()
199
187
200 def test_scatter_tracked(self):
188 def test_scatter_tracked(self):
201 t = self.client.ids
189 t = self.client.ids
202 x='x'*1024*1024
190 x='x'*1024*1024
203 ar = self.client[t].scatter('x', x, block=False, track=False)
191 ar = self.client[t].scatter('x', x, block=False, track=False)
204 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(ar.sent)
193 self.assertTrue(ar.sent)
206
194
207 ar = self.client[t].scatter('x', x, block=False, track=True)
195 ar = self.client[t].scatter('x', x, block=False, track=True)
208 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
209 self.assertEquals(ar.sent, ar._tracker.done)
197 self.assertEquals(ar.sent, ar._tracker.done)
210 ar._tracker.wait()
198 ar._tracker.wait()
211 self.assertTrue(ar.sent)
199 self.assertTrue(ar.sent)
212 ar.get()
200 ar.get()
213
201
214 def test_remote_reference(self):
202 def test_remote_reference(self):
215 v = self.client[-1]
203 v = self.client[-1]
216 v['a'] = 123
204 v['a'] = 123
217 ra = pmod.Reference('a')
205 ra = pmod.Reference('a')
218 b = v.apply_sync(lambda x: x, ra)
206 b = v.apply_sync(lambda x: x, ra)
219 self.assertEquals(b, 123)
207 self.assertEquals(b, 123)
220
208
221
209
222 def test_scatter_gather(self):
210 def test_scatter_gather(self):
223 view = self.client[:]
211 view = self.client[:]
224 seq1 = range(16)
212 seq1 = range(16)
225 view.scatter('a', seq1)
213 view.scatter('a', seq1)
226 seq2 = view.gather('a', block=True)
214 seq2 = view.gather('a', block=True)
227 self.assertEquals(seq2, seq1)
215 self.assertEquals(seq2, seq1)
228 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
216 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
229
217
230 @skip_without('numpy')
218 @skip_without('numpy')
231 def test_scatter_gather_numpy(self):
219 def test_scatter_gather_numpy(self):
232 import numpy
220 import numpy
233 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
221 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
234 view = self.client[:]
222 view = self.client[:]
235 a = numpy.arange(64)
223 a = numpy.arange(64)
236 view.scatter('a', a)
224 view.scatter('a', a)
237 b = view.gather('a', block=True)
225 b = view.gather('a', block=True)
238 assert_array_equal(b, a)
226 assert_array_equal(b, a)
239
227
240 def test_map(self):
228 def test_map(self):
241 view = self.client[:]
229 view = self.client[:]
242 def f(x):
230 def f(x):
243 return x**2
231 return x**2
244 data = range(16)
232 data = range(16)
245 r = view.map_sync(f, data)
233 r = view.map_sync(f, data)
246 self.assertEquals(r, map(f, data))
234 self.assertEquals(r, map(f, data))
247
235
248 def test_scatterGatherNonblocking(self):
236 def test_scatterGatherNonblocking(self):
249 data = range(16)
237 data = range(16)
250 view = self.client[:]
238 view = self.client[:]
251 view.scatter('a', data, block=False)
239 view.scatter('a', data, block=False)
252 ar = view.gather('a', block=False)
240 ar = view.gather('a', block=False)
253 self.assertEquals(ar.get(), data)
241 self.assertEquals(ar.get(), data)
254
242
255 @skip_without('numpy')
243 @skip_without('numpy')
256 def test_scatter_gather_numpy_nonblocking(self):
244 def test_scatter_gather_numpy_nonblocking(self):
257 import numpy
245 import numpy
258 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
246 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
259 a = numpy.arange(64)
247 a = numpy.arange(64)
260 view = self.client[:]
248 view = self.client[:]
261 ar = view.scatter('a', a, block=False)
249 ar = view.scatter('a', a, block=False)
262 self.assertTrue(isinstance(ar, AsyncResult))
250 self.assertTrue(isinstance(ar, AsyncResult))
263 amr = view.gather('a', block=False)
251 amr = view.gather('a', block=False)
264 self.assertTrue(isinstance(amr, AsyncMapResult))
252 self.assertTrue(isinstance(amr, AsyncMapResult))
265 assert_array_equal(amr.get(), a)
253 assert_array_equal(amr.get(), a)
266
254
267 def test_execute(self):
255 def test_execute(self):
268 view = self.client[:]
256 view = self.client[:]
269 # self.client.debug=True
257 # self.client.debug=True
270 execute = view.execute
258 execute = view.execute
271 ar = execute('c=30', block=False)
259 ar = execute('c=30', block=False)
272 self.assertTrue(isinstance(ar, AsyncResult))
260 self.assertTrue(isinstance(ar, AsyncResult))
273 ar = execute('d=[0,1,2]', block=False)
261 ar = execute('d=[0,1,2]', block=False)
274 self.client.wait(ar, 1)
262 self.client.wait(ar, 1)
275 self.assertEquals(len(ar.get()), len(self.client))
263 self.assertEquals(len(ar.get()), len(self.client))
276 for c in view['c']:
264 for c in view['c']:
277 self.assertEquals(c, 30)
265 self.assertEquals(c, 30)
278
266
279 def test_abort(self):
267 def test_abort(self):
280 view = self.client[-1]
268 view = self.client[-1]
281 ar = view.execute('import time; time.sleep(0.25)', block=False)
269 ar = view.execute('import time; time.sleep(0.25)', block=False)
282 ar2 = view.apply_async(lambda : 2)
270 ar2 = view.apply_async(lambda : 2)
283 ar3 = view.apply_async(lambda : 3)
271 ar3 = view.apply_async(lambda : 3)
284 view.abort(ar2)
272 view.abort(ar2)
285 view.abort(ar3.msg_ids)
273 view.abort(ar3.msg_ids)
286 self.assertRaises(error.TaskAborted, ar2.get)
274 self.assertRaises(error.TaskAborted, ar2.get)
287 self.assertRaises(error.TaskAborted, ar3.get)
275 self.assertRaises(error.TaskAborted, ar3.get)
288
276
289 def test_temp_flags(self):
277 def test_temp_flags(self):
290 view = self.client[-1]
278 view = self.client[-1]
291 view.block=True
279 view.block=True
292 with view.temp_flags(block=False):
280 with view.temp_flags(block=False):
293 self.assertFalse(view.block)
281 self.assertFalse(view.block)
294 self.assertTrue(view.block)
282 self.assertTrue(view.block)
295
283
296 def test_importer(self):
284 def test_importer(self):
297 view = self.client[-1]
285 view = self.client[-1]
298 view.clear(block=True)
286 view.clear(block=True)
299 with view.importer:
287 with view.importer:
300 import re
288 import re
301
289
302 @interactive
290 @interactive
303 def findall(pat, s):
291 def findall(pat, s):
304 # this globals() step isn't necessary in real code
292 # this globals() step isn't necessary in real code
305 # only to prevent a closure in the test
293 # only to prevent a closure in the test
306 re = globals()['re']
294 re = globals()['re']
307 return re.findall(pat, s)
295 return re.findall(pat, s)
308
296
309 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
297 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
310
298
311 # parallel magic tests
299 # parallel magic tests
312
300
313 def test_magic_px_blocking(self):
301 def test_magic_px_blocking(self):
314 ip = get_ipython()
302 ip = get_ipython()
315 v = self.client[-1]
303 v = self.client[-1]
316 v.activate()
304 v.activate()
317 v.block=True
305 v.block=True
318
306
319 ip.magic_px('a=5')
307 ip.magic_px('a=5')
320 self.assertEquals(v['a'], 5)
308 self.assertEquals(v['a'], 5)
321 ip.magic_px('a=10')
309 ip.magic_px('a=10')
322 self.assertEquals(v['a'], 10)
310 self.assertEquals(v['a'], 10)
323 sio = StringIO()
311 sio = StringIO()
324 savestdout = sys.stdout
312 savestdout = sys.stdout
325 sys.stdout = sio
313 sys.stdout = sio
326 ip.magic_px('print a')
314 ip.magic_px('print a')
327 sys.stdout = savestdout
315 sys.stdout = savestdout
328 sio.read()
316 sio.read()
329 self.assertTrue('[stdout:%i]'%v.targets in sio.buf)
317 self.assertTrue('[stdout:%i]'%v.targets in sio.buf)
330 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
318 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
331
319
332 def test_magic_px_nonblocking(self):
320 def test_magic_px_nonblocking(self):
333 ip = get_ipython()
321 ip = get_ipython()
334 v = self.client[-1]
322 v = self.client[-1]
335 v.activate()
323 v.activate()
336 v.block=False
324 v.block=False
337
325
338 ip.magic_px('a=5')
326 ip.magic_px('a=5')
339 self.assertEquals(v['a'], 5)
327 self.assertEquals(v['a'], 5)
340 ip.magic_px('a=10')
328 ip.magic_px('a=10')
341 self.assertEquals(v['a'], 10)
329 self.assertEquals(v['a'], 10)
342 sio = StringIO()
330 sio = StringIO()
343 savestdout = sys.stdout
331 savestdout = sys.stdout
344 sys.stdout = sio
332 sys.stdout = sio
345 ip.magic_px('print a')
333 ip.magic_px('print a')
346 sys.stdout = savestdout
334 sys.stdout = savestdout
347 sio.read()
335 sio.read()
348 self.assertFalse('[stdout:%i]'%v.targets in sio.buf)
336 self.assertFalse('[stdout:%i]'%v.targets in sio.buf)
349 ip.magic_px('1/0')
337 ip.magic_px('1/0')
350 ar = v.get_result(-1)
338 ar = v.get_result(-1)
351 self.assertRaisesRemote(ZeroDivisionError, ar.get)
339 self.assertRaisesRemote(ZeroDivisionError, ar.get)
352
340
353 def test_magic_autopx_blocking(self):
341 def test_magic_autopx_blocking(self):
354 ip = get_ipython()
342 ip = get_ipython()
355 v = self.client[-1]
343 v = self.client[-1]
356 v.activate()
344 v.activate()
357 v.block=True
345 v.block=True
358
346
359 sio = StringIO()
347 sio = StringIO()
360 savestdout = sys.stdout
348 savestdout = sys.stdout
361 sys.stdout = sio
349 sys.stdout = sio
362 ip.magic_autopx()
350 ip.magic_autopx()
363 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
351 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
364 ip.run_cell('print b')
352 ip.run_cell('print b')
365 ip.run_cell("b/c")
353 ip.run_cell("b/c")
366 ip.run_code(compile('b*=2', '', 'single'))
354 ip.run_code(compile('b*=2', '', 'single'))
367 ip.magic_autopx()
355 ip.magic_autopx()
368 sys.stdout = savestdout
356 sys.stdout = savestdout
369 sio.read()
357 sio.read()
370 output = sio.buf.strip()
358 output = sio.buf.strip()
371 self.assertTrue(output.startswith('%autopx enabled'))
359 self.assertTrue(output.startswith('%autopx enabled'))
372 self.assertTrue(output.endswith('%autopx disabled'))
360 self.assertTrue(output.endswith('%autopx disabled'))
373 self.assertTrue('RemoteError: ZeroDivisionError' in output)
361 self.assertTrue('RemoteError: ZeroDivisionError' in output)
374 ar = v.get_result(-2)
362 ar = v.get_result(-2)
375 self.assertEquals(v['a'], 5)
363 self.assertEquals(v['a'], 5)
376 self.assertEquals(v['b'], 20)
364 self.assertEquals(v['b'], 20)
377 self.assertRaisesRemote(ZeroDivisionError, ar.get)
365 self.assertRaisesRemote(ZeroDivisionError, ar.get)
378
366
379 def test_magic_autopx_nonblocking(self):
367 def test_magic_autopx_nonblocking(self):
380 ip = get_ipython()
368 ip = get_ipython()
381 v = self.client[-1]
369 v = self.client[-1]
382 v.activate()
370 v.activate()
383 v.block=False
371 v.block=False
384
372
385 sio = StringIO()
373 sio = StringIO()
386 savestdout = sys.stdout
374 savestdout = sys.stdout
387 sys.stdout = sio
375 sys.stdout = sio
388 ip.magic_autopx()
376 ip.magic_autopx()
389 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
377 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
390 ip.run_cell('print b')
378 ip.run_cell('print b')
391 ip.run_cell("b/c")
379 ip.run_cell("b/c")
392 ip.run_code(compile('b*=2', '', 'single'))
380 ip.run_code(compile('b*=2', '', 'single'))
393 ip.magic_autopx()
381 ip.magic_autopx()
394 sys.stdout = savestdout
382 sys.stdout = savestdout
395 sio.read()
383 sio.read()
396 output = sio.buf.strip()
384 output = sio.buf.strip()
397 self.assertTrue(output.startswith('%autopx enabled'))
385 self.assertTrue(output.startswith('%autopx enabled'))
398 self.assertTrue(output.endswith('%autopx disabled'))
386 self.assertTrue(output.endswith('%autopx disabled'))
399 self.assertFalse('ZeroDivisionError' in output)
387 self.assertFalse('ZeroDivisionError' in output)
400 ar = v.get_result(-2)
388 ar = v.get_result(-2)
401 self.assertEquals(v['a'], 5)
389 self.assertEquals(v['a'], 5)
402 self.assertEquals(v['b'], 20)
390 self.assertEquals(v['b'], 20)
403 self.assertRaisesRemote(ZeroDivisionError, ar.get)
391 self.assertRaisesRemote(ZeroDivisionError, ar.get)
404
392
405 def test_magic_result(self):
393 def test_magic_result(self):
406 ip = get_ipython()
394 ip = get_ipython()
407 v = self.client[-1]
395 v = self.client[-1]
408 v.activate()
396 v.activate()
409 v['a'] = 111
397 v['a'] = 111
410 ra = v['a']
398 ra = v['a']
411
399
412 ar = ip.magic_result()
400 ar = ip.magic_result()
413 self.assertEquals(ar.msg_ids, [v.history[-1]])
401 self.assertEquals(ar.msg_ids, [v.history[-1]])
414 self.assertEquals(ar.get(), 111)
402 self.assertEquals(ar.get(), 111)
415 ar = ip.magic_result('-2')
403 ar = ip.magic_result('-2')
416 self.assertEquals(ar.msg_ids, [v.history[-2]])
404 self.assertEquals(ar.msg_ids, [v.history[-2]])
417
405
418 def test_unicode_execute(self):
406 def test_unicode_execute(self):
419 """test executing unicode strings"""
407 """test executing unicode strings"""
420 v = self.client[-1]
408 v = self.client[-1]
421 v.block=True
409 v.block=True
422 code=u"a=u'é'"
410 code=u"a=u'é'"
423 v.execute(code)
411 v.execute(code)
424 self.assertEquals(v['a'], u'é')
412 self.assertEquals(v['a'], u'é')
425
413
426 def test_unicode_apply_result(self):
414 def test_unicode_apply_result(self):
427 """test unicode apply results"""
415 """test unicode apply results"""
428 v = self.client[-1]
416 v = self.client[-1]
429 r = v.apply_sync(lambda : u'é')
417 r = v.apply_sync(lambda : u'é')
430 self.assertEquals(r, u'é')
418 self.assertEquals(r, u'é')
431
419
432 def test_unicode_apply_arg(self):
420 def test_unicode_apply_arg(self):
433 """test passing unicode arguments to apply"""
421 """test passing unicode arguments to apply"""
434 v = self.client[-1]
422 v = self.client[-1]
435
423
436 @interactive
424 @interactive
437 def check_unicode(a, check):
425 def check_unicode(a, check):
438 assert isinstance(a, unicode), "%r is not unicode"%a
426 assert isinstance(a, unicode), "%r is not unicode"%a
439 assert isinstance(check, bytes), "%r is not bytes"%check
427 assert isinstance(check, bytes), "%r is not bytes"%check
440 assert a.encode('utf8') == check, "%s != %s"%(a,check)
428 assert a.encode('utf8') == check, "%s != %s"%(a,check)
441
429
442 for s in [ u'é', u'ßø®∫','asdf'.decode() ]:
430 for s in [ u'é', u'ßø®∫','asdf'.decode() ]:
443 try:
431 try:
444 v.apply_sync(check_unicode, s, s.encode('utf8'))
432 v.apply_sync(check_unicode, s, s.encode('utf8'))
445 except error.RemoteError as e:
433 except error.RemoteError as e:
446 if e.ename == 'AssertionError':
434 if e.ename == 'AssertionError':
447 self.fail(e.evalue)
435 self.fail(e.evalue)
448 else:
436 else:
449 raise e
437 raise e
450
438
451
439
452
440
General Comments 0
You need to be logged in to leave comments. Login now