##// END OF EJS Templates
Simplify structure of a Job in the TaskScheduler...
MinRK -
Show More
@@ -23,6 +23,7 b' from __future__ import print_function'
23
23
24 import logging
24 import logging
25 import sys
25 import sys
26 import time
26
27
27 from datetime import datetime, timedelta
28 from datetime import datetime, timedelta
28 from random import randint, random
29 from random import randint, random
@@ -119,9 +120,33 b' def leastload(loads):'
119 #---------------------------------------------------------------------
120 #---------------------------------------------------------------------
120 # Classes
121 # Classes
121 #---------------------------------------------------------------------
122 #---------------------------------------------------------------------
123
124
122 # store empty default dependency:
125 # store empty default dependency:
123 MET = Dependency([])
126 MET = Dependency([])
124
127
128
129 class Job(object):
130 """Simple container for a job"""
131 def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout):
132 self.msg_id = msg_id
133 self.raw_msg = raw_msg
134 self.idents = idents
135 self.msg = msg
136 self.header = header
137 self.targets = targets
138 self.after = after
139 self.follow = follow
140 self.timeout = timeout
141
142
143 self.timestamp = time.time()
144 self.blacklist = set()
145
146 @property
147 def dependents(self):
148 return self.follow.union(self.after)
149
125 class TaskScheduler(SessionFactory):
150 class TaskScheduler(SessionFactory):
126 """Python TaskScheduler object.
151 """Python TaskScheduler object.
127
152
@@ -168,7 +193,7 b' class TaskScheduler(SessionFactory):'
168 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
193 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
169 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
194 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
170 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
195 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
171 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
196 depending = Dict() # dict by msg_id of Jobs
172 pending = Dict() # dict by engine_uuid of submitted tasks
197 pending = Dict() # dict by engine_uuid of submitted tasks
173 completed = Dict() # dict by engine_uuid of completed tasks
198 completed = Dict() # dict by engine_uuid of completed tasks
174 failed = Dict() # dict by engine_uuid of failed tasks
199 failed = Dict() # dict by engine_uuid of failed tasks
@@ -181,7 +206,7 b' class TaskScheduler(SessionFactory):'
181 all_failed = Set() # set of all failed tasks
206 all_failed = Set() # set of all failed tasks
182 all_done = Set() # set of all finished tasks=union(completed,failed)
207 all_done = Set() # set of all finished tasks=union(completed,failed)
183 all_ids = Set() # set of all submitted task IDs
208 all_ids = Set() # set of all submitted task IDs
184 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
209
185 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
210 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
186
211
187 ident = CBytes() # ZMQ identity. This should just be self.session.session
212 ident = CBytes() # ZMQ identity. This should just be self.session.session
@@ -380,7 +405,10 b' class TaskScheduler(SessionFactory):'
380 # which timedelta does not accept
405 # which timedelta does not accept
381 timeout = datetime.now() + timedelta(0,float(timeout),0)
406 timeout = datetime.now() + timedelta(0,float(timeout),0)
382
407
383 args = [raw_msg, targets, after, follow, timeout]
408 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
409 header=header, targets=targets, after=after, follow=follow,
410 timeout=timeout,
411 )
384
412
385 # validate and reduce dependencies:
413 # validate and reduce dependencies:
386 for dep in after,follow:
414 for dep in after,follow:
@@ -388,22 +416,22 b' class TaskScheduler(SessionFactory):'
388 continue
416 continue
389 # check valid:
417 # check valid:
390 if msg_id in dep or dep.difference(self.all_ids):
418 if msg_id in dep or dep.difference(self.all_ids):
391 self.depending[msg_id] = args
419 self.depending[msg_id] = job
392 return self.fail_unreachable(msg_id, error.InvalidDependency)
420 return self.fail_unreachable(msg_id, error.InvalidDependency)
393 # check if unreachable:
421 # check if unreachable:
394 if dep.unreachable(self.all_completed, self.all_failed):
422 if dep.unreachable(self.all_completed, self.all_failed):
395 self.depending[msg_id] = args
423 self.depending[msg_id] = job
396 return self.fail_unreachable(msg_id)
424 return self.fail_unreachable(msg_id)
397
425
398 if after.check(self.all_completed, self.all_failed):
426 if after.check(self.all_completed, self.all_failed):
399 # time deps already met, try to run
427 # time deps already met, try to run
400 if not self.maybe_run(msg_id, *args):
428 if not self.maybe_run(job):
401 # can't run yet
429 # can't run yet
402 if msg_id not in self.all_failed:
430 if msg_id not in self.all_failed:
403 # could have failed as unreachable
431 # could have failed as unreachable
404 self.save_unmet(msg_id, *args)
432 self.save_unmet(job)
405 else:
433 else:
406 self.save_unmet(msg_id, *args)
434 self.save_unmet(job)
407
435
408 def audit_timeouts(self):
436 def audit_timeouts(self):
409 """Audit all waiting tasks for expired timeouts."""
437 """Audit all waiting tasks for expired timeouts."""
@@ -411,8 +439,8 b' class TaskScheduler(SessionFactory):'
411 for msg_id in self.depending.keys():
439 for msg_id in self.depending.keys():
412 # must recheck, in case one failure cascaded to another:
440 # must recheck, in case one failure cascaded to another:
413 if msg_id in self.depending:
441 if msg_id in self.depending:
414 raw,after,targets,follow,timeout = self.depending[msg_id]
442 job = self.depending[msg_id]
415 if timeout and timeout < now:
443 if job.timeout and job.timeout < now:
416 self.fail_unreachable(msg_id, error.TaskTimeout)
444 self.fail_unreachable(msg_id, error.TaskTimeout)
417
445
418 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
446 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
@@ -421,15 +449,11 b' class TaskScheduler(SessionFactory):'
421 if msg_id not in self.depending:
449 if msg_id not in self.depending:
422 self.log.error("msg %r already failed!", msg_id)
450 self.log.error("msg %r already failed!", msg_id)
423 return
451 return
424 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
452 job = self.depending.pop(msg_id)
425 for mid in follow.union(after):
453 for mid in job.dependents:
426 if mid in self.graph:
454 if mid in self.graph:
427 self.graph[mid].remove(msg_id)
455 self.graph[mid].remove(msg_id)
428
456
429 # FIXME: unpacking a message I've already unpacked, but didn't save:
430 idents,msg = self.session.feed_identities(raw_msg, copy=False)
431 header = self.session.unpack(msg[1].bytes)
432
433 try:
457 try:
434 raise why()
458 raise why()
435 except:
459 except:
@@ -439,20 +463,20 b' class TaskScheduler(SessionFactory):'
439 self.all_failed.add(msg_id)
463 self.all_failed.add(msg_id)
440
464
441 msg = self.session.send(self.client_stream, 'apply_reply', content,
465 msg = self.session.send(self.client_stream, 'apply_reply', content,
442 parent=header, ident=idents)
466 parent=job.header, ident=job.idents)
443 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
467 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
444
468
445 self.update_graph(msg_id, success=False)
469 self.update_graph(msg_id, success=False)
446
470
447 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
471 def maybe_run(self, job):
448 """check location dependencies, and run if they are met."""
472 """check location dependencies, and run if they are met."""
473 msg_id = job.msg_id
449 self.log.debug("Attempting to assign task %s", msg_id)
474 self.log.debug("Attempting to assign task %s", msg_id)
450 if not self.targets:
475 if not self.targets:
451 # no engines, definitely can't run
476 # no engines, definitely can't run
452 return False
477 return False
453
478
454 blacklist = self.blacklist.setdefault(msg_id, set())
479 if job.follow or job.targets or job.blacklist or self.hwm:
455 if follow or targets or blacklist or self.hwm:
456 # we need a can_run filter
480 # we need a can_run filter
457 def can_run(idx):
481 def can_run(idx):
458 # check hwm
482 # check hwm
@@ -460,56 +484,57 b' class TaskScheduler(SessionFactory):'
460 return False
484 return False
461 target = self.targets[idx]
485 target = self.targets[idx]
462 # check blacklist
486 # check blacklist
463 if target in blacklist:
487 if target in job.blacklist:
464 return False
488 return False
465 # check targets
489 # check targets
466 if targets and target not in targets:
490 if job.targets and target not in job.targets:
467 return False
491 return False
468 # check follow
492 # check follow
469 return follow.check(self.completed[target], self.failed[target])
493 return job.follow.check(self.completed[target], self.failed[target])
470
494
471 indices = filter(can_run, range(len(self.targets)))
495 indices = filter(can_run, range(len(self.targets)))
472
496
473 if not indices:
497 if not indices:
474 # couldn't run
498 # couldn't run
475 if follow.all:
499 if job.follow.all:
476 # check follow for impossibility
500 # check follow for impossibility
477 dests = set()
501 dests = set()
478 relevant = set()
502 relevant = set()
479 if follow.success:
503 if job.follow.success:
480 relevant = self.all_completed
504 relevant = self.all_completed
481 if follow.failure:
505 if job.follow.failure:
482 relevant = relevant.union(self.all_failed)
506 relevant = relevant.union(self.all_failed)
483 for m in follow.intersection(relevant):
507 for m in job.follow.intersection(relevant):
484 dests.add(self.destinations[m])
508 dests.add(self.destinations[m])
485 if len(dests) > 1:
509 if len(dests) > 1:
486 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
510 self.depending[msg_id] = job
487 self.fail_unreachable(msg_id)
511 self.fail_unreachable(msg_id)
488 return False
512 return False
489 if targets:
513 if job.targets:
490 # check blacklist+targets for impossibility
514 # check blacklist+targets for impossibility
491 targets.difference_update(blacklist)
515 job.targets.difference_update(blacklist)
492 if not targets or not targets.intersection(self.targets):
516 if not job.targets or not job.targets.intersection(self.targets):
493 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
517 self.depending[msg_id] = job
494 self.fail_unreachable(msg_id)
518 self.fail_unreachable(msg_id)
495 return False
519 return False
496 return False
520 return False
497 else:
521 else:
498 indices = None
522 indices = None
499
523
500 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
524 self.submit_task(job, indices)
501 return True
525 return True
502
526
503 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
527 def save_unmet(self, job):
504 """Save a message for later submission when its dependencies are met."""
528 """Save a message for later submission when its dependencies are met."""
505 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
529 msg_id = job.msg_id
530 self.depending[msg_id] = job
506 # track the ids in follow or after, but not those already finished
531 # track the ids in follow or after, but not those already finished
507 for dep_id in after.union(follow).difference(self.all_done):
532 for dep_id in job.after.union(job.follow).difference(self.all_done):
508 if dep_id not in self.graph:
533 if dep_id not in self.graph:
509 self.graph[dep_id] = set()
534 self.graph[dep_id] = set()
510 self.graph[dep_id].add(msg_id)
535 self.graph[dep_id].add(msg_id)
511
536
512 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
537 def submit_task(self, job, indices=None):
513 """Submit a task to any of a subset of our targets."""
538 """Submit a task to any of a subset of our targets."""
514 if indices:
539 if indices:
515 loads = [self.loads[i] for i in indices]
540 loads = [self.loads[i] for i in indices]
@@ -522,12 +547,12 b' class TaskScheduler(SessionFactory):'
522 # print (target, map(str, msg[:3]))
547 # print (target, map(str, msg[:3]))
523 # send job to the engine
548 # send job to the engine
524 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
549 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
525 self.engine_stream.send_multipart(raw_msg, copy=False)
550 self.engine_stream.send_multipart(job.raw_msg, copy=False)
526 # update load
551 # update load
527 self.add_job(idx)
552 self.add_job(idx)
528 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
553 self.pending[target][job.msg_id] = job
529 # notify Hub
554 # notify Hub
530 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
555 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
531 self.session.send(self.mon_stream, 'task_destination', content=content,
556 self.session.send(self.mon_stream, 'task_destination', content=content,
532 ident=[b'tracktask',self.ident])
557 ident=[b'tracktask',self.ident])
533
558
@@ -581,7 +606,6 b' class TaskScheduler(SessionFactory):'
581 self.client_stream.send_multipart(raw_msg, copy=False)
606 self.client_stream.send_multipart(raw_msg, copy=False)
582 # now, update our data structures
607 # now, update our data structures
583 msg_id = parent['msg_id']
608 msg_id = parent['msg_id']
584 self.blacklist.pop(msg_id, None)
585 self.pending[engine].pop(msg_id)
609 self.pending[engine].pop(msg_id)
586 if success:
610 if success:
587 self.completed[engine].add(msg_id)
611 self.completed[engine].add(msg_id)
@@ -599,21 +623,17 b' class TaskScheduler(SessionFactory):'
599 engine = idents[0]
623 engine = idents[0]
600 msg_id = parent['msg_id']
624 msg_id = parent['msg_id']
601
625
602 if msg_id not in self.blacklist:
626 job = self.pending[engine].pop(msg_id)
603 self.blacklist[msg_id] = set()
627 job.blacklist.add(engine)
604 self.blacklist[msg_id].add(engine)
605
606 args = self.pending[engine].pop(msg_id)
607 raw,targets,after,follow,timeout = args
608
628
609 if self.blacklist[msg_id] == targets:
629 if job.blacklist == job.targets:
610 self.depending[msg_id] = args
630 self.depending[msg_id] = job
611 self.fail_unreachable(msg_id)
631 self.fail_unreachable(msg_id)
612 elif not self.maybe_run(msg_id, *args):
632 elif not self.maybe_run(job):
613 # resubmit failed
633 # resubmit failed
614 if msg_id not in self.all_failed:
634 if msg_id not in self.all_failed:
615 # put it back in our dependency tree
635 # put it back in our dependency tree
616 self.save_unmet(msg_id, *args)
636 self.save_unmet(job)
617
637
618 if self.hwm:
638 if self.hwm:
619 try:
639 try:
@@ -646,21 +666,22 b' class TaskScheduler(SessionFactory):'
646 # recheck *all* jobs if
666 # recheck *all* jobs if
647 # a) we have HWM and an engine just become no longer full
667 # a) we have HWM and an engine just become no longer full
648 # or b) dep_id was given as None
668 # or b) dep_id was given as None
669
649 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
670 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
650 jobs = self.depending.keys()
671 jobs = self.depending.keys()
672
673 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
674 job = self.depending[msg_id]
651
675
652 for msg_id in jobs:
676 if job.after.unreachable(self.all_completed, self.all_failed)\
653 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
677 or job.follow.unreachable(self.all_completed, self.all_failed):
654
655 if after.unreachable(self.all_completed, self.all_failed)\
656 or follow.unreachable(self.all_completed, self.all_failed):
657 self.fail_unreachable(msg_id)
678 self.fail_unreachable(msg_id)
658
679
659 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
680 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
660 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
681 if self.maybe_run(job):
661
682
662 self.depending.pop(msg_id)
683 self.depending.pop(msg_id)
663 for mid in follow.union(after):
684 for mid in job.dependents:
664 if mid in self.graph:
685 if mid in self.graph:
665 self.graph[mid].remove(msg_id)
686 self.graph[mid].remove(msg_id)
666
687
General Comments 0
You need to be logged in to leave comments. Login now