##// END OF EJS Templates
load header with engine id when engine dies in TaskScheduler...
MinRK -
Show More
@@ -1,726 +1,732 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44
45 45 from IPython.parallel import error
46 46 from IPython.parallel.factory import SessionFactory
47 47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48 48
49 49 from .dependency import Dependency
50 50
51 51 @decorator
52 52 def logged(f,self,*args,**kwargs):
53 53 # print ("#--------------------")
54 54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 55 # print ("#--")
56 56 return f(self,*args, **kwargs)
57 57
58 58 #----------------------------------------------------------------------
59 59 # Chooser functions
60 60 #----------------------------------------------------------------------
61 61
62 62 def plainrandom(loads):
63 63 """Plain random pick."""
64 64 n = len(loads)
65 65 return randint(0,n-1)
66 66
67 67 def lru(loads):
68 68 """Always pick the front of the line.
69 69
70 70 The content of `loads` is ignored.
71 71
72 72 Assumes LRU ordering of loads, with oldest first.
73 73 """
74 74 return 0
75 75
76 76 def twobin(loads):
77 77 """Pick two at random, use the LRU of the two.
78 78
79 79 The content of loads is ignored.
80 80
81 81 Assumes LRU ordering of loads, with oldest first.
82 82 """
83 83 n = len(loads)
84 84 a = randint(0,n-1)
85 85 b = randint(0,n-1)
86 86 return min(a,b)
87 87
88 88 def weighted(loads):
89 89 """Pick two at random using inverse load as weight.
90 90
91 91 Return the less loaded of the two.
92 92 """
93 93 # weight 0 a million times more than 1:
94 94 weights = 1./(1e-6+numpy.array(loads))
95 95 sums = weights.cumsum()
96 96 t = sums[-1]
97 97 x = random()*t
98 98 y = random()*t
99 99 idx = 0
100 100 idy = 0
101 101 while sums[idx] < x:
102 102 idx += 1
103 103 while sums[idy] < y:
104 104 idy += 1
105 105 if weights[idy] > weights[idx]:
106 106 return idy
107 107 else:
108 108 return idx
109 109
110 110 def leastload(loads):
111 111 """Always choose the lowest load.
112 112
113 113 If the lowest load occurs more than once, the first
114 114 occurance will be used. If loads has LRU ordering, this means
115 115 the LRU of those with the lowest load is chosen.
116 116 """
117 117 return loads.index(min(loads))
118 118
119 119 #---------------------------------------------------------------------
120 120 # Classes
121 121 #---------------------------------------------------------------------
122 122 # store empty default dependency:
123 123 MET = Dependency([])
124 124
125 125 class TaskScheduler(SessionFactory):
126 126 """Python TaskScheduler object.
127 127
128 128 This is the simplest object that supports msg_id based
129 129 DAG dependencies. *Only* task msg_ids are checked, not
130 130 msg_ids of jobs submitted via the MUX queue.
131 131
132 132 """
133 133
134 134 hwm = Integer(1, config=True,
135 135 help="""specify the High Water Mark (HWM) for the downstream
136 136 socket in the Task scheduler. This is the maximum number
137 137 of allowed outstanding tasks on each engine.
138 138
139 139 The default (1) means that only one task can be outstanding on each
140 140 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
141 141 engines continue to be assigned tasks while they are working,
142 142 effectively hiding network latency behind computation, but can result
143 143 in an imbalance of work when submitting many heterogenous tasks all at
144 144 once. Any positive value greater than one is a compromise between the
145 145 two.
146 146
147 147 """
148 148 )
149 149 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
150 150 'leastload', config=True, allow_none=False,
151 151 help="""select the task scheduler scheme [default: Python LRU]
152 152 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
153 153 )
154 154 def _scheme_name_changed(self, old, new):
155 155 self.log.debug("Using scheme %r"%new)
156 156 self.scheme = globals()[new]
157 157
158 158 # input arguments:
159 159 scheme = Instance(FunctionType) # function for determining the destination
160 160 def _scheme_default(self):
161 161 return leastload
162 162 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
163 163 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
164 164 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
165 165 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
166 166
167 167 # internals:
168 168 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
169 169 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
170 170 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
171 171 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
172 172 pending = Dict() # dict by engine_uuid of submitted tasks
173 173 completed = Dict() # dict by engine_uuid of completed tasks
174 174 failed = Dict() # dict by engine_uuid of failed tasks
175 175 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
176 176 clients = Dict() # dict by msg_id for who submitted the task
177 177 targets = List() # list of target IDENTs
178 178 loads = List() # list of engine loads
179 179 # full = Set() # set of IDENTs that have HWM outstanding tasks
180 180 all_completed = Set() # set of all completed tasks
181 181 all_failed = Set() # set of all failed tasks
182 182 all_done = Set() # set of all finished tasks=union(completed,failed)
183 183 all_ids = Set() # set of all submitted task IDs
184 184 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
185 185 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
186 186
187 187 ident = CBytes() # ZMQ identity. This should just be self.session.session
188 188 # but ensure Bytes
189 189 def _ident_default(self):
190 190 return self.session.bsession
191 191
192 192 def start(self):
193 193 self.engine_stream.on_recv(self.dispatch_result, copy=False)
194 194 self._notification_handlers = dict(
195 195 registration_notification = self._register_engine,
196 196 unregistration_notification = self._unregister_engine
197 197 )
198 198 self.notifier_stream.on_recv(self.dispatch_notification)
199 199 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
200 200 self.auditor.start()
201 201 self.log.info("Scheduler started [%s]"%self.scheme_name)
202 202
203 203 def resume_receiving(self):
204 204 """Resume accepting jobs."""
205 205 self.client_stream.on_recv(self.dispatch_submission, copy=False)
206 206
207 207 def stop_receiving(self):
208 208 """Stop accepting jobs while there are no engines.
209 209 Leave them in the ZMQ queue."""
210 210 self.client_stream.on_recv(None)
211 211
212 212 #-----------------------------------------------------------------------
213 213 # [Un]Registration Handling
214 214 #-----------------------------------------------------------------------
215 215
216 216 def dispatch_notification(self, msg):
217 217 """dispatch register/unregister events."""
218 218 try:
219 219 idents,msg = self.session.feed_identities(msg)
220 220 except ValueError:
221 221 self.log.warn("task::Invalid Message: %r",msg)
222 222 return
223 223 try:
224 224 msg = self.session.unserialize(msg)
225 225 except ValueError:
226 226 self.log.warn("task::Unauthorized message from: %r"%idents)
227 227 return
228 228
229 229 msg_type = msg['header']['msg_type']
230 230
231 231 handler = self._notification_handlers.get(msg_type, None)
232 232 if handler is None:
233 233 self.log.error("Unhandled message type: %r"%msg_type)
234 234 else:
235 235 try:
236 236 handler(asbytes(msg['content']['queue']))
237 237 except Exception:
238 238 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
239 239
240 240 def _register_engine(self, uid):
241 241 """New engine with ident `uid` became available."""
242 242 # head of the line:
243 243 self.targets.insert(0,uid)
244 244 self.loads.insert(0,0)
245 245
246 246 # initialize sets
247 247 self.completed[uid] = set()
248 248 self.failed[uid] = set()
249 249 self.pending[uid] = {}
250 250 if len(self.targets) == 1:
251 251 self.resume_receiving()
252 252 # rescan the graph:
253 253 self.update_graph(None)
254 254
255 255 def _unregister_engine(self, uid):
256 256 """Existing engine with ident `uid` became unavailable."""
257 257 if len(self.targets) == 1:
258 258 # this was our only engine
259 259 self.stop_receiving()
260 260
261 261 # handle any potentially finished tasks:
262 262 self.engine_stream.flush()
263 263
264 264 # don't pop destinations, because they might be used later
265 265 # map(self.destinations.pop, self.completed.pop(uid))
266 266 # map(self.destinations.pop, self.failed.pop(uid))
267 267
268 268 # prevent this engine from receiving work
269 269 idx = self.targets.index(uid)
270 270 self.targets.pop(idx)
271 271 self.loads.pop(idx)
272 272
273 273 # wait 5 seconds before cleaning up pending jobs, since the results might
274 274 # still be incoming
275 275 if self.pending[uid]:
276 276 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
277 277 dc.start()
278 278 else:
279 279 self.completed.pop(uid)
280 280 self.failed.pop(uid)
281 281
282 282
283 283 def handle_stranded_tasks(self, engine):
284 284 """Deal with jobs resident in an engine that died."""
285 285 lost = self.pending[engine]
286 286 for msg_id in lost.keys():
287 287 if msg_id not in self.pending[engine]:
288 288 # prevent double-handling of messages
289 289 continue
290 290
291 291 raw_msg = lost[msg_id][0]
292 292 idents,msg = self.session.feed_identities(raw_msg, copy=False)
293 293 parent = self.session.unpack(msg[1].bytes)
294 294 idents = [engine, idents[0]]
295 295
296 296 # build fake error reply
297 297 try:
298 298 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
299 299 except:
300 300 content = error.wrap_exception()
301 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
301 # build fake header
302 header = dict(
303 status='error',
304 engine=engine,
305 date=datetime.now(),
306 )
307 msg = self.session.msg('apply_reply', content, parent=parent, subheader=header)
302 308 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
303 309 # and dispatch it
304 310 self.dispatch_result(raw_reply)
305 311
306 312 # finally scrub completed/failed lists
307 313 self.completed.pop(engine)
308 314 self.failed.pop(engine)
309 315
310 316
311 317 #-----------------------------------------------------------------------
312 318 # Job Submission
313 319 #-----------------------------------------------------------------------
314 320 def dispatch_submission(self, raw_msg):
315 321 """Dispatch job submission to appropriate handlers."""
316 322 # ensure targets up to date:
317 323 self.notifier_stream.flush()
318 324 try:
319 325 idents, msg = self.session.feed_identities(raw_msg, copy=False)
320 326 msg = self.session.unserialize(msg, content=False, copy=False)
321 327 except Exception:
322 328 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
323 329 return
324 330
325 331
326 332 # send to monitor
327 333 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
328 334
329 335 header = msg['header']
330 336 msg_id = header['msg_id']
331 337 self.all_ids.add(msg_id)
332 338
333 339 # get targets as a set of bytes objects
334 340 # from a list of unicode objects
335 341 targets = header.get('targets', [])
336 342 targets = map(asbytes, targets)
337 343 targets = set(targets)
338 344
339 345 retries = header.get('retries', 0)
340 346 self.retries[msg_id] = retries
341 347
342 348 # time dependencies
343 349 after = header.get('after', None)
344 350 if after:
345 351 after = Dependency(after)
346 352 if after.all:
347 353 if after.success:
348 354 after = Dependency(after.difference(self.all_completed),
349 355 success=after.success,
350 356 failure=after.failure,
351 357 all=after.all,
352 358 )
353 359 if after.failure:
354 360 after = Dependency(after.difference(self.all_failed),
355 361 success=after.success,
356 362 failure=after.failure,
357 363 all=after.all,
358 364 )
359 365 if after.check(self.all_completed, self.all_failed):
360 366 # recast as empty set, if `after` already met,
361 367 # to prevent unnecessary set comparisons
362 368 after = MET
363 369 else:
364 370 after = MET
365 371
366 372 # location dependencies
367 373 follow = Dependency(header.get('follow', []))
368 374
369 375 # turn timeouts into datetime objects:
370 376 timeout = header.get('timeout', None)
371 377 if timeout:
372 378 # cast to float, because jsonlib returns floats as decimal.Decimal,
373 379 # which timedelta does not accept
374 380 timeout = datetime.now() + timedelta(0,float(timeout),0)
375 381
376 382 args = [raw_msg, targets, after, follow, timeout]
377 383
378 384 # validate and reduce dependencies:
379 385 for dep in after,follow:
380 386 if not dep: # empty dependency
381 387 continue
382 388 # check valid:
383 389 if msg_id in dep or dep.difference(self.all_ids):
384 390 self.depending[msg_id] = args
385 391 return self.fail_unreachable(msg_id, error.InvalidDependency)
386 392 # check if unreachable:
387 393 if dep.unreachable(self.all_completed, self.all_failed):
388 394 self.depending[msg_id] = args
389 395 return self.fail_unreachable(msg_id)
390 396
391 397 if after.check(self.all_completed, self.all_failed):
392 398 # time deps already met, try to run
393 399 if not self.maybe_run(msg_id, *args):
394 400 # can't run yet
395 401 if msg_id not in self.all_failed:
396 402 # could have failed as unreachable
397 403 self.save_unmet(msg_id, *args)
398 404 else:
399 405 self.save_unmet(msg_id, *args)
400 406
401 407 def audit_timeouts(self):
402 408 """Audit all waiting tasks for expired timeouts."""
403 409 now = datetime.now()
404 410 for msg_id in self.depending.keys():
405 411 # must recheck, in case one failure cascaded to another:
406 412 if msg_id in self.depending:
407 413 raw,after,targets,follow,timeout = self.depending[msg_id]
408 414 if timeout and timeout < now:
409 415 self.fail_unreachable(msg_id, error.TaskTimeout)
410 416
411 417 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
412 418 """a task has become unreachable, send a reply with an ImpossibleDependency
413 419 error."""
414 420 if msg_id not in self.depending:
415 421 self.log.error("msg %r already failed!", msg_id)
416 422 return
417 423 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
418 424 for mid in follow.union(after):
419 425 if mid in self.graph:
420 426 self.graph[mid].remove(msg_id)
421 427
422 428 # FIXME: unpacking a message I've already unpacked, but didn't save:
423 429 idents,msg = self.session.feed_identities(raw_msg, copy=False)
424 430 header = self.session.unpack(msg[1].bytes)
425 431
426 432 try:
427 433 raise why()
428 434 except:
429 435 content = error.wrap_exception()
430 436
431 437 self.all_done.add(msg_id)
432 438 self.all_failed.add(msg_id)
433 439
434 440 msg = self.session.send(self.client_stream, 'apply_reply', content,
435 441 parent=header, ident=idents)
436 442 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
437 443
438 444 self.update_graph(msg_id, success=False)
439 445
440 446 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
441 447 """check location dependencies, and run if they are met."""
442 448 blacklist = self.blacklist.setdefault(msg_id, set())
443 449 if follow or targets or blacklist or self.hwm:
444 450 # we need a can_run filter
445 451 def can_run(idx):
446 452 # check hwm
447 453 if self.hwm and self.loads[idx] == self.hwm:
448 454 return False
449 455 target = self.targets[idx]
450 456 # check blacklist
451 457 if target in blacklist:
452 458 return False
453 459 # check targets
454 460 if targets and target not in targets:
455 461 return False
456 462 # check follow
457 463 return follow.check(self.completed[target], self.failed[target])
458 464
459 465 indices = filter(can_run, range(len(self.targets)))
460 466
461 467 if not indices:
462 468 # couldn't run
463 469 if follow.all:
464 470 # check follow for impossibility
465 471 dests = set()
466 472 relevant = set()
467 473 if follow.success:
468 474 relevant = self.all_completed
469 475 if follow.failure:
470 476 relevant = relevant.union(self.all_failed)
471 477 for m in follow.intersection(relevant):
472 478 dests.add(self.destinations[m])
473 479 if len(dests) > 1:
474 480 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
475 481 self.fail_unreachable(msg_id)
476 482 return False
477 483 if targets:
478 484 # check blacklist+targets for impossibility
479 485 targets.difference_update(blacklist)
480 486 if not targets or not targets.intersection(self.targets):
481 487 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
482 488 self.fail_unreachable(msg_id)
483 489 return False
484 490 return False
485 491 else:
486 492 indices = None
487 493
488 494 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
489 495 return True
490 496
491 497 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
492 498 """Save a message for later submission when its dependencies are met."""
493 499 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
494 500 # track the ids in follow or after, but not those already finished
495 501 for dep_id in after.union(follow).difference(self.all_done):
496 502 if dep_id not in self.graph:
497 503 self.graph[dep_id] = set()
498 504 self.graph[dep_id].add(msg_id)
499 505
500 506 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
501 507 """Submit a task to any of a subset of our targets."""
502 508 if indices:
503 509 loads = [self.loads[i] for i in indices]
504 510 else:
505 511 loads = self.loads
506 512 idx = self.scheme(loads)
507 513 if indices:
508 514 idx = indices[idx]
509 515 target = self.targets[idx]
510 516 # print (target, map(str, msg[:3]))
511 517 # send job to the engine
512 518 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
513 519 self.engine_stream.send_multipart(raw_msg, copy=False)
514 520 # update load
515 521 self.add_job(idx)
516 522 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
517 523 # notify Hub
518 524 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
519 525 self.session.send(self.mon_stream, 'task_destination', content=content,
520 526 ident=[b'tracktask',self.ident])
521 527
522 528
523 529 #-----------------------------------------------------------------------
524 530 # Result Handling
525 531 #-----------------------------------------------------------------------
526 532 def dispatch_result(self, raw_msg):
527 533 """dispatch method for result replies"""
528 534 try:
529 535 idents,msg = self.session.feed_identities(raw_msg, copy=False)
530 536 msg = self.session.unserialize(msg, content=False, copy=False)
531 537 engine = idents[0]
532 538 try:
533 539 idx = self.targets.index(engine)
534 540 except ValueError:
535 541 pass # skip load-update for dead engines
536 542 else:
537 543 self.finish_job(idx)
538 544 except Exception:
539 545 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
540 546 return
541 547
542 548 header = msg['header']
543 549 parent = msg['parent_header']
544 550 if header.get('dependencies_met', True):
545 551 success = (header['status'] == 'ok')
546 552 msg_id = parent['msg_id']
547 553 retries = self.retries[msg_id]
548 554 if not success and retries > 0:
549 555 # failed
550 556 self.retries[msg_id] = retries - 1
551 557 self.handle_unmet_dependency(idents, parent)
552 558 else:
553 559 del self.retries[msg_id]
554 560 # relay to client and update graph
555 561 self.handle_result(idents, parent, raw_msg, success)
556 562 # send to Hub monitor
557 563 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
558 564 else:
559 565 self.handle_unmet_dependency(idents, parent)
560 566
561 567 def handle_result(self, idents, parent, raw_msg, success=True):
562 568 """handle a real task result, either success or failure"""
563 569 # first, relay result to client
564 570 engine = idents[0]
565 571 client = idents[1]
566 572 # swap_ids for XREP-XREP mirror
567 573 raw_msg[:2] = [client,engine]
568 574 # print (map(str, raw_msg[:4]))
569 575 self.client_stream.send_multipart(raw_msg, copy=False)
570 576 # now, update our data structures
571 577 msg_id = parent['msg_id']
572 578 self.blacklist.pop(msg_id, None)
573 579 self.pending[engine].pop(msg_id)
574 580 if success:
575 581 self.completed[engine].add(msg_id)
576 582 self.all_completed.add(msg_id)
577 583 else:
578 584 self.failed[engine].add(msg_id)
579 585 self.all_failed.add(msg_id)
580 586 self.all_done.add(msg_id)
581 587 self.destinations[msg_id] = engine
582 588
583 589 self.update_graph(msg_id, success)
584 590
585 591 def handle_unmet_dependency(self, idents, parent):
586 592 """handle an unmet dependency"""
587 593 engine = idents[0]
588 594 msg_id = parent['msg_id']
589 595
590 596 if msg_id not in self.blacklist:
591 597 self.blacklist[msg_id] = set()
592 598 self.blacklist[msg_id].add(engine)
593 599
594 600 args = self.pending[engine].pop(msg_id)
595 601 raw,targets,after,follow,timeout = args
596 602
597 603 if self.blacklist[msg_id] == targets:
598 604 self.depending[msg_id] = args
599 605 self.fail_unreachable(msg_id)
600 606 elif not self.maybe_run(msg_id, *args):
601 607 # resubmit failed
602 608 if msg_id not in self.all_failed:
603 609 # put it back in our dependency tree
604 610 self.save_unmet(msg_id, *args)
605 611
606 612 if self.hwm:
607 613 try:
608 614 idx = self.targets.index(engine)
609 615 except ValueError:
610 616 pass # skip load-update for dead engines
611 617 else:
612 618 if self.loads[idx] == self.hwm-1:
613 619 self.update_graph(None)
614 620
615 621
616 622
617 623 def update_graph(self, dep_id=None, success=True):
618 624 """dep_id just finished. Update our dependency
619 625 graph and submit any jobs that just became runable.
620 626
621 627 Called with dep_id=None to update entire graph for hwm, but without finishing
622 628 a task.
623 629 """
624 630 # print ("\n\n***********")
625 631 # pprint (dep_id)
626 632 # pprint (self.graph)
627 633 # pprint (self.depending)
628 634 # pprint (self.all_completed)
629 635 # pprint (self.all_failed)
630 636 # print ("\n\n***********\n\n")
631 637 # update any jobs that depended on the dependency
632 638 jobs = self.graph.pop(dep_id, [])
633 639
634 640 # recheck *all* jobs if
635 641 # a) we have HWM and an engine just become no longer full
636 642 # or b) dep_id was given as None
637 643 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
638 644 jobs = self.depending.keys()
639 645
640 646 for msg_id in jobs:
641 647 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
642 648
643 649 if after.unreachable(self.all_completed, self.all_failed)\
644 650 or follow.unreachable(self.all_completed, self.all_failed):
645 651 self.fail_unreachable(msg_id)
646 652
647 653 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
648 654 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
649 655
650 656 self.depending.pop(msg_id)
651 657 for mid in follow.union(after):
652 658 if mid in self.graph:
653 659 self.graph[mid].remove(msg_id)
654 660
655 661 #----------------------------------------------------------------------
656 662 # methods to be overridden by subclasses
657 663 #----------------------------------------------------------------------
658 664
659 665 def add_job(self, idx):
660 666 """Called after self.targets[idx] just got the job with header.
661 667 Override with subclasses. The default ordering is simple LRU.
662 668 The default loads are the number of outstanding jobs."""
663 669 self.loads[idx] += 1
664 670 for lis in (self.targets, self.loads):
665 671 lis.append(lis.pop(idx))
666 672
667 673
668 674 def finish_job(self, idx):
669 675 """Called after self.targets[idx] just finished a job.
670 676 Override with subclasses."""
671 677 self.loads[idx] -= 1
672 678
673 679
674 680
675 681 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
676 682 logname='root', log_url=None, loglevel=logging.DEBUG,
677 683 identity=b'task', in_thread=False):
678 684
679 685 ZMQStream = zmqstream.ZMQStream
680 686
681 687 if config:
682 688 # unwrap dict back into Config
683 689 config = Config(config)
684 690
685 691 if in_thread:
686 692 # use instance() to get the same Context/Loop as our parent
687 693 ctx = zmq.Context.instance()
688 694 loop = ioloop.IOLoop.instance()
689 695 else:
690 696 # in a process, don't use instance()
691 697 # for safety with multiprocessing
692 698 ctx = zmq.Context()
693 699 loop = ioloop.IOLoop()
694 700 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
695 701 ins.setsockopt(zmq.IDENTITY, identity)
696 702 ins.bind(in_addr)
697 703
698 704 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
699 705 outs.setsockopt(zmq.IDENTITY, identity)
700 706 outs.bind(out_addr)
701 707 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
702 708 mons.connect(mon_addr)
703 709 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
704 710 nots.setsockopt(zmq.SUBSCRIBE, b'')
705 711 nots.connect(not_addr)
706 712
707 713 # setup logging.
708 714 if in_thread:
709 715 log = Application.instance().log
710 716 else:
711 717 if log_url:
712 718 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
713 719 else:
714 720 log = local_logger(logname, loglevel)
715 721
716 722 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
717 723 mon_stream=mons, notifier_stream=nots,
718 724 loop=loop, log=log,
719 725 config=config)
720 726 scheduler.start()
721 727 if not in_thread:
722 728 try:
723 729 loop.start()
724 730 except KeyboardInterrupt:
725 731 print ("interrupted, exiting...", file=sys.__stderr__)
726 732
General Comments 0
You need to be logged in to leave comments. Login now