##// END OF EJS Templates
added exec_key and fixed client.shutdown
MinRK -
Show More
@@ -1,905 +1,924 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 import os
15 16 import time
16 17 from pprint import pprint
17 18
18 19 import zmq
19 20 from zmq.eventloop import ioloop, zmqstream
20 21
21 22 from IPython.external.decorator import decorator
22 23 from IPython.zmq import tunnel
23 24
24 25 import streamsession as ss
25 26 # from remotenamespace import RemoteNamespace
26 27 from view import DirectView, LoadBalancedView
27 28 from dependency import Dependency, depend, require
28 29
29 30 def _push(ns):
30 31 globals().update(ns)
31 32
32 33 def _pull(keys):
33 34 g = globals()
34 35 if isinstance(keys, (list,tuple, set)):
35 36 for key in keys:
36 37 if not g.has_key(key):
37 38 raise NameError("name '%s' is not defined"%key)
38 39 return map(g.get, keys)
39 40 else:
40 41 if not g.has_key(keys):
41 42 raise NameError("name '%s' is not defined"%keys)
42 43 return g.get(keys)
43 44
44 45 def _clear():
45 46 globals().clear()
46 47
47 48 def execute(code):
48 49 exec code in globals()
49 50
50 51 #--------------------------------------------------------------------------
51 52 # Decorators for Client methods
52 53 #--------------------------------------------------------------------------
53 54
54 55 @decorator
55 56 def spinfirst(f, self, *args, **kwargs):
56 57 """Call spin() to sync state prior to calling the method."""
57 58 self.spin()
58 59 return f(self, *args, **kwargs)
59 60
60 61 @decorator
61 62 def defaultblock(f, self, *args, **kwargs):
62 63 """Default to self.block; preserve self.block."""
63 64 block = kwargs.get('block',None)
64 65 block = self.block if block is None else block
65 66 saveblock = self.block
66 67 self.block = block
67 68 ret = f(self, *args, **kwargs)
68 69 self.block = saveblock
69 70 return ret
70 71
71 72 def remote(client, bound=False, block=None, targets=None):
72 73 """Turn a function into a remote function.
73 74
74 75 This method can be used for map:
75 76
76 77 >>> @remote(client,block=True)
77 78 def func(a)
78 79 """
79 80 def remote_function(f):
80 81 return RemoteFunction(client, f, bound, block, targets)
81 82 return remote_function
82 83
83 84 #--------------------------------------------------------------------------
84 85 # Classes
85 86 #--------------------------------------------------------------------------
86 87
87 88 class RemoteFunction(object):
88 89 """Turn an existing function into a remote function"""
89 90
90 91 def __init__(self, client, f, bound=False, block=None, targets=None):
91 92 self.client = client
92 93 self.func = f
93 94 self.block=block
94 95 self.bound=bound
95 96 self.targets=targets
96 97
97 98 def __call__(self, *args, **kwargs):
98 99 return self.client.apply(self.func, args=args, kwargs=kwargs,
99 100 block=self.block, targets=self.targets, bound=self.bound)
100 101
101 102
102 103 class AbortedTask(object):
103 104 """A basic wrapper object describing an aborted task."""
104 105 def __init__(self, msg_id):
105 106 self.msg_id = msg_id
106 107
107 108 class ControllerError(Exception):
108 109 def __init__(self, etype, evalue, tb):
109 110 self.etype = etype
110 111 self.evalue = evalue
111 112 self.traceback=tb
112 113
113 114 class Client(object):
114 115 """A semi-synchronous client to the IPython ZMQ controller
115 116
116 117 Parameters
117 118 ----------
118 119
119 120 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
120 121 The address of the controller's registration socket.
121 122 [Default: 'tcp://127.0.0.1:10101']
122 123 context : zmq.Context
123 124 Pass an existing zmq.Context instance, otherwise the client will create its own
124 125 username : bytes
125 126 set username to be passed to the Session object
126 127 debug : bool
127 128 flag for lots of message printing for debug purposes
128 129
129 130 #-------------- ssh related args ----------------
130 131 # These are args for configuring the ssh tunnel to be used
131 132 # credentials are used to forward connections over ssh to the Controller
132 133 # Note that the ip given in `addr` needs to be relative to sshserver
133 134 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
134 135 # and set sshserver as the same machine the Controller is on. However,
135 136 # the only requirement is that sshserver is able to see the Controller
136 137 # (i.e. is within the same trusted network).
137 138
138 139 sshserver : str
139 140 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
140 141 If keyfile or password is specified, and this is not, it will default to
141 142 the ip given in addr.
142 keyfile : str; path to public key file
143 sshkey : str; path to public ssh key file
143 144 This specifies a key to be used in ssh login, default None.
144 145 Regular default ssh keys will be used without specifying this argument.
145 146 password : str;
146 147 Your ssh password to sshserver. Note that if this is left None,
147 148 you will be prompted for it if passwordless key based login is unavailable.
148 149
150 #------- exec authentication args -------
151 # If even localhost is untrusted, you can have some protection against
152 # unauthorized execution by using a key. Messages are still sent
153 # as cleartext, so if someone can snoop your loopback traffic this will
154 # not help anything.
155
156 exec_key : str
157 an authentication key or file containing a key
158 default: None
159
160
149 161 Attributes
150 162 ----------
151 163 ids : set of int engine IDs
152 164 requesting the ids attribute always synchronizes
153 165 the registration state. To request ids without synchronization,
154 use semi-private _ids.
166 use semi-private _ids attributes.
155 167
156 168 history : list of msg_ids
157 169 a list of msg_ids, keeping track of all the execution
158 170 messages you have submitted in order.
159 171
160 172 outstanding : set of msg_ids
161 173 a set of msg_ids that have been submitted, but whose
162 174 results have not yet been received.
163 175
164 176 results : dict
165 177 a dict of all our results, keyed by msg_id
166 178
167 179 block : bool
168 180 determines default behavior when block not specified
169 181 in execution methods
170 182
171 183 Methods
172 184 -------
173 185 spin : flushes incoming results and registration state changes
174 186 control methods spin, and requesting `ids` also ensures up to date
175 187
176 188 barrier : wait on one or more msg_ids
177 189
178 execution methods: apply/apply_bound/apply_to/applu_bount
190 execution methods: apply/apply_bound/apply_to/apply_bound
179 191 legacy: execute, run
180 192
181 193 query methods: queue_status, get_result, purge
182 194
183 195 control methods: abort, kill
184 196
185 197 """
186 198
187 199
188 200 _connected=False
189 201 _ssh=False
190 202 _engines=None
191 203 _addr='tcp://127.0.0.1:10101'
192 204 _registration_socket=None
193 205 _query_socket=None
194 206 _control_socket=None
195 207 _notification_socket=None
196 208 _mux_socket=None
197 209 _task_socket=None
198 210 block = False
199 211 outstanding=None
200 212 results = None
201 213 history = None
202 214 debug = False
203 215
204 216 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
205 sshserver=None, keyfile=None, password=None, paramiko=None):
217 sshserver=None, sshkey=None, password=None, paramiko=None,
218 exec_key=None,):
206 219 if context is None:
207 220 context = zmq.Context()
208 221 self.context = context
209 222 self._addr = addr
210 self._ssh = bool(sshserver or keyfile or password)
223 self._ssh = bool(sshserver or sshkey or password)
211 224 if self._ssh and sshserver is None:
212 225 # default to the same
213 226 sshserver = addr.split('://')[1].split(':')[0]
214 227 if self._ssh and password is None:
215 if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko):
228 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
216 229 password=False
217 230 else:
218 231 password = getpass("SSH Password for %s: "%sshserver)
219 ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko)
220
232 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
233
234 if os.path.isfile(exec_key):
235 arg = 'keyfile'
236 else:
237 arg = 'key'
238 key_arg = {arg:exec_key}
221 239 if username is None:
222 self.session = ss.StreamSession()
240 self.session = ss.StreamSession(**key_arg)
223 241 else:
224 self.session = ss.StreamSession(username)
242 self.session = ss.StreamSession(username, **key_arg)
225 243 self._registration_socket = self.context.socket(zmq.XREQ)
226 244 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
227 245 if self._ssh:
228 246 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
229 247 else:
230 248 self._registration_socket.connect(addr)
231 249 self._engines = {}
232 250 self._ids = set()
233 251 self.outstanding=set()
234 252 self.results = {}
235 253 self.history = []
236 254 self.debug = debug
237 255 self.session.debug = debug
238 256
239 257 self._notification_handlers = {'registration_notification' : self._register_engine,
240 258 'unregistration_notification' : self._unregister_engine,
241 259 }
242 260 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
243 261 'apply_reply' : self._handle_apply_reply}
244 262 self._connect(sshserver, ssh_kwargs)
245 263
246 264
247 265 @property
248 266 def ids(self):
249 267 """Always up to date ids property."""
250 268 self._flush_notifications()
251 269 return self._ids
252 270
253 271 def _update_engines(self, engines):
254 272 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
255 273 for k,v in engines.iteritems():
256 274 eid = int(k)
257 275 self._engines[eid] = bytes(v) # force not unicode
258 276 self._ids.add(eid)
259 277
260 278 def _build_targets(self, targets):
261 279 """Turn valid target IDs or 'all' into two lists:
262 280 (int_ids, uuids).
263 281 """
264 282 if targets is None:
265 283 targets = self._ids
266 284 elif isinstance(targets, str):
267 285 if targets.lower() == 'all':
268 286 targets = self._ids
269 287 else:
270 288 raise TypeError("%r not valid str target, must be 'all'"%(targets))
271 289 elif isinstance(targets, int):
272 290 targets = [targets]
273 291 return [self._engines[t] for t in targets], list(targets)
274 292
275 293 def _connect(self, sshserver, ssh_kwargs):
276 294 """setup all our socket connections to the controller. This is called from
277 295 __init__."""
278 296 if self._connected:
279 297 return
280 298 self._connected=True
281 299
282 300 def connect_socket(s, addr):
283 301 if self._ssh:
284 302 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
285 303 else:
286 304 return s.connect(addr)
287 305
288 306 self.session.send(self._registration_socket, 'connection_request')
289 307 idents,msg = self.session.recv(self._registration_socket,mode=0)
290 308 if self.debug:
291 309 pprint(msg)
292 310 msg = ss.Message(msg)
293 311 content = msg.content
294 312 if content.status == 'ok':
295 313 if content.queue:
296 314 self._mux_socket = self.context.socket(zmq.PAIR)
297 315 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
298 316 connect_socket(self._mux_socket, content.queue)
299 317 if content.task:
300 318 self._task_socket = self.context.socket(zmq.PAIR)
301 319 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 320 connect_socket(self._task_socket, content.task)
303 321 if content.notification:
304 322 self._notification_socket = self.context.socket(zmq.SUB)
305 323 connect_socket(self._notification_socket, content.notification)
306 324 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
307 325 if content.query:
308 326 self._query_socket = self.context.socket(zmq.PAIR)
309 327 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
310 328 connect_socket(self._query_socket, content.query)
311 329 if content.control:
312 330 self._control_socket = self.context.socket(zmq.PAIR)
313 331 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
314 332 connect_socket(self._control_socket, content.control)
315 333 self._update_engines(dict(content.engines))
316 334
317 335 else:
318 336 self._connected = False
319 337 raise Exception("Failed to connect!")
320 338
321 339 #--------------------------------------------------------------------------
322 340 # handlers and callbacks for incoming messages
323 341 #--------------------------------------------------------------------------
324 342
325 343 def _register_engine(self, msg):
326 344 """Register a new engine, and update our connection info."""
327 345 content = msg['content']
328 346 eid = content['id']
329 347 d = {eid : content['queue']}
330 348 self._update_engines(d)
331 349 self._ids.add(int(eid))
332 350
333 351 def _unregister_engine(self, msg):
334 352 """Unregister an engine that has died."""
335 353 content = msg['content']
336 354 eid = int(content['id'])
337 355 if eid in self._ids:
338 356 self._ids.remove(eid)
339 357 self._engines.pop(eid)
340 358
341 359 def _handle_execute_reply(self, msg):
342 360 """Save the reply to an execute_request into our results."""
343 361 parent = msg['parent_header']
344 362 msg_id = parent['msg_id']
345 363 if msg_id not in self.outstanding:
346 364 print("got unknown result: %s"%msg_id)
347 365 else:
348 366 self.outstanding.remove(msg_id)
349 367 self.results[msg_id] = ss.unwrap_exception(msg['content'])
350 368
351 369 def _handle_apply_reply(self, msg):
352 370 """Save the reply to an apply_request into our results."""
353 371 parent = msg['parent_header']
354 372 msg_id = parent['msg_id']
355 373 if msg_id not in self.outstanding:
356 374 print ("got unknown result: %s"%msg_id)
357 375 else:
358 376 self.outstanding.remove(msg_id)
359 377 content = msg['content']
360 378 if content['status'] == 'ok':
361 379 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
362 380 elif content['status'] == 'aborted':
363 381 self.results[msg_id] = AbortedTask(msg_id)
364 382 elif content['status'] == 'resubmitted':
365 383 # TODO: handle resubmission
366 384 pass
367 385 else:
368 386 self.results[msg_id] = ss.unwrap_exception(content)
369 387
370 388 def _flush_notifications(self):
371 389 """Flush notifications of engine registrations waiting
372 390 in ZMQ queue."""
373 391 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
374 392 while msg is not None:
375 393 if self.debug:
376 394 pprint(msg)
377 395 msg = msg[-1]
378 396 msg_type = msg['msg_type']
379 397 handler = self._notification_handlers.get(msg_type, None)
380 398 if handler is None:
381 399 raise Exception("Unhandled message type: %s"%msg.msg_type)
382 400 else:
383 401 handler(msg)
384 402 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
385 403
386 404 def _flush_results(self, sock):
387 405 """Flush task or queue results waiting in ZMQ queue."""
388 406 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
389 407 while msg is not None:
390 408 if self.debug:
391 409 pprint(msg)
392 410 msg = msg[-1]
393 411 msg_type = msg['msg_type']
394 412 handler = self._queue_handlers.get(msg_type, None)
395 413 if handler is None:
396 414 raise Exception("Unhandled message type: %s"%msg.msg_type)
397 415 else:
398 416 handler(msg)
399 417 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
400 418
401 419 def _flush_control(self, sock):
402 420 """Flush replies from the control channel waiting
403 421 in the ZMQ queue.
404 422
405 423 Currently: ignore them."""
406 424 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
407 425 while msg is not None:
408 426 if self.debug:
409 427 pprint(msg)
410 428 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
411 429
412 430 #--------------------------------------------------------------------------
413 431 # getitem
414 432 #--------------------------------------------------------------------------
415 433
416 434 def __getitem__(self, key):
417 435 """Dict access returns DirectView multiplexer objects or,
418 436 if key is None, a LoadBalancedView."""
419 437 if key is None:
420 438 return LoadBalancedView(self)
421 439 if isinstance(key, int):
422 440 if key not in self.ids:
423 441 raise IndexError("No such engine: %i"%key)
424 442 return DirectView(self, key)
425 443
426 444 if isinstance(key, slice):
427 445 indices = range(len(self.ids))[key]
428 446 ids = sorted(self._ids)
429 447 key = [ ids[i] for i in indices ]
430 448 # newkeys = sorted(self._ids)[thekeys[k]]
431 449
432 450 if isinstance(key, (tuple, list, xrange)):
433 451 _,targets = self._build_targets(list(key))
434 452 return DirectView(self, targets)
435 453 else:
436 454 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
437 455
438 456 #--------------------------------------------------------------------------
439 457 # Begin public methods
440 458 #--------------------------------------------------------------------------
441 459
442 460 def spin(self):
443 461 """Flush any registration notifications and execution results
444 462 waiting in the ZMQ queue.
445 463 """
446 464 if self._notification_socket:
447 465 self._flush_notifications()
448 466 if self._mux_socket:
449 467 self._flush_results(self._mux_socket)
450 468 if self._task_socket:
451 469 self._flush_results(self._task_socket)
452 470 if self._control_socket:
453 471 self._flush_control(self._control_socket)
454 472
455 473 def barrier(self, msg_ids=None, timeout=-1):
456 474 """waits on one or more `msg_ids`, for up to `timeout` seconds.
457 475
458 476 Parameters
459 477 ----------
460 478 msg_ids : int, str, or list of ints and/or strs
461 479 ints are indices to self.history
462 480 strs are msg_ids
463 481 default: wait on all outstanding messages
464 482 timeout : float
465 483 a time in seconds, after which to give up.
466 484 default is -1, which means no timeout
467 485
468 486 Returns
469 487 -------
470 488 True : when all msg_ids are done
471 489 False : timeout reached, some msg_ids still outstanding
472 490 """
473 491 tic = time.time()
474 492 if msg_ids is None:
475 493 theids = self.outstanding
476 494 else:
477 495 if isinstance(msg_ids, (int, str)):
478 496 msg_ids = [msg_ids]
479 497 theids = set()
480 498 for msg_id in msg_ids:
481 499 if isinstance(msg_id, int):
482 500 msg_id = self.history[msg_id]
483 501 theids.add(msg_id)
484 502 self.spin()
485 503 while theids.intersection(self.outstanding):
486 504 if timeout >= 0 and ( time.time()-tic ) > timeout:
487 505 break
488 506 time.sleep(1e-3)
489 507 self.spin()
490 508 return len(theids.intersection(self.outstanding)) == 0
491 509
492 510 #--------------------------------------------------------------------------
493 511 # Control methods
494 512 #--------------------------------------------------------------------------
495 513
496 514 @spinfirst
497 515 @defaultblock
498 516 def clear(self, targets=None, block=None):
499 517 """Clear the namespace in target(s)."""
500 518 targets = self._build_targets(targets)[0]
501 519 for t in targets:
502 520 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
503 521 error = False
504 522 if self.block:
505 523 for i in range(len(targets)):
506 524 idents,msg = self.session.recv(self._control_socket,0)
507 525 if self.debug:
508 526 pprint(msg)
509 527 if msg['content']['status'] != 'ok':
510 528 error = ss.unwrap_exception(msg['content'])
511 529 if error:
512 530 return error
513 531
514 532
515 533 @spinfirst
516 534 @defaultblock
517 535 def abort(self, msg_ids = None, targets=None, block=None):
518 536 """Abort the execution queues of target(s)."""
519 537 targets = self._build_targets(targets)[0]
520 538 if isinstance(msg_ids, basestring):
521 539 msg_ids = [msg_ids]
522 540 content = dict(msg_ids=msg_ids)
523 541 for t in targets:
524 542 self.session.send(self._control_socket, 'abort_request',
525 543 content=content, ident=t)
526 544 error = False
527 545 if self.block:
528 546 for i in range(len(targets)):
529 547 idents,msg = self.session.recv(self._control_socket,0)
530 548 if self.debug:
531 549 pprint(msg)
532 550 if msg['content']['status'] != 'ok':
533 551 error = ss.unwrap_exception(msg['content'])
534 552 if error:
535 553 return error
536 554
537 555 @spinfirst
538 556 @defaultblock
539 def kill(self, targets=None, block=None):
557 def shutdown(self, targets=None, restart=False, block=None):
540 558 """Terminates one or more engine processes."""
541 559 targets = self._build_targets(targets)[0]
542 560 for t in targets:
543 self.session.send(self._control_socket, 'kill_request', content={},ident=t)
561 self.session.send(self._control_socket, 'shutdown_request',
562 content={'restart':restart},ident=t)
544 563 error = False
545 564 if self.block:
546 565 for i in range(len(targets)):
547 566 idents,msg = self.session.recv(self._control_socket,0)
548 567 if self.debug:
549 568 pprint(msg)
550 569 if msg['content']['status'] != 'ok':
551 570 error = ss.unwrap_exception(msg['content'])
552 571 if error:
553 572 return error
554 573
555 574 #--------------------------------------------------------------------------
556 575 # Execution methods
557 576 #--------------------------------------------------------------------------
558 577
559 578 @defaultblock
560 579 def execute(self, code, targets='all', block=None):
561 580 """Executes `code` on `targets` in blocking or nonblocking manner.
562 581
563 582 Parameters
564 583 ----------
565 584 code : str
566 585 the code string to be executed
567 586 targets : int/str/list of ints/strs
568 587 the engines on which to execute
569 588 default : all
570 589 block : bool
571 590 whether or not to wait until done to return
572 591 default: self.block
573 592 """
574 593 # block = self.block if block is None else block
575 594 # saveblock = self.block
576 595 # self.block = block
577 596 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
578 597 # self.block = saveblock
579 598 return result
580 599
581 600 def run(self, code, block=None):
582 601 """Runs `code` on an engine.
583 602
584 603 Calls to this are load-balanced.
585 604
586 605 Parameters
587 606 ----------
588 607 code : str
589 608 the code string to be executed
590 609 block : bool
591 610 whether or not to wait until done
592 611
593 612 """
594 613 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
595 614 return result
596 615
597 616 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
598 617 after=None, follow=None):
599 618 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
600 619
601 620 This is the central execution command for the client.
602 621
603 622 Parameters
604 623 ----------
605 624
606 625 f : function
607 626 The fuction to be called remotely
608 627 args : tuple/list
609 628 The positional arguments passed to `f`
610 629 kwargs : dict
611 630 The keyword arguments passed to `f`
612 631 bound : bool (default: True)
613 632 Whether to execute in the Engine(s) namespace, or in a clean
614 633 namespace not affecting the engine.
615 634 block : bool (default: self.block)
616 635 Whether to wait for the result, or return immediately.
617 636 False:
618 637 returns msg_id(s)
619 638 if multiple targets:
620 639 list of ids
621 640 True:
622 641 returns actual result(s) of f(*args, **kwargs)
623 642 if multiple targets:
624 643 dict of results, by engine ID
625 644 targets : int,list of ints, 'all', None
626 645 Specify the destination of the job.
627 646 if None:
628 647 Submit via Task queue for load-balancing.
629 648 if 'all':
630 649 Run on all active engines
631 650 if list:
632 651 Run on each specified engine
633 652 if int:
634 653 Run on single engine
635 654
636 655 after : Dependency or collection of msg_ids
637 656 Only for load-balanced execution (targets=None)
638 657 Specify a list of msg_ids as a time-based dependency.
639 658 This job will only be run *after* the dependencies
640 659 have been met.
641 660
642 661 follow : Dependency or collection of msg_ids
643 662 Only for load-balanced execution (targets=None)
644 663 Specify a list of msg_ids as a location-based dependency.
645 664 This job will only be run on an engine where this dependency
646 665 is met.
647 666
648 667 Returns
649 668 -------
650 669 if block is False:
651 670 if single target:
652 671 return msg_id
653 672 else:
654 673 return list of msg_ids
655 674 ? (should this be dict like block=True) ?
656 675 else:
657 676 if single target:
658 677 return result of f(*args, **kwargs)
659 678 else:
660 679 return dict of results, keyed by engine
661 680 """
662 681
663 682 # defaults:
664 683 block = block if block is not None else self.block
665 684 args = args if args is not None else []
666 685 kwargs = kwargs if kwargs is not None else {}
667 686
668 687 # enforce types of f,args,kwrags
669 688 if not callable(f):
670 689 raise TypeError("f must be callable, not %s"%type(f))
671 690 if not isinstance(args, (tuple, list)):
672 691 raise TypeError("args must be tuple or list, not %s"%type(args))
673 692 if not isinstance(kwargs, dict):
674 693 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
675 694
676 695 options = dict(bound=bound, block=block, after=after, follow=follow)
677 696
678 697 if targets is None:
679 698 return self._apply_balanced(f, args, kwargs, **options)
680 699 else:
681 700 return self._apply_direct(f, args, kwargs, targets=targets, **options)
682 701
683 702 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
684 703 after=None, follow=None):
685 704 """The underlying method for applying functions in a load balanced
686 705 manner, via the task queue."""
687 706 if isinstance(after, Dependency):
688 707 after = after.as_dict()
689 708 elif after is None:
690 709 after = []
691 710 if isinstance(follow, Dependency):
692 711 follow = follow.as_dict()
693 712 elif follow is None:
694 713 follow = []
695 714 subheader = dict(after=after, follow=follow)
696 715
697 716 bufs = ss.pack_apply_message(f,args,kwargs)
698 717 content = dict(bound=bound)
699 718 msg = self.session.send(self._task_socket, "apply_request",
700 719 content=content, buffers=bufs, subheader=subheader)
701 720 msg_id = msg['msg_id']
702 721 self.outstanding.add(msg_id)
703 722 self.history.append(msg_id)
704 723 if block:
705 724 self.barrier(msg_id)
706 725 return self.results[msg_id]
707 726 else:
708 727 return msg_id
709 728
710 729 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
711 730 after=None, follow=None):
712 731 """Then underlying method for applying functions to specific engines
713 732 via the MUX queue."""
714 733
715 734 queues,targets = self._build_targets(targets)
716 735 bufs = ss.pack_apply_message(f,args,kwargs)
717 736 if isinstance(after, Dependency):
718 737 after = after.as_dict()
719 738 elif after is None:
720 739 after = []
721 740 if isinstance(follow, Dependency):
722 741 follow = follow.as_dict()
723 742 elif follow is None:
724 743 follow = []
725 744 subheader = dict(after=after, follow=follow)
726 745 content = dict(bound=bound)
727 746 msg_ids = []
728 747 for queue in queues:
729 748 msg = self.session.send(self._mux_socket, "apply_request",
730 749 content=content, buffers=bufs,ident=queue, subheader=subheader)
731 750 msg_id = msg['msg_id']
732 751 self.outstanding.add(msg_id)
733 752 self.history.append(msg_id)
734 753 msg_ids.append(msg_id)
735 754 if block:
736 755 self.barrier(msg_ids)
737 756 else:
738 757 if len(msg_ids) == 1:
739 758 return msg_ids[0]
740 759 else:
741 760 return msg_ids
742 761 if len(msg_ids) == 1:
743 762 return self.results[msg_ids[0]]
744 763 else:
745 764 result = {}
746 765 for target,mid in zip(targets, msg_ids):
747 766 result[target] = self.results[mid]
748 767 return result
749 768
750 769 #--------------------------------------------------------------------------
751 770 # Data movement
752 771 #--------------------------------------------------------------------------
753 772
754 773 @defaultblock
755 774 def push(self, ns, targets=None, block=None):
756 775 """Push the contents of `ns` into the namespace on `target`"""
757 776 if not isinstance(ns, dict):
758 777 raise TypeError("Must be a dict, not %s"%type(ns))
759 778 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
760 779 return result
761 780
762 781 @defaultblock
763 782 def pull(self, keys, targets=None, block=True):
764 783 """Pull objects from `target`'s namespace by `keys`"""
765 784 if isinstance(keys, str):
766 785 pass
767 786 elif isistance(keys, (list,tuple,set)):
768 787 for key in keys:
769 788 if not isinstance(key, str):
770 789 raise TypeError
771 790 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
772 791 return result
773 792
774 793 #--------------------------------------------------------------------------
775 794 # Query methods
776 795 #--------------------------------------------------------------------------
777 796
778 797 @spinfirst
779 798 def get_results(self, msg_ids, status_only=False):
780 799 """Returns the result of the execute or task request with `msg_ids`.
781 800
782 801 Parameters
783 802 ----------
784 803 msg_ids : list of ints or msg_ids
785 804 if int:
786 805 Passed as index to self.history for convenience.
787 806 status_only : bool (default: False)
788 807 if False:
789 808 return the actual results
790 809 """
791 810 if not isinstance(msg_ids, (list,tuple)):
792 811 msg_ids = [msg_ids]
793 812 theids = []
794 813 for msg_id in msg_ids:
795 814 if isinstance(msg_id, int):
796 815 msg_id = self.history[msg_id]
797 816 if not isinstance(msg_id, str):
798 817 raise TypeError("msg_ids must be str, not %r"%msg_id)
799 818 theids.append(msg_id)
800 819
801 820 completed = []
802 821 local_results = {}
803 822 for msg_id in list(theids):
804 823 if msg_id in self.results:
805 824 completed.append(msg_id)
806 825 local_results[msg_id] = self.results[msg_id]
807 826 theids.remove(msg_id)
808 827
809 828 if theids: # some not locally cached
810 829 content = dict(msg_ids=theids, status_only=status_only)
811 830 msg = self.session.send(self._query_socket, "result_request", content=content)
812 831 zmq.select([self._query_socket], [], [])
813 832 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
814 833 if self.debug:
815 834 pprint(msg)
816 835 content = msg['content']
817 836 if content['status'] != 'ok':
818 837 raise ss.unwrap_exception(content)
819 838 else:
820 839 content = dict(completed=[],pending=[])
821 840 if not status_only:
822 841 # load cached results into result:
823 842 content['completed'].extend(completed)
824 843 content.update(local_results)
825 844 # update cache with results:
826 845 for msg_id in msg_ids:
827 846 if msg_id in content['completed']:
828 847 self.results[msg_id] = content[msg_id]
829 848 return content
830 849
831 850 @spinfirst
832 851 def queue_status(self, targets=None, verbose=False):
833 852 """Fetch the status of engine queues.
834 853
835 854 Parameters
836 855 ----------
837 856 targets : int/str/list of ints/strs
838 857 the engines on which to execute
839 858 default : all
840 859 verbose : bool
841 860 whether to return lengths only, or lists of ids for each element
842 861 """
843 862 targets = self._build_targets(targets)[1]
844 863 content = dict(targets=targets, verbose=verbose)
845 864 self.session.send(self._query_socket, "queue_request", content=content)
846 865 idents,msg = self.session.recv(self._query_socket, 0)
847 866 if self.debug:
848 867 pprint(msg)
849 868 content = msg['content']
850 869 status = content.pop('status')
851 870 if status != 'ok':
852 871 raise ss.unwrap_exception(content)
853 872 return content
854 873
855 874 @spinfirst
856 875 def purge_results(self, msg_ids=[], targets=[]):
857 876 """Tell the controller to forget results.
858 877
859 878 Individual results can be purged by msg_id, or the entire
860 879 history of specific targets can
861 880
862 881 Parameters
863 882 ----------
864 883 targets : int/str/list of ints/strs
865 884 the targets
866 885 default : None
867 886 """
868 887 if not targets and not msg_ids:
869 888 raise ValueError
870 889 if targets:
871 890 targets = self._build_targets(targets)[1]
872 891 content = dict(targets=targets, msg_ids=msg_ids)
873 892 self.session.send(self._query_socket, "purge_request", content=content)
874 893 idents, msg = self.session.recv(self._query_socket, 0)
875 894 if self.debug:
876 895 pprint(msg)
877 896 content = msg['content']
878 897 if content['status'] != 'ok':
879 898 raise ss.unwrap_exception(content)
880 899
881 900 class AsynClient(Client):
882 901 """An Asynchronous client, using the Tornado Event Loop.
883 902 !!!unfinished!!!"""
884 903 io_loop = None
885 904 _queue_stream = None
886 905 _notifier_stream = None
887 906 _task_stream = None
888 907 _control_stream = None
889 908
890 909 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
891 910 Client.__init__(self, addr, context, username, debug)
892 911 if io_loop is None:
893 912 io_loop = ioloop.IOLoop.instance()
894 913 self.io_loop = io_loop
895 914
896 915 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
897 916 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
898 917 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
899 918 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
900 919
901 920 def spin(self):
902 921 for stream in (self.queue_stream, self.notifier_stream,
903 922 self.task_stream, self.control_stream):
904 923 stream.flush()
905 924
@@ -1,953 +1,957 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 import os
18 19 from datetime import datetime
19 20 import logging
20 21
21 22 import zmq
22 23 from zmq.eventloop import zmqstream, ioloop
23 24 import uuid
24 25
25 26 # internal:
26 27 from IPython.zmq.log import logger # a Logger object
27 28 from IPython.zmq.entry_point import bind_port
28 29
29 30 from streamsession import Message, wrap_exception
30 31 from entry_point import (make_base_argument_parser, select_random_ports, split_ports,
31 connect_logger, parse_url, signal_children)
32 connect_logger, parse_url, signal_children, generate_exec_key)
32 33
33 34 #-----------------------------------------------------------------------------
34 35 # Code
35 36 #-----------------------------------------------------------------------------
36 37
37 38 def _passer(*args, **kwargs):
38 39 return
39 40
40 41 class ReverseDict(dict):
41 42 """simple double-keyed subset of dict methods."""
42 43
43 44 def __init__(self, *args, **kwargs):
44 45 dict.__init__(self, *args, **kwargs)
45 46 self.reverse = dict()
46 47 for key, value in self.iteritems():
47 48 self.reverse[value] = key
48 49
49 50 def __getitem__(self, key):
50 51 try:
51 52 return dict.__getitem__(self, key)
52 53 except KeyError:
53 54 return self.reverse[key]
54 55
55 56 def __setitem__(self, key, value):
56 57 if key in self.reverse:
57 58 raise KeyError("Can't have key %r on both sides!"%key)
58 59 dict.__setitem__(self, key, value)
59 60 self.reverse[value] = key
60 61
61 62 def pop(self, key):
62 63 value = dict.pop(self, key)
63 64 self.d1.pop(value)
64 65 return value
65 66
66 67
67 68 class EngineConnector(object):
68 69 """A simple object for accessing the various zmq connections of an object.
69 70 Attributes are:
70 71 id (int): engine ID
71 72 uuid (str): uuid (unused?)
72 73 queue (str): identity of queue's XREQ socket
73 74 registration (str): identity of registration XREQ socket
74 75 heartbeat (str): identity of heartbeat XREQ socket
75 76 """
76 77 id=0
77 78 queue=None
78 79 control=None
79 80 registration=None
80 81 heartbeat=None
81 82 pending=None
82 83
83 84 def __init__(self, id, queue, registration, control, heartbeat=None):
84 85 logger.info("engine::Engine Connected: %i"%id)
85 86 self.id = id
86 87 self.queue = queue
87 88 self.registration = registration
88 89 self.control = control
89 90 self.heartbeat = heartbeat
90 91
91 92 class Controller(object):
92 93 """The IPython Controller with 0MQ connections
93 94
94 95 Parameters
95 96 ==========
96 97 loop: zmq IOLoop instance
97 98 session: StreamSession object
98 99 <removed> context: zmq context for creating new connections (?)
99 100 queue: ZMQStream for monitoring the command queue (SUB)
100 101 registrar: ZMQStream for engine registration requests (XREP)
101 102 heartbeat: HeartMonitor object checking the pulse of the engines
102 103 clientele: ZMQStream for client connections (XREP)
103 104 not used for jobs, only query/control commands
104 105 notifier: ZMQStream for broadcasting engine registration changes (PUB)
105 106 db: connection to db for out of memory logging of commands
106 107 NotImplemented
107 108 engine_addrs: dict of zmq connection information for engines to connect
108 109 to the queues.
109 110 client_addrs: dict of zmq connection information for engines to connect
110 111 to the queues.
111 112 """
112 113 # internal data structures:
113 114 ids=None # engine IDs
114 115 keytable=None
115 116 engines=None
116 117 clients=None
117 118 hearts=None
118 119 pending=None
119 120 results=None
120 121 tasks=None
121 122 completed=None
122 123 mia=None
123 124 incoming_registrations=None
124 125 registration_timeout=None
125 126
126 127 #objects from constructor:
127 128 loop=None
128 129 registrar=None
129 130 clientelle=None
130 131 queue=None
131 132 heartbeat=None
132 133 notifier=None
133 134 db=None
134 135 client_addr=None
135 136 engine_addrs=None
136 137
137 138
138 139 def __init__(self, loop, session, queue, registrar, heartbeat, clientele, notifier, db, engine_addrs, client_addrs):
139 140 """
140 141 # universal:
141 142 loop: IOLoop for creating future connections
142 143 session: streamsession for sending serialized data
143 144 # engine:
144 145 queue: ZMQStream for monitoring queue messages
145 146 registrar: ZMQStream for engine registration
146 147 heartbeat: HeartMonitor object for tracking engines
147 148 # client:
148 149 clientele: ZMQStream for client connections
149 150 # extra:
150 151 db: ZMQStream for db connection (NotImplemented)
151 152 engine_addrs: zmq address/protocol dict for engine connections
152 153 client_addrs: zmq address/protocol dict for client connections
153 154 """
154 155 self.ids = set()
155 156 self.keytable={}
156 157 self.incoming_registrations={}
157 158 self.engines = {}
158 159 self.by_ident = {}
159 160 self.clients = {}
160 161 self.hearts = {}
161 162 self.mia = set()
162 163
163 164 # self.sockets = {}
164 165 self.loop = loop
165 166 self.session = session
166 167 self.registrar = registrar
167 168 self.clientele = clientele
168 169 self.queue = queue
169 170 self.heartbeat = heartbeat
170 171 self.notifier = notifier
171 172 self.db = db
172 173
173 174 # validate connection dicts:
174 175 self.client_addrs = client_addrs
175 176 assert isinstance(client_addrs['queue'], str)
176 177 assert isinstance(client_addrs['control'], str)
177 178 # self.hb_addrs = hb_addrs
178 179 self.engine_addrs = engine_addrs
179 180 assert isinstance(engine_addrs['queue'], str)
180 181 assert isinstance(client_addrs['control'], str)
181 182 assert len(engine_addrs['heartbeat']) == 2
182 183
183 184 # register our callbacks
184 185 self.registrar.on_recv(self.dispatch_register_request)
185 186 self.clientele.on_recv(self.dispatch_client_msg)
186 187 self.queue.on_recv(self.dispatch_queue_traffic)
187 188
188 189 if heartbeat is not None:
189 190 heartbeat.add_heart_failure_handler(self.handle_heart_failure)
190 191 heartbeat.add_new_heart_handler(self.handle_new_heart)
191 192
192 193 self.queue_handlers = { 'in' : self.save_queue_request,
193 194 'out': self.save_queue_result,
194 195 'intask': self.save_task_request,
195 196 'outtask': self.save_task_result,
196 197 'tracktask': self.save_task_destination,
197 198 'incontrol': _passer,
198 199 'outcontrol': _passer,
199 200 }
200 201
201 202 self.client_handlers = {'queue_request': self.queue_status,
202 203 'result_request': self.get_results,
203 204 'purge_request': self.purge_results,
204 205 'load_request': self.check_load,
205 206 'resubmit_request': self.resubmit_task,
206 207 }
207 208
208 209 self.registrar_handlers = {'registration_request' : self.register_engine,
209 210 'unregistration_request' : self.unregister_engine,
210 211 'connection_request': self.connection_request,
211 212 }
212 213 #
213 214 # this is the stuff that will move to DB:
214 215 self.results = {} # completed results
215 216 self.pending = {} # pending messages, keyed by msg_id
216 217 self.queues = {} # pending msg_ids keyed by engine_id
217 218 self.tasks = {} # pending msg_ids submitted as tasks, keyed by client_id
218 219 self.completed = {} # completed msg_ids keyed by engine_id
219 220 self.registration_timeout = max(5000, 2*self.heartbeat.period)
220 221
221 222 logger.info("controller::created controller")
222 223
223 224 def _new_id(self):
224 225 """gemerate a new ID"""
225 226 newid = 0
226 227 incoming = [id[0] for id in self.incoming_registrations.itervalues()]
227 228 # print newid, self.ids, self.incoming_registrations
228 229 while newid in self.ids or newid in incoming:
229 230 newid += 1
230 231 return newid
231 232
232 233 #-----------------------------------------------------------------------------
233 234 # message validation
234 235 #-----------------------------------------------------------------------------
235 236
236 237 def _validate_targets(self, targets):
237 238 """turn any valid targets argument into a list of integer ids"""
238 239 if targets is None:
239 240 # default to all
240 241 targets = self.ids
241 242
242 243 if isinstance(targets, (int,str,unicode)):
243 244 # only one target specified
244 245 targets = [targets]
245 246 _targets = []
246 247 for t in targets:
247 248 # map raw identities to ids
248 249 if isinstance(t, (str,unicode)):
249 250 t = self.by_ident.get(t, t)
250 251 _targets.append(t)
251 252 targets = _targets
252 253 bad_targets = [ t for t in targets if t not in self.ids ]
253 254 if bad_targets:
254 255 raise IndexError("No Such Engine: %r"%bad_targets)
255 256 if not targets:
256 257 raise IndexError("No Engines Registered")
257 258 return targets
258 259
259 260 def _validate_client_msg(self, msg):
260 261 """validates and unpacks headers of a message. Returns False if invalid,
261 262 (ident, header, parent, content)"""
262 263 client_id = msg[0]
263 264 try:
264 265 msg = self.session.unpack_message(msg[1:], content=True)
265 266 except:
266 267 logger.error("client::Invalid Message %s"%msg)
267 268 return False
268 269
269 270 msg_type = msg.get('msg_type', None)
270 271 if msg_type is None:
271 272 return False
272 273 header = msg.get('header')
273 274 # session doesn't handle split content for now:
274 275 return client_id, msg
275 276
276 277
277 278 #-----------------------------------------------------------------------------
278 279 # dispatch methods (1 per stream)
279 280 #-----------------------------------------------------------------------------
280 281
281 282 def dispatch_register_request(self, msg):
282 283 """"""
283 284 logger.debug("registration::dispatch_register_request(%s)"%msg)
284 285 idents,msg = self.session.feed_identities(msg)
285 286 if not idents:
286 logger.error("Bad Queue Message: %s"%msg)
287 logger.error("Bad Queue Message: %s"%msg, exc_info=True)
287 288 return
288 289 try:
289 290 msg = self.session.unpack_message(msg,content=True)
290 except Exception as e:
291 logger.error("registration::got bad registration message: %s"%msg)
292 raise e
291 except:
292 logger.error("registration::got bad registration message: %s"%msg, exc_info=True)
293 293 return
294 294
295 295 msg_type = msg['msg_type']
296 296 content = msg['content']
297 297
298 298 handler = self.registrar_handlers.get(msg_type, None)
299 299 if handler is None:
300 300 logger.error("registration::got bad registration message: %s"%msg)
301 301 else:
302 302 handler(idents, msg)
303 303
304 304 def dispatch_queue_traffic(self, msg):
305 305 """all ME and Task queue messages come through here"""
306 306 logger.debug("queue traffic: %s"%msg[:2])
307 307 switch = msg[0]
308 308 idents, msg = self.session.feed_identities(msg[1:])
309 309 if not idents:
310 310 logger.error("Bad Queue Message: %s"%msg)
311 311 return
312 312 handler = self.queue_handlers.get(switch, None)
313 313 if handler is not None:
314 314 handler(idents, msg)
315 315 else:
316 316 logger.error("Invalid message topic: %s"%switch)
317 317
318 318
319 319 def dispatch_client_msg(self, msg):
320 320 """Route messages from clients"""
321 321 idents, msg = self.session.feed_identities(msg)
322 322 if not idents:
323 323 logger.error("Bad Client Message: %s"%msg)
324 324 return
325 325 try:
326 326 msg = self.session.unpack_message(msg, content=True)
327 327 except:
328 328 content = wrap_exception()
329 logger.error("Bad Client Message: %s"%msg)
329 logger.error("Bad Client Message: %s"%msg, exc_info=True)
330 330 self.session.send(self.clientele, "controller_error", ident=client_id,
331 331 content=content)
332 332 return
333 333
334 334 # print client_id, header, parent, content
335 335 #switch on message type:
336 336 msg_type = msg['msg_type']
337 337 logger.info("client:: client %s requested %s"%(client_id, msg_type))
338 338 handler = self.client_handlers.get(msg_type, None)
339 339 try:
340 340 assert handler is not None, "Bad Message Type: %s"%msg_type
341 341 except:
342 342 content = wrap_exception()
343 logger.error("Bad Message Type: %s"%msg_type)
343 logger.error("Bad Message Type: %s"%msg_type, exc_info=True)
344 344 self.session.send(self.clientele, "controller_error", ident=client_id,
345 345 content=content)
346 346 return
347 347 else:
348 348 handler(client_id, msg)
349 349
350 350 def dispatch_db(self, msg):
351 351 """"""
352 352 raise NotImplementedError
353 353
354 354 #---------------------------------------------------------------------------
355 355 # handler methods (1 per event)
356 356 #---------------------------------------------------------------------------
357 357
358 358 #----------------------- Heartbeat --------------------------------------
359 359
360 360 def handle_new_heart(self, heart):
361 361 """handler to attach to heartbeater.
362 362 Called when a new heart starts to beat.
363 363 Triggers completion of registration."""
364 364 logger.debug("heartbeat::handle_new_heart(%r)"%heart)
365 365 if heart not in self.incoming_registrations:
366 366 logger.info("heartbeat::ignoring new heart: %r"%heart)
367 367 else:
368 368 self.finish_registration(heart)
369 369
370 370
371 371 def handle_heart_failure(self, heart):
372 372 """handler to attach to heartbeater.
373 373 called when a previously registered heart fails to respond to beat request.
374 374 triggers unregistration"""
375 375 logger.debug("heartbeat::handle_heart_failure(%r)"%heart)
376 376 eid = self.hearts.get(heart, None)
377 377 queue = self.engines[eid].queue
378 378 if eid is None:
379 379 logger.info("heartbeat::ignoring heart failure %r"%heart)
380 380 else:
381 381 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
382 382
383 383 #----------------------- MUX Queue Traffic ------------------------------
384 384
385 385 def save_queue_request(self, idents, msg):
386 386 if len(idents) < 2:
387 387 logger.error("invalid identity prefix: %s"%idents)
388 388 return
389 389 queue_id, client_id = idents[:2]
390 390 try:
391 391 msg = self.session.unpack_message(msg, content=False)
392 392 except:
393 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg))
393 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
394 394 return
395 395
396 396 eid = self.by_ident.get(queue_id, None)
397 397 if eid is None:
398 398 logger.error("queue::target %r not registered"%queue_id)
399 399 logger.debug("queue:: valid are: %s"%(self.by_ident.keys()))
400 400 return
401 401
402 402 header = msg['header']
403 403 msg_id = header['msg_id']
404 404 info = dict(submit=datetime.now(),
405 405 received=None,
406 406 engine=(eid, queue_id))
407 407 self.pending[msg_id] = ( msg, info )
408 408 self.queues[eid].append(msg_id)
409 409
410 410 def save_queue_result(self, idents, msg):
411 411 if len(idents) < 2:
412 412 logger.error("invalid identity prefix: %s"%idents)
413 413 return
414 414
415 415 client_id, queue_id = idents[:2]
416 416 try:
417 417 msg = self.session.unpack_message(msg, content=False)
418 418 except:
419 419 logger.error("queue::engine %r sent invalid message to %r: %s"%(
420 queue_id,client_id, msg))
420 queue_id,client_id, msg), exc_info=True)
421 421 return
422 422
423 423 eid = self.by_ident.get(queue_id, None)
424 424 if eid is None:
425 425 logger.error("queue::unknown engine %r is sending a reply: "%queue_id)
426 426 logger.debug("queue:: %s"%msg[2:])
427 427 return
428 428
429 429 parent = msg['parent_header']
430 430 if not parent:
431 431 return
432 432 msg_id = parent['msg_id']
433 433 self.results[msg_id] = msg
434 434 if msg_id in self.pending:
435 435 self.pending.pop(msg_id)
436 436 self.queues[eid].remove(msg_id)
437 437 self.completed[eid].append(msg_id)
438 438 else:
439 439 logger.debug("queue:: unknown msg finished %s"%msg_id)
440 440
441 441 #--------------------- Task Queue Traffic ------------------------------
442 442
443 443 def save_task_request(self, idents, msg):
444 444 """Save the submission of a task."""
445 445 client_id = idents[0]
446 446
447 447 try:
448 448 msg = self.session.unpack_message(msg, content=False)
449 449 except:
450 450 logger.error("task::client %r sent invalid task message: %s"%(
451 client_id, msg))
451 client_id, msg), exc_info=True)
452 452 return
453 453
454 454 header = msg['header']
455 455 msg_id = header['msg_id']
456 456 self.mia.add(msg_id)
457 457 info = dict(submit=datetime.now(),
458 458 received=None,
459 459 engine=None)
460 460 self.pending[msg_id] = (msg, info)
461 461 if not self.tasks.has_key(client_id):
462 462 self.tasks[client_id] = []
463 463 self.tasks[client_id].append(msg_id)
464 464
465 465 def save_task_result(self, idents, msg):
466 466 """save the result of a completed task."""
467 467 client_id = idents[0]
468 468 try:
469 469 msg = self.session.unpack_message(msg, content=False)
470 470 except:
471 471 logger.error("task::invalid task result message send to %r: %s"%(
472 472 client_id, msg))
473 473 return
474 474
475 475 parent = msg['parent_header']
476 476 if not parent:
477 477 # print msg
478 478 logger.warn("Task %r had no parent!"%msg)
479 479 return
480 480 msg_id = parent['msg_id']
481 481 self.results[msg_id] = msg
482 482
483 483 header = msg['header']
484 484 engine_uuid = header.get('engine', None)
485 485 eid = self.by_ident.get(engine_uuid, None)
486 486
487 487 if msg_id in self.pending:
488 488 self.pending.pop(msg_id)
489 489 if msg_id in self.mia:
490 490 self.mia.remove(msg_id)
491 491 if eid is not None and msg_id in self.tasks[eid]:
492 492 self.completed[eid].append(msg_id)
493 493 self.tasks[eid].remove(msg_id)
494 494 else:
495 495 logger.debug("task::unknown task %s finished"%msg_id)
496 496
497 497 def save_task_destination(self, idents, msg):
498 498 try:
499 499 msg = self.session.unpack_message(msg, content=True)
500 500 except:
501 501 logger.error("task::invalid task tracking message")
502 502 return
503 503 content = msg['content']
504 504 print (content)
505 505 msg_id = content['msg_id']
506 506 engine_uuid = content['engine_id']
507 507 eid = self.by_ident[engine_uuid]
508 508
509 509 logger.info("task::task %s arrived on %s"%(msg_id, eid))
510 510 if msg_id in self.mia:
511 511 self.mia.remove(msg_id)
512 512 else:
513 513 logger.debug("task::task %s not listed as MIA?!"%(msg_id))
514 514
515 515 self.tasks[eid].append(msg_id)
516 516 self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
517 517
518 518 def mia_task_request(self, idents, msg):
519 519 client_id = idents[0]
520 520 content = dict(mia=self.mia,status='ok')
521 521 self.session.send('mia_reply', content=content, idents=client_id)
522 522
523 523
524 524
525 525 #-------------------------------------------------------------------------
526 526 # Registration requests
527 527 #-------------------------------------------------------------------------
528 528
529 529 def connection_request(self, client_id, msg):
530 530 """Reply with connection addresses for clients."""
531 531 logger.info("client::client %s connected"%client_id)
532 532 content = dict(status='ok')
533 533 content.update(self.client_addrs)
534 534 jsonable = {}
535 535 for k,v in self.keytable.iteritems():
536 536 jsonable[str(k)] = v
537 537 content['engines'] = jsonable
538 538 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
539 539
540 540 def register_engine(self, reg, msg):
541 541 """Register a new engine."""
542 542 content = msg['content']
543 543 try:
544 544 queue = content['queue']
545 545 except KeyError:
546 546 logger.error("registration::queue not specified")
547 547 return
548 548 heart = content.get('heartbeat', None)
549 549 """register a new engine, and create the socket(s) necessary"""
550 550 eid = self._new_id()
551 551 # print (eid, queue, reg, heart)
552 552
553 553 logger.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
554 554
555 555 content = dict(id=eid,status='ok')
556 556 content.update(self.engine_addrs)
557 557 # check if requesting available IDs:
558 558 if queue in self.by_ident:
559 559 try:
560 560 raise KeyError("queue_id %r in use"%queue)
561 561 except:
562 562 content = wrap_exception()
563 563 elif heart in self.hearts: # need to check unique hearts?
564 564 try:
565 565 raise KeyError("heart_id %r in use"%heart)
566 566 except:
567 567 content = wrap_exception()
568 568 else:
569 569 for h, pack in self.incoming_registrations.iteritems():
570 570 if heart == h:
571 571 try:
572 572 raise KeyError("heart_id %r in use"%heart)
573 573 except:
574 574 content = wrap_exception()
575 575 break
576 576 elif queue == pack[1]:
577 577 try:
578 578 raise KeyError("queue_id %r in use"%queue)
579 579 except:
580 580 content = wrap_exception()
581 581 break
582 582
583 583 msg = self.session.send(self.registrar, "registration_reply",
584 584 content=content,
585 585 ident=reg)
586 586
587 587 if content['status'] == 'ok':
588 588 if heart in self.heartbeat.hearts:
589 589 # already beating
590 590 self.incoming_registrations[heart] = (eid,queue,reg,None)
591 591 self.finish_registration(heart)
592 592 else:
593 593 purge = lambda : self._purge_stalled_registration(heart)
594 594 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
595 595 dc.start()
596 596 self.incoming_registrations[heart] = (eid,queue,reg,dc)
597 597 else:
598 598 logger.error("registration::registration %i failed: %s"%(eid, content['evalue']))
599 599 return eid
600 600
601 601 def unregister_engine(self, ident, msg):
602 602 """Unregister an engine that explicitly requested to leave."""
603 603 try:
604 604 eid = msg['content']['id']
605 605 except:
606 606 logger.error("registration::bad engine id for unregistration: %s"%ident)
607 607 return
608 608 logger.info("registration::unregister_engine(%s)"%eid)
609 609 content=dict(id=eid, queue=self.engines[eid].queue)
610 610 self.ids.remove(eid)
611 611 self.keytable.pop(eid)
612 612 ec = self.engines.pop(eid)
613 613 self.hearts.pop(ec.heartbeat)
614 614 self.by_ident.pop(ec.queue)
615 615 self.completed.pop(eid)
616 616 for msg_id in self.queues.pop(eid):
617 617 msg = self.pending.pop(msg_id)
618 618 ############## TODO: HANDLE IT ################
619 619
620 620 if self.notifier:
621 621 self.session.send(self.notifier, "unregistration_notification", content=content)
622 622
623 623 def finish_registration(self, heart):
624 624 """Second half of engine registration, called after our HeartMonitor
625 625 has received a beat from the Engine's Heart."""
626 626 try:
627 627 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
628 628 except KeyError:
629 629 logger.error("registration::tried to finish nonexistant registration")
630 630 return
631 631 logger.info("registration::finished registering engine %i:%r"%(eid,queue))
632 632 if purge is not None:
633 633 purge.stop()
634 634 control = queue
635 635 self.ids.add(eid)
636 636 self.keytable[eid] = queue
637 637 self.engines[eid] = EngineConnector(eid, queue, reg, control, heart)
638 638 self.by_ident[queue] = eid
639 639 self.queues[eid] = list()
640 640 self.tasks[eid] = list()
641 641 self.completed[eid] = list()
642 642 self.hearts[heart] = eid
643 643 content = dict(id=eid, queue=self.engines[eid].queue)
644 644 if self.notifier:
645 645 self.session.send(self.notifier, "registration_notification", content=content)
646 646
647 647 def _purge_stalled_registration(self, heart):
648 648 if heart in self.incoming_registrations:
649 649 eid = self.incoming_registrations.pop(heart)[0]
650 650 logger.info("registration::purging stalled registration: %i"%eid)
651 651 else:
652 652 pass
653 653
654 654 #-------------------------------------------------------------------------
655 655 # Client Requests
656 656 #-------------------------------------------------------------------------
657 657
658 658 def check_load(self, client_id, msg):
659 659 content = msg['content']
660 660 try:
661 661 targets = content['targets']
662 662 targets = self._validate_targets(targets)
663 663 except:
664 664 content = wrap_exception()
665 665 self.session.send(self.clientele, "controller_error",
666 666 content=content, ident=client_id)
667 667 return
668 668
669 669 content = dict(status='ok')
670 670 # loads = {}
671 671 for t in targets:
672 672 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
673 673 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
674 674
675 675
676 676 def queue_status(self, client_id, msg):
677 677 """Return the Queue status of one or more targets.
678 678 if verbose: return the msg_ids
679 679 else: return len of each type.
680 680 keys: queue (pending MUX jobs)
681 681 tasks (pending Task jobs)
682 682 completed (finished jobs from both queues)"""
683 683 content = msg['content']
684 684 targets = content['targets']
685 685 try:
686 686 targets = self._validate_targets(targets)
687 687 except:
688 688 content = wrap_exception()
689 689 self.session.send(self.clientele, "controller_error",
690 690 content=content, ident=client_id)
691 691 return
692 692 verbose = content.get('verbose', False)
693 693 content = dict(status='ok')
694 694 for t in targets:
695 695 queue = self.queues[t]
696 696 completed = self.completed[t]
697 697 tasks = self.tasks[t]
698 698 if not verbose:
699 699 queue = len(queue)
700 700 completed = len(completed)
701 701 tasks = len(tasks)
702 702 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
703 703 # pending
704 704 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
705 705
706 706 def purge_results(self, client_id, msg):
707 707 """Purge results from memory. This method is more valuable before we move
708 708 to a DB based message storage mechanism."""
709 709 content = msg['content']
710 710 msg_ids = content.get('msg_ids', [])
711 711 reply = dict(status='ok')
712 712 if msg_ids == 'all':
713 713 self.results = {}
714 714 else:
715 715 for msg_id in msg_ids:
716 716 if msg_id in self.results:
717 717 self.results.pop(msg_id)
718 718 else:
719 719 if msg_id in self.pending:
720 720 try:
721 721 raise IndexError("msg pending: %r"%msg_id)
722 722 except:
723 723 reply = wrap_exception()
724 724 else:
725 725 try:
726 726 raise IndexError("No such msg: %r"%msg_id)
727 727 except:
728 728 reply = wrap_exception()
729 729 break
730 730 eids = content.get('engine_ids', [])
731 731 for eid in eids:
732 732 if eid not in self.engines:
733 733 try:
734 734 raise IndexError("No such engine: %i"%eid)
735 735 except:
736 736 reply = wrap_exception()
737 737 break
738 738 msg_ids = self.completed.pop(eid)
739 739 for msg_id in msg_ids:
740 740 self.results.pop(msg_id)
741 741
742 742 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
743 743
744 744 def resubmit_task(self, client_id, msg, buffers):
745 745 """Resubmit a task."""
746 746 raise NotImplementedError
747 747
748 748 def get_results(self, client_id, msg):
749 749 """Get the result of 1 or more messages."""
750 750 content = msg['content']
751 751 msg_ids = set(content['msg_ids'])
752 752 statusonly = content.get('status_only', False)
753 753 pending = []
754 754 completed = []
755 755 content = dict(status='ok')
756 756 content['pending'] = pending
757 757 content['completed'] = completed
758 758 for msg_id in msg_ids:
759 759 if msg_id in self.pending:
760 760 pending.append(msg_id)
761 761 elif msg_id in self.results:
762 762 completed.append(msg_id)
763 763 if not statusonly:
764 764 content[msg_id] = self.results[msg_id]['content']
765 765 else:
766 766 try:
767 767 raise KeyError('No such message: '+msg_id)
768 768 except:
769 769 content = wrap_exception()
770 770 break
771 771 self.session.send(self.clientele, "result_reply", content=content,
772 772 parent=msg, ident=client_id)
773 773
774 774
775 775 #-------------------------------------------------------------------------
776 776 # Entry Point
777 777 #-------------------------------------------------------------------------
778 778
779 779 def make_argument_parser():
780 780 """Make an argument parser"""
781 781 parser = make_base_argument_parser()
782 782
783 783 parser.add_argument('--client', type=int, metavar='PORT', default=0,
784 784 help='set the XREP port for clients [default: random]')
785 785 parser.add_argument('--notice', type=int, metavar='PORT', default=0,
786 786 help='set the PUB socket for registration notification [default: random]')
787 787 parser.add_argument('--hb', type=str, metavar='PORTS',
788 788 help='set the 2 ports for heartbeats [default: random]')
789 789 parser.add_argument('--ping', type=int, default=3000,
790 790 help='set the heartbeat period in ms [default: 3000]')
791 791 parser.add_argument('--monitor', type=int, metavar='PORT', default=0,
792 792 help='set the SUB port for queue monitoring [default: random]')
793 793 parser.add_argument('--mux', type=str, metavar='PORTS',
794 794 help='set the XREP ports for the MUX queue [default: random]')
795 795 parser.add_argument('--task', type=str, metavar='PORTS',
796 796 help='set the XREP/XREQ ports for the task queue [default: random]')
797 797 parser.add_argument('--control', type=str, metavar='PORTS',
798 798 help='set the XREP ports for the control queue [default: random]')
799 799 parser.add_argument('--scheduler', type=str, default='pure',
800 800 choices = ['pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'],
801 801 help='select the task scheduler [default: pure ZMQ]')
802 802
803 803 return parser
804 804
805 805 def main():
806 806 import time
807 807 from multiprocessing import Process
808 808
809 809 from zmq.eventloop.zmqstream import ZMQStream
810 810 from zmq.devices import ProcessMonitoredQueue
811 811 from zmq.log import handlers
812 812
813 813 import streamsession as session
814 814 import heartmonitor
815 815 from scheduler import launch_scheduler
816 816
817 817 parser = make_argument_parser()
818 818
819 819 args = parser.parse_args()
820 820 parse_url(args)
821 821
822 822 iface="%s://%s"%(args.transport,args.ip)+':%i'
823 823
824 824 random_ports = 0
825 825 if args.hb:
826 826 hb = split_ports(args.hb, 2)
827 827 else:
828 828 hb = select_random_ports(2)
829 829 if args.mux:
830 830 mux = split_ports(args.mux, 2)
831 831 else:
832 832 mux = None
833 833 random_ports += 2
834 834 if args.task:
835 835 task = split_ports(args.task, 2)
836 836 else:
837 837 task = None
838 838 random_ports += 2
839 839 if args.control:
840 840 control = split_ports(args.control, 2)
841 841 else:
842 842 control = None
843 843 random_ports += 2
844 844
845 845 ctx = zmq.Context()
846 846 loop = ioloop.IOLoop.instance()
847 847
848 848 # setup logging
849 849 connect_logger(ctx, iface%args.logport, root="controller", loglevel=args.loglevel)
850 850
851 851 # Registrar socket
852 852 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
853 853 regport = bind_port(reg, args.ip, args.regport)
854 854
855 855 ### Engine connections ###
856 856
857 857 # heartbeat
858 858 hpub = ctx.socket(zmq.PUB)
859 859 bind_port(hpub, args.ip, hb[0])
860 860 hrep = ctx.socket(zmq.XREP)
861 861 bind_port(hrep, args.ip, hb[1])
862 862
863 863 hmon = heartmonitor.HeartMonitor(loop, ZMQStream(hpub,loop), ZMQStream(hrep,loop),args.ping)
864 864 hmon.start()
865 865
866 866 ### Client connections ###
867 867 # Clientele socket
868 868 c = ZMQStream(ctx.socket(zmq.XREP), loop)
869 869 cport = bind_port(c, args.ip, args.client)
870 870 # Notifier socket
871 871 n = ZMQStream(ctx.socket(zmq.PUB), loop)
872 872 nport = bind_port(n, args.ip, args.notice)
873 873
874 thesession = session.StreamSession(username=args.ident or "controller")
874 ### Key File ###
875 if args.execkey and not os.path.isfile(args.execkey):
876 generate_exec_key(args.execkey)
877
878 thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey)
875 879
876 880 ### build and launch the queues ###
877 881
878 882 # monitor socket
879 883 sub = ctx.socket(zmq.SUB)
880 884 sub.setsockopt(zmq.SUBSCRIBE, "")
881 885 monport = bind_port(sub, args.ip, args.monitor)
882 886 sub = ZMQStream(sub, loop)
883 887
884 888 ports = select_random_ports(random_ports)
885 889 children = []
886 890 # Multiplexer Queue (in a Process)
887 891 if not mux:
888 892 mux = (ports.pop(),ports.pop())
889 893 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
890 894 q.bind_in(iface%mux[0])
891 895 q.bind_out(iface%mux[1])
892 896 q.connect_mon(iface%monport)
893 897 q.daemon=True
894 898 q.start()
895 899 children.append(q.launcher)
896 900
897 901 # Control Queue (in a Process)
898 902 if not control:
899 903 control = (ports.pop(),ports.pop())
900 904 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
901 905 q.bind_in(iface%control[0])
902 906 q.bind_out(iface%control[1])
903 907 q.connect_mon(iface%monport)
904 908 q.daemon=True
905 909 q.start()
906 910 children.append(q.launcher)
907 911 # Task Queue (in a Process)
908 912 if not task:
909 913 task = (ports.pop(),ports.pop())
910 914 if args.scheduler == 'pure':
911 915 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
912 916 q.bind_in(iface%task[0])
913 917 q.bind_out(iface%task[1])
914 918 q.connect_mon(iface%monport)
915 919 q.daemon=True
916 920 q.start()
917 921 children.append(q.launcher)
918 922 else:
919 923 sargs = (iface%task[0],iface%task[1],iface%monport,iface%nport,args.scheduler)
920 924 print (sargs)
921 925 q = Process(target=launch_scheduler, args=sargs)
922 926 q.daemon=True
923 927 q.start()
924 928 children.append(q)
925 929
926 930 time.sleep(.25)
927 931
928 932 # build connection dicts
929 933 engine_addrs = {
930 934 'control' : iface%control[1],
931 935 'queue': iface%mux[1],
932 936 'heartbeat': (iface%hb[0], iface%hb[1]),
933 937 'task' : iface%task[1],
934 938 'monitor' : iface%monport,
935 939 }
936 940
937 941 client_addrs = {
938 942 'control' : iface%control[0],
939 943 'query': iface%cport,
940 944 'queue': iface%mux[0],
941 945 'task' : iface%task[0],
942 946 'notification': iface%nport
943 947 }
944 948 signal_children(children)
945 949 con = Controller(loop, thesession, sub, reg, hmon, c, n, None, engine_addrs, client_addrs)
946 950 dc = ioloop.DelayedCallback(lambda : print("Controller started..."), 100, loop)
947 951 loop.start()
948 952
949 953
950 954
951 955
952 956 if __name__ == '__main__':
953 957 main()
@@ -1,132 +1,133 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's queue(s).
5 5 """
6 6 from __future__ import print_function
7 7 import sys
8 8 import time
9 9 import traceback
10 10 import uuid
11 11 from pprint import pprint
12 12
13 13 import zmq
14 14 from zmq.eventloop import ioloop, zmqstream
15 15
16 16 from IPython.utils.traitlets import HasTraits
17 17 from IPython.utils.localinterfaces import LOCALHOST
18 18
19 19 from streamsession import Message, StreamSession
20 20 from client import Client
21 21 from streamkernel import Kernel, make_kernel
22 22 import heartmonitor
23 23 from entry_point import make_base_argument_parser, connect_logger, parse_url
24 24 # import taskthread
25 25 # from log import logger
26 26
27 27
28 28 def printer(*msg):
29 29 pprint(msg)
30 30
31 31 class Engine(object):
32 32 """IPython engine"""
33 33
34 34 id=None
35 35 context=None
36 36 loop=None
37 37 session=None
38 38 ident=None
39 39 registrar=None
40 40 heart=None
41 41 kernel=None
42 42
43 def __init__(self, context, loop, session, registrar, client, ident=None, heart_id=None):
43 def __init__(self, context, loop, session, registrar, client=None, ident=None):
44 44 self.context = context
45 45 self.loop = loop
46 46 self.session = session
47 47 self.registrar = registrar
48 48 self.client = client
49 49 self.ident = ident if ident else str(uuid.uuid4())
50 50 self.registrar.on_send(printer)
51 51
52 52 def register(self):
53 53
54 54 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
55 55 self.registrar.on_recv(self.complete_registration)
56 # print (self.session.key)
56 57 self.session.send(self.registrar, "registration_request",content=content)
57 58
58 59 def complete_registration(self, msg):
59 60 # print msg
60 61 idents,msg = self.session.feed_identities(msg)
61 62 msg = Message(self.session.unpack_message(msg))
62 63 if msg.content.status == 'ok':
63 64 self.session.username = str(msg.content.id)
64 65 queue_addr = msg.content.queue
65 66 shell_addrs = [str(queue_addr)]
66 67 control_addr = str(msg.content.control)
67 68 task_addr = msg.content.task
68 69 if task_addr:
69 70 shell_addrs.append(str(task_addr))
70 71
71 72 hb_addrs = msg.content.heartbeat
72 73 # ioloop.DelayedCallback(self.heart.start, 1000, self.loop).start()
73 74
74 75 # placeholder for no, since pub isn't hooked up:
75 76 sub = self.context.socket(zmq.SUB)
76 77 sub = zmqstream.ZMQStream(sub, self.loop)
77 78 sub.on_recv(lambda *a: None)
78 79 port = sub.bind_to_random_port("tcp://%s"%LOCALHOST)
79 80 iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345)
80
81 81 make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs,
82 client_addr=None, loop=self.loop, context=self.context)
82 client_addr=None, loop=self.loop, context=self.context, key=self.session.key)
83 83
84 84 else:
85 85 # logger.error("Registration Failed: %s"%msg)
86 86 raise Exception("Registration Failed: %s"%msg)
87 87
88 88 # logger.info("engine::completed registration with id %s"%self.session.username)
89 89
90 90 print (msg)
91 91
92 92 def unregister(self):
93 93 self.session.send(self.registrar, "unregistration_request", content=dict(id=int(self.session.username)))
94 94 time.sleep(1)
95 95 sys.exit(0)
96 96
97 97 def start(self):
98 98 print ("registering")
99 99 self.register()
100 100
101 101
102 102
103 103 def main():
104 104
105 105 parser = make_base_argument_parser()
106 106
107 107 args = parser.parse_args()
108 108
109 109 parse_url(args)
110 110
111 111 iface="%s://%s"%(args.transport,args.ip)+':%i'
112 112
113 113 loop = ioloop.IOLoop.instance()
114 session = StreamSession()
114 session = StreamSession(keyfile=args.execkey)
115 # print (session.key)
115 116 ctx = zmq.Context()
116 117
117 118 # setup logging
118 119 connect_logger(ctx, iface%args.logport, root="engine", loglevel=args.loglevel)
119 120
120 121 reg_conn = iface % args.regport
121 122 print (reg_conn)
122 123 print ("Starting the engine...", file=sys.__stderr__)
123 124
124 125 reg = ctx.socket(zmq.PAIR)
125 126 reg.connect(reg_conn)
126 127 reg = zmqstream.ZMQStream(reg, loop)
127 client = Client(reg_conn)
128 client = None
128 129
129 130 e = Engine(ctx, loop, session, reg, client, args.ident)
130 131 dc = ioloop.DelayedCallback(e.start, 100, loop)
131 132 dc.start()
132 133 loop.start() No newline at end of file
@@ -1,100 +1,116 b''
1 1 """ Defines helper functions for creating kernel entry points and process
2 2 launchers.
3 3 """
4 4
5 5 # Standard library imports.
6 6 import logging
7 7 import atexit
8 8 import sys
9 9 import os
10 import stat
10 11 import socket
11 12 from subprocess import Popen, PIPE
12 13 from signal import signal, SIGINT, SIGABRT, SIGTERM
13 14 try:
14 15 from signal import SIGKILL
15 16 except ImportError:
16 17 SIGKILL=None
17 18
18 19 # System library imports.
19 20 import zmq
20 21 from zmq.log import handlers
21 22 # Local imports.
22 23 from IPython.core.ultratb import FormattedTB
23 24 from IPython.external.argparse import ArgumentParser
24 25 from IPython.zmq.log import logger
25 26
26 27 def split_ports(s, n):
27 28 """Parser helper for multiport strings"""
28 29 if not s:
29 30 return tuple([0]*n)
30 31 ports = map(int, s.split(','))
31 32 if len(ports) != n:
32 33 raise ValueError
33 34 return ports
34 35
35 36 def select_random_ports(n):
36 """Selects and return n random ports that are open."""
37 """Selects and return n random ports that are available."""
37 38 ports = []
38 39 for i in xrange(n):
39 40 sock = socket.socket()
40 41 sock.bind(('', 0))
41 42 ports.append(sock)
42 43 for i, sock in enumerate(ports):
43 44 port = sock.getsockname()[1]
44 45 sock.close()
45 46 ports[i] = port
46 47 return ports
47 48
48 49 def parse_url(args):
50 """Ensure args.url contains full transport://interface:port"""
49 51 if args.url:
50 52 iface = args.url.split('://',1)
51 53 if len(args) == 2:
52 54 args.transport,iface = iface
53 55 iface = iface.split(':')
54 56 args.ip = iface[0]
55 57 if iface[1]:
56 58 args.regport = iface[1]
57 59 args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport)
58 60
59 61 def signal_children(children):
62 """Relay interupt/term signals to children, for more solid process cleanup."""
60 63 def terminate_children(sig, frame):
61 64 for child in children:
62 65 child.terminate()
63 66 # sys.exit(sig)
64 67 for sig in (SIGINT, SIGABRT, SIGTERM):
65 68 signal(sig, terminate_children)
66 69
70 def generate_exec_key(keyfile):
71 import uuid
72 newkey = str(uuid.uuid4())
73 with open(keyfile, 'w') as f:
74 # f.write('ipython-key ')
75 f.write(newkey)
76 # set user-only RW permissions (0600)
77 # this will have no effect on Windows
78 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
79
80
67 81 def make_base_argument_parser():
68 82 """ Creates an ArgumentParser for the generic arguments supported by all
69 83 ipcluster entry points.
70 84 """
71 85 parser = ArgumentParser()
72 86 parser.add_argument('--ip', type=str, default='127.0.0.1',
73 87 help='set the controller\'s IP address [default: local]')
74 88 parser.add_argument('--transport', type=str, default='tcp',
75 89 help='set the transport to use [default: tcp]')
76 90 parser.add_argument('--regport', type=int, metavar='PORT', default=10101,
77 91 help='set the XREP port for registration [default: 10101]')
78 92 parser.add_argument('--logport', type=int, metavar='PORT', default=20202,
79 93 help='set the PUB port for logging [default: 10201]')
80 94 parser.add_argument('--loglevel', type=int, metavar='LEVEL', default=logging.DEBUG,
81 95 help='set the log level [default: DEBUG]')
82 96 parser.add_argument('--ident', type=str,
83 97 help='set the ZMQ identity [default: random]')
84 98 parser.add_argument('--packer', type=str, default='json',
85 99 choices=['json','pickle'],
86 100 help='set the message format method [default: json]')
87 101 parser.add_argument('--url', type=str,
88 102 help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101')
103 parser.add_argument('--execkey', type=str,
104 help="File containing key for authenticating requests.")
89 105
90 106 return parser
91 107
92 108
93 109 def connect_logger(context, iface, root="ip", loglevel=logging.DEBUG):
94 110 lsock = context.socket(zmq.PUB)
95 111 lsock.connect(iface)
96 112 handler = handlers.PUBHandler(lsock)
97 113 handler.setLevel(loglevel)
98 114 handler.root_topic = root
99 115 logger.addHandler(handler)
100 116 No newline at end of file
@@ -1,93 +1,93 b''
1 1 #!/usr/bin/env python
2 2 from __future__ import print_function
3 3 import sys,os
4 4 import time
5 5 from subprocess import Popen, PIPE
6 6
7 7 from entry_point import parse_url
8 8 from controller import make_argument_parser
9 9
10 10 def _filter_arg(flag, args):
11 11 filtered = []
12 12 if flag in args:
13 13 filtered.append(flag)
14 14 idx = args.index(flag)
15 15 if len(args) > idx+1:
16 16 if not args[idx+1].startswith('-'):
17 17 filtered.append(args[idx+1])
18 18 return filtered
19 19
20 20 def filter_args(flags, args=sys.argv[1:]):
21 21 filtered = []
22 22 for flag in flags:
23 23 if isinstance(flag, (list,tuple)):
24 24 for f in flag:
25 25 filtered.extend(_filter_arg(f, args))
26 26 else:
27 27 filtered.extend(_filter_arg(flag, args))
28 28 return filtered
29 29
30 30 def _strip_arg(flag, args):
31 31 while flag in args:
32 32 idx = args.index(flag)
33 33 args.pop(idx)
34 34 if len(args) > idx:
35 35 if not args[idx].startswith('-'):
36 36 args.pop(idx)
37 37
38 38 def strip_args(flags, args=sys.argv[1:]):
39 39 args = list(args)
40 40 for flag in flags:
41 41 if isinstance(flag, (list,tuple)):
42 42 for f in flag:
43 43 _strip_arg(f, args)
44 44 else:
45 45 _strip_arg(flag, args)
46 46 return args
47 47
48 48
49 49 def launch_process(mod, args):
50 50 """Launch a controller or engine in a subprocess."""
51 51 code = "from IPython.zmq.parallel.%s import main;main()"%mod
52 52 arguments = [ sys.executable, '-c', code ] + args
53 53 blackholew = file(os.devnull, 'w')
54 54 blackholer = file(os.devnull, 'r')
55 55
56 56 proc = Popen(arguments, stdin=blackholer, stdout=blackholew, stderr=PIPE)
57 57 return proc
58 58
59 59 def main():
60 60 parser = make_argument_parser()
61 61 parser.add_argument('--n', '-n', type=int, default=1,
62 62 help="The number of engines to start.")
63 63 args = parser.parse_args()
64 64 parse_url(args)
65 65
66 66 controller_args = strip_args([('--n','-n')])
67 67 engine_args = filter_args(['--url', '--regport', '--logport', '--ip',
68 '--transport','--loglevel','--packer'])+['--ident']
68 '--transport','--loglevel','--packer', '--execkey'])+['--ident']
69 69
70 70 controller = launch_process('controller', controller_args)
71 71 for i in range(10):
72 72 time.sleep(.1)
73 73 if controller.poll() is not None:
74 74 print("Controller failed to launch:")
75 75 print (controller.stderr.read())
76 76 sys.exit(255)
77 77
78 78 print("Launched Controller at %s"%args.url)
79 79 engines = [ launch_process('engine', engine_args+['engine-%i'%i]) for i in range(args.n) ]
80 80 print("%i Engines started"%args.n)
81 81
82 82 def wait_quietly(p):
83 83 try:
84 84 p.wait()
85 85 except KeyboardInterrupt:
86 86 pass
87 87
88 88 wait_quietly(controller)
89 89 map(wait_quietly, engines)
90 90 print ("Engines cleaned up.")
91 91
92 92 if __name__ == '__main__':
93 93 main() No newline at end of file
@@ -1,413 +1,423 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Imports
8 8 #-----------------------------------------------------------------------------
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 12 import __builtin__
13 13 from code import CommandCompiler
14 14 import os
15 15 import sys
16 16 import time
17 17 import traceback
18 18 from signal import SIGTERM, SIGKILL
19 19 from pprint import pprint
20 20
21 21 # System library imports.
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 # Local imports.
26 26 from IPython.utils.traitlets import HasTraits, Instance, List
27 27 from IPython.zmq.completer import KernelCompleter
28 28
29 29 from streamsession import StreamSession, Message, extract_header, serialize_object,\
30 30 unpack_apply_message
31 31 from dependency import UnmetDependency
32 32 import heartmonitor
33 33 from client import Client
34 34
35 35 def printer(*args):
36 36 pprint(args)
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Main kernel class
40 40 #-----------------------------------------------------------------------------
41 41
42 42 class Kernel(HasTraits):
43 43
44 44 #---------------------------------------------------------------------------
45 45 # Kernel interface
46 46 #---------------------------------------------------------------------------
47 47
48 48 session = Instance(StreamSession)
49 49 shell_streams = Instance(list)
50 50 control_stream = Instance(zmqstream.ZMQStream)
51 51 task_stream = Instance(zmqstream.ZMQStream)
52 52 iopub_stream = Instance(zmqstream.ZMQStream)
53 53 client = Instance(Client)
54 54
55 55 def __init__(self, **kwargs):
56 56 super(Kernel, self).__init__(**kwargs)
57 57 self.identity = self.shell_streams[0].getsockopt(zmq.IDENTITY)
58 58 self.user_ns = {}
59 59 self.history = []
60 60 self.compiler = CommandCompiler()
61 61 self.completer = KernelCompleter(self.user_ns)
62 62 self.aborted = set()
63 63
64 64 # Build dict of handlers for message types
65 65 self.shell_handlers = {}
66 66 self.control_handlers = {}
67 67 for msg_type in ['execute_request', 'complete_request', 'apply_request',
68 68 'clear_request']:
69 69 self.shell_handlers[msg_type] = getattr(self, msg_type)
70 70
71 71 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
72 72 self.control_handlers[msg_type] = getattr(self, msg_type)
73 73
74 74 #-------------------- control handlers -----------------------------
75 75 def abort_queues(self):
76 76 for stream in self.shell_streams:
77 77 if stream:
78 78 self.abort_queue(stream)
79 79
80 80 def abort_queue(self, stream):
81 81 while True:
82 82 try:
83 83 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
84 84 except zmq.ZMQError as e:
85 85 if e.errno == zmq.EAGAIN:
86 86 break
87 87 else:
88 88 return
89 89 else:
90 90 if msg is None:
91 91 return
92 92 else:
93 93 idents,msg = msg
94 94
95 95 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
96 96 # msg = self.reply_socket.recv_json()
97 97 print ("Aborting:", file=sys.__stdout__)
98 98 print (Message(msg), file=sys.__stdout__)
99 99 msg_type = msg['msg_type']
100 100 reply_type = msg_type.split('_')[0] + '_reply'
101 101 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
102 102 # self.reply_socket.send(ident,zmq.SNDMORE)
103 103 # self.reply_socket.send_json(reply_msg)
104 104 reply_msg = self.session.send(stream, reply_type,
105 105 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
106 106 print(Message(reply_msg), file=sys.__stdout__)
107 107 # We need to wait a bit for requests to come in. This can probably
108 108 # be set shorter for true asynchronous clients.
109 109 time.sleep(0.05)
110 110
111 111 def abort_request(self, stream, ident, parent):
112 112 """abort a specifig msg by id"""
113 113 msg_ids = parent['content'].get('msg_ids', None)
114 114 if isinstance(msg_ids, basestring):
115 115 msg_ids = [msg_ids]
116 116 if not msg_ids:
117 117 self.abort_queues()
118 118 for mid in msg_ids:
119 119 self.aborted.add(str(mid))
120 120
121 121 content = dict(status='ok')
122 122 reply_msg = self.session.send(stream, 'abort_reply', content=content,
123 123 parent=parent, ident=ident)[0]
124 124 print(Message(reply_msg), file=sys.__stdout__)
125 125
126 126 def shutdown_request(self, stream, ident, parent):
127 127 """kill ourself. This should really be handled in an external process"""
128 128 self.abort_queues()
129 129 content = dict(parent['content'])
130 msg = self.session.send(self.reply_socket, 'shutdown_reply',
131 content, parent, ident)
132 msg = self.session.send(self.pub_socket, 'shutdown_reply',
133 content, parent, ident)
130 msg = self.session.send(stream, 'shutdown_reply',
131 content=content, parent=parent, ident=ident)
132 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
133 # content, parent, ident)
134 134 # print >> sys.__stdout__, msg
135 135 time.sleep(0.1)
136 136 sys.exit(0)
137 137
138 138 def dispatch_control(self, msg):
139 139 idents,msg = self.session.feed_identities(msg, copy=False)
140 msg = self.session.unpack_message(msg, content=True, copy=False)
140 try:
141 msg = self.session.unpack_message(msg, content=True, copy=False)
142 except:
143 logger.error("Invalid Message", exc_info=True)
144 return
141 145
142 146 header = msg['header']
143 147 msg_id = header['msg_id']
144 148
145 149 handler = self.control_handlers.get(msg['msg_type'], None)
146 150 if handler is None:
147 151 print ("UNKNOWN CONTROL MESSAGE TYPE:", msg, file=sys.__stderr__)
148 152 else:
149 153 handler(self.control_stream, idents, msg)
150 154
151 155
152 156 #-------------------- queue helpers ------------------------------
153 157
154 158 def check_dependencies(self, dependencies):
155 159 if not dependencies:
156 160 return True
157 161 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
158 162 anyorall = dependencies[0]
159 163 dependencies = dependencies[1]
160 164 else:
161 165 anyorall = 'all'
162 166 results = self.client.get_results(dependencies,status_only=True)
163 167 if results['status'] != 'ok':
164 168 return False
165 169
166 170 if anyorall == 'any':
167 171 if not results['completed']:
168 172 return False
169 173 else:
170 174 if results['pending']:
171 175 return False
172 176
173 177 return True
174 178
175 179 def check_aborted(self, msg_id):
176 180 return msg_id in self.aborted
177 181
178 182 #-------------------- queue handlers -----------------------------
179 183
180 184 def clear_request(self, stream, idents, parent):
181 185 """Clear our namespace."""
182 186 self.user_ns = {}
183 187 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
184 188 content = dict(status='ok'))
185 189
186 190 def execute_request(self, stream, ident, parent):
187 191 try:
188 192 code = parent[u'content'][u'code']
189 193 except:
190 194 print("Got bad msg: ", file=sys.__stderr__)
191 195 print(Message(parent), file=sys.__stderr__)
192 196 return
193 197 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
194 198 # self.iopub_stream.send(pyin_msg)
195 199 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
196 200 try:
197 201 comp_code = self.compiler(code, '<zmq-kernel>')
198 202 # allow for not overriding displayhook
199 203 if hasattr(sys.displayhook, 'set_parent'):
200 204 sys.displayhook.set_parent(parent)
201 205 exec comp_code in self.user_ns, self.user_ns
202 206 except:
203 207 # result = u'error'
204 208 etype, evalue, tb = sys.exc_info()
205 209 tb = traceback.format_exception(etype, evalue, tb)
206 210 exc_content = {
207 211 u'status' : u'error',
208 212 u'traceback' : tb,
209 213 u'etype' : unicode(etype),
210 214 u'evalue' : unicode(evalue)
211 215 }
212 216 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
213 217 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
214 218 reply_content = exc_content
215 219 else:
216 220 reply_content = {'status' : 'ok'}
217 221 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
218 222 # self.reply_socket.send(ident, zmq.SNDMORE)
219 223 # self.reply_socket.send_json(reply_msg)
220 224 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, ident=ident)
221 225 print(Message(reply_msg), file=sys.__stdout__)
222 226 if reply_msg['content']['status'] == u'error':
223 227 self.abort_queues()
224 228
225 229 def complete_request(self, stream, ident, parent):
226 230 matches = {'matches' : self.complete(parent),
227 231 'status' : 'ok'}
228 232 completion_msg = self.session.send(stream, 'complete_reply',
229 233 matches, parent, ident)
230 234 # print >> sys.__stdout__, completion_msg
231 235
232 236 def complete(self, msg):
233 237 return self.completer.complete(msg.content.line, msg.content.text)
234 238
235 239 def apply_request(self, stream, ident, parent):
236 240 print (parent)
237 241 try:
238 242 content = parent[u'content']
239 243 bufs = parent[u'buffers']
240 244 msg_id = parent['header']['msg_id']
241 245 bound = content.get('bound', False)
242 246 except:
243 247 print("Got bad msg: ", file=sys.__stderr__)
244 248 print(Message(parent), file=sys.__stderr__)
245 249 return
246 250 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
247 251 # self.iopub_stream.send(pyin_msg)
248 252 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
249 253 sub = {'dependencies_met' : True, 'engine' : self.identity}
250 254 try:
251 255 # allow for not overriding displayhook
252 256 if hasattr(sys.displayhook, 'set_parent'):
253 257 sys.displayhook.set_parent(parent)
254 258 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
255 259 if bound:
256 260 working = self.user_ns
257 261 suffix = str(msg_id).replace("-","")
258 262 prefix = "_"
259 263
260 264 else:
261 265 working = dict()
262 266 suffix = prefix = "_" # prevent keyword collisions with lambda
263 267 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
264 268 # if f.fun
265 269 fname = prefix+f.func_name.strip('<>')+suffix
266 270 argname = prefix+"args"+suffix
267 271 kwargname = prefix+"kwargs"+suffix
268 272 resultname = prefix+"result"+suffix
269 273
270 274 ns = { fname : f, argname : args, kwargname : kwargs }
271 275 # print ns
272 276 working.update(ns)
273 277 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
274 278 exec code in working, working
275 279 result = working.get(resultname)
276 280 # clear the namespace
277 281 if bound:
278 282 for key in ns.iterkeys():
279 283 self.user_ns.pop(key)
280 284 else:
281 285 del working
282 286
283 287 packed_result,buf = serialize_object(result)
284 288 result_buf = [packed_result]+buf
285 289 except:
286 290 result = u'error'
287 291 etype, evalue, tb = sys.exc_info()
288 292 tb = traceback.format_exception(etype, evalue, tb)
289 293 exc_content = {
290 294 u'status' : u'error',
291 295 u'traceback' : tb,
292 296 u'etype' : unicode(etype),
293 297 u'evalue' : unicode(evalue)
294 298 }
295 299 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
296 300 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
297 301 reply_content = exc_content
298 302 result_buf = []
299 303
300 304 if etype is UnmetDependency:
301 305 sub['dependencies_met'] = False
302 306 else:
303 307 reply_content = {'status' : 'ok'}
304 308 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
305 309 # self.reply_socket.send(ident, zmq.SNDMORE)
306 310 # self.reply_socket.send_json(reply_msg)
307 311 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
308 312 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
309 313 print(Message(reply_msg), file=sys.__stdout__)
310 314 # if reply_msg['content']['status'] == u'error':
311 315 # self.abort_queues()
312 316
313 317 def dispatch_queue(self, stream, msg):
314 318 self.control_stream.flush()
315 319 idents,msg = self.session.feed_identities(msg, copy=False)
316 msg = self.session.unpack_message(msg, content=True, copy=False)
320 try:
321 msg = self.session.unpack_message(msg, content=True, copy=False)
322 except:
323 logger.error("Invalid Message", exc_info=True)
324 return
325
317 326
318 327 header = msg['header']
319 328 msg_id = header['msg_id']
320 329 if self.check_aborted(msg_id):
321 330 self.aborted.remove(msg_id)
322 331 # is it safe to assume a msg_id will not be resubmitted?
323 332 reply_type = msg['msg_type'].split('_')[0] + '_reply'
324 333 reply_msg = self.session.send(stream, reply_type,
325 334 content={'status' : 'aborted'}, parent=msg, ident=idents)
326 335 return
327 336 handler = self.shell_handlers.get(msg['msg_type'], None)
328 337 if handler is None:
329 338 print ("UNKNOWN MESSAGE TYPE:", msg, file=sys.__stderr__)
330 339 else:
331 340 handler(stream, idents, msg)
332 341
333 342 def start(self):
334 343 #### stream mode:
335 344 if self.control_stream:
336 345 self.control_stream.on_recv(self.dispatch_control, copy=False)
337 346 self.control_stream.on_err(printer)
338 347
339 348 for s in self.shell_streams:
340 349 s.on_recv(lambda msg:
341 350 self.dispatch_queue(s, msg), copy=False)
342 351 s.on_err(printer)
343 352
344 353 if self.iopub_stream:
345 354 self.iopub_stream.on_err(printer)
346 355 self.iopub_stream.on_send(printer)
347 356
348 357 #### while True mode:
349 358 # while True:
350 359 # idle = True
351 360 # try:
352 361 # msg = self.shell_stream.socket.recv_multipart(
353 362 # zmq.NOBLOCK, copy=False)
354 363 # except zmq.ZMQError, e:
355 364 # if e.errno != zmq.EAGAIN:
356 365 # raise e
357 366 # else:
358 367 # idle=False
359 368 # self.dispatch_queue(self.shell_stream, msg)
360 369 #
361 370 # if not self.task_stream.empty():
362 371 # idle=False
363 372 # msg = self.task_stream.recv_multipart()
364 373 # self.dispatch_queue(self.task_stream, msg)
365 374 # if idle:
366 375 # # don't busywait
367 376 # time.sleep(1e-3)
368 377
369 378 def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
370 client_addr=None, loop=None, context=None):
379 client_addr=None, loop=None, context=None, key=None):
371 380 # create loop, context, and session:
372 381 if loop is None:
373 382 loop = ioloop.IOLoop.instance()
374 383 if context is None:
375 384 context = zmq.Context()
376 385 c = context
377 session = StreamSession()
386 session = StreamSession(key=key)
387 # print (session.key)
378 388 print (control_addr, shell_addrs, iopub_addr, hb_addrs)
379 389
380 390 # create Control Stream
381 391 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
382 392 control_stream.setsockopt(zmq.IDENTITY, identity)
383 393 control_stream.connect(control_addr)
384 394
385 395 # create Shell Streams (MUX, Task, etc.):
386 396 shell_streams = []
387 397 for addr in shell_addrs:
388 398 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
389 399 stream.setsockopt(zmq.IDENTITY, identity)
390 400 stream.connect(addr)
391 401 shell_streams.append(stream)
392 402
393 403 # create iopub stream:
394 404 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
395 405 iopub_stream.setsockopt(zmq.IDENTITY, identity)
396 406 iopub_stream.connect(iopub_addr)
397 407
398 408 # launch heartbeat
399 409 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
400 410 heart.start()
401 411
402 412 # create (optional) Client
403 413 if client_addr:
404 414 client = Client(client_addr, username=identity)
405 415 else:
406 416 client = None
407 417
408 418 kernel = Kernel(session=session, control_stream=control_stream,
409 419 shell_streams=shell_streams, iopub_stream=iopub_stream,
410 420 client=client)
411 421 kernel.start()
412 422 return loop, c
413 423
@@ -1,503 +1,530 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11 from datetime import datetime
12 12
13 13 import zmq
14 14 from zmq.utils import jsonapi
15 15 from zmq.eventloop.zmqstream import ZMQStream
16 16
17 17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 18 from IPython.utils.newserialized import serialize, unserialize
19 19
20 20 try:
21 21 import cPickle
22 22 pickle = cPickle
23 23 except:
24 24 cPickle = None
25 25 import pickle
26 26
27 27 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
28 28 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
29 29 if json_name in ('jsonlib', 'jsonlib2'):
30 30 use_json = True
31 31 elif json_name:
32 32 if cPickle is None:
33 33 use_json = True
34 34 else:
35 35 use_json = False
36 36 else:
37 37 use_json = False
38 38
39 39 def squash_unicode(obj):
40 40 if isinstance(obj,dict):
41 41 for key in obj.keys():
42 42 obj[key] = squash_unicode(obj[key])
43 43 if isinstance(key, unicode):
44 44 obj[squash_unicode(key)] = obj.pop(key)
45 45 elif isinstance(obj, list):
46 46 for i,v in enumerate(obj):
47 47 obj[i] = squash_unicode(v)
48 48 elif isinstance(obj, unicode):
49 49 obj = obj.encode('utf8')
50 50 return obj
51 51
52 52 if use_json:
53 53 default_packer = jsonapi.dumps
54 54 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
55 55 else:
56 56 default_packer = lambda o: pickle.dumps(o,-1)
57 57 default_unpacker = pickle.loads
58 58
59 59
60 60 DELIM="<IDS|MSG>"
61 61
62 62 def wrap_exception():
63 63 etype, evalue, tb = sys.exc_info()
64 64 tb = traceback.format_exception(etype, evalue, tb)
65 65 exc_content = {
66 66 'status' : 'error',
67 67 'traceback' : str(tb),
68 68 'etype' : str(etype),
69 69 'evalue' : str(evalue)
70 70 }
71 71 return exc_content
72 72
73 73 class KernelError(Exception):
74 74 pass
75 75
76 76 def unwrap_exception(content):
77 77 err = KernelError(content['etype'], content['evalue'])
78 78 err.evalue = content['evalue']
79 79 err.etype = content['etype']
80 80 err.traceback = ''.join(content['traceback'])
81 81 return err
82 82
83 83
84 84 class Message(object):
85 85 """A simple message object that maps dict keys to attributes.
86 86
87 87 A Message can be created from a dict and a dict from a Message instance
88 88 simply by calling dict(msg_obj)."""
89 89
90 90 def __init__(self, msg_dict):
91 91 dct = self.__dict__
92 92 for k, v in dict(msg_dict).iteritems():
93 93 if isinstance(v, dict):
94 94 v = Message(v)
95 95 dct[k] = v
96 96
97 97 # Having this iterator lets dict(msg_obj) work out of the box.
98 98 def __iter__(self):
99 99 return iter(self.__dict__.iteritems())
100 100
101 101 def __repr__(self):
102 102 return repr(self.__dict__)
103 103
104 104 def __str__(self):
105 105 return pprint.pformat(self.__dict__)
106 106
107 107 def __contains__(self, k):
108 108 return k in self.__dict__
109 109
110 110 def __getitem__(self, k):
111 111 return self.__dict__[k]
112 112
113 113
114 114 def msg_header(msg_id, msg_type, username, session):
115 115 date=datetime.now().isoformat()
116 116 return locals()
117 117
118 118 def extract_header(msg_or_header):
119 119 """Given a message or header, return the header."""
120 120 if not msg_or_header:
121 121 return {}
122 122 try:
123 123 # See if msg_or_header is the entire message.
124 124 h = msg_or_header['header']
125 125 except KeyError:
126 126 try:
127 127 # See if msg_or_header is just the header
128 128 h = msg_or_header['msg_id']
129 129 except KeyError:
130 130 raise
131 131 else:
132 132 h = msg_or_header
133 133 if not isinstance(h, dict):
134 134 h = dict(h)
135 135 return h
136 136
137 137 def rekey(dikt):
138 138 """Rekey a dict that has been forced to use str keys where there should be
139 139 ints by json. This belongs in the jsonutil added by fperez."""
140 140 for k in dikt.iterkeys():
141 141 if isinstance(k, str):
142 142 ik=fk=None
143 143 try:
144 144 ik = int(k)
145 145 except ValueError:
146 146 try:
147 147 fk = float(k)
148 148 except ValueError:
149 149 continue
150 150 if ik is not None:
151 151 nk = ik
152 152 else:
153 153 nk = fk
154 154 if nk in dikt:
155 155 raise KeyError("already have key %r"%nk)
156 156 dikt[nk] = dikt.pop(k)
157 157 return dikt
158 158
159 159 def serialize_object(obj, threshold=64e-6):
160 160 """Serialize an object into a list of sendable buffers.
161 161
162 162 Parameters
163 163 ----------
164 164
165 165 obj : object
166 166 The object to be serialized
167 167 threshold : float
168 168 The threshold for not double-pickling the content.
169 169
170 170
171 171 Returns
172 172 -------
173 173 ('pmd', [bufs]) :
174 174 where pmd is the pickled metadata wrapper,
175 175 bufs is a list of data buffers
176 176 """
177 177 databuffers = []
178 178 if isinstance(obj, (list, tuple)):
179 179 clist = canSequence(obj)
180 180 slist = map(serialize, clist)
181 181 for s in slist:
182 182 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
183 183 databuffers.append(s.getData())
184 184 s.data = None
185 185 return pickle.dumps(slist,-1), databuffers
186 186 elif isinstance(obj, dict):
187 187 sobj = {}
188 188 for k in sorted(obj.iterkeys()):
189 189 s = serialize(can(obj[k]))
190 190 if s.getDataSize() > threshold:
191 191 databuffers.append(s.getData())
192 192 s.data = None
193 193 sobj[k] = s
194 194 return pickle.dumps(sobj,-1),databuffers
195 195 else:
196 196 s = serialize(can(obj))
197 197 if s.getDataSize() > threshold:
198 198 databuffers.append(s.getData())
199 199 s.data = None
200 200 return pickle.dumps(s,-1),databuffers
201 201
202 202
203 203 def unserialize_object(bufs):
204 204 """reconstruct an object serialized by serialize_object from data buffers"""
205 205 bufs = list(bufs)
206 206 sobj = pickle.loads(bufs.pop(0))
207 207 if isinstance(sobj, (list, tuple)):
208 208 for s in sobj:
209 209 if s.data is None:
210 210 s.data = bufs.pop(0)
211 211 return uncanSequence(map(unserialize, sobj))
212 212 elif isinstance(sobj, dict):
213 213 newobj = {}
214 214 for k in sorted(sobj.iterkeys()):
215 215 s = sobj[k]
216 216 if s.data is None:
217 217 s.data = bufs.pop(0)
218 218 newobj[k] = uncan(unserialize(s))
219 219 return newobj
220 220 else:
221 221 if sobj.data is None:
222 222 sobj.data = bufs.pop(0)
223 223 return uncan(unserialize(sobj))
224 224
225 225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
226 226 """pack up a function, args, and kwargs to be sent over the wire
227 227 as a series of buffers. Any object whose data is larger than `threshold`
228 228 will not have their data copied (currently only numpy arrays support zero-copy)"""
229 229 msg = [pickle.dumps(can(f),-1)]
230 230 databuffers = [] # for large objects
231 231 sargs, bufs = serialize_object(args,threshold)
232 232 msg.append(sargs)
233 233 databuffers.extend(bufs)
234 234 skwargs, bufs = serialize_object(kwargs,threshold)
235 235 msg.append(skwargs)
236 236 databuffers.extend(bufs)
237 237 msg.extend(databuffers)
238 238 return msg
239 239
240 240 def unpack_apply_message(bufs, g=None, copy=True):
241 241 """unpack f,args,kwargs from buffers packed by pack_apply_message()
242 242 Returns: original f,args,kwargs"""
243 243 bufs = list(bufs) # allow us to pop
244 244 assert len(bufs) >= 3, "not enough buffers!"
245 245 if not copy:
246 246 for i in range(3):
247 247 bufs[i] = bufs[i].bytes
248 248 cf = pickle.loads(bufs.pop(0))
249 249 sargs = list(pickle.loads(bufs.pop(0)))
250 250 skwargs = dict(pickle.loads(bufs.pop(0)))
251 251 # print sargs, skwargs
252 252 f = uncan(cf, g)
253 253 for sa in sargs:
254 254 if sa.data is None:
255 255 m = bufs.pop(0)
256 256 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
257 257 if copy:
258 258 sa.data = buffer(m)
259 259 else:
260 260 sa.data = m.buffer
261 261 else:
262 262 if copy:
263 263 sa.data = m
264 264 else:
265 265 sa.data = m.bytes
266 266
267 267 args = uncanSequence(map(unserialize, sargs), g)
268 268 kwargs = {}
269 269 for k in sorted(skwargs.iterkeys()):
270 270 sa = skwargs[k]
271 271 if sa.data is None:
272 272 sa.data = bufs.pop(0)
273 273 kwargs[k] = uncan(unserialize(sa), g)
274 274
275 275 return f,args,kwargs
276 276
277 277 class StreamSession(object):
278 278 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
279 279 debug=False
280 def __init__(self, username=None, session=None, packer=None, unpacker=None):
280 key=None
281
282 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
281 283 if username is None:
282 284 username = os.environ.get('USER','username')
283 285 self.username = username
284 286 if session is None:
285 287 self.session = str(uuid.uuid4())
286 288 else:
287 289 self.session = session
288 290 self.msg_id = str(uuid.uuid4())
289 291 if packer is None:
290 292 self.pack = default_packer
291 293 else:
292 294 if not callable(packer):
293 295 raise TypeError("packer must be callable, not %s"%type(packer))
294 296 self.pack = packer
295 297
296 298 if unpacker is None:
297 299 self.unpack = default_unpacker
298 300 else:
299 301 if not callable(unpacker):
300 302 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
301 303 self.unpack = unpacker
302 304
305 if key is not None and keyfile is not None:
306 raise TypeError("Must specify key OR keyfile, not both")
307 if keyfile is not None:
308 with open(keyfile) as f:
309 self.key = f.read().strip()
310 else:
311 self.key = key
312 # print key, keyfile, self.key
303 313 self.none = self.pack({})
304 314
305 315 def msg_header(self, msg_type):
306 316 h = msg_header(self.msg_id, msg_type, self.username, self.session)
307 317 self.msg_id = str(uuid.uuid4())
308 318 return h
309 319
310 320 def msg(self, msg_type, content=None, parent=None, subheader=None):
311 321 msg = {}
312 322 msg['header'] = self.msg_header(msg_type)
313 323 msg['msg_id'] = msg['header']['msg_id']
314 324 msg['parent_header'] = {} if parent is None else extract_header(parent)
315 325 msg['msg_type'] = msg_type
316 326 msg['content'] = {} if content is None else content
317 327 sub = {} if subheader is None else subheader
318 328 msg['header'].update(sub)
319 329 return msg
320 330
331 def check_key(self, msg_or_header):
332 """Check that a message's header has the right key"""
333 if self.key is None:
334 return True
335 header = extract_header(msg_or_header)
336 return header.get('key', None) == self.key
337
338
321 339 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
322 340 """Build and send a message via stream or socket.
323 341
324 342 Parameters
325 343 ----------
326 344
327 345 stream : zmq.Socket or ZMQStream
328 346 the socket-like object used to send the data
329 347 msg_type : str or Message/dict
330 348 Normally, msg_type will be
331 349
332 350
333 351
334 352 Returns
335 353 -------
336 354 (msg,sent) : tuple
337 355 msg : Message
338 356 the nice wrapped dict-like object containing the headers
339 357
340 358 """
341 359 if isinstance(msg_type, (Message, dict)):
342 360 # we got a Message, not a msg_type
343 361 # don't build a new Message
344 362 msg = msg_type
345 363 content = msg['content']
346 364 else:
347 365 msg = self.msg(msg_type, content, parent, subheader)
348 366 buffers = [] if buffers is None else buffers
349 367 to_send = []
350 368 if isinstance(ident, list):
351 369 # accept list of idents
352 370 to_send.extend(ident)
353 371 elif ident is not None:
354 372 to_send.append(ident)
355 373 to_send.append(DELIM)
374 if self.key is not None:
375 to_send.append(self.key)
356 376 to_send.append(self.pack(msg['header']))
357 377 to_send.append(self.pack(msg['parent_header']))
358 378
359 379 if content is None:
360 380 content = self.none
361 381 elif isinstance(content, dict):
362 382 content = self.pack(content)
363 383 elif isinstance(content, str):
364 384 # content is already packed, as in a relayed message
365 385 pass
366 386 else:
367 387 raise TypeError("Content incorrect type: %s"%type(content))
368 388 to_send.append(content)
369 389 flag = 0
370 390 if buffers:
371 391 flag = zmq.SNDMORE
372 392 stream.send_multipart(to_send, flag, copy=False)
373 393 for b in buffers[:-1]:
374 394 stream.send(b, flag, copy=False)
375 395 if buffers:
376 396 stream.send(buffers[-1], copy=False)
377 397 omsg = Message(msg)
378 398 if self.debug:
379 399 pprint.pprint(omsg)
380 400 pprint.pprint(to_send)
381 401 pprint.pprint(buffers)
382 402 return omsg
383 403
384 404 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
385 405 """Send a raw message via idents.
386 406
387 407 Parameters
388 408 ----------
389 409 msg : list of sendable buffers"""
390 410 to_send = []
391 411 if isinstance(ident, str):
392 412 ident = [ident]
393 413 if ident is not None:
394 414 to_send.extend(ident)
395 415 to_send.append(DELIM)
416 if self.key is not None:
417 to_send.append(self.key)
396 418 to_send.extend(msg)
397 419 stream.send_multipart(msg, flags, copy=copy)
398 420
399 421 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
400 422 """receives and unpacks a message
401 423 returns [idents], msg"""
402 424 if isinstance(socket, ZMQStream):
403 425 socket = socket.socket
404 426 try:
405 427 msg = socket.recv_multipart(mode)
406 428 except zmq.ZMQError as e:
407 429 if e.errno == zmq.EAGAIN:
408 430 # We can convert EAGAIN to None as we know in this case
409 431 # recv_json won't return None.
410 432 return None
411 433 else:
412 434 raise
413 435 # return an actual Message object
414 436 # determine the number of idents by trying to unpack them.
415 437 # this is terrible:
416 438 idents, msg = self.feed_identities(msg, copy)
417 439 try:
418 440 return idents, self.unpack_message(msg, content=content, copy=copy)
419 441 except Exception as e:
420 442 print (idents, msg)
421 443 # TODO: handle it
422 444 raise e
423 445
424 446 def feed_identities(self, msg, copy=True):
425 447 """This is a completely horrible thing, but it strips the zmq
426 448 ident prefixes off of a message. It will break if any identities
427 449 are unpackable by self.unpack."""
428 450 msg = list(msg)
429 451 idents = []
430 452 while len(msg) > 3:
431 453 if copy:
432 454 s = msg[0]
433 455 else:
434 456 s = msg[0].bytes
435 457 if s == DELIM:
436 458 msg.pop(0)
437 459 break
438 460 else:
439 461 idents.append(s)
440 462 msg.pop(0)
441 463
442 464 return idents, msg
443 465
444 466 def unpack_message(self, msg, content=True, copy=True):
445 467 """Return a message object from the format
446 468 sent by self.send.
447 469
448 470 Parameters:
449 471 -----------
450 472
451 473 content : bool (True)
452 474 whether to unpack the content dict (True),
453 475 or leave it serialized (False)
454 476
455 477 copy : bool (True)
456 478 whether to return the bytes (True),
457 479 or the non-copying Message object in each place (False)
458 480
459 481 """
460 if not len(msg) >= 3:
461 raise TypeError("malformed message, must have at least 3 elements")
482 ikey = int(self.key is not None)
483 minlen = 3 + ikey
484 if not len(msg) >= minlen:
485 raise TypeError("malformed message, must have at least %i elements"%minlen)
462 486 message = {}
463 487 if not copy:
464 for i in range(3):
488 for i in range(minlen):
465 489 msg[i] = msg[i].bytes
466 message['header'] = self.unpack(msg[0])
490 if ikey:
491 if not self.key == msg[0]:
492 raise KeyError("Invalid Session Key: %s"%msg[0])
493 message['header'] = self.unpack(msg[ikey+0])
467 494 message['msg_type'] = message['header']['msg_type']
468 message['parent_header'] = self.unpack(msg[1])
495 message['parent_header'] = self.unpack(msg[ikey+1])
469 496 if content:
470 message['content'] = self.unpack(msg[2])
497 message['content'] = self.unpack(msg[ikey+2])
471 498 else:
472 message['content'] = msg[2]
499 message['content'] = msg[ikey+2]
473 500
474 501 # message['buffers'] = msg[3:]
475 502 # else:
476 503 # message['header'] = self.unpack(msg[0].bytes)
477 504 # message['msg_type'] = message['header']['msg_type']
478 505 # message['parent_header'] = self.unpack(msg[1].bytes)
479 506 # if content:
480 507 # message['content'] = self.unpack(msg[2].bytes)
481 508 # else:
482 509 # message['content'] = msg[2].bytes
483 510
484 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
511 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
485 512 return message
486 513
487 514
488 515
489 516 def test_msg2obj():
490 517 am = dict(x=1)
491 518 ao = Message(am)
492 519 assert ao.x == am['x']
493 520
494 521 am['y'] = dict(z=1)
495 522 ao = Message(am)
496 523 assert ao.y.z == am['y']['z']
497 524
498 525 k1, k2 = 'y', 'z'
499 526 assert ao[k1][k2] == am[k1][k2]
500 527
501 528 am2 = dict(ao)
502 529 assert am['x'] == am2['x']
503 530 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now