##// END OF EJS Templates
adjust Scheduler timeout logic...
MinRK -
Show More
@@ -1,852 +1,860 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 import logging
23 23 import sys
24 24 import time
25 25
26 26 from collections import deque
27 27 from datetime import datetime
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44 from IPython.utils.py3compat import cast_bytes
45 45
46 46 from IPython.parallel import error, util
47 47 from IPython.parallel.factory import SessionFactory
48 48 from IPython.parallel.util import connect_logger, local_logger
49 49
50 50 from .dependency import Dependency
51 51
52 52 @decorator
53 53 def logged(f,self,*args,**kwargs):
54 54 # print ("#--------------------")
55 55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 56 # print ("#--")
57 57 return f(self,*args, **kwargs)
58 58
59 59 #----------------------------------------------------------------------
60 60 # Chooser functions
61 61 #----------------------------------------------------------------------
62 62
63 63 def plainrandom(loads):
64 64 """Plain random pick."""
65 65 n = len(loads)
66 66 return randint(0,n-1)
67 67
68 68 def lru(loads):
69 69 """Always pick the front of the line.
70 70
71 71 The content of `loads` is ignored.
72 72
73 73 Assumes LRU ordering of loads, with oldest first.
74 74 """
75 75 return 0
76 76
77 77 def twobin(loads):
78 78 """Pick two at random, use the LRU of the two.
79 79
80 80 The content of loads is ignored.
81 81
82 82 Assumes LRU ordering of loads, with oldest first.
83 83 """
84 84 n = len(loads)
85 85 a = randint(0,n-1)
86 86 b = randint(0,n-1)
87 87 return min(a,b)
88 88
89 89 def weighted(loads):
90 90 """Pick two at random using inverse load as weight.
91 91
92 92 Return the less loaded of the two.
93 93 """
94 94 # weight 0 a million times more than 1:
95 95 weights = 1./(1e-6+numpy.array(loads))
96 96 sums = weights.cumsum()
97 97 t = sums[-1]
98 98 x = random()*t
99 99 y = random()*t
100 100 idx = 0
101 101 idy = 0
102 102 while sums[idx] < x:
103 103 idx += 1
104 104 while sums[idy] < y:
105 105 idy += 1
106 106 if weights[idy] > weights[idx]:
107 107 return idy
108 108 else:
109 109 return idx
110 110
111 111 def leastload(loads):
112 112 """Always choose the lowest load.
113 113
114 114 If the lowest load occurs more than once, the first
115 115 occurance will be used. If loads has LRU ordering, this means
116 116 the LRU of those with the lowest load is chosen.
117 117 """
118 118 return loads.index(min(loads))
119 119
120 120 #---------------------------------------------------------------------
121 121 # Classes
122 122 #---------------------------------------------------------------------
123 123
124 124
125 125 # store empty default dependency:
126 126 MET = Dependency([])
127 127
128 128
129 129 class Job(object):
130 130 """Simple container for a job"""
131 131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 132 targets, after, follow, timeout):
133 133 self.msg_id = msg_id
134 134 self.raw_msg = raw_msg
135 135 self.idents = idents
136 136 self.msg = msg
137 137 self.header = header
138 138 self.metadata = metadata
139 139 self.targets = targets
140 140 self.after = after
141 141 self.follow = follow
142 142 self.timeout = timeout
143 self.removed = False # used for lazy-delete from sorted queue
144 143
144 self.removed = False # used for lazy-delete from sorted queue
145 145 self.timestamp = time.time()
146 self.timeout_id = 0
146 147 self.blacklist = set()
147 148
148 149 def __lt__(self, other):
149 150 return self.timestamp < other.timestamp
150 151
151 152 def __cmp__(self, other):
152 153 return cmp(self.timestamp, other.timestamp)
153 154
154 155 @property
155 156 def dependents(self):
156 157 return self.follow.union(self.after)
157 158
159
158 160 class TaskScheduler(SessionFactory):
159 161 """Python TaskScheduler object.
160 162
161 163 This is the simplest object that supports msg_id based
162 164 DAG dependencies. *Only* task msg_ids are checked, not
163 165 msg_ids of jobs submitted via the MUX queue.
164 166
165 167 """
166 168
167 169 hwm = Integer(1, config=True,
168 170 help="""specify the High Water Mark (HWM) for the downstream
169 171 socket in the Task scheduler. This is the maximum number
170 172 of allowed outstanding tasks on each engine.
171 173
172 174 The default (1) means that only one task can be outstanding on each
173 175 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
174 176 engines continue to be assigned tasks while they are working,
175 177 effectively hiding network latency behind computation, but can result
176 178 in an imbalance of work when submitting many heterogenous tasks all at
177 179 once. Any positive value greater than one is a compromise between the
178 180 two.
179 181
180 182 """
181 183 )
182 184 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
183 185 'leastload', config=True, allow_none=False,
184 186 help="""select the task scheduler scheme [default: Python LRU]
185 187 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
186 188 )
187 189 def _scheme_name_changed(self, old, new):
188 190 self.log.debug("Using scheme %r"%new)
189 191 self.scheme = globals()[new]
190 192
191 193 # input arguments:
192 194 scheme = Instance(FunctionType) # function for determining the destination
193 195 def _scheme_default(self):
194 196 return leastload
195 197 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
196 198 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
197 199 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
198 200 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
199 201 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
200 202
201 203 # internals:
202 204 queue = Instance(deque) # sorted list of Jobs
203 205 def _queue_default(self):
204 206 return deque()
205 207 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
206 208 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
207 209 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
208 210 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
209 211 pending = Dict() # dict by engine_uuid of submitted tasks
210 212 completed = Dict() # dict by engine_uuid of completed tasks
211 213 failed = Dict() # dict by engine_uuid of failed tasks
212 214 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
213 215 clients = Dict() # dict by msg_id for who submitted the task
214 216 targets = List() # list of target IDENTs
215 217 loads = List() # list of engine loads
216 218 # full = Set() # set of IDENTs that have HWM outstanding tasks
217 219 all_completed = Set() # set of all completed tasks
218 220 all_failed = Set() # set of all failed tasks
219 221 all_done = Set() # set of all finished tasks=union(completed,failed)
220 222 all_ids = Set() # set of all submitted task IDs
221 223
222 224 ident = CBytes() # ZMQ identity. This should just be self.session.session
223 225 # but ensure Bytes
224 226 def _ident_default(self):
225 227 return self.session.bsession
226 228
227 229 def start(self):
228 230 self.query_stream.on_recv(self.dispatch_query_reply)
229 231 self.session.send(self.query_stream, "connection_request", {})
230 232
231 233 self.engine_stream.on_recv(self.dispatch_result, copy=False)
232 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
233 235
234 236 self._notification_handlers = dict(
235 237 registration_notification = self._register_engine,
236 238 unregistration_notification = self._unregister_engine
237 239 )
238 240 self.notifier_stream.on_recv(self.dispatch_notification)
239 241 self.log.info("Scheduler started [%s]" % self.scheme_name)
240 242
241 243 def resume_receiving(self):
242 244 """Resume accepting jobs."""
243 245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
244 246
245 247 def stop_receiving(self):
246 248 """Stop accepting jobs while there are no engines.
247 249 Leave them in the ZMQ queue."""
248 250 self.client_stream.on_recv(None)
249 251
250 252 #-----------------------------------------------------------------------
251 253 # [Un]Registration Handling
252 254 #-----------------------------------------------------------------------
253 255
254 256
255 257 def dispatch_query_reply(self, msg):
256 258 """handle reply to our initial connection request"""
257 259 try:
258 260 idents,msg = self.session.feed_identities(msg)
259 261 except ValueError:
260 262 self.log.warn("task::Invalid Message: %r",msg)
261 263 return
262 264 try:
263 265 msg = self.session.unserialize(msg)
264 266 except ValueError:
265 267 self.log.warn("task::Unauthorized message from: %r"%idents)
266 268 return
267 269
268 270 content = msg['content']
269 271 for uuid in content.get('engines', {}).values():
270 272 self._register_engine(cast_bytes(uuid))
271 273
272 274
273 275 @util.log_errors
274 276 def dispatch_notification(self, msg):
275 277 """dispatch register/unregister events."""
276 278 try:
277 279 idents,msg = self.session.feed_identities(msg)
278 280 except ValueError:
279 281 self.log.warn("task::Invalid Message: %r",msg)
280 282 return
281 283 try:
282 284 msg = self.session.unserialize(msg)
283 285 except ValueError:
284 286 self.log.warn("task::Unauthorized message from: %r"%idents)
285 287 return
286 288
287 289 msg_type = msg['header']['msg_type']
288 290
289 291 handler = self._notification_handlers.get(msg_type, None)
290 292 if handler is None:
291 293 self.log.error("Unhandled message type: %r"%msg_type)
292 294 else:
293 295 try:
294 296 handler(cast_bytes(msg['content']['uuid']))
295 297 except Exception:
296 298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
297 299
298 300 def _register_engine(self, uid):
299 301 """New engine with ident `uid` became available."""
300 302 # head of the line:
301 303 self.targets.insert(0,uid)
302 304 self.loads.insert(0,0)
303 305
304 306 # initialize sets
305 307 self.completed[uid] = set()
306 308 self.failed[uid] = set()
307 309 self.pending[uid] = {}
308 310
309 311 # rescan the graph:
310 312 self.update_graph(None)
311 313
312 314 def _unregister_engine(self, uid):
313 315 """Existing engine with ident `uid` became unavailable."""
314 316 if len(self.targets) == 1:
315 317 # this was our only engine
316 318 pass
317 319
318 320 # handle any potentially finished tasks:
319 321 self.engine_stream.flush()
320 322
321 323 # don't pop destinations, because they might be used later
322 324 # map(self.destinations.pop, self.completed.pop(uid))
323 325 # map(self.destinations.pop, self.failed.pop(uid))
324 326
325 327 # prevent this engine from receiving work
326 328 idx = self.targets.index(uid)
327 329 self.targets.pop(idx)
328 330 self.loads.pop(idx)
329 331
330 332 # wait 5 seconds before cleaning up pending jobs, since the results might
331 333 # still be incoming
332 334 if self.pending[uid]:
333 335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
334 336 dc.start()
335 337 else:
336 338 self.completed.pop(uid)
337 339 self.failed.pop(uid)
338 340
339 341
340 342 def handle_stranded_tasks(self, engine):
341 343 """Deal with jobs resident in an engine that died."""
342 344 lost = self.pending[engine]
343 345 for msg_id in lost.keys():
344 346 if msg_id not in self.pending[engine]:
345 347 # prevent double-handling of messages
346 348 continue
347 349
348 350 raw_msg = lost[msg_id].raw_msg
349 351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
350 352 parent = self.session.unpack(msg[1].bytes)
351 353 idents = [engine, idents[0]]
352 354
353 355 # build fake error reply
354 356 try:
355 357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
356 358 except:
357 359 content = error.wrap_exception()
358 360 # build fake metadata
359 361 md = dict(
360 362 status=u'error',
361 363 engine=engine.decode('ascii'),
362 364 date=datetime.now(),
363 365 )
364 366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
365 367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
366 368 # and dispatch it
367 369 self.dispatch_result(raw_reply)
368 370
369 371 # finally scrub completed/failed lists
370 372 self.completed.pop(engine)
371 373 self.failed.pop(engine)
372 374
373 375
374 376 #-----------------------------------------------------------------------
375 377 # Job Submission
376 378 #-----------------------------------------------------------------------
377 379
378 380
379 381 @util.log_errors
380 382 def dispatch_submission(self, raw_msg):
381 383 """Dispatch job submission to appropriate handlers."""
382 384 # ensure targets up to date:
383 385 self.notifier_stream.flush()
384 386 try:
385 387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
386 388 msg = self.session.unserialize(msg, content=False, copy=False)
387 389 except Exception:
388 390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
389 391 return
390 392
391 393
392 394 # send to monitor
393 395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
394 396
395 397 header = msg['header']
396 398 md = msg['metadata']
397 399 msg_id = header['msg_id']
398 400 self.all_ids.add(msg_id)
399 401
400 402 # get targets as a set of bytes objects
401 403 # from a list of unicode objects
402 404 targets = md.get('targets', [])
403 405 targets = map(cast_bytes, targets)
404 406 targets = set(targets)
405 407
406 408 retries = md.get('retries', 0)
407 409 self.retries[msg_id] = retries
408 410
409 411 # time dependencies
410 412 after = md.get('after', None)
411 413 if after:
412 414 after = Dependency(after)
413 415 if after.all:
414 416 if after.success:
415 417 after = Dependency(after.difference(self.all_completed),
416 418 success=after.success,
417 419 failure=after.failure,
418 420 all=after.all,
419 421 )
420 422 if after.failure:
421 423 after = Dependency(after.difference(self.all_failed),
422 424 success=after.success,
423 425 failure=after.failure,
424 426 all=after.all,
425 427 )
426 428 if after.check(self.all_completed, self.all_failed):
427 429 # recast as empty set, if `after` already met,
428 430 # to prevent unnecessary set comparisons
429 431 after = MET
430 432 else:
431 433 after = MET
432 434
433 435 # location dependencies
434 436 follow = Dependency(md.get('follow', []))
435 437
436 # turn timeouts into datetime objects:
437 438 timeout = md.get('timeout', None)
438 439 if timeout:
439 timeout = time.time() + float(timeout)
440 timeout = float(timeout)
440 441
441 442 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
442 443 header=header, targets=targets, after=after, follow=follow,
443 444 timeout=timeout, metadata=md,
444 445 )
445 if timeout:
446 # schedule timeout callback
447 self.loop.add_timeout(timeout, lambda : self.job_timeout(job))
448
449 446 # validate and reduce dependencies:
450 447 for dep in after,follow:
451 448 if not dep: # empty dependency
452 449 continue
453 450 # check valid:
454 451 if msg_id in dep or dep.difference(self.all_ids):
455 452 self.queue_map[msg_id] = job
456 453 return self.fail_unreachable(msg_id, error.InvalidDependency)
457 454 # check if unreachable:
458 455 if dep.unreachable(self.all_completed, self.all_failed):
459 456 self.queue_map[msg_id] = job
460 457 return self.fail_unreachable(msg_id)
461 458
462 459 if after.check(self.all_completed, self.all_failed):
463 460 # time deps already met, try to run
464 461 if not self.maybe_run(job):
465 462 # can't run yet
466 463 if msg_id not in self.all_failed:
467 464 # could have failed as unreachable
468 465 self.save_unmet(job)
469 466 else:
470 467 self.save_unmet(job)
471 468
472 def job_timeout(self, job):
469 def job_timeout(self, job, timeout_id):
473 470 """callback for a job's timeout.
474 471
475 472 The job may or may not have been run at this point.
476 473 """
474 if job.timeout_id != timeout_id:
475 # not the most recent call
476 return
477 477 now = time.time()
478 478 if job.timeout >= (now + 1):
479 479 self.log.warn("task %s timeout fired prematurely: %s > %s",
480 480 job.msg_id, job.timeout, now
481 481 )
482 482 if job.msg_id in self.queue_map:
483 483 # still waiting, but ran out of time
484 484 self.log.info("task %r timed out", job.msg_id)
485 485 self.fail_unreachable(job.msg_id, error.TaskTimeout)
486 486
487 487 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
488 488 """a task has become unreachable, send a reply with an ImpossibleDependency
489 489 error."""
490 490 if msg_id not in self.queue_map:
491 491 self.log.error("task %r already failed!", msg_id)
492 492 return
493 493 job = self.queue_map.pop(msg_id)
494 494 # lazy-delete from the queue
495 495 job.removed = True
496 496 for mid in job.dependents:
497 497 if mid in self.graph:
498 498 self.graph[mid].remove(msg_id)
499 499
500 500 try:
501 501 raise why()
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
505 505
506 506 self.all_done.add(msg_id)
507 507 self.all_failed.add(msg_id)
508 508
509 509 msg = self.session.send(self.client_stream, 'apply_reply', content,
510 510 parent=job.header, ident=job.idents)
511 511 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
512 512
513 513 self.update_graph(msg_id, success=False)
514 514
515 515 def available_engines(self):
516 516 """return a list of available engine indices based on HWM"""
517 517 if not self.hwm:
518 518 return range(len(self.targets))
519 519 available = []
520 520 for idx in range(len(self.targets)):
521 521 if self.loads[idx] < self.hwm:
522 522 available.append(idx)
523 523 return available
524 524
525 525 def maybe_run(self, job):
526 526 """check location dependencies, and run if they are met."""
527 527 msg_id = job.msg_id
528 528 self.log.debug("Attempting to assign task %s", msg_id)
529 529 available = self.available_engines()
530 530 if not available:
531 531 # no engines, definitely can't run
532 532 return False
533 533
534 534 if job.follow or job.targets or job.blacklist or self.hwm:
535 535 # we need a can_run filter
536 536 def can_run(idx):
537 537 # check hwm
538 538 if self.hwm and self.loads[idx] == self.hwm:
539 539 return False
540 540 target = self.targets[idx]
541 541 # check blacklist
542 542 if target in job.blacklist:
543 543 return False
544 544 # check targets
545 545 if job.targets and target not in job.targets:
546 546 return False
547 547 # check follow
548 548 return job.follow.check(self.completed[target], self.failed[target])
549 549
550 550 indices = filter(can_run, available)
551 551
552 552 if not indices:
553 553 # couldn't run
554 554 if job.follow.all:
555 555 # check follow for impossibility
556 556 dests = set()
557 557 relevant = set()
558 558 if job.follow.success:
559 559 relevant = self.all_completed
560 560 if job.follow.failure:
561 561 relevant = relevant.union(self.all_failed)
562 562 for m in job.follow.intersection(relevant):
563 563 dests.add(self.destinations[m])
564 564 if len(dests) > 1:
565 565 self.queue_map[msg_id] = job
566 566 self.fail_unreachable(msg_id)
567 567 return False
568 568 if job.targets:
569 569 # check blacklist+targets for impossibility
570 570 job.targets.difference_update(job.blacklist)
571 571 if not job.targets or not job.targets.intersection(self.targets):
572 572 self.queue_map[msg_id] = job
573 573 self.fail_unreachable(msg_id)
574 574 return False
575 575 return False
576 576 else:
577 577 indices = None
578 578
579 579 self.submit_task(job, indices)
580 580 return True
581 581
582 582 def save_unmet(self, job):
583 583 """Save a message for later submission when its dependencies are met."""
584 584 msg_id = job.msg_id
585 585 self.log.debug("Adding task %s to the queue", msg_id)
586 586 self.queue_map[msg_id] = job
587 587 self.queue.append(job)
588 588 # track the ids in follow or after, but not those already finished
589 589 for dep_id in job.after.union(job.follow).difference(self.all_done):
590 590 if dep_id not in self.graph:
591 591 self.graph[dep_id] = set()
592 592 self.graph[dep_id].add(msg_id)
593
594 # schedule timeout callback
595 if job.timeout:
596 timeout_id = job.timeout_id = job.timeout_id + 1
597 self.loop.add_timeout(time.time() + job.timeout,
598 lambda : self.job_timeout(job, timeout_id)
599 )
600
593 601
594 602 def submit_task(self, job, indices=None):
595 603 """Submit a task to any of a subset of our targets."""
596 604 if indices:
597 605 loads = [self.loads[i] for i in indices]
598 606 else:
599 607 loads = self.loads
600 608 idx = self.scheme(loads)
601 609 if indices:
602 610 idx = indices[idx]
603 611 target = self.targets[idx]
604 612 # print (target, map(str, msg[:3]))
605 613 # send job to the engine
606 614 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
607 615 self.engine_stream.send_multipart(job.raw_msg, copy=False)
608 616 # update load
609 617 self.add_job(idx)
610 618 self.pending[target][job.msg_id] = job
611 619 # notify Hub
612 620 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
613 621 self.session.send(self.mon_stream, 'task_destination', content=content,
614 622 ident=[b'tracktask',self.ident])
615 623
616 624
617 625 #-----------------------------------------------------------------------
618 626 # Result Handling
619 627 #-----------------------------------------------------------------------
620 628
621 629
622 630 @util.log_errors
623 631 def dispatch_result(self, raw_msg):
624 632 """dispatch method for result replies"""
625 633 try:
626 634 idents,msg = self.session.feed_identities(raw_msg, copy=False)
627 635 msg = self.session.unserialize(msg, content=False, copy=False)
628 636 engine = idents[0]
629 637 try:
630 638 idx = self.targets.index(engine)
631 639 except ValueError:
632 640 pass # skip load-update for dead engines
633 641 else:
634 642 self.finish_job(idx)
635 643 except Exception:
636 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
644 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
637 645 return
638 646
639 647 md = msg['metadata']
640 648 parent = msg['parent_header']
641 649 if md.get('dependencies_met', True):
642 650 success = (md['status'] == 'ok')
643 651 msg_id = parent['msg_id']
644 652 retries = self.retries[msg_id]
645 653 if not success and retries > 0:
646 654 # failed
647 655 self.retries[msg_id] = retries - 1
648 656 self.handle_unmet_dependency(idents, parent)
649 657 else:
650 658 del self.retries[msg_id]
651 659 # relay to client and update graph
652 660 self.handle_result(idents, parent, raw_msg, success)
653 661 # send to Hub monitor
654 662 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
655 663 else:
656 664 self.handle_unmet_dependency(idents, parent)
657 665
658 666 def handle_result(self, idents, parent, raw_msg, success=True):
659 667 """handle a real task result, either success or failure"""
660 668 # first, relay result to client
661 669 engine = idents[0]
662 670 client = idents[1]
663 671 # swap_ids for ROUTER-ROUTER mirror
664 672 raw_msg[:2] = [client,engine]
665 673 # print (map(str, raw_msg[:4]))
666 674 self.client_stream.send_multipart(raw_msg, copy=False)
667 675 # now, update our data structures
668 676 msg_id = parent['msg_id']
669 677 self.pending[engine].pop(msg_id)
670 678 if success:
671 679 self.completed[engine].add(msg_id)
672 680 self.all_completed.add(msg_id)
673 681 else:
674 682 self.failed[engine].add(msg_id)
675 683 self.all_failed.add(msg_id)
676 684 self.all_done.add(msg_id)
677 685 self.destinations[msg_id] = engine
678 686
679 687 self.update_graph(msg_id, success)
680 688
681 689 def handle_unmet_dependency(self, idents, parent):
682 690 """handle an unmet dependency"""
683 691 engine = idents[0]
684 692 msg_id = parent['msg_id']
685 693
686 694 job = self.pending[engine].pop(msg_id)
687 695 job.blacklist.add(engine)
688 696
689 697 if job.blacklist == job.targets:
690 698 self.queue_map[msg_id] = job
691 699 self.fail_unreachable(msg_id)
692 700 elif not self.maybe_run(job):
693 701 # resubmit failed
694 702 if msg_id not in self.all_failed:
695 703 # put it back in our dependency tree
696 704 self.save_unmet(job)
697 705
698 706 if self.hwm:
699 707 try:
700 708 idx = self.targets.index(engine)
701 709 except ValueError:
702 710 pass # skip load-update for dead engines
703 711 else:
704 712 if self.loads[idx] == self.hwm-1:
705 713 self.update_graph(None)
706 714
707 715 def update_graph(self, dep_id=None, success=True):
708 716 """dep_id just finished. Update our dependency
709 717 graph and submit any jobs that just became runnable.
710 718
711 719 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
712 720 """
713 721 # print ("\n\n***********")
714 722 # pprint (dep_id)
715 723 # pprint (self.graph)
716 724 # pprint (self.queue_map)
717 725 # pprint (self.all_completed)
718 726 # pprint (self.all_failed)
719 727 # print ("\n\n***********\n\n")
720 728 # update any jobs that depended on the dependency
721 729 msg_ids = self.graph.pop(dep_id, [])
722 730
723 731 # recheck *all* jobs if
724 732 # a) we have HWM and an engine just become no longer full
725 733 # or b) dep_id was given as None
726 734
727 735 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
728 736 jobs = self.queue
729 737 using_queue = True
730 738 else:
731 739 using_queue = False
732 740 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
733 741
734 742 to_restore = []
735 743 while jobs:
736 744 job = jobs.popleft()
737 745 if job.removed:
738 746 continue
739 747 msg_id = job.msg_id
740 748
741 749 put_it_back = True
742 750
743 751 if job.after.unreachable(self.all_completed, self.all_failed)\
744 752 or job.follow.unreachable(self.all_completed, self.all_failed):
745 753 self.fail_unreachable(msg_id)
746 754 put_it_back = False
747 755
748 756 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
749 757 if self.maybe_run(job):
750 758 put_it_back = False
751 759 self.queue_map.pop(msg_id)
752 760 for mid in job.dependents:
753 761 if mid in self.graph:
754 762 self.graph[mid].remove(msg_id)
755 763
756 764 # abort the loop if we just filled up all of our engines.
757 765 # avoids an O(N) operation in situation of full queue,
758 766 # where graph update is triggered as soon as an engine becomes
759 767 # non-full, and all tasks after the first are checked,
760 768 # even though they can't run.
761 769 if not self.available_engines():
762 770 break
763 771
764 772 if using_queue and put_it_back:
765 773 # popped a job from the queue but it neither ran nor failed,
766 774 # so we need to put it back when we are done
767 775 # make sure to_restore preserves the same ordering
768 776 to_restore.append(job)
769 777
770 778 # put back any tasks we popped but didn't run
771 779 if using_queue:
772 780 self.queue.extendleft(to_restore)
773 781
774 782 #----------------------------------------------------------------------
775 783 # methods to be overridden by subclasses
776 784 #----------------------------------------------------------------------
777 785
778 786 def add_job(self, idx):
779 787 """Called after self.targets[idx] just got the job with header.
780 788 Override with subclasses. The default ordering is simple LRU.
781 789 The default loads are the number of outstanding jobs."""
782 790 self.loads[idx] += 1
783 791 for lis in (self.targets, self.loads):
784 792 lis.append(lis.pop(idx))
785 793
786 794
787 795 def finish_job(self, idx):
788 796 """Called after self.targets[idx] just finished a job.
789 797 Override with subclasses."""
790 798 self.loads[idx] -= 1
791 799
792 800
793 801
794 802 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
795 803 logname='root', log_url=None, loglevel=logging.DEBUG,
796 804 identity=b'task', in_thread=False):
797 805
798 806 ZMQStream = zmqstream.ZMQStream
799 807
800 808 if config:
801 809 # unwrap dict back into Config
802 810 config = Config(config)
803 811
804 812 if in_thread:
805 813 # use instance() to get the same Context/Loop as our parent
806 814 ctx = zmq.Context.instance()
807 815 loop = ioloop.IOLoop.instance()
808 816 else:
809 817 # in a process, don't use instance()
810 818 # for safety with multiprocessing
811 819 ctx = zmq.Context()
812 820 loop = ioloop.IOLoop()
813 821 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
814 822 util.set_hwm(ins, 0)
815 823 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
816 824 ins.bind(in_addr)
817 825
818 826 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
819 827 util.set_hwm(outs, 0)
820 828 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
821 829 outs.bind(out_addr)
822 830 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
823 831 util.set_hwm(mons, 0)
824 832 mons.connect(mon_addr)
825 833 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
826 834 nots.setsockopt(zmq.SUBSCRIBE, b'')
827 835 nots.connect(not_addr)
828 836
829 837 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
830 838 querys.connect(reg_addr)
831 839
832 840 # setup logging.
833 841 if in_thread:
834 842 log = Application.instance().log
835 843 else:
836 844 if log_url:
837 845 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
838 846 else:
839 847 log = local_logger(logname, loglevel)
840 848
841 849 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
842 850 mon_stream=mons, notifier_stream=nots,
843 851 query_stream=querys,
844 852 loop=loop, log=log,
845 853 config=config)
846 854 scheduler.start()
847 855 if not in_thread:
848 856 try:
849 857 loop.start()
850 858 except KeyboardInterrupt:
851 859 scheduler.log.critical("Interrupted, exiting...")
852 860
General Comments 0
You need to be logged in to leave comments. Login now