##// END OF EJS Templates
add debug log when a task is added to the queue
MinRK -
Show More
@@ -1,843 +1,844 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 import heapq
22 import heapq
23 import logging
23 import logging
24 import sys
24 import sys
25 import time
25 import time
26
26
27 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
28 from random import randint, random
28 from random import randint, random
29 from types import FunctionType
29 from types import FunctionType
30
30
31 try:
31 try:
32 import numpy
32 import numpy
33 except ImportError:
33 except ImportError:
34 numpy = None
34 numpy = None
35
35
36 import zmq
36 import zmq
37 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
38
38
39 # local imports
39 # local imports
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 from IPython.config.application import Application
42 from IPython.config.loader import Config
42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 from IPython.utils.py3compat import cast_bytes
44 from IPython.utils.py3compat import cast_bytes
45
45
46 from IPython.parallel import error, util
46 from IPython.parallel import error, util
47 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.factory import SessionFactory
48 from IPython.parallel.util import connect_logger, local_logger
48 from IPython.parallel.util import connect_logger, local_logger
49
49
50 from .dependency import Dependency
50 from .dependency import Dependency
51
51
52 @decorator
52 @decorator
53 def logged(f,self,*args,**kwargs):
53 def logged(f,self,*args,**kwargs):
54 # print ("#--------------------")
54 # print ("#--------------------")
55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 # print ("#--")
56 # print ("#--")
57 return f(self,*args, **kwargs)
57 return f(self,*args, **kwargs)
58
58
59 #----------------------------------------------------------------------
59 #----------------------------------------------------------------------
60 # Chooser functions
60 # Chooser functions
61 #----------------------------------------------------------------------
61 #----------------------------------------------------------------------
62
62
63 def plainrandom(loads):
63 def plainrandom(loads):
64 """Plain random pick."""
64 """Plain random pick."""
65 n = len(loads)
65 n = len(loads)
66 return randint(0,n-1)
66 return randint(0,n-1)
67
67
68 def lru(loads):
68 def lru(loads):
69 """Always pick the front of the line.
69 """Always pick the front of the line.
70
70
71 The content of `loads` is ignored.
71 The content of `loads` is ignored.
72
72
73 Assumes LRU ordering of loads, with oldest first.
73 Assumes LRU ordering of loads, with oldest first.
74 """
74 """
75 return 0
75 return 0
76
76
77 def twobin(loads):
77 def twobin(loads):
78 """Pick two at random, use the LRU of the two.
78 """Pick two at random, use the LRU of the two.
79
79
80 The content of loads is ignored.
80 The content of loads is ignored.
81
81
82 Assumes LRU ordering of loads, with oldest first.
82 Assumes LRU ordering of loads, with oldest first.
83 """
83 """
84 n = len(loads)
84 n = len(loads)
85 a = randint(0,n-1)
85 a = randint(0,n-1)
86 b = randint(0,n-1)
86 b = randint(0,n-1)
87 return min(a,b)
87 return min(a,b)
88
88
89 def weighted(loads):
89 def weighted(loads):
90 """Pick two at random using inverse load as weight.
90 """Pick two at random using inverse load as weight.
91
91
92 Return the less loaded of the two.
92 Return the less loaded of the two.
93 """
93 """
94 # weight 0 a million times more than 1:
94 # weight 0 a million times more than 1:
95 weights = 1./(1e-6+numpy.array(loads))
95 weights = 1./(1e-6+numpy.array(loads))
96 sums = weights.cumsum()
96 sums = weights.cumsum()
97 t = sums[-1]
97 t = sums[-1]
98 x = random()*t
98 x = random()*t
99 y = random()*t
99 y = random()*t
100 idx = 0
100 idx = 0
101 idy = 0
101 idy = 0
102 while sums[idx] < x:
102 while sums[idx] < x:
103 idx += 1
103 idx += 1
104 while sums[idy] < y:
104 while sums[idy] < y:
105 idy += 1
105 idy += 1
106 if weights[idy] > weights[idx]:
106 if weights[idy] > weights[idx]:
107 return idy
107 return idy
108 else:
108 else:
109 return idx
109 return idx
110
110
111 def leastload(loads):
111 def leastload(loads):
112 """Always choose the lowest load.
112 """Always choose the lowest load.
113
113
114 If the lowest load occurs more than once, the first
114 If the lowest load occurs more than once, the first
115 occurance will be used. If loads has LRU ordering, this means
115 occurance will be used. If loads has LRU ordering, this means
116 the LRU of those with the lowest load is chosen.
116 the LRU of those with the lowest load is chosen.
117 """
117 """
118 return loads.index(min(loads))
118 return loads.index(min(loads))
119
119
120 #---------------------------------------------------------------------
120 #---------------------------------------------------------------------
121 # Classes
121 # Classes
122 #---------------------------------------------------------------------
122 #---------------------------------------------------------------------
123
123
124
124
125 # store empty default dependency:
125 # store empty default dependency:
126 MET = Dependency([])
126 MET = Dependency([])
127
127
128
128
129 class Job(object):
129 class Job(object):
130 """Simple container for a job"""
130 """Simple container for a job"""
131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 targets, after, follow, timeout):
132 targets, after, follow, timeout):
133 self.msg_id = msg_id
133 self.msg_id = msg_id
134 self.raw_msg = raw_msg
134 self.raw_msg = raw_msg
135 self.idents = idents
135 self.idents = idents
136 self.msg = msg
136 self.msg = msg
137 self.header = header
137 self.header = header
138 self.metadata = metadata
138 self.metadata = metadata
139 self.targets = targets
139 self.targets = targets
140 self.after = after
140 self.after = after
141 self.follow = follow
141 self.follow = follow
142 self.timeout = timeout
142 self.timeout = timeout
143 self.removed = False # used for lazy-delete in heap-sorted queue
143 self.removed = False # used for lazy-delete in heap-sorted queue
144
144
145 self.timestamp = time.time()
145 self.timestamp = time.time()
146 self.blacklist = set()
146 self.blacklist = set()
147
147
148 def __lt__(self, other):
148 def __lt__(self, other):
149 return self.timestamp < other.timestamp
149 return self.timestamp < other.timestamp
150
150
151 def __cmp__(self, other):
151 def __cmp__(self, other):
152 return cmp(self.timestamp, other.timestamp)
152 return cmp(self.timestamp, other.timestamp)
153
153
154 @property
154 @property
155 def dependents(self):
155 def dependents(self):
156 return self.follow.union(self.after)
156 return self.follow.union(self.after)
157
157
158 class TaskScheduler(SessionFactory):
158 class TaskScheduler(SessionFactory):
159 """Python TaskScheduler object.
159 """Python TaskScheduler object.
160
160
161 This is the simplest object that supports msg_id based
161 This is the simplest object that supports msg_id based
162 DAG dependencies. *Only* task msg_ids are checked, not
162 DAG dependencies. *Only* task msg_ids are checked, not
163 msg_ids of jobs submitted via the MUX queue.
163 msg_ids of jobs submitted via the MUX queue.
164
164
165 """
165 """
166
166
167 hwm = Integer(1, config=True,
167 hwm = Integer(1, config=True,
168 help="""specify the High Water Mark (HWM) for the downstream
168 help="""specify the High Water Mark (HWM) for the downstream
169 socket in the Task scheduler. This is the maximum number
169 socket in the Task scheduler. This is the maximum number
170 of allowed outstanding tasks on each engine.
170 of allowed outstanding tasks on each engine.
171
171
172 The default (1) means that only one task can be outstanding on each
172 The default (1) means that only one task can be outstanding on each
173 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
173 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
174 engines continue to be assigned tasks while they are working,
174 engines continue to be assigned tasks while they are working,
175 effectively hiding network latency behind computation, but can result
175 effectively hiding network latency behind computation, but can result
176 in an imbalance of work when submitting many heterogenous tasks all at
176 in an imbalance of work when submitting many heterogenous tasks all at
177 once. Any positive value greater than one is a compromise between the
177 once. Any positive value greater than one is a compromise between the
178 two.
178 two.
179
179
180 """
180 """
181 )
181 )
182 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
182 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
183 'leastload', config=True, allow_none=False,
183 'leastload', config=True, allow_none=False,
184 help="""select the task scheduler scheme [default: Python LRU]
184 help="""select the task scheduler scheme [default: Python LRU]
185 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
185 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
186 )
186 )
187 def _scheme_name_changed(self, old, new):
187 def _scheme_name_changed(self, old, new):
188 self.log.debug("Using scheme %r"%new)
188 self.log.debug("Using scheme %r"%new)
189 self.scheme = globals()[new]
189 self.scheme = globals()[new]
190
190
191 # input arguments:
191 # input arguments:
192 scheme = Instance(FunctionType) # function for determining the destination
192 scheme = Instance(FunctionType) # function for determining the destination
193 def _scheme_default(self):
193 def _scheme_default(self):
194 return leastload
194 return leastload
195 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
195 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
196 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
196 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
197 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
197 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
198 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
198 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
199 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
199 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
200
200
201 # internals:
201 # internals:
202 queue = List() # heap-sorted list of Jobs
202 queue = List() # heap-sorted list of Jobs
203 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
203 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
204 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
204 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
205 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
205 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
206 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
206 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
207 pending = Dict() # dict by engine_uuid of submitted tasks
207 pending = Dict() # dict by engine_uuid of submitted tasks
208 completed = Dict() # dict by engine_uuid of completed tasks
208 completed = Dict() # dict by engine_uuid of completed tasks
209 failed = Dict() # dict by engine_uuid of failed tasks
209 failed = Dict() # dict by engine_uuid of failed tasks
210 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
210 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
211 clients = Dict() # dict by msg_id for who submitted the task
211 clients = Dict() # dict by msg_id for who submitted the task
212 targets = List() # list of target IDENTs
212 targets = List() # list of target IDENTs
213 loads = List() # list of engine loads
213 loads = List() # list of engine loads
214 # full = Set() # set of IDENTs that have HWM outstanding tasks
214 # full = Set() # set of IDENTs that have HWM outstanding tasks
215 all_completed = Set() # set of all completed tasks
215 all_completed = Set() # set of all completed tasks
216 all_failed = Set() # set of all failed tasks
216 all_failed = Set() # set of all failed tasks
217 all_done = Set() # set of all finished tasks=union(completed,failed)
217 all_done = Set() # set of all finished tasks=union(completed,failed)
218 all_ids = Set() # set of all submitted task IDs
218 all_ids = Set() # set of all submitted task IDs
219
219
220 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
220 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
221
221
222 ident = CBytes() # ZMQ identity. This should just be self.session.session
222 ident = CBytes() # ZMQ identity. This should just be self.session.session
223 # but ensure Bytes
223 # but ensure Bytes
224 def _ident_default(self):
224 def _ident_default(self):
225 return self.session.bsession
225 return self.session.bsession
226
226
227 def start(self):
227 def start(self):
228 self.query_stream.on_recv(self.dispatch_query_reply)
228 self.query_stream.on_recv(self.dispatch_query_reply)
229 self.session.send(self.query_stream, "connection_request", {})
229 self.session.send(self.query_stream, "connection_request", {})
230
230
231 self.engine_stream.on_recv(self.dispatch_result, copy=False)
231 self.engine_stream.on_recv(self.dispatch_result, copy=False)
232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
233
233
234 self._notification_handlers = dict(
234 self._notification_handlers = dict(
235 registration_notification = self._register_engine,
235 registration_notification = self._register_engine,
236 unregistration_notification = self._unregister_engine
236 unregistration_notification = self._unregister_engine
237 )
237 )
238 self.notifier_stream.on_recv(self.dispatch_notification)
238 self.notifier_stream.on_recv(self.dispatch_notification)
239 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
239 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
240 self.auditor.start()
240 self.auditor.start()
241 self.log.info("Scheduler started [%s]"%self.scheme_name)
241 self.log.info("Scheduler started [%s]"%self.scheme_name)
242
242
243 def resume_receiving(self):
243 def resume_receiving(self):
244 """Resume accepting jobs."""
244 """Resume accepting jobs."""
245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246
246
247 def stop_receiving(self):
247 def stop_receiving(self):
248 """Stop accepting jobs while there are no engines.
248 """Stop accepting jobs while there are no engines.
249 Leave them in the ZMQ queue."""
249 Leave them in the ZMQ queue."""
250 self.client_stream.on_recv(None)
250 self.client_stream.on_recv(None)
251
251
252 #-----------------------------------------------------------------------
252 #-----------------------------------------------------------------------
253 # [Un]Registration Handling
253 # [Un]Registration Handling
254 #-----------------------------------------------------------------------
254 #-----------------------------------------------------------------------
255
255
256
256
257 def dispatch_query_reply(self, msg):
257 def dispatch_query_reply(self, msg):
258 """handle reply to our initial connection request"""
258 """handle reply to our initial connection request"""
259 try:
259 try:
260 idents,msg = self.session.feed_identities(msg)
260 idents,msg = self.session.feed_identities(msg)
261 except ValueError:
261 except ValueError:
262 self.log.warn("task::Invalid Message: %r",msg)
262 self.log.warn("task::Invalid Message: %r",msg)
263 return
263 return
264 try:
264 try:
265 msg = self.session.unserialize(msg)
265 msg = self.session.unserialize(msg)
266 except ValueError:
266 except ValueError:
267 self.log.warn("task::Unauthorized message from: %r"%idents)
267 self.log.warn("task::Unauthorized message from: %r"%idents)
268 return
268 return
269
269
270 content = msg['content']
270 content = msg['content']
271 for uuid in content.get('engines', {}).values():
271 for uuid in content.get('engines', {}).values():
272 self._register_engine(cast_bytes(uuid))
272 self._register_engine(cast_bytes(uuid))
273
273
274
274
275 @util.log_errors
275 @util.log_errors
276 def dispatch_notification(self, msg):
276 def dispatch_notification(self, msg):
277 """dispatch register/unregister events."""
277 """dispatch register/unregister events."""
278 try:
278 try:
279 idents,msg = self.session.feed_identities(msg)
279 idents,msg = self.session.feed_identities(msg)
280 except ValueError:
280 except ValueError:
281 self.log.warn("task::Invalid Message: %r",msg)
281 self.log.warn("task::Invalid Message: %r",msg)
282 return
282 return
283 try:
283 try:
284 msg = self.session.unserialize(msg)
284 msg = self.session.unserialize(msg)
285 except ValueError:
285 except ValueError:
286 self.log.warn("task::Unauthorized message from: %r"%idents)
286 self.log.warn("task::Unauthorized message from: %r"%idents)
287 return
287 return
288
288
289 msg_type = msg['header']['msg_type']
289 msg_type = msg['header']['msg_type']
290
290
291 handler = self._notification_handlers.get(msg_type, None)
291 handler = self._notification_handlers.get(msg_type, None)
292 if handler is None:
292 if handler is None:
293 self.log.error("Unhandled message type: %r"%msg_type)
293 self.log.error("Unhandled message type: %r"%msg_type)
294 else:
294 else:
295 try:
295 try:
296 handler(cast_bytes(msg['content']['uuid']))
296 handler(cast_bytes(msg['content']['uuid']))
297 except Exception:
297 except Exception:
298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299
299
300 def _register_engine(self, uid):
300 def _register_engine(self, uid):
301 """New engine with ident `uid` became available."""
301 """New engine with ident `uid` became available."""
302 # head of the line:
302 # head of the line:
303 self.targets.insert(0,uid)
303 self.targets.insert(0,uid)
304 self.loads.insert(0,0)
304 self.loads.insert(0,0)
305
305
306 # initialize sets
306 # initialize sets
307 self.completed[uid] = set()
307 self.completed[uid] = set()
308 self.failed[uid] = set()
308 self.failed[uid] = set()
309 self.pending[uid] = {}
309 self.pending[uid] = {}
310
310
311 # rescan the graph:
311 # rescan the graph:
312 self.update_graph(None)
312 self.update_graph(None)
313
313
314 def _unregister_engine(self, uid):
314 def _unregister_engine(self, uid):
315 """Existing engine with ident `uid` became unavailable."""
315 """Existing engine with ident `uid` became unavailable."""
316 if len(self.targets) == 1:
316 if len(self.targets) == 1:
317 # this was our only engine
317 # this was our only engine
318 pass
318 pass
319
319
320 # handle any potentially finished tasks:
320 # handle any potentially finished tasks:
321 self.engine_stream.flush()
321 self.engine_stream.flush()
322
322
323 # don't pop destinations, because they might be used later
323 # don't pop destinations, because they might be used later
324 # map(self.destinations.pop, self.completed.pop(uid))
324 # map(self.destinations.pop, self.completed.pop(uid))
325 # map(self.destinations.pop, self.failed.pop(uid))
325 # map(self.destinations.pop, self.failed.pop(uid))
326
326
327 # prevent this engine from receiving work
327 # prevent this engine from receiving work
328 idx = self.targets.index(uid)
328 idx = self.targets.index(uid)
329 self.targets.pop(idx)
329 self.targets.pop(idx)
330 self.loads.pop(idx)
330 self.loads.pop(idx)
331
331
332 # wait 5 seconds before cleaning up pending jobs, since the results might
332 # wait 5 seconds before cleaning up pending jobs, since the results might
333 # still be incoming
333 # still be incoming
334 if self.pending[uid]:
334 if self.pending[uid]:
335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 dc.start()
336 dc.start()
337 else:
337 else:
338 self.completed.pop(uid)
338 self.completed.pop(uid)
339 self.failed.pop(uid)
339 self.failed.pop(uid)
340
340
341
341
342 def handle_stranded_tasks(self, engine):
342 def handle_stranded_tasks(self, engine):
343 """Deal with jobs resident in an engine that died."""
343 """Deal with jobs resident in an engine that died."""
344 lost = self.pending[engine]
344 lost = self.pending[engine]
345 for msg_id in lost.keys():
345 for msg_id in lost.keys():
346 if msg_id not in self.pending[engine]:
346 if msg_id not in self.pending[engine]:
347 # prevent double-handling of messages
347 # prevent double-handling of messages
348 continue
348 continue
349
349
350 raw_msg = lost[msg_id].raw_msg
350 raw_msg = lost[msg_id].raw_msg
351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 parent = self.session.unpack(msg[1].bytes)
352 parent = self.session.unpack(msg[1].bytes)
353 idents = [engine, idents[0]]
353 idents = [engine, idents[0]]
354
354
355 # build fake error reply
355 # build fake error reply
356 try:
356 try:
357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 except:
358 except:
359 content = error.wrap_exception()
359 content = error.wrap_exception()
360 # build fake metadata
360 # build fake metadata
361 md = dict(
361 md = dict(
362 status=u'error',
362 status=u'error',
363 engine=engine,
363 engine=engine,
364 date=datetime.now(),
364 date=datetime.now(),
365 )
365 )
366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
368 # and dispatch it
368 # and dispatch it
369 self.dispatch_result(raw_reply)
369 self.dispatch_result(raw_reply)
370
370
371 # finally scrub completed/failed lists
371 # finally scrub completed/failed lists
372 self.completed.pop(engine)
372 self.completed.pop(engine)
373 self.failed.pop(engine)
373 self.failed.pop(engine)
374
374
375
375
376 #-----------------------------------------------------------------------
376 #-----------------------------------------------------------------------
377 # Job Submission
377 # Job Submission
378 #-----------------------------------------------------------------------
378 #-----------------------------------------------------------------------
379
379
380
380
381 @util.log_errors
381 @util.log_errors
382 def dispatch_submission(self, raw_msg):
382 def dispatch_submission(self, raw_msg):
383 """Dispatch job submission to appropriate handlers."""
383 """Dispatch job submission to appropriate handlers."""
384 # ensure targets up to date:
384 # ensure targets up to date:
385 self.notifier_stream.flush()
385 self.notifier_stream.flush()
386 try:
386 try:
387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 msg = self.session.unserialize(msg, content=False, copy=False)
388 msg = self.session.unserialize(msg, content=False, copy=False)
389 except Exception:
389 except Exception:
390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 return
391 return
392
392
393
393
394 # send to monitor
394 # send to monitor
395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396
396
397 header = msg['header']
397 header = msg['header']
398 md = msg['metadata']
398 md = msg['metadata']
399 msg_id = header['msg_id']
399 msg_id = header['msg_id']
400 self.all_ids.add(msg_id)
400 self.all_ids.add(msg_id)
401
401
402 # get targets as a set of bytes objects
402 # get targets as a set of bytes objects
403 # from a list of unicode objects
403 # from a list of unicode objects
404 targets = md.get('targets', [])
404 targets = md.get('targets', [])
405 targets = map(cast_bytes, targets)
405 targets = map(cast_bytes, targets)
406 targets = set(targets)
406 targets = set(targets)
407
407
408 retries = md.get('retries', 0)
408 retries = md.get('retries', 0)
409 self.retries[msg_id] = retries
409 self.retries[msg_id] = retries
410
410
411 # time dependencies
411 # time dependencies
412 after = md.get('after', None)
412 after = md.get('after', None)
413 if after:
413 if after:
414 after = Dependency(after)
414 after = Dependency(after)
415 if after.all:
415 if after.all:
416 if after.success:
416 if after.success:
417 after = Dependency(after.difference(self.all_completed),
417 after = Dependency(after.difference(self.all_completed),
418 success=after.success,
418 success=after.success,
419 failure=after.failure,
419 failure=after.failure,
420 all=after.all,
420 all=after.all,
421 )
421 )
422 if after.failure:
422 if after.failure:
423 after = Dependency(after.difference(self.all_failed),
423 after = Dependency(after.difference(self.all_failed),
424 success=after.success,
424 success=after.success,
425 failure=after.failure,
425 failure=after.failure,
426 all=after.all,
426 all=after.all,
427 )
427 )
428 if after.check(self.all_completed, self.all_failed):
428 if after.check(self.all_completed, self.all_failed):
429 # recast as empty set, if `after` already met,
429 # recast as empty set, if `after` already met,
430 # to prevent unnecessary set comparisons
430 # to prevent unnecessary set comparisons
431 after = MET
431 after = MET
432 else:
432 else:
433 after = MET
433 after = MET
434
434
435 # location dependencies
435 # location dependencies
436 follow = Dependency(md.get('follow', []))
436 follow = Dependency(md.get('follow', []))
437
437
438 # turn timeouts into datetime objects:
438 # turn timeouts into datetime objects:
439 timeout = md.get('timeout', None)
439 timeout = md.get('timeout', None)
440 if timeout:
440 if timeout:
441 # cast to float, because jsonlib returns floats as decimal.Decimal,
441 # cast to float, because jsonlib returns floats as decimal.Decimal,
442 # which timedelta does not accept
442 # which timedelta does not accept
443 timeout = datetime.now() + timedelta(0,float(timeout),0)
443 timeout = datetime.now() + timedelta(0,float(timeout),0)
444
444
445 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
445 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
446 header=header, targets=targets, after=after, follow=follow,
446 header=header, targets=targets, after=after, follow=follow,
447 timeout=timeout, metadata=md,
447 timeout=timeout, metadata=md,
448 )
448 )
449
449
450 # validate and reduce dependencies:
450 # validate and reduce dependencies:
451 for dep in after,follow:
451 for dep in after,follow:
452 if not dep: # empty dependency
452 if not dep: # empty dependency
453 continue
453 continue
454 # check valid:
454 # check valid:
455 if msg_id in dep or dep.difference(self.all_ids):
455 if msg_id in dep or dep.difference(self.all_ids):
456 self.queue_map[msg_id] = job
456 self.queue_map[msg_id] = job
457 return self.fail_unreachable(msg_id, error.InvalidDependency)
457 return self.fail_unreachable(msg_id, error.InvalidDependency)
458 # check if unreachable:
458 # check if unreachable:
459 if dep.unreachable(self.all_completed, self.all_failed):
459 if dep.unreachable(self.all_completed, self.all_failed):
460 self.queue_map[msg_id] = job
460 self.queue_map[msg_id] = job
461 return self.fail_unreachable(msg_id)
461 return self.fail_unreachable(msg_id)
462
462
463 if after.check(self.all_completed, self.all_failed):
463 if after.check(self.all_completed, self.all_failed):
464 # time deps already met, try to run
464 # time deps already met, try to run
465 if not self.maybe_run(job):
465 if not self.maybe_run(job):
466 # can't run yet
466 # can't run yet
467 if msg_id not in self.all_failed:
467 if msg_id not in self.all_failed:
468 # could have failed as unreachable
468 # could have failed as unreachable
469 self.save_unmet(job)
469 self.save_unmet(job)
470 else:
470 else:
471 self.save_unmet(job)
471 self.save_unmet(job)
472
472
473 def audit_timeouts(self):
473 def audit_timeouts(self):
474 """Audit all waiting tasks for expired timeouts."""
474 """Audit all waiting tasks for expired timeouts."""
475 now = datetime.now()
475 now = datetime.now()
476 for msg_id in self.queue_map.keys():
476 for msg_id in self.queue_map.keys():
477 # must recheck, in case one failure cascaded to another:
477 # must recheck, in case one failure cascaded to another:
478 if msg_id in self.queue_map:
478 if msg_id in self.queue_map:
479 job = self.queue_map[msg_id]
479 job = self.queue_map[msg_id]
480 if job.timeout and job.timeout < now:
480 if job.timeout and job.timeout < now:
481 self.fail_unreachable(msg_id, error.TaskTimeout)
481 self.fail_unreachable(msg_id, error.TaskTimeout)
482
482
483 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
483 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
484 """a task has become unreachable, send a reply with an ImpossibleDependency
484 """a task has become unreachable, send a reply with an ImpossibleDependency
485 error."""
485 error."""
486 if msg_id not in self.queue_map:
486 if msg_id not in self.queue_map:
487 self.log.error("msg %r already failed!", msg_id)
487 self.log.error("msg %r already failed!", msg_id)
488 return
488 return
489 job = self.queue_map.pop(msg_id)
489 job = self.queue_map.pop(msg_id)
490 # lazy-delete from the queue
490 # lazy-delete from the queue
491 job.removed = True
491 job.removed = True
492 for mid in job.dependents:
492 for mid in job.dependents:
493 if mid in self.graph:
493 if mid in self.graph:
494 self.graph[mid].remove(msg_id)
494 self.graph[mid].remove(msg_id)
495
495
496 try:
496 try:
497 raise why()
497 raise why()
498 except:
498 except:
499 content = error.wrap_exception()
499 content = error.wrap_exception()
500
500
501 self.all_done.add(msg_id)
501 self.all_done.add(msg_id)
502 self.all_failed.add(msg_id)
502 self.all_failed.add(msg_id)
503
503
504 msg = self.session.send(self.client_stream, 'apply_reply', content,
504 msg = self.session.send(self.client_stream, 'apply_reply', content,
505 parent=job.header, ident=job.idents)
505 parent=job.header, ident=job.idents)
506 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
506 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
507
507
508 self.update_graph(msg_id, success=False)
508 self.update_graph(msg_id, success=False)
509
509
510 def available_engines(self):
510 def available_engines(self):
511 """return a list of available engine indices based on HWM"""
511 """return a list of available engine indices based on HWM"""
512 if not self.hwm:
512 if not self.hwm:
513 return range(len(self.targets))
513 return range(len(self.targets))
514 available = []
514 available = []
515 for idx in range(len(self.targets)):
515 for idx in range(len(self.targets)):
516 if self.loads[idx] < self.hwm:
516 if self.loads[idx] < self.hwm:
517 available.append(idx)
517 available.append(idx)
518 return available
518 return available
519
519
520 def maybe_run(self, job):
520 def maybe_run(self, job):
521 """check location dependencies, and run if they are met."""
521 """check location dependencies, and run if they are met."""
522 msg_id = job.msg_id
522 msg_id = job.msg_id
523 self.log.debug("Attempting to assign task %s", msg_id)
523 self.log.debug("Attempting to assign task %s", msg_id)
524 available = self.available_engines()
524 available = self.available_engines()
525 if not available:
525 if not available:
526 # no engines, definitely can't run
526 # no engines, definitely can't run
527 return False
527 return False
528
528
529 if job.follow or job.targets or job.blacklist or self.hwm:
529 if job.follow or job.targets or job.blacklist or self.hwm:
530 # we need a can_run filter
530 # we need a can_run filter
531 def can_run(idx):
531 def can_run(idx):
532 # check hwm
532 # check hwm
533 if self.hwm and self.loads[idx] == self.hwm:
533 if self.hwm and self.loads[idx] == self.hwm:
534 return False
534 return False
535 target = self.targets[idx]
535 target = self.targets[idx]
536 # check blacklist
536 # check blacklist
537 if target in job.blacklist:
537 if target in job.blacklist:
538 return False
538 return False
539 # check targets
539 # check targets
540 if job.targets and target not in job.targets:
540 if job.targets and target not in job.targets:
541 return False
541 return False
542 # check follow
542 # check follow
543 return job.follow.check(self.completed[target], self.failed[target])
543 return job.follow.check(self.completed[target], self.failed[target])
544
544
545 indices = filter(can_run, available)
545 indices = filter(can_run, available)
546
546
547 if not indices:
547 if not indices:
548 # couldn't run
548 # couldn't run
549 if job.follow.all:
549 if job.follow.all:
550 # check follow for impossibility
550 # check follow for impossibility
551 dests = set()
551 dests = set()
552 relevant = set()
552 relevant = set()
553 if job.follow.success:
553 if job.follow.success:
554 relevant = self.all_completed
554 relevant = self.all_completed
555 if job.follow.failure:
555 if job.follow.failure:
556 relevant = relevant.union(self.all_failed)
556 relevant = relevant.union(self.all_failed)
557 for m in job.follow.intersection(relevant):
557 for m in job.follow.intersection(relevant):
558 dests.add(self.destinations[m])
558 dests.add(self.destinations[m])
559 if len(dests) > 1:
559 if len(dests) > 1:
560 self.queue_map[msg_id] = job
560 self.queue_map[msg_id] = job
561 self.fail_unreachable(msg_id)
561 self.fail_unreachable(msg_id)
562 return False
562 return False
563 if job.targets:
563 if job.targets:
564 # check blacklist+targets for impossibility
564 # check blacklist+targets for impossibility
565 job.targets.difference_update(job.blacklist)
565 job.targets.difference_update(job.blacklist)
566 if not job.targets or not job.targets.intersection(self.targets):
566 if not job.targets or not job.targets.intersection(self.targets):
567 self.queue_map[msg_id] = job
567 self.queue_map[msg_id] = job
568 self.fail_unreachable(msg_id)
568 self.fail_unreachable(msg_id)
569 return False
569 return False
570 return False
570 return False
571 else:
571 else:
572 indices = None
572 indices = None
573
573
574 self.submit_task(job, indices)
574 self.submit_task(job, indices)
575 return True
575 return True
576
576
577 def save_unmet(self, job):
577 def save_unmet(self, job):
578 """Save a message for later submission when its dependencies are met."""
578 """Save a message for later submission when its dependencies are met."""
579 msg_id = job.msg_id
579 msg_id = job.msg_id
580 self.log.debug("Adding task %s to the queue", msg_id)
580 self.queue_map[msg_id] = job
581 self.queue_map[msg_id] = job
581 heapq.heappush(self.queue, job)
582 heapq.heappush(self.queue, job)
582 # track the ids in follow or after, but not those already finished
583 # track the ids in follow or after, but not those already finished
583 for dep_id in job.after.union(job.follow).difference(self.all_done):
584 for dep_id in job.after.union(job.follow).difference(self.all_done):
584 if dep_id not in self.graph:
585 if dep_id not in self.graph:
585 self.graph[dep_id] = set()
586 self.graph[dep_id] = set()
586 self.graph[dep_id].add(msg_id)
587 self.graph[dep_id].add(msg_id)
587
588
588 def submit_task(self, job, indices=None):
589 def submit_task(self, job, indices=None):
589 """Submit a task to any of a subset of our targets."""
590 """Submit a task to any of a subset of our targets."""
590 if indices:
591 if indices:
591 loads = [self.loads[i] for i in indices]
592 loads = [self.loads[i] for i in indices]
592 else:
593 else:
593 loads = self.loads
594 loads = self.loads
594 idx = self.scheme(loads)
595 idx = self.scheme(loads)
595 if indices:
596 if indices:
596 idx = indices[idx]
597 idx = indices[idx]
597 target = self.targets[idx]
598 target = self.targets[idx]
598 # print (target, map(str, msg[:3]))
599 # print (target, map(str, msg[:3]))
599 # send job to the engine
600 # send job to the engine
600 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
601 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
601 self.engine_stream.send_multipart(job.raw_msg, copy=False)
602 self.engine_stream.send_multipart(job.raw_msg, copy=False)
602 # update load
603 # update load
603 self.add_job(idx)
604 self.add_job(idx)
604 self.pending[target][job.msg_id] = job
605 self.pending[target][job.msg_id] = job
605 # notify Hub
606 # notify Hub
606 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
607 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
607 self.session.send(self.mon_stream, 'task_destination', content=content,
608 self.session.send(self.mon_stream, 'task_destination', content=content,
608 ident=[b'tracktask',self.ident])
609 ident=[b'tracktask',self.ident])
609
610
610
611
611 #-----------------------------------------------------------------------
612 #-----------------------------------------------------------------------
612 # Result Handling
613 # Result Handling
613 #-----------------------------------------------------------------------
614 #-----------------------------------------------------------------------
614
615
615
616
616 @util.log_errors
617 @util.log_errors
617 def dispatch_result(self, raw_msg):
618 def dispatch_result(self, raw_msg):
618 """dispatch method for result replies"""
619 """dispatch method for result replies"""
619 try:
620 try:
620 idents,msg = self.session.feed_identities(raw_msg, copy=False)
621 idents,msg = self.session.feed_identities(raw_msg, copy=False)
621 msg = self.session.unserialize(msg, content=False, copy=False)
622 msg = self.session.unserialize(msg, content=False, copy=False)
622 engine = idents[0]
623 engine = idents[0]
623 try:
624 try:
624 idx = self.targets.index(engine)
625 idx = self.targets.index(engine)
625 except ValueError:
626 except ValueError:
626 pass # skip load-update for dead engines
627 pass # skip load-update for dead engines
627 else:
628 else:
628 self.finish_job(idx)
629 self.finish_job(idx)
629 except Exception:
630 except Exception:
630 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
631 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
631 return
632 return
632
633
633 md = msg['metadata']
634 md = msg['metadata']
634 parent = msg['parent_header']
635 parent = msg['parent_header']
635 if md.get('dependencies_met', True):
636 if md.get('dependencies_met', True):
636 success = (md['status'] == 'ok')
637 success = (md['status'] == 'ok')
637 msg_id = parent['msg_id']
638 msg_id = parent['msg_id']
638 retries = self.retries[msg_id]
639 retries = self.retries[msg_id]
639 if not success and retries > 0:
640 if not success and retries > 0:
640 # failed
641 # failed
641 self.retries[msg_id] = retries - 1
642 self.retries[msg_id] = retries - 1
642 self.handle_unmet_dependency(idents, parent)
643 self.handle_unmet_dependency(idents, parent)
643 else:
644 else:
644 del self.retries[msg_id]
645 del self.retries[msg_id]
645 # relay to client and update graph
646 # relay to client and update graph
646 self.handle_result(idents, parent, raw_msg, success)
647 self.handle_result(idents, parent, raw_msg, success)
647 # send to Hub monitor
648 # send to Hub monitor
648 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
649 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
649 else:
650 else:
650 self.handle_unmet_dependency(idents, parent)
651 self.handle_unmet_dependency(idents, parent)
651
652
652 def handle_result(self, idents, parent, raw_msg, success=True):
653 def handle_result(self, idents, parent, raw_msg, success=True):
653 """handle a real task result, either success or failure"""
654 """handle a real task result, either success or failure"""
654 # first, relay result to client
655 # first, relay result to client
655 engine = idents[0]
656 engine = idents[0]
656 client = idents[1]
657 client = idents[1]
657 # swap_ids for ROUTER-ROUTER mirror
658 # swap_ids for ROUTER-ROUTER mirror
658 raw_msg[:2] = [client,engine]
659 raw_msg[:2] = [client,engine]
659 # print (map(str, raw_msg[:4]))
660 # print (map(str, raw_msg[:4]))
660 self.client_stream.send_multipart(raw_msg, copy=False)
661 self.client_stream.send_multipart(raw_msg, copy=False)
661 # now, update our data structures
662 # now, update our data structures
662 msg_id = parent['msg_id']
663 msg_id = parent['msg_id']
663 self.pending[engine].pop(msg_id)
664 self.pending[engine].pop(msg_id)
664 if success:
665 if success:
665 self.completed[engine].add(msg_id)
666 self.completed[engine].add(msg_id)
666 self.all_completed.add(msg_id)
667 self.all_completed.add(msg_id)
667 else:
668 else:
668 self.failed[engine].add(msg_id)
669 self.failed[engine].add(msg_id)
669 self.all_failed.add(msg_id)
670 self.all_failed.add(msg_id)
670 self.all_done.add(msg_id)
671 self.all_done.add(msg_id)
671 self.destinations[msg_id] = engine
672 self.destinations[msg_id] = engine
672
673
673 self.update_graph(msg_id, success)
674 self.update_graph(msg_id, success)
674
675
675 def handle_unmet_dependency(self, idents, parent):
676 def handle_unmet_dependency(self, idents, parent):
676 """handle an unmet dependency"""
677 """handle an unmet dependency"""
677 engine = idents[0]
678 engine = idents[0]
678 msg_id = parent['msg_id']
679 msg_id = parent['msg_id']
679
680
680 job = self.pending[engine].pop(msg_id)
681 job = self.pending[engine].pop(msg_id)
681 job.blacklist.add(engine)
682 job.blacklist.add(engine)
682
683
683 if job.blacklist == job.targets:
684 if job.blacklist == job.targets:
684 self.queue_map[msg_id] = job
685 self.queue_map[msg_id] = job
685 self.fail_unreachable(msg_id)
686 self.fail_unreachable(msg_id)
686 elif not self.maybe_run(job):
687 elif not self.maybe_run(job):
687 # resubmit failed
688 # resubmit failed
688 if msg_id not in self.all_failed:
689 if msg_id not in self.all_failed:
689 # put it back in our dependency tree
690 # put it back in our dependency tree
690 self.save_unmet(job)
691 self.save_unmet(job)
691
692
692 if self.hwm:
693 if self.hwm:
693 try:
694 try:
694 idx = self.targets.index(engine)
695 idx = self.targets.index(engine)
695 except ValueError:
696 except ValueError:
696 pass # skip load-update for dead engines
697 pass # skip load-update for dead engines
697 else:
698 else:
698 if self.loads[idx] == self.hwm-1:
699 if self.loads[idx] == self.hwm-1:
699 self.update_graph(None)
700 self.update_graph(None)
700
701
701 def update_graph(self, dep_id=None, success=True):
702 def update_graph(self, dep_id=None, success=True):
702 """dep_id just finished. Update our dependency
703 """dep_id just finished. Update our dependency
703 graph and submit any jobs that just became runnable.
704 graph and submit any jobs that just became runnable.
704
705
705 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
706 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
706 """
707 """
707 # print ("\n\n***********")
708 # print ("\n\n***********")
708 # pprint (dep_id)
709 # pprint (dep_id)
709 # pprint (self.graph)
710 # pprint (self.graph)
710 # pprint (self.queue_map)
711 # pprint (self.queue_map)
711 # pprint (self.all_completed)
712 # pprint (self.all_completed)
712 # pprint (self.all_failed)
713 # pprint (self.all_failed)
713 # print ("\n\n***********\n\n")
714 # print ("\n\n***********\n\n")
714 # update any jobs that depended on the dependency
715 # update any jobs that depended on the dependency
715 msg_ids = self.graph.pop(dep_id, [])
716 msg_ids = self.graph.pop(dep_id, [])
716
717
717 # recheck *all* jobs if
718 # recheck *all* jobs if
718 # a) we have HWM and an engine just become no longer full
719 # a) we have HWM and an engine just become no longer full
719 # or b) dep_id was given as None
720 # or b) dep_id was given as None
720
721
721 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
722 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
722 jobs = self.queue
723 jobs = self.queue
723 using_queue = True
724 using_queue = True
724 else:
725 else:
725 using_queue = False
726 using_queue = False
726 jobs = heapq.heapify([ self.queue_map[msg_id] for msg_id in msg_ids ])
727 jobs = heapq.heapify([ self.queue_map[msg_id] for msg_id in msg_ids ])
727
728
728 to_restore = []
729 to_restore = []
729 while jobs:
730 while jobs:
730 job = heapq.heappop(jobs)
731 job = heapq.heappop(jobs)
731 if job.removed:
732 if job.removed:
732 continue
733 continue
733 msg_id = job.msg_id
734 msg_id = job.msg_id
734
735
735 put_it_back = True
736 put_it_back = True
736
737
737 if job.after.unreachable(self.all_completed, self.all_failed)\
738 if job.after.unreachable(self.all_completed, self.all_failed)\
738 or job.follow.unreachable(self.all_completed, self.all_failed):
739 or job.follow.unreachable(self.all_completed, self.all_failed):
739 self.fail_unreachable(msg_id)
740 self.fail_unreachable(msg_id)
740 put_it_back = False
741 put_it_back = False
741
742
742 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
743 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
743 if self.maybe_run(job):
744 if self.maybe_run(job):
744 put_it_back = False
745 put_it_back = False
745 self.queue_map.pop(msg_id)
746 self.queue_map.pop(msg_id)
746 for mid in job.dependents:
747 for mid in job.dependents:
747 if mid in self.graph:
748 if mid in self.graph:
748 self.graph[mid].remove(msg_id)
749 self.graph[mid].remove(msg_id)
749
750
750 # abort the loop if we just filled up all of our engines.
751 # abort the loop if we just filled up all of our engines.
751 # avoids an O(N) operation in situation of full queue,
752 # avoids an O(N) operation in situation of full queue,
752 # where graph update is triggered as soon as an engine becomes
753 # where graph update is triggered as soon as an engine becomes
753 # non-full, and all tasks after the first are checked,
754 # non-full, and all tasks after the first are checked,
754 # even though they can't run.
755 # even though they can't run.
755 if not self.available_engines():
756 if not self.available_engines():
756 break
757 break
757
758
758 if using_queue and put_it_back:
759 if using_queue and put_it_back:
759 # popped a job from the queue but it neither ran nor failed,
760 # popped a job from the queue but it neither ran nor failed,
760 # so we need to put it back when we are done
761 # so we need to put it back when we are done
761 to_restore.append(job)
762 to_restore.append(job)
762
763
763 # put back any tasks we popped but didn't run
764 # put back any tasks we popped but didn't run
764 for job in to_restore:
765 for job in to_restore:
765 heapq.heappush(self.queue, job)
766 heapq.heappush(self.queue, job)
766
767
767
768
768 #----------------------------------------------------------------------
769 #----------------------------------------------------------------------
769 # methods to be overridden by subclasses
770 # methods to be overridden by subclasses
770 #----------------------------------------------------------------------
771 #----------------------------------------------------------------------
771
772
772 def add_job(self, idx):
773 def add_job(self, idx):
773 """Called after self.targets[idx] just got the job with header.
774 """Called after self.targets[idx] just got the job with header.
774 Override with subclasses. The default ordering is simple LRU.
775 Override with subclasses. The default ordering is simple LRU.
775 The default loads are the number of outstanding jobs."""
776 The default loads are the number of outstanding jobs."""
776 self.loads[idx] += 1
777 self.loads[idx] += 1
777 for lis in (self.targets, self.loads):
778 for lis in (self.targets, self.loads):
778 lis.append(lis.pop(idx))
779 lis.append(lis.pop(idx))
779
780
780
781
781 def finish_job(self, idx):
782 def finish_job(self, idx):
782 """Called after self.targets[idx] just finished a job.
783 """Called after self.targets[idx] just finished a job.
783 Override with subclasses."""
784 Override with subclasses."""
784 self.loads[idx] -= 1
785 self.loads[idx] -= 1
785
786
786
787
787
788
788 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
789 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
789 logname='root', log_url=None, loglevel=logging.DEBUG,
790 logname='root', log_url=None, loglevel=logging.DEBUG,
790 identity=b'task', in_thread=False):
791 identity=b'task', in_thread=False):
791
792
792 ZMQStream = zmqstream.ZMQStream
793 ZMQStream = zmqstream.ZMQStream
793
794
794 if config:
795 if config:
795 # unwrap dict back into Config
796 # unwrap dict back into Config
796 config = Config(config)
797 config = Config(config)
797
798
798 if in_thread:
799 if in_thread:
799 # use instance() to get the same Context/Loop as our parent
800 # use instance() to get the same Context/Loop as our parent
800 ctx = zmq.Context.instance()
801 ctx = zmq.Context.instance()
801 loop = ioloop.IOLoop.instance()
802 loop = ioloop.IOLoop.instance()
802 else:
803 else:
803 # in a process, don't use instance()
804 # in a process, don't use instance()
804 # for safety with multiprocessing
805 # for safety with multiprocessing
805 ctx = zmq.Context()
806 ctx = zmq.Context()
806 loop = ioloop.IOLoop()
807 loop = ioloop.IOLoop()
807 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
808 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
808 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
809 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
809 ins.bind(in_addr)
810 ins.bind(in_addr)
810
811
811 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
812 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
812 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
813 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
813 outs.bind(out_addr)
814 outs.bind(out_addr)
814 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
815 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
815 mons.connect(mon_addr)
816 mons.connect(mon_addr)
816 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
817 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
817 nots.setsockopt(zmq.SUBSCRIBE, b'')
818 nots.setsockopt(zmq.SUBSCRIBE, b'')
818 nots.connect(not_addr)
819 nots.connect(not_addr)
819
820
820 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
821 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
821 querys.connect(reg_addr)
822 querys.connect(reg_addr)
822
823
823 # setup logging.
824 # setup logging.
824 if in_thread:
825 if in_thread:
825 log = Application.instance().log
826 log = Application.instance().log
826 else:
827 else:
827 if log_url:
828 if log_url:
828 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
829 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
829 else:
830 else:
830 log = local_logger(logname, loglevel)
831 log = local_logger(logname, loglevel)
831
832
832 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
833 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
833 mon_stream=mons, notifier_stream=nots,
834 mon_stream=mons, notifier_stream=nots,
834 query_stream=querys,
835 query_stream=querys,
835 loop=loop, log=log,
836 loop=loop, log=log,
836 config=config)
837 config=config)
837 scheduler.start()
838 scheduler.start()
838 if not in_thread:
839 if not in_thread:
839 try:
840 try:
840 loop.start()
841 loop.start()
841 except KeyboardInterrupt:
842 except KeyboardInterrupt:
842 scheduler.log.critical("Interrupted, exiting...")
843 scheduler.log.critical("Interrupted, exiting...")
843
844
General Comments 0
You need to be logged in to leave comments. Login now