##// END OF EJS Templates
use per-timeout callback, rather than audit for timeouts
MinRK -
Show More
@@ -1,844 +1,846 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 import heapq
23 23 import logging
24 24 import sys
25 25 import time
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44 from IPython.utils.py3compat import cast_bytes
45 45
46 46 from IPython.parallel import error, util
47 47 from IPython.parallel.factory import SessionFactory
48 48 from IPython.parallel.util import connect_logger, local_logger
49 49
50 50 from .dependency import Dependency
51 51
52 52 @decorator
53 53 def logged(f,self,*args,**kwargs):
54 54 # print ("#--------------------")
55 55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 56 # print ("#--")
57 57 return f(self,*args, **kwargs)
58 58
59 59 #----------------------------------------------------------------------
60 60 # Chooser functions
61 61 #----------------------------------------------------------------------
62 62
63 63 def plainrandom(loads):
64 64 """Plain random pick."""
65 65 n = len(loads)
66 66 return randint(0,n-1)
67 67
68 68 def lru(loads):
69 69 """Always pick the front of the line.
70 70
71 71 The content of `loads` is ignored.
72 72
73 73 Assumes LRU ordering of loads, with oldest first.
74 74 """
75 75 return 0
76 76
77 77 def twobin(loads):
78 78 """Pick two at random, use the LRU of the two.
79 79
80 80 The content of loads is ignored.
81 81
82 82 Assumes LRU ordering of loads, with oldest first.
83 83 """
84 84 n = len(loads)
85 85 a = randint(0,n-1)
86 86 b = randint(0,n-1)
87 87 return min(a,b)
88 88
89 89 def weighted(loads):
90 90 """Pick two at random using inverse load as weight.
91 91
92 92 Return the less loaded of the two.
93 93 """
94 94 # weight 0 a million times more than 1:
95 95 weights = 1./(1e-6+numpy.array(loads))
96 96 sums = weights.cumsum()
97 97 t = sums[-1]
98 98 x = random()*t
99 99 y = random()*t
100 100 idx = 0
101 101 idy = 0
102 102 while sums[idx] < x:
103 103 idx += 1
104 104 while sums[idy] < y:
105 105 idy += 1
106 106 if weights[idy] > weights[idx]:
107 107 return idy
108 108 else:
109 109 return idx
110 110
111 111 def leastload(loads):
112 112 """Always choose the lowest load.
113 113
114 114 If the lowest load occurs more than once, the first
115 115 occurance will be used. If loads has LRU ordering, this means
116 116 the LRU of those with the lowest load is chosen.
117 117 """
118 118 return loads.index(min(loads))
119 119
120 120 #---------------------------------------------------------------------
121 121 # Classes
122 122 #---------------------------------------------------------------------
123 123
124 124
125 125 # store empty default dependency:
126 126 MET = Dependency([])
127 127
128 128
129 129 class Job(object):
130 130 """Simple container for a job"""
131 131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 132 targets, after, follow, timeout):
133 133 self.msg_id = msg_id
134 134 self.raw_msg = raw_msg
135 135 self.idents = idents
136 136 self.msg = msg
137 137 self.header = header
138 138 self.metadata = metadata
139 139 self.targets = targets
140 140 self.after = after
141 141 self.follow = follow
142 142 self.timeout = timeout
143 143 self.removed = False # used for lazy-delete in heap-sorted queue
144 144
145 145 self.timestamp = time.time()
146 146 self.blacklist = set()
147 147
148 148 def __lt__(self, other):
149 149 return self.timestamp < other.timestamp
150 150
151 151 def __cmp__(self, other):
152 152 return cmp(self.timestamp, other.timestamp)
153 153
154 154 @property
155 155 def dependents(self):
156 156 return self.follow.union(self.after)
157 157
158 158 class TaskScheduler(SessionFactory):
159 159 """Python TaskScheduler object.
160 160
161 161 This is the simplest object that supports msg_id based
162 162 DAG dependencies. *Only* task msg_ids are checked, not
163 163 msg_ids of jobs submitted via the MUX queue.
164 164
165 165 """
166 166
167 167 hwm = Integer(1, config=True,
168 168 help="""specify the High Water Mark (HWM) for the downstream
169 169 socket in the Task scheduler. This is the maximum number
170 170 of allowed outstanding tasks on each engine.
171 171
172 172 The default (1) means that only one task can be outstanding on each
173 173 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
174 174 engines continue to be assigned tasks while they are working,
175 175 effectively hiding network latency behind computation, but can result
176 176 in an imbalance of work when submitting many heterogenous tasks all at
177 177 once. Any positive value greater than one is a compromise between the
178 178 two.
179 179
180 180 """
181 181 )
182 182 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
183 183 'leastload', config=True, allow_none=False,
184 184 help="""select the task scheduler scheme [default: Python LRU]
185 185 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
186 186 )
187 187 def _scheme_name_changed(self, old, new):
188 188 self.log.debug("Using scheme %r"%new)
189 189 self.scheme = globals()[new]
190 190
191 191 # input arguments:
192 192 scheme = Instance(FunctionType) # function for determining the destination
193 193 def _scheme_default(self):
194 194 return leastload
195 195 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
196 196 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
197 197 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
198 198 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
199 199 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
200 200
201 201 # internals:
202 202 queue = List() # heap-sorted list of Jobs
203 203 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
204 204 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
205 205 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
206 206 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
207 207 pending = Dict() # dict by engine_uuid of submitted tasks
208 208 completed = Dict() # dict by engine_uuid of completed tasks
209 209 failed = Dict() # dict by engine_uuid of failed tasks
210 210 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
211 211 clients = Dict() # dict by msg_id for who submitted the task
212 212 targets = List() # list of target IDENTs
213 213 loads = List() # list of engine loads
214 214 # full = Set() # set of IDENTs that have HWM outstanding tasks
215 215 all_completed = Set() # set of all completed tasks
216 216 all_failed = Set() # set of all failed tasks
217 217 all_done = Set() # set of all finished tasks=union(completed,failed)
218 218 all_ids = Set() # set of all submitted task IDs
219 219
220 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
221
222 220 ident = CBytes() # ZMQ identity. This should just be self.session.session
223 221 # but ensure Bytes
224 222 def _ident_default(self):
225 223 return self.session.bsession
226 224
227 225 def start(self):
228 226 self.query_stream.on_recv(self.dispatch_query_reply)
229 227 self.session.send(self.query_stream, "connection_request", {})
230 228
231 229 self.engine_stream.on_recv(self.dispatch_result, copy=False)
232 230 self.client_stream.on_recv(self.dispatch_submission, copy=False)
233 231
234 232 self._notification_handlers = dict(
235 233 registration_notification = self._register_engine,
236 234 unregistration_notification = self._unregister_engine
237 235 )
238 236 self.notifier_stream.on_recv(self.dispatch_notification)
239 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
240 self.auditor.start()
241 237 self.log.info("Scheduler started [%s]"%self.scheme_name)
242 238
243 239 def resume_receiving(self):
244 240 """Resume accepting jobs."""
245 241 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246 242
247 243 def stop_receiving(self):
248 244 """Stop accepting jobs while there are no engines.
249 245 Leave them in the ZMQ queue."""
250 246 self.client_stream.on_recv(None)
251 247
252 248 #-----------------------------------------------------------------------
253 249 # [Un]Registration Handling
254 250 #-----------------------------------------------------------------------
255 251
256 252
257 253 def dispatch_query_reply(self, msg):
258 254 """handle reply to our initial connection request"""
259 255 try:
260 256 idents,msg = self.session.feed_identities(msg)
261 257 except ValueError:
262 258 self.log.warn("task::Invalid Message: %r",msg)
263 259 return
264 260 try:
265 261 msg = self.session.unserialize(msg)
266 262 except ValueError:
267 263 self.log.warn("task::Unauthorized message from: %r"%idents)
268 264 return
269 265
270 266 content = msg['content']
271 267 for uuid in content.get('engines', {}).values():
272 268 self._register_engine(cast_bytes(uuid))
273 269
274 270
275 271 @util.log_errors
276 272 def dispatch_notification(self, msg):
277 273 """dispatch register/unregister events."""
278 274 try:
279 275 idents,msg = self.session.feed_identities(msg)
280 276 except ValueError:
281 277 self.log.warn("task::Invalid Message: %r",msg)
282 278 return
283 279 try:
284 280 msg = self.session.unserialize(msg)
285 281 except ValueError:
286 282 self.log.warn("task::Unauthorized message from: %r"%idents)
287 283 return
288 284
289 285 msg_type = msg['header']['msg_type']
290 286
291 287 handler = self._notification_handlers.get(msg_type, None)
292 288 if handler is None:
293 289 self.log.error("Unhandled message type: %r"%msg_type)
294 290 else:
295 291 try:
296 292 handler(cast_bytes(msg['content']['uuid']))
297 293 except Exception:
298 294 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299 295
300 296 def _register_engine(self, uid):
301 297 """New engine with ident `uid` became available."""
302 298 # head of the line:
303 299 self.targets.insert(0,uid)
304 300 self.loads.insert(0,0)
305 301
306 302 # initialize sets
307 303 self.completed[uid] = set()
308 304 self.failed[uid] = set()
309 305 self.pending[uid] = {}
310 306
311 307 # rescan the graph:
312 308 self.update_graph(None)
313 309
314 310 def _unregister_engine(self, uid):
315 311 """Existing engine with ident `uid` became unavailable."""
316 312 if len(self.targets) == 1:
317 313 # this was our only engine
318 314 pass
319 315
320 316 # handle any potentially finished tasks:
321 317 self.engine_stream.flush()
322 318
323 319 # don't pop destinations, because they might be used later
324 320 # map(self.destinations.pop, self.completed.pop(uid))
325 321 # map(self.destinations.pop, self.failed.pop(uid))
326 322
327 323 # prevent this engine from receiving work
328 324 idx = self.targets.index(uid)
329 325 self.targets.pop(idx)
330 326 self.loads.pop(idx)
331 327
332 328 # wait 5 seconds before cleaning up pending jobs, since the results might
333 329 # still be incoming
334 330 if self.pending[uid]:
335 331 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 332 dc.start()
337 333 else:
338 334 self.completed.pop(uid)
339 335 self.failed.pop(uid)
340 336
341 337
342 338 def handle_stranded_tasks(self, engine):
343 339 """Deal with jobs resident in an engine that died."""
344 340 lost = self.pending[engine]
345 341 for msg_id in lost.keys():
346 342 if msg_id not in self.pending[engine]:
347 343 # prevent double-handling of messages
348 344 continue
349 345
350 346 raw_msg = lost[msg_id].raw_msg
351 347 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 348 parent = self.session.unpack(msg[1].bytes)
353 349 idents = [engine, idents[0]]
354 350
355 351 # build fake error reply
356 352 try:
357 353 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 354 except:
359 355 content = error.wrap_exception()
360 356 # build fake metadata
361 357 md = dict(
362 358 status=u'error',
363 359 engine=engine,
364 360 date=datetime.now(),
365 361 )
366 362 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 363 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
368 364 # and dispatch it
369 365 self.dispatch_result(raw_reply)
370 366
371 367 # finally scrub completed/failed lists
372 368 self.completed.pop(engine)
373 369 self.failed.pop(engine)
374 370
375 371
376 372 #-----------------------------------------------------------------------
377 373 # Job Submission
378 374 #-----------------------------------------------------------------------
379 375
380 376
381 377 @util.log_errors
382 378 def dispatch_submission(self, raw_msg):
383 379 """Dispatch job submission to appropriate handlers."""
384 380 # ensure targets up to date:
385 381 self.notifier_stream.flush()
386 382 try:
387 383 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 384 msg = self.session.unserialize(msg, content=False, copy=False)
389 385 except Exception:
390 386 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 387 return
392 388
393 389
394 390 # send to monitor
395 391 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396 392
397 393 header = msg['header']
398 394 md = msg['metadata']
399 395 msg_id = header['msg_id']
400 396 self.all_ids.add(msg_id)
401 397
402 398 # get targets as a set of bytes objects
403 399 # from a list of unicode objects
404 400 targets = md.get('targets', [])
405 401 targets = map(cast_bytes, targets)
406 402 targets = set(targets)
407 403
408 404 retries = md.get('retries', 0)
409 405 self.retries[msg_id] = retries
410 406
411 407 # time dependencies
412 408 after = md.get('after', None)
413 409 if after:
414 410 after = Dependency(after)
415 411 if after.all:
416 412 if after.success:
417 413 after = Dependency(after.difference(self.all_completed),
418 414 success=after.success,
419 415 failure=after.failure,
420 416 all=after.all,
421 417 )
422 418 if after.failure:
423 419 after = Dependency(after.difference(self.all_failed),
424 420 success=after.success,
425 421 failure=after.failure,
426 422 all=after.all,
427 423 )
428 424 if after.check(self.all_completed, self.all_failed):
429 425 # recast as empty set, if `after` already met,
430 426 # to prevent unnecessary set comparisons
431 427 after = MET
432 428 else:
433 429 after = MET
434 430
435 431 # location dependencies
436 432 follow = Dependency(md.get('follow', []))
437 433
438 434 # turn timeouts into datetime objects:
439 435 timeout = md.get('timeout', None)
440 436 if timeout:
441 # cast to float, because jsonlib returns floats as decimal.Decimal,
442 # which timedelta does not accept
443 timeout = datetime.now() + timedelta(0,float(timeout),0)
437 timeout = time.time() + float(timeout)
444 438
445 439 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
446 440 header=header, targets=targets, after=after, follow=follow,
447 441 timeout=timeout, metadata=md,
448 442 )
443 if timeout:
444 # schedule timeout callback
445 self.loop.add_timeout(timeout, lambda : self.job_timeout(job))
449 446
450 447 # validate and reduce dependencies:
451 448 for dep in after,follow:
452 449 if not dep: # empty dependency
453 450 continue
454 451 # check valid:
455 452 if msg_id in dep or dep.difference(self.all_ids):
456 453 self.queue_map[msg_id] = job
457 454 return self.fail_unreachable(msg_id, error.InvalidDependency)
458 455 # check if unreachable:
459 456 if dep.unreachable(self.all_completed, self.all_failed):
460 457 self.queue_map[msg_id] = job
461 458 return self.fail_unreachable(msg_id)
462 459
463 460 if after.check(self.all_completed, self.all_failed):
464 461 # time deps already met, try to run
465 462 if not self.maybe_run(job):
466 463 # can't run yet
467 464 if msg_id not in self.all_failed:
468 465 # could have failed as unreachable
469 466 self.save_unmet(job)
470 467 else:
471 468 self.save_unmet(job)
472 469
473 def audit_timeouts(self):
474 """Audit all waiting tasks for expired timeouts."""
475 now = datetime.now()
476 for msg_id in self.queue_map.keys():
477 # must recheck, in case one failure cascaded to another:
478 if msg_id in self.queue_map:
479 job = self.queue_map[msg_id]
480 if job.timeout and job.timeout < now:
481 self.fail_unreachable(msg_id, error.TaskTimeout)
470 def job_timeout(self, job):
471 """callback for a job's timeout.
472
473 The job may or may not have been run at this point.
474 """
475 if job.timeout >= (time.time() + 1):
476 self.log.warn("task %s timeout fired prematurely: %s > %s",
477 job.msg_id, job.timeout, now
478 )
479 if job.msg_id in self.queue_map:
480 # still waiting, but ran out of time
481 self.log.info("task %r timed out", job.msg_id)
482 self.fail_unreachable(job.msg_id, error.TaskTimeout)
482 483
483 484 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
484 485 """a task has become unreachable, send a reply with an ImpossibleDependency
485 486 error."""
486 487 if msg_id not in self.queue_map:
487 self.log.error("msg %r already failed!", msg_id)
488 self.log.error("task %r already failed!", msg_id)
488 489 return
489 490 job = self.queue_map.pop(msg_id)
490 491 # lazy-delete from the queue
491 492 job.removed = True
492 493 for mid in job.dependents:
493 494 if mid in self.graph:
494 495 self.graph[mid].remove(msg_id)
495 496
496 497 try:
497 498 raise why()
498 499 except:
499 500 content = error.wrap_exception()
501 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
500 502
501 503 self.all_done.add(msg_id)
502 504 self.all_failed.add(msg_id)
503 505
504 506 msg = self.session.send(self.client_stream, 'apply_reply', content,
505 507 parent=job.header, ident=job.idents)
506 508 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
507 509
508 510 self.update_graph(msg_id, success=False)
509 511
510 512 def available_engines(self):
511 513 """return a list of available engine indices based on HWM"""
512 514 if not self.hwm:
513 515 return range(len(self.targets))
514 516 available = []
515 517 for idx in range(len(self.targets)):
516 518 if self.loads[idx] < self.hwm:
517 519 available.append(idx)
518 520 return available
519 521
520 522 def maybe_run(self, job):
521 523 """check location dependencies, and run if they are met."""
522 524 msg_id = job.msg_id
523 525 self.log.debug("Attempting to assign task %s", msg_id)
524 526 available = self.available_engines()
525 527 if not available:
526 528 # no engines, definitely can't run
527 529 return False
528 530
529 531 if job.follow or job.targets or job.blacklist or self.hwm:
530 532 # we need a can_run filter
531 533 def can_run(idx):
532 534 # check hwm
533 535 if self.hwm and self.loads[idx] == self.hwm:
534 536 return False
535 537 target = self.targets[idx]
536 538 # check blacklist
537 539 if target in job.blacklist:
538 540 return False
539 541 # check targets
540 542 if job.targets and target not in job.targets:
541 543 return False
542 544 # check follow
543 545 return job.follow.check(self.completed[target], self.failed[target])
544 546
545 547 indices = filter(can_run, available)
546 548
547 549 if not indices:
548 550 # couldn't run
549 551 if job.follow.all:
550 552 # check follow for impossibility
551 553 dests = set()
552 554 relevant = set()
553 555 if job.follow.success:
554 556 relevant = self.all_completed
555 557 if job.follow.failure:
556 558 relevant = relevant.union(self.all_failed)
557 559 for m in job.follow.intersection(relevant):
558 560 dests.add(self.destinations[m])
559 561 if len(dests) > 1:
560 562 self.queue_map[msg_id] = job
561 563 self.fail_unreachable(msg_id)
562 564 return False
563 565 if job.targets:
564 566 # check blacklist+targets for impossibility
565 567 job.targets.difference_update(job.blacklist)
566 568 if not job.targets or not job.targets.intersection(self.targets):
567 569 self.queue_map[msg_id] = job
568 570 self.fail_unreachable(msg_id)
569 571 return False
570 572 return False
571 573 else:
572 574 indices = None
573 575
574 576 self.submit_task(job, indices)
575 577 return True
576 578
577 579 def save_unmet(self, job):
578 580 """Save a message for later submission when its dependencies are met."""
579 581 msg_id = job.msg_id
580 582 self.log.debug("Adding task %s to the queue", msg_id)
581 583 self.queue_map[msg_id] = job
582 584 heapq.heappush(self.queue, job)
583 585 # track the ids in follow or after, but not those already finished
584 586 for dep_id in job.after.union(job.follow).difference(self.all_done):
585 587 if dep_id not in self.graph:
586 588 self.graph[dep_id] = set()
587 589 self.graph[dep_id].add(msg_id)
588 590
589 591 def submit_task(self, job, indices=None):
590 592 """Submit a task to any of a subset of our targets."""
591 593 if indices:
592 594 loads = [self.loads[i] for i in indices]
593 595 else:
594 596 loads = self.loads
595 597 idx = self.scheme(loads)
596 598 if indices:
597 599 idx = indices[idx]
598 600 target = self.targets[idx]
599 601 # print (target, map(str, msg[:3]))
600 602 # send job to the engine
601 603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
602 604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
603 605 # update load
604 606 self.add_job(idx)
605 607 self.pending[target][job.msg_id] = job
606 608 # notify Hub
607 609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
608 610 self.session.send(self.mon_stream, 'task_destination', content=content,
609 611 ident=[b'tracktask',self.ident])
610 612
611 613
612 614 #-----------------------------------------------------------------------
613 615 # Result Handling
614 616 #-----------------------------------------------------------------------
615 617
616 618
617 619 @util.log_errors
618 620 def dispatch_result(self, raw_msg):
619 621 """dispatch method for result replies"""
620 622 try:
621 623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
622 624 msg = self.session.unserialize(msg, content=False, copy=False)
623 625 engine = idents[0]
624 626 try:
625 627 idx = self.targets.index(engine)
626 628 except ValueError:
627 629 pass # skip load-update for dead engines
628 630 else:
629 631 self.finish_job(idx)
630 632 except Exception:
631 633 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
632 634 return
633 635
634 636 md = msg['metadata']
635 637 parent = msg['parent_header']
636 638 if md.get('dependencies_met', True):
637 639 success = (md['status'] == 'ok')
638 640 msg_id = parent['msg_id']
639 641 retries = self.retries[msg_id]
640 642 if not success and retries > 0:
641 643 # failed
642 644 self.retries[msg_id] = retries - 1
643 645 self.handle_unmet_dependency(idents, parent)
644 646 else:
645 647 del self.retries[msg_id]
646 648 # relay to client and update graph
647 649 self.handle_result(idents, parent, raw_msg, success)
648 650 # send to Hub monitor
649 651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
650 652 else:
651 653 self.handle_unmet_dependency(idents, parent)
652 654
653 655 def handle_result(self, idents, parent, raw_msg, success=True):
654 656 """handle a real task result, either success or failure"""
655 657 # first, relay result to client
656 658 engine = idents[0]
657 659 client = idents[1]
658 660 # swap_ids for ROUTER-ROUTER mirror
659 661 raw_msg[:2] = [client,engine]
660 662 # print (map(str, raw_msg[:4]))
661 663 self.client_stream.send_multipart(raw_msg, copy=False)
662 664 # now, update our data structures
663 665 msg_id = parent['msg_id']
664 666 self.pending[engine].pop(msg_id)
665 667 if success:
666 668 self.completed[engine].add(msg_id)
667 669 self.all_completed.add(msg_id)
668 670 else:
669 671 self.failed[engine].add(msg_id)
670 672 self.all_failed.add(msg_id)
671 673 self.all_done.add(msg_id)
672 674 self.destinations[msg_id] = engine
673 675
674 676 self.update_graph(msg_id, success)
675 677
676 678 def handle_unmet_dependency(self, idents, parent):
677 679 """handle an unmet dependency"""
678 680 engine = idents[0]
679 681 msg_id = parent['msg_id']
680 682
681 683 job = self.pending[engine].pop(msg_id)
682 684 job.blacklist.add(engine)
683 685
684 686 if job.blacklist == job.targets:
685 687 self.queue_map[msg_id] = job
686 688 self.fail_unreachable(msg_id)
687 689 elif not self.maybe_run(job):
688 690 # resubmit failed
689 691 if msg_id not in self.all_failed:
690 692 # put it back in our dependency tree
691 693 self.save_unmet(job)
692 694
693 695 if self.hwm:
694 696 try:
695 697 idx = self.targets.index(engine)
696 698 except ValueError:
697 699 pass # skip load-update for dead engines
698 700 else:
699 701 if self.loads[idx] == self.hwm-1:
700 702 self.update_graph(None)
701 703
702 704 def update_graph(self, dep_id=None, success=True):
703 705 """dep_id just finished. Update our dependency
704 706 graph and submit any jobs that just became runnable.
705 707
706 708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
707 709 """
708 710 # print ("\n\n***********")
709 711 # pprint (dep_id)
710 712 # pprint (self.graph)
711 713 # pprint (self.queue_map)
712 714 # pprint (self.all_completed)
713 715 # pprint (self.all_failed)
714 716 # print ("\n\n***********\n\n")
715 717 # update any jobs that depended on the dependency
716 718 msg_ids = self.graph.pop(dep_id, [])
717 719
718 720 # recheck *all* jobs if
719 721 # a) we have HWM and an engine just become no longer full
720 722 # or b) dep_id was given as None
721 723
722 724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
723 725 jobs = self.queue
724 726 using_queue = True
725 727 else:
726 728 using_queue = False
727 729 jobs = heapq.heapify([ self.queue_map[msg_id] for msg_id in msg_ids ])
728 730
729 731 to_restore = []
730 732 while jobs:
731 733 job = heapq.heappop(jobs)
732 734 if job.removed:
733 735 continue
734 736 msg_id = job.msg_id
735 737
736 738 put_it_back = True
737 739
738 740 if job.after.unreachable(self.all_completed, self.all_failed)\
739 741 or job.follow.unreachable(self.all_completed, self.all_failed):
740 742 self.fail_unreachable(msg_id)
741 743 put_it_back = False
742 744
743 745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
744 746 if self.maybe_run(job):
745 747 put_it_back = False
746 748 self.queue_map.pop(msg_id)
747 749 for mid in job.dependents:
748 750 if mid in self.graph:
749 751 self.graph[mid].remove(msg_id)
750 752
751 753 # abort the loop if we just filled up all of our engines.
752 754 # avoids an O(N) operation in situation of full queue,
753 755 # where graph update is triggered as soon as an engine becomes
754 756 # non-full, and all tasks after the first are checked,
755 757 # even though they can't run.
756 758 if not self.available_engines():
757 759 break
758 760
759 761 if using_queue and put_it_back:
760 762 # popped a job from the queue but it neither ran nor failed,
761 763 # so we need to put it back when we are done
762 764 to_restore.append(job)
763 765
764 766 # put back any tasks we popped but didn't run
765 767 for job in to_restore:
766 768 heapq.heappush(self.queue, job)
767 769
768 770
769 771 #----------------------------------------------------------------------
770 772 # methods to be overridden by subclasses
771 773 #----------------------------------------------------------------------
772 774
773 775 def add_job(self, idx):
774 776 """Called after self.targets[idx] just got the job with header.
775 777 Override with subclasses. The default ordering is simple LRU.
776 778 The default loads are the number of outstanding jobs."""
777 779 self.loads[idx] += 1
778 780 for lis in (self.targets, self.loads):
779 781 lis.append(lis.pop(idx))
780 782
781 783
782 784 def finish_job(self, idx):
783 785 """Called after self.targets[idx] just finished a job.
784 786 Override with subclasses."""
785 787 self.loads[idx] -= 1
786 788
787 789
788 790
789 791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
790 792 logname='root', log_url=None, loglevel=logging.DEBUG,
791 793 identity=b'task', in_thread=False):
792 794
793 795 ZMQStream = zmqstream.ZMQStream
794
796 loglevel = logging.DEBUG
795 797 if config:
796 798 # unwrap dict back into Config
797 799 config = Config(config)
798 800
799 801 if in_thread:
800 802 # use instance() to get the same Context/Loop as our parent
801 803 ctx = zmq.Context.instance()
802 804 loop = ioloop.IOLoop.instance()
803 805 else:
804 806 # in a process, don't use instance()
805 807 # for safety with multiprocessing
806 808 ctx = zmq.Context()
807 809 loop = ioloop.IOLoop()
808 810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
809 811 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
810 812 ins.bind(in_addr)
811 813
812 814 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
813 815 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
814 816 outs.bind(out_addr)
815 817 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
816 818 mons.connect(mon_addr)
817 819 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
818 820 nots.setsockopt(zmq.SUBSCRIBE, b'')
819 821 nots.connect(not_addr)
820 822
821 823 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
822 824 querys.connect(reg_addr)
823 825
824 826 # setup logging.
825 827 if in_thread:
826 828 log = Application.instance().log
827 829 else:
828 830 if log_url:
829 831 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
830 832 else:
831 833 log = local_logger(logname, loglevel)
832 834
833 835 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
834 836 mon_stream=mons, notifier_stream=nots,
835 837 query_stream=querys,
836 838 loop=loop, log=log,
837 839 config=config)
838 840 scheduler.start()
839 841 if not in_thread:
840 842 try:
841 843 loop.start()
842 844 except KeyboardInterrupt:
843 845 scheduler.log.critical("Interrupted, exiting...")
844 846
General Comments 0
You need to be logged in to leave comments. Login now