##// END OF EJS Templates
scheduler progress
MinRK -
Show More
@@ -1,640 +1,654 b''
1 #!/usr/bin/env python
2 1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
4 #
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
3 8
4 import time
9 #-----------------------------------------------------------------------------
10 # Imports
11 #-----------------------------------------------------------------------------
5 12
13 import time
6 14 from pprint import pprint
7 15
16 import zmq
17 from zmq.eventloop import ioloop, zmqstream
18
8 19 from IPython.external.decorator import decorator
9 20
10 21 import streamsession as ss
11 import zmq
12 from zmq.eventloop import ioloop, zmqstream
13 22 from remotenamespace import RemoteNamespace
14 23 from view import DirectView
15 24 from dependency import Dependency, depend, require
16 25
17 26 def _push(ns):
18 27 globals().update(ns)
19 28
20 29 def _pull(keys):
21 30 g = globals()
22 31 if isinstance(keys, (list,tuple)):
23 32 return map(g.get, keys)
24 33 else:
25 34 return g.get(keys)
26 35
27 36 def _clear():
28 37 globals().clear()
29 38
30 39 def execute(code):
31 40 exec code in globals()
32 41
33 42 # decorators for methods:
34 43 @decorator
35 44 def spinfirst(f,self,*args,**kwargs):
36 45 self.spin()
37 46 return f(self, *args, **kwargs)
38 47
39 48 @decorator
40 49 def defaultblock(f, self, *args, **kwargs):
41 50 block = kwargs.get('block',None)
42 51 block = self.block if block is None else block
43 52 saveblock = self.block
44 53 self.block = block
45 54 ret = f(self, *args, **kwargs)
46 55 self.block = saveblock
47 56 return ret
48 57
49 58 class AbortedTask(object):
50 59 def __init__(self, msg_id):
51 60 self.msg_id = msg_id
52 61 # @decorator
53 62 # def checktargets(f):
54 63 # @wraps(f)
55 64 # def checked_method(self, *args, **kwargs):
56 65 # self._build_targets(kwargs['targets'])
57 66 # return f(self, *args, **kwargs)
58 67 # return checked_method
59 68
60 69
61 70 # class _ZMQEventLoopThread(threading.Thread):
62 71 #
63 72 # def __init__(self, loop):
64 73 # self.loop = loop
65 74 # threading.Thread.__init__(self)
66 75 #
67 76 # def run(self):
68 77 # self.loop.start()
69 78 #
70 79 class Client(object):
71 80 """A semi-synchronous client to the IPython ZMQ controller
72 81
73 82 Attributes
74 83 ----------
75 84 ids : set
76 85 a set of engine IDs
77 86 requesting the ids attribute always synchronizes
78 87 the registration state. To request ids without synchronization,
79 88 use _ids
80 89
81 90 history : list of msg_ids
82 91 a list of msg_ids, keeping track of all the execution
83 92 messages you have submitted
84 93
85 94 outstanding : set of msg_ids
86 95 a set of msg_ids that have been submitted, but whose
87 96 results have not been received
88 97
89 98 results : dict
90 99 a dict of all our results, keyed by msg_id
91 100
92 101 block : bool
93 102 determines default behavior when block not specified
94 103 in execution methods
95 104
96 105 Methods
97 106 -------
98 107 spin : flushes incoming results and registration state changes
99 108 control methods spin, and requesting `ids` also ensures up to date
100 109
101 110 barrier : wait on one or more msg_ids
102 111
103 112 execution methods: apply/apply_bound/apply_to
104 113 legacy: execute, run
105 114
106 115 query methods: queue_status, get_result
107 116
108 117 control methods: abort, kill
109 118
110 119
111 120
112 121 """
113 122
114 123
115 124 _connected=False
116 125 _engines=None
117 126 registration_socket=None
118 127 query_socket=None
119 128 control_socket=None
120 129 notification_socket=None
121 130 queue_socket=None
122 131 task_socket=None
123 132 block = False
124 133 outstanding=None
125 134 results = None
126 135 history = None
127 136 debug = False
128 137
129 138 def __init__(self, addr, context=None, username=None, debug=False):
130 139 if context is None:
131 140 context = zmq.Context()
132 141 self.context = context
133 142 self.addr = addr
134 143 if username is None:
135 144 self.session = ss.StreamSession()
136 145 else:
137 146 self.session = ss.StreamSession(username)
138 147 self.registration_socket = self.context.socket(zmq.PAIR)
139 148 self.registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
140 149 self.registration_socket.connect(addr)
141 150 self._engines = {}
142 151 self._ids = set()
143 152 self.outstanding=set()
144 153 self.results = {}
145 154 self.history = []
146 155 self.debug = debug
147 156 self.session.debug = debug
148 157
149 158 self._notification_handlers = {'registration_notification' : self._register_engine,
150 159 'unregistration_notification' : self._unregister_engine,
151 160 }
152 161 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
153 162 'apply_reply' : self._handle_apply_reply}
154 163 self._connect()
155 164
156 165
157 166 @property
158 167 def ids(self):
159 168 self._flush_notifications()
160 169 return self._ids
161 170
162 171 def _update_engines(self, engines):
163 172 for k,v in engines.iteritems():
164 173 eid = int(k)
165 174 self._engines[eid] = bytes(v) # force not unicode
166 175 self._ids.add(eid)
167 176
168 177 def _build_targets(self, targets):
169 178 if targets is None:
170 179 targets = self._ids
171 180 elif isinstance(targets, str):
172 181 if targets.lower() == 'all':
173 182 targets = self._ids
174 183 else:
175 184 raise TypeError("%r not valid str target, must be 'all'"%(targets))
176 185 elif isinstance(targets, int):
177 186 targets = [targets]
178 187 return [self._engines[t] for t in targets], list(targets)
179 188
180 189 def _connect(self):
181 190 """setup all our socket connections to the controller"""
182 191 if self._connected:
183 192 return
184 193 self._connected=True
185 194 self.session.send(self.registration_socket, 'connection_request')
186 195 idents,msg = self.session.recv(self.registration_socket,mode=0)
187 196 if self.debug:
188 197 pprint(msg)
189 198 msg = ss.Message(msg)
190 199 content = msg.content
191 200 if content.status == 'ok':
192 201 if content.queue:
193 202 self.queue_socket = self.context.socket(zmq.PAIR)
194 203 self.queue_socket.setsockopt(zmq.IDENTITY, self.session.session)
195 204 self.queue_socket.connect(content.queue)
196 205 if content.task:
197 206 self.task_socket = self.context.socket(zmq.PAIR)
198 207 self.task_socket.setsockopt(zmq.IDENTITY, self.session.session)
199 208 self.task_socket.connect(content.task)
200 209 if content.notification:
201 210 self.notification_socket = self.context.socket(zmq.SUB)
202 211 self.notification_socket.connect(content.notification)
203 212 self.notification_socket.setsockopt(zmq.SUBSCRIBE, "")
204 213 if content.query:
205 214 self.query_socket = self.context.socket(zmq.PAIR)
206 215 self.query_socket.setsockopt(zmq.IDENTITY, self.session.session)
207 216 self.query_socket.connect(content.query)
208 217 if content.control:
209 218 self.control_socket = self.context.socket(zmq.PAIR)
210 219 self.control_socket.setsockopt(zmq.IDENTITY, self.session.session)
211 220 self.control_socket.connect(content.control)
212 221 self._update_engines(dict(content.engines))
213 222
214 223 else:
215 224 self._connected = False
216 225 raise Exception("Failed to connect!")
217 226
218 227 #### handlers and callbacks for incoming messages #######
219 228 def _register_engine(self, msg):
220 229 content = msg['content']
221 230 eid = content['id']
222 231 d = {eid : content['queue']}
223 232 self._update_engines(d)
224 233 self._ids.add(int(eid))
225 234
226 235 def _unregister_engine(self, msg):
227 236 # print 'unregister',msg
228 237 content = msg['content']
229 238 eid = int(content['id'])
230 239 if eid in self._ids:
231 240 self._ids.remove(eid)
232 241 self._engines.pop(eid)
233 242
234 243 def _handle_execute_reply(self, msg):
235 244 # msg_id = msg['msg_id']
236 245 parent = msg['parent_header']
237 246 msg_id = parent['msg_id']
238 247 if msg_id not in self.outstanding:
239 248 print "got unknown result: %s"%msg_id
240 249 else:
241 250 self.outstanding.remove(msg_id)
242 251 self.results[msg_id] = ss.unwrap_exception(msg['content'])
243 252
244 253 def _handle_apply_reply(self, msg):
245 254 # pprint(msg)
246 255 # msg_id = msg['msg_id']
247 256 parent = msg['parent_header']
248 257 msg_id = parent['msg_id']
249 258 if msg_id not in self.outstanding:
250 259 print "got unknown result: %s"%msg_id
251 260 else:
252 261 self.outstanding.remove(msg_id)
253 262 content = msg['content']
254 263 if content['status'] == 'ok':
255 264 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
256 265 elif content['status'] == 'aborted':
257 266 self.results[msg_id] = AbortedTask(msg_id)
258 267 elif content['status'] == 'resubmitted':
259 268 pass # handle resubmission
260 269 else:
261 270 self.results[msg_id] = ss.unwrap_exception(content)
262 271
263 272 def _flush_notifications(self):
264 273 "flush incoming notifications of engine registrations"
265 274 msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
266 275 while msg is not None:
267 276 if self.debug:
268 277 pprint(msg)
269 278 msg = msg[-1]
270 279 msg_type = msg['msg_type']
271 280 handler = self._notification_handlers.get(msg_type, None)
272 281 if handler is None:
273 282 raise Exception("Unhandled message type: %s"%msg.msg_type)
274 283 else:
275 284 handler(msg)
276 285 msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
277 286
278 287 def _flush_results(self, sock):
279 288 "flush incoming task or queue results"
280 289 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
281 290 while msg is not None:
282 291 if self.debug:
283 292 pprint(msg)
284 293 msg = msg[-1]
285 294 msg_type = msg['msg_type']
286 295 handler = self._queue_handlers.get(msg_type, None)
287 296 if handler is None:
288 297 raise Exception("Unhandled message type: %s"%msg.msg_type)
289 298 else:
290 299 handler(msg)
291 300 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
292 301
293 302 def _flush_control(self, sock):
294 303 "flush incoming control replies"
295 304 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
296 305 while msg is not None:
297 306 if self.debug:
298 307 pprint(msg)
299 308 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
300 309
301 310 ###### get/setitem ########
302 311
303 312 def __getitem__(self, key):
304 313 if isinstance(key, int):
305 314 if key not in self.ids:
306 315 raise IndexError("No such engine: %i"%key)
307 316 return DirectView(self, key)
308 317
309 318 if isinstance(key, slice):
310 319 indices = range(len(self.ids))[key]
311 320 ids = sorted(self._ids)
312 321 key = [ ids[i] for i in indices ]
313 322 # newkeys = sorted(self._ids)[thekeys[k]]
314 323
315 324 if isinstance(key, (tuple, list, xrange)):
316 325 _,targets = self._build_targets(list(key))
317 326 return DirectView(self, targets)
318 327 else:
319 328 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
320 329
321 330 ############ begin real methods #############
322 331
323 332 def spin(self):
324 333 """flush incoming notifications and execution results."""
325 334 if self.notification_socket:
326 335 self._flush_notifications()
327 336 if self.queue_socket:
328 337 self._flush_results(self.queue_socket)
329 338 if self.task_socket:
330 339 self._flush_results(self.task_socket)
331 340 if self.control_socket:
332 341 self._flush_control(self.control_socket)
333 342
334 343 @spinfirst
335 344 def queue_status(self, targets=None, verbose=False):
336 345 """fetch the status of engine queues
337 346
338 347 Parameters
339 348 ----------
340 349 targets : int/str/list of ints/strs
341 350 the engines on which to execute
342 351 default : all
343 352 verbose : bool
344 353 whether to return lengths only, or lists of ids for each element
345 354
346 355 """
347 356 targets = self._build_targets(targets)[1]
348 357 content = dict(targets=targets)
349 358 self.session.send(self.query_socket, "queue_request", content=content)
350 359 idents,msg = self.session.recv(self.query_socket, 0)
351 360 if self.debug:
352 361 pprint(msg)
353 362 return msg['content']
354 363
355 364 @spinfirst
356 365 @defaultblock
357 366 def clear(self, targets=None, block=None):
358 367 """clear the namespace in target(s)"""
359 368 targets = self._build_targets(targets)[0]
360 print targets
361 369 for t in targets:
362 370 self.session.send(self.control_socket, 'clear_request', content={},ident=t)
363 371 error = False
364 372 if self.block:
365 373 for i in range(len(targets)):
366 374 idents,msg = self.session.recv(self.control_socket,0)
367 375 if self.debug:
368 376 pprint(msg)
369 377 if msg['content']['status'] != 'ok':
370 378 error = msg['content']
371 379 if error:
372 380 return error
373 381
374 382
375 383 @spinfirst
376 384 @defaultblock
377 385 def abort(self, msg_ids = None, targets=None, block=None):
378 386 """abort the Queues of target(s)"""
379 387 targets = self._build_targets(targets)[0]
380 print targets
381 388 if isinstance(msg_ids, basestring):
382 389 msg_ids = [msg_ids]
383 390 content = dict(msg_ids=msg_ids)
384 391 for t in targets:
385 392 self.session.send(self.control_socket, 'abort_request',
386 393 content=content, ident=t)
387 394 error = False
388 395 if self.block:
389 396 for i in range(len(targets)):
390 397 idents,msg = self.session.recv(self.control_socket,0)
391 398 if self.debug:
392 399 pprint(msg)
393 400 if msg['content']['status'] != 'ok':
394 401 error = msg['content']
395 402 if error:
396 403 return error
397 404
398 405 @spinfirst
399 406 @defaultblock
400 407 def kill(self, targets=None, block=None):
401 408 """Terminates one or more engine processes."""
402 409 targets = self._build_targets(targets)[0]
403 print targets
404 410 for t in targets:
405 411 self.session.send(self.control_socket, 'kill_request', content={},ident=t)
406 412 error = False
407 413 if self.block:
408 414 for i in range(len(targets)):
409 415 idents,msg = self.session.recv(self.control_socket,0)
410 416 if self.debug:
411 417 pprint(msg)
412 418 if msg['content']['status'] != 'ok':
413 419 error = msg['content']
414 420 if error:
415 421 return error
416 422
417 423 @defaultblock
418 424 def execute(self, code, targets='all', block=None):
419 425 """executes `code` on `targets` in blocking or nonblocking manner.
420 426
421 427 Parameters
422 428 ----------
423 429 code : str
424 430 the code string to be executed
425 431 targets : int/str/list of ints/strs
426 432 the engines on which to execute
427 433 default : all
428 434 block : bool
429 435 whether or not to wait until done
430 436 """
431 437 # block = self.block if block is None else block
432 438 # saveblock = self.block
433 439 # self.block = block
434 440 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
435 441 # self.block = saveblock
436 442 return result
437 443
438 444 def run(self, code, block=None):
439 445 """runs `code` on an engine.
440 446
441 447 Calls to this are load-balanced.
442 448
443 449 Parameters
444 450 ----------
445 451 code : str
446 452 the code string to be executed
447 453 block : bool
448 454 whether or not to wait until done
449 455
450 456 """
451 457 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
452 458 return result
453 459
454 460 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
455 461 after=None, follow=None):
456 462 """the underlying method for applying functions in a load balanced
457 463 manner."""
458 464 block = block if block is not None else self.block
465 if isinstance(after, Dependency):
466 after = after.as_dict()
467 elif after is None:
468 after = []
469 if isinstance(follow, Dependency):
470 follow = follow.as_dict()
471 elif follow is None:
472 follow = []
473 subheader = dict(after=after, follow=follow)
459 474
460 475 bufs = ss.pack_apply_message(f,args,kwargs)
461 476 content = dict(bound=bound)
462 477 msg = self.session.send(self.task_socket, "apply_request",
463 content=content, buffers=bufs)
478 content=content, buffers=bufs, subheader=subheader)
464 479 msg_id = msg['msg_id']
465 480 self.outstanding.add(msg_id)
466 481 self.history.append(msg_id)
467 482 if block:
468 483 self.barrier(msg_id)
469 484 return self.results[msg_id]
470 485 else:
471 486 return msg_id
472 487
473 488 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
474 489 after=None, follow=None):
475 490 """Then underlying method for applying functions to specific engines."""
476 491
477 492 block = block if block is not None else self.block
478 493
479 494 queues,targets = self._build_targets(targets)
480 print queues
481 495 bufs = ss.pack_apply_message(f,args,kwargs)
482 496 if isinstance(after, Dependency):
483 497 after = after.as_dict()
484 498 elif after is None:
485 499 after = []
486 500 if isinstance(follow, Dependency):
487 501 follow = follow.as_dict()
488 502 elif follow is None:
489 503 follow = []
490 504 subheader = dict(after=after, follow=follow)
491 505 content = dict(bound=bound)
492 506 msg_ids = []
493 507 for queue in queues:
494 508 msg = self.session.send(self.queue_socket, "apply_request",
495 509 content=content, buffers=bufs,ident=queue, subheader=subheader)
496 510 msg_id = msg['msg_id']
497 511 self.outstanding.add(msg_id)
498 512 self.history.append(msg_id)
499 513 msg_ids.append(msg_id)
500 514 if block:
501 515 self.barrier(msg_ids)
502 516 else:
503 517 if len(msg_ids) == 1:
504 518 return msg_ids[0]
505 519 else:
506 520 return msg_ids
507 521 if len(msg_ids) == 1:
508 522 return self.results[msg_ids[0]]
509 523 else:
510 524 result = {}
511 525 for target,mid in zip(targets, msg_ids):
512 526 result[target] = self.results[mid]
513 527 return result
514 528
515 529 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
516 530 after=None, follow=None):
517 531 """calls f(*args, **kwargs) on a remote engine(s), returning the result.
518 532
519 533 if self.block is False:
520 534 returns msg_id or list of msg_ids
521 535 else:
522 536 returns actual result of f(*args, **kwargs)
523 537 """
524 538 # enforce types of f,args,kwrags
525 539 args = args if args is not None else []
526 540 kwargs = kwargs if kwargs is not None else {}
527 541 if not callable(f):
528 542 raise TypeError("f must be callable, not %s"%type(f))
529 543 if not isinstance(args, (tuple, list)):
530 544 raise TypeError("args must be tuple or list, not %s"%type(args))
531 545 if not isinstance(kwargs, dict):
532 546 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
533 547
534 548 options = dict(bound=bound, block=block, after=after, follow=follow)
535 549
536 550 if targets is None:
537 551 return self._apply_balanced(f, args, kwargs, **options)
538 552 else:
539 553 return self._apply_direct(f, args, kwargs, targets=targets, **options)
540 554
541 555 def push(self, ns, targets=None, block=None):
542 556 """push the contents of `ns` into the namespace on `target`"""
543 557 if not isinstance(ns, dict):
544 558 raise TypeError("Must be a dict, not %s"%type(ns))
545 559 result = self.apply(_push, (ns,), targets=targets, block=block,bound=True)
546 560 return result
547 561
548 562 @spinfirst
549 563 def pull(self, keys, targets=None, block=True):
550 564 """pull objects from `target`'s namespace by `keys`"""
551 565
552 566 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
553 567 return result
554 568
555 569 def barrier(self, msg_ids=None, timeout=-1):
556 570 """waits on one or more `msg_ids`, for up to `timeout` seconds.
557 571
558 572 Parameters
559 573 ----------
560 574 msg_ids : int, str, or list of ints and/or strs
561 575 ints are indices to self.history
562 576 strs are msg_ids
563 577 default: wait on all outstanding messages
564 578 timeout : float
565 579 a time in seconds, after which to give up.
566 580 default is -1, which means no timeout
567 581
568 582 Returns
569 583 -------
570 584 True : when all msg_ids are done
571 585 False : timeout reached, msg_ids still outstanding
572 586 """
573 587 tic = time.time()
574 588 if msg_ids is None:
575 589 theids = self.outstanding
576 590 else:
577 591 if isinstance(msg_ids, (int, str)):
578 592 msg_ids = [msg_ids]
579 593 theids = set()
580 594 for msg_id in msg_ids:
581 595 if isinstance(msg_id, int):
582 596 msg_id = self.history[msg_id]
583 597 theids.add(msg_id)
584 598 self.spin()
585 599 while theids.intersection(self.outstanding):
586 600 if timeout >= 0 and ( time.time()-tic ) > timeout:
587 601 break
588 602 time.sleep(1e-3)
589 603 self.spin()
590 604 return len(theids.intersection(self.outstanding)) == 0
591 605
592 606 @spinfirst
593 607 def get_results(self, msg_ids,status_only=False):
594 608 """returns the result of the execute or task request with `msg_id`"""
595 609 if not isinstance(msg_ids, (list,tuple)):
596 610 msg_ids = [msg_ids]
597 611 theids = []
598 612 for msg_id in msg_ids:
599 613 if isinstance(msg_id, int):
600 614 msg_id = self.history[msg_id]
601 615 theids.append(msg_id)
602 616
603 617 content = dict(msg_ids=theids, status_only=status_only)
604 618 msg = self.session.send(self.query_socket, "result_request", content=content)
605 619 zmq.select([self.query_socket], [], [])
606 620 idents,msg = self.session.recv(self.query_socket, zmq.NOBLOCK)
607 621 if self.debug:
608 622 pprint(msg)
609 623
610 624 # while True:
611 625 # try:
612 626 # except zmq.ZMQError:
613 627 # time.sleep(1e-3)
614 628 # continue
615 629 # else:
616 630 # break
617 631 return msg['content']
618 632
619 633 class AsynClient(Client):
620 634 """An Asynchronous client, using the Tornado Event Loop"""
621 635 io_loop = None
622 636 queue_stream = None
623 637 notifier_stream = None
624 638
625 639 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
626 640 Client.__init__(self, addr, context, username, debug)
627 641 if io_loop is None:
628 642 io_loop = ioloop.IOLoop.instance()
629 643 self.io_loop = io_loop
630 644
631 645 self.queue_stream = zmqstream.ZMQStream(self.queue_socket, io_loop)
632 646 self.control_stream = zmqstream.ZMQStream(self.control_socket, io_loop)
633 647 self.task_stream = zmqstream.ZMQStream(self.task_socket, io_loop)
634 648 self.notification_stream = zmqstream.ZMQStream(self.notification_socket, io_loop)
635 649
636 650 def spin(self):
637 651 for stream in (self.queue_stream, self.notifier_stream,
638 652 self.task_stream, self.control_stream):
639 653 stream.flush()
640 654 No newline at end of file
@@ -1,921 +1,919 b''
1 1 #!/usr/bin/env python
2 # encoding: utf-8
3
4 2 """The IPython Controller with 0MQ
5 3 This is the master object that handles connections from engines, clients, and
6 4 """
7 5 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
6 # Copyright (C) 2010 The IPython Development Team
9 7 #
10 8 # Distributed under the terms of the BSD License. The full license is in
11 9 # the file COPYING, distributed as part of this software.
12 10 #-----------------------------------------------------------------------------
13 11
14 12 #-----------------------------------------------------------------------------
15 13 # Imports
16 14 #-----------------------------------------------------------------------------
17 15 from datetime import datetime
18 16 import logging
19 17
20 18 import zmq
21 19 from zmq.eventloop import zmqstream, ioloop
22 20 import uuid
23 21
24 22 # internal:
25 23 from IPython.zmq.log import logger # a Logger object
26 24 from IPython.zmq.entry_point import bind_port
27 25
28 26 from streamsession import Message, wrap_exception
29 27 from entry_point import (make_argument_parser, select_random_ports, split_ports,
30 28 connect_logger)
31 29 # from messages import json # use the same import switches
32 30
33 31 #-----------------------------------------------------------------------------
34 32 # Code
35 33 #-----------------------------------------------------------------------------
36 34
37 35 class ReverseDict(dict):
38 36 """simple double-keyed subset of dict methods."""
39 37
40 38 def __init__(self, *args, **kwargs):
41 39 dict.__init__(self, *args, **kwargs)
42 40 self.reverse = dict()
43 41 for key, value in self.iteritems():
44 42 self.reverse[value] = key
45 43
46 44 def __getitem__(self, key):
47 45 try:
48 46 return dict.__getitem__(self, key)
49 47 except KeyError:
50 48 return self.reverse[key]
51 49
52 50 def __setitem__(self, key, value):
53 51 if key in self.reverse:
54 52 raise KeyError("Can't have key %r on both sides!"%key)
55 53 dict.__setitem__(self, key, value)
56 54 self.reverse[value] = key
57 55
58 56 def pop(self, key):
59 57 value = dict.pop(self, key)
60 58 self.d1.pop(value)
61 59 return value
62 60
63 61
64 62 class EngineConnector(object):
65 63 """A simple object for accessing the various zmq connections of an object.
66 64 Attributes are:
67 65 id (int): engine ID
68 66 uuid (str): uuid (unused?)
69 67 queue (str): identity of queue's XREQ socket
70 68 registration (str): identity of registration XREQ socket
71 69 heartbeat (str): identity of heartbeat XREQ socket
72 70 """
73 71 id=0
74 72 queue=None
75 73 control=None
76 74 registration=None
77 75 heartbeat=None
78 76 pending=None
79 77
80 78 def __init__(self, id, queue, registration, control, heartbeat=None):
81 79 logger.info("engine::Engine Connected: %i"%id)
82 80 self.id = id
83 81 self.queue = queue
84 82 self.registration = registration
85 83 self.control = control
86 84 self.heartbeat = heartbeat
87 85
88 86 class Controller(object):
89 87 """The IPython Controller with 0MQ connections
90 88
91 89 Parameters
92 90 ==========
93 91 loop: zmq IOLoop instance
94 92 session: StreamSession object
95 93 <removed> context: zmq context for creating new connections (?)
96 94 registrar: ZMQStream for engine registration requests (XREP)
97 95 clientele: ZMQStream for client connections (XREP)
98 96 not used for jobs, only query/control commands
99 97 queue: ZMQStream for monitoring the command queue (SUB)
100 98 heartbeat: HeartMonitor object checking the pulse of the engines
101 99 db_stream: connection to db for out of memory logging of commands
102 100 NotImplemented
103 101 queue_addr: zmq connection address of the XREP socket for the queue
104 102 hb_addr: zmq connection address of the PUB socket for heartbeats
105 103 task_addr: zmq connection address of the XREQ socket for task queue
106 104 """
107 105 # internal data structures:
108 106 ids=None # engine IDs
109 107 keytable=None
110 108 engines=None
111 109 clients=None
112 110 hearts=None
113 111 pending=None
114 112 results=None
115 113 tasks=None
116 114 completed=None
117 115 mia=None
118 116 incoming_registrations=None
119 117 registration_timeout=None
120 118
121 119 #objects from constructor:
122 120 loop=None
123 121 registrar=None
124 122 clientelle=None
125 123 queue=None
126 124 heartbeat=None
127 125 notifier=None
128 126 db=None
129 127 client_addr=None
130 128 engine_addrs=None
131 129
132 130
133 131 def __init__(self, loop, session, queue, registrar, heartbeat, clientele, notifier, db, engine_addrs, client_addrs):
134 132 """
135 133 # universal:
136 134 loop: IOLoop for creating future connections
137 135 session: streamsession for sending serialized data
138 136 # engine:
139 137 queue: ZMQStream for monitoring queue messages
140 138 registrar: ZMQStream for engine registration
141 139 heartbeat: HeartMonitor object for tracking engines
142 140 # client:
143 141 clientele: ZMQStream for client connections
144 142 # extra:
145 143 db: ZMQStream for db connection (NotImplemented)
146 144 engine_addrs: zmq address/protocol dict for engine connections
147 145 client_addrs: zmq address/protocol dict for client connections
148 146 """
149 147 self.ids = set()
150 148 self.keytable={}
151 149 self.incoming_registrations={}
152 150 self.engines = {}
153 151 self.by_ident = {}
154 152 self.clients = {}
155 153 self.hearts = {}
156 154 self.mia = set()
157 155
158 156 # self.sockets = {}
159 157 self.loop = loop
160 158 self.session = session
161 159 self.registrar = registrar
162 160 self.clientele = clientele
163 161 self.queue = queue
164 162 self.heartbeat = heartbeat
165 163 self.notifier = notifier
166 164 self.db = db
167 165
168 166 self.client_addrs = client_addrs
169 167 assert isinstance(client_addrs['queue'], str)
170 168 # self.hb_addrs = hb_addrs
171 169 self.engine_addrs = engine_addrs
172 170 assert isinstance(engine_addrs['queue'], str)
173 171 assert len(engine_addrs['heartbeat']) == 2
174 172
175 173
176 174 # register our callbacks
177 175 self.registrar.on_recv(self.dispatch_register_request)
178 176 self.clientele.on_recv(self.dispatch_client_msg)
179 177 self.queue.on_recv(self.dispatch_queue_traffic)
180 178
181 179 if heartbeat is not None:
182 180 heartbeat.add_heart_failure_handler(self.handle_heart_failure)
183 181 heartbeat.add_new_heart_handler(self.handle_new_heart)
184 182
185 183 if self.db is not None:
186 184 self.db.on_recv(self.dispatch_db)
187 185
188 186 self.client_handlers = {'queue_request': self.queue_status,
189 187 'result_request': self.get_results,
190 188 'purge_request': self.purge_results,
191 189 'resubmit_request': self.resubmit_task,
192 190 }
193 191
194 192 self.registrar_handlers = {'registration_request' : self.register_engine,
195 193 'unregistration_request' : self.unregister_engine,
196 194 'connection_request': self.connection_request,
197 195
198 196 }
199 197 #
200 198 # this is the stuff that will move to DB:
201 199 self.results = {} # completed results
202 200 self.pending = {} # pending messages, keyed by msg_id
203 201 self.queues = {} # pending msg_ids keyed by engine_id
204 202 self.tasks = {} # pending msg_ids submitted as tasks, keyed by client_id
205 203 self.completed = {} # completed msg_ids keyed by engine_id
206 204 self.registration_timeout = max(5000, 2*self.heartbeat.period)
207 205
208 206 logger.info("controller::created controller")
209 207
210 208 def _new_id(self):
211 209 """gemerate a new ID"""
212 210 newid = 0
213 211 incoming = [id[0] for id in self.incoming_registrations.itervalues()]
214 212 # print newid, self.ids, self.incoming_registrations
215 213 while newid in self.ids or newid in incoming:
216 214 newid += 1
217 215 return newid
218 216
219
220 217 #-----------------------------------------------------------------------------
221 218 # message validation
222 219 #-----------------------------------------------------------------------------
220
223 221 def _validate_targets(self, targets):
224 222 """turn any valid targets argument into a list of integer ids"""
225 223 if targets is None:
226 224 # default to all
227 225 targets = self.ids
228 226
229 227 if isinstance(targets, (int,str,unicode)):
230 228 # only one target specified
231 229 targets = [targets]
232 230 _targets = []
233 231 for t in targets:
234 232 # map raw identities to ids
235 233 if isinstance(t, (str,unicode)):
236 234 t = self.by_ident.get(t, t)
237 235 _targets.append(t)
238 236 targets = _targets
239 237 bad_targets = [ t for t in targets if t not in self.ids ]
240 238 if bad_targets:
241 239 raise IndexError("No Such Engine: %r"%bad_targets)
242 240 if not targets:
243 241 raise IndexError("No Engines Registered")
244 242 return targets
245 243
246 244 def _validate_client_msg(self, msg):
247 245 """validates and unpacks headers of a message. Returns False if invalid,
248 246 (ident, header, parent, content)"""
249 247 client_id = msg[0]
250 248 try:
251 249 msg = self.session.unpack_message(msg[1:], content=True)
252 250 except:
253 251 logger.error("client::Invalid Message %s"%msg)
254 252 return False
255 253
256 254 msg_type = msg.get('msg_type', None)
257 255 if msg_type is None:
258 256 return False
259 257 header = msg.get('header')
260 258 # session doesn't handle split content for now:
261 259 return client_id, msg
262 260
263 261
264 262 #-----------------------------------------------------------------------------
265 # dispatch methods (1 per socket)
263 # dispatch methods (1 per stream)
266 264 #-----------------------------------------------------------------------------
267 265
268 266 def dispatch_register_request(self, msg):
269 267 """"""
270 268 logger.debug("registration::dispatch_register_request(%s)"%msg)
271 269 idents,msg = self.session.feed_identities(msg)
272 270 print idents,msg, len(msg)
273 271 try:
274 272 msg = self.session.unpack_message(msg,content=True)
275 273 except Exception, e:
276 274 logger.error("registration::got bad registration message: %s"%msg)
277 275 raise e
278 276 return
279 277
280 278 msg_type = msg['msg_type']
281 279 content = msg['content']
282 280
283 281 handler = self.registrar_handlers.get(msg_type, None)
284 282 if handler is None:
285 283 logger.error("registration::got bad registration message: %s"%msg)
286 284 else:
287 285 handler(idents, msg)
288 286
289 287 def dispatch_queue_traffic(self, msg):
290 288 """all ME and Task queue messages come through here"""
291 289 logger.debug("queue traffic: %s"%msg[:2])
292 290 switch = msg[0]
293 291 idents, msg = self.session.feed_identities(msg[1:])
294 292 if switch == 'in':
295 293 self.save_queue_request(idents, msg)
296 294 elif switch == 'out':
297 295 self.save_queue_result(idents, msg)
298 296 elif switch == 'intask':
299 297 self.save_task_request(idents, msg)
300 298 elif switch == 'outtask':
301 299 self.save_task_result(idents, msg)
302 300 elif switch == 'tracktask':
303 301 self.save_task_destination(idents, msg)
304 302 elif switch in ('incontrol', 'outcontrol'):
305 303 pass
306 304 else:
307 305 logger.error("Invalid message topic: %s"%switch)
308 306
309 307
310 308 def dispatch_client_msg(self, msg):
311 309 """Route messages from clients"""
312 310 idents, msg = self.session.feed_identities(msg)
313 311 client_id = idents[0]
314 312 try:
315 313 msg = self.session.unpack_message(msg, content=True)
316 314 except:
317 315 content = wrap_exception()
318 316 logger.error("Bad Client Message: %s"%msg)
319 317 self.session.send(self.clientele, "controller_error", ident=client_id,
320 318 content=content)
321 319 return
322 320
323 321 # print client_id, header, parent, content
324 322 #switch on message type:
325 323 msg_type = msg['msg_type']
326 324 logger.info("client:: client %s requested %s"%(client_id, msg_type))
327 325 handler = self.client_handlers.get(msg_type, None)
328 326 try:
329 327 assert handler is not None, "Bad Message Type: %s"%msg_type
330 328 except:
331 329 content = wrap_exception()
332 330 logger.error("Bad Message Type: %s"%msg_type)
333 331 self.session.send(self.clientele, "controller_error", ident=client_id,
334 332 content=content)
335 333 return
336 334 else:
337 335 handler(client_id, msg)
338 336
339 337 def dispatch_db(self, msg):
340 338 """"""
341 339 raise NotImplementedError
342 340
343 341 #---------------------------------------------------------------------------
344 342 # handler methods (1 per event)
345 343 #---------------------------------------------------------------------------
346 344
347 345 #----------------------- Heartbeat --------------------------------------
348 346
349 347 def handle_new_heart(self, heart):
350 348 """handler to attach to heartbeater.
351 349 Called when a new heart starts to beat.
352 350 Triggers completion of registration."""
353 351 logger.debug("heartbeat::handle_new_heart(%r)"%heart)
354 352 if heart not in self.incoming_registrations:
355 353 logger.info("heartbeat::ignoring new heart: %r"%heart)
356 354 else:
357 355 self.finish_registration(heart)
358 356
359 357
360 358 def handle_heart_failure(self, heart):
361 359 """handler to attach to heartbeater.
362 360 called when a previously registered heart fails to respond to beat request.
363 361 triggers unregistration"""
364 362 logger.debug("heartbeat::handle_heart_failure(%r)"%heart)
365 363 eid = self.hearts.get(heart, None)
366 364 queue = self.engines[eid].queue
367 365 if eid is None:
368 366 logger.info("heartbeat::ignoring heart failure %r"%heart)
369 367 else:
370 368 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
371 369
372 370 #----------------------- MUX Queue Traffic ------------------------------
373 371
374 372 def save_queue_request(self, idents, msg):
375 373 queue_id, client_id = idents[:2]
376 374
377 375 try:
378 376 msg = self.session.unpack_message(msg, content=False)
379 377 except:
380 378 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg))
381 379 return
382 380
383 381 eid = self.by_ident.get(queue_id, None)
384 382 if eid is None:
385 383 logger.error("queue::target %r not registered"%queue_id)
386 384 logger.debug("queue:: valid are: %s"%(self.by_ident.keys()))
387 385 return
388 386
389 387 header = msg['header']
390 388 msg_id = header['msg_id']
391 389 info = dict(submit=datetime.now(),
392 390 received=None,
393 391 engine=(eid, queue_id))
394 392 self.pending[msg_id] = ( msg, info )
395 393 self.queues[eid][0].append(msg_id)
396 394
397 395 def save_queue_result(self, idents, msg):
398 396 client_id, queue_id = idents[:2]
399 397
400 398 try:
401 399 msg = self.session.unpack_message(msg, content=False)
402 400 except:
403 401 logger.error("queue::engine %r sent invalid message to %r: %s"%(
404 402 queue_id,client_id, msg))
405 403 return
406 404
407 405 eid = self.by_ident.get(queue_id, None)
408 406 if eid is None:
409 407 logger.error("queue::unknown engine %r is sending a reply: "%queue_id)
410 408 logger.debug("queue:: %s"%msg[2:])
411 409 return
412 410
413 411 parent = msg['parent_header']
414 412 if not parent:
415 413 return
416 414 msg_id = parent['msg_id']
417 415 self.results[msg_id] = msg
418 416 if msg_id in self.pending:
419 417 self.pending.pop(msg_id)
420 418 self.queues[eid][0].remove(msg_id)
421 419 self.completed[eid].append(msg_id)
422 420 else:
423 421 logger.debug("queue:: unknown msg finished %s"%msg_id)
424 422
425 423 #--------------------- Task Queue Traffic ------------------------------
426 424
427 425 def save_task_request(self, idents, msg):
428 426 client_id = idents[0]
429 427
430 428 try:
431 429 msg = self.session.unpack_message(msg, content=False)
432 430 except:
433 431 logger.error("task::client %r sent invalid task message: %s"%(
434 432 client_id, msg))
435 433 return
436 434
437 435 header = msg['header']
438 436 msg_id = header['msg_id']
439 437 self.mia.add(msg_id)
440 438 self.pending[msg_id] = msg
441 439 if not self.tasks.has_key(client_id):
442 440 self.tasks[client_id] = []
443 441 self.tasks[client_id].append(msg_id)
444 442
445 443 def save_task_result(self, idents, msg):
446 444 client_id = idents[0]
447 445 try:
448 446 msg = self.session.unpack_message(msg, content=False)
449 447 except:
450 448 logger.error("task::invalid task result message send to %r: %s"%(
451 449 client_id, msg))
452 450 return
453 451
454 452 parent = msg['parent_header']
455 453 if not parent:
456 454 # print msg
457 455 # logger.warn("")
458 456 return
459 457 msg_id = parent['msg_id']
460 458 self.results[msg_id] = msg
461 459 if msg_id in self.pending:
462 460 self.pending.pop(msg_id)
463 461 if msg_id in self.mia:
464 462 self.mia.remove(msg_id)
465 463 else:
466 logger.debug("task:: unknown task %s finished"%msg_id)
464 logger.debug("task::unknown task %s finished"%msg_id)
467 465
468 466 def save_task_destination(self, idents, msg):
469 467 try:
470 468 msg = self.session.unpack_message(msg, content=True)
471 469 except:
472 470 logger.error("task::invalid task tracking message")
473 471 return
474 472 content = msg['content']
475 473 print content
476 474 msg_id = content['msg_id']
477 475 engine_uuid = content['engine_id']
478 476 for eid,queue_id in self.keytable.iteritems():
479 477 if queue_id == engine_uuid:
480 478 break
481 479
482 logger.info("task:: task %s arrived on %s"%(msg_id, eid))
480 logger.info("task::task %s arrived on %s"%(msg_id, eid))
483 481 if msg_id in self.mia:
484 482 self.mia.remove(msg_id)
485 483 else:
486 484 logger.debug("task::task %s not listed as MIA?!"%(msg_id))
487 485 self.tasks[engine_uuid].append(msg_id)
488 486
489 487 def mia_task_request(self, idents, msg):
490 488 client_id = idents[0]
491 489 content = dict(mia=self.mia,status='ok')
492 490 self.session.send('mia_reply', content=content, idents=client_id)
493 491
494 492
495 493
496 494 #-------------------- Registration -----------------------------
497 495
498 496 def connection_request(self, client_id, msg):
499 497 """reply with connection addresses for clients"""
500 498 logger.info("client::client %s connected"%client_id)
501 499 content = dict(status='ok')
502 500 content.update(self.client_addrs)
503 501 jsonable = {}
504 502 for k,v in self.keytable.iteritems():
505 503 jsonable[str(k)] = v
506 504 content['engines'] = jsonable
507 505 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
508 506
509 507 def register_engine(self, reg, msg):
510 508 """register an engine"""
511 509 content = msg['content']
512 510 try:
513 511 queue = content['queue']
514 512 except KeyError:
515 513 logger.error("registration::queue not specified")
516 514 return
517 515 heart = content.get('heartbeat', None)
518 516 """register a new engine, and create the socket(s) necessary"""
519 517 eid = self._new_id()
520 518 # print (eid, queue, reg, heart)
521 519
522 520 logger.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
523 521
524 522 content = dict(id=eid,status='ok')
525 523 content.update(self.engine_addrs)
526 524 # check if requesting available IDs:
527 525 if queue in self.by_ident:
528 526 content = {'status': 'error', 'reason': "queue_id %r in use"%queue}
529 527 elif heart in self.hearts: # need to check unique hearts?
530 528 content = {'status': 'error', 'reason': "heart_id %r in use"%heart}
531 529 else:
532 530 for h, pack in self.incoming_registrations.iteritems():
533 531 if heart == h:
534 532 content = {'status': 'error', 'reason': "heart_id %r in use"%heart}
535 533 break
536 534 elif queue == pack[1]:
537 535 content = {'status': 'error', 'reason': "queue_id %r in use"%queue}
538 536 break
539 537
540 538 msg = self.session.send(self.registrar, "registration_reply",
541 539 content=content,
542 540 ident=reg)
543 541
544 542 if content['status'] == 'ok':
545 543 if heart in self.heartbeat.hearts:
546 544 # already beating
547 545 self.incoming_registrations[heart] = (eid,queue,reg,None)
548 546 self.finish_registration(heart)
549 547 else:
550 548 purge = lambda : self._purge_stalled_registration(heart)
551 549 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
552 550 dc.start()
553 551 self.incoming_registrations[heart] = (eid,queue,reg,dc)
554 552 else:
555 553 logger.error("registration::registration %i failed: %s"%(eid, content['reason']))
556 554 return eid
557 555
558 556 def unregister_engine(self, ident, msg):
559 557 try:
560 558 eid = msg['content']['id']
561 559 except:
562 560 logger.error("registration::bad engine id for unregistration: %s"%ident)
563 561 return
564 562 logger.info("registration::unregister_engine(%s)"%eid)
565 563 content=dict(id=eid, queue=self.engines[eid].queue)
566 564 self.ids.remove(eid)
567 565 self.keytable.pop(eid)
568 566 ec = self.engines.pop(eid)
569 567 self.hearts.pop(ec.heartbeat)
570 568 self.by_ident.pop(ec.queue)
571 569 self.completed.pop(eid)
572 570 for msg_id in self.queues.pop(eid)[0]:
573 571 msg = self.pending.pop(msg_id)
574 572 ############## TODO: HANDLE IT ################
575 573
576 574 if self.notifier:
577 575 self.session.send(self.notifier, "unregistration_notification", content=content)
578 576
579 577 def finish_registration(self, heart):
580 578 try:
581 579 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
582 580 except KeyError:
583 581 logger.error("registration::tried to finish nonexistant registration")
584 582 return
585 583 logger.info("registration::finished registering engine %i:%r"%(eid,queue))
586 584 if purge is not None:
587 585 purge.stop()
588 586 control = queue
589 587 self.ids.add(eid)
590 588 self.keytable[eid] = queue
591 589 self.engines[eid] = EngineConnector(eid, queue, reg, control, heart)
592 590 self.by_ident[queue] = eid
593 591 self.queues[eid] = ([],[])
594 592 self.completed[eid] = list()
595 593 self.hearts[heart] = eid
596 594 content = dict(id=eid, queue=self.engines[eid].queue)
597 595 if self.notifier:
598 596 self.session.send(self.notifier, "registration_notification", content=content)
599 597
600 598 def _purge_stalled_registration(self, heart):
601 599 if heart in self.incoming_registrations:
602 600 eid = self.incoming_registrations.pop(heart)[0]
603 601 logger.info("registration::purging stalled registration: %i"%eid)
604 602 else:
605 603 pass
606 604
607 605 #------------------- Client Requests -------------------------------
608 606
609 607 def check_load(self, client_id, msg):
610 608 content = msg['content']
611 609 try:
612 610 targets = content['targets']
613 611 targets = self._validate_targets(targets)
614 612 except:
615 613 content = wrap_exception()
616 614 self.session.send(self.clientele, "controller_error",
617 615 content=content, ident=client_id)
618 616 return
619 617
620 618 content = dict(status='ok')
621 619 # loads = {}
622 620 for t in targets:
623 621 content[str(t)] = len(self.queues[t])
624 622 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
625 623
626 624
627 625 def queue_status(self, client_id, msg):
628 626 """handle queue_status request"""
629 627 content = msg['content']
630 628 targets = content['targets']
631 629 try:
632 630 targets = self._validate_targets(targets)
633 631 except:
634 632 content = wrap_exception()
635 633 self.session.send(self.clientele, "controller_error",
636 634 content=content, ident=client_id)
637 635 return
638 636 verbose = msg.get('verbose', False)
639 637 content = dict()
640 638 for t in targets:
641 639 queue = self.queues[t]
642 640 completed = self.completed[t]
643 641 if not verbose:
644 642 queue = len(queue)
645 643 completed = len(completed)
646 644 content[str(t)] = {'queue': queue, 'completed': completed }
647 645 # pending
648 646 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
649 647
650 648 def purge_results(self, client_id, msg):
651 649 content = msg['content']
652 650 msg_ids = content.get('msg_ids', [])
653 651 reply = dict(status='ok')
654 652 if msg_ids == 'all':
655 653 self.results = {}
656 654 else:
657 655 for msg_id in msg_ids:
658 656 if msg_id in self.results:
659 657 self.results.pop(msg_id)
660 658 else:
661 659 if msg_id in self.pending:
662 660 reply = dict(status='error', reason="msg pending: %r"%msg_id)
663 661 else:
664 662 reply = dict(status='error', reason="No such msg: %r"%msg_id)
665 663 break
666 664 eids = content.get('engine_ids', [])
667 665 for eid in eids:
668 666 if eid not in self.engines:
669 667 reply = dict(status='error', reason="No such engine: %i"%eid)
670 668 break
671 669 msg_ids = self.completed.pop(eid)
672 670 for msg_id in msg_ids:
673 671 self.results.pop(msg_id)
674 672
675 673 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
676 674
677 675 def resubmit_task(self, client_id, msg, buffers):
678 676 content = msg['content']
679 677 header = msg['header']
680 678
681 679
682 680 msg_ids = content.get('msg_ids', [])
683 681 reply = dict(status='ok')
684 682 if msg_ids == 'all':
685 683 self.results = {}
686 684 else:
687 685 for msg_id in msg_ids:
688 686 if msg_id in self.results:
689 687 self.results.pop(msg_id)
690 688 else:
691 689 if msg_id in self.pending:
692 690 reply = dict(status='error', reason="msg pending: %r"%msg_id)
693 691 else:
694 692 reply = dict(status='error', reason="No such msg: %r"%msg_id)
695 693 break
696 694 eids = content.get('engine_ids', [])
697 695 for eid in eids:
698 696 if eid not in self.engines:
699 697 reply = dict(status='error', reason="No such engine: %i"%eid)
700 698 break
701 699 msg_ids = self.completed.pop(eid)
702 700 for msg_id in msg_ids:
703 701 self.results.pop(msg_id)
704 702
705 703 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
706 704
707 705 def get_results(self, client_id, msg):
708 706 """get the result of 1 or more messages"""
709 707 content = msg['content']
710 708 msg_ids = set(content['msg_ids'])
711 709 statusonly = content.get('status_only', False)
712 710 pending = []
713 711 completed = []
714 712 content = dict(status='ok')
715 713 content['pending'] = pending
716 714 content['completed'] = completed
717 715 for msg_id in msg_ids:
718 716 if msg_id in self.pending:
719 717 pending.append(msg_id)
720 718 elif msg_id in self.results:
721 719 completed.append(msg_id)
722 720 if not statusonly:
723 721 content[msg_id] = self.results[msg_id]['content']
724 722 else:
725 723 content = dict(status='error')
726 724 content['reason'] = 'no such message: '+msg_id
727 725 break
728 726 self.session.send(self.clientele, "result_reply", content=content,
729 727 parent=msg, ident=client_id)
730 728
731 729
732 730
733 731 ############ OLD METHODS for Python Relay Controller ###################
734 732 def _validate_engine_msg(self, msg):
735 733 """validates and unpacks headers of a message. Returns False if invalid,
736 734 (ident, message)"""
737 735 ident = msg[0]
738 736 try:
739 737 msg = self.session.unpack_message(msg[1:], content=False)
740 738 except:
741 739 logger.error("engine.%s::Invalid Message %s"%(ident, msg))
742 740 return False
743 741
744 742 try:
745 743 eid = msg.header.username
746 744 assert self.engines.has_key(eid)
747 745 except:
748 746 logger.error("engine::Invalid Engine ID %s"%(ident))
749 747 return False
750 748
751 749 return eid, msg
752 750
753 751
754 752 #--------------------
755 753 # Entry Point
756 754 #--------------------
757 755
758 756 def main():
759 757 import time
760 758 from multiprocessing import Process
761 759
762 760 from zmq.eventloop.zmqstream import ZMQStream
763 761 from zmq.devices import ProcessMonitoredQueue
764 762 from zmq.log import handlers
765 763
766 764 import streamsession as session
767 765 import heartmonitor
768 766 from scheduler import launch_scheduler
769 767
770 768 parser = make_argument_parser()
771 769
772 770 parser.add_argument('--client', type=int, metavar='PORT', default=0,
773 771 help='set the XREP port for clients [default: random]')
774 772 parser.add_argument('--notice', type=int, metavar='PORT', default=0,
775 773 help='set the PUB socket for registration notification [default: random]')
776 774 parser.add_argument('--hb', type=str, metavar='PORTS',
777 775 help='set the 2 ports for heartbeats [default: random]')
778 776 parser.add_argument('--ping', type=int, default=3000,
779 777 help='set the heartbeat period in ms [default: 3000]')
780 778 parser.add_argument('--monitor', type=int, metavar='PORT', default=0,
781 779 help='set the SUB port for queue monitoring [default: random]')
782 780 parser.add_argument('--mux', type=str, metavar='PORTS',
783 781 help='set the XREP ports for the MUX queue [default: random]')
784 782 parser.add_argument('--task', type=str, metavar='PORTS',
785 783 help='set the XREP/XREQ ports for the task queue [default: random]')
786 784 parser.add_argument('--control', type=str, metavar='PORTS',
787 785 help='set the XREP ports for the control queue [default: random]')
788 786 parser.add_argument('--scheduler', type=str, default='pure',
789 787 choices = ['pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'],
790 788 help='select the task scheduler [default: pure ZMQ]')
791 789
792 790 args = parser.parse_args()
793 791
794 792 if args.url:
795 793 args.transport,iface = args.url.split('://')
796 794 iface = iface.split(':')
797 795 args.ip = iface[0]
798 796 if iface[1]:
799 797 args.regport = iface[1]
800 798
801 799 iface="%s://%s"%(args.transport,args.ip)+':%i'
802 800
803 801 random_ports = 0
804 802 if args.hb:
805 803 hb = split_ports(args.hb, 2)
806 804 else:
807 805 hb = select_random_ports(2)
808 806 if args.mux:
809 807 mux = split_ports(args.mux, 2)
810 808 else:
811 809 mux = None
812 810 random_ports += 2
813 811 if args.task:
814 812 task = split_ports(args.task, 2)
815 813 else:
816 814 task = None
817 815 random_ports += 2
818 816 if args.control:
819 817 control = split_ports(args.control, 2)
820 818 else:
821 819 control = None
822 820 random_ports += 2
823 821
824 822 ctx = zmq.Context()
825 823 loop = ioloop.IOLoop.instance()
826 824
827 825 # setup logging
828 826 connect_logger(ctx, iface%args.logport, root="controller", loglevel=args.loglevel)
829 827
830 828 # Registrar socket
831 829 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
832 830 regport = bind_port(reg, args.ip, args.regport)
833 831
834 832 ### Engine connections ###
835 833
836 834 # heartbeat
837 835 hpub = ctx.socket(zmq.PUB)
838 836 bind_port(hpub, args.ip, hb[0])
839 837 hrep = ctx.socket(zmq.XREP)
840 838 bind_port(hrep, args.ip, hb[1])
841 839
842 840 hmon = heartmonitor.HeartMonitor(loop, ZMQStream(hpub,loop), ZMQStream(hrep,loop),args.ping)
843 841 hmon.start()
844 842
845 843 ### Client connections ###
846 844 # Clientele socket
847 845 c = ZMQStream(ctx.socket(zmq.XREP), loop)
848 846 cport = bind_port(c, args.ip, args.client)
849 847 # Notifier socket
850 848 n = ZMQStream(ctx.socket(zmq.PUB), loop)
851 849 nport = bind_port(n, args.ip, args.notice)
852 850
853 851 thesession = session.StreamSession(username=args.ident or "controller")
854 852
855 853 ### build and launch the queues ###
856 854
857 855 # monitor socket
858 856 sub = ctx.socket(zmq.SUB)
859 857 sub.setsockopt(zmq.SUBSCRIBE, "")
860 858 monport = bind_port(sub, args.ip, args.monitor)
861 859 sub = ZMQStream(sub, loop)
862 860
863 861 ports = select_random_ports(random_ports)
864 862 # Multiplexer Queue (in a Process)
865 863 if not mux:
866 864 mux = (ports.pop(),ports.pop())
867 865 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
868 866 q.bind_in(iface%mux[0])
869 867 q.bind_out(iface%mux[1])
870 868 q.connect_mon(iface%monport)
871 869 q.daemon=True
872 870 q.start()
873 871
874 872 # Control Queue (in a Process)
875 873 if not control:
876 874 control = (ports.pop(),ports.pop())
877 875 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
878 876 q.bind_in(iface%control[0])
879 877 q.bind_out(iface%control[1])
880 878 q.connect_mon(iface%monport)
881 879 q.daemon=True
882 880 q.start()
883 881
884 882 # Task Queue (in a Process)
885 883 if not task:
886 884 task = (ports.pop(),ports.pop())
887 885 if args.scheduler == 'pure':
888 886 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
889 887 q.bind_in(iface%task[0])
890 888 q.bind_out(iface%task[1])
891 889 q.connect_mon(iface%monport)
892 890 q.daemon=True
893 891 q.start()
894 892 else:
895 893 sargs = (iface%task[0],iface%task[1],iface%monport,iface%nport,args.scheduler)
896 894 print sargs
897 895 p = Process(target=launch_scheduler, args=sargs)
898 896 p.daemon=True
899 897 p.start()
900 898
901 899 time.sleep(.25)
902 900
903 901 # build connection dicts
904 902 engine_addrs = {
905 903 'control' : iface%control[1],
906 904 'queue': iface%mux[1],
907 905 'heartbeat': (iface%hb[0], iface%hb[1]),
908 906 'task' : iface%task[1],
909 907 'monitor' : iface%monport,
910 908 }
911 909
912 910 client_addrs = {
913 911 'control' : iface%control[0],
914 912 'query': iface%cport,
915 913 'queue': iface%mux[0],
916 914 'task' : iface%task[0],
917 915 'notification': iface%nport
918 916 }
919 917 con = Controller(loop, thesession, sub, reg, hmon, c, n, None, engine_addrs, client_addrs)
920 918 loop.start()
921 919
@@ -1,401 +1,404 b''
1 1 #----------------------------------------------------------------------
2 2 # Imports
3 3 #----------------------------------------------------------------------
4 4
5 5 from random import randint,random
6 6
7 7 try:
8 8 import numpy
9 9 except ImportError:
10 10 numpy = None
11 11
12 12 import zmq
13 13 from zmq.eventloop import ioloop, zmqstream
14 14
15 15 # local imports
16 16 from IPython.zmq.log import logger # a Logger object
17 17 from client import Client
18 18 from dependency import Dependency
19 19 import streamsession as ss
20 20
21 21 from IPython.external.decorator import decorator
22 22
23 23 @decorator
24 24 def logged(f,self,*args,**kwargs):
25 25 print ("#--------------------")
26 26 print ("%s(*%s,**%s)"%(f.func_name, args, kwargs))
27 print ("#--")
27 28 return f(self,*args, **kwargs)
28 29
29 30 #----------------------------------------------------------------------
30 31 # Chooser functions
31 32 #----------------------------------------------------------------------
32 33
33 34 def plainrandom(loads):
34 35 """Plain random pick."""
35 36 n = len(loads)
36 37 return randint(0,n-1)
37 38
38 39 def lru(loads):
39 40 """Always pick the front of the line.
40 41
41 42 The content of loads is ignored.
42 43
43 44 Assumes LRU ordering of loads, with oldest first.
44 45 """
45 46 return 0
46 47
47 48 def twobin(loads):
48 49 """Pick two at random, use the LRU of the two.
49 50
50 51 The content of loads is ignored.
51 52
52 53 Assumes LRU ordering of loads, with oldest first.
53 54 """
54 55 n = len(loads)
55 56 a = randint(0,n-1)
56 57 b = randint(0,n-1)
57 58 return min(a,b)
58 59
59 60 def weighted(loads):
60 61 """Pick two at random using inverse load as weight.
61 62
62 63 Return the less loaded of the two.
63 64 """
64 65 # weight 0 a million times more than 1:
65 66 weights = 1./(1e-6+numpy.array(loads))
66 67 sums = weights.cumsum()
67 68 t = sums[-1]
68 69 x = random()*t
69 70 y = random()*t
70 71 idx = 0
71 72 idy = 0
72 73 while sums[idx] < x:
73 74 idx += 1
74 75 while sums[idy] < y:
75 76 idy += 1
76 77 if weights[idy] > weights[idx]:
77 78 return idy
78 79 else:
79 80 return idx
80 81
81 82 def leastload(loads):
82 83 """Always choose the lowest load.
83 84
84 85 If the lowest load occurs more than once, the first
85 86 occurance will be used. If loads has LRU ordering, this means
86 87 the LRU of those with the lowest load is chosen.
87 88 """
88 89 return loads.index(min(loads))
89 90
90 91 #---------------------------------------------------------------------
91 92 # Classes
92 93 #---------------------------------------------------------------------
93 94 class TaskScheduler(object):
94 """Simple Python TaskScheduler object.
95 """Python TaskScheduler object.
95 96
96 97 This is the simplest object that supports msg_id based
97 98 DAG dependencies. *Only* task msg_ids are checked, not
98 99 msg_ids of jobs submitted via the MUX queue.
99 100
100 101 """
101 102
102 103 scheme = leastload # function for determining the destination
103 104 client_stream = None # client-facing stream
104 105 engine_stream = None # engine-facing stream
105 106 mon_stream = None # controller-facing stream
106 107 dependencies = None # dict by msg_id of [ msg_ids that depend on key ]
107 108 depending = None # dict by msg_id of (msg_id, raw_msg, after, follow)
108 109 pending = None # dict by engine_uuid of submitted tasks
109 110 completed = None # dict by engine_uuid of completed tasks
110 111 clients = None # dict by msg_id for who submitted the task
111 112 targets = None # list of target IDENTs
112 113 loads = None # list of engine loads
113 114 all_done = None # set of all completed tasks
114 115 blacklist = None # dict by msg_id of locations where a job has encountered UnmetDependency
115 116
116 117
117 118 def __init__(self, client_stream, engine_stream, mon_stream,
118 119 notifier_stream, scheme=None, io_loop=None):
119 120 if io_loop is None:
120 121 io_loop = ioloop.IOLoop.instance()
121 122 self.io_loop = io_loop
122 123 self.client_stream = client_stream
123 124 self.engine_stream = engine_stream
124 125 self.mon_stream = mon_stream
125 126 self.notifier_stream = notifier_stream
126 127
127 128 if scheme is not None:
128 129 self.scheme = scheme
129 130 else:
130 131 self.scheme = TaskScheduler.scheme
131 132
132 133 self.session = ss.StreamSession(username="TaskScheduler")
133 134
134 135 self.dependencies = {}
135 136 self.depending = {}
136 137 self.completed = {}
137 138 self.pending = {}
138 139 self.all_done = set()
140 self.blacklist = {}
139 141
140 142 self.targets = []
141 143 self.loads = []
142 144
143 145 engine_stream.on_recv(self.dispatch_result, copy=False)
144 146 self._notification_handlers = dict(
145 147 registration_notification = self._register_engine,
146 148 unregistration_notification = self._unregister_engine
147 149 )
148 150 self.notifier_stream.on_recv(self.dispatch_notification)
149 151
150 152 def resume_receiving(self):
151 153 """resume accepting jobs"""
152 154 self.client_stream.on_recv(self.dispatch_submission, copy=False)
153 155
154 156 def stop_receiving(self):
155 157 self.client_stream.on_recv(None)
156 158
157 159 #-----------------------------------------------------------------------
158 160 # [Un]Registration Handling
159 161 #-----------------------------------------------------------------------
160 162
161 163 def dispatch_notification(self, msg):
162 164 """dispatch register/unregister events."""
163 165 idents,msg = self.session.feed_identities(msg)
164 166 msg = self.session.unpack_message(msg)
165 167 msg_type = msg['msg_type']
166 168 handler = self._notification_handlers.get(msg_type, None)
167 169 if handler is None:
168 170 raise Exception("Unhandled message type: %s"%msg_type)
169 171 else:
170 172 try:
171 173 handler(str(msg['content']['queue']))
172 174 except KeyError:
173 175 logger.error("task::Invalid notification msg: %s"%msg)
174 176 @logged
175 177 def _register_engine(self, uid):
176 178 """new engine became available"""
177 179 # head of the line:
178 180 self.targets.insert(0,uid)
179 181 self.loads.insert(0,0)
180 182 # initialize sets
181 183 self.completed[uid] = set()
182 184 self.pending[uid] = {}
183 185 if len(self.targets) == 1:
184 186 self.resume_receiving()
185 187
186 188 def _unregister_engine(self, uid):
187 189 """existing engine became unavailable"""
188 190 # handle any potentially finished tasks:
189 191 if len(self.targets) == 1:
190 192 self.stop_receiving()
191 193 self.engine_stream.flush()
192 194
193 195 self.completed.pop(uid)
194 196 lost = self.pending.pop(uid)
195 197
196 198 idx = self.targets.index(uid)
197 199 self.targets.pop(idx)
198 200 self.loads.pop(idx)
199 201
200 202 self.handle_stranded_tasks(lost)
201 203
202 204 def handle_stranded_tasks(self, lost):
203 205 """deal with jobs resident in an engine that died."""
204 206 # TODO: resubmit the tasks?
205 207 for msg_id in lost:
206 208 pass
207 209
208 210
209 211 #-----------------------------------------------------------------------
210 212 # Job Submission
211 213 #-----------------------------------------------------------------------
212 214 @logged
213 215 def dispatch_submission(self, raw_msg):
214 216 """dispatch job submission"""
215 217 # ensure targets up to date:
216 218 self.notifier_stream.flush()
217 219 try:
218 220 idents, msg = self.session.feed_identities(raw_msg, copy=False)
219 221 except Exception, e:
220 222 logger.error("task::Invaid msg: %s"%msg)
221 223 return
222 224
223 225 msg = self.session.unpack_message(msg, content=False, copy=False)
224 print idents,msg
225 226 header = msg['header']
226 227 msg_id = header['msg_id']
227 228 after = Dependency(header.get('after', []))
228 229 if after.mode == 'all':
229 230 after.difference_update(self.all_done)
230 231 if after.check(self.all_done):
231 232 # recast as empty set, if we are already met,
232 233 # to prevent
233 234 after = Dependency([])
234 235
235 236 follow = Dependency(header.get('follow', []))
236 print raw_msg
237 237 if len(after) == 0:
238 238 # time deps already met, try to run
239 239 if not self.maybe_run(msg_id, raw_msg, follow):
240 240 # can't run yet
241 241 self.save_unmet(msg_id, raw_msg, after, follow)
242 242 else:
243 243 self.save_unmet(msg_id, raw_msg, after, follow)
244 244 # send to monitor
245 245 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
246 246 @logged
247 247 def maybe_run(self, msg_id, raw_msg, follow=None):
248 248 """check location dependencies, and run if they are met."""
249 249
250 250 if follow:
251 251 def can_run(idx):
252 252 target = self.targets[idx]
253 253 return target not in self.blacklist.get(msg_id, []) and\
254 254 follow.check(self.completed[target])
255 255
256 256 indices = filter(can_run, range(len(self.targets)))
257 257 if not indices:
258 258 return False
259 259 else:
260 260 indices = None
261 261
262 262 self.submit_task(msg_id, raw_msg, indices)
263 263 return True
264 264
265 265 @logged
266 266 def save_unmet(self, msg_id, msg, after, follow):
267 267 """Save a message for later submission when its dependencies are met."""
268 268 self.depending[msg_id] = (msg_id,msg,after,follow)
269 269 # track the ids in both follow/after, but not those already completed
270 270 for dep_id in after.union(follow).difference(self.all_done):
271 print dep_id
271 272 if dep_id not in self.dependencies:
272 273 self.dependencies[dep_id] = set()
273 274 self.dependencies[dep_id].add(msg_id)
275
274 276 @logged
275 277 def submit_task(self, msg_id, msg, follow=None, indices=None):
276 278 """submit a task to any of a subset of our targets"""
277 279 if indices:
278 280 loads = [self.loads[i] for i in indices]
279 281 else:
280 282 loads = self.loads
281 283 idx = self.scheme(loads)
282 284 if indices:
283 285 idx = indices[idx]
284 286 target = self.targets[idx]
285 287 print target, map(str, msg[:3])
286 self.engine_stream.socket.send(target, flags=zmq.SNDMORE, copy=False)
287 self.engine_stream.socket.send_multipart(msg, copy=False)
288 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
289 self.engine_stream.send_multipart(msg, copy=False)
288 290 self.add_job(idx)
289 291 self.pending[target][msg_id] = (msg, follow)
290 292
291 293 #-----------------------------------------------------------------------
292 294 # Result Handling
293 295 #-----------------------------------------------------------------------
294 296 @logged
295 297 def dispatch_result(self, raw_msg):
296 298 try:
297 299 idents,msg = self.session.feed_identities(raw_msg, copy=False)
298 300 except Exception, e:
299 301 logger.error("task::Invaid result: %s"%msg)
300 302 return
301 303 msg = self.session.unpack_message(msg, content=False, copy=False)
302 304 header = msg['header']
303 305 if header.get('dependencies_met', True):
304 306 self.handle_result_success(idents, msg['parent_header'], raw_msg)
305 307 # send to monitor
306 308 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
307 309 else:
308 self.handle_unmet_dependency(self, idents, msg['parent_header'])
310 self.handle_unmet_dependency(idents, msg['parent_header'])
309 311
310 312 @logged
311 313 def handle_result_success(self, idents, parent, raw_msg):
312 314 # first, relay result to client
313 315 engine = idents[0]
314 316 client = idents[1]
315 317 # swap_ids for XREP-XREP mirror
316 318 raw_msg[:2] = [client,engine]
317 319 print map(str, raw_msg[:4])
318 320 self.client_stream.send_multipart(raw_msg, copy=False)
319 321 # now, update our data structures
320 322 msg_id = parent['msg_id']
321 323 self.pending[engine].pop(msg_id)
322 324 self.completed[engine].add(msg_id)
323 325
324 326 self.update_dependencies(msg_id)
325 327
326 328 @logged
327 329 def handle_unmet_dependency(self, idents, parent):
328 330 engine = idents[0]
329 331 msg_id = parent['msg_id']
330 332 if msg_id not in self.blacklist:
331 333 self.blacklist[msg_id] = set()
332 334 self.blacklist[msg_id].add(engine)
333 335 raw_msg,follow = self.pending[engine].pop(msg_id)
334 if not self.maybe_run(raw_msg, follow):
336 if not self.maybe_run(msg_id, raw_msg, follow):
335 337 # resubmit failed, put it back in our dependency tree
336 338 self.save_unmet(msg_id, raw_msg, Dependency(), follow)
337 339 pass
338 340 @logged
339 341 def update_dependencies(self, dep_id):
340 342 """dep_id just finished. Update our dependency
341 343 table and submit any jobs that just became runable."""
342 344 if dep_id not in self.dependencies:
343 345 return
344 346 jobs = self.dependencies.pop(dep_id)
345 347 for job in jobs:
346 348 msg_id, raw_msg, after, follow = self.depending[job]
347 349 if msg_id in after:
348 350 after.remove(msg_id)
349 351 if not after: # time deps met
350 352 if self.maybe_run(msg_id, raw_msg, follow):
351 353 self.depending.pop(job)
352 354 for mid in follow:
353 self.dependencies[mid].remove(msg_id)
355 if mid in self.dependencies:
356 self.dependencies[mid].remove(msg_id)
354 357
355 358 #----------------------------------------------------------------------
356 359 # methods to be overridden by subclasses
357 360 #----------------------------------------------------------------------
358 361
359 362 def add_job(self, idx):
360 363 """Called after self.targets[idx] just got the job with header.
361 364 Override with subclasses. The default ordering is simple LRU.
362 365 The default loads are the number of outstanding jobs."""
363 366 self.loads[idx] += 1
364 367 for lis in (self.targets, self.loads):
365 368 lis.append(lis.pop(idx))
366 369
367 370
368 371 def finish_job(self, idx):
369 372 """Called after self.targets[idx] just finished a job.
370 373 Override with subclasses."""
371 374 self.loads[idx] -= 1
372 375
373 376
374 377
375 378 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, scheme='weighted'):
376 379 from zmq.eventloop import ioloop
377 380 from zmq.eventloop.zmqstream import ZMQStream
378 381
379 382 ctx = zmq.Context()
380 383 loop = ioloop.IOLoop()
381 384
382 385 scheme = globals().get(scheme)
383 386
384 387 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
385 388 ins.bind(in_addr)
386 389 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
387 390 outs.bind(out_addr)
388 391 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
389 392 mons.connect(mon_addr)
390 393 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
391 394 nots.setsockopt(zmq.SUBSCRIBE, '')
392 395 nots.connect(not_addr)
393 396
394 397 scheduler = TaskScheduler(ins,outs,mons,nots,scheme,loop)
395 398
396 399 loop.start()
397 400
398 401
399 402 if __name__ == '__main__':
400 403 iface = 'tcp://127.0.0.1:%i'
401 404 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
@@ -1,498 +1,499 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11
12 12 import zmq
13 13 from zmq.utils import jsonapi
14 14 from zmq.eventloop.zmqstream import ZMQStream
15 15
16 16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
17 17 from IPython.zmq.newserialized import serialize, unserialize
18 18
19 19 try:
20 20 import cPickle
21 21 pickle = cPickle
22 22 except:
23 23 cPickle = None
24 24 import pickle
25 25
26 26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
27 27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
28 28 if json_name in ('jsonlib', 'jsonlib2'):
29 29 use_json = True
30 30 elif json_name:
31 31 if cPickle is None:
32 32 use_json = True
33 33 else:
34 34 use_json = False
35 35 else:
36 36 use_json = False
37 37
38 38 def squash_unicode(obj):
39 39 if isinstance(obj,dict):
40 40 for key in obj.keys():
41 41 obj[key] = squash_unicode(obj[key])
42 42 if isinstance(key, unicode):
43 43 obj[squash_unicode(key)] = obj.pop(key)
44 44 elif isinstance(obj, list):
45 45 for i,v in enumerate(obj):
46 46 obj[i] = squash_unicode(v)
47 47 elif isinstance(obj, unicode):
48 48 obj = obj.encode('utf8')
49 49 return obj
50 50
51 51 if use_json:
52 52 default_packer = jsonapi.dumps
53 53 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
54 54 else:
55 55 default_packer = lambda o: pickle.dumps(o,-1)
56 56 default_unpacker = pickle.loads
57 57
58 58
59 59 DELIM="<IDS|MSG>"
60 60
61 61 def wrap_exception():
62 62 etype, evalue, tb = sys.exc_info()
63 63 tb = traceback.format_exception(etype, evalue, tb)
64 64 exc_content = {
65 65 u'status' : u'error',
66 66 u'traceback' : tb,
67 67 u'etype' : unicode(etype),
68 68 u'evalue' : unicode(evalue)
69 69 }
70 70 return exc_content
71 71
72 72 class KernelError(Exception):
73 73 pass
74 74
75 75 def unwrap_exception(content):
76 76 err = KernelError(content['etype'], content['evalue'])
77 77 err.evalue = content['evalue']
78 78 err.etype = content['etype']
79 79 err.traceback = ''.join(content['traceback'])
80 80 return err
81 81
82 82
83 83 class Message(object):
84 84 """A simple message object that maps dict keys to attributes.
85 85
86 86 A Message can be created from a dict and a dict from a Message instance
87 87 simply by calling dict(msg_obj)."""
88 88
89 89 def __init__(self, msg_dict):
90 90 dct = self.__dict__
91 91 for k, v in dict(msg_dict).iteritems():
92 92 if isinstance(v, dict):
93 93 v = Message(v)
94 94 dct[k] = v
95 95
96 96 # Having this iterator lets dict(msg_obj) work out of the box.
97 97 def __iter__(self):
98 98 return iter(self.__dict__.iteritems())
99 99
100 100 def __repr__(self):
101 101 return repr(self.__dict__)
102 102
103 103 def __str__(self):
104 104 return pprint.pformat(self.__dict__)
105 105
106 106 def __contains__(self, k):
107 107 return k in self.__dict__
108 108
109 109 def __getitem__(self, k):
110 110 return self.__dict__[k]
111 111
112 112
113 113 def msg_header(msg_id, msg_type, username, session):
114 114 return locals()
115 115 # return {
116 116 # 'msg_id' : msg_id,
117 117 # 'msg_type': msg_type,
118 118 # 'username' : username,
119 119 # 'session' : session
120 120 # }
121 121
122 122
123 123 def extract_header(msg_or_header):
124 124 """Given a message or header, return the header."""
125 125 if not msg_or_header:
126 126 return {}
127 127 try:
128 128 # See if msg_or_header is the entire message.
129 129 h = msg_or_header['header']
130 130 except KeyError:
131 131 try:
132 132 # See if msg_or_header is just the header
133 133 h = msg_or_header['msg_id']
134 134 except KeyError:
135 135 raise
136 136 else:
137 137 h = msg_or_header
138 138 if not isinstance(h, dict):
139 139 h = dict(h)
140 140 return h
141 141
142 142 def rekey(dikt):
143 143 """rekey a dict that has been forced to use str keys where there should be
144 144 ints by json. This belongs in the jsonutil added by fperez."""
145 145 for k in dikt.iterkeys():
146 146 if isinstance(k, str):
147 147 ik=fk=None
148 148 try:
149 149 ik = int(k)
150 150 except ValueError:
151 151 try:
152 152 fk = float(k)
153 153 except ValueError:
154 154 continue
155 155 if ik is not None:
156 156 nk = ik
157 157 else:
158 158 nk = fk
159 159 if nk in dikt:
160 160 raise KeyError("already have key %r"%nk)
161 161 dikt[nk] = dikt.pop(k)
162 162 return dikt
163 163
164 164 def serialize_object(obj, threshold=64e-6):
165 165 """serialize an object into a list of sendable buffers.
166 166
167 167 Returns: (pmd, bufs)
168 168 where pmd is the pickled metadata wrapper, and bufs
169 169 is a list of data buffers"""
170 170 # threshold is 100 B
171 171 databuffers = []
172 172 if isinstance(obj, (list, tuple)):
173 173 clist = canSequence(obj)
174 174 slist = map(serialize, clist)
175 175 for s in slist:
176 176 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
177 177 databuffers.append(s.getData())
178 178 s.data = None
179 179 return pickle.dumps(slist,-1), databuffers
180 180 elif isinstance(obj, dict):
181 181 sobj = {}
182 182 for k in sorted(obj.iterkeys()):
183 183 s = serialize(can(obj[k]))
184 184 if s.getDataSize() > threshold:
185 185 databuffers.append(s.getData())
186 186 s.data = None
187 187 sobj[k] = s
188 188 return pickle.dumps(sobj,-1),databuffers
189 189 else:
190 190 s = serialize(can(obj))
191 191 if s.getDataSize() > threshold:
192 192 databuffers.append(s.getData())
193 193 s.data = None
194 194 return pickle.dumps(s,-1),databuffers
195 195
196 196
197 197 def unserialize_object(bufs):
198 198 """reconstruct an object serialized by serialize_object from data buffers"""
199 199 bufs = list(bufs)
200 200 sobj = pickle.loads(bufs.pop(0))
201 201 if isinstance(sobj, (list, tuple)):
202 202 for s in sobj:
203 203 if s.data is None:
204 204 s.data = bufs.pop(0)
205 205 return uncanSequence(map(unserialize, sobj))
206 206 elif isinstance(sobj, dict):
207 207 newobj = {}
208 208 for k in sorted(sobj.iterkeys()):
209 209 s = sobj[k]
210 210 if s.data is None:
211 211 s.data = bufs.pop(0)
212 212 newobj[k] = uncan(unserialize(s))
213 213 return newobj
214 214 else:
215 215 if sobj.data is None:
216 216 sobj.data = bufs.pop(0)
217 217 return uncan(unserialize(sobj))
218 218
219 219 def pack_apply_message(f, args, kwargs, threshold=64e-6):
220 220 """pack up a function, args, and kwargs to be sent over the wire
221 221 as a series of buffers. Any object whose data is larger than `threshold`
222 222 will not have their data copied (currently only numpy arrays support zero-copy)"""
223 223 msg = [pickle.dumps(can(f),-1)]
224 224 databuffers = [] # for large objects
225 225 sargs, bufs = serialize_object(args,threshold)
226 226 msg.append(sargs)
227 227 databuffers.extend(bufs)
228 228 skwargs, bufs = serialize_object(kwargs,threshold)
229 229 msg.append(skwargs)
230 230 databuffers.extend(bufs)
231 231 msg.extend(databuffers)
232 232 return msg
233 233
234 234 def unpack_apply_message(bufs, g=None, copy=True):
235 235 """unpack f,args,kwargs from buffers packed by pack_apply_message()
236 236 Returns: original f,args,kwargs"""
237 237 bufs = list(bufs) # allow us to pop
238 238 assert len(bufs) >= 3, "not enough buffers!"
239 239 if not copy:
240 240 for i in range(3):
241 241 bufs[i] = bufs[i].bytes
242 242 cf = pickle.loads(bufs.pop(0))
243 243 sargs = list(pickle.loads(bufs.pop(0)))
244 244 skwargs = dict(pickle.loads(bufs.pop(0)))
245 245 # print sargs, skwargs
246 246 f = uncan(cf, g)
247 247 for sa in sargs:
248 248 if sa.data is None:
249 249 m = bufs.pop(0)
250 250 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
251 251 if copy:
252 252 sa.data = buffer(m)
253 253 else:
254 254 sa.data = m.buffer
255 255 else:
256 256 if copy:
257 257 sa.data = m
258 258 else:
259 259 sa.data = m.bytes
260 260
261 261 args = uncanSequence(map(unserialize, sargs), g)
262 262 kwargs = {}
263 263 for k in sorted(skwargs.iterkeys()):
264 264 sa = skwargs[k]
265 265 if sa.data is None:
266 266 sa.data = bufs.pop(0)
267 267 kwargs[k] = uncan(unserialize(sa), g)
268 268
269 269 return f,args,kwargs
270 270
271 271 class StreamSession(object):
272 272 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
273 273 debug=False
274 274 def __init__(self, username=None, session=None, packer=None, unpacker=None):
275 275 if username is None:
276 276 username = os.environ.get('USER','username')
277 277 self.username = username
278 278 if session is None:
279 279 self.session = str(uuid.uuid4())
280 280 else:
281 281 self.session = session
282 282 self.msg_id = str(uuid.uuid4())
283 283 if packer is None:
284 284 self.pack = default_packer
285 285 else:
286 286 if not callable(packer):
287 287 raise TypeError("packer must be callable, not %s"%type(packer))
288 288 self.pack = packer
289 289
290 290 if unpacker is None:
291 291 self.unpack = default_unpacker
292 292 else:
293 293 if not callable(unpacker):
294 294 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
295 295 self.unpack = unpacker
296 296
297 297 self.none = self.pack({})
298 298
299 299 def msg_header(self, msg_type):
300 300 h = msg_header(self.msg_id, msg_type, self.username, self.session)
301 301 self.msg_id = str(uuid.uuid4())
302 302 return h
303 303
304 304 def msg(self, msg_type, content=None, parent=None, subheader=None):
305 305 msg = {}
306 306 msg['header'] = self.msg_header(msg_type)
307 307 msg['msg_id'] = msg['header']['msg_id']
308 308 msg['parent_header'] = {} if parent is None else extract_header(parent)
309 309 msg['msg_type'] = msg_type
310 310 msg['content'] = {} if content is None else content
311 311 sub = {} if subheader is None else subheader
312 312 msg['header'].update(sub)
313 313 return msg
314 314
315 315 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
316 316 """Build and send a message via stream or socket.
317 317
318 318 Parameters
319 319 ----------
320 320
321 321 msg_type : str or Message/dict
322 322 Normally, msg_type will be
323 323
324 324
325 325
326 326 Returns
327 327 -------
328 328 (msg,sent) : tuple
329 329 msg : Message
330 330 the nice wrapped dict-like object containing the headers
331 331
332 332 """
333 333 if isinstance(msg_type, (Message, dict)):
334 334 # we got a Message, not a msg_type
335 335 # don't build a new Message
336 336 msg = msg_type
337 337 content = msg['content']
338 338 else:
339 339 msg = self.msg(msg_type, content, parent, subheader)
340 340 buffers = [] if buffers is None else buffers
341 341 to_send = []
342 342 if isinstance(ident, list):
343 343 # accept list of idents
344 344 to_send.extend(ident)
345 345 elif ident is not None:
346 346 to_send.append(ident)
347 347 to_send.append(DELIM)
348 348 to_send.append(self.pack(msg['header']))
349 349 to_send.append(self.pack(msg['parent_header']))
350 350 # if parent is None:
351 351 # to_send.append(self.none)
352 352 # else:
353 353 # to_send.append(self.pack(dict(parent)))
354 354 if content is None:
355 355 content = self.none
356 356 elif isinstance(content, dict):
357 357 content = self.pack(content)
358 358 elif isinstance(content, str):
359 359 # content is already packed, as in a relayed message
360 360 pass
361 361 else:
362 362 raise TypeError("Content incorrect type: %s"%type(content))
363 363 to_send.append(content)
364 364 flag = 0
365 365 if buffers:
366 366 flag = zmq.SNDMORE
367 367 stream.send_multipart(to_send, flag, copy=False)
368 368 for b in buffers[:-1]:
369 369 stream.send(b, flag, copy=False)
370 370 if buffers:
371 371 stream.send(buffers[-1], copy=False)
372 372 omsg = Message(msg)
373 373 if self.debug:
374 374 pprint.pprint(omsg)
375 375 pprint.pprint(to_send)
376 376 pprint.pprint(buffers)
377 377 # return both the msg object and the buffers
378 378 return omsg
379 379
380 380 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
381 381 """send a raw message via idents.
382 382
383 383 Parameters
384 384 ----------
385 385 msg : list of sendable buffers"""
386 386 to_send = []
387 387 if isinstance(ident, str):
388 388 ident = [ident]
389 389 if ident is not None:
390 390 to_send.extend(ident)
391 391 to_send.append(DELIM)
392 392 to_send.extend(msg)
393 393 stream.send_multipart(msg, flags, copy=copy)
394 394
395 395 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
396 396 """receives and unpacks a message
397 397 returns [idents], msg"""
398 398 if isinstance(socket, ZMQStream):
399 399 socket = socket.socket
400 400 try:
401 401 msg = socket.recv_multipart(mode)
402 402 except zmq.ZMQError, e:
403 403 if e.errno == zmq.EAGAIN:
404 404 # We can convert EAGAIN to None as we know in this case
405 405 # recv_json won't return None.
406 406 return None
407 407 else:
408 408 raise
409 409 # return an actual Message object
410 410 # determine the number of idents by trying to unpack them.
411 411 # this is terrible:
412 412 idents, msg = self.feed_identities(msg, copy)
413 413 try:
414 414 return idents, self.unpack_message(msg, content=content, copy=copy)
415 415 except Exception, e:
416 416 print idents, msg
417 417 # TODO: handle it
418 418 raise e
419 419
420 420 def feed_identities(self, msg, copy=True):
421 421 """This is a completely horrible thing, but it strips the zmq
422 422 ident prefixes off of a message. It will break if any identities
423 423 are unpackable by self.unpack."""
424 424 msg = list(msg)
425 425 idents = []
426 426 while len(msg) > 3:
427 427 if copy:
428 428 s = msg[0]
429 429 else:
430 430 s = msg[0].bytes
431 431 if s == DELIM:
432 432 msg.pop(0)
433 433 break
434 434 else:
435 435 idents.append(s)
436 436 msg.pop(0)
437 437
438 438 return idents, msg
439 439
440 440 def unpack_message(self, msg, content=True, copy=True):
441 """return a message object from the format
441 """Return a message object from the format
442 442 sent by self.send.
443 443
444 parameters:
444 Parameters:
445 -----------
445 446
446 447 content : bool (True)
447 448 whether to unpack the content dict (True),
448 449 or leave it serialized (False)
449 450
450 451 copy : bool (True)
451 452 whether to return the bytes (True),
452 453 or the non-copying Message object in each place (False)
453 454
454 455 """
455 456 if not len(msg) >= 3:
456 457 raise TypeError("malformed message, must have at least 3 elements")
457 458 message = {}
458 459 if not copy:
459 460 for i in range(3):
460 461 msg[i] = msg[i].bytes
461 462 message['header'] = self.unpack(msg[0])
462 463 message['msg_type'] = message['header']['msg_type']
463 464 message['parent_header'] = self.unpack(msg[1])
464 465 if content:
465 466 message['content'] = self.unpack(msg[2])
466 467 else:
467 468 message['content'] = msg[2]
468 469
469 470 # message['buffers'] = msg[3:]
470 471 # else:
471 472 # message['header'] = self.unpack(msg[0].bytes)
472 473 # message['msg_type'] = message['header']['msg_type']
473 474 # message['parent_header'] = self.unpack(msg[1].bytes)
474 475 # if content:
475 476 # message['content'] = self.unpack(msg[2].bytes)
476 477 # else:
477 478 # message['content'] = msg[2].bytes
478 479
479 480 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
480 481 return message
481 482
482 483
483 484
484 485 def test_msg2obj():
485 486 am = dict(x=1)
486 487 ao = Message(am)
487 488 assert ao.x == am['x']
488 489
489 490 am['y'] = dict(z=1)
490 491 ao = Message(am)
491 492 assert ao.y.z == am['y']['z']
492 493
493 494 k1, k2 = 'y', 'z'
494 495 assert ao[k1][k2] == am[k1][k2]
495 496
496 497 am2 = dict(ao)
497 498 assert am['x'] == am2['x']
498 499 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now