##// END OF EJS Templates
use heap-sorted queue...
MinRK -
Show More
@@ -1,813 +1,843 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 from __future__ import print_function
22 import heapq
23
24 import logging
23 import logging
25 import sys
24 import sys
26 import time
25 import time
27
26
28 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
29 from random import randint, random
28 from random import randint, random
30 from types import FunctionType
29 from types import FunctionType
31
30
32 try:
31 try:
33 import numpy
32 import numpy
34 except ImportError:
33 except ImportError:
35 numpy = None
34 numpy = None
36
35
37 import zmq
36 import zmq
38 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
39
38
40 # local imports
39 # local imports
41 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
42 from IPython.config.application import Application
41 from IPython.config.application import Application
43 from IPython.config.loader import Config
42 from IPython.config.loader import Config
44 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
45 from IPython.utils.py3compat import cast_bytes
44 from IPython.utils.py3compat import cast_bytes
46
45
47 from IPython.parallel import error, util
46 from IPython.parallel import error, util
48 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.factory import SessionFactory
49 from IPython.parallel.util import connect_logger, local_logger
48 from IPython.parallel.util import connect_logger, local_logger
50
49
51 from .dependency import Dependency
50 from .dependency import Dependency
52
51
53 @decorator
52 @decorator
54 def logged(f,self,*args,**kwargs):
53 def logged(f,self,*args,**kwargs):
55 # print ("#--------------------")
54 # print ("#--------------------")
56 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
57 # print ("#--")
56 # print ("#--")
58 return f(self,*args, **kwargs)
57 return f(self,*args, **kwargs)
59
58
60 #----------------------------------------------------------------------
59 #----------------------------------------------------------------------
61 # Chooser functions
60 # Chooser functions
62 #----------------------------------------------------------------------
61 #----------------------------------------------------------------------
63
62
64 def plainrandom(loads):
63 def plainrandom(loads):
65 """Plain random pick."""
64 """Plain random pick."""
66 n = len(loads)
65 n = len(loads)
67 return randint(0,n-1)
66 return randint(0,n-1)
68
67
69 def lru(loads):
68 def lru(loads):
70 """Always pick the front of the line.
69 """Always pick the front of the line.
71
70
72 The content of `loads` is ignored.
71 The content of `loads` is ignored.
73
72
74 Assumes LRU ordering of loads, with oldest first.
73 Assumes LRU ordering of loads, with oldest first.
75 """
74 """
76 return 0
75 return 0
77
76
78 def twobin(loads):
77 def twobin(loads):
79 """Pick two at random, use the LRU of the two.
78 """Pick two at random, use the LRU of the two.
80
79
81 The content of loads is ignored.
80 The content of loads is ignored.
82
81
83 Assumes LRU ordering of loads, with oldest first.
82 Assumes LRU ordering of loads, with oldest first.
84 """
83 """
85 n = len(loads)
84 n = len(loads)
86 a = randint(0,n-1)
85 a = randint(0,n-1)
87 b = randint(0,n-1)
86 b = randint(0,n-1)
88 return min(a,b)
87 return min(a,b)
89
88
90 def weighted(loads):
89 def weighted(loads):
91 """Pick two at random using inverse load as weight.
90 """Pick two at random using inverse load as weight.
92
91
93 Return the less loaded of the two.
92 Return the less loaded of the two.
94 """
93 """
95 # weight 0 a million times more than 1:
94 # weight 0 a million times more than 1:
96 weights = 1./(1e-6+numpy.array(loads))
95 weights = 1./(1e-6+numpy.array(loads))
97 sums = weights.cumsum()
96 sums = weights.cumsum()
98 t = sums[-1]
97 t = sums[-1]
99 x = random()*t
98 x = random()*t
100 y = random()*t
99 y = random()*t
101 idx = 0
100 idx = 0
102 idy = 0
101 idy = 0
103 while sums[idx] < x:
102 while sums[idx] < x:
104 idx += 1
103 idx += 1
105 while sums[idy] < y:
104 while sums[idy] < y:
106 idy += 1
105 idy += 1
107 if weights[idy] > weights[idx]:
106 if weights[idy] > weights[idx]:
108 return idy
107 return idy
109 else:
108 else:
110 return idx
109 return idx
111
110
112 def leastload(loads):
111 def leastload(loads):
113 """Always choose the lowest load.
112 """Always choose the lowest load.
114
113
115 If the lowest load occurs more than once, the first
114 If the lowest load occurs more than once, the first
116 occurance will be used. If loads has LRU ordering, this means
115 occurance will be used. If loads has LRU ordering, this means
117 the LRU of those with the lowest load is chosen.
116 the LRU of those with the lowest load is chosen.
118 """
117 """
119 return loads.index(min(loads))
118 return loads.index(min(loads))
120
119
121 #---------------------------------------------------------------------
120 #---------------------------------------------------------------------
122 # Classes
121 # Classes
123 #---------------------------------------------------------------------
122 #---------------------------------------------------------------------
124
123
125
124
126 # store empty default dependency:
125 # store empty default dependency:
127 MET = Dependency([])
126 MET = Dependency([])
128
127
129
128
130 class Job(object):
129 class Job(object):
131 """Simple container for a job"""
130 """Simple container for a job"""
132 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
133 targets, after, follow, timeout):
132 targets, after, follow, timeout):
134 self.msg_id = msg_id
133 self.msg_id = msg_id
135 self.raw_msg = raw_msg
134 self.raw_msg = raw_msg
136 self.idents = idents
135 self.idents = idents
137 self.msg = msg
136 self.msg = msg
138 self.header = header
137 self.header = header
139 self.metadata = metadata
138 self.metadata = metadata
140 self.targets = targets
139 self.targets = targets
141 self.after = after
140 self.after = after
142 self.follow = follow
141 self.follow = follow
143 self.timeout = timeout
142 self.timeout = timeout
144
143 self.removed = False # used for lazy-delete in heap-sorted queue
145
144
146 self.timestamp = time.time()
145 self.timestamp = time.time()
147 self.blacklist = set()
146 self.blacklist = set()
148
147
148 def __lt__(self, other):
149 return self.timestamp < other.timestamp
150
151 def __cmp__(self, other):
152 return cmp(self.timestamp, other.timestamp)
153
149 @property
154 @property
150 def dependents(self):
155 def dependents(self):
151 return self.follow.union(self.after)
156 return self.follow.union(self.after)
152
157
153 class TaskScheduler(SessionFactory):
158 class TaskScheduler(SessionFactory):
154 """Python TaskScheduler object.
159 """Python TaskScheduler object.
155
160
156 This is the simplest object that supports msg_id based
161 This is the simplest object that supports msg_id based
157 DAG dependencies. *Only* task msg_ids are checked, not
162 DAG dependencies. *Only* task msg_ids are checked, not
158 msg_ids of jobs submitted via the MUX queue.
163 msg_ids of jobs submitted via the MUX queue.
159
164
160 """
165 """
161
166
162 hwm = Integer(1, config=True,
167 hwm = Integer(1, config=True,
163 help="""specify the High Water Mark (HWM) for the downstream
168 help="""specify the High Water Mark (HWM) for the downstream
164 socket in the Task scheduler. This is the maximum number
169 socket in the Task scheduler. This is the maximum number
165 of allowed outstanding tasks on each engine.
170 of allowed outstanding tasks on each engine.
166
171
167 The default (1) means that only one task can be outstanding on each
172 The default (1) means that only one task can be outstanding on each
168 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
173 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
169 engines continue to be assigned tasks while they are working,
174 engines continue to be assigned tasks while they are working,
170 effectively hiding network latency behind computation, but can result
175 effectively hiding network latency behind computation, but can result
171 in an imbalance of work when submitting many heterogenous tasks all at
176 in an imbalance of work when submitting many heterogenous tasks all at
172 once. Any positive value greater than one is a compromise between the
177 once. Any positive value greater than one is a compromise between the
173 two.
178 two.
174
179
175 """
180 """
176 )
181 )
177 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
182 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
178 'leastload', config=True, allow_none=False,
183 'leastload', config=True, allow_none=False,
179 help="""select the task scheduler scheme [default: Python LRU]
184 help="""select the task scheduler scheme [default: Python LRU]
180 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
185 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
181 )
186 )
182 def _scheme_name_changed(self, old, new):
187 def _scheme_name_changed(self, old, new):
183 self.log.debug("Using scheme %r"%new)
188 self.log.debug("Using scheme %r"%new)
184 self.scheme = globals()[new]
189 self.scheme = globals()[new]
185
190
186 # input arguments:
191 # input arguments:
187 scheme = Instance(FunctionType) # function for determining the destination
192 scheme = Instance(FunctionType) # function for determining the destination
188 def _scheme_default(self):
193 def _scheme_default(self):
189 return leastload
194 return leastload
190 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
195 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
191 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
196 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
192 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
197 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
193 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
198 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
194 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
199 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
195
200
196 # internals:
201 # internals:
202 queue = List() # heap-sorted list of Jobs
203 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
204 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
205 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
206 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 depending = Dict() # dict by msg_id of Jobs
201 pending = Dict() # dict by engine_uuid of submitted tasks
207 pending = Dict() # dict by engine_uuid of submitted tasks
202 completed = Dict() # dict by engine_uuid of completed tasks
208 completed = Dict() # dict by engine_uuid of completed tasks
203 failed = Dict() # dict by engine_uuid of failed tasks
209 failed = Dict() # dict by engine_uuid of failed tasks
204 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
210 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
205 clients = Dict() # dict by msg_id for who submitted the task
211 clients = Dict() # dict by msg_id for who submitted the task
206 targets = List() # list of target IDENTs
212 targets = List() # list of target IDENTs
207 loads = List() # list of engine loads
213 loads = List() # list of engine loads
208 # full = Set() # set of IDENTs that have HWM outstanding tasks
214 # full = Set() # set of IDENTs that have HWM outstanding tasks
209 all_completed = Set() # set of all completed tasks
215 all_completed = Set() # set of all completed tasks
210 all_failed = Set() # set of all failed tasks
216 all_failed = Set() # set of all failed tasks
211 all_done = Set() # set of all finished tasks=union(completed,failed)
217 all_done = Set() # set of all finished tasks=union(completed,failed)
212 all_ids = Set() # set of all submitted task IDs
218 all_ids = Set() # set of all submitted task IDs
213
219
214 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
220 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
215
221
216 ident = CBytes() # ZMQ identity. This should just be self.session.session
222 ident = CBytes() # ZMQ identity. This should just be self.session.session
217 # but ensure Bytes
223 # but ensure Bytes
218 def _ident_default(self):
224 def _ident_default(self):
219 return self.session.bsession
225 return self.session.bsession
220
226
221 def start(self):
227 def start(self):
222 self.query_stream.on_recv(self.dispatch_query_reply)
228 self.query_stream.on_recv(self.dispatch_query_reply)
223 self.session.send(self.query_stream, "connection_request", {})
229 self.session.send(self.query_stream, "connection_request", {})
224
230
225 self.engine_stream.on_recv(self.dispatch_result, copy=False)
231 self.engine_stream.on_recv(self.dispatch_result, copy=False)
226 self.client_stream.on_recv(self.dispatch_submission, copy=False)
232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
227
233
228 self._notification_handlers = dict(
234 self._notification_handlers = dict(
229 registration_notification = self._register_engine,
235 registration_notification = self._register_engine,
230 unregistration_notification = self._unregister_engine
236 unregistration_notification = self._unregister_engine
231 )
237 )
232 self.notifier_stream.on_recv(self.dispatch_notification)
238 self.notifier_stream.on_recv(self.dispatch_notification)
233 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
239 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
234 self.auditor.start()
240 self.auditor.start()
235 self.log.info("Scheduler started [%s]"%self.scheme_name)
241 self.log.info("Scheduler started [%s]"%self.scheme_name)
236
242
237 def resume_receiving(self):
243 def resume_receiving(self):
238 """Resume accepting jobs."""
244 """Resume accepting jobs."""
239 self.client_stream.on_recv(self.dispatch_submission, copy=False)
245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
240
246
241 def stop_receiving(self):
247 def stop_receiving(self):
242 """Stop accepting jobs while there are no engines.
248 """Stop accepting jobs while there are no engines.
243 Leave them in the ZMQ queue."""
249 Leave them in the ZMQ queue."""
244 self.client_stream.on_recv(None)
250 self.client_stream.on_recv(None)
245
251
246 #-----------------------------------------------------------------------
252 #-----------------------------------------------------------------------
247 # [Un]Registration Handling
253 # [Un]Registration Handling
248 #-----------------------------------------------------------------------
254 #-----------------------------------------------------------------------
249
255
250
256
251 def dispatch_query_reply(self, msg):
257 def dispatch_query_reply(self, msg):
252 """handle reply to our initial connection request"""
258 """handle reply to our initial connection request"""
253 try:
259 try:
254 idents,msg = self.session.feed_identities(msg)
260 idents,msg = self.session.feed_identities(msg)
255 except ValueError:
261 except ValueError:
256 self.log.warn("task::Invalid Message: %r",msg)
262 self.log.warn("task::Invalid Message: %r",msg)
257 return
263 return
258 try:
264 try:
259 msg = self.session.unserialize(msg)
265 msg = self.session.unserialize(msg)
260 except ValueError:
266 except ValueError:
261 self.log.warn("task::Unauthorized message from: %r"%idents)
267 self.log.warn("task::Unauthorized message from: %r"%idents)
262 return
268 return
263
269
264 content = msg['content']
270 content = msg['content']
265 for uuid in content.get('engines', {}).values():
271 for uuid in content.get('engines', {}).values():
266 self._register_engine(cast_bytes(uuid))
272 self._register_engine(cast_bytes(uuid))
267
273
268
274
269 @util.log_errors
275 @util.log_errors
270 def dispatch_notification(self, msg):
276 def dispatch_notification(self, msg):
271 """dispatch register/unregister events."""
277 """dispatch register/unregister events."""
272 try:
278 try:
273 idents,msg = self.session.feed_identities(msg)
279 idents,msg = self.session.feed_identities(msg)
274 except ValueError:
280 except ValueError:
275 self.log.warn("task::Invalid Message: %r",msg)
281 self.log.warn("task::Invalid Message: %r",msg)
276 return
282 return
277 try:
283 try:
278 msg = self.session.unserialize(msg)
284 msg = self.session.unserialize(msg)
279 except ValueError:
285 except ValueError:
280 self.log.warn("task::Unauthorized message from: %r"%idents)
286 self.log.warn("task::Unauthorized message from: %r"%idents)
281 return
287 return
282
288
283 msg_type = msg['header']['msg_type']
289 msg_type = msg['header']['msg_type']
284
290
285 handler = self._notification_handlers.get(msg_type, None)
291 handler = self._notification_handlers.get(msg_type, None)
286 if handler is None:
292 if handler is None:
287 self.log.error("Unhandled message type: %r"%msg_type)
293 self.log.error("Unhandled message type: %r"%msg_type)
288 else:
294 else:
289 try:
295 try:
290 handler(cast_bytes(msg['content']['uuid']))
296 handler(cast_bytes(msg['content']['uuid']))
291 except Exception:
297 except Exception:
292 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
293
299
294 def _register_engine(self, uid):
300 def _register_engine(self, uid):
295 """New engine with ident `uid` became available."""
301 """New engine with ident `uid` became available."""
296 # head of the line:
302 # head of the line:
297 self.targets.insert(0,uid)
303 self.targets.insert(0,uid)
298 self.loads.insert(0,0)
304 self.loads.insert(0,0)
299
305
300 # initialize sets
306 # initialize sets
301 self.completed[uid] = set()
307 self.completed[uid] = set()
302 self.failed[uid] = set()
308 self.failed[uid] = set()
303 self.pending[uid] = {}
309 self.pending[uid] = {}
304
310
305 # rescan the graph:
311 # rescan the graph:
306 self.update_graph(None)
312 self.update_graph(None)
307
313
308 def _unregister_engine(self, uid):
314 def _unregister_engine(self, uid):
309 """Existing engine with ident `uid` became unavailable."""
315 """Existing engine with ident `uid` became unavailable."""
310 if len(self.targets) == 1:
316 if len(self.targets) == 1:
311 # this was our only engine
317 # this was our only engine
312 pass
318 pass
313
319
314 # handle any potentially finished tasks:
320 # handle any potentially finished tasks:
315 self.engine_stream.flush()
321 self.engine_stream.flush()
316
322
317 # don't pop destinations, because they might be used later
323 # don't pop destinations, because they might be used later
318 # map(self.destinations.pop, self.completed.pop(uid))
324 # map(self.destinations.pop, self.completed.pop(uid))
319 # map(self.destinations.pop, self.failed.pop(uid))
325 # map(self.destinations.pop, self.failed.pop(uid))
320
326
321 # prevent this engine from receiving work
327 # prevent this engine from receiving work
322 idx = self.targets.index(uid)
328 idx = self.targets.index(uid)
323 self.targets.pop(idx)
329 self.targets.pop(idx)
324 self.loads.pop(idx)
330 self.loads.pop(idx)
325
331
326 # wait 5 seconds before cleaning up pending jobs, since the results might
332 # wait 5 seconds before cleaning up pending jobs, since the results might
327 # still be incoming
333 # still be incoming
328 if self.pending[uid]:
334 if self.pending[uid]:
329 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
330 dc.start()
336 dc.start()
331 else:
337 else:
332 self.completed.pop(uid)
338 self.completed.pop(uid)
333 self.failed.pop(uid)
339 self.failed.pop(uid)
334
340
335
341
336 def handle_stranded_tasks(self, engine):
342 def handle_stranded_tasks(self, engine):
337 """Deal with jobs resident in an engine that died."""
343 """Deal with jobs resident in an engine that died."""
338 lost = self.pending[engine]
344 lost = self.pending[engine]
339 for msg_id in lost.keys():
345 for msg_id in lost.keys():
340 if msg_id not in self.pending[engine]:
346 if msg_id not in self.pending[engine]:
341 # prevent double-handling of messages
347 # prevent double-handling of messages
342 continue
348 continue
343
349
344 raw_msg = lost[msg_id].raw_msg
350 raw_msg = lost[msg_id].raw_msg
345 idents,msg = self.session.feed_identities(raw_msg, copy=False)
351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
346 parent = self.session.unpack(msg[1].bytes)
352 parent = self.session.unpack(msg[1].bytes)
347 idents = [engine, idents[0]]
353 idents = [engine, idents[0]]
348
354
349 # build fake error reply
355 # build fake error reply
350 try:
356 try:
351 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
352 except:
358 except:
353 content = error.wrap_exception()
359 content = error.wrap_exception()
354 # build fake metadata
360 # build fake metadata
355 md = dict(
361 md = dict(
356 status=u'error',
362 status=u'error',
357 engine=engine,
363 engine=engine,
358 date=datetime.now(),
364 date=datetime.now(),
359 )
365 )
360 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
361 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
362 # and dispatch it
368 # and dispatch it
363 self.dispatch_result(raw_reply)
369 self.dispatch_result(raw_reply)
364
370
365 # finally scrub completed/failed lists
371 # finally scrub completed/failed lists
366 self.completed.pop(engine)
372 self.completed.pop(engine)
367 self.failed.pop(engine)
373 self.failed.pop(engine)
368
374
369
375
370 #-----------------------------------------------------------------------
376 #-----------------------------------------------------------------------
371 # Job Submission
377 # Job Submission
372 #-----------------------------------------------------------------------
378 #-----------------------------------------------------------------------
373
379
374
380
375 @util.log_errors
381 @util.log_errors
376 def dispatch_submission(self, raw_msg):
382 def dispatch_submission(self, raw_msg):
377 """Dispatch job submission to appropriate handlers."""
383 """Dispatch job submission to appropriate handlers."""
378 # ensure targets up to date:
384 # ensure targets up to date:
379 self.notifier_stream.flush()
385 self.notifier_stream.flush()
380 try:
386 try:
381 idents, msg = self.session.feed_identities(raw_msg, copy=False)
387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
382 msg = self.session.unserialize(msg, content=False, copy=False)
388 msg = self.session.unserialize(msg, content=False, copy=False)
383 except Exception:
389 except Exception:
384 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
385 return
391 return
386
392
387
393
388 # send to monitor
394 # send to monitor
389 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
390
396
391 header = msg['header']
397 header = msg['header']
392 md = msg['metadata']
398 md = msg['metadata']
393 msg_id = header['msg_id']
399 msg_id = header['msg_id']
394 self.all_ids.add(msg_id)
400 self.all_ids.add(msg_id)
395
401
396 # get targets as a set of bytes objects
402 # get targets as a set of bytes objects
397 # from a list of unicode objects
403 # from a list of unicode objects
398 targets = md.get('targets', [])
404 targets = md.get('targets', [])
399 targets = map(cast_bytes, targets)
405 targets = map(cast_bytes, targets)
400 targets = set(targets)
406 targets = set(targets)
401
407
402 retries = md.get('retries', 0)
408 retries = md.get('retries', 0)
403 self.retries[msg_id] = retries
409 self.retries[msg_id] = retries
404
410
405 # time dependencies
411 # time dependencies
406 after = md.get('after', None)
412 after = md.get('after', None)
407 if after:
413 if after:
408 after = Dependency(after)
414 after = Dependency(after)
409 if after.all:
415 if after.all:
410 if after.success:
416 if after.success:
411 after = Dependency(after.difference(self.all_completed),
417 after = Dependency(after.difference(self.all_completed),
412 success=after.success,
418 success=after.success,
413 failure=after.failure,
419 failure=after.failure,
414 all=after.all,
420 all=after.all,
415 )
421 )
416 if after.failure:
422 if after.failure:
417 after = Dependency(after.difference(self.all_failed),
423 after = Dependency(after.difference(self.all_failed),
418 success=after.success,
424 success=after.success,
419 failure=after.failure,
425 failure=after.failure,
420 all=after.all,
426 all=after.all,
421 )
427 )
422 if after.check(self.all_completed, self.all_failed):
428 if after.check(self.all_completed, self.all_failed):
423 # recast as empty set, if `after` already met,
429 # recast as empty set, if `after` already met,
424 # to prevent unnecessary set comparisons
430 # to prevent unnecessary set comparisons
425 after = MET
431 after = MET
426 else:
432 else:
427 after = MET
433 after = MET
428
434
429 # location dependencies
435 # location dependencies
430 follow = Dependency(md.get('follow', []))
436 follow = Dependency(md.get('follow', []))
431
437
432 # turn timeouts into datetime objects:
438 # turn timeouts into datetime objects:
433 timeout = md.get('timeout', None)
439 timeout = md.get('timeout', None)
434 if timeout:
440 if timeout:
435 # cast to float, because jsonlib returns floats as decimal.Decimal,
441 # cast to float, because jsonlib returns floats as decimal.Decimal,
436 # which timedelta does not accept
442 # which timedelta does not accept
437 timeout = datetime.now() + timedelta(0,float(timeout),0)
443 timeout = datetime.now() + timedelta(0,float(timeout),0)
438
444
439 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
445 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
440 header=header, targets=targets, after=after, follow=follow,
446 header=header, targets=targets, after=after, follow=follow,
441 timeout=timeout, metadata=md,
447 timeout=timeout, metadata=md,
442 )
448 )
443
449
444 # validate and reduce dependencies:
450 # validate and reduce dependencies:
445 for dep in after,follow:
451 for dep in after,follow:
446 if not dep: # empty dependency
452 if not dep: # empty dependency
447 continue
453 continue
448 # check valid:
454 # check valid:
449 if msg_id in dep or dep.difference(self.all_ids):
455 if msg_id in dep or dep.difference(self.all_ids):
450 self.depending[msg_id] = job
456 self.queue_map[msg_id] = job
451 return self.fail_unreachable(msg_id, error.InvalidDependency)
457 return self.fail_unreachable(msg_id, error.InvalidDependency)
452 # check if unreachable:
458 # check if unreachable:
453 if dep.unreachable(self.all_completed, self.all_failed):
459 if dep.unreachable(self.all_completed, self.all_failed):
454 self.depending[msg_id] = job
460 self.queue_map[msg_id] = job
455 return self.fail_unreachable(msg_id)
461 return self.fail_unreachable(msg_id)
456
462
457 if after.check(self.all_completed, self.all_failed):
463 if after.check(self.all_completed, self.all_failed):
458 # time deps already met, try to run
464 # time deps already met, try to run
459 if not self.maybe_run(job):
465 if not self.maybe_run(job):
460 # can't run yet
466 # can't run yet
461 if msg_id not in self.all_failed:
467 if msg_id not in self.all_failed:
462 # could have failed as unreachable
468 # could have failed as unreachable
463 self.save_unmet(job)
469 self.save_unmet(job)
464 else:
470 else:
465 self.save_unmet(job)
471 self.save_unmet(job)
466
472
467 def audit_timeouts(self):
473 def audit_timeouts(self):
468 """Audit all waiting tasks for expired timeouts."""
474 """Audit all waiting tasks for expired timeouts."""
469 now = datetime.now()
475 now = datetime.now()
470 for msg_id in self.depending.keys():
476 for msg_id in self.queue_map.keys():
471 # must recheck, in case one failure cascaded to another:
477 # must recheck, in case one failure cascaded to another:
472 if msg_id in self.depending:
478 if msg_id in self.queue_map:
473 job = self.depending[msg_id]
479 job = self.queue_map[msg_id]
474 if job.timeout and job.timeout < now:
480 if job.timeout and job.timeout < now:
475 self.fail_unreachable(msg_id, error.TaskTimeout)
481 self.fail_unreachable(msg_id, error.TaskTimeout)
476
482
477 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
483 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
478 """a task has become unreachable, send a reply with an ImpossibleDependency
484 """a task has become unreachable, send a reply with an ImpossibleDependency
479 error."""
485 error."""
480 if msg_id not in self.depending:
486 if msg_id not in self.queue_map:
481 self.log.error("msg %r already failed!", msg_id)
487 self.log.error("msg %r already failed!", msg_id)
482 return
488 return
483 job = self.depending.pop(msg_id)
489 job = self.queue_map.pop(msg_id)
490 # lazy-delete from the queue
491 job.removed = True
484 for mid in job.dependents:
492 for mid in job.dependents:
485 if mid in self.graph:
493 if mid in self.graph:
486 self.graph[mid].remove(msg_id)
494 self.graph[mid].remove(msg_id)
487
495
488 try:
496 try:
489 raise why()
497 raise why()
490 except:
498 except:
491 content = error.wrap_exception()
499 content = error.wrap_exception()
492
500
493 self.all_done.add(msg_id)
501 self.all_done.add(msg_id)
494 self.all_failed.add(msg_id)
502 self.all_failed.add(msg_id)
495
503
496 msg = self.session.send(self.client_stream, 'apply_reply', content,
504 msg = self.session.send(self.client_stream, 'apply_reply', content,
497 parent=job.header, ident=job.idents)
505 parent=job.header, ident=job.idents)
498 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
506 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
499
507
500 self.update_graph(msg_id, success=False)
508 self.update_graph(msg_id, success=False)
501
509
502 def available_engines(self):
510 def available_engines(self):
503 """return a list of available engine indices based on HWM"""
511 """return a list of available engine indices based on HWM"""
504 if not self.hwm:
512 if not self.hwm:
505 return range(len(self.targets))
513 return range(len(self.targets))
506 available = []
514 available = []
507 for idx in range(len(self.targets)):
515 for idx in range(len(self.targets)):
508 if self.loads[idx] < self.hwm:
516 if self.loads[idx] < self.hwm:
509 available.append(idx)
517 available.append(idx)
510 return available
518 return available
511
519
512 def maybe_run(self, job):
520 def maybe_run(self, job):
513 """check location dependencies, and run if they are met."""
521 """check location dependencies, and run if they are met."""
514 msg_id = job.msg_id
522 msg_id = job.msg_id
515 self.log.debug("Attempting to assign task %s", msg_id)
523 self.log.debug("Attempting to assign task %s", msg_id)
516 available = self.available_engines()
524 available = self.available_engines()
517 if not available:
525 if not available:
518 # no engines, definitely can't run
526 # no engines, definitely can't run
519 return False
527 return False
520
528
521 if job.follow or job.targets or job.blacklist or self.hwm:
529 if job.follow or job.targets or job.blacklist or self.hwm:
522 # we need a can_run filter
530 # we need a can_run filter
523 def can_run(idx):
531 def can_run(idx):
524 # check hwm
532 # check hwm
525 if self.hwm and self.loads[idx] == self.hwm:
533 if self.hwm and self.loads[idx] == self.hwm:
526 return False
534 return False
527 target = self.targets[idx]
535 target = self.targets[idx]
528 # check blacklist
536 # check blacklist
529 if target in job.blacklist:
537 if target in job.blacklist:
530 return False
538 return False
531 # check targets
539 # check targets
532 if job.targets and target not in job.targets:
540 if job.targets and target not in job.targets:
533 return False
541 return False
534 # check follow
542 # check follow
535 return job.follow.check(self.completed[target], self.failed[target])
543 return job.follow.check(self.completed[target], self.failed[target])
536
544
537 indices = filter(can_run, available)
545 indices = filter(can_run, available)
538
546
539 if not indices:
547 if not indices:
540 # couldn't run
548 # couldn't run
541 if job.follow.all:
549 if job.follow.all:
542 # check follow for impossibility
550 # check follow for impossibility
543 dests = set()
551 dests = set()
544 relevant = set()
552 relevant = set()
545 if job.follow.success:
553 if job.follow.success:
546 relevant = self.all_completed
554 relevant = self.all_completed
547 if job.follow.failure:
555 if job.follow.failure:
548 relevant = relevant.union(self.all_failed)
556 relevant = relevant.union(self.all_failed)
549 for m in job.follow.intersection(relevant):
557 for m in job.follow.intersection(relevant):
550 dests.add(self.destinations[m])
558 dests.add(self.destinations[m])
551 if len(dests) > 1:
559 if len(dests) > 1:
552 self.depending[msg_id] = job
560 self.queue_map[msg_id] = job
553 self.fail_unreachable(msg_id)
561 self.fail_unreachable(msg_id)
554 return False
562 return False
555 if job.targets:
563 if job.targets:
556 # check blacklist+targets for impossibility
564 # check blacklist+targets for impossibility
557 job.targets.difference_update(job.blacklist)
565 job.targets.difference_update(job.blacklist)
558 if not job.targets or not job.targets.intersection(self.targets):
566 if not job.targets or not job.targets.intersection(self.targets):
559 self.depending[msg_id] = job
567 self.queue_map[msg_id] = job
560 self.fail_unreachable(msg_id)
568 self.fail_unreachable(msg_id)
561 return False
569 return False
562 return False
570 return False
563 else:
571 else:
564 indices = None
572 indices = None
565
573
566 self.submit_task(job, indices)
574 self.submit_task(job, indices)
567 return True
575 return True
568
576
569 def save_unmet(self, job):
577 def save_unmet(self, job):
570 """Save a message for later submission when its dependencies are met."""
578 """Save a message for later submission when its dependencies are met."""
571 msg_id = job.msg_id
579 msg_id = job.msg_id
572 self.depending[msg_id] = job
580 self.queue_map[msg_id] = job
581 heapq.heappush(self.queue, job)
573 # track the ids in follow or after, but not those already finished
582 # track the ids in follow or after, but not those already finished
574 for dep_id in job.after.union(job.follow).difference(self.all_done):
583 for dep_id in job.after.union(job.follow).difference(self.all_done):
575 if dep_id not in self.graph:
584 if dep_id not in self.graph:
576 self.graph[dep_id] = set()
585 self.graph[dep_id] = set()
577 self.graph[dep_id].add(msg_id)
586 self.graph[dep_id].add(msg_id)
578
587
579 def submit_task(self, job, indices=None):
588 def submit_task(self, job, indices=None):
580 """Submit a task to any of a subset of our targets."""
589 """Submit a task to any of a subset of our targets."""
581 if indices:
590 if indices:
582 loads = [self.loads[i] for i in indices]
591 loads = [self.loads[i] for i in indices]
583 else:
592 else:
584 loads = self.loads
593 loads = self.loads
585 idx = self.scheme(loads)
594 idx = self.scheme(loads)
586 if indices:
595 if indices:
587 idx = indices[idx]
596 idx = indices[idx]
588 target = self.targets[idx]
597 target = self.targets[idx]
589 # print (target, map(str, msg[:3]))
598 # print (target, map(str, msg[:3]))
590 # send job to the engine
599 # send job to the engine
591 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
600 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
592 self.engine_stream.send_multipart(job.raw_msg, copy=False)
601 self.engine_stream.send_multipart(job.raw_msg, copy=False)
593 # update load
602 # update load
594 self.add_job(idx)
603 self.add_job(idx)
595 self.pending[target][job.msg_id] = job
604 self.pending[target][job.msg_id] = job
596 # notify Hub
605 # notify Hub
597 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
606 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
598 self.session.send(self.mon_stream, 'task_destination', content=content,
607 self.session.send(self.mon_stream, 'task_destination', content=content,
599 ident=[b'tracktask',self.ident])
608 ident=[b'tracktask',self.ident])
600
609
601
610
602 #-----------------------------------------------------------------------
611 #-----------------------------------------------------------------------
603 # Result Handling
612 # Result Handling
604 #-----------------------------------------------------------------------
613 #-----------------------------------------------------------------------
605
614
606
615
607 @util.log_errors
616 @util.log_errors
608 def dispatch_result(self, raw_msg):
617 def dispatch_result(self, raw_msg):
609 """dispatch method for result replies"""
618 """dispatch method for result replies"""
610 try:
619 try:
611 idents,msg = self.session.feed_identities(raw_msg, copy=False)
620 idents,msg = self.session.feed_identities(raw_msg, copy=False)
612 msg = self.session.unserialize(msg, content=False, copy=False)
621 msg = self.session.unserialize(msg, content=False, copy=False)
613 engine = idents[0]
622 engine = idents[0]
614 try:
623 try:
615 idx = self.targets.index(engine)
624 idx = self.targets.index(engine)
616 except ValueError:
625 except ValueError:
617 pass # skip load-update for dead engines
626 pass # skip load-update for dead engines
618 else:
627 else:
619 self.finish_job(idx)
628 self.finish_job(idx)
620 except Exception:
629 except Exception:
621 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
630 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
622 return
631 return
623
632
624 md = msg['metadata']
633 md = msg['metadata']
625 parent = msg['parent_header']
634 parent = msg['parent_header']
626 if md.get('dependencies_met', True):
635 if md.get('dependencies_met', True):
627 success = (md['status'] == 'ok')
636 success = (md['status'] == 'ok')
628 msg_id = parent['msg_id']
637 msg_id = parent['msg_id']
629 retries = self.retries[msg_id]
638 retries = self.retries[msg_id]
630 if not success and retries > 0:
639 if not success and retries > 0:
631 # failed
640 # failed
632 self.retries[msg_id] = retries - 1
641 self.retries[msg_id] = retries - 1
633 self.handle_unmet_dependency(idents, parent)
642 self.handle_unmet_dependency(idents, parent)
634 else:
643 else:
635 del self.retries[msg_id]
644 del self.retries[msg_id]
636 # relay to client and update graph
645 # relay to client and update graph
637 self.handle_result(idents, parent, raw_msg, success)
646 self.handle_result(idents, parent, raw_msg, success)
638 # send to Hub monitor
647 # send to Hub monitor
639 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
648 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
640 else:
649 else:
641 self.handle_unmet_dependency(idents, parent)
650 self.handle_unmet_dependency(idents, parent)
642
651
643 def handle_result(self, idents, parent, raw_msg, success=True):
652 def handle_result(self, idents, parent, raw_msg, success=True):
644 """handle a real task result, either success or failure"""
653 """handle a real task result, either success or failure"""
645 # first, relay result to client
654 # first, relay result to client
646 engine = idents[0]
655 engine = idents[0]
647 client = idents[1]
656 client = idents[1]
648 # swap_ids for ROUTER-ROUTER mirror
657 # swap_ids for ROUTER-ROUTER mirror
649 raw_msg[:2] = [client,engine]
658 raw_msg[:2] = [client,engine]
650 # print (map(str, raw_msg[:4]))
659 # print (map(str, raw_msg[:4]))
651 self.client_stream.send_multipart(raw_msg, copy=False)
660 self.client_stream.send_multipart(raw_msg, copy=False)
652 # now, update our data structures
661 # now, update our data structures
653 msg_id = parent['msg_id']
662 msg_id = parent['msg_id']
654 self.pending[engine].pop(msg_id)
663 self.pending[engine].pop(msg_id)
655 if success:
664 if success:
656 self.completed[engine].add(msg_id)
665 self.completed[engine].add(msg_id)
657 self.all_completed.add(msg_id)
666 self.all_completed.add(msg_id)
658 else:
667 else:
659 self.failed[engine].add(msg_id)
668 self.failed[engine].add(msg_id)
660 self.all_failed.add(msg_id)
669 self.all_failed.add(msg_id)
661 self.all_done.add(msg_id)
670 self.all_done.add(msg_id)
662 self.destinations[msg_id] = engine
671 self.destinations[msg_id] = engine
663
672
664 self.update_graph(msg_id, success)
673 self.update_graph(msg_id, success)
665
674
666 def handle_unmet_dependency(self, idents, parent):
675 def handle_unmet_dependency(self, idents, parent):
667 """handle an unmet dependency"""
676 """handle an unmet dependency"""
668 engine = idents[0]
677 engine = idents[0]
669 msg_id = parent['msg_id']
678 msg_id = parent['msg_id']
670
679
671 job = self.pending[engine].pop(msg_id)
680 job = self.pending[engine].pop(msg_id)
672 job.blacklist.add(engine)
681 job.blacklist.add(engine)
673
682
674 if job.blacklist == job.targets:
683 if job.blacklist == job.targets:
675 self.depending[msg_id] = job
684 self.queue_map[msg_id] = job
676 self.fail_unreachable(msg_id)
685 self.fail_unreachable(msg_id)
677 elif not self.maybe_run(job):
686 elif not self.maybe_run(job):
678 # resubmit failed
687 # resubmit failed
679 if msg_id not in self.all_failed:
688 if msg_id not in self.all_failed:
680 # put it back in our dependency tree
689 # put it back in our dependency tree
681 self.save_unmet(job)
690 self.save_unmet(job)
682
691
683 if self.hwm:
692 if self.hwm:
684 try:
693 try:
685 idx = self.targets.index(engine)
694 idx = self.targets.index(engine)
686 except ValueError:
695 except ValueError:
687 pass # skip load-update for dead engines
696 pass # skip load-update for dead engines
688 else:
697 else:
689 if self.loads[idx] == self.hwm-1:
698 if self.loads[idx] == self.hwm-1:
690 self.update_graph(None)
699 self.update_graph(None)
691
700
692 def update_graph(self, dep_id=None, success=True):
701 def update_graph(self, dep_id=None, success=True):
693 """dep_id just finished. Update our dependency
702 """dep_id just finished. Update our dependency
694 graph and submit any jobs that just became runnable.
703 graph and submit any jobs that just became runnable.
695
704
696 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
705 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
697 """
706 """
698 # print ("\n\n***********")
707 # print ("\n\n***********")
699 # pprint (dep_id)
708 # pprint (dep_id)
700 # pprint (self.graph)
709 # pprint (self.graph)
701 # pprint (self.depending)
710 # pprint (self.queue_map)
702 # pprint (self.all_completed)
711 # pprint (self.all_completed)
703 # pprint (self.all_failed)
712 # pprint (self.all_failed)
704 # print ("\n\n***********\n\n")
713 # print ("\n\n***********\n\n")
705 # update any jobs that depended on the dependency
714 # update any jobs that depended on the dependency
706 jobs = self.graph.pop(dep_id, [])
715 msg_ids = self.graph.pop(dep_id, [])
707
716
708 # recheck *all* jobs if
717 # recheck *all* jobs if
709 # a) we have HWM and an engine just become no longer full
718 # a) we have HWM and an engine just become no longer full
710 # or b) dep_id was given as None
719 # or b) dep_id was given as None
711
720
712 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
721 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
713 jobs = self.depending.keys()
722 jobs = self.queue
723 using_queue = True
724 else:
725 using_queue = False
726 jobs = heapq.heapify([ self.queue_map[msg_id] for msg_id in msg_ids ])
714
727
715 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
728 to_restore = []
716 job = self.depending[msg_id]
729 while jobs:
717
730 job = heapq.heappop(jobs)
731 if job.removed:
732 continue
733 msg_id = job.msg_id
734
735 put_it_back = True
736
718 if job.after.unreachable(self.all_completed, self.all_failed)\
737 if job.after.unreachable(self.all_completed, self.all_failed)\
719 or job.follow.unreachable(self.all_completed, self.all_failed):
738 or job.follow.unreachable(self.all_completed, self.all_failed):
720 self.fail_unreachable(msg_id)
739 self.fail_unreachable(msg_id)
740 put_it_back = False
721
741
722 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
742 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
723 if self.maybe_run(job):
743 if self.maybe_run(job):
724
744 put_it_back = False
725 self.depending.pop(msg_id)
745 self.queue_map.pop(msg_id)
726 for mid in job.dependents:
746 for mid in job.dependents:
727 if mid in self.graph:
747 if mid in self.graph:
728 self.graph[mid].remove(msg_id)
748 self.graph[mid].remove(msg_id)
729
749
730 # abort the loop if we just filled up all of our engines.
750 # abort the loop if we just filled up all of our engines.
731 # avoids an O(N) operation in situation of full queue,
751 # avoids an O(N) operation in situation of full queue,
732 # where graph update is triggered as soon as an engine becomes
752 # where graph update is triggered as soon as an engine becomes
733 # non-full, and all tasks after the first are checked,
753 # non-full, and all tasks after the first are checked,
734 # even though they can't run.
754 # even though they can't run.
735 if not self.available_engines():
755 if not self.available_engines():
736 return
756 break
757
758 if using_queue and put_it_back:
759 # popped a job from the queue but it neither ran nor failed,
760 # so we need to put it back when we are done
761 to_restore.append(job)
762
763 # put back any tasks we popped but didn't run
764 for job in to_restore:
765 heapq.heappush(self.queue, job)
766
737
767
738 #----------------------------------------------------------------------
768 #----------------------------------------------------------------------
739 # methods to be overridden by subclasses
769 # methods to be overridden by subclasses
740 #----------------------------------------------------------------------
770 #----------------------------------------------------------------------
741
771
742 def add_job(self, idx):
772 def add_job(self, idx):
743 """Called after self.targets[idx] just got the job with header.
773 """Called after self.targets[idx] just got the job with header.
744 Override with subclasses. The default ordering is simple LRU.
774 Override with subclasses. The default ordering is simple LRU.
745 The default loads are the number of outstanding jobs."""
775 The default loads are the number of outstanding jobs."""
746 self.loads[idx] += 1
776 self.loads[idx] += 1
747 for lis in (self.targets, self.loads):
777 for lis in (self.targets, self.loads):
748 lis.append(lis.pop(idx))
778 lis.append(lis.pop(idx))
749
779
750
780
751 def finish_job(self, idx):
781 def finish_job(self, idx):
752 """Called after self.targets[idx] just finished a job.
782 """Called after self.targets[idx] just finished a job.
753 Override with subclasses."""
783 Override with subclasses."""
754 self.loads[idx] -= 1
784 self.loads[idx] -= 1
755
785
756
786
757
787
758 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
788 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
759 logname='root', log_url=None, loglevel=logging.DEBUG,
789 logname='root', log_url=None, loglevel=logging.DEBUG,
760 identity=b'task', in_thread=False):
790 identity=b'task', in_thread=False):
761
791
762 ZMQStream = zmqstream.ZMQStream
792 ZMQStream = zmqstream.ZMQStream
763
793
764 if config:
794 if config:
765 # unwrap dict back into Config
795 # unwrap dict back into Config
766 config = Config(config)
796 config = Config(config)
767
797
768 if in_thread:
798 if in_thread:
769 # use instance() to get the same Context/Loop as our parent
799 # use instance() to get the same Context/Loop as our parent
770 ctx = zmq.Context.instance()
800 ctx = zmq.Context.instance()
771 loop = ioloop.IOLoop.instance()
801 loop = ioloop.IOLoop.instance()
772 else:
802 else:
773 # in a process, don't use instance()
803 # in a process, don't use instance()
774 # for safety with multiprocessing
804 # for safety with multiprocessing
775 ctx = zmq.Context()
805 ctx = zmq.Context()
776 loop = ioloop.IOLoop()
806 loop = ioloop.IOLoop()
777 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
807 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
778 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
808 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
779 ins.bind(in_addr)
809 ins.bind(in_addr)
780
810
781 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
782 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
812 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
783 outs.bind(out_addr)
813 outs.bind(out_addr)
784 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
814 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
785 mons.connect(mon_addr)
815 mons.connect(mon_addr)
786 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
816 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
787 nots.setsockopt(zmq.SUBSCRIBE, b'')
817 nots.setsockopt(zmq.SUBSCRIBE, b'')
788 nots.connect(not_addr)
818 nots.connect(not_addr)
789
819
790 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
820 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
791 querys.connect(reg_addr)
821 querys.connect(reg_addr)
792
822
793 # setup logging.
823 # setup logging.
794 if in_thread:
824 if in_thread:
795 log = Application.instance().log
825 log = Application.instance().log
796 else:
826 else:
797 if log_url:
827 if log_url:
798 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
828 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
799 else:
829 else:
800 log = local_logger(logname, loglevel)
830 log = local_logger(logname, loglevel)
801
831
802 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
832 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
803 mon_stream=mons, notifier_stream=nots,
833 mon_stream=mons, notifier_stream=nots,
804 query_stream=querys,
834 query_stream=querys,
805 loop=loop, log=log,
835 loop=loop, log=log,
806 config=config)
836 config=config)
807 scheduler.start()
837 scheduler.start()
808 if not in_thread:
838 if not in_thread:
809 try:
839 try:
810 loop.start()
840 loop.start()
811 except KeyboardInterrupt:
841 except KeyboardInterrupt:
812 scheduler.log.critical("Interrupted, exiting...")
842 scheduler.log.critical("Interrupted, exiting...")
813
843
General Comments 0
You need to be logged in to leave comments. Login now