diff --git a/IPython/parallel/client/view.py b/IPython/parallel/client/view.py index 647729d..2235941 100644 --- a/IPython/parallel/client/view.py +++ b/IPython/parallel/client/view.py @@ -19,7 +19,7 @@ from types import ModuleType import zmq from IPython.testing import decorators as testdec -from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat +from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt from IPython.external.decorator import decorator @@ -791,9 +791,10 @@ class LoadBalancedView(View): follow=Any() after=Any() timeout=CFloat() + retries = CInt(0) _task_scheme = Any() - _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout']) + _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries']) def __init__(self, client=None, socket=None, **flags): super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags) @@ -851,7 +852,7 @@ class LoadBalancedView(View): whether to create a MessageTracker to allow the user to safely edit after arrays and buffers during non-copying sends. - # + after : Dependency or collection of msg_ids Only for load-balanced execution (targets=None) Specify a list of msg_ids as a time-based dependency. @@ -869,6 +870,9 @@ class LoadBalancedView(View): Specify an amount of time (in seconds) for the scheduler to wait for dependencies to be met before failing with a DependencyTimeout. + + retries : int + Number of times a task will be retried on failure. """ super(LoadBalancedView, self).set_flags(**kwargs) @@ -892,7 +896,7 @@ class LoadBalancedView(View): @save_ids def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, after=None, follow=None, timeout=None, - targets=None): + targets=None, retries=None): """calls f(*args, **kwargs) on a remote engine, returning the result. This method temporarily sets all of `apply`'s flags for a single call. @@ -933,10 +937,11 @@ class LoadBalancedView(View): raise RuntimeError(msg) if self._task_scheme == 'pure': - # pure zmq scheme doesn't support dependencies - msg = "Pure ZMQ scheduler doesn't support dependencies" - if (follow or after): - # hard fail on DAG dependencies + # pure zmq scheme doesn't support extra features + msg = "Pure ZMQ scheduler doesn't support the following flags:" + "follow, after, retries, targets, timeout" + if (follow or after or retries or targets or timeout): + # hard fail on Scheduler flags raise RuntimeError(msg) if isinstance(f, dependent): # soft warn on functional dependencies @@ -948,10 +953,14 @@ class LoadBalancedView(View): block = self.block if block is None else block track = self.track if track is None else track after = self.after if after is None else after + retries = self.retries if retries is None else retries follow = self.follow if follow is None else follow timeout = self.timeout if timeout is None else timeout targets = self.targets if targets is None else targets + if not isinstance(retries, int): + raise TypeError('retries must be int, not %r'%type(retries)) + if targets is None: idents = [] else: @@ -959,7 +968,7 @@ class LoadBalancedView(View): after = self._render_dependency(after) follow = self._render_dependency(follow) - subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents) + subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries) msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, subheader=subheader) diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index 05da905..147700c 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -137,6 +137,7 @@ class TaskScheduler(SessionFactory): # internals: graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] + retries = Dict() # dict by msg_id of retries remaining (non-neg ints) # waiting = List() # list of msg_ids ready to run, but haven't due to HWM depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow) pending = Dict() # dict by engine_uuid of submitted tasks @@ -205,6 +206,8 @@ class TaskScheduler(SessionFactory): self.pending[uid] = {} if len(self.targets) == 1: self.resume_receiving() + # rescan the graph: + self.update_graph(None) def _unregister_engine(self, uid): """Existing engine with ident `uid` became unavailable.""" @@ -215,11 +218,11 @@ class TaskScheduler(SessionFactory): # handle any potentially finished tasks: self.engine_stream.flush() - self.completed.pop(uid) - self.failed.pop(uid) - # don't pop destinations, because it might be used later + # don't pop destinations, because they might be used later # map(self.destinations.pop, self.completed.pop(uid)) # map(self.destinations.pop, self.failed.pop(uid)) + + # prevent this engine from receiving work idx = self.targets.index(uid) self.targets.pop(idx) self.loads.pop(idx) @@ -229,28 +232,40 @@ class TaskScheduler(SessionFactory): if self.pending[uid]: dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop) dc.start() + else: + self.completed.pop(uid) + self.failed.pop(uid) + @logged def handle_stranded_tasks(self, engine): """Deal with jobs resident in an engine that died.""" - lost = self.pending.pop(engine) - - for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems(): - self.all_failed.add(msg_id) - self.all_done.add(msg_id) + lost = self.pending[engine] + for msg_id in lost.keys(): + if msg_id not in self.pending[engine]: + # prevent double-handling of messages + continue + + raw_msg = lost[msg_id][0] + idents,msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, copy=False, content=False) parent = msg['header'] - idents = [idents[0],engine]+idents[1:] - # print (idents) + idents = [engine, idents[0]] + + # build fake error reply try: raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id)) except: content = error.wrap_exception() - msg = self.session.send(self.client_stream, 'apply_reply', content, - parent=parent, ident=idents) - self.session.send(self.mon_stream, msg, ident=['outtask']+idents) - self.update_graph(msg_id) + msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'}) + raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents)) + # and dispatch it + self.dispatch_result(raw_reply) + + # finally scrub completed/failed lists + self.completed.pop(engine) + self.failed.pop(engine) #----------------------------------------------------------------------- @@ -277,6 +292,8 @@ class TaskScheduler(SessionFactory): # targets targets = set(header.get('targets', [])) + retries = header.get('retries', 0) + self.retries[msg_id] = retries # time dependencies after = Dependency(header.get('after', [])) @@ -315,7 +332,9 @@ class TaskScheduler(SessionFactory): # time deps already met, try to run if not self.maybe_run(msg_id, *args): # can't run yet - self.save_unmet(msg_id, *args) + if msg_id not in self.all_failed: + # could have failed as unreachable + self.save_unmet(msg_id, *args) else: self.save_unmet(msg_id, *args) @@ -328,7 +347,7 @@ class TaskScheduler(SessionFactory): if msg_id in self.depending: raw,after,targets,follow,timeout = self.depending[msg_id] if timeout and timeout < now: - self.fail_unreachable(msg_id, timeout=True) + self.fail_unreachable(msg_id, error.TaskTimeout) @logged def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): @@ -369,7 +388,7 @@ class TaskScheduler(SessionFactory): # we need a can_run filter def can_run(idx): # check hwm - if self.loads[idx] == self.hwm: + if self.hwm and self.loads[idx] == self.hwm: return False target = self.targets[idx] # check blacklist @@ -382,6 +401,7 @@ class TaskScheduler(SessionFactory): return follow.check(self.completed[target], self.failed[target]) indices = filter(can_run, range(len(self.targets))) + if not indices: # couldn't run if follow.all: @@ -395,12 +415,14 @@ class TaskScheduler(SessionFactory): for m in follow.intersection(relevant): dests.add(self.destinations[m]) if len(dests) > 1: + self.depending[msg_id] = (raw_msg, targets, after, follow, timeout) self.fail_unreachable(msg_id) return False if targets: # check blacklist+targets for impossibility targets.difference_update(blacklist) if not targets or not targets.intersection(self.targets): + self.depending[msg_id] = (raw_msg, targets, after, follow, timeout) self.fail_unreachable(msg_id) return False return False @@ -454,20 +476,34 @@ class TaskScheduler(SessionFactory): idents,msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, content=False, copy=False) engine = idents[0] - idx = self.targets.index(engine) - self.finish_job(idx) + try: + idx = self.targets.index(engine) + except ValueError: + pass # skip load-update for dead engines + else: + self.finish_job(idx) except Exception: self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True) return header = msg['header'] + parent = msg['parent_header'] if header.get('dependencies_met', True): success = (header['status'] == 'ok') - self.handle_result(idents, msg['parent_header'], raw_msg, success) - # send to Hub monitor - self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False) + msg_id = parent['msg_id'] + retries = self.retries[msg_id] + if not success and retries > 0: + # failed + self.retries[msg_id] = retries - 1 + self.handle_unmet_dependency(idents, parent) + else: + del self.retries[msg_id] + # relay to client and update graph + self.handle_result(idents, parent, raw_msg, success) + # send to Hub monitor + self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False) else: - self.handle_unmet_dependency(idents, msg['parent_header']) + self.handle_unmet_dependency(idents, parent) @logged def handle_result(self, idents, parent, raw_msg, success=True): @@ -511,13 +547,19 @@ class TaskScheduler(SessionFactory): self.depending[msg_id] = args self.fail_unreachable(msg_id) elif not self.maybe_run(msg_id, *args): - # resubmit failed, put it back in our dependency tree - self.save_unmet(msg_id, *args) + # resubmit failed + if msg_id not in self.all_failed: + # put it back in our dependency tree + self.save_unmet(msg_id, *args) if self.hwm: - idx = self.targets.index(engine) - if self.loads[idx] == self.hwm-1: - self.update_graph(None) + try: + idx = self.targets.index(engine) + except ValueError: + pass # skip load-update for dead engines + else: + if self.loads[idx] == self.hwm-1: + self.update_graph(None) @@ -526,7 +568,7 @@ class TaskScheduler(SessionFactory): """dep_id just finished. Update our dependency graph and submit any jobs that just became runable. - Called with dep_id=None to update graph for hwm, but without finishing + Called with dep_id=None to update entire graph for hwm, but without finishing a task. """ # print ("\n\n***********") @@ -538,9 +580,11 @@ class TaskScheduler(SessionFactory): # print ("\n\n***********\n\n") # update any jobs that depended on the dependency jobs = self.graph.pop(dep_id, []) - # if we have HWM and an engine just become no longer full - # recheck *all* jobs: - if self.hwm and any( [ load==self.hwm-1 for load in self.loads]): + + # recheck *all* jobs if + # a) we have HWM and an engine just become no longer full + # or b) dep_id was given as None + if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): jobs = self.depending.keys() for msg_id in jobs: diff --git a/IPython/parallel/tests/__init__.py b/IPython/parallel/tests/__init__.py index 0b1e013..ae517ed 100644 --- a/IPython/parallel/tests/__init__.py +++ b/IPython/parallel/tests/__init__.py @@ -48,7 +48,7 @@ class TestProcessLauncher(LocalProcessLauncher): def setup(): cp = TestProcessLauncher() cp.cmd_and_args = ipcontroller_cmd_argv + \ - ['--profile', 'iptest', '--log-level', '99', '-r', '--usethreads'] + ['--profile', 'iptest', '--log-level', '99', '-r'] cp.start() launchers.append(cp) cluster_dir = os.path.join(get_ipython_dir(), 'cluster_iptest') diff --git a/IPython/parallel/tests/test_lbview.py b/IPython/parallel/tests/test_lbview.py new file mode 100644 index 0000000..14a8211 --- /dev/null +++ b/IPython/parallel/tests/test_lbview.py @@ -0,0 +1,120 @@ +"""test LoadBalancedView objects""" +# -*- coding: utf-8 -*- +#------------------------------------------------------------------------------- +# Copyright (C) 2011 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import sys +import time + +import zmq + +from IPython import parallel as pmod +from IPython.parallel import error + +from IPython.parallel.tests import add_engines + +from .clienttest import ClusterTestCase, crash, wait, skip_without + +def setup(): + add_engines(3) + +class TestLoadBalancedView(ClusterTestCase): + + def setUp(self): + ClusterTestCase.setUp(self) + self.view = self.client.load_balanced_view() + + def test_z_crash_task(self): + """test graceful handling of engine death (balanced)""" + # self.add_engines(1) + ar = self.view.apply_async(crash) + self.assertRaisesRemote(error.EngineError, ar.get) + eid = ar.engine_id + tic = time.time() + while eid in self.client.ids and time.time()-tic < 5: + time.sleep(.01) + self.client.spin() + self.assertFalse(eid in self.client.ids, "Engine should have died") + + def test_map(self): + def f(x): + return x**2 + data = range(16) + r = self.view.map_sync(f, data) + self.assertEquals(r, map(f, data)) + + def test_abort(self): + view = self.view + ar = self.client[:].apply_async(time.sleep, .5) + ar2 = view.apply_async(lambda : 2) + ar3 = view.apply_async(lambda : 3) + view.abort(ar2) + view.abort(ar3.msg_ids) + self.assertRaises(error.TaskAborted, ar2.get) + self.assertRaises(error.TaskAborted, ar3.get) + + def test_retries(self): + add_engines(3) + view = self.view + view.timeout = 1 # prevent hang if this doesn't behave + def fail(): + assert False + for r in range(len(self.client)-1): + with view.temp_flags(retries=r): + self.assertRaisesRemote(AssertionError, view.apply_sync, fail) + + with view.temp_flags(retries=len(self.client), timeout=0.25): + self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail) + + def test_invalid_dependency(self): + view = self.view + with view.temp_flags(after='12345'): + self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1) + + def test_impossible_dependency(self): + if len(self.client) < 2: + add_engines(2) + view = self.client.load_balanced_view() + ar1 = view.apply_async(lambda : 1) + ar1.get() + e1 = ar1.engine_id + e2 = e1 + while e2 == e1: + ar2 = view.apply_async(lambda : 1) + ar2.get() + e2 = ar2.engine_id + + with view.temp_flags(follow=[ar1, ar2]): + self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1) + + + def test_follow(self): + ar = self.view.apply_async(lambda : 1) + ar.get() + ars = [] + first_id = ar.engine_id + + self.view.follow = ar + for i in range(5): + ars.append(self.view.apply_async(lambda : 1)) + self.view.wait(ars) + for ar in ars: + self.assertEquals(ar.engine_id, first_id) + + def test_after(self): + view = self.view + ar = view.apply_async(time.sleep, 0.5) + with view.temp_flags(after=ar): + ar2 = view.apply_async(lambda : 1) + + ar.wait() + ar2.wait() + self.assertTrue(ar2.started > ar.completed) diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index 8d0f3f7..8eeb596 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -21,7 +21,7 @@ import zmq from IPython import parallel as pmod from IPython.parallel import error from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult -from IPython.parallel import LoadBalancedView, DirectView +from IPython.parallel import DirectView from IPython.parallel.util import interactive from IPython.parallel.tests import add_engines @@ -33,18 +33,6 @@ def setup(): class TestView(ClusterTestCase): - def test_z_crash_task(self): - """test graceful handling of engine death (balanced)""" - # self.add_engines(1) - ar = self.client[-1].apply_async(crash) - self.assertRaisesRemote(error.EngineError, ar.get) - eid = ar.engine_id - tic = time.time() - while eid in self.client.ids and time.time()-tic < 5: - time.sleep(.01) - self.client.spin() - self.assertFalse(eid in self.client.ids, "Engine should have died") - def test_z_crash_mux(self): """test graceful handling of engine death (direct)""" # self.add_engines(1)