Show More
@@ -23,6 +23,7 b' from __future__ import print_function' | |||
|
23 | 23 | |
|
24 | 24 | import logging |
|
25 | 25 | import sys |
|
26 | import time | |
|
26 | 27 | |
|
27 | 28 | from datetime import datetime, timedelta |
|
28 | 29 | from random import randint, random |
@@ -119,9 +120,33 b' def leastload(loads):' | |||
|
119 | 120 | #--------------------------------------------------------------------- |
|
120 | 121 | # Classes |
|
121 | 122 | #--------------------------------------------------------------------- |
|
123 | ||
|
124 | ||
|
122 | 125 | # store empty default dependency: |
|
123 | 126 | MET = Dependency([]) |
|
124 | 127 | |
|
128 | ||
|
129 | class Job(object): | |
|
130 | """Simple container for a job""" | |
|
131 | def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout): | |
|
132 | self.msg_id = msg_id | |
|
133 | self.raw_msg = raw_msg | |
|
134 | self.idents = idents | |
|
135 | self.msg = msg | |
|
136 | self.header = header | |
|
137 | self.targets = targets | |
|
138 | self.after = after | |
|
139 | self.follow = follow | |
|
140 | self.timeout = timeout | |
|
141 | ||
|
142 | ||
|
143 | self.timestamp = time.time() | |
|
144 | self.blacklist = set() | |
|
145 | ||
|
146 | @property | |
|
147 | def dependents(self): | |
|
148 | return self.follow.union(self.after) | |
|
149 | ||
|
125 | 150 | class TaskScheduler(SessionFactory): |
|
126 | 151 | """Python TaskScheduler object. |
|
127 | 152 | |
@@ -168,7 +193,7 b' class TaskScheduler(SessionFactory):' | |||
|
168 | 193 | graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] |
|
169 | 194 | retries = Dict() # dict by msg_id of retries remaining (non-neg ints) |
|
170 | 195 | # waiting = List() # list of msg_ids ready to run, but haven't due to HWM |
|
171 |
depending = Dict() # dict by msg_id of |
|
|
196 | depending = Dict() # dict by msg_id of Jobs | |
|
172 | 197 | pending = Dict() # dict by engine_uuid of submitted tasks |
|
173 | 198 | completed = Dict() # dict by engine_uuid of completed tasks |
|
174 | 199 | failed = Dict() # dict by engine_uuid of failed tasks |
@@ -181,7 +206,7 b' class TaskScheduler(SessionFactory):' | |||
|
181 | 206 | all_failed = Set() # set of all failed tasks |
|
182 | 207 | all_done = Set() # set of all finished tasks=union(completed,failed) |
|
183 | 208 | all_ids = Set() # set of all submitted task IDs |
|
184 | blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency | |
|
209 | ||
|
185 | 210 | auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback') |
|
186 | 211 | |
|
187 | 212 | ident = CBytes() # ZMQ identity. This should just be self.session.session |
@@ -380,7 +405,10 b' class TaskScheduler(SessionFactory):' | |||
|
380 | 405 | # which timedelta does not accept |
|
381 | 406 | timeout = datetime.now() + timedelta(0,float(timeout),0) |
|
382 | 407 | |
|
383 | args = [raw_msg, targets, after, follow, timeout] | |
|
408 | job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, | |
|
409 | header=header, targets=targets, after=after, follow=follow, | |
|
410 | timeout=timeout, | |
|
411 | ) | |
|
384 | 412 | |
|
385 | 413 | # validate and reduce dependencies: |
|
386 | 414 | for dep in after,follow: |
@@ -388,22 +416,22 b' class TaskScheduler(SessionFactory):' | |||
|
388 | 416 | continue |
|
389 | 417 | # check valid: |
|
390 | 418 | if msg_id in dep or dep.difference(self.all_ids): |
|
391 |
self.depending[msg_id] = |
|
|
419 | self.depending[msg_id] = job | |
|
392 | 420 | return self.fail_unreachable(msg_id, error.InvalidDependency) |
|
393 | 421 | # check if unreachable: |
|
394 | 422 | if dep.unreachable(self.all_completed, self.all_failed): |
|
395 |
self.depending[msg_id] = |
|
|
423 | self.depending[msg_id] = job | |
|
396 | 424 | return self.fail_unreachable(msg_id) |
|
397 | 425 | |
|
398 | 426 | if after.check(self.all_completed, self.all_failed): |
|
399 | 427 | # time deps already met, try to run |
|
400 |
if not self.maybe_run( |
|
|
428 | if not self.maybe_run(job): | |
|
401 | 429 | # can't run yet |
|
402 | 430 | if msg_id not in self.all_failed: |
|
403 | 431 | # could have failed as unreachable |
|
404 |
self.save_unmet( |
|
|
432 | self.save_unmet(job) | |
|
405 | 433 | else: |
|
406 |
self.save_unmet( |
|
|
434 | self.save_unmet(job) | |
|
407 | 435 | |
|
408 | 436 | def audit_timeouts(self): |
|
409 | 437 | """Audit all waiting tasks for expired timeouts.""" |
@@ -411,8 +439,8 b' class TaskScheduler(SessionFactory):' | |||
|
411 | 439 | for msg_id in self.depending.keys(): |
|
412 | 440 | # must recheck, in case one failure cascaded to another: |
|
413 | 441 | if msg_id in self.depending: |
|
414 |
|
|
|
415 | if timeout and timeout < now: | |
|
442 | job = self.depending[msg_id] | |
|
443 | if job.timeout and job.timeout < now: | |
|
416 | 444 | self.fail_unreachable(msg_id, error.TaskTimeout) |
|
417 | 445 | |
|
418 | 446 | def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): |
@@ -421,15 +449,11 b' class TaskScheduler(SessionFactory):' | |||
|
421 | 449 | if msg_id not in self.depending: |
|
422 | 450 | self.log.error("msg %r already failed!", msg_id) |
|
423 | 451 | return |
|
424 |
|
|
|
425 |
for mid in |
|
|
452 | job = self.depending.pop(msg_id) | |
|
453 | for mid in job.dependents: | |
|
426 | 454 | if mid in self.graph: |
|
427 | 455 | self.graph[mid].remove(msg_id) |
|
428 | 456 | |
|
429 | # FIXME: unpacking a message I've already unpacked, but didn't save: | |
|
430 | idents,msg = self.session.feed_identities(raw_msg, copy=False) | |
|
431 | header = self.session.unpack(msg[1].bytes) | |
|
432 | ||
|
433 | 457 | try: |
|
434 | 458 | raise why() |
|
435 | 459 | except: |
@@ -439,20 +463,20 b' class TaskScheduler(SessionFactory):' | |||
|
439 | 463 | self.all_failed.add(msg_id) |
|
440 | 464 | |
|
441 | 465 | msg = self.session.send(self.client_stream, 'apply_reply', content, |
|
442 | parent=header, ident=idents) | |
|
443 | self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents) | |
|
466 | parent=job.header, ident=job.idents) | |
|
467 | self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents) | |
|
444 | 468 | |
|
445 | 469 | self.update_graph(msg_id, success=False) |
|
446 | 470 | |
|
447 | def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout): | |
|
471 | def maybe_run(self, job): | |
|
448 | 472 | """check location dependencies, and run if they are met.""" |
|
473 | msg_id = job.msg_id | |
|
449 | 474 | self.log.debug("Attempting to assign task %s", msg_id) |
|
450 | 475 | if not self.targets: |
|
451 | 476 | # no engines, definitely can't run |
|
452 | 477 | return False |
|
453 | 478 | |
|
454 | blacklist = self.blacklist.setdefault(msg_id, set()) | |
|
455 | if follow or targets or blacklist or self.hwm: | |
|
479 | if job.follow or job.targets or job.blacklist or self.hwm: | |
|
456 | 480 | # we need a can_run filter |
|
457 | 481 | def can_run(idx): |
|
458 | 482 | # check hwm |
@@ -460,56 +484,57 b' class TaskScheduler(SessionFactory):' | |||
|
460 | 484 | return False |
|
461 | 485 | target = self.targets[idx] |
|
462 | 486 | # check blacklist |
|
463 | if target in blacklist: | |
|
487 | if target in job.blacklist: | |
|
464 | 488 | return False |
|
465 | 489 | # check targets |
|
466 | if targets and target not in targets: | |
|
490 | if job.targets and target not in job.targets: | |
|
467 | 491 | return False |
|
468 | 492 | # check follow |
|
469 | return follow.check(self.completed[target], self.failed[target]) | |
|
493 | return job.follow.check(self.completed[target], self.failed[target]) | |
|
470 | 494 | |
|
471 | 495 | indices = filter(can_run, range(len(self.targets))) |
|
472 | 496 | |
|
473 | 497 | if not indices: |
|
474 | 498 | # couldn't run |
|
475 | if follow.all: | |
|
499 | if job.follow.all: | |
|
476 | 500 | # check follow for impossibility |
|
477 | 501 | dests = set() |
|
478 | 502 | relevant = set() |
|
479 | if follow.success: | |
|
503 | if job.follow.success: | |
|
480 | 504 | relevant = self.all_completed |
|
481 | if follow.failure: | |
|
505 | if job.follow.failure: | |
|
482 | 506 | relevant = relevant.union(self.all_failed) |
|
483 | for m in follow.intersection(relevant): | |
|
507 | for m in job.follow.intersection(relevant): | |
|
484 | 508 | dests.add(self.destinations[m]) |
|
485 | 509 | if len(dests) > 1: |
|
486 |
self.depending[msg_id] = |
|
|
510 | self.depending[msg_id] = job | |
|
487 | 511 | self.fail_unreachable(msg_id) |
|
488 | 512 | return False |
|
489 | if targets: | |
|
513 | if job.targets: | |
|
490 | 514 | # check blacklist+targets for impossibility |
|
491 | targets.difference_update(blacklist) | |
|
492 | if not targets or not targets.intersection(self.targets): | |
|
493 |
self.depending[msg_id] = |
|
|
515 | job.targets.difference_update(blacklist) | |
|
516 | if not job.targets or not job.targets.intersection(self.targets): | |
|
517 | self.depending[msg_id] = job | |
|
494 | 518 | self.fail_unreachable(msg_id) |
|
495 | 519 | return False |
|
496 | 520 | return False |
|
497 | 521 | else: |
|
498 | 522 | indices = None |
|
499 | 523 | |
|
500 |
self.submit_task( |
|
|
524 | self.submit_task(job, indices) | |
|
501 | 525 | return True |
|
502 | 526 | |
|
503 | def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout): | |
|
527 | def save_unmet(self, job): | |
|
504 | 528 | """Save a message for later submission when its dependencies are met.""" |
|
505 | self.depending[msg_id] = [raw_msg,targets,after,follow,timeout] | |
|
529 | msg_id = job.msg_id | |
|
530 | self.depending[msg_id] = job | |
|
506 | 531 | # track the ids in follow or after, but not those already finished |
|
507 | for dep_id in after.union(follow).difference(self.all_done): | |
|
532 | for dep_id in job.after.union(job.follow).difference(self.all_done): | |
|
508 | 533 | if dep_id not in self.graph: |
|
509 | 534 | self.graph[dep_id] = set() |
|
510 | 535 | self.graph[dep_id].add(msg_id) |
|
511 | 536 | |
|
512 |
def submit_task(self, |
|
|
537 | def submit_task(self, job, indices=None): | |
|
513 | 538 | """Submit a task to any of a subset of our targets.""" |
|
514 | 539 | if indices: |
|
515 | 540 | loads = [self.loads[i] for i in indices] |
@@ -522,12 +547,12 b' class TaskScheduler(SessionFactory):' | |||
|
522 | 547 | # print (target, map(str, msg[:3])) |
|
523 | 548 | # send job to the engine |
|
524 | 549 | self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) |
|
525 | self.engine_stream.send_multipart(raw_msg, copy=False) | |
|
550 | self.engine_stream.send_multipart(job.raw_msg, copy=False) | |
|
526 | 551 | # update load |
|
527 | 552 | self.add_job(idx) |
|
528 |
self.pending[target][msg_id] = |
|
|
553 | self.pending[target][job.msg_id] = job | |
|
529 | 554 | # notify Hub |
|
530 | content = dict(msg_id=msg_id, engine_id=target.decode('ascii')) | |
|
555 | content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) | |
|
531 | 556 | self.session.send(self.mon_stream, 'task_destination', content=content, |
|
532 | 557 | ident=[b'tracktask',self.ident]) |
|
533 | 558 | |
@@ -581,7 +606,6 b' class TaskScheduler(SessionFactory):' | |||
|
581 | 606 | self.client_stream.send_multipart(raw_msg, copy=False) |
|
582 | 607 | # now, update our data structures |
|
583 | 608 | msg_id = parent['msg_id'] |
|
584 | self.blacklist.pop(msg_id, None) | |
|
585 | 609 | self.pending[engine].pop(msg_id) |
|
586 | 610 | if success: |
|
587 | 611 | self.completed[engine].add(msg_id) |
@@ -599,21 +623,17 b' class TaskScheduler(SessionFactory):' | |||
|
599 | 623 | engine = idents[0] |
|
600 | 624 | msg_id = parent['msg_id'] |
|
601 | 625 | |
|
602 | if msg_id not in self.blacklist: | |
|
603 | self.blacklist[msg_id] = set() | |
|
604 | self.blacklist[msg_id].add(engine) | |
|
605 | ||
|
606 | args = self.pending[engine].pop(msg_id) | |
|
607 | raw,targets,after,follow,timeout = args | |
|
626 | job = self.pending[engine].pop(msg_id) | |
|
627 | job.blacklist.add(engine) | |
|
608 | 628 | |
|
609 |
if |
|
|
610 |
self.depending[msg_id] = |
|
|
629 | if job.blacklist == job.targets: | |
|
630 | self.depending[msg_id] = job | |
|
611 | 631 | self.fail_unreachable(msg_id) |
|
612 |
elif not self.maybe_run( |
|
|
632 | elif not self.maybe_run(job): | |
|
613 | 633 | # resubmit failed |
|
614 | 634 | if msg_id not in self.all_failed: |
|
615 | 635 | # put it back in our dependency tree |
|
616 |
self.save_unmet( |
|
|
636 | self.save_unmet(job) | |
|
617 | 637 | |
|
618 | 638 | if self.hwm: |
|
619 | 639 | try: |
@@ -646,21 +666,22 b' class TaskScheduler(SessionFactory):' | |||
|
646 | 666 | # recheck *all* jobs if |
|
647 | 667 | # a) we have HWM and an engine just become no longer full |
|
648 | 668 | # or b) dep_id was given as None |
|
669 | ||
|
649 | 670 | if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): |
|
650 | 671 | jobs = self.depending.keys() |
|
672 | ||
|
673 | for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp): | |
|
674 | job = self.depending[msg_id] | |
|
651 | 675 | |
|
652 | for msg_id in jobs: | |
|
653 | raw_msg, targets, after, follow, timeout = self.depending[msg_id] | |
|
654 | ||
|
655 | if after.unreachable(self.all_completed, self.all_failed)\ | |
|
656 | or follow.unreachable(self.all_completed, self.all_failed): | |
|
676 | if job.after.unreachable(self.all_completed, self.all_failed)\ | |
|
677 | or job.follow.unreachable(self.all_completed, self.all_failed): | |
|
657 | 678 | self.fail_unreachable(msg_id) |
|
658 | 679 | |
|
659 | elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run | |
|
660 |
if self.maybe_run( |
|
|
680 | elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run | |
|
681 | if self.maybe_run(job): | |
|
661 | 682 | |
|
662 | 683 | self.depending.pop(msg_id) |
|
663 |
for mid in |
|
|
684 | for mid in job.dependents: | |
|
664 | 685 | if mid in self.graph: |
|
665 | 686 | self.graph[mid].remove(msg_id) |
|
666 | 687 |
General Comments 0
You need to be logged in to leave comments.
Login now