##// END OF EJS Templates
preserve dependency attributes in scheduler...
MinRK -
Show More
@@ -1,697 +1,705 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 from __future__ import print_function
22 from __future__ import print_function
23
23
24 import logging
24 import logging
25 import sys
25 import sys
26
26
27 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
28 from random import randint, random
28 from random import randint, random
29 from types import FunctionType
29 from types import FunctionType
30
30
31 try:
31 try:
32 import numpy
32 import numpy
33 except ImportError:
33 except ImportError:
34 numpy = None
34 numpy = None
35
35
36 import zmq
36 import zmq
37 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
38
38
39 # local imports
39 # local imports
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 from IPython.config.application import Application
42 from IPython.config.loader import Config
42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum
44
44
45 from IPython.parallel import error
45 from IPython.parallel import error
46 from IPython.parallel.factory import SessionFactory
46 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.util import connect_logger, local_logger
47 from IPython.parallel.util import connect_logger, local_logger
48
48
49 from .dependency import Dependency
49 from .dependency import Dependency
50
50
51 @decorator
51 @decorator
52 def logged(f,self,*args,**kwargs):
52 def logged(f,self,*args,**kwargs):
53 # print ("#--------------------")
53 # print ("#--------------------")
54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 # print ("#--")
55 # print ("#--")
56 return f(self,*args, **kwargs)
56 return f(self,*args, **kwargs)
57
57
58 #----------------------------------------------------------------------
58 #----------------------------------------------------------------------
59 # Chooser functions
59 # Chooser functions
60 #----------------------------------------------------------------------
60 #----------------------------------------------------------------------
61
61
62 def plainrandom(loads):
62 def plainrandom(loads):
63 """Plain random pick."""
63 """Plain random pick."""
64 n = len(loads)
64 n = len(loads)
65 return randint(0,n-1)
65 return randint(0,n-1)
66
66
67 def lru(loads):
67 def lru(loads):
68 """Always pick the front of the line.
68 """Always pick the front of the line.
69
69
70 The content of `loads` is ignored.
70 The content of `loads` is ignored.
71
71
72 Assumes LRU ordering of loads, with oldest first.
72 Assumes LRU ordering of loads, with oldest first.
73 """
73 """
74 return 0
74 return 0
75
75
76 def twobin(loads):
76 def twobin(loads):
77 """Pick two at random, use the LRU of the two.
77 """Pick two at random, use the LRU of the two.
78
78
79 The content of loads is ignored.
79 The content of loads is ignored.
80
80
81 Assumes LRU ordering of loads, with oldest first.
81 Assumes LRU ordering of loads, with oldest first.
82 """
82 """
83 n = len(loads)
83 n = len(loads)
84 a = randint(0,n-1)
84 a = randint(0,n-1)
85 b = randint(0,n-1)
85 b = randint(0,n-1)
86 return min(a,b)
86 return min(a,b)
87
87
88 def weighted(loads):
88 def weighted(loads):
89 """Pick two at random using inverse load as weight.
89 """Pick two at random using inverse load as weight.
90
90
91 Return the less loaded of the two.
91 Return the less loaded of the two.
92 """
92 """
93 # weight 0 a million times more than 1:
93 # weight 0 a million times more than 1:
94 weights = 1./(1e-6+numpy.array(loads))
94 weights = 1./(1e-6+numpy.array(loads))
95 sums = weights.cumsum()
95 sums = weights.cumsum()
96 t = sums[-1]
96 t = sums[-1]
97 x = random()*t
97 x = random()*t
98 y = random()*t
98 y = random()*t
99 idx = 0
99 idx = 0
100 idy = 0
100 idy = 0
101 while sums[idx] < x:
101 while sums[idx] < x:
102 idx += 1
102 idx += 1
103 while sums[idy] < y:
103 while sums[idy] < y:
104 idy += 1
104 idy += 1
105 if weights[idy] > weights[idx]:
105 if weights[idy] > weights[idx]:
106 return idy
106 return idy
107 else:
107 else:
108 return idx
108 return idx
109
109
110 def leastload(loads):
110 def leastload(loads):
111 """Always choose the lowest load.
111 """Always choose the lowest load.
112
112
113 If the lowest load occurs more than once, the first
113 If the lowest load occurs more than once, the first
114 occurance will be used. If loads has LRU ordering, this means
114 occurance will be used. If loads has LRU ordering, this means
115 the LRU of those with the lowest load is chosen.
115 the LRU of those with the lowest load is chosen.
116 """
116 """
117 return loads.index(min(loads))
117 return loads.index(min(loads))
118
118
119 #---------------------------------------------------------------------
119 #---------------------------------------------------------------------
120 # Classes
120 # Classes
121 #---------------------------------------------------------------------
121 #---------------------------------------------------------------------
122 # store empty default dependency:
122 # store empty default dependency:
123 MET = Dependency([])
123 MET = Dependency([])
124
124
125 class TaskScheduler(SessionFactory):
125 class TaskScheduler(SessionFactory):
126 """Python TaskScheduler object.
126 """Python TaskScheduler object.
127
127
128 This is the simplest object that supports msg_id based
128 This is the simplest object that supports msg_id based
129 DAG dependencies. *Only* task msg_ids are checked, not
129 DAG dependencies. *Only* task msg_ids are checked, not
130 msg_ids of jobs submitted via the MUX queue.
130 msg_ids of jobs submitted via the MUX queue.
131
131
132 """
132 """
133
133
134 hwm = Int(0, config=True, shortname='hwm',
134 hwm = Int(0, config=True, shortname='hwm',
135 help="""specify the High Water Mark (HWM) for the downstream
135 help="""specify the High Water Mark (HWM) for the downstream
136 socket in the Task scheduler. This is the maximum number
136 socket in the Task scheduler. This is the maximum number
137 of allowed outstanding tasks on each engine."""
137 of allowed outstanding tasks on each engine."""
138 )
138 )
139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 'leastload', config=True, shortname='scheme', allow_none=False,
140 'leastload', config=True, shortname='scheme', allow_none=False,
141 help="""select the task scheduler scheme [default: Python LRU]
141 help="""select the task scheduler scheme [default: Python LRU]
142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 )
143 )
144 def _scheme_name_changed(self, old, new):
144 def _scheme_name_changed(self, old, new):
145 self.log.debug("Using scheme %r"%new)
145 self.log.debug("Using scheme %r"%new)
146 self.scheme = globals()[new]
146 self.scheme = globals()[new]
147
147
148 # input arguments:
148 # input arguments:
149 scheme = Instance(FunctionType) # function for determining the destination
149 scheme = Instance(FunctionType) # function for determining the destination
150 def _scheme_default(self):
150 def _scheme_default(self):
151 return leastload
151 return leastload
152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156
156
157 # internals:
157 # internals:
158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 pending = Dict() # dict by engine_uuid of submitted tasks
162 pending = Dict() # dict by engine_uuid of submitted tasks
163 completed = Dict() # dict by engine_uuid of completed tasks
163 completed = Dict() # dict by engine_uuid of completed tasks
164 failed = Dict() # dict by engine_uuid of failed tasks
164 failed = Dict() # dict by engine_uuid of failed tasks
165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 clients = Dict() # dict by msg_id for who submitted the task
166 clients = Dict() # dict by msg_id for who submitted the task
167 targets = List() # list of target IDENTs
167 targets = List() # list of target IDENTs
168 loads = List() # list of engine loads
168 loads = List() # list of engine loads
169 # full = Set() # set of IDENTs that have HWM outstanding tasks
169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 all_completed = Set() # set of all completed tasks
170 all_completed = Set() # set of all completed tasks
171 all_failed = Set() # set of all failed tasks
171 all_failed = Set() # set of all failed tasks
172 all_done = Set() # set of all finished tasks=union(completed,failed)
172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 all_ids = Set() # set of all submitted task IDs
173 all_ids = Set() # set of all submitted task IDs
174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176
176
177
177
178 def start(self):
178 def start(self):
179 self.engine_stream.on_recv(self.dispatch_result, copy=False)
179 self.engine_stream.on_recv(self.dispatch_result, copy=False)
180 self._notification_handlers = dict(
180 self._notification_handlers = dict(
181 registration_notification = self._register_engine,
181 registration_notification = self._register_engine,
182 unregistration_notification = self._unregister_engine
182 unregistration_notification = self._unregister_engine
183 )
183 )
184 self.notifier_stream.on_recv(self.dispatch_notification)
184 self.notifier_stream.on_recv(self.dispatch_notification)
185 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
185 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
186 self.auditor.start()
186 self.auditor.start()
187 self.log.info("Scheduler started [%s]"%self.scheme_name)
187 self.log.info("Scheduler started [%s]"%self.scheme_name)
188
188
189 def resume_receiving(self):
189 def resume_receiving(self):
190 """Resume accepting jobs."""
190 """Resume accepting jobs."""
191 self.client_stream.on_recv(self.dispatch_submission, copy=False)
191 self.client_stream.on_recv(self.dispatch_submission, copy=False)
192
192
193 def stop_receiving(self):
193 def stop_receiving(self):
194 """Stop accepting jobs while there are no engines.
194 """Stop accepting jobs while there are no engines.
195 Leave them in the ZMQ queue."""
195 Leave them in the ZMQ queue."""
196 self.client_stream.on_recv(None)
196 self.client_stream.on_recv(None)
197
197
198 #-----------------------------------------------------------------------
198 #-----------------------------------------------------------------------
199 # [Un]Registration Handling
199 # [Un]Registration Handling
200 #-----------------------------------------------------------------------
200 #-----------------------------------------------------------------------
201
201
202 def dispatch_notification(self, msg):
202 def dispatch_notification(self, msg):
203 """dispatch register/unregister events."""
203 """dispatch register/unregister events."""
204 try:
204 try:
205 idents,msg = self.session.feed_identities(msg)
205 idents,msg = self.session.feed_identities(msg)
206 except ValueError:
206 except ValueError:
207 self.log.warn("task::Invalid Message: %r"%msg)
207 self.log.warn("task::Invalid Message: %r"%msg)
208 return
208 return
209 try:
209 try:
210 msg = self.session.unpack_message(msg)
210 msg = self.session.unpack_message(msg)
211 except ValueError:
211 except ValueError:
212 self.log.warn("task::Unauthorized message from: %r"%idents)
212 self.log.warn("task::Unauthorized message from: %r"%idents)
213 return
213 return
214
214
215 msg_type = msg['msg_type']
215 msg_type = msg['msg_type']
216
216
217 handler = self._notification_handlers.get(msg_type, None)
217 handler = self._notification_handlers.get(msg_type, None)
218 if handler is None:
218 if handler is None:
219 self.log.error("Unhandled message type: %r"%msg_type)
219 self.log.error("Unhandled message type: %r"%msg_type)
220 else:
220 else:
221 try:
221 try:
222 handler(str(msg['content']['queue']))
222 handler(str(msg['content']['queue']))
223 except KeyError:
223 except KeyError:
224 self.log.error("task::Invalid notification msg: %r"%msg)
224 self.log.error("task::Invalid notification msg: %r"%msg)
225
225
226 def _register_engine(self, uid):
226 def _register_engine(self, uid):
227 """New engine with ident `uid` became available."""
227 """New engine with ident `uid` became available."""
228 # head of the line:
228 # head of the line:
229 self.targets.insert(0,uid)
229 self.targets.insert(0,uid)
230 self.loads.insert(0,0)
230 self.loads.insert(0,0)
231 # initialize sets
231 # initialize sets
232 self.completed[uid] = set()
232 self.completed[uid] = set()
233 self.failed[uid] = set()
233 self.failed[uid] = set()
234 self.pending[uid] = {}
234 self.pending[uid] = {}
235 if len(self.targets) == 1:
235 if len(self.targets) == 1:
236 self.resume_receiving()
236 self.resume_receiving()
237 # rescan the graph:
237 # rescan the graph:
238 self.update_graph(None)
238 self.update_graph(None)
239
239
240 def _unregister_engine(self, uid):
240 def _unregister_engine(self, uid):
241 """Existing engine with ident `uid` became unavailable."""
241 """Existing engine with ident `uid` became unavailable."""
242 if len(self.targets) == 1:
242 if len(self.targets) == 1:
243 # this was our only engine
243 # this was our only engine
244 self.stop_receiving()
244 self.stop_receiving()
245
245
246 # handle any potentially finished tasks:
246 # handle any potentially finished tasks:
247 self.engine_stream.flush()
247 self.engine_stream.flush()
248
248
249 # don't pop destinations, because they might be used later
249 # don't pop destinations, because they might be used later
250 # map(self.destinations.pop, self.completed.pop(uid))
250 # map(self.destinations.pop, self.completed.pop(uid))
251 # map(self.destinations.pop, self.failed.pop(uid))
251 # map(self.destinations.pop, self.failed.pop(uid))
252
252
253 # prevent this engine from receiving work
253 # prevent this engine from receiving work
254 idx = self.targets.index(uid)
254 idx = self.targets.index(uid)
255 self.targets.pop(idx)
255 self.targets.pop(idx)
256 self.loads.pop(idx)
256 self.loads.pop(idx)
257
257
258 # wait 5 seconds before cleaning up pending jobs, since the results might
258 # wait 5 seconds before cleaning up pending jobs, since the results might
259 # still be incoming
259 # still be incoming
260 if self.pending[uid]:
260 if self.pending[uid]:
261 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
261 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
262 dc.start()
262 dc.start()
263 else:
263 else:
264 self.completed.pop(uid)
264 self.completed.pop(uid)
265 self.failed.pop(uid)
265 self.failed.pop(uid)
266
266
267
267
268 def handle_stranded_tasks(self, engine):
268 def handle_stranded_tasks(self, engine):
269 """Deal with jobs resident in an engine that died."""
269 """Deal with jobs resident in an engine that died."""
270 lost = self.pending[engine]
270 lost = self.pending[engine]
271 for msg_id in lost.keys():
271 for msg_id in lost.keys():
272 if msg_id not in self.pending[engine]:
272 if msg_id not in self.pending[engine]:
273 # prevent double-handling of messages
273 # prevent double-handling of messages
274 continue
274 continue
275
275
276 raw_msg = lost[msg_id][0]
276 raw_msg = lost[msg_id][0]
277 idents,msg = self.session.feed_identities(raw_msg, copy=False)
277 idents,msg = self.session.feed_identities(raw_msg, copy=False)
278 parent = self.session.unpack(msg[1].bytes)
278 parent = self.session.unpack(msg[1].bytes)
279 idents = [engine, idents[0]]
279 idents = [engine, idents[0]]
280
280
281 # build fake error reply
281 # build fake error reply
282 try:
282 try:
283 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
283 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
284 except:
284 except:
285 content = error.wrap_exception()
285 content = error.wrap_exception()
286 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
286 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
287 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
287 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
288 # and dispatch it
288 # and dispatch it
289 self.dispatch_result(raw_reply)
289 self.dispatch_result(raw_reply)
290
290
291 # finally scrub completed/failed lists
291 # finally scrub completed/failed lists
292 self.completed.pop(engine)
292 self.completed.pop(engine)
293 self.failed.pop(engine)
293 self.failed.pop(engine)
294
294
295
295
296 #-----------------------------------------------------------------------
296 #-----------------------------------------------------------------------
297 # Job Submission
297 # Job Submission
298 #-----------------------------------------------------------------------
298 #-----------------------------------------------------------------------
299 def dispatch_submission(self, raw_msg):
299 def dispatch_submission(self, raw_msg):
300 """Dispatch job submission to appropriate handlers."""
300 """Dispatch job submission to appropriate handlers."""
301 # ensure targets up to date:
301 # ensure targets up to date:
302 self.notifier_stream.flush()
302 self.notifier_stream.flush()
303 try:
303 try:
304 idents, msg = self.session.feed_identities(raw_msg, copy=False)
304 idents, msg = self.session.feed_identities(raw_msg, copy=False)
305 msg = self.session.unpack_message(msg, content=False, copy=False)
305 msg = self.session.unpack_message(msg, content=False, copy=False)
306 except Exception:
306 except Exception:
307 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
307 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
308 return
308 return
309
309
310
310
311 # send to monitor
311 # send to monitor
312 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
312 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
313
313
314 header = msg['header']
314 header = msg['header']
315 msg_id = header['msg_id']
315 msg_id = header['msg_id']
316 self.all_ids.add(msg_id)
316 self.all_ids.add(msg_id)
317
317
318 # targets
318 # targets
319 targets = set(header.get('targets', []))
319 targets = set(header.get('targets', []))
320 retries = header.get('retries', 0)
320 retries = header.get('retries', 0)
321 self.retries[msg_id] = retries
321 self.retries[msg_id] = retries
322
322
323 # time dependencies
323 # time dependencies
324 after = header.get('after', None)
324 after = header.get('after', None)
325 if after:
325 if after:
326 after = Dependency(after)
326 after = Dependency(after)
327 if after.all:
327 if after.all:
328 if after.success:
328 if after.success:
329 after = after.difference(self.all_completed)
329 after = Dependency(after.difference(self.all_completed),
330 success=after.success,
331 failure=after.failure,
332 all=after.all,
333 )
330 if after.failure:
334 if after.failure:
331 after = after.difference(self.all_failed)
335 after = Dependency(after.difference(self.all_failed),
336 success=after.success,
337 failure=after.failure,
338 all=after.all,
339 )
332 if after.check(self.all_completed, self.all_failed):
340 if after.check(self.all_completed, self.all_failed):
333 # recast as empty set, if `after` already met,
341 # recast as empty set, if `after` already met,
334 # to prevent unnecessary set comparisons
342 # to prevent unnecessary set comparisons
335 after = MET
343 after = MET
336 else:
344 else:
337 after = MET
345 after = MET
338
346
339 # location dependencies
347 # location dependencies
340 follow = Dependency(header.get('follow', []))
348 follow = Dependency(header.get('follow', []))
341
349
342 # turn timeouts into datetime objects:
350 # turn timeouts into datetime objects:
343 timeout = header.get('timeout', None)
351 timeout = header.get('timeout', None)
344 if timeout:
352 if timeout:
345 timeout = datetime.now() + timedelta(0,timeout,0)
353 timeout = datetime.now() + timedelta(0,timeout,0)
346
354
347 args = [raw_msg, targets, after, follow, timeout]
355 args = [raw_msg, targets, after, follow, timeout]
348
356
349 # validate and reduce dependencies:
357 # validate and reduce dependencies:
350 for dep in after,follow:
358 for dep in after,follow:
351 if not dep: # empty dependency
359 if not dep: # empty dependency
352 continue
360 continue
353 # check valid:
361 # check valid:
354 if msg_id in dep or dep.difference(self.all_ids):
362 if msg_id in dep or dep.difference(self.all_ids):
355 self.depending[msg_id] = args
363 self.depending[msg_id] = args
356 return self.fail_unreachable(msg_id, error.InvalidDependency)
364 return self.fail_unreachable(msg_id, error.InvalidDependency)
357 # check if unreachable:
365 # check if unreachable:
358 if dep.unreachable(self.all_completed, self.all_failed):
366 if dep.unreachable(self.all_completed, self.all_failed):
359 self.depending[msg_id] = args
367 self.depending[msg_id] = args
360 return self.fail_unreachable(msg_id)
368 return self.fail_unreachable(msg_id)
361
369
362 if after.check(self.all_completed, self.all_failed):
370 if after.check(self.all_completed, self.all_failed):
363 # time deps already met, try to run
371 # time deps already met, try to run
364 if not self.maybe_run(msg_id, *args):
372 if not self.maybe_run(msg_id, *args):
365 # can't run yet
373 # can't run yet
366 if msg_id not in self.all_failed:
374 if msg_id not in self.all_failed:
367 # could have failed as unreachable
375 # could have failed as unreachable
368 self.save_unmet(msg_id, *args)
376 self.save_unmet(msg_id, *args)
369 else:
377 else:
370 self.save_unmet(msg_id, *args)
378 self.save_unmet(msg_id, *args)
371
379
372 def audit_timeouts(self):
380 def audit_timeouts(self):
373 """Audit all waiting tasks for expired timeouts."""
381 """Audit all waiting tasks for expired timeouts."""
374 now = datetime.now()
382 now = datetime.now()
375 for msg_id in self.depending.keys():
383 for msg_id in self.depending.keys():
376 # must recheck, in case one failure cascaded to another:
384 # must recheck, in case one failure cascaded to another:
377 if msg_id in self.depending:
385 if msg_id in self.depending:
378 raw,after,targets,follow,timeout = self.depending[msg_id]
386 raw,after,targets,follow,timeout = self.depending[msg_id]
379 if timeout and timeout < now:
387 if timeout and timeout < now:
380 self.fail_unreachable(msg_id, error.TaskTimeout)
388 self.fail_unreachable(msg_id, error.TaskTimeout)
381
389
382 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
390 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
383 """a task has become unreachable, send a reply with an ImpossibleDependency
391 """a task has become unreachable, send a reply with an ImpossibleDependency
384 error."""
392 error."""
385 if msg_id not in self.depending:
393 if msg_id not in self.depending:
386 self.log.error("msg %r already failed!", msg_id)
394 self.log.error("msg %r already failed!", msg_id)
387 return
395 return
388 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
396 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
389 for mid in follow.union(after):
397 for mid in follow.union(after):
390 if mid in self.graph:
398 if mid in self.graph:
391 self.graph[mid].remove(msg_id)
399 self.graph[mid].remove(msg_id)
392
400
393 # FIXME: unpacking a message I've already unpacked, but didn't save:
401 # FIXME: unpacking a message I've already unpacked, but didn't save:
394 idents,msg = self.session.feed_identities(raw_msg, copy=False)
402 idents,msg = self.session.feed_identities(raw_msg, copy=False)
395 header = self.session.unpack(msg[1].bytes)
403 header = self.session.unpack(msg[1].bytes)
396
404
397 try:
405 try:
398 raise why()
406 raise why()
399 except:
407 except:
400 content = error.wrap_exception()
408 content = error.wrap_exception()
401
409
402 self.all_done.add(msg_id)
410 self.all_done.add(msg_id)
403 self.all_failed.add(msg_id)
411 self.all_failed.add(msg_id)
404
412
405 msg = self.session.send(self.client_stream, 'apply_reply', content,
413 msg = self.session.send(self.client_stream, 'apply_reply', content,
406 parent=header, ident=idents)
414 parent=header, ident=idents)
407 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
415 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
408
416
409 self.update_graph(msg_id, success=False)
417 self.update_graph(msg_id, success=False)
410
418
411 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
419 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
412 """check location dependencies, and run if they are met."""
420 """check location dependencies, and run if they are met."""
413 blacklist = self.blacklist.setdefault(msg_id, set())
421 blacklist = self.blacklist.setdefault(msg_id, set())
414 if follow or targets or blacklist or self.hwm:
422 if follow or targets or blacklist or self.hwm:
415 # we need a can_run filter
423 # we need a can_run filter
416 def can_run(idx):
424 def can_run(idx):
417 # check hwm
425 # check hwm
418 if self.hwm and self.loads[idx] == self.hwm:
426 if self.hwm and self.loads[idx] == self.hwm:
419 return False
427 return False
420 target = self.targets[idx]
428 target = self.targets[idx]
421 # check blacklist
429 # check blacklist
422 if target in blacklist:
430 if target in blacklist:
423 return False
431 return False
424 # check targets
432 # check targets
425 if targets and target not in targets:
433 if targets and target not in targets:
426 return False
434 return False
427 # check follow
435 # check follow
428 return follow.check(self.completed[target], self.failed[target])
436 return follow.check(self.completed[target], self.failed[target])
429
437
430 indices = filter(can_run, range(len(self.targets)))
438 indices = filter(can_run, range(len(self.targets)))
431
439
432 if not indices:
440 if not indices:
433 # couldn't run
441 # couldn't run
434 if follow.all:
442 if follow.all:
435 # check follow for impossibility
443 # check follow for impossibility
436 dests = set()
444 dests = set()
437 relevant = set()
445 relevant = set()
438 if follow.success:
446 if follow.success:
439 relevant = self.all_completed
447 relevant = self.all_completed
440 if follow.failure:
448 if follow.failure:
441 relevant = relevant.union(self.all_failed)
449 relevant = relevant.union(self.all_failed)
442 for m in follow.intersection(relevant):
450 for m in follow.intersection(relevant):
443 dests.add(self.destinations[m])
451 dests.add(self.destinations[m])
444 if len(dests) > 1:
452 if len(dests) > 1:
445 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
453 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
446 self.fail_unreachable(msg_id)
454 self.fail_unreachable(msg_id)
447 return False
455 return False
448 if targets:
456 if targets:
449 # check blacklist+targets for impossibility
457 # check blacklist+targets for impossibility
450 targets.difference_update(blacklist)
458 targets.difference_update(blacklist)
451 if not targets or not targets.intersection(self.targets):
459 if not targets or not targets.intersection(self.targets):
452 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
460 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
453 self.fail_unreachable(msg_id)
461 self.fail_unreachable(msg_id)
454 return False
462 return False
455 return False
463 return False
456 else:
464 else:
457 indices = None
465 indices = None
458
466
459 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
467 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
460 return True
468 return True
461
469
462 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
470 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
463 """Save a message for later submission when its dependencies are met."""
471 """Save a message for later submission when its dependencies are met."""
464 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
472 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
465 # track the ids in follow or after, but not those already finished
473 # track the ids in follow or after, but not those already finished
466 for dep_id in after.union(follow).difference(self.all_done):
474 for dep_id in after.union(follow).difference(self.all_done):
467 if dep_id not in self.graph:
475 if dep_id not in self.graph:
468 self.graph[dep_id] = set()
476 self.graph[dep_id] = set()
469 self.graph[dep_id].add(msg_id)
477 self.graph[dep_id].add(msg_id)
470
478
471 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
479 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
472 """Submit a task to any of a subset of our targets."""
480 """Submit a task to any of a subset of our targets."""
473 if indices:
481 if indices:
474 loads = [self.loads[i] for i in indices]
482 loads = [self.loads[i] for i in indices]
475 else:
483 else:
476 loads = self.loads
484 loads = self.loads
477 idx = self.scheme(loads)
485 idx = self.scheme(loads)
478 if indices:
486 if indices:
479 idx = indices[idx]
487 idx = indices[idx]
480 target = self.targets[idx]
488 target = self.targets[idx]
481 # print (target, map(str, msg[:3]))
489 # print (target, map(str, msg[:3]))
482 # send job to the engine
490 # send job to the engine
483 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
491 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
484 self.engine_stream.send_multipart(raw_msg, copy=False)
492 self.engine_stream.send_multipart(raw_msg, copy=False)
485 # update load
493 # update load
486 self.add_job(idx)
494 self.add_job(idx)
487 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
495 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
488 # notify Hub
496 # notify Hub
489 content = dict(msg_id=msg_id, engine_id=target)
497 content = dict(msg_id=msg_id, engine_id=target)
490 self.session.send(self.mon_stream, 'task_destination', content=content,
498 self.session.send(self.mon_stream, 'task_destination', content=content,
491 ident=['tracktask',self.session.session])
499 ident=['tracktask',self.session.session])
492
500
493
501
494 #-----------------------------------------------------------------------
502 #-----------------------------------------------------------------------
495 # Result Handling
503 # Result Handling
496 #-----------------------------------------------------------------------
504 #-----------------------------------------------------------------------
497 def dispatch_result(self, raw_msg):
505 def dispatch_result(self, raw_msg):
498 """dispatch method for result replies"""
506 """dispatch method for result replies"""
499 try:
507 try:
500 idents,msg = self.session.feed_identities(raw_msg, copy=False)
508 idents,msg = self.session.feed_identities(raw_msg, copy=False)
501 msg = self.session.unpack_message(msg, content=False, copy=False)
509 msg = self.session.unpack_message(msg, content=False, copy=False)
502 engine = idents[0]
510 engine = idents[0]
503 try:
511 try:
504 idx = self.targets.index(engine)
512 idx = self.targets.index(engine)
505 except ValueError:
513 except ValueError:
506 pass # skip load-update for dead engines
514 pass # skip load-update for dead engines
507 else:
515 else:
508 self.finish_job(idx)
516 self.finish_job(idx)
509 except Exception:
517 except Exception:
510 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
518 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
511 return
519 return
512
520
513 header = msg['header']
521 header = msg['header']
514 parent = msg['parent_header']
522 parent = msg['parent_header']
515 if header.get('dependencies_met', True):
523 if header.get('dependencies_met', True):
516 success = (header['status'] == 'ok')
524 success = (header['status'] == 'ok')
517 msg_id = parent['msg_id']
525 msg_id = parent['msg_id']
518 retries = self.retries[msg_id]
526 retries = self.retries[msg_id]
519 if not success and retries > 0:
527 if not success and retries > 0:
520 # failed
528 # failed
521 self.retries[msg_id] = retries - 1
529 self.retries[msg_id] = retries - 1
522 self.handle_unmet_dependency(idents, parent)
530 self.handle_unmet_dependency(idents, parent)
523 else:
531 else:
524 del self.retries[msg_id]
532 del self.retries[msg_id]
525 # relay to client and update graph
533 # relay to client and update graph
526 self.handle_result(idents, parent, raw_msg, success)
534 self.handle_result(idents, parent, raw_msg, success)
527 # send to Hub monitor
535 # send to Hub monitor
528 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
536 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
529 else:
537 else:
530 self.handle_unmet_dependency(idents, parent)
538 self.handle_unmet_dependency(idents, parent)
531
539
532 def handle_result(self, idents, parent, raw_msg, success=True):
540 def handle_result(self, idents, parent, raw_msg, success=True):
533 """handle a real task result, either success or failure"""
541 """handle a real task result, either success or failure"""
534 # first, relay result to client
542 # first, relay result to client
535 engine = idents[0]
543 engine = idents[0]
536 client = idents[1]
544 client = idents[1]
537 # swap_ids for XREP-XREP mirror
545 # swap_ids for XREP-XREP mirror
538 raw_msg[:2] = [client,engine]
546 raw_msg[:2] = [client,engine]
539 # print (map(str, raw_msg[:4]))
547 # print (map(str, raw_msg[:4]))
540 self.client_stream.send_multipart(raw_msg, copy=False)
548 self.client_stream.send_multipart(raw_msg, copy=False)
541 # now, update our data structures
549 # now, update our data structures
542 msg_id = parent['msg_id']
550 msg_id = parent['msg_id']
543 self.blacklist.pop(msg_id, None)
551 self.blacklist.pop(msg_id, None)
544 self.pending[engine].pop(msg_id)
552 self.pending[engine].pop(msg_id)
545 if success:
553 if success:
546 self.completed[engine].add(msg_id)
554 self.completed[engine].add(msg_id)
547 self.all_completed.add(msg_id)
555 self.all_completed.add(msg_id)
548 else:
556 else:
549 self.failed[engine].add(msg_id)
557 self.failed[engine].add(msg_id)
550 self.all_failed.add(msg_id)
558 self.all_failed.add(msg_id)
551 self.all_done.add(msg_id)
559 self.all_done.add(msg_id)
552 self.destinations[msg_id] = engine
560 self.destinations[msg_id] = engine
553
561
554 self.update_graph(msg_id, success)
562 self.update_graph(msg_id, success)
555
563
556 def handle_unmet_dependency(self, idents, parent):
564 def handle_unmet_dependency(self, idents, parent):
557 """handle an unmet dependency"""
565 """handle an unmet dependency"""
558 engine = idents[0]
566 engine = idents[0]
559 msg_id = parent['msg_id']
567 msg_id = parent['msg_id']
560
568
561 if msg_id not in self.blacklist:
569 if msg_id not in self.blacklist:
562 self.blacklist[msg_id] = set()
570 self.blacklist[msg_id] = set()
563 self.blacklist[msg_id].add(engine)
571 self.blacklist[msg_id].add(engine)
564
572
565 args = self.pending[engine].pop(msg_id)
573 args = self.pending[engine].pop(msg_id)
566 raw,targets,after,follow,timeout = args
574 raw,targets,after,follow,timeout = args
567
575
568 if self.blacklist[msg_id] == targets:
576 if self.blacklist[msg_id] == targets:
569 self.depending[msg_id] = args
577 self.depending[msg_id] = args
570 self.fail_unreachable(msg_id)
578 self.fail_unreachable(msg_id)
571 elif not self.maybe_run(msg_id, *args):
579 elif not self.maybe_run(msg_id, *args):
572 # resubmit failed
580 # resubmit failed
573 if msg_id not in self.all_failed:
581 if msg_id not in self.all_failed:
574 # put it back in our dependency tree
582 # put it back in our dependency tree
575 self.save_unmet(msg_id, *args)
583 self.save_unmet(msg_id, *args)
576
584
577 if self.hwm:
585 if self.hwm:
578 try:
586 try:
579 idx = self.targets.index(engine)
587 idx = self.targets.index(engine)
580 except ValueError:
588 except ValueError:
581 pass # skip load-update for dead engines
589 pass # skip load-update for dead engines
582 else:
590 else:
583 if self.loads[idx] == self.hwm-1:
591 if self.loads[idx] == self.hwm-1:
584 self.update_graph(None)
592 self.update_graph(None)
585
593
586
594
587
595
588 def update_graph(self, dep_id=None, success=True):
596 def update_graph(self, dep_id=None, success=True):
589 """dep_id just finished. Update our dependency
597 """dep_id just finished. Update our dependency
590 graph and submit any jobs that just became runable.
598 graph and submit any jobs that just became runable.
591
599
592 Called with dep_id=None to update entire graph for hwm, but without finishing
600 Called with dep_id=None to update entire graph for hwm, but without finishing
593 a task.
601 a task.
594 """
602 """
595 # print ("\n\n***********")
603 # print ("\n\n***********")
596 # pprint (dep_id)
604 # pprint (dep_id)
597 # pprint (self.graph)
605 # pprint (self.graph)
598 # pprint (self.depending)
606 # pprint (self.depending)
599 # pprint (self.all_completed)
607 # pprint (self.all_completed)
600 # pprint (self.all_failed)
608 # pprint (self.all_failed)
601 # print ("\n\n***********\n\n")
609 # print ("\n\n***********\n\n")
602 # update any jobs that depended on the dependency
610 # update any jobs that depended on the dependency
603 jobs = self.graph.pop(dep_id, [])
611 jobs = self.graph.pop(dep_id, [])
604
612
605 # recheck *all* jobs if
613 # recheck *all* jobs if
606 # a) we have HWM and an engine just become no longer full
614 # a) we have HWM and an engine just become no longer full
607 # or b) dep_id was given as None
615 # or b) dep_id was given as None
608 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
616 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
609 jobs = self.depending.keys()
617 jobs = self.depending.keys()
610
618
611 for msg_id in jobs:
619 for msg_id in jobs:
612 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
620 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
613
621
614 if after.unreachable(self.all_completed, self.all_failed)\
622 if after.unreachable(self.all_completed, self.all_failed)\
615 or follow.unreachable(self.all_completed, self.all_failed):
623 or follow.unreachable(self.all_completed, self.all_failed):
616 self.fail_unreachable(msg_id)
624 self.fail_unreachable(msg_id)
617
625
618 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
626 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
619 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
627 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
620
628
621 self.depending.pop(msg_id)
629 self.depending.pop(msg_id)
622 for mid in follow.union(after):
630 for mid in follow.union(after):
623 if mid in self.graph:
631 if mid in self.graph:
624 self.graph[mid].remove(msg_id)
632 self.graph[mid].remove(msg_id)
625
633
626 #----------------------------------------------------------------------
634 #----------------------------------------------------------------------
627 # methods to be overridden by subclasses
635 # methods to be overridden by subclasses
628 #----------------------------------------------------------------------
636 #----------------------------------------------------------------------
629
637
630 def add_job(self, idx):
638 def add_job(self, idx):
631 """Called after self.targets[idx] just got the job with header.
639 """Called after self.targets[idx] just got the job with header.
632 Override with subclasses. The default ordering is simple LRU.
640 Override with subclasses. The default ordering is simple LRU.
633 The default loads are the number of outstanding jobs."""
641 The default loads are the number of outstanding jobs."""
634 self.loads[idx] += 1
642 self.loads[idx] += 1
635 for lis in (self.targets, self.loads):
643 for lis in (self.targets, self.loads):
636 lis.append(lis.pop(idx))
644 lis.append(lis.pop(idx))
637
645
638
646
639 def finish_job(self, idx):
647 def finish_job(self, idx):
640 """Called after self.targets[idx] just finished a job.
648 """Called after self.targets[idx] just finished a job.
641 Override with subclasses."""
649 Override with subclasses."""
642 self.loads[idx] -= 1
650 self.loads[idx] -= 1
643
651
644
652
645
653
646 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
654 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
647 logname='root', log_url=None, loglevel=logging.DEBUG,
655 logname='root', log_url=None, loglevel=logging.DEBUG,
648 identity=b'task', in_thread=False):
656 identity=b'task', in_thread=False):
649
657
650 ZMQStream = zmqstream.ZMQStream
658 ZMQStream = zmqstream.ZMQStream
651
659
652 if config:
660 if config:
653 # unwrap dict back into Config
661 # unwrap dict back into Config
654 config = Config(config)
662 config = Config(config)
655
663
656 if in_thread:
664 if in_thread:
657 # use instance() to get the same Context/Loop as our parent
665 # use instance() to get the same Context/Loop as our parent
658 ctx = zmq.Context.instance()
666 ctx = zmq.Context.instance()
659 loop = ioloop.IOLoop.instance()
667 loop = ioloop.IOLoop.instance()
660 else:
668 else:
661 # in a process, don't use instance()
669 # in a process, don't use instance()
662 # for safety with multiprocessing
670 # for safety with multiprocessing
663 ctx = zmq.Context()
671 ctx = zmq.Context()
664 loop = ioloop.IOLoop()
672 loop = ioloop.IOLoop()
665 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
673 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
666 ins.setsockopt(zmq.IDENTITY, identity)
674 ins.setsockopt(zmq.IDENTITY, identity)
667 ins.bind(in_addr)
675 ins.bind(in_addr)
668
676
669 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
677 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
670 outs.setsockopt(zmq.IDENTITY, identity)
678 outs.setsockopt(zmq.IDENTITY, identity)
671 outs.bind(out_addr)
679 outs.bind(out_addr)
672 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
680 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
673 mons.connect(mon_addr)
681 mons.connect(mon_addr)
674 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
682 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
675 nots.setsockopt(zmq.SUBSCRIBE, b'')
683 nots.setsockopt(zmq.SUBSCRIBE, b'')
676 nots.connect(not_addr)
684 nots.connect(not_addr)
677
685
678 # setup logging.
686 # setup logging.
679 if in_thread:
687 if in_thread:
680 log = Application.instance().log
688 log = Application.instance().log
681 else:
689 else:
682 if log_url:
690 if log_url:
683 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
691 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
684 else:
692 else:
685 log = local_logger(logname, loglevel)
693 log = local_logger(logname, loglevel)
686
694
687 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
695 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
688 mon_stream=mons, notifier_stream=nots,
696 mon_stream=mons, notifier_stream=nots,
689 loop=loop, log=log,
697 loop=loop, log=log,
690 config=config)
698 config=config)
691 scheduler.start()
699 scheduler.start()
692 if not in_thread:
700 if not in_thread:
693 try:
701 try:
694 loop.start()
702 loop.start()
695 except KeyboardInterrupt:
703 except KeyboardInterrupt:
696 print ("interrupted, exiting...", file=sys.__stderr__)
704 print ("interrupted, exiting...", file=sys.__stderr__)
697
705
General Comments 0
You need to be logged in to leave comments. Login now