##// END OF EJS Templates
use deque instead of heapq
MinRK -
Show More
@@ -1,846 +1,849 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 import heapq
23 22 import logging
24 23 import sys
25 24 import time
26 25
27 from datetime import datetime, timedelta
26 from collections import deque
27 from datetime import datetime
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44 from IPython.utils.py3compat import cast_bytes
45 45
46 46 from IPython.parallel import error, util
47 47 from IPython.parallel.factory import SessionFactory
48 48 from IPython.parallel.util import connect_logger, local_logger
49 49
50 50 from .dependency import Dependency
51 51
52 52 @decorator
53 53 def logged(f,self,*args,**kwargs):
54 54 # print ("#--------------------")
55 55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 56 # print ("#--")
57 57 return f(self,*args, **kwargs)
58 58
59 59 #----------------------------------------------------------------------
60 60 # Chooser functions
61 61 #----------------------------------------------------------------------
62 62
63 63 def plainrandom(loads):
64 64 """Plain random pick."""
65 65 n = len(loads)
66 66 return randint(0,n-1)
67 67
68 68 def lru(loads):
69 69 """Always pick the front of the line.
70 70
71 71 The content of `loads` is ignored.
72 72
73 73 Assumes LRU ordering of loads, with oldest first.
74 74 """
75 75 return 0
76 76
77 77 def twobin(loads):
78 78 """Pick two at random, use the LRU of the two.
79 79
80 80 The content of loads is ignored.
81 81
82 82 Assumes LRU ordering of loads, with oldest first.
83 83 """
84 84 n = len(loads)
85 85 a = randint(0,n-1)
86 86 b = randint(0,n-1)
87 87 return min(a,b)
88 88
89 89 def weighted(loads):
90 90 """Pick two at random using inverse load as weight.
91 91
92 92 Return the less loaded of the two.
93 93 """
94 94 # weight 0 a million times more than 1:
95 95 weights = 1./(1e-6+numpy.array(loads))
96 96 sums = weights.cumsum()
97 97 t = sums[-1]
98 98 x = random()*t
99 99 y = random()*t
100 100 idx = 0
101 101 idy = 0
102 102 while sums[idx] < x:
103 103 idx += 1
104 104 while sums[idy] < y:
105 105 idy += 1
106 106 if weights[idy] > weights[idx]:
107 107 return idy
108 108 else:
109 109 return idx
110 110
111 111 def leastload(loads):
112 112 """Always choose the lowest load.
113 113
114 114 If the lowest load occurs more than once, the first
115 115 occurance will be used. If loads has LRU ordering, this means
116 116 the LRU of those with the lowest load is chosen.
117 117 """
118 118 return loads.index(min(loads))
119 119
120 120 #---------------------------------------------------------------------
121 121 # Classes
122 122 #---------------------------------------------------------------------
123 123
124 124
125 125 # store empty default dependency:
126 126 MET = Dependency([])
127 127
128 128
129 129 class Job(object):
130 130 """Simple container for a job"""
131 131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 132 targets, after, follow, timeout):
133 133 self.msg_id = msg_id
134 134 self.raw_msg = raw_msg
135 135 self.idents = idents
136 136 self.msg = msg
137 137 self.header = header
138 138 self.metadata = metadata
139 139 self.targets = targets
140 140 self.after = after
141 141 self.follow = follow
142 142 self.timeout = timeout
143 self.removed = False # used for lazy-delete in heap-sorted queue
143 self.removed = False # used for lazy-delete from sorted queue
144 144
145 145 self.timestamp = time.time()
146 146 self.blacklist = set()
147 147
148 148 def __lt__(self, other):
149 149 return self.timestamp < other.timestamp
150 150
151 151 def __cmp__(self, other):
152 152 return cmp(self.timestamp, other.timestamp)
153 153
154 154 @property
155 155 def dependents(self):
156 156 return self.follow.union(self.after)
157 157
158 158 class TaskScheduler(SessionFactory):
159 159 """Python TaskScheduler object.
160 160
161 161 This is the simplest object that supports msg_id based
162 162 DAG dependencies. *Only* task msg_ids are checked, not
163 163 msg_ids of jobs submitted via the MUX queue.
164 164
165 165 """
166 166
167 167 hwm = Integer(1, config=True,
168 168 help="""specify the High Water Mark (HWM) for the downstream
169 169 socket in the Task scheduler. This is the maximum number
170 170 of allowed outstanding tasks on each engine.
171 171
172 172 The default (1) means that only one task can be outstanding on each
173 173 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
174 174 engines continue to be assigned tasks while they are working,
175 175 effectively hiding network latency behind computation, but can result
176 176 in an imbalance of work when submitting many heterogenous tasks all at
177 177 once. Any positive value greater than one is a compromise between the
178 178 two.
179 179
180 180 """
181 181 )
182 182 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
183 183 'leastload', config=True, allow_none=False,
184 184 help="""select the task scheduler scheme [default: Python LRU]
185 185 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
186 186 )
187 187 def _scheme_name_changed(self, old, new):
188 188 self.log.debug("Using scheme %r"%new)
189 189 self.scheme = globals()[new]
190 190
191 191 # input arguments:
192 192 scheme = Instance(FunctionType) # function for determining the destination
193 193 def _scheme_default(self):
194 194 return leastload
195 195 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
196 196 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
197 197 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
198 198 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
199 199 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
200 200
201 201 # internals:
202 queue = List() # heap-sorted list of Jobs
202 queue = Instance(deque) # sorted list of Jobs
203 def _queue_default(self):
204 return deque()
203 205 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
204 206 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
205 207 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
206 208 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
207 209 pending = Dict() # dict by engine_uuid of submitted tasks
208 210 completed = Dict() # dict by engine_uuid of completed tasks
209 211 failed = Dict() # dict by engine_uuid of failed tasks
210 212 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
211 213 clients = Dict() # dict by msg_id for who submitted the task
212 214 targets = List() # list of target IDENTs
213 215 loads = List() # list of engine loads
214 216 # full = Set() # set of IDENTs that have HWM outstanding tasks
215 217 all_completed = Set() # set of all completed tasks
216 218 all_failed = Set() # set of all failed tasks
217 219 all_done = Set() # set of all finished tasks=union(completed,failed)
218 220 all_ids = Set() # set of all submitted task IDs
219 221
220 222 ident = CBytes() # ZMQ identity. This should just be self.session.session
221 223 # but ensure Bytes
222 224 def _ident_default(self):
223 225 return self.session.bsession
224 226
225 227 def start(self):
226 228 self.query_stream.on_recv(self.dispatch_query_reply)
227 229 self.session.send(self.query_stream, "connection_request", {})
228 230
229 231 self.engine_stream.on_recv(self.dispatch_result, copy=False)
230 232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
231 233
232 234 self._notification_handlers = dict(
233 235 registration_notification = self._register_engine,
234 236 unregistration_notification = self._unregister_engine
235 237 )
236 238 self.notifier_stream.on_recv(self.dispatch_notification)
237 239 self.log.info("Scheduler started [%s]" % self.scheme_name)
238 240
239 241 def resume_receiving(self):
240 242 """Resume accepting jobs."""
241 243 self.client_stream.on_recv(self.dispatch_submission, copy=False)
242 244
243 245 def stop_receiving(self):
244 246 """Stop accepting jobs while there are no engines.
245 247 Leave them in the ZMQ queue."""
246 248 self.client_stream.on_recv(None)
247 249
248 250 #-----------------------------------------------------------------------
249 251 # [Un]Registration Handling
250 252 #-----------------------------------------------------------------------
251 253
252 254
253 255 def dispatch_query_reply(self, msg):
254 256 """handle reply to our initial connection request"""
255 257 try:
256 258 idents,msg = self.session.feed_identities(msg)
257 259 except ValueError:
258 260 self.log.warn("task::Invalid Message: %r",msg)
259 261 return
260 262 try:
261 263 msg = self.session.unserialize(msg)
262 264 except ValueError:
263 265 self.log.warn("task::Unauthorized message from: %r"%idents)
264 266 return
265 267
266 268 content = msg['content']
267 269 for uuid in content.get('engines', {}).values():
268 270 self._register_engine(cast_bytes(uuid))
269 271
270 272
271 273 @util.log_errors
272 274 def dispatch_notification(self, msg):
273 275 """dispatch register/unregister events."""
274 276 try:
275 277 idents,msg = self.session.feed_identities(msg)
276 278 except ValueError:
277 279 self.log.warn("task::Invalid Message: %r",msg)
278 280 return
279 281 try:
280 282 msg = self.session.unserialize(msg)
281 283 except ValueError:
282 284 self.log.warn("task::Unauthorized message from: %r"%idents)
283 285 return
284 286
285 287 msg_type = msg['header']['msg_type']
286 288
287 289 handler = self._notification_handlers.get(msg_type, None)
288 290 if handler is None:
289 291 self.log.error("Unhandled message type: %r"%msg_type)
290 292 else:
291 293 try:
292 294 handler(cast_bytes(msg['content']['uuid']))
293 295 except Exception:
294 296 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
295 297
296 298 def _register_engine(self, uid):
297 299 """New engine with ident `uid` became available."""
298 300 # head of the line:
299 301 self.targets.insert(0,uid)
300 302 self.loads.insert(0,0)
301 303
302 304 # initialize sets
303 305 self.completed[uid] = set()
304 306 self.failed[uid] = set()
305 307 self.pending[uid] = {}
306 308
307 309 # rescan the graph:
308 310 self.update_graph(None)
309 311
310 312 def _unregister_engine(self, uid):
311 313 """Existing engine with ident `uid` became unavailable."""
312 314 if len(self.targets) == 1:
313 315 # this was our only engine
314 316 pass
315 317
316 318 # handle any potentially finished tasks:
317 319 self.engine_stream.flush()
318 320
319 321 # don't pop destinations, because they might be used later
320 322 # map(self.destinations.pop, self.completed.pop(uid))
321 323 # map(self.destinations.pop, self.failed.pop(uid))
322 324
323 325 # prevent this engine from receiving work
324 326 idx = self.targets.index(uid)
325 327 self.targets.pop(idx)
326 328 self.loads.pop(idx)
327 329
328 330 # wait 5 seconds before cleaning up pending jobs, since the results might
329 331 # still be incoming
330 332 if self.pending[uid]:
331 333 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
332 334 dc.start()
333 335 else:
334 336 self.completed.pop(uid)
335 337 self.failed.pop(uid)
336 338
337 339
338 340 def handle_stranded_tasks(self, engine):
339 341 """Deal with jobs resident in an engine that died."""
340 342 lost = self.pending[engine]
341 343 for msg_id in lost.keys():
342 344 if msg_id not in self.pending[engine]:
343 345 # prevent double-handling of messages
344 346 continue
345 347
346 348 raw_msg = lost[msg_id].raw_msg
347 349 idents,msg = self.session.feed_identities(raw_msg, copy=False)
348 350 parent = self.session.unpack(msg[1].bytes)
349 351 idents = [engine, idents[0]]
350 352
351 353 # build fake error reply
352 354 try:
353 355 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
354 356 except:
355 357 content = error.wrap_exception()
356 358 # build fake metadata
357 359 md = dict(
358 360 status=u'error',
359 361 engine=engine,
360 362 date=datetime.now(),
361 363 )
362 364 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
363 365 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
364 366 # and dispatch it
365 367 self.dispatch_result(raw_reply)
366 368
367 369 # finally scrub completed/failed lists
368 370 self.completed.pop(engine)
369 371 self.failed.pop(engine)
370 372
371 373
372 374 #-----------------------------------------------------------------------
373 375 # Job Submission
374 376 #-----------------------------------------------------------------------
375 377
376 378
377 379 @util.log_errors
378 380 def dispatch_submission(self, raw_msg):
379 381 """Dispatch job submission to appropriate handlers."""
380 382 # ensure targets up to date:
381 383 self.notifier_stream.flush()
382 384 try:
383 385 idents, msg = self.session.feed_identities(raw_msg, copy=False)
384 386 msg = self.session.unserialize(msg, content=False, copy=False)
385 387 except Exception:
386 388 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
387 389 return
388 390
389 391
390 392 # send to monitor
391 393 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
392 394
393 395 header = msg['header']
394 396 md = msg['metadata']
395 397 msg_id = header['msg_id']
396 398 self.all_ids.add(msg_id)
397 399
398 400 # get targets as a set of bytes objects
399 401 # from a list of unicode objects
400 402 targets = md.get('targets', [])
401 403 targets = map(cast_bytes, targets)
402 404 targets = set(targets)
403 405
404 406 retries = md.get('retries', 0)
405 407 self.retries[msg_id] = retries
406 408
407 409 # time dependencies
408 410 after = md.get('after', None)
409 411 if after:
410 412 after = Dependency(after)
411 413 if after.all:
412 414 if after.success:
413 415 after = Dependency(after.difference(self.all_completed),
414 416 success=after.success,
415 417 failure=after.failure,
416 418 all=after.all,
417 419 )
418 420 if after.failure:
419 421 after = Dependency(after.difference(self.all_failed),
420 422 success=after.success,
421 423 failure=after.failure,
422 424 all=after.all,
423 425 )
424 426 if after.check(self.all_completed, self.all_failed):
425 427 # recast as empty set, if `after` already met,
426 428 # to prevent unnecessary set comparisons
427 429 after = MET
428 430 else:
429 431 after = MET
430 432
431 433 # location dependencies
432 434 follow = Dependency(md.get('follow', []))
433 435
434 436 # turn timeouts into datetime objects:
435 437 timeout = md.get('timeout', None)
436 438 if timeout:
437 439 timeout = time.time() + float(timeout)
438 440
439 441 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
440 442 header=header, targets=targets, after=after, follow=follow,
441 443 timeout=timeout, metadata=md,
442 444 )
443 445 if timeout:
444 446 # schedule timeout callback
445 447 self.loop.add_timeout(timeout, lambda : self.job_timeout(job))
446 448
447 449 # validate and reduce dependencies:
448 450 for dep in after,follow:
449 451 if not dep: # empty dependency
450 452 continue
451 453 # check valid:
452 454 if msg_id in dep or dep.difference(self.all_ids):
453 455 self.queue_map[msg_id] = job
454 456 return self.fail_unreachable(msg_id, error.InvalidDependency)
455 457 # check if unreachable:
456 458 if dep.unreachable(self.all_completed, self.all_failed):
457 459 self.queue_map[msg_id] = job
458 460 return self.fail_unreachable(msg_id)
459 461
460 462 if after.check(self.all_completed, self.all_failed):
461 463 # time deps already met, try to run
462 464 if not self.maybe_run(job):
463 465 # can't run yet
464 466 if msg_id not in self.all_failed:
465 467 # could have failed as unreachable
466 468 self.save_unmet(job)
467 469 else:
468 470 self.save_unmet(job)
469 471
470 472 def job_timeout(self, job):
471 473 """callback for a job's timeout.
472 474
473 475 The job may or may not have been run at this point.
474 476 """
475 if job.timeout >= (time.time() + 1):
477 now = time.time()
478 if job.timeout >= (now + 1):
476 479 self.log.warn("task %s timeout fired prematurely: %s > %s",
477 480 job.msg_id, job.timeout, now
478 481 )
479 482 if job.msg_id in self.queue_map:
480 483 # still waiting, but ran out of time
481 484 self.log.info("task %r timed out", job.msg_id)
482 485 self.fail_unreachable(job.msg_id, error.TaskTimeout)
483 486
484 487 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
485 488 """a task has become unreachable, send a reply with an ImpossibleDependency
486 489 error."""
487 490 if msg_id not in self.queue_map:
488 491 self.log.error("task %r already failed!", msg_id)
489 492 return
490 493 job = self.queue_map.pop(msg_id)
491 494 # lazy-delete from the queue
492 495 job.removed = True
493 496 for mid in job.dependents:
494 497 if mid in self.graph:
495 498 self.graph[mid].remove(msg_id)
496 499
497 500 try:
498 501 raise why()
499 502 except:
500 503 content = error.wrap_exception()
501 504 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
502 505
503 506 self.all_done.add(msg_id)
504 507 self.all_failed.add(msg_id)
505 508
506 509 msg = self.session.send(self.client_stream, 'apply_reply', content,
507 510 parent=job.header, ident=job.idents)
508 511 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
509 512
510 513 self.update_graph(msg_id, success=False)
511 514
512 515 def available_engines(self):
513 516 """return a list of available engine indices based on HWM"""
514 517 if not self.hwm:
515 518 return range(len(self.targets))
516 519 available = []
517 520 for idx in range(len(self.targets)):
518 521 if self.loads[idx] < self.hwm:
519 522 available.append(idx)
520 523 return available
521 524
522 525 def maybe_run(self, job):
523 526 """check location dependencies, and run if they are met."""
524 527 msg_id = job.msg_id
525 528 self.log.debug("Attempting to assign task %s", msg_id)
526 529 available = self.available_engines()
527 530 if not available:
528 531 # no engines, definitely can't run
529 532 return False
530 533
531 534 if job.follow or job.targets or job.blacklist or self.hwm:
532 535 # we need a can_run filter
533 536 def can_run(idx):
534 537 # check hwm
535 538 if self.hwm and self.loads[idx] == self.hwm:
536 539 return False
537 540 target = self.targets[idx]
538 541 # check blacklist
539 542 if target in job.blacklist:
540 543 return False
541 544 # check targets
542 545 if job.targets and target not in job.targets:
543 546 return False
544 547 # check follow
545 548 return job.follow.check(self.completed[target], self.failed[target])
546 549
547 550 indices = filter(can_run, available)
548 551
549 552 if not indices:
550 553 # couldn't run
551 554 if job.follow.all:
552 555 # check follow for impossibility
553 556 dests = set()
554 557 relevant = set()
555 558 if job.follow.success:
556 559 relevant = self.all_completed
557 560 if job.follow.failure:
558 561 relevant = relevant.union(self.all_failed)
559 562 for m in job.follow.intersection(relevant):
560 563 dests.add(self.destinations[m])
561 564 if len(dests) > 1:
562 565 self.queue_map[msg_id] = job
563 566 self.fail_unreachable(msg_id)
564 567 return False
565 568 if job.targets:
566 569 # check blacklist+targets for impossibility
567 570 job.targets.difference_update(job.blacklist)
568 571 if not job.targets or not job.targets.intersection(self.targets):
569 572 self.queue_map[msg_id] = job
570 573 self.fail_unreachable(msg_id)
571 574 return False
572 575 return False
573 576 else:
574 577 indices = None
575 578
576 579 self.submit_task(job, indices)
577 580 return True
578 581
579 582 def save_unmet(self, job):
580 583 """Save a message for later submission when its dependencies are met."""
581 584 msg_id = job.msg_id
582 585 self.log.debug("Adding task %s to the queue", msg_id)
583 586 self.queue_map[msg_id] = job
584 heapq.heappush(self.queue, job)
587 self.queue.append(job)
585 588 # track the ids in follow or after, but not those already finished
586 589 for dep_id in job.after.union(job.follow).difference(self.all_done):
587 590 if dep_id not in self.graph:
588 591 self.graph[dep_id] = set()
589 592 self.graph[dep_id].add(msg_id)
590 593
591 594 def submit_task(self, job, indices=None):
592 595 """Submit a task to any of a subset of our targets."""
593 596 if indices:
594 597 loads = [self.loads[i] for i in indices]
595 598 else:
596 599 loads = self.loads
597 600 idx = self.scheme(loads)
598 601 if indices:
599 602 idx = indices[idx]
600 603 target = self.targets[idx]
601 604 # print (target, map(str, msg[:3]))
602 605 # send job to the engine
603 606 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 607 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 608 # update load
606 609 self.add_job(idx)
607 610 self.pending[target][job.msg_id] = job
608 611 # notify Hub
609 612 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 613 self.session.send(self.mon_stream, 'task_destination', content=content,
611 614 ident=[b'tracktask',self.ident])
612 615
613 616
614 617 #-----------------------------------------------------------------------
615 618 # Result Handling
616 619 #-----------------------------------------------------------------------
617 620
618 621
619 622 @util.log_errors
620 623 def dispatch_result(self, raw_msg):
621 624 """dispatch method for result replies"""
622 625 try:
623 626 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 627 msg = self.session.unserialize(msg, content=False, copy=False)
625 628 engine = idents[0]
626 629 try:
627 630 idx = self.targets.index(engine)
628 631 except ValueError:
629 632 pass # skip load-update for dead engines
630 633 else:
631 634 self.finish_job(idx)
632 635 except Exception:
633 636 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
634 637 return
635 638
636 639 md = msg['metadata']
637 640 parent = msg['parent_header']
638 641 if md.get('dependencies_met', True):
639 642 success = (md['status'] == 'ok')
640 643 msg_id = parent['msg_id']
641 644 retries = self.retries[msg_id]
642 645 if not success and retries > 0:
643 646 # failed
644 647 self.retries[msg_id] = retries - 1
645 648 self.handle_unmet_dependency(idents, parent)
646 649 else:
647 650 del self.retries[msg_id]
648 651 # relay to client and update graph
649 652 self.handle_result(idents, parent, raw_msg, success)
650 653 # send to Hub monitor
651 654 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 655 else:
653 656 self.handle_unmet_dependency(idents, parent)
654 657
655 658 def handle_result(self, idents, parent, raw_msg, success=True):
656 659 """handle a real task result, either success or failure"""
657 660 # first, relay result to client
658 661 engine = idents[0]
659 662 client = idents[1]
660 663 # swap_ids for ROUTER-ROUTER mirror
661 664 raw_msg[:2] = [client,engine]
662 665 # print (map(str, raw_msg[:4]))
663 666 self.client_stream.send_multipart(raw_msg, copy=False)
664 667 # now, update our data structures
665 668 msg_id = parent['msg_id']
666 669 self.pending[engine].pop(msg_id)
667 670 if success:
668 671 self.completed[engine].add(msg_id)
669 672 self.all_completed.add(msg_id)
670 673 else:
671 674 self.failed[engine].add(msg_id)
672 675 self.all_failed.add(msg_id)
673 676 self.all_done.add(msg_id)
674 677 self.destinations[msg_id] = engine
675 678
676 679 self.update_graph(msg_id, success)
677 680
678 681 def handle_unmet_dependency(self, idents, parent):
679 682 """handle an unmet dependency"""
680 683 engine = idents[0]
681 684 msg_id = parent['msg_id']
682 685
683 686 job = self.pending[engine].pop(msg_id)
684 687 job.blacklist.add(engine)
685 688
686 689 if job.blacklist == job.targets:
687 690 self.queue_map[msg_id] = job
688 691 self.fail_unreachable(msg_id)
689 692 elif not self.maybe_run(job):
690 693 # resubmit failed
691 694 if msg_id not in self.all_failed:
692 695 # put it back in our dependency tree
693 696 self.save_unmet(job)
694 697
695 698 if self.hwm:
696 699 try:
697 700 idx = self.targets.index(engine)
698 701 except ValueError:
699 702 pass # skip load-update for dead engines
700 703 else:
701 704 if self.loads[idx] == self.hwm-1:
702 705 self.update_graph(None)
703 706
704 707 def update_graph(self, dep_id=None, success=True):
705 708 """dep_id just finished. Update our dependency
706 709 graph and submit any jobs that just became runnable.
707 710
708 711 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 712 """
710 713 # print ("\n\n***********")
711 714 # pprint (dep_id)
712 715 # pprint (self.graph)
713 716 # pprint (self.queue_map)
714 717 # pprint (self.all_completed)
715 718 # pprint (self.all_failed)
716 719 # print ("\n\n***********\n\n")
717 720 # update any jobs that depended on the dependency
718 721 msg_ids = self.graph.pop(dep_id, [])
719 722
720 723 # recheck *all* jobs if
721 724 # a) we have HWM and an engine just become no longer full
722 725 # or b) dep_id was given as None
723 726
724 727 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 728 jobs = self.queue
726 729 using_queue = True
727 730 else:
728 731 using_queue = False
729 jobs = heapq.heapify([ self.queue_map[msg_id] for msg_id in msg_ids ])
732 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730 733
731 734 to_restore = []
732 735 while jobs:
733 job = heapq.heappop(jobs)
736 job = jobs.popleft()
734 737 if job.removed:
735 738 continue
736 739 msg_id = job.msg_id
737 740
738 741 put_it_back = True
739 742
740 743 if job.after.unreachable(self.all_completed, self.all_failed)\
741 744 or job.follow.unreachable(self.all_completed, self.all_failed):
742 745 self.fail_unreachable(msg_id)
743 746 put_it_back = False
744 747
745 748 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 749 if self.maybe_run(job):
747 750 put_it_back = False
748 751 self.queue_map.pop(msg_id)
749 752 for mid in job.dependents:
750 753 if mid in self.graph:
751 754 self.graph[mid].remove(msg_id)
752 755
753 756 # abort the loop if we just filled up all of our engines.
754 757 # avoids an O(N) operation in situation of full queue,
755 758 # where graph update is triggered as soon as an engine becomes
756 759 # non-full, and all tasks after the first are checked,
757 760 # even though they can't run.
758 761 if not self.available_engines():
759 762 break
760 763
761 764 if using_queue and put_it_back:
762 765 # popped a job from the queue but it neither ran nor failed,
763 766 # so we need to put it back when we are done
767 # make sure to_restore preserves the same ordering
764 768 to_restore.append(job)
765 769
766 770 # put back any tasks we popped but didn't run
767 for job in to_restore:
768 heapq.heappush(self.queue, job)
769
771 if using_queue:
772 self.queue.extendleft(to_restore)
770 773
771 774 #----------------------------------------------------------------------
772 775 # methods to be overridden by subclasses
773 776 #----------------------------------------------------------------------
774 777
775 778 def add_job(self, idx):
776 779 """Called after self.targets[idx] just got the job with header.
777 780 Override with subclasses. The default ordering is simple LRU.
778 781 The default loads are the number of outstanding jobs."""
779 782 self.loads[idx] += 1
780 783 for lis in (self.targets, self.loads):
781 784 lis.append(lis.pop(idx))
782 785
783 786
784 787 def finish_job(self, idx):
785 788 """Called after self.targets[idx] just finished a job.
786 789 Override with subclasses."""
787 790 self.loads[idx] -= 1
788 791
789 792
790 793
791 794 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 795 logname='root', log_url=None, loglevel=logging.DEBUG,
793 796 identity=b'task', in_thread=False):
794 797
795 798 ZMQStream = zmqstream.ZMQStream
796 799 loglevel = logging.DEBUG
797 800 if config:
798 801 # unwrap dict back into Config
799 802 config = Config(config)
800 803
801 804 if in_thread:
802 805 # use instance() to get the same Context/Loop as our parent
803 806 ctx = zmq.Context.instance()
804 807 loop = ioloop.IOLoop.instance()
805 808 else:
806 809 # in a process, don't use instance()
807 810 # for safety with multiprocessing
808 811 ctx = zmq.Context()
809 812 loop = ioloop.IOLoop()
810 813 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 814 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
812 815 ins.bind(in_addr)
813 816
814 817 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
815 818 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
816 819 outs.bind(out_addr)
817 820 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
818 821 mons.connect(mon_addr)
819 822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
820 823 nots.setsockopt(zmq.SUBSCRIBE, b'')
821 824 nots.connect(not_addr)
822 825
823 826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
824 827 querys.connect(reg_addr)
825 828
826 829 # setup logging.
827 830 if in_thread:
828 831 log = Application.instance().log
829 832 else:
830 833 if log_url:
831 834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
832 835 else:
833 836 log = local_logger(logname, loglevel)
834 837
835 838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
836 839 mon_stream=mons, notifier_stream=nots,
837 840 query_stream=querys,
838 841 loop=loop, log=log,
839 842 config=config)
840 843 scheduler.start()
841 844 if not in_thread:
842 845 try:
843 846 loop.start()
844 847 except KeyboardInterrupt:
845 848 scheduler.log.critical("Interrupted, exiting...")
846 849
General Comments 0
You need to be logged in to leave comments. Login now