##// END OF EJS Templates
Simplify structure of a Job in the TaskScheduler...
MinRK -
Show More
@@ -1,738 +1,759 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 import time
26 27
27 28 from datetime import datetime, timedelta
28 29 from random import randint, random
29 30 from types import FunctionType
30 31
31 32 try:
32 33 import numpy
33 34 except ImportError:
34 35 numpy = None
35 36
36 37 import zmq
37 38 from zmq.eventloop import ioloop, zmqstream
38 39
39 40 # local imports
40 41 from IPython.external.decorator import decorator
41 42 from IPython.config.application import Application
42 43 from IPython.config.loader import Config
43 44 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 45
45 46 from IPython.parallel import error
46 47 from IPython.parallel.factory import SessionFactory
47 48 from IPython.parallel.util import connect_logger, local_logger, asbytes
48 49
49 50 from .dependency import Dependency
50 51
51 52 @decorator
52 53 def logged(f,self,*args,**kwargs):
53 54 # print ("#--------------------")
54 55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 56 # print ("#--")
56 57 return f(self,*args, **kwargs)
57 58
58 59 #----------------------------------------------------------------------
59 60 # Chooser functions
60 61 #----------------------------------------------------------------------
61 62
62 63 def plainrandom(loads):
63 64 """Plain random pick."""
64 65 n = len(loads)
65 66 return randint(0,n-1)
66 67
67 68 def lru(loads):
68 69 """Always pick the front of the line.
69 70
70 71 The content of `loads` is ignored.
71 72
72 73 Assumes LRU ordering of loads, with oldest first.
73 74 """
74 75 return 0
75 76
76 77 def twobin(loads):
77 78 """Pick two at random, use the LRU of the two.
78 79
79 80 The content of loads is ignored.
80 81
81 82 Assumes LRU ordering of loads, with oldest first.
82 83 """
83 84 n = len(loads)
84 85 a = randint(0,n-1)
85 86 b = randint(0,n-1)
86 87 return min(a,b)
87 88
88 89 def weighted(loads):
89 90 """Pick two at random using inverse load as weight.
90 91
91 92 Return the less loaded of the two.
92 93 """
93 94 # weight 0 a million times more than 1:
94 95 weights = 1./(1e-6+numpy.array(loads))
95 96 sums = weights.cumsum()
96 97 t = sums[-1]
97 98 x = random()*t
98 99 y = random()*t
99 100 idx = 0
100 101 idy = 0
101 102 while sums[idx] < x:
102 103 idx += 1
103 104 while sums[idy] < y:
104 105 idy += 1
105 106 if weights[idy] > weights[idx]:
106 107 return idy
107 108 else:
108 109 return idx
109 110
110 111 def leastload(loads):
111 112 """Always choose the lowest load.
112 113
113 114 If the lowest load occurs more than once, the first
114 115 occurance will be used. If loads has LRU ordering, this means
115 116 the LRU of those with the lowest load is chosen.
116 117 """
117 118 return loads.index(min(loads))
118 119
119 120 #---------------------------------------------------------------------
120 121 # Classes
121 122 #---------------------------------------------------------------------
123
124
122 125 # store empty default dependency:
123 126 MET = Dependency([])
124 127
128
129 class Job(object):
130 """Simple container for a job"""
131 def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout):
132 self.msg_id = msg_id
133 self.raw_msg = raw_msg
134 self.idents = idents
135 self.msg = msg
136 self.header = header
137 self.targets = targets
138 self.after = after
139 self.follow = follow
140 self.timeout = timeout
141
142
143 self.timestamp = time.time()
144 self.blacklist = set()
145
146 @property
147 def dependents(self):
148 return self.follow.union(self.after)
149
125 150 class TaskScheduler(SessionFactory):
126 151 """Python TaskScheduler object.
127 152
128 153 This is the simplest object that supports msg_id based
129 154 DAG dependencies. *Only* task msg_ids are checked, not
130 155 msg_ids of jobs submitted via the MUX queue.
131 156
132 157 """
133 158
134 159 hwm = Integer(1, config=True,
135 160 help="""specify the High Water Mark (HWM) for the downstream
136 161 socket in the Task scheduler. This is the maximum number
137 162 of allowed outstanding tasks on each engine.
138 163
139 164 The default (1) means that only one task can be outstanding on each
140 165 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
141 166 engines continue to be assigned tasks while they are working,
142 167 effectively hiding network latency behind computation, but can result
143 168 in an imbalance of work when submitting many heterogenous tasks all at
144 169 once. Any positive value greater than one is a compromise between the
145 170 two.
146 171
147 172 """
148 173 )
149 174 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
150 175 'leastload', config=True, allow_none=False,
151 176 help="""select the task scheduler scheme [default: Python LRU]
152 177 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
153 178 )
154 179 def _scheme_name_changed(self, old, new):
155 180 self.log.debug("Using scheme %r"%new)
156 181 self.scheme = globals()[new]
157 182
158 183 # input arguments:
159 184 scheme = Instance(FunctionType) # function for determining the destination
160 185 def _scheme_default(self):
161 186 return leastload
162 187 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
163 188 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
164 189 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
165 190 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
166 191
167 192 # internals:
168 193 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
169 194 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
170 195 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
171 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
196 depending = Dict() # dict by msg_id of Jobs
172 197 pending = Dict() # dict by engine_uuid of submitted tasks
173 198 completed = Dict() # dict by engine_uuid of completed tasks
174 199 failed = Dict() # dict by engine_uuid of failed tasks
175 200 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
176 201 clients = Dict() # dict by msg_id for who submitted the task
177 202 targets = List() # list of target IDENTs
178 203 loads = List() # list of engine loads
179 204 # full = Set() # set of IDENTs that have HWM outstanding tasks
180 205 all_completed = Set() # set of all completed tasks
181 206 all_failed = Set() # set of all failed tasks
182 207 all_done = Set() # set of all finished tasks=union(completed,failed)
183 208 all_ids = Set() # set of all submitted task IDs
184 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
209
185 210 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
186 211
187 212 ident = CBytes() # ZMQ identity. This should just be self.session.session
188 213 # but ensure Bytes
189 214 def _ident_default(self):
190 215 return self.session.bsession
191 216
192 217 def start(self):
193 218 self.engine_stream.on_recv(self.dispatch_result, copy=False)
194 219 self.client_stream.on_recv(self.dispatch_submission, copy=False)
195 220
196 221 self._notification_handlers = dict(
197 222 registration_notification = self._register_engine,
198 223 unregistration_notification = self._unregister_engine
199 224 )
200 225 self.notifier_stream.on_recv(self.dispatch_notification)
201 226 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
202 227 self.auditor.start()
203 228 self.log.info("Scheduler started [%s]"%self.scheme_name)
204 229
205 230 def resume_receiving(self):
206 231 """Resume accepting jobs."""
207 232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
208 233
209 234 def stop_receiving(self):
210 235 """Stop accepting jobs while there are no engines.
211 236 Leave them in the ZMQ queue."""
212 237 self.client_stream.on_recv(None)
213 238
214 239 #-----------------------------------------------------------------------
215 240 # [Un]Registration Handling
216 241 #-----------------------------------------------------------------------
217 242
218 243 def dispatch_notification(self, msg):
219 244 """dispatch register/unregister events."""
220 245 try:
221 246 idents,msg = self.session.feed_identities(msg)
222 247 except ValueError:
223 248 self.log.warn("task::Invalid Message: %r",msg)
224 249 return
225 250 try:
226 251 msg = self.session.unserialize(msg)
227 252 except ValueError:
228 253 self.log.warn("task::Unauthorized message from: %r"%idents)
229 254 return
230 255
231 256 msg_type = msg['header']['msg_type']
232 257
233 258 handler = self._notification_handlers.get(msg_type, None)
234 259 if handler is None:
235 260 self.log.error("Unhandled message type: %r"%msg_type)
236 261 else:
237 262 try:
238 263 handler(asbytes(msg['content']['queue']))
239 264 except Exception:
240 265 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
241 266
242 267 def _register_engine(self, uid):
243 268 """New engine with ident `uid` became available."""
244 269 # head of the line:
245 270 self.targets.insert(0,uid)
246 271 self.loads.insert(0,0)
247 272
248 273 # initialize sets
249 274 self.completed[uid] = set()
250 275 self.failed[uid] = set()
251 276 self.pending[uid] = {}
252 277
253 278 # rescan the graph:
254 279 self.update_graph(None)
255 280
256 281 def _unregister_engine(self, uid):
257 282 """Existing engine with ident `uid` became unavailable."""
258 283 if len(self.targets) == 1:
259 284 # this was our only engine
260 285 pass
261 286
262 287 # handle any potentially finished tasks:
263 288 self.engine_stream.flush()
264 289
265 290 # don't pop destinations, because they might be used later
266 291 # map(self.destinations.pop, self.completed.pop(uid))
267 292 # map(self.destinations.pop, self.failed.pop(uid))
268 293
269 294 # prevent this engine from receiving work
270 295 idx = self.targets.index(uid)
271 296 self.targets.pop(idx)
272 297 self.loads.pop(idx)
273 298
274 299 # wait 5 seconds before cleaning up pending jobs, since the results might
275 300 # still be incoming
276 301 if self.pending[uid]:
277 302 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
278 303 dc.start()
279 304 else:
280 305 self.completed.pop(uid)
281 306 self.failed.pop(uid)
282 307
283 308
284 309 def handle_stranded_tasks(self, engine):
285 310 """Deal with jobs resident in an engine that died."""
286 311 lost = self.pending[engine]
287 312 for msg_id in lost.keys():
288 313 if msg_id not in self.pending[engine]:
289 314 # prevent double-handling of messages
290 315 continue
291 316
292 317 raw_msg = lost[msg_id][0]
293 318 idents,msg = self.session.feed_identities(raw_msg, copy=False)
294 319 parent = self.session.unpack(msg[1].bytes)
295 320 idents = [engine, idents[0]]
296 321
297 322 # build fake error reply
298 323 try:
299 324 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
300 325 except:
301 326 content = error.wrap_exception()
302 327 # build fake header
303 328 header = dict(
304 329 status='error',
305 330 engine=engine,
306 331 date=datetime.now(),
307 332 )
308 333 msg = self.session.msg('apply_reply', content, parent=parent, subheader=header)
309 334 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
310 335 # and dispatch it
311 336 self.dispatch_result(raw_reply)
312 337
313 338 # finally scrub completed/failed lists
314 339 self.completed.pop(engine)
315 340 self.failed.pop(engine)
316 341
317 342
318 343 #-----------------------------------------------------------------------
319 344 # Job Submission
320 345 #-----------------------------------------------------------------------
321 346 def dispatch_submission(self, raw_msg):
322 347 """Dispatch job submission to appropriate handlers."""
323 348 # ensure targets up to date:
324 349 self.notifier_stream.flush()
325 350 try:
326 351 idents, msg = self.session.feed_identities(raw_msg, copy=False)
327 352 msg = self.session.unserialize(msg, content=False, copy=False)
328 353 except Exception:
329 354 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
330 355 return
331 356
332 357
333 358 # send to monitor
334 359 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
335 360
336 361 header = msg['header']
337 362 msg_id = header['msg_id']
338 363 self.all_ids.add(msg_id)
339 364
340 365 # get targets as a set of bytes objects
341 366 # from a list of unicode objects
342 367 targets = header.get('targets', [])
343 368 targets = map(asbytes, targets)
344 369 targets = set(targets)
345 370
346 371 retries = header.get('retries', 0)
347 372 self.retries[msg_id] = retries
348 373
349 374 # time dependencies
350 375 after = header.get('after', None)
351 376 if after:
352 377 after = Dependency(after)
353 378 if after.all:
354 379 if after.success:
355 380 after = Dependency(after.difference(self.all_completed),
356 381 success=after.success,
357 382 failure=after.failure,
358 383 all=after.all,
359 384 )
360 385 if after.failure:
361 386 after = Dependency(after.difference(self.all_failed),
362 387 success=after.success,
363 388 failure=after.failure,
364 389 all=after.all,
365 390 )
366 391 if after.check(self.all_completed, self.all_failed):
367 392 # recast as empty set, if `after` already met,
368 393 # to prevent unnecessary set comparisons
369 394 after = MET
370 395 else:
371 396 after = MET
372 397
373 398 # location dependencies
374 399 follow = Dependency(header.get('follow', []))
375 400
376 401 # turn timeouts into datetime objects:
377 402 timeout = header.get('timeout', None)
378 403 if timeout:
379 404 # cast to float, because jsonlib returns floats as decimal.Decimal,
380 405 # which timedelta does not accept
381 406 timeout = datetime.now() + timedelta(0,float(timeout),0)
382 407
383 args = [raw_msg, targets, after, follow, timeout]
408 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
409 header=header, targets=targets, after=after, follow=follow,
410 timeout=timeout,
411 )
384 412
385 413 # validate and reduce dependencies:
386 414 for dep in after,follow:
387 415 if not dep: # empty dependency
388 416 continue
389 417 # check valid:
390 418 if msg_id in dep or dep.difference(self.all_ids):
391 self.depending[msg_id] = args
419 self.depending[msg_id] = job
392 420 return self.fail_unreachable(msg_id, error.InvalidDependency)
393 421 # check if unreachable:
394 422 if dep.unreachable(self.all_completed, self.all_failed):
395 self.depending[msg_id] = args
423 self.depending[msg_id] = job
396 424 return self.fail_unreachable(msg_id)
397 425
398 426 if after.check(self.all_completed, self.all_failed):
399 427 # time deps already met, try to run
400 if not self.maybe_run(msg_id, *args):
428 if not self.maybe_run(job):
401 429 # can't run yet
402 430 if msg_id not in self.all_failed:
403 431 # could have failed as unreachable
404 self.save_unmet(msg_id, *args)
432 self.save_unmet(job)
405 433 else:
406 self.save_unmet(msg_id, *args)
434 self.save_unmet(job)
407 435
408 436 def audit_timeouts(self):
409 437 """Audit all waiting tasks for expired timeouts."""
410 438 now = datetime.now()
411 439 for msg_id in self.depending.keys():
412 440 # must recheck, in case one failure cascaded to another:
413 441 if msg_id in self.depending:
414 raw,after,targets,follow,timeout = self.depending[msg_id]
415 if timeout and timeout < now:
442 job = self.depending[msg_id]
443 if job.timeout and job.timeout < now:
416 444 self.fail_unreachable(msg_id, error.TaskTimeout)
417 445
418 446 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
419 447 """a task has become unreachable, send a reply with an ImpossibleDependency
420 448 error."""
421 449 if msg_id not in self.depending:
422 450 self.log.error("msg %r already failed!", msg_id)
423 451 return
424 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
425 for mid in follow.union(after):
452 job = self.depending.pop(msg_id)
453 for mid in job.dependents:
426 454 if mid in self.graph:
427 455 self.graph[mid].remove(msg_id)
428 456
429 # FIXME: unpacking a message I've already unpacked, but didn't save:
430 idents,msg = self.session.feed_identities(raw_msg, copy=False)
431 header = self.session.unpack(msg[1].bytes)
432
433 457 try:
434 458 raise why()
435 459 except:
436 460 content = error.wrap_exception()
437 461
438 462 self.all_done.add(msg_id)
439 463 self.all_failed.add(msg_id)
440 464
441 465 msg = self.session.send(self.client_stream, 'apply_reply', content,
442 parent=header, ident=idents)
443 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
466 parent=job.header, ident=job.idents)
467 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
444 468
445 469 self.update_graph(msg_id, success=False)
446 470
447 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
471 def maybe_run(self, job):
448 472 """check location dependencies, and run if they are met."""
473 msg_id = job.msg_id
449 474 self.log.debug("Attempting to assign task %s", msg_id)
450 475 if not self.targets:
451 476 # no engines, definitely can't run
452 477 return False
453 478
454 blacklist = self.blacklist.setdefault(msg_id, set())
455 if follow or targets or blacklist or self.hwm:
479 if job.follow or job.targets or job.blacklist or self.hwm:
456 480 # we need a can_run filter
457 481 def can_run(idx):
458 482 # check hwm
459 483 if self.hwm and self.loads[idx] == self.hwm:
460 484 return False
461 485 target = self.targets[idx]
462 486 # check blacklist
463 if target in blacklist:
487 if target in job.blacklist:
464 488 return False
465 489 # check targets
466 if targets and target not in targets:
490 if job.targets and target not in job.targets:
467 491 return False
468 492 # check follow
469 return follow.check(self.completed[target], self.failed[target])
493 return job.follow.check(self.completed[target], self.failed[target])
470 494
471 495 indices = filter(can_run, range(len(self.targets)))
472 496
473 497 if not indices:
474 498 # couldn't run
475 if follow.all:
499 if job.follow.all:
476 500 # check follow for impossibility
477 501 dests = set()
478 502 relevant = set()
479 if follow.success:
503 if job.follow.success:
480 504 relevant = self.all_completed
481 if follow.failure:
505 if job.follow.failure:
482 506 relevant = relevant.union(self.all_failed)
483 for m in follow.intersection(relevant):
507 for m in job.follow.intersection(relevant):
484 508 dests.add(self.destinations[m])
485 509 if len(dests) > 1:
486 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
510 self.depending[msg_id] = job
487 511 self.fail_unreachable(msg_id)
488 512 return False
489 if targets:
513 if job.targets:
490 514 # check blacklist+targets for impossibility
491 targets.difference_update(blacklist)
492 if not targets or not targets.intersection(self.targets):
493 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
515 job.targets.difference_update(blacklist)
516 if not job.targets or not job.targets.intersection(self.targets):
517 self.depending[msg_id] = job
494 518 self.fail_unreachable(msg_id)
495 519 return False
496 520 return False
497 521 else:
498 522 indices = None
499 523
500 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
524 self.submit_task(job, indices)
501 525 return True
502 526
503 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
527 def save_unmet(self, job):
504 528 """Save a message for later submission when its dependencies are met."""
505 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
529 msg_id = job.msg_id
530 self.depending[msg_id] = job
506 531 # track the ids in follow or after, but not those already finished
507 for dep_id in after.union(follow).difference(self.all_done):
532 for dep_id in job.after.union(job.follow).difference(self.all_done):
508 533 if dep_id not in self.graph:
509 534 self.graph[dep_id] = set()
510 535 self.graph[dep_id].add(msg_id)
511 536
512 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
537 def submit_task(self, job, indices=None):
513 538 """Submit a task to any of a subset of our targets."""
514 539 if indices:
515 540 loads = [self.loads[i] for i in indices]
516 541 else:
517 542 loads = self.loads
518 543 idx = self.scheme(loads)
519 544 if indices:
520 545 idx = indices[idx]
521 546 target = self.targets[idx]
522 547 # print (target, map(str, msg[:3]))
523 548 # send job to the engine
524 549 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
525 self.engine_stream.send_multipart(raw_msg, copy=False)
550 self.engine_stream.send_multipart(job.raw_msg, copy=False)
526 551 # update load
527 552 self.add_job(idx)
528 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
553 self.pending[target][job.msg_id] = job
529 554 # notify Hub
530 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
555 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
531 556 self.session.send(self.mon_stream, 'task_destination', content=content,
532 557 ident=[b'tracktask',self.ident])
533 558
534 559
535 560 #-----------------------------------------------------------------------
536 561 # Result Handling
537 562 #-----------------------------------------------------------------------
538 563 def dispatch_result(self, raw_msg):
539 564 """dispatch method for result replies"""
540 565 try:
541 566 idents,msg = self.session.feed_identities(raw_msg, copy=False)
542 567 msg = self.session.unserialize(msg, content=False, copy=False)
543 568 engine = idents[0]
544 569 try:
545 570 idx = self.targets.index(engine)
546 571 except ValueError:
547 572 pass # skip load-update for dead engines
548 573 else:
549 574 self.finish_job(idx)
550 575 except Exception:
551 576 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
552 577 return
553 578
554 579 header = msg['header']
555 580 parent = msg['parent_header']
556 581 if header.get('dependencies_met', True):
557 582 success = (header['status'] == 'ok')
558 583 msg_id = parent['msg_id']
559 584 retries = self.retries[msg_id]
560 585 if not success and retries > 0:
561 586 # failed
562 587 self.retries[msg_id] = retries - 1
563 588 self.handle_unmet_dependency(idents, parent)
564 589 else:
565 590 del self.retries[msg_id]
566 591 # relay to client and update graph
567 592 self.handle_result(idents, parent, raw_msg, success)
568 593 # send to Hub monitor
569 594 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
570 595 else:
571 596 self.handle_unmet_dependency(idents, parent)
572 597
573 598 def handle_result(self, idents, parent, raw_msg, success=True):
574 599 """handle a real task result, either success or failure"""
575 600 # first, relay result to client
576 601 engine = idents[0]
577 602 client = idents[1]
578 603 # swap_ids for XREP-XREP mirror
579 604 raw_msg[:2] = [client,engine]
580 605 # print (map(str, raw_msg[:4]))
581 606 self.client_stream.send_multipart(raw_msg, copy=False)
582 607 # now, update our data structures
583 608 msg_id = parent['msg_id']
584 self.blacklist.pop(msg_id, None)
585 609 self.pending[engine].pop(msg_id)
586 610 if success:
587 611 self.completed[engine].add(msg_id)
588 612 self.all_completed.add(msg_id)
589 613 else:
590 614 self.failed[engine].add(msg_id)
591 615 self.all_failed.add(msg_id)
592 616 self.all_done.add(msg_id)
593 617 self.destinations[msg_id] = engine
594 618
595 619 self.update_graph(msg_id, success)
596 620
597 621 def handle_unmet_dependency(self, idents, parent):
598 622 """handle an unmet dependency"""
599 623 engine = idents[0]
600 624 msg_id = parent['msg_id']
601 625
602 if msg_id not in self.blacklist:
603 self.blacklist[msg_id] = set()
604 self.blacklist[msg_id].add(engine)
605
606 args = self.pending[engine].pop(msg_id)
607 raw,targets,after,follow,timeout = args
626 job = self.pending[engine].pop(msg_id)
627 job.blacklist.add(engine)
608 628
609 if self.blacklist[msg_id] == targets:
610 self.depending[msg_id] = args
629 if job.blacklist == job.targets:
630 self.depending[msg_id] = job
611 631 self.fail_unreachable(msg_id)
612 elif not self.maybe_run(msg_id, *args):
632 elif not self.maybe_run(job):
613 633 # resubmit failed
614 634 if msg_id not in self.all_failed:
615 635 # put it back in our dependency tree
616 self.save_unmet(msg_id, *args)
636 self.save_unmet(job)
617 637
618 638 if self.hwm:
619 639 try:
620 640 idx = self.targets.index(engine)
621 641 except ValueError:
622 642 pass # skip load-update for dead engines
623 643 else:
624 644 if self.loads[idx] == self.hwm-1:
625 645 self.update_graph(None)
626 646
627 647
628 648
629 649 def update_graph(self, dep_id=None, success=True):
630 650 """dep_id just finished. Update our dependency
631 651 graph and submit any jobs that just became runable.
632 652
633 653 Called with dep_id=None to update entire graph for hwm, but without finishing
634 654 a task.
635 655 """
636 656 # print ("\n\n***********")
637 657 # pprint (dep_id)
638 658 # pprint (self.graph)
639 659 # pprint (self.depending)
640 660 # pprint (self.all_completed)
641 661 # pprint (self.all_failed)
642 662 # print ("\n\n***********\n\n")
643 663 # update any jobs that depended on the dependency
644 664 jobs = self.graph.pop(dep_id, [])
645 665
646 666 # recheck *all* jobs if
647 667 # a) we have HWM and an engine just become no longer full
648 668 # or b) dep_id was given as None
669
649 670 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
650 671 jobs = self.depending.keys()
651 672
652 for msg_id in jobs:
653 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
673 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
674 job = self.depending[msg_id]
654 675
655 if after.unreachable(self.all_completed, self.all_failed)\
656 or follow.unreachable(self.all_completed, self.all_failed):
676 if job.after.unreachable(self.all_completed, self.all_failed)\
677 or job.follow.unreachable(self.all_completed, self.all_failed):
657 678 self.fail_unreachable(msg_id)
658 679
659 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
660 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
680 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
681 if self.maybe_run(job):
661 682
662 683 self.depending.pop(msg_id)
663 for mid in follow.union(after):
684 for mid in job.dependents:
664 685 if mid in self.graph:
665 686 self.graph[mid].remove(msg_id)
666 687
667 688 #----------------------------------------------------------------------
668 689 # methods to be overridden by subclasses
669 690 #----------------------------------------------------------------------
670 691
671 692 def add_job(self, idx):
672 693 """Called after self.targets[idx] just got the job with header.
673 694 Override with subclasses. The default ordering is simple LRU.
674 695 The default loads are the number of outstanding jobs."""
675 696 self.loads[idx] += 1
676 697 for lis in (self.targets, self.loads):
677 698 lis.append(lis.pop(idx))
678 699
679 700
680 701 def finish_job(self, idx):
681 702 """Called after self.targets[idx] just finished a job.
682 703 Override with subclasses."""
683 704 self.loads[idx] -= 1
684 705
685 706
686 707
687 708 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
688 709 logname='root', log_url=None, loglevel=logging.DEBUG,
689 710 identity=b'task', in_thread=False):
690 711
691 712 ZMQStream = zmqstream.ZMQStream
692 713
693 714 if config:
694 715 # unwrap dict back into Config
695 716 config = Config(config)
696 717
697 718 if in_thread:
698 719 # use instance() to get the same Context/Loop as our parent
699 720 ctx = zmq.Context.instance()
700 721 loop = ioloop.IOLoop.instance()
701 722 else:
702 723 # in a process, don't use instance()
703 724 # for safety with multiprocessing
704 725 ctx = zmq.Context()
705 726 loop = ioloop.IOLoop()
706 727 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
707 728 ins.setsockopt(zmq.IDENTITY, identity)
708 729 ins.bind(in_addr)
709 730
710 731 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
711 732 outs.setsockopt(zmq.IDENTITY, identity)
712 733 outs.bind(out_addr)
713 734 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
714 735 mons.connect(mon_addr)
715 736 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
716 737 nots.setsockopt(zmq.SUBSCRIBE, b'')
717 738 nots.connect(not_addr)
718 739
719 740 # setup logging.
720 741 if in_thread:
721 742 log = Application.instance().log
722 743 else:
723 744 if log_url:
724 745 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
725 746 else:
726 747 log = local_logger(logname, loglevel)
727 748
728 749 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
729 750 mon_stream=mons, notifier_stream=nots,
730 751 loop=loop, log=log,
731 752 config=config)
732 753 scheduler.start()
733 754 if not in_thread:
734 755 try:
735 756 loop.start()
736 757 except KeyboardInterrupt:
737 758 scheduler.log.critical("Interrupted, exiting...")
738 759
General Comments 0
You need to be logged in to leave comments. Login now