##// END OF EJS Templates
add retries flag to LoadBalancedView...
MinRK -
Show More
@@ -0,0 +1,120 b''
1 """test LoadBalancedView objects"""
2 # -*- coding: utf-8 -*-
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 import sys
15 import time
16
17 import zmq
18
19 from IPython import parallel as pmod
20 from IPython.parallel import error
21
22 from IPython.parallel.tests import add_engines
23
24 from .clienttest import ClusterTestCase, crash, wait, skip_without
25
26 def setup():
27 add_engines(3)
28
29 class TestLoadBalancedView(ClusterTestCase):
30
31 def setUp(self):
32 ClusterTestCase.setUp(self)
33 self.view = self.client.load_balanced_view()
34
35 def test_z_crash_task(self):
36 """test graceful handling of engine death (balanced)"""
37 # self.add_engines(1)
38 ar = self.view.apply_async(crash)
39 self.assertRaisesRemote(error.EngineError, ar.get)
40 eid = ar.engine_id
41 tic = time.time()
42 while eid in self.client.ids and time.time()-tic < 5:
43 time.sleep(.01)
44 self.client.spin()
45 self.assertFalse(eid in self.client.ids, "Engine should have died")
46
47 def test_map(self):
48 def f(x):
49 return x**2
50 data = range(16)
51 r = self.view.map_sync(f, data)
52 self.assertEquals(r, map(f, data))
53
54 def test_abort(self):
55 view = self.view
56 ar = self.client[:].apply_async(time.sleep, .5)
57 ar2 = view.apply_async(lambda : 2)
58 ar3 = view.apply_async(lambda : 3)
59 view.abort(ar2)
60 view.abort(ar3.msg_ids)
61 self.assertRaises(error.TaskAborted, ar2.get)
62 self.assertRaises(error.TaskAborted, ar3.get)
63
64 def test_retries(self):
65 add_engines(3)
66 view = self.view
67 view.timeout = 1 # prevent hang if this doesn't behave
68 def fail():
69 assert False
70 for r in range(len(self.client)-1):
71 with view.temp_flags(retries=r):
72 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
73
74 with view.temp_flags(retries=len(self.client), timeout=0.25):
75 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
76
77 def test_invalid_dependency(self):
78 view = self.view
79 with view.temp_flags(after='12345'):
80 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
81
82 def test_impossible_dependency(self):
83 if len(self.client) < 2:
84 add_engines(2)
85 view = self.client.load_balanced_view()
86 ar1 = view.apply_async(lambda : 1)
87 ar1.get()
88 e1 = ar1.engine_id
89 e2 = e1
90 while e2 == e1:
91 ar2 = view.apply_async(lambda : 1)
92 ar2.get()
93 e2 = ar2.engine_id
94
95 with view.temp_flags(follow=[ar1, ar2]):
96 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
97
98
99 def test_follow(self):
100 ar = self.view.apply_async(lambda : 1)
101 ar.get()
102 ars = []
103 first_id = ar.engine_id
104
105 self.view.follow = ar
106 for i in range(5):
107 ars.append(self.view.apply_async(lambda : 1))
108 self.view.wait(ars)
109 for ar in ars:
110 self.assertEquals(ar.engine_id, first_id)
111
112 def test_after(self):
113 view = self.view
114 ar = view.apply_async(time.sleep, 0.5)
115 with view.temp_flags(after=ar):
116 ar2 = view.apply_async(lambda : 1)
117
118 ar.wait()
119 ar2.wait()
120 self.assertTrue(ar2.started > ar.completed)
@@ -19,7 +19,7 b' from types import ModuleType'
19 19 import zmq
20 20
21 21 from IPython.testing import decorators as testdec
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
23 23
24 24 from IPython.external.decorator import decorator
25 25
@@ -791,9 +791,10 b' class LoadBalancedView(View):'
791 791 follow=Any()
792 792 after=Any()
793 793 timeout=CFloat()
794 retries = CInt(0)
794 795
795 796 _task_scheme = Any()
796 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout'])
797 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
797 798
798 799 def __init__(self, client=None, socket=None, **flags):
799 800 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
@@ -851,7 +852,7 b' class LoadBalancedView(View):'
851 852 whether to create a MessageTracker to allow the user to
852 853 safely edit after arrays and buffers during non-copying
853 854 sends.
854 #
855
855 856 after : Dependency or collection of msg_ids
856 857 Only for load-balanced execution (targets=None)
857 858 Specify a list of msg_ids as a time-based dependency.
@@ -869,6 +870,9 b' class LoadBalancedView(View):'
869 870 Specify an amount of time (in seconds) for the scheduler to
870 871 wait for dependencies to be met before failing with a
871 872 DependencyTimeout.
873
874 retries : int
875 Number of times a task will be retried on failure.
872 876 """
873 877
874 878 super(LoadBalancedView, self).set_flags(**kwargs)
@@ -892,7 +896,7 b' class LoadBalancedView(View):'
892 896 @save_ids
893 897 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
894 898 after=None, follow=None, timeout=None,
895 targets=None):
899 targets=None, retries=None):
896 900 """calls f(*args, **kwargs) on a remote engine, returning the result.
897 901
898 902 This method temporarily sets all of `apply`'s flags for a single call.
@@ -933,10 +937,11 b' class LoadBalancedView(View):'
933 937 raise RuntimeError(msg)
934 938
935 939 if self._task_scheme == 'pure':
936 # pure zmq scheme doesn't support dependencies
937 msg = "Pure ZMQ scheduler doesn't support dependencies"
938 if (follow or after):
939 # hard fail on DAG dependencies
940 # pure zmq scheme doesn't support extra features
941 msg = "Pure ZMQ scheduler doesn't support the following flags:"
942 "follow, after, retries, targets, timeout"
943 if (follow or after or retries or targets or timeout):
944 # hard fail on Scheduler flags
940 945 raise RuntimeError(msg)
941 946 if isinstance(f, dependent):
942 947 # soft warn on functional dependencies
@@ -948,10 +953,14 b' class LoadBalancedView(View):'
948 953 block = self.block if block is None else block
949 954 track = self.track if track is None else track
950 955 after = self.after if after is None else after
956 retries = self.retries if retries is None else retries
951 957 follow = self.follow if follow is None else follow
952 958 timeout = self.timeout if timeout is None else timeout
953 959 targets = self.targets if targets is None else targets
954 960
961 if not isinstance(retries, int):
962 raise TypeError('retries must be int, not %r'%type(retries))
963
955 964 if targets is None:
956 965 idents = []
957 966 else:
@@ -959,7 +968,7 b' class LoadBalancedView(View):'
959 968
960 969 after = self._render_dependency(after)
961 970 follow = self._render_dependency(follow)
962 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
971 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
963 972
964 973 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
965 974 subheader=subheader)
@@ -137,6 +137,7 b' class TaskScheduler(SessionFactory):'
137 137
138 138 # internals:
139 139 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
140 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
140 141 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
141 142 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
142 143 pending = Dict() # dict by engine_uuid of submitted tasks
@@ -205,6 +206,8 b' class TaskScheduler(SessionFactory):'
205 206 self.pending[uid] = {}
206 207 if len(self.targets) == 1:
207 208 self.resume_receiving()
209 # rescan the graph:
210 self.update_graph(None)
208 211
209 212 def _unregister_engine(self, uid):
210 213 """Existing engine with ident `uid` became unavailable."""
@@ -215,11 +218,11 b' class TaskScheduler(SessionFactory):'
215 218 # handle any potentially finished tasks:
216 219 self.engine_stream.flush()
217 220
218 self.completed.pop(uid)
219 self.failed.pop(uid)
220 # don't pop destinations, because it might be used later
221 # don't pop destinations, because they might be used later
221 222 # map(self.destinations.pop, self.completed.pop(uid))
222 223 # map(self.destinations.pop, self.failed.pop(uid))
224
225 # prevent this engine from receiving work
223 226 idx = self.targets.index(uid)
224 227 self.targets.pop(idx)
225 228 self.loads.pop(idx)
@@ -229,28 +232,40 b' class TaskScheduler(SessionFactory):'
229 232 if self.pending[uid]:
230 233 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
231 234 dc.start()
235 else:
236 self.completed.pop(uid)
237 self.failed.pop(uid)
238
232 239
233 240 @logged
234 241 def handle_stranded_tasks(self, engine):
235 242 """Deal with jobs resident in an engine that died."""
236 lost = self.pending.pop(engine)
237
238 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
239 self.all_failed.add(msg_id)
240 self.all_done.add(msg_id)
243 lost = self.pending[engine]
244 for msg_id in lost.keys():
245 if msg_id not in self.pending[engine]:
246 # prevent double-handling of messages
247 continue
248
249 raw_msg = lost[msg_id][0]
250
241 251 idents,msg = self.session.feed_identities(raw_msg, copy=False)
242 252 msg = self.session.unpack_message(msg, copy=False, content=False)
243 253 parent = msg['header']
244 idents = [idents[0],engine]+idents[1:]
245 # print (idents)
254 idents = [engine, idents[0]]
255
256 # build fake error reply
246 257 try:
247 258 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
248 259 except:
249 260 content = error.wrap_exception()
250 msg = self.session.send(self.client_stream, 'apply_reply', content,
251 parent=parent, ident=idents)
252 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
253 self.update_graph(msg_id)
261 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
262 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
263 # and dispatch it
264 self.dispatch_result(raw_reply)
265
266 # finally scrub completed/failed lists
267 self.completed.pop(engine)
268 self.failed.pop(engine)
254 269
255 270
256 271 #-----------------------------------------------------------------------
@@ -277,6 +292,8 b' class TaskScheduler(SessionFactory):'
277 292
278 293 # targets
279 294 targets = set(header.get('targets', []))
295 retries = header.get('retries', 0)
296 self.retries[msg_id] = retries
280 297
281 298 # time dependencies
282 299 after = Dependency(header.get('after', []))
@@ -315,7 +332,9 b' class TaskScheduler(SessionFactory):'
315 332 # time deps already met, try to run
316 333 if not self.maybe_run(msg_id, *args):
317 334 # can't run yet
318 self.save_unmet(msg_id, *args)
335 if msg_id not in self.all_failed:
336 # could have failed as unreachable
337 self.save_unmet(msg_id, *args)
319 338 else:
320 339 self.save_unmet(msg_id, *args)
321 340
@@ -328,7 +347,7 b' class TaskScheduler(SessionFactory):'
328 347 if msg_id in self.depending:
329 348 raw,after,targets,follow,timeout = self.depending[msg_id]
330 349 if timeout and timeout < now:
331 self.fail_unreachable(msg_id, timeout=True)
350 self.fail_unreachable(msg_id, error.TaskTimeout)
332 351
333 352 @logged
334 353 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
@@ -369,7 +388,7 b' class TaskScheduler(SessionFactory):'
369 388 # we need a can_run filter
370 389 def can_run(idx):
371 390 # check hwm
372 if self.loads[idx] == self.hwm:
391 if self.hwm and self.loads[idx] == self.hwm:
373 392 return False
374 393 target = self.targets[idx]
375 394 # check blacklist
@@ -382,6 +401,7 b' class TaskScheduler(SessionFactory):'
382 401 return follow.check(self.completed[target], self.failed[target])
383 402
384 403 indices = filter(can_run, range(len(self.targets)))
404
385 405 if not indices:
386 406 # couldn't run
387 407 if follow.all:
@@ -395,12 +415,14 b' class TaskScheduler(SessionFactory):'
395 415 for m in follow.intersection(relevant):
396 416 dests.add(self.destinations[m])
397 417 if len(dests) > 1:
418 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
398 419 self.fail_unreachable(msg_id)
399 420 return False
400 421 if targets:
401 422 # check blacklist+targets for impossibility
402 423 targets.difference_update(blacklist)
403 424 if not targets or not targets.intersection(self.targets):
425 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
404 426 self.fail_unreachable(msg_id)
405 427 return False
406 428 return False
@@ -454,20 +476,34 b' class TaskScheduler(SessionFactory):'
454 476 idents,msg = self.session.feed_identities(raw_msg, copy=False)
455 477 msg = self.session.unpack_message(msg, content=False, copy=False)
456 478 engine = idents[0]
457 idx = self.targets.index(engine)
458 self.finish_job(idx)
479 try:
480 idx = self.targets.index(engine)
481 except ValueError:
482 pass # skip load-update for dead engines
483 else:
484 self.finish_job(idx)
459 485 except Exception:
460 486 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
461 487 return
462 488
463 489 header = msg['header']
490 parent = msg['parent_header']
464 491 if header.get('dependencies_met', True):
465 492 success = (header['status'] == 'ok')
466 self.handle_result(idents, msg['parent_header'], raw_msg, success)
467 # send to Hub monitor
468 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
493 msg_id = parent['msg_id']
494 retries = self.retries[msg_id]
495 if not success and retries > 0:
496 # failed
497 self.retries[msg_id] = retries - 1
498 self.handle_unmet_dependency(idents, parent)
499 else:
500 del self.retries[msg_id]
501 # relay to client and update graph
502 self.handle_result(idents, parent, raw_msg, success)
503 # send to Hub monitor
504 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
469 505 else:
470 self.handle_unmet_dependency(idents, msg['parent_header'])
506 self.handle_unmet_dependency(idents, parent)
471 507
472 508 @logged
473 509 def handle_result(self, idents, parent, raw_msg, success=True):
@@ -511,13 +547,19 b' class TaskScheduler(SessionFactory):'
511 547 self.depending[msg_id] = args
512 548 self.fail_unreachable(msg_id)
513 549 elif not self.maybe_run(msg_id, *args):
514 # resubmit failed, put it back in our dependency tree
515 self.save_unmet(msg_id, *args)
550 # resubmit failed
551 if msg_id not in self.all_failed:
552 # put it back in our dependency tree
553 self.save_unmet(msg_id, *args)
516 554
517 555 if self.hwm:
518 idx = self.targets.index(engine)
519 if self.loads[idx] == self.hwm-1:
520 self.update_graph(None)
556 try:
557 idx = self.targets.index(engine)
558 except ValueError:
559 pass # skip load-update for dead engines
560 else:
561 if self.loads[idx] == self.hwm-1:
562 self.update_graph(None)
521 563
522 564
523 565
@@ -526,7 +568,7 b' class TaskScheduler(SessionFactory):'
526 568 """dep_id just finished. Update our dependency
527 569 graph and submit any jobs that just became runable.
528 570
529 Called with dep_id=None to update graph for hwm, but without finishing
571 Called with dep_id=None to update entire graph for hwm, but without finishing
530 572 a task.
531 573 """
532 574 # print ("\n\n***********")
@@ -538,9 +580,11 b' class TaskScheduler(SessionFactory):'
538 580 # print ("\n\n***********\n\n")
539 581 # update any jobs that depended on the dependency
540 582 jobs = self.graph.pop(dep_id, [])
541 # if we have HWM and an engine just become no longer full
542 # recheck *all* jobs:
543 if self.hwm and any( [ load==self.hwm-1 for load in self.loads]):
583
584 # recheck *all* jobs if
585 # a) we have HWM and an engine just become no longer full
586 # or b) dep_id was given as None
587 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
544 588 jobs = self.depending.keys()
545 589
546 590 for msg_id in jobs:
@@ -48,7 +48,7 b' class TestProcessLauncher(LocalProcessLauncher):'
48 48 def setup():
49 49 cp = TestProcessLauncher()
50 50 cp.cmd_and_args = ipcontroller_cmd_argv + \
51 ['--profile', 'iptest', '--log-level', '99', '-r', '--usethreads']
51 ['--profile', 'iptest', '--log-level', '99', '-r']
52 52 cp.start()
53 53 launchers.append(cp)
54 54 cluster_dir = os.path.join(get_ipython_dir(), 'cluster_iptest')
@@ -21,7 +21,7 b' import zmq'
21 21 from IPython import parallel as pmod
22 22 from IPython.parallel import error
23 23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 from IPython.parallel import LoadBalancedView, DirectView
24 from IPython.parallel import DirectView
25 25 from IPython.parallel.util import interactive
26 26
27 27 from IPython.parallel.tests import add_engines
@@ -33,18 +33,6 b' def setup():'
33 33
34 34 class TestView(ClusterTestCase):
35 35
36 def test_z_crash_task(self):
37 """test graceful handling of engine death (balanced)"""
38 # self.add_engines(1)
39 ar = self.client[-1].apply_async(crash)
40 self.assertRaisesRemote(error.EngineError, ar.get)
41 eid = ar.engine_id
42 tic = time.time()
43 while eid in self.client.ids and time.time()-tic < 5:
44 time.sleep(.01)
45 self.client.spin()
46 self.assertFalse(eid in self.client.ids, "Engine should have died")
47
48 36 def test_z_crash_mux(self):
49 37 """test graceful handling of engine death (direct)"""
50 38 # self.add_engines(1)
General Comments 0
You need to be logged in to leave comments. Login now