##// END OF EJS Templates
quiet down scheduler printing, fix dep_id check in update_dependencies
MinRK -
Show More
@@ -1,422 +1,423 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 #----------------------------------------------------------------------
9 9 # Imports
10 10 #----------------------------------------------------------------------
11 11
12 12 from __future__ import print_function
13 13 from random import randint,random
14 14
15 15 try:
16 16 import numpy
17 17 except ImportError:
18 18 numpy = None
19 19
20 20 import zmq
21 21 from zmq.eventloop import ioloop, zmqstream
22 22
23 23 # local imports
24 24 from IPython.zmq.log import logger # a Logger object
25 25 from client import Client
26 26 from dependency import Dependency
27 27 import streamsession as ss
28 28
29 29 from IPython.external.decorator import decorator
30 30
31 31 @decorator
32 32 def logged(f,self,*args,**kwargs):
33 print ("#--------------------")
34 print ("%s(*%s,**%s)"%(f.func_name, args, kwargs))
35 print ("#--")
33 # print ("#--------------------")
34 # print ("%s(*%s,**%s)"%(f.func_name, args, kwargs))
35 # print ("#--")
36 36 return f(self,*args, **kwargs)
37 37
38 38 #----------------------------------------------------------------------
39 39 # Chooser functions
40 40 #----------------------------------------------------------------------
41 41
42 42 def plainrandom(loads):
43 43 """Plain random pick."""
44 44 n = len(loads)
45 45 return randint(0,n-1)
46 46
47 47 def lru(loads):
48 48 """Always pick the front of the line.
49 49
50 50 The content of `loads` is ignored.
51 51
52 52 Assumes LRU ordering of loads, with oldest first.
53 53 """
54 54 return 0
55 55
56 56 def twobin(loads):
57 57 """Pick two at random, use the LRU of the two.
58 58
59 59 The content of loads is ignored.
60 60
61 61 Assumes LRU ordering of loads, with oldest first.
62 62 """
63 63 n = len(loads)
64 64 a = randint(0,n-1)
65 65 b = randint(0,n-1)
66 66 return min(a,b)
67 67
68 68 def weighted(loads):
69 69 """Pick two at random using inverse load as weight.
70 70
71 71 Return the less loaded of the two.
72 72 """
73 73 # weight 0 a million times more than 1:
74 74 weights = 1./(1e-6+numpy.array(loads))
75 75 sums = weights.cumsum()
76 76 t = sums[-1]
77 77 x = random()*t
78 78 y = random()*t
79 79 idx = 0
80 80 idy = 0
81 81 while sums[idx] < x:
82 82 idx += 1
83 83 while sums[idy] < y:
84 84 idy += 1
85 85 if weights[idy] > weights[idx]:
86 86 return idy
87 87 else:
88 88 return idx
89 89
90 90 def leastload(loads):
91 91 """Always choose the lowest load.
92 92
93 93 If the lowest load occurs more than once, the first
94 94 occurance will be used. If loads has LRU ordering, this means
95 95 the LRU of those with the lowest load is chosen.
96 96 """
97 97 return loads.index(min(loads))
98 98
99 99 #---------------------------------------------------------------------
100 100 # Classes
101 101 #---------------------------------------------------------------------
102 102 class TaskScheduler(object):
103 103 """Python TaskScheduler object.
104 104
105 105 This is the simplest object that supports msg_id based
106 106 DAG dependencies. *Only* task msg_ids are checked, not
107 107 msg_ids of jobs submitted via the MUX queue.
108 108
109 109 """
110 110
111 111 scheme = leastload # function for determining the destination
112 112 client_stream = None # client-facing stream
113 113 engine_stream = None # engine-facing stream
114 114 mon_stream = None # controller-facing stream
115 115 dependencies = None # dict by msg_id of [ msg_ids that depend on key ]
116 116 depending = None # dict by msg_id of (msg_id, raw_msg, after, follow)
117 117 pending = None # dict by engine_uuid of submitted tasks
118 118 completed = None # dict by engine_uuid of completed tasks
119 119 clients = None # dict by msg_id for who submitted the task
120 120 targets = None # list of target IDENTs
121 121 loads = None # list of engine loads
122 122 all_done = None # set of all completed tasks
123 123 blacklist = None # dict by msg_id of locations where a job has encountered UnmetDependency
124 124
125 125
126 126 def __init__(self, client_stream, engine_stream, mon_stream,
127 127 notifier_stream, scheme=None, io_loop=None):
128 128 if io_loop is None:
129 129 io_loop = ioloop.IOLoop.instance()
130 130 self.io_loop = io_loop
131 131 self.client_stream = client_stream
132 132 self.engine_stream = engine_stream
133 133 self.mon_stream = mon_stream
134 134 self.notifier_stream = notifier_stream
135 135
136 136 if scheme is not None:
137 137 self.scheme = scheme
138 138 else:
139 139 self.scheme = TaskScheduler.scheme
140 140
141 141 self.session = ss.StreamSession(username="TaskScheduler")
142 142
143 143 self.dependencies = {}
144 144 self.depending = {}
145 145 self.completed = {}
146 146 self.pending = {}
147 147 self.all_done = set()
148 148 self.blacklist = {}
149 149
150 150 self.targets = []
151 151 self.loads = []
152 152
153 153 engine_stream.on_recv(self.dispatch_result, copy=False)
154 154 self._notification_handlers = dict(
155 155 registration_notification = self._register_engine,
156 156 unregistration_notification = self._unregister_engine
157 157 )
158 158 self.notifier_stream.on_recv(self.dispatch_notification)
159 159
160 160 def resume_receiving(self):
161 161 """Resume accepting jobs."""
162 162 self.client_stream.on_recv(self.dispatch_submission, copy=False)
163 163
164 164 def stop_receiving(self):
165 165 """Stop accepting jobs while there are no engines.
166 166 Leave them in the ZMQ queue."""
167 167 self.client_stream.on_recv(None)
168 168
169 169 #-----------------------------------------------------------------------
170 170 # [Un]Registration Handling
171 171 #-----------------------------------------------------------------------
172 172
173 173 def dispatch_notification(self, msg):
174 174 """dispatch register/unregister events."""
175 175 idents,msg = self.session.feed_identities(msg)
176 176 msg = self.session.unpack_message(msg)
177 177 msg_type = msg['msg_type']
178 178 handler = self._notification_handlers.get(msg_type, None)
179 179 if handler is None:
180 180 raise Exception("Unhandled message type: %s"%msg_type)
181 181 else:
182 182 try:
183 183 handler(str(msg['content']['queue']))
184 184 except KeyError:
185 185 logger.error("task::Invalid notification msg: %s"%msg)
186 186 @logged
187 187 def _register_engine(self, uid):
188 188 """New engine with ident `uid` became available."""
189 189 # head of the line:
190 190 self.targets.insert(0,uid)
191 191 self.loads.insert(0,0)
192 192 # initialize sets
193 193 self.completed[uid] = set()
194 194 self.pending[uid] = {}
195 195 if len(self.targets) == 1:
196 196 self.resume_receiving()
197 197
198 198 def _unregister_engine(self, uid):
199 199 """Existing engine with ident `uid` became unavailable."""
200 200 if len(self.targets) == 1:
201 201 # this was our only engine
202 202 self.stop_receiving()
203 203
204 204 # handle any potentially finished tasks:
205 205 self.engine_stream.flush()
206 206
207 207 self.completed.pop(uid)
208 208 lost = self.pending.pop(uid)
209 209
210 210 idx = self.targets.index(uid)
211 211 self.targets.pop(idx)
212 212 self.loads.pop(idx)
213 213
214 214 self.handle_stranded_tasks(lost)
215 215
216 216 def handle_stranded_tasks(self, lost):
217 217 """Deal with jobs resident in an engine that died."""
218 218 # TODO: resubmit the tasks?
219 219 for msg_id in lost:
220 220 pass
221 221
222 222
223 223 #-----------------------------------------------------------------------
224 224 # Job Submission
225 225 #-----------------------------------------------------------------------
226 226 @logged
227 227 def dispatch_submission(self, raw_msg):
228 228 """Dispatch job submission to appropriate handlers."""
229 229 # ensure targets up to date:
230 230 self.notifier_stream.flush()
231 231 try:
232 232 idents, msg = self.session.feed_identities(raw_msg, copy=False)
233 233 except Exception as e:
234 234 logger.error("task::Invaid msg: %s"%msg)
235 235 return
236 236
237 # send to monitor
238 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
239
237 240 msg = self.session.unpack_message(msg, content=False, copy=False)
238 241 header = msg['header']
239 242 msg_id = header['msg_id']
240 243
241 244 # time dependencies
242 245 after = Dependency(header.get('after', []))
243 246 if after.mode == 'all':
244 247 after.difference_update(self.all_done)
245 248 if after.check(self.all_done):
246 249 # recast as empty set, if `after` already met,
247 250 # to prevent unnecessary set comparisons
248 251 after = Dependency([])
249 252
250 253 # location dependencies
251 254 follow = Dependency(header.get('follow', []))
252 255 if len(after) == 0:
253 256 # time deps already met, try to run
254 257 if not self.maybe_run(msg_id, raw_msg, follow):
255 258 # can't run yet
256 259 self.save_unmet(msg_id, raw_msg, after, follow)
257 260 else:
258 261 self.save_unmet(msg_id, raw_msg, after, follow)
259 # send to monitor
260 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
261 262
262 263 @logged
263 264 def maybe_run(self, msg_id, raw_msg, follow=None):
264 265 """check location dependencies, and run if they are met."""
265 266
266 267 if follow:
267 268 def can_run(idx):
268 269 target = self.targets[idx]
269 270 return target not in self.blacklist.get(msg_id, []) and\
270 271 follow.check(self.completed[target])
271 272
272 273 indices = filter(can_run, range(len(self.targets)))
273 274 if not indices:
274 275 return False
275 276 else:
276 277 indices = None
277 278
278 279 self.submit_task(msg_id, raw_msg, indices)
279 280 return True
280 281
281 282 @logged
282 283 def save_unmet(self, msg_id, msg, after, follow):
283 284 """Save a message for later submission when its dependencies are met."""
284 285 self.depending[msg_id] = (msg_id,msg,after,follow)
285 286 # track the ids in both follow/after, but not those already completed
286 287 for dep_id in after.union(follow).difference(self.all_done):
287 print (dep_id)
288 288 if dep_id not in self.dependencies:
289 289 self.dependencies[dep_id] = set()
290 290 self.dependencies[dep_id].add(msg_id)
291 291
292 292 @logged
293 293 def submit_task(self, msg_id, msg, follow=None, indices=None):
294 294 """Submit a task to any of a subset of our targets."""
295 295 if indices:
296 296 loads = [self.loads[i] for i in indices]
297 297 else:
298 298 loads = self.loads
299 299 idx = self.scheme(loads)
300 300 if indices:
301 301 idx = indices[idx]
302 302 target = self.targets[idx]
303 print (target, map(str, msg[:3]))
303 # print (target, map(str, msg[:3]))
304 304 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
305 305 self.engine_stream.send_multipart(msg, copy=False)
306 306 self.add_job(idx)
307 307 self.pending[target][msg_id] = (msg, follow)
308 308 content = dict(msg_id=msg_id, engine_id=target)
309 309 self.session.send(self.mon_stream, 'task_destination', content=content, ident='tracktask')
310 310
311 311 #-----------------------------------------------------------------------
312 312 # Result Handling
313 313 #-----------------------------------------------------------------------
314 314 @logged
315 315 def dispatch_result(self, raw_msg):
316 316 try:
317 317 idents,msg = self.session.feed_identities(raw_msg, copy=False)
318 318 except Exception as e:
319 319 logger.error("task::Invaid result: %s"%msg)
320 320 return
321 321 msg = self.session.unpack_message(msg, content=False, copy=False)
322 322 header = msg['header']
323 323 if header.get('dependencies_met', True):
324 324 self.handle_result_success(idents, msg['parent_header'], raw_msg)
325 325 # send to monitor
326 326 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
327 327 else:
328 328 self.handle_unmet_dependency(idents, msg['parent_header'])
329 329
330 330 @logged
331 331 def handle_result_success(self, idents, parent, raw_msg):
332 332 # first, relay result to client
333 333 engine = idents[0]
334 334 client = idents[1]
335 335 # swap_ids for XREP-XREP mirror
336 336 raw_msg[:2] = [client,engine]
337 print (map(str, raw_msg[:4]))
337 # print (map(str, raw_msg[:4]))
338 338 self.client_stream.send_multipart(raw_msg, copy=False)
339 339 # now, update our data structures
340 340 msg_id = parent['msg_id']
341 341 self.pending[engine].pop(msg_id)
342 342 self.completed[engine].add(msg_id)
343 343
344 344 self.update_dependencies(msg_id)
345 345
346 346 @logged
347 347 def handle_unmet_dependency(self, idents, parent):
348 348 engine = idents[0]
349 349 msg_id = parent['msg_id']
350 350 if msg_id not in self.blacklist:
351 351 self.blacklist[msg_id] = set()
352 352 self.blacklist[msg_id].add(engine)
353 353 raw_msg,follow = self.pending[engine].pop(msg_id)
354 354 if not self.maybe_run(msg_id, raw_msg, follow):
355 355 # resubmit failed, put it back in our dependency tree
356 356 self.save_unmet(msg_id, raw_msg, Dependency(), follow)
357 357 pass
358 358 @logged
359 359 def update_dependencies(self, dep_id):
360 360 """dep_id just finished. Update our dependency
361 361 table and submit any jobs that just became runable."""
362
362 363 if dep_id not in self.dependencies:
363 364 return
364 365 jobs = self.dependencies.pop(dep_id)
365 366 for job in jobs:
366 367 msg_id, raw_msg, after, follow = self.depending[job]
367 if msg_id in after:
368 after.remove(msg_id)
369 if not after: # time deps met
368 if dep_id in after:
369 after.remove(dep_id)
370 if not after: # time deps met, maybe run
370 371 if self.maybe_run(msg_id, raw_msg, follow):
371 372 self.depending.pop(job)
372 373 for mid in follow:
373 374 if mid in self.dependencies:
374 375 self.dependencies[mid].remove(msg_id)
375 376
376 377 #----------------------------------------------------------------------
377 378 # methods to be overridden by subclasses
378 379 #----------------------------------------------------------------------
379 380
380 381 def add_job(self, idx):
381 382 """Called after self.targets[idx] just got the job with header.
382 383 Override with subclasses. The default ordering is simple LRU.
383 384 The default loads are the number of outstanding jobs."""
384 385 self.loads[idx] += 1
385 386 for lis in (self.targets, self.loads):
386 387 lis.append(lis.pop(idx))
387 388
388 389
389 390 def finish_job(self, idx):
390 391 """Called after self.targets[idx] just finished a job.
391 392 Override with subclasses."""
392 393 self.loads[idx] -= 1
393 394
394 395
395 396
396 397 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, scheme='weighted'):
397 398 from zmq.eventloop import ioloop
398 399 from zmq.eventloop.zmqstream import ZMQStream
399 400
400 401 ctx = zmq.Context()
401 402 loop = ioloop.IOLoop()
402 403
403 404 scheme = globals().get(scheme)
404 405
405 406 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
406 407 ins.bind(in_addr)
407 408 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
408 409 outs.bind(out_addr)
409 410 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
410 411 mons.connect(mon_addr)
411 412 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
412 413 nots.setsockopt(zmq.SUBSCRIBE, '')
413 414 nots.connect(not_addr)
414 415
415 416 scheduler = TaskScheduler(ins,outs,mons,nots,scheme,loop)
416 417
417 418 loop.start()
418 419
419 420
420 421 if __name__ == '__main__':
421 422 iface = 'tcp://127.0.0.1:%i'
422 423 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
General Comments 0
You need to be logged in to leave comments. Login now