##// END OF EJS Templates
tasks on engines when they die fail instead of hang...
MinRK -
Show More
@@ -1,526 +1,545 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7
7
8 #----------------------------------------------------------------------
8 #----------------------------------------------------------------------
9 # Imports
9 # Imports
10 #----------------------------------------------------------------------
10 #----------------------------------------------------------------------
11
11
12 from __future__ import print_function
12 from __future__ import print_function
13 import sys
13 import sys
14 import logging
14 import logging
15 from random import randint, random
15 from random import randint, random
16 from types import FunctionType
16 from types import FunctionType
17 from datetime import datetime, timedelta
17 from datetime import datetime, timedelta
18 try:
18 try:
19 import numpy
19 import numpy
20 except ImportError:
20 except ImportError:
21 numpy = None
21 numpy = None
22
22
23 import zmq
23 import zmq
24 from zmq.eventloop import ioloop, zmqstream
24 from zmq.eventloop import ioloop, zmqstream
25
25
26 # local imports
26 # local imports
27 from IPython.external.decorator import decorator
27 from IPython.external.decorator import decorator
28 # from IPython.config.configurable import Configurable
28 # from IPython.config.configurable import Configurable
29 from IPython.utils.traitlets import Instance, Dict, List, Set
29 from IPython.utils.traitlets import Instance, Dict, List, Set
30
30
31 import error
31 import error
32 # from client import Client
32 # from client import Client
33 from dependency import Dependency
33 from dependency import Dependency
34 import streamsession as ss
34 import streamsession as ss
35 from entry_point import connect_logger, local_logger
35 from entry_point import connect_logger, local_logger
36 from factory import SessionFactory
36 from factory import SessionFactory
37
37
38
38
39 @decorator
39 @decorator
40 def logged(f,self,*args,**kwargs):
40 def logged(f,self,*args,**kwargs):
41 # print ("#--------------------")
41 # print ("#--------------------")
42 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
42 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
43 # print ("#--")
43 # print ("#--")
44 return f(self,*args, **kwargs)
44 return f(self,*args, **kwargs)
45
45
46 #----------------------------------------------------------------------
46 #----------------------------------------------------------------------
47 # Chooser functions
47 # Chooser functions
48 #----------------------------------------------------------------------
48 #----------------------------------------------------------------------
49
49
50 def plainrandom(loads):
50 def plainrandom(loads):
51 """Plain random pick."""
51 """Plain random pick."""
52 n = len(loads)
52 n = len(loads)
53 return randint(0,n-1)
53 return randint(0,n-1)
54
54
55 def lru(loads):
55 def lru(loads):
56 """Always pick the front of the line.
56 """Always pick the front of the line.
57
57
58 The content of `loads` is ignored.
58 The content of `loads` is ignored.
59
59
60 Assumes LRU ordering of loads, with oldest first.
60 Assumes LRU ordering of loads, with oldest first.
61 """
61 """
62 return 0
62 return 0
63
63
64 def twobin(loads):
64 def twobin(loads):
65 """Pick two at random, use the LRU of the two.
65 """Pick two at random, use the LRU of the two.
66
66
67 The content of loads is ignored.
67 The content of loads is ignored.
68
68
69 Assumes LRU ordering of loads, with oldest first.
69 Assumes LRU ordering of loads, with oldest first.
70 """
70 """
71 n = len(loads)
71 n = len(loads)
72 a = randint(0,n-1)
72 a = randint(0,n-1)
73 b = randint(0,n-1)
73 b = randint(0,n-1)
74 return min(a,b)
74 return min(a,b)
75
75
76 def weighted(loads):
76 def weighted(loads):
77 """Pick two at random using inverse load as weight.
77 """Pick two at random using inverse load as weight.
78
78
79 Return the less loaded of the two.
79 Return the less loaded of the two.
80 """
80 """
81 # weight 0 a million times more than 1:
81 # weight 0 a million times more than 1:
82 weights = 1./(1e-6+numpy.array(loads))
82 weights = 1./(1e-6+numpy.array(loads))
83 sums = weights.cumsum()
83 sums = weights.cumsum()
84 t = sums[-1]
84 t = sums[-1]
85 x = random()*t
85 x = random()*t
86 y = random()*t
86 y = random()*t
87 idx = 0
87 idx = 0
88 idy = 0
88 idy = 0
89 while sums[idx] < x:
89 while sums[idx] < x:
90 idx += 1
90 idx += 1
91 while sums[idy] < y:
91 while sums[idy] < y:
92 idy += 1
92 idy += 1
93 if weights[idy] > weights[idx]:
93 if weights[idy] > weights[idx]:
94 return idy
94 return idy
95 else:
95 else:
96 return idx
96 return idx
97
97
98 def leastload(loads):
98 def leastload(loads):
99 """Always choose the lowest load.
99 """Always choose the lowest load.
100
100
101 If the lowest load occurs more than once, the first
101 If the lowest load occurs more than once, the first
102 occurance will be used. If loads has LRU ordering, this means
102 occurance will be used. If loads has LRU ordering, this means
103 the LRU of those with the lowest load is chosen.
103 the LRU of those with the lowest load is chosen.
104 """
104 """
105 return loads.index(min(loads))
105 return loads.index(min(loads))
106
106
107 #---------------------------------------------------------------------
107 #---------------------------------------------------------------------
108 # Classes
108 # Classes
109 #---------------------------------------------------------------------
109 #---------------------------------------------------------------------
110 # store empty default dependency:
110 # store empty default dependency:
111 MET = Dependency([])
111 MET = Dependency([])
112
112
113 class TaskScheduler(SessionFactory):
113 class TaskScheduler(SessionFactory):
114 """Python TaskScheduler object.
114 """Python TaskScheduler object.
115
115
116 This is the simplest object that supports msg_id based
116 This is the simplest object that supports msg_id based
117 DAG dependencies. *Only* task msg_ids are checked, not
117 DAG dependencies. *Only* task msg_ids are checked, not
118 msg_ids of jobs submitted via the MUX queue.
118 msg_ids of jobs submitted via the MUX queue.
119
119
120 """
120 """
121
121
122 # input arguments:
122 # input arguments:
123 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
123 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
124 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
124 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
125 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
125 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
126 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
126 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
127 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
127 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
128
128
129 # internals:
129 # internals:
130 dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
130 dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
131 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
131 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
132 pending = Dict() # dict by engine_uuid of submitted tasks
132 pending = Dict() # dict by engine_uuid of submitted tasks
133 completed = Dict() # dict by engine_uuid of completed tasks
133 completed = Dict() # dict by engine_uuid of completed tasks
134 failed = Dict() # dict by engine_uuid of failed tasks
134 failed = Dict() # dict by engine_uuid of failed tasks
135 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
135 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
136 clients = Dict() # dict by msg_id for who submitted the task
136 clients = Dict() # dict by msg_id for who submitted the task
137 targets = List() # list of target IDENTs
137 targets = List() # list of target IDENTs
138 loads = List() # list of engine loads
138 loads = List() # list of engine loads
139 all_completed = Set() # set of all completed tasks
139 all_completed = Set() # set of all completed tasks
140 all_failed = Set() # set of all failed tasks
140 all_failed = Set() # set of all failed tasks
141 all_done = Set() # set of all finished tasks=union(completed,failed)
141 all_done = Set() # set of all finished tasks=union(completed,failed)
142 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
142 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
143 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
143 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
144
144
145
145
146 def start(self):
146 def start(self):
147 self.engine_stream.on_recv(self.dispatch_result, copy=False)
147 self.engine_stream.on_recv(self.dispatch_result, copy=False)
148 self._notification_handlers = dict(
148 self._notification_handlers = dict(
149 registration_notification = self._register_engine,
149 registration_notification = self._register_engine,
150 unregistration_notification = self._unregister_engine
150 unregistration_notification = self._unregister_engine
151 )
151 )
152 self.notifier_stream.on_recv(self.dispatch_notification)
152 self.notifier_stream.on_recv(self.dispatch_notification)
153 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 1e3, self.loop) # 1 Hz
153 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
154 self.auditor.start()
154 self.auditor.start()
155 self.log.info("Scheduler started...%r"%self)
155 self.log.info("Scheduler started...%r"%self)
156
156
157 def resume_receiving(self):
157 def resume_receiving(self):
158 """Resume accepting jobs."""
158 """Resume accepting jobs."""
159 self.client_stream.on_recv(self.dispatch_submission, copy=False)
159 self.client_stream.on_recv(self.dispatch_submission, copy=False)
160
160
161 def stop_receiving(self):
161 def stop_receiving(self):
162 """Stop accepting jobs while there are no engines.
162 """Stop accepting jobs while there are no engines.
163 Leave them in the ZMQ queue."""
163 Leave them in the ZMQ queue."""
164 self.client_stream.on_recv(None)
164 self.client_stream.on_recv(None)
165
165
166 #-----------------------------------------------------------------------
166 #-----------------------------------------------------------------------
167 # [Un]Registration Handling
167 # [Un]Registration Handling
168 #-----------------------------------------------------------------------
168 #-----------------------------------------------------------------------
169
169
170 def dispatch_notification(self, msg):
170 def dispatch_notification(self, msg):
171 """dispatch register/unregister events."""
171 """dispatch register/unregister events."""
172 idents,msg = self.session.feed_identities(msg)
172 idents,msg = self.session.feed_identities(msg)
173 msg = self.session.unpack_message(msg)
173 msg = self.session.unpack_message(msg)
174 msg_type = msg['msg_type']
174 msg_type = msg['msg_type']
175 handler = self._notification_handlers.get(msg_type, None)
175 handler = self._notification_handlers.get(msg_type, None)
176 if handler is None:
176 if handler is None:
177 raise Exception("Unhandled message type: %s"%msg_type)
177 raise Exception("Unhandled message type: %s"%msg_type)
178 else:
178 else:
179 try:
179 try:
180 handler(str(msg['content']['queue']))
180 handler(str(msg['content']['queue']))
181 except KeyError:
181 except KeyError:
182 self.log.error("task::Invalid notification msg: %s"%msg)
182 self.log.error("task::Invalid notification msg: %s"%msg)
183
183
184 @logged
184 @logged
185 def _register_engine(self, uid):
185 def _register_engine(self, uid):
186 """New engine with ident `uid` became available."""
186 """New engine with ident `uid` became available."""
187 # head of the line:
187 # head of the line:
188 self.targets.insert(0,uid)
188 self.targets.insert(0,uid)
189 self.loads.insert(0,0)
189 self.loads.insert(0,0)
190 # initialize sets
190 # initialize sets
191 self.completed[uid] = set()
191 self.completed[uid] = set()
192 self.failed[uid] = set()
192 self.failed[uid] = set()
193 self.pending[uid] = {}
193 self.pending[uid] = {}
194 if len(self.targets) == 1:
194 if len(self.targets) == 1:
195 self.resume_receiving()
195 self.resume_receiving()
196
196
197 def _unregister_engine(self, uid):
197 def _unregister_engine(self, uid):
198 """Existing engine with ident `uid` became unavailable."""
198 """Existing engine with ident `uid` became unavailable."""
199 if len(self.targets) == 1:
199 if len(self.targets) == 1:
200 # this was our only engine
200 # this was our only engine
201 self.stop_receiving()
201 self.stop_receiving()
202
202
203 # handle any potentially finished tasks:
203 # handle any potentially finished tasks:
204 self.engine_stream.flush()
204 self.engine_stream.flush()
205
205
206 self.completed.pop(uid)
206 self.completed.pop(uid)
207 self.failed.pop(uid)
207 self.failed.pop(uid)
208 # don't pop destinations, because it might be used later
208 # don't pop destinations, because it might be used later
209 # map(self.destinations.pop, self.completed.pop(uid))
209 # map(self.destinations.pop, self.completed.pop(uid))
210 # map(self.destinations.pop, self.failed.pop(uid))
210 # map(self.destinations.pop, self.failed.pop(uid))
211
211
212 lost = self.pending.pop(uid)
213
214 idx = self.targets.index(uid)
212 idx = self.targets.index(uid)
215 self.targets.pop(idx)
213 self.targets.pop(idx)
216 self.loads.pop(idx)
214 self.loads.pop(idx)
217
215
218 self.handle_stranded_tasks(lost)
216 # wait 5 seconds before cleaning up pending jobs, since the results might
217 # still be incoming
218 if self.pending[uid]:
219 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
220 dc.start()
219
221
220 def handle_stranded_tasks(self, lost):
222 @logged
223 def handle_stranded_tasks(self, engine):
221 """Deal with jobs resident in an engine that died."""
224 """Deal with jobs resident in an engine that died."""
222 # TODO: resubmit the tasks?
225 lost = self.pending.pop(engine)
223 for msg_id in lost:
226
224 pass
227 for msg_id, (raw_msg,follow) in lost.iteritems():
228 self.all_failed.add(msg_id)
229 self.all_done.add(msg_id)
230 idents,msg = self.session.feed_identities(raw_msg, copy=False)
231 msg = self.session.unpack_message(msg, copy=False, content=False)
232 parent = msg['header']
233 idents = [idents[0],engine]+idents[1:]
234 print (idents)
235 try:
236 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
237 except:
238 content = ss.wrap_exception()
239 msg = self.session.send(self.client_stream, 'apply_reply', content,
240 parent=parent, ident=idents)
241 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
242 self.update_dependencies(msg_id)
225
243
226
244
227 #-----------------------------------------------------------------------
245 #-----------------------------------------------------------------------
228 # Job Submission
246 # Job Submission
229 #-----------------------------------------------------------------------
247 #-----------------------------------------------------------------------
230 @logged
248 @logged
231 def dispatch_submission(self, raw_msg):
249 def dispatch_submission(self, raw_msg):
232 """Dispatch job submission to appropriate handlers."""
250 """Dispatch job submission to appropriate handlers."""
233 # ensure targets up to date:
251 # ensure targets up to date:
234 self.notifier_stream.flush()
252 self.notifier_stream.flush()
235 try:
253 try:
236 idents, msg = self.session.feed_identities(raw_msg, copy=False)
254 idents, msg = self.session.feed_identities(raw_msg, copy=False)
237 except Exception as e:
255 except Exception as e:
238 self.log.error("task::Invaid msg: %s"%msg)
256 self.log.error("task::Invaid msg: %s"%msg)
239 return
257 return
240
258
241 # send to monitor
259 # send to monitor
242 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
260 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
243
261
244 msg = self.session.unpack_message(msg, content=False, copy=False)
262 msg = self.session.unpack_message(msg, content=False, copy=False)
245 header = msg['header']
263 header = msg['header']
246 msg_id = header['msg_id']
264 msg_id = header['msg_id']
247
265
248 # time dependencies
266 # time dependencies
249 after = Dependency(header.get('after', []))
267 after = Dependency(header.get('after', []))
250 if after.mode == 'all':
268 if after.mode == 'all':
251 after.difference_update(self.all_completed)
269 after.difference_update(self.all_completed)
252 if not after.success_only:
270 if not after.success_only:
253 after.difference_update(self.all_failed)
271 after.difference_update(self.all_failed)
254 if after.check(self.all_completed, self.all_failed):
272 if after.check(self.all_completed, self.all_failed):
255 # recast as empty set, if `after` already met,
273 # recast as empty set, if `after` already met,
256 # to prevent unnecessary set comparisons
274 # to prevent unnecessary set comparisons
257 after = MET
275 after = MET
258
276
259 # location dependencies
277 # location dependencies
260 follow = Dependency(header.get('follow', []))
278 follow = Dependency(header.get('follow', []))
261 # check if unreachable:
279 # check if unreachable:
262 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
280 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
263 self.depending[msg_id] = [raw_msg,MET,MET,None]
281 self.depending[msg_id] = [raw_msg,MET,MET,None]
264 return self.fail_unreachable(msg_id)
282 return self.fail_unreachable(msg_id)
265
283
266 # turn timeouts into datetime objects:
284 # turn timeouts into datetime objects:
267 timeout = header.get('timeout', None)
285 timeout = header.get('timeout', None)
268 if timeout:
286 if timeout:
269 timeout = datetime.now() + timedelta(0,timeout,0)
287 timeout = datetime.now() + timedelta(0,timeout,0)
270
288
271 if after.check(self.all_completed, self.all_failed):
289 if after.check(self.all_completed, self.all_failed):
272 # time deps already met, try to run
290 # time deps already met, try to run
273 if not self.maybe_run(msg_id, raw_msg, follow):
291 if not self.maybe_run(msg_id, raw_msg, follow):
274 # can't run yet
292 # can't run yet
275 self.save_unmet(msg_id, raw_msg, after, follow, timeout)
293 self.save_unmet(msg_id, raw_msg, after, follow, timeout)
276 else:
294 else:
277 self.save_unmet(msg_id, raw_msg, after, follow, timeout)
295 self.save_unmet(msg_id, raw_msg, after, follow, timeout)
278
296
279 @logged
297 @logged
280 def audit_timeouts(self):
298 def audit_timeouts(self):
281 """Audit all waiting tasks for expired timeouts."""
299 """Audit all waiting tasks for expired timeouts."""
282 now = datetime.now()
300 now = datetime.now()
283 for msg_id in self.depending.keys():
301 for msg_id in self.depending.keys():
284 # must recheck, in case one failure cascaded to another:
302 # must recheck, in case one failure cascaded to another:
285 if msg_id in self.depending:
303 if msg_id in self.depending:
286 raw,after,follow,timeout = self.depending[msg_id]
304 raw,after,follow,timeout = self.depending[msg_id]
287 if timeout and timeout < now:
305 if timeout and timeout < now:
288 self.fail_unreachable(msg_id, timeout=True)
306 self.fail_unreachable(msg_id, timeout=True)
289
307
290 @logged
308 @logged
291 def fail_unreachable(self, msg_id, timeout=False):
309 def fail_unreachable(self, msg_id, timeout=False):
292 """a message has become unreachable"""
310 """a message has become unreachable"""
293 if msg_id not in self.depending:
311 if msg_id not in self.depending:
294 self.log.error("msg %r already failed!"%msg_id)
312 self.log.error("msg %r already failed!"%msg_id)
295 return
313 return
296 raw_msg, after, follow, timeout = self.depending.pop(msg_id)
314 raw_msg, after, follow, timeout = self.depending.pop(msg_id)
297 for mid in follow.union(after):
315 for mid in follow.union(after):
298 if mid in self.dependencies:
316 if mid in self.dependencies:
299 self.dependencies[mid].remove(msg_id)
317 self.dependencies[mid].remove(msg_id)
300
318
301 # FIXME: unpacking a message I've already unpacked, but didn't save:
319 # FIXME: unpacking a message I've already unpacked, but didn't save:
302 idents,msg = self.session.feed_identities(raw_msg, copy=False)
320 idents,msg = self.session.feed_identities(raw_msg, copy=False)
303 msg = self.session.unpack_message(msg, copy=False, content=False)
321 msg = self.session.unpack_message(msg, copy=False, content=False)
304 header = msg['header']
322 header = msg['header']
305
323
306 impossible = error.DependencyTimeout if timeout else error.ImpossibleDependency
324 impossible = error.DependencyTimeout if timeout else error.ImpossibleDependency
307
325
308 try:
326 try:
309 raise impossible()
327 raise impossible()
310 except:
328 except:
311 content = ss.wrap_exception()
329 content = ss.wrap_exception()
312
330
313 self.all_done.add(msg_id)
331 self.all_done.add(msg_id)
314 self.all_failed.add(msg_id)
332 self.all_failed.add(msg_id)
315
333
316 msg = self.session.send(self.client_stream, 'apply_reply', content,
334 msg = self.session.send(self.client_stream, 'apply_reply', content,
317 parent=header, ident=idents)
335 parent=header, ident=idents)
318 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
336 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
319
337
320 self.update_dependencies(msg_id, success=False)
338 self.update_dependencies(msg_id, success=False)
321
339
322 @logged
340 @logged
323 def maybe_run(self, msg_id, raw_msg, follow=None):
341 def maybe_run(self, msg_id, raw_msg, follow=None):
324 """check location dependencies, and run if they are met."""
342 """check location dependencies, and run if they are met."""
325
343
326 if follow:
344 if follow:
327 def can_run(idx):
345 def can_run(idx):
328 target = self.targets[idx]
346 target = self.targets[idx]
329 return target not in self.blacklist.get(msg_id, []) and\
347 return target not in self.blacklist.get(msg_id, []) and\
330 follow.check(self.completed[target], self.failed[target])
348 follow.check(self.completed[target], self.failed[target])
331
349
332 indices = filter(can_run, range(len(self.targets)))
350 indices = filter(can_run, range(len(self.targets)))
333 if not indices:
351 if not indices:
334 # TODO evaluate unmeetable follow dependencies
352 # TODO evaluate unmeetable follow dependencies
335 if follow.mode == 'all':
353 if follow.mode == 'all':
336 dests = set()
354 dests = set()
337 relevant = self.all_completed if follow.success_only else self.all_done
355 relevant = self.all_completed if follow.success_only else self.all_done
338 for m in follow.intersection(relevant):
356 for m in follow.intersection(relevant):
339 dests.add(self.destinations[m])
357 dests.add(self.destinations[m])
340 if len(dests) > 1:
358 if len(dests) > 1:
341 self.fail_unreachable(msg_id)
359 self.fail_unreachable(msg_id)
342
360
343
361
344 return False
362 return False
345 else:
363 else:
346 indices = None
364 indices = None
347
365
348 self.submit_task(msg_id, raw_msg, indices)
366 self.submit_task(msg_id, raw_msg, indices)
349 return True
367 return True
350
368
351 @logged
369 @logged
352 def save_unmet(self, msg_id, raw_msg, after, follow, timeout):
370 def save_unmet(self, msg_id, raw_msg, after, follow, timeout):
353 """Save a message for later submission when its dependencies are met."""
371 """Save a message for later submission when its dependencies are met."""
354 self.depending[msg_id] = [raw_msg,after,follow,timeout]
372 self.depending[msg_id] = [raw_msg,after,follow,timeout]
355 # track the ids in follow or after, but not those already finished
373 # track the ids in follow or after, but not those already finished
356 for dep_id in after.union(follow).difference(self.all_done):
374 for dep_id in after.union(follow).difference(self.all_done):
357 if dep_id not in self.dependencies:
375 if dep_id not in self.dependencies:
358 self.dependencies[dep_id] = set()
376 self.dependencies[dep_id] = set()
359 self.dependencies[dep_id].add(msg_id)
377 self.dependencies[dep_id].add(msg_id)
360
378
361 @logged
379 @logged
362 def submit_task(self, msg_id, msg, follow=None, indices=None):
380 def submit_task(self, msg_id, raw_msg, follow=None, indices=None):
363 """Submit a task to any of a subset of our targets."""
381 """Submit a task to any of a subset of our targets."""
364 if indices:
382 if indices:
365 loads = [self.loads[i] for i in indices]
383 loads = [self.loads[i] for i in indices]
366 else:
384 else:
367 loads = self.loads
385 loads = self.loads
368 idx = self.scheme(loads)
386 idx = self.scheme(loads)
369 if indices:
387 if indices:
370 idx = indices[idx]
388 idx = indices[idx]
371 target = self.targets[idx]
389 target = self.targets[idx]
372 # print (target, map(str, msg[:3]))
390 # print (target, map(str, msg[:3]))
373 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
391 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
374 self.engine_stream.send_multipart(msg, copy=False)
392 self.engine_stream.send_multipart(raw_msg, copy=False)
375 self.add_job(idx)
393 self.add_job(idx)
376 self.pending[target][msg_id] = (msg, follow)
394 self.pending[target][msg_id] = (raw_msg, follow)
377 content = dict(msg_id=msg_id, engine_id=target)
395 content = dict(msg_id=msg_id, engine_id=target)
378 self.session.send(self.mon_stream, 'task_destination', content=content,
396 self.session.send(self.mon_stream, 'task_destination', content=content,
379 ident=['tracktask',self.session.session])
397 ident=['tracktask',self.session.session])
380
398
381 #-----------------------------------------------------------------------
399 #-----------------------------------------------------------------------
382 # Result Handling
400 # Result Handling
383 #-----------------------------------------------------------------------
401 #-----------------------------------------------------------------------
384 @logged
402 @logged
385 def dispatch_result(self, raw_msg):
403 def dispatch_result(self, raw_msg):
386 try:
404 try:
387 idents,msg = self.session.feed_identities(raw_msg, copy=False)
405 idents,msg = self.session.feed_identities(raw_msg, copy=False)
388 except Exception as e:
406 except Exception as e:
389 self.log.error("task::Invaid result: %s"%msg)
407 self.log.error("task::Invaid result: %s"%msg)
390 return
408 return
391 msg = self.session.unpack_message(msg, content=False, copy=False)
409 msg = self.session.unpack_message(msg, content=False, copy=False)
392 header = msg['header']
410 header = msg['header']
393 if header.get('dependencies_met', True):
411 if header.get('dependencies_met', True):
394 success = (header['status'] == 'ok')
412 success = (header['status'] == 'ok')
395 self.handle_result(idents, msg['parent_header'], raw_msg, success)
413 self.handle_result(idents, msg['parent_header'], raw_msg, success)
396 # send to Hub monitor
414 # send to Hub monitor
397 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
415 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
398 else:
416 else:
399 self.handle_unmet_dependency(idents, msg['parent_header'])
417 self.handle_unmet_dependency(idents, msg['parent_header'])
400
418
401 @logged
419 @logged
402 def handle_result(self, idents, parent, raw_msg, success=True):
420 def handle_result(self, idents, parent, raw_msg, success=True):
403 # first, relay result to client
421 # first, relay result to client
404 engine = idents[0]
422 engine = idents[0]
405 client = idents[1]
423 client = idents[1]
406 # swap_ids for XREP-XREP mirror
424 # swap_ids for XREP-XREP mirror
407 raw_msg[:2] = [client,engine]
425 raw_msg[:2] = [client,engine]
408 # print (map(str, raw_msg[:4]))
426 # print (map(str, raw_msg[:4]))
409 self.client_stream.send_multipart(raw_msg, copy=False)
427 self.client_stream.send_multipart(raw_msg, copy=False)
410 # now, update our data structures
428 # now, update our data structures
411 msg_id = parent['msg_id']
429 msg_id = parent['msg_id']
430 self.blacklist.pop(msg_id, None)
412 self.pending[engine].pop(msg_id)
431 self.pending[engine].pop(msg_id)
413 if success:
432 if success:
414 self.completed[engine].add(msg_id)
433 self.completed[engine].add(msg_id)
415 self.all_completed.add(msg_id)
434 self.all_completed.add(msg_id)
416 else:
435 else:
417 self.failed[engine].add(msg_id)
436 self.failed[engine].add(msg_id)
418 self.all_failed.add(msg_id)
437 self.all_failed.add(msg_id)
419 self.all_done.add(msg_id)
438 self.all_done.add(msg_id)
420 self.destinations[msg_id] = engine
439 self.destinations[msg_id] = engine
421
440
422 self.update_dependencies(msg_id, success)
441 self.update_dependencies(msg_id, success)
423
442
424 @logged
443 @logged
425 def handle_unmet_dependency(self, idents, parent):
444 def handle_unmet_dependency(self, idents, parent):
426 engine = idents[0]
445 engine = idents[0]
427 msg_id = parent['msg_id']
446 msg_id = parent['msg_id']
428 if msg_id not in self.blacklist:
447 if msg_id not in self.blacklist:
429 self.blacklist[msg_id] = set()
448 self.blacklist[msg_id] = set()
430 self.blacklist[msg_id].add(engine)
449 self.blacklist[msg_id].add(engine)
431 raw_msg,follow,timeout = self.pending[engine].pop(msg_id)
450 raw_msg,follow,timeout = self.pending[engine].pop(msg_id)
432 if not self.maybe_run(msg_id, raw_msg, follow):
451 if not self.maybe_run(msg_id, raw_msg, follow):
433 # resubmit failed, put it back in our dependency tree
452 # resubmit failed, put it back in our dependency tree
434 self.save_unmet(msg_id, raw_msg, MET, follow, timeout)
453 self.save_unmet(msg_id, raw_msg, MET, follow, timeout)
435 pass
454 pass
436
455
437 @logged
456 @logged
438 def update_dependencies(self, dep_id, success=True):
457 def update_dependencies(self, dep_id, success=True):
439 """dep_id just finished. Update our dependency
458 """dep_id just finished. Update our dependency
440 table and submit any jobs that just became runable."""
459 table and submit any jobs that just became runable."""
441 # print ("\n\n***********")
460 # print ("\n\n***********")
442 # pprint (dep_id)
461 # pprint (dep_id)
443 # pprint (self.dependencies)
462 # pprint (self.dependencies)
444 # pprint (self.depending)
463 # pprint (self.depending)
445 # pprint (self.all_completed)
464 # pprint (self.all_completed)
446 # pprint (self.all_failed)
465 # pprint (self.all_failed)
447 # print ("\n\n***********\n\n")
466 # print ("\n\n***********\n\n")
448 if dep_id not in self.dependencies:
467 if dep_id not in self.dependencies:
449 return
468 return
450 jobs = self.dependencies.pop(dep_id)
469 jobs = self.dependencies.pop(dep_id)
451
470
452 for msg_id in jobs:
471 for msg_id in jobs:
453 raw_msg, after, follow, timeout = self.depending[msg_id]
472 raw_msg, after, follow, timeout = self.depending[msg_id]
454 # if dep_id in after:
473 # if dep_id in after:
455 # if after.mode == 'all' and (success or not after.success_only):
474 # if after.mode == 'all' and (success or not after.success_only):
456 # after.remove(dep_id)
475 # after.remove(dep_id)
457
476
458 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
477 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
459 self.fail_unreachable(msg_id)
478 self.fail_unreachable(msg_id)
460
479
461 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
480 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
462 self.depending[msg_id][1] = MET
481 self.depending[msg_id][1] = MET
463 if self.maybe_run(msg_id, raw_msg, follow):
482 if self.maybe_run(msg_id, raw_msg, follow):
464
483
465 self.depending.pop(msg_id)
484 self.depending.pop(msg_id)
466 for mid in follow.union(after):
485 for mid in follow.union(after):
467 if mid in self.dependencies:
486 if mid in self.dependencies:
468 self.dependencies[mid].remove(msg_id)
487 self.dependencies[mid].remove(msg_id)
469
488
470 #----------------------------------------------------------------------
489 #----------------------------------------------------------------------
471 # methods to be overridden by subclasses
490 # methods to be overridden by subclasses
472 #----------------------------------------------------------------------
491 #----------------------------------------------------------------------
473
492
474 def add_job(self, idx):
493 def add_job(self, idx):
475 """Called after self.targets[idx] just got the job with header.
494 """Called after self.targets[idx] just got the job with header.
476 Override with subclasses. The default ordering is simple LRU.
495 Override with subclasses. The default ordering is simple LRU.
477 The default loads are the number of outstanding jobs."""
496 The default loads are the number of outstanding jobs."""
478 self.loads[idx] += 1
497 self.loads[idx] += 1
479 for lis in (self.targets, self.loads):
498 for lis in (self.targets, self.loads):
480 lis.append(lis.pop(idx))
499 lis.append(lis.pop(idx))
481
500
482
501
483 def finish_job(self, idx):
502 def finish_job(self, idx):
484 """Called after self.targets[idx] just finished a job.
503 """Called after self.targets[idx] just finished a job.
485 Override with subclasses."""
504 Override with subclasses."""
486 self.loads[idx] -= 1
505 self.loads[idx] -= 1
487
506
488
507
489
508
490 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, logname='ZMQ', log_addr=None, loglevel=logging.DEBUG, scheme='weighted'):
509 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, logname='ZMQ', log_addr=None, loglevel=logging.DEBUG, scheme='weighted'):
491 from zmq.eventloop import ioloop
510 from zmq.eventloop import ioloop
492 from zmq.eventloop.zmqstream import ZMQStream
511 from zmq.eventloop.zmqstream import ZMQStream
493
512
494 ctx = zmq.Context()
513 ctx = zmq.Context()
495 loop = ioloop.IOLoop()
514 loop = ioloop.IOLoop()
496
515
497 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
516 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
498 ins.bind(in_addr)
517 ins.bind(in_addr)
499 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
518 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
500 outs.bind(out_addr)
519 outs.bind(out_addr)
501 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
520 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
502 mons.connect(mon_addr)
521 mons.connect(mon_addr)
503 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
522 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
504 nots.setsockopt(zmq.SUBSCRIBE, '')
523 nots.setsockopt(zmq.SUBSCRIBE, '')
505 nots.connect(not_addr)
524 nots.connect(not_addr)
506
525
507 scheme = globals().get(scheme, None)
526 scheme = globals().get(scheme, None)
508 # setup logging
527 # setup logging
509 if log_addr:
528 if log_addr:
510 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
529 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
511 else:
530 else:
512 local_logger(logname, loglevel)
531 local_logger(logname, loglevel)
513
532
514 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
533 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
515 mon_stream=mons, notifier_stream=nots,
534 mon_stream=mons, notifier_stream=nots,
516 scheme=scheme, loop=loop, logname=logname)
535 scheme=scheme, loop=loop, logname=logname)
517 scheduler.start()
536 scheduler.start()
518 try:
537 try:
519 loop.start()
538 loop.start()
520 except KeyboardInterrupt:
539 except KeyboardInterrupt:
521 print ("interrupted, exiting...", file=sys.__stderr__)
540 print ("interrupted, exiting...", file=sys.__stderr__)
522
541
523
542
524 if __name__ == '__main__':
543 if __name__ == '__main__':
525 iface = 'tcp://127.0.0.1:%i'
544 iface = 'tcp://127.0.0.1:%i'
526 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
545 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
@@ -1,79 +1,79 b''
1 """some generic utilities"""
1 """some generic utilities"""
2 import re
2 import re
3
3
4 class ReverseDict(dict):
4 class ReverseDict(dict):
5 """simple double-keyed subset of dict methods."""
5 """simple double-keyed subset of dict methods."""
6
6
7 def __init__(self, *args, **kwargs):
7 def __init__(self, *args, **kwargs):
8 dict.__init__(self, *args, **kwargs)
8 dict.__init__(self, *args, **kwargs)
9 self._reverse = dict()
9 self._reverse = dict()
10 for key, value in self.iteritems():
10 for key, value in self.iteritems():
11 self._reverse[value] = key
11 self._reverse[value] = key
12
12
13 def __getitem__(self, key):
13 def __getitem__(self, key):
14 try:
14 try:
15 return dict.__getitem__(self, key)
15 return dict.__getitem__(self, key)
16 except KeyError:
16 except KeyError:
17 return self._reverse[key]
17 return self._reverse[key]
18
18
19 def __setitem__(self, key, value):
19 def __setitem__(self, key, value):
20 if key in self._reverse:
20 if key in self._reverse:
21 raise KeyError("Can't have key %r on both sides!"%key)
21 raise KeyError("Can't have key %r on both sides!"%key)
22 dict.__setitem__(self, key, value)
22 dict.__setitem__(self, key, value)
23 self._reverse[value] = key
23 self._reverse[value] = key
24
24
25 def pop(self, key):
25 def pop(self, key):
26 value = dict.pop(self, key)
26 value = dict.pop(self, key)
27 self.d1.pop(value)
27 self._reverse.pop(value)
28 return value
28 return value
29
29
30 def get(self, key, default=None):
30 def get(self, key, default=None):
31 try:
31 try:
32 return self[key]
32 return self[key]
33 except KeyError:
33 except KeyError:
34 return default
34 return default
35
35
36
36
37 def validate_url(url):
37 def validate_url(url):
38 """validate a url for zeromq"""
38 """validate a url for zeromq"""
39 if not isinstance(url, basestring):
39 if not isinstance(url, basestring):
40 raise TypeError("url must be a string, not %r"%type(url))
40 raise TypeError("url must be a string, not %r"%type(url))
41 url = url.lower()
41 url = url.lower()
42
42
43 proto_addr = url.split('://')
43 proto_addr = url.split('://')
44 assert len(proto_addr) == 2, 'Invalid url: %r'%url
44 assert len(proto_addr) == 2, 'Invalid url: %r'%url
45 proto, addr = proto_addr
45 proto, addr = proto_addr
46 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
46 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
47
47
48 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
48 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
49 # author: Remi Sabourin
49 # author: Remi Sabourin
50 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
50 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
51
51
52 if proto == 'tcp':
52 if proto == 'tcp':
53 lis = addr.split(':')
53 lis = addr.split(':')
54 assert len(lis) == 2, 'Invalid url: %r'%url
54 assert len(lis) == 2, 'Invalid url: %r'%url
55 addr,s_port = lis
55 addr,s_port = lis
56 try:
56 try:
57 port = int(s_port)
57 port = int(s_port)
58 except ValueError:
58 except ValueError:
59 raise AssertionError("Invalid port %r in url: %r"%(port, url))
59 raise AssertionError("Invalid port %r in url: %r"%(port, url))
60
60
61 assert pat.match(addr) is not None, 'Invalid url: %r'%url
61 assert pat.match(addr) is not None, 'Invalid url: %r'%url
62
62
63 else:
63 else:
64 # only validate tcp urls currently
64 # only validate tcp urls currently
65 pass
65 pass
66
66
67 return True
67 return True
68
68
69
69
70 def validate_url_container(container):
70 def validate_url_container(container):
71 """validate a potentially nested collection of urls."""
71 """validate a potentially nested collection of urls."""
72 if isinstance(container, basestring):
72 if isinstance(container, basestring):
73 url = container
73 url = container
74 return validate_url(url)
74 return validate_url(url)
75 elif isinstance(container, dict):
75 elif isinstance(container, dict):
76 container = container.itervalues()
76 container = container.itervalues()
77
77
78 for element in container:
78 for element in container:
79 validate_url_container(element) No newline at end of file
79 validate_url_container(element)
General Comments 0
You need to be logged in to leave comments. Login now