##// END OF EJS Templates
add Session.bsession trait for session id as bytes
MinRK -
Show More
@@ -1,1435 +1,1435 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 174 sshkey : str; path to ssh private key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Int(0)
283 283 _ignored_hub_replies=Int(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 try:
310 310 util.validate_url(url_or_file)
311 311 except AssertionError:
312 312 if not os.path.exists(url_or_file):
313 313 if self._cd:
314 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 316 with open(url_or_file) as f:
317 317 cfg = json.loads(f.read())
318 318 else:
319 319 cfg = {'url':url_or_file}
320 320
321 321 # sync defaults from args, json:
322 322 if sshserver:
323 323 cfg['ssh'] = sshserver
324 324 if exec_key:
325 325 cfg['exec_key'] = exec_key
326 326 exec_key = cfg['exec_key']
327 327 location = cfg.setdefault('location', None)
328 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 329 url = cfg['url']
330 330 proto,addr,port = util.split_url(url)
331 331 if location is not None and addr == '127.0.0.1':
332 332 # location specified, and connection is expected to be local
333 333 if location not in LOCAL_IPS and not sshserver:
334 334 # load ssh from JSON *only* if the controller is not on
335 335 # this machine
336 336 sshserver=cfg['ssh']
337 337 if location not in LOCAL_IPS and not sshserver:
338 338 # warn if no ssh specified, but SSH is probably needed
339 339 # This is only a warning, because the most likely cause
340 340 # is a local Controller on a laptop whose IP is dynamic
341 341 warnings.warn("""
342 342 Controller appears to be listening on localhost, but not on this machine.
343 343 If this is true, you should specify Client(...,sshserver='you@%s')
344 344 or instruct your controller to listen on an external IP."""%location,
345 345 RuntimeWarning)
346 346 elif not sshserver:
347 347 # otherwise sync with cfg
348 348 sshserver = cfg['ssh']
349 349
350 350 self._config = cfg
351 351
352 352 self._ssh = bool(sshserver or sshkey or password)
353 353 if self._ssh and sshserver is None:
354 354 # default to ssh via localhost
355 355 sshserver = url.split('://')[1].split(':')[0]
356 356 if self._ssh and password is None:
357 357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 358 password=False
359 359 else:
360 360 password = getpass("SSH Password for %s: "%sshserver)
361 361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362 362
363 363 # configure and construct the session
364 364 if exec_key is not None:
365 365 if os.path.isfile(exec_key):
366 366 extra_args['keyfile'] = exec_key
367 367 else:
368 368 exec_key = util.asbytes(exec_key)
369 369 extra_args['key'] = exec_key
370 370 self.session = Session(**extra_args)
371 371
372 372 self._query_socket = self._context.socket(zmq.DEALER)
373 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
373 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
374 374 if self._ssh:
375 375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 376 else:
377 377 self._query_socket.connect(url)
378 378
379 379 self.session.debug = self.debug
380 380
381 381 self._notification_handlers = {'registration_notification' : self._register_engine,
382 382 'unregistration_notification' : self._unregister_engine,
383 383 'shutdown_notification' : lambda msg: self.close(),
384 384 }
385 385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 386 'apply_reply' : self._handle_apply_reply}
387 387 self._connect(sshserver, ssh_kwargs, timeout)
388 388
389 389 def __del__(self):
390 390 """cleanup sockets, but _not_ context."""
391 391 self.close()
392 392
393 393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 394 if ipython_dir is None:
395 395 ipython_dir = get_ipython_dir()
396 396 if profile_dir is not None:
397 397 try:
398 398 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 399 return
400 400 except ProfileDirError:
401 401 pass
402 402 elif profile is not None:
403 403 try:
404 404 self._cd = ProfileDir.find_profile_dir_by_name(
405 405 ipython_dir, profile)
406 406 return
407 407 except ProfileDirError:
408 408 pass
409 409 self._cd = None
410 410
411 411 def _update_engines(self, engines):
412 412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 413 for k,v in engines.iteritems():
414 414 eid = int(k)
415 415 self._engines[eid] = v
416 416 self._ids.append(eid)
417 417 self._ids = sorted(self._ids)
418 418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 419 self._task_scheme == 'pure' and self._task_socket:
420 420 self._stop_scheduling_tasks()
421 421
422 422 def _stop_scheduling_tasks(self):
423 423 """Stop scheduling tasks because an engine has been unregistered
424 424 from a pure ZMQ scheduler.
425 425 """
426 426 self._task_socket.close()
427 427 self._task_socket = None
428 428 msg = "An engine has been unregistered, and we are using pure " +\
429 429 "ZMQ task scheduling. Task farming will be disabled."
430 430 if self.outstanding:
431 431 msg += " If you were running tasks when this happened, " +\
432 432 "some `outstanding` msg_ids may never resolve."
433 433 warnings.warn(msg, RuntimeWarning)
434 434
435 435 def _build_targets(self, targets):
436 436 """Turn valid target IDs or 'all' into two lists:
437 437 (int_ids, uuids).
438 438 """
439 439 if not self._ids:
440 440 # flush notification socket if no engines yet, just in case
441 441 if not self.ids:
442 442 raise error.NoEnginesRegistered("Can't build targets without any engines")
443 443
444 444 if targets is None:
445 445 targets = self._ids
446 446 elif isinstance(targets, basestring):
447 447 if targets.lower() == 'all':
448 448 targets = self._ids
449 449 else:
450 450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 451 elif isinstance(targets, int):
452 452 if targets < 0:
453 453 targets = self.ids[targets]
454 454 if targets not in self._ids:
455 455 raise IndexError("No such engine: %i"%targets)
456 456 targets = [targets]
457 457
458 458 if isinstance(targets, slice):
459 459 indices = range(len(self._ids))[targets]
460 460 ids = self.ids
461 461 targets = [ ids[i] for i in indices ]
462 462
463 463 if not isinstance(targets, (tuple, list, xrange)):
464 464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465 465
466 466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467 467
468 468 def _connect(self, sshserver, ssh_kwargs, timeout):
469 469 """setup all our socket connections to the cluster. This is called from
470 470 __init__."""
471 471
472 472 # Maybe allow reconnecting?
473 473 if self._connected:
474 474 return
475 475 self._connected=True
476 476
477 477 def connect_socket(s, url):
478 478 url = util.disambiguate_url(url, self._config['location'])
479 479 if self._ssh:
480 480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 481 else:
482 482 return s.connect(url)
483 483
484 484 self.session.send(self._query_socket, 'connection_request')
485 485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 486 poller = zmq.Poller()
487 487 poller.register(self._query_socket, zmq.POLLIN)
488 488 # poll expects milliseconds, timeout is seconds
489 489 evts = poller.poll(timeout*1000)
490 490 if not evts:
491 491 raise error.TimeoutError("Hub connection request timed out")
492 492 idents,msg = self.session.recv(self._query_socket,mode=0)
493 493 if self.debug:
494 494 pprint(msg)
495 495 msg = Message(msg)
496 496 content = msg.content
497 497 self._config['registration'] = dict(content)
498 498 if content.status == 'ok':
499 ident = util.asbytes(self.session.session)
499 ident = self.session.bsession
500 500 if content.mux:
501 501 self._mux_socket = self._context.socket(zmq.DEALER)
502 502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 503 connect_socket(self._mux_socket, content.mux)
504 504 if content.task:
505 505 self._task_scheme, task_addr = content.task
506 506 self._task_socket = self._context.socket(zmq.DEALER)
507 507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 508 connect_socket(self._task_socket, task_addr)
509 509 if content.notification:
510 510 self._notification_socket = self._context.socket(zmq.SUB)
511 511 connect_socket(self._notification_socket, content.notification)
512 512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 513 # if content.query:
514 514 # self._query_socket = self._context.socket(zmq.DEALER)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
516 516 # connect_socket(self._query_socket, content.query)
517 517 if content.control:
518 518 self._control_socket = self._context.socket(zmq.DEALER)
519 519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 520 connect_socket(self._control_socket, content.control)
521 521 if content.iopub:
522 522 self._iopub_socket = self._context.socket(zmq.SUB)
523 523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 525 connect_socket(self._iopub_socket, content.iopub)
526 526 self._update_engines(dict(content.engines))
527 527 else:
528 528 self._connected = False
529 529 raise Exception("Failed to connect!")
530 530
531 531 #--------------------------------------------------------------------------
532 532 # handlers and callbacks for incoming messages
533 533 #--------------------------------------------------------------------------
534 534
535 535 def _unwrap_exception(self, content):
536 536 """unwrap exception, and remap engine_id to int."""
537 537 e = error.unwrap_exception(content)
538 538 # print e.traceback
539 539 if e.engine_info:
540 540 e_uuid = e.engine_info['engine_uuid']
541 541 eid = self._engines[e_uuid]
542 542 e.engine_info['engine_id'] = eid
543 543 return e
544 544
545 545 def _extract_metadata(self, header, parent, content):
546 546 md = {'msg_id' : parent['msg_id'],
547 547 'received' : datetime.now(),
548 548 'engine_uuid' : header.get('engine', None),
549 549 'follow' : parent.get('follow', []),
550 550 'after' : parent.get('after', []),
551 551 'status' : content['status'],
552 552 }
553 553
554 554 if md['engine_uuid'] is not None:
555 555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556 556
557 557 if 'date' in parent:
558 558 md['submitted'] = parent['date']
559 559 if 'started' in header:
560 560 md['started'] = header['started']
561 561 if 'date' in header:
562 562 md['completed'] = header['date']
563 563 return md
564 564
565 565 def _register_engine(self, msg):
566 566 """Register a new engine, and update our connection info."""
567 567 content = msg['content']
568 568 eid = content['id']
569 569 d = {eid : content['queue']}
570 570 self._update_engines(d)
571 571
572 572 def _unregister_engine(self, msg):
573 573 """Unregister an engine that has died."""
574 574 content = msg['content']
575 575 eid = int(content['id'])
576 576 if eid in self._ids:
577 577 self._ids.remove(eid)
578 578 uuid = self._engines.pop(eid)
579 579
580 580 self._handle_stranded_msgs(eid, uuid)
581 581
582 582 if self._task_socket and self._task_scheme == 'pure':
583 583 self._stop_scheduling_tasks()
584 584
585 585 def _handle_stranded_msgs(self, eid, uuid):
586 586 """Handle messages known to be on an engine when the engine unregisters.
587 587
588 588 It is possible that this will fire prematurely - that is, an engine will
589 589 go down after completing a result, and the client will be notified
590 590 of the unregistration and later receive the successful result.
591 591 """
592 592
593 593 outstanding = self._outstanding_dict[uuid]
594 594
595 595 for msg_id in list(outstanding):
596 596 if msg_id in self.results:
597 597 # we already
598 598 continue
599 599 try:
600 600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 601 except:
602 602 content = error.wrap_exception()
603 603 # build a fake message:
604 604 parent = {}
605 605 header = {}
606 606 parent['msg_id'] = msg_id
607 607 header['engine'] = uuid
608 608 header['date'] = datetime.now()
609 609 msg = dict(parent_header=parent, header=header, content=content)
610 610 self._handle_apply_reply(msg)
611 611
612 612 def _handle_execute_reply(self, msg):
613 613 """Save the reply to an execute_request into our results.
614 614
615 615 execute messages are never actually used. apply is used instead.
616 616 """
617 617
618 618 parent = msg['parent_header']
619 619 msg_id = parent['msg_id']
620 620 if msg_id not in self.outstanding:
621 621 if msg_id in self.history:
622 622 print ("got stale result: %s"%msg_id)
623 623 else:
624 624 print ("got unknown result: %s"%msg_id)
625 625 else:
626 626 self.outstanding.remove(msg_id)
627 627 self.results[msg_id] = self._unwrap_exception(msg['content'])
628 628
629 629 def _handle_apply_reply(self, msg):
630 630 """Save the reply to an apply_request into our results."""
631 631 parent = msg['parent_header']
632 632 msg_id = parent['msg_id']
633 633 if msg_id not in self.outstanding:
634 634 if msg_id in self.history:
635 635 print ("got stale result: %s"%msg_id)
636 636 print self.results[msg_id]
637 637 print msg
638 638 else:
639 639 print ("got unknown result: %s"%msg_id)
640 640 else:
641 641 self.outstanding.remove(msg_id)
642 642 content = msg['content']
643 643 header = msg['header']
644 644
645 645 # construct metadata:
646 646 md = self.metadata[msg_id]
647 647 md.update(self._extract_metadata(header, parent, content))
648 648 # is this redundant?
649 649 self.metadata[msg_id] = md
650 650
651 651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 652 if msg_id in e_outstanding:
653 653 e_outstanding.remove(msg_id)
654 654
655 655 # construct result:
656 656 if content['status'] == 'ok':
657 657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 658 elif content['status'] == 'aborted':
659 659 self.results[msg_id] = error.TaskAborted(msg_id)
660 660 elif content['status'] == 'resubmitted':
661 661 # TODO: handle resubmission
662 662 pass
663 663 else:
664 664 self.results[msg_id] = self._unwrap_exception(content)
665 665
666 666 def _flush_notifications(self):
667 667 """Flush notifications of engine registrations waiting
668 668 in ZMQ queue."""
669 669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 670 while msg is not None:
671 671 if self.debug:
672 672 pprint(msg)
673 673 msg_type = msg['header']['msg_type']
674 674 handler = self._notification_handlers.get(msg_type, None)
675 675 if handler is None:
676 676 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 677 else:
678 678 handler(msg)
679 679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680 680
681 681 def _flush_results(self, sock):
682 682 """Flush task or queue results waiting in ZMQ queue."""
683 683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 684 while msg is not None:
685 685 if self.debug:
686 686 pprint(msg)
687 687 msg_type = msg['header']['msg_type']
688 688 handler = self._queue_handlers.get(msg_type, None)
689 689 if handler is None:
690 690 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 691 else:
692 692 handler(msg)
693 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 694
695 695 def _flush_control(self, sock):
696 696 """Flush replies from the control channel waiting
697 697 in the ZMQ queue.
698 698
699 699 Currently: ignore them."""
700 700 if self._ignored_control_replies <= 0:
701 701 return
702 702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 703 while msg is not None:
704 704 self._ignored_control_replies -= 1
705 705 if self.debug:
706 706 pprint(msg)
707 707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708 708
709 709 def _flush_ignored_control(self):
710 710 """flush ignored control replies"""
711 711 while self._ignored_control_replies > 0:
712 712 self.session.recv(self._control_socket)
713 713 self._ignored_control_replies -= 1
714 714
715 715 def _flush_ignored_hub_replies(self):
716 716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 717 while msg is not None:
718 718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719 719
720 720 def _flush_iopub(self, sock):
721 721 """Flush replies from the iopub channel waiting
722 722 in the ZMQ queue.
723 723 """
724 724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 725 while msg is not None:
726 726 if self.debug:
727 727 pprint(msg)
728 728 parent = msg['parent_header']
729 729 msg_id = parent['msg_id']
730 730 content = msg['content']
731 731 header = msg['header']
732 732 msg_type = msg['header']['msg_type']
733 733
734 734 # init metadata:
735 735 md = self.metadata[msg_id]
736 736
737 737 if msg_type == 'stream':
738 738 name = content['name']
739 739 s = md[name] or ''
740 740 md[name] = s + content['data']
741 741 elif msg_type == 'pyerr':
742 742 md.update({'pyerr' : self._unwrap_exception(content)})
743 743 elif msg_type == 'pyin':
744 744 md.update({'pyin' : content['code']})
745 745 else:
746 746 md.update({msg_type : content.get('data', '')})
747 747
748 748 # reduntant?
749 749 self.metadata[msg_id] = md
750 750
751 751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
752 752
753 753 #--------------------------------------------------------------------------
754 754 # len, getitem
755 755 #--------------------------------------------------------------------------
756 756
757 757 def __len__(self):
758 758 """len(client) returns # of engines."""
759 759 return len(self.ids)
760 760
761 761 def __getitem__(self, key):
762 762 """index access returns DirectView multiplexer objects
763 763
764 764 Must be int, slice, or list/tuple/xrange of ints"""
765 765 if not isinstance(key, (int, slice, tuple, list, xrange)):
766 766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
767 767 else:
768 768 return self.direct_view(key)
769 769
770 770 #--------------------------------------------------------------------------
771 771 # Begin public methods
772 772 #--------------------------------------------------------------------------
773 773
774 774 @property
775 775 def ids(self):
776 776 """Always up-to-date ids property."""
777 777 self._flush_notifications()
778 778 # always copy:
779 779 return list(self._ids)
780 780
781 781 def close(self):
782 782 if self._closed:
783 783 return
784 784 snames = filter(lambda n: n.endswith('socket'), dir(self))
785 785 for socket in map(lambda name: getattr(self, name), snames):
786 786 if isinstance(socket, zmq.Socket) and not socket.closed:
787 787 socket.close()
788 788 self._closed = True
789 789
790 790 def spin(self):
791 791 """Flush any registration notifications and execution results
792 792 waiting in the ZMQ queue.
793 793 """
794 794 if self._notification_socket:
795 795 self._flush_notifications()
796 796 if self._mux_socket:
797 797 self._flush_results(self._mux_socket)
798 798 if self._task_socket:
799 799 self._flush_results(self._task_socket)
800 800 if self._control_socket:
801 801 self._flush_control(self._control_socket)
802 802 if self._iopub_socket:
803 803 self._flush_iopub(self._iopub_socket)
804 804 if self._query_socket:
805 805 self._flush_ignored_hub_replies()
806 806
807 807 def wait(self, jobs=None, timeout=-1):
808 808 """waits on one or more `jobs`, for up to `timeout` seconds.
809 809
810 810 Parameters
811 811 ----------
812 812
813 813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
814 814 ints are indices to self.history
815 815 strs are msg_ids
816 816 default: wait on all outstanding messages
817 817 timeout : float
818 818 a time in seconds, after which to give up.
819 819 default is -1, which means no timeout
820 820
821 821 Returns
822 822 -------
823 823
824 824 True : when all msg_ids are done
825 825 False : timeout reached, some msg_ids still outstanding
826 826 """
827 827 tic = time.time()
828 828 if jobs is None:
829 829 theids = self.outstanding
830 830 else:
831 831 if isinstance(jobs, (int, basestring, AsyncResult)):
832 832 jobs = [jobs]
833 833 theids = set()
834 834 for job in jobs:
835 835 if isinstance(job, int):
836 836 # index access
837 837 job = self.history[job]
838 838 elif isinstance(job, AsyncResult):
839 839 map(theids.add, job.msg_ids)
840 840 continue
841 841 theids.add(job)
842 842 if not theids.intersection(self.outstanding):
843 843 return True
844 844 self.spin()
845 845 while theids.intersection(self.outstanding):
846 846 if timeout >= 0 and ( time.time()-tic ) > timeout:
847 847 break
848 848 time.sleep(1e-3)
849 849 self.spin()
850 850 return len(theids.intersection(self.outstanding)) == 0
851 851
852 852 #--------------------------------------------------------------------------
853 853 # Control methods
854 854 #--------------------------------------------------------------------------
855 855
856 856 @spin_first
857 857 def clear(self, targets=None, block=None):
858 858 """Clear the namespace in target(s)."""
859 859 block = self.block if block is None else block
860 860 targets = self._build_targets(targets)[0]
861 861 for t in targets:
862 862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
863 863 error = False
864 864 if block:
865 865 self._flush_ignored_control()
866 866 for i in range(len(targets)):
867 867 idents,msg = self.session.recv(self._control_socket,0)
868 868 if self.debug:
869 869 pprint(msg)
870 870 if msg['content']['status'] != 'ok':
871 871 error = self._unwrap_exception(msg['content'])
872 872 else:
873 873 self._ignored_control_replies += len(targets)
874 874 if error:
875 875 raise error
876 876
877 877
878 878 @spin_first
879 879 def abort(self, jobs=None, targets=None, block=None):
880 880 """Abort specific jobs from the execution queues of target(s).
881 881
882 882 This is a mechanism to prevent jobs that have already been submitted
883 883 from executing.
884 884
885 885 Parameters
886 886 ----------
887 887
888 888 jobs : msg_id, list of msg_ids, or AsyncResult
889 889 The jobs to be aborted
890 890
891 891
892 892 """
893 893 block = self.block if block is None else block
894 894 targets = self._build_targets(targets)[0]
895 895 msg_ids = []
896 896 if isinstance(jobs, (basestring,AsyncResult)):
897 897 jobs = [jobs]
898 898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
899 899 if bad_ids:
900 900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
901 901 for j in jobs:
902 902 if isinstance(j, AsyncResult):
903 903 msg_ids.extend(j.msg_ids)
904 904 else:
905 905 msg_ids.append(j)
906 906 content = dict(msg_ids=msg_ids)
907 907 for t in targets:
908 908 self.session.send(self._control_socket, 'abort_request',
909 909 content=content, ident=t)
910 910 error = False
911 911 if block:
912 912 self._flush_ignored_control()
913 913 for i in range(len(targets)):
914 914 idents,msg = self.session.recv(self._control_socket,0)
915 915 if self.debug:
916 916 pprint(msg)
917 917 if msg['content']['status'] != 'ok':
918 918 error = self._unwrap_exception(msg['content'])
919 919 else:
920 920 self._ignored_control_replies += len(targets)
921 921 if error:
922 922 raise error
923 923
924 924 @spin_first
925 925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
926 926 """Terminates one or more engine processes, optionally including the hub."""
927 927 block = self.block if block is None else block
928 928 if hub:
929 929 targets = 'all'
930 930 targets = self._build_targets(targets)[0]
931 931 for t in targets:
932 932 self.session.send(self._control_socket, 'shutdown_request',
933 933 content={'restart':restart},ident=t)
934 934 error = False
935 935 if block or hub:
936 936 self._flush_ignored_control()
937 937 for i in range(len(targets)):
938 938 idents,msg = self.session.recv(self._control_socket, 0)
939 939 if self.debug:
940 940 pprint(msg)
941 941 if msg['content']['status'] != 'ok':
942 942 error = self._unwrap_exception(msg['content'])
943 943 else:
944 944 self._ignored_control_replies += len(targets)
945 945
946 946 if hub:
947 947 time.sleep(0.25)
948 948 self.session.send(self._query_socket, 'shutdown_request')
949 949 idents,msg = self.session.recv(self._query_socket, 0)
950 950 if self.debug:
951 951 pprint(msg)
952 952 if msg['content']['status'] != 'ok':
953 953 error = self._unwrap_exception(msg['content'])
954 954
955 955 if error:
956 956 raise error
957 957
958 958 #--------------------------------------------------------------------------
959 959 # Execution related methods
960 960 #--------------------------------------------------------------------------
961 961
962 962 def _maybe_raise(self, result):
963 963 """wrapper for maybe raising an exception if apply failed."""
964 964 if isinstance(result, error.RemoteError):
965 965 raise result
966 966
967 967 return result
968 968
969 969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
970 970 ident=None):
971 971 """construct and send an apply message via a socket.
972 972
973 973 This is the principal method with which all engine execution is performed by views.
974 974 """
975 975
976 976 assert not self._closed, "cannot use me anymore, I'm closed!"
977 977 # defaults:
978 978 args = args if args is not None else []
979 979 kwargs = kwargs if kwargs is not None else {}
980 980 subheader = subheader if subheader is not None else {}
981 981
982 982 # validate arguments
983 983 if not callable(f):
984 984 raise TypeError("f must be callable, not %s"%type(f))
985 985 if not isinstance(args, (tuple, list)):
986 986 raise TypeError("args must be tuple or list, not %s"%type(args))
987 987 if not isinstance(kwargs, dict):
988 988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
989 989 if not isinstance(subheader, dict):
990 990 raise TypeError("subheader must be dict, not %s"%type(subheader))
991 991
992 992 bufs = util.pack_apply_message(f,args,kwargs)
993 993
994 994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 995 subheader=subheader, track=track)
996 996
997 997 msg_id = msg['header']['msg_id']
998 998 self.outstanding.add(msg_id)
999 999 if ident:
1000 1000 # possibly routed to a specific engine
1001 1001 if isinstance(ident, list):
1002 1002 ident = ident[-1]
1003 1003 if ident in self._engines.values():
1004 1004 # save for later, in case of engine death
1005 1005 self._outstanding_dict[ident].add(msg_id)
1006 1006 self.history.append(msg_id)
1007 1007 self.metadata[msg_id]['submitted'] = datetime.now()
1008 1008
1009 1009 return msg
1010 1010
1011 1011 #--------------------------------------------------------------------------
1012 1012 # construct a View object
1013 1013 #--------------------------------------------------------------------------
1014 1014
1015 1015 def load_balanced_view(self, targets=None):
1016 1016 """construct a DirectView object.
1017 1017
1018 1018 If no arguments are specified, create a LoadBalancedView
1019 1019 using all engines.
1020 1020
1021 1021 Parameters
1022 1022 ----------
1023 1023
1024 1024 targets: list,slice,int,etc. [default: use all engines]
1025 1025 The subset of engines across which to load-balance
1026 1026 """
1027 1027 if targets == 'all':
1028 1028 targets = None
1029 1029 if targets is not None:
1030 1030 targets = self._build_targets(targets)[1]
1031 1031 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1032 1032
1033 1033 def direct_view(self, targets='all'):
1034 1034 """construct a DirectView object.
1035 1035
1036 1036 If no targets are specified, create a DirectView
1037 1037 using all engines.
1038 1038
1039 1039 Parameters
1040 1040 ----------
1041 1041
1042 1042 targets: list,slice,int,etc. [default: use all engines]
1043 1043 The engines to use for the View
1044 1044 """
1045 1045 single = isinstance(targets, int)
1046 1046 # allow 'all' to be lazily evaluated at each execution
1047 1047 if targets != 'all':
1048 1048 targets = self._build_targets(targets)[1]
1049 1049 if single:
1050 1050 targets = targets[0]
1051 1051 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1052 1052
1053 1053 #--------------------------------------------------------------------------
1054 1054 # Query methods
1055 1055 #--------------------------------------------------------------------------
1056 1056
1057 1057 @spin_first
1058 1058 def get_result(self, indices_or_msg_ids=None, block=None):
1059 1059 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1060 1060
1061 1061 If the client already has the results, no request to the Hub will be made.
1062 1062
1063 1063 This is a convenient way to construct AsyncResult objects, which are wrappers
1064 1064 that include metadata about execution, and allow for awaiting results that
1065 1065 were not submitted by this Client.
1066 1066
1067 1067 It can also be a convenient way to retrieve the metadata associated with
1068 1068 blocking execution, since it always retrieves
1069 1069
1070 1070 Examples
1071 1071 --------
1072 1072 ::
1073 1073
1074 1074 In [10]: r = client.apply()
1075 1075
1076 1076 Parameters
1077 1077 ----------
1078 1078
1079 1079 indices_or_msg_ids : integer history index, str msg_id, or list of either
1080 1080 The indices or msg_ids of indices to be retrieved
1081 1081
1082 1082 block : bool
1083 1083 Whether to wait for the result to be done
1084 1084
1085 1085 Returns
1086 1086 -------
1087 1087
1088 1088 AsyncResult
1089 1089 A single AsyncResult object will always be returned.
1090 1090
1091 1091 AsyncHubResult
1092 1092 A subclass of AsyncResult that retrieves results from the Hub
1093 1093
1094 1094 """
1095 1095 block = self.block if block is None else block
1096 1096 if indices_or_msg_ids is None:
1097 1097 indices_or_msg_ids = -1
1098 1098
1099 1099 if not isinstance(indices_or_msg_ids, (list,tuple)):
1100 1100 indices_or_msg_ids = [indices_or_msg_ids]
1101 1101
1102 1102 theids = []
1103 1103 for id in indices_or_msg_ids:
1104 1104 if isinstance(id, int):
1105 1105 id = self.history[id]
1106 1106 if not isinstance(id, basestring):
1107 1107 raise TypeError("indices must be str or int, not %r"%id)
1108 1108 theids.append(id)
1109 1109
1110 1110 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1111 1111 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1112 1112
1113 1113 if remote_ids:
1114 1114 ar = AsyncHubResult(self, msg_ids=theids)
1115 1115 else:
1116 1116 ar = AsyncResult(self, msg_ids=theids)
1117 1117
1118 1118 if block:
1119 1119 ar.wait()
1120 1120
1121 1121 return ar
1122 1122
1123 1123 @spin_first
1124 1124 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1125 1125 """Resubmit one or more tasks.
1126 1126
1127 1127 in-flight tasks may not be resubmitted.
1128 1128
1129 1129 Parameters
1130 1130 ----------
1131 1131
1132 1132 indices_or_msg_ids : integer history index, str msg_id, or list of either
1133 1133 The indices or msg_ids of indices to be retrieved
1134 1134
1135 1135 block : bool
1136 1136 Whether to wait for the result to be done
1137 1137
1138 1138 Returns
1139 1139 -------
1140 1140
1141 1141 AsyncHubResult
1142 1142 A subclass of AsyncResult that retrieves results from the Hub
1143 1143
1144 1144 """
1145 1145 block = self.block if block is None else block
1146 1146 if indices_or_msg_ids is None:
1147 1147 indices_or_msg_ids = -1
1148 1148
1149 1149 if not isinstance(indices_or_msg_ids, (list,tuple)):
1150 1150 indices_or_msg_ids = [indices_or_msg_ids]
1151 1151
1152 1152 theids = []
1153 1153 for id in indices_or_msg_ids:
1154 1154 if isinstance(id, int):
1155 1155 id = self.history[id]
1156 1156 if not isinstance(id, basestring):
1157 1157 raise TypeError("indices must be str or int, not %r"%id)
1158 1158 theids.append(id)
1159 1159
1160 1160 for msg_id in theids:
1161 1161 self.outstanding.discard(msg_id)
1162 1162 if msg_id in self.history:
1163 1163 self.history.remove(msg_id)
1164 1164 self.results.pop(msg_id, None)
1165 1165 self.metadata.pop(msg_id, None)
1166 1166 content = dict(msg_ids = theids)
1167 1167
1168 1168 self.session.send(self._query_socket, 'resubmit_request', content)
1169 1169
1170 1170 zmq.select([self._query_socket], [], [])
1171 1171 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1172 1172 if self.debug:
1173 1173 pprint(msg)
1174 1174 content = msg['content']
1175 1175 if content['status'] != 'ok':
1176 1176 raise self._unwrap_exception(content)
1177 1177
1178 1178 ar = AsyncHubResult(self, msg_ids=theids)
1179 1179
1180 1180 if block:
1181 1181 ar.wait()
1182 1182
1183 1183 return ar
1184 1184
1185 1185 @spin_first
1186 1186 def result_status(self, msg_ids, status_only=True):
1187 1187 """Check on the status of the result(s) of the apply request with `msg_ids`.
1188 1188
1189 1189 If status_only is False, then the actual results will be retrieved, else
1190 1190 only the status of the results will be checked.
1191 1191
1192 1192 Parameters
1193 1193 ----------
1194 1194
1195 1195 msg_ids : list of msg_ids
1196 1196 if int:
1197 1197 Passed as index to self.history for convenience.
1198 1198 status_only : bool (default: True)
1199 1199 if False:
1200 1200 Retrieve the actual results of completed tasks.
1201 1201
1202 1202 Returns
1203 1203 -------
1204 1204
1205 1205 results : dict
1206 1206 There will always be the keys 'pending' and 'completed', which will
1207 1207 be lists of msg_ids that are incomplete or complete. If `status_only`
1208 1208 is False, then completed results will be keyed by their `msg_id`.
1209 1209 """
1210 1210 if not isinstance(msg_ids, (list,tuple)):
1211 1211 msg_ids = [msg_ids]
1212 1212
1213 1213 theids = []
1214 1214 for msg_id in msg_ids:
1215 1215 if isinstance(msg_id, int):
1216 1216 msg_id = self.history[msg_id]
1217 1217 if not isinstance(msg_id, basestring):
1218 1218 raise TypeError("msg_ids must be str, not %r"%msg_id)
1219 1219 theids.append(msg_id)
1220 1220
1221 1221 completed = []
1222 1222 local_results = {}
1223 1223
1224 1224 # comment this block out to temporarily disable local shortcut:
1225 1225 for msg_id in theids:
1226 1226 if msg_id in self.results:
1227 1227 completed.append(msg_id)
1228 1228 local_results[msg_id] = self.results[msg_id]
1229 1229 theids.remove(msg_id)
1230 1230
1231 1231 if theids: # some not locally cached
1232 1232 content = dict(msg_ids=theids, status_only=status_only)
1233 1233 msg = self.session.send(self._query_socket, "result_request", content=content)
1234 1234 zmq.select([self._query_socket], [], [])
1235 1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1236 1236 if self.debug:
1237 1237 pprint(msg)
1238 1238 content = msg['content']
1239 1239 if content['status'] != 'ok':
1240 1240 raise self._unwrap_exception(content)
1241 1241 buffers = msg['buffers']
1242 1242 else:
1243 1243 content = dict(completed=[],pending=[])
1244 1244
1245 1245 content['completed'].extend(completed)
1246 1246
1247 1247 if status_only:
1248 1248 return content
1249 1249
1250 1250 failures = []
1251 1251 # load cached results into result:
1252 1252 content.update(local_results)
1253 1253
1254 1254 # update cache with results:
1255 1255 for msg_id in sorted(theids):
1256 1256 if msg_id in content['completed']:
1257 1257 rec = content[msg_id]
1258 1258 parent = rec['header']
1259 1259 header = rec['result_header']
1260 1260 rcontent = rec['result_content']
1261 1261 iodict = rec['io']
1262 1262 if isinstance(rcontent, str):
1263 1263 rcontent = self.session.unpack(rcontent)
1264 1264
1265 1265 md = self.metadata[msg_id]
1266 1266 md.update(self._extract_metadata(header, parent, rcontent))
1267 1267 md.update(iodict)
1268 1268
1269 1269 if rcontent['status'] == 'ok':
1270 1270 res,buffers = util.unserialize_object(buffers)
1271 1271 else:
1272 1272 print rcontent
1273 1273 res = self._unwrap_exception(rcontent)
1274 1274 failures.append(res)
1275 1275
1276 1276 self.results[msg_id] = res
1277 1277 content[msg_id] = res
1278 1278
1279 1279 if len(theids) == 1 and failures:
1280 1280 raise failures[0]
1281 1281
1282 1282 error.collect_exceptions(failures, "result_status")
1283 1283 return content
1284 1284
1285 1285 @spin_first
1286 1286 def queue_status(self, targets='all', verbose=False):
1287 1287 """Fetch the status of engine queues.
1288 1288
1289 1289 Parameters
1290 1290 ----------
1291 1291
1292 1292 targets : int/str/list of ints/strs
1293 1293 the engines whose states are to be queried.
1294 1294 default : all
1295 1295 verbose : bool
1296 1296 Whether to return lengths only, or lists of ids for each element
1297 1297 """
1298 1298 engine_ids = self._build_targets(targets)[1]
1299 1299 content = dict(targets=engine_ids, verbose=verbose)
1300 1300 self.session.send(self._query_socket, "queue_request", content=content)
1301 1301 idents,msg = self.session.recv(self._query_socket, 0)
1302 1302 if self.debug:
1303 1303 pprint(msg)
1304 1304 content = msg['content']
1305 1305 status = content.pop('status')
1306 1306 if status != 'ok':
1307 1307 raise self._unwrap_exception(content)
1308 1308 content = rekey(content)
1309 1309 if isinstance(targets, int):
1310 1310 return content[targets]
1311 1311 else:
1312 1312 return content
1313 1313
1314 1314 @spin_first
1315 1315 def purge_results(self, jobs=[], targets=[]):
1316 1316 """Tell the Hub to forget results.
1317 1317
1318 1318 Individual results can be purged by msg_id, or the entire
1319 1319 history of specific targets can be purged.
1320 1320
1321 1321 Use `purge_results('all')` to scrub everything from the Hub's db.
1322 1322
1323 1323 Parameters
1324 1324 ----------
1325 1325
1326 1326 jobs : str or list of str or AsyncResult objects
1327 1327 the msg_ids whose results should be forgotten.
1328 1328 targets : int/str/list of ints/strs
1329 1329 The targets, by int_id, whose entire history is to be purged.
1330 1330
1331 1331 default : None
1332 1332 """
1333 1333 if not targets and not jobs:
1334 1334 raise ValueError("Must specify at least one of `targets` and `jobs`")
1335 1335 if targets:
1336 1336 targets = self._build_targets(targets)[1]
1337 1337
1338 1338 # construct msg_ids from jobs
1339 1339 if jobs == 'all':
1340 1340 msg_ids = jobs
1341 1341 else:
1342 1342 msg_ids = []
1343 1343 if isinstance(jobs, (basestring,AsyncResult)):
1344 1344 jobs = [jobs]
1345 1345 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1346 1346 if bad_ids:
1347 1347 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1348 1348 for j in jobs:
1349 1349 if isinstance(j, AsyncResult):
1350 1350 msg_ids.extend(j.msg_ids)
1351 1351 else:
1352 1352 msg_ids.append(j)
1353 1353
1354 1354 content = dict(engine_ids=targets, msg_ids=msg_ids)
1355 1355 self.session.send(self._query_socket, "purge_request", content=content)
1356 1356 idents, msg = self.session.recv(self._query_socket, 0)
1357 1357 if self.debug:
1358 1358 pprint(msg)
1359 1359 content = msg['content']
1360 1360 if content['status'] != 'ok':
1361 1361 raise self._unwrap_exception(content)
1362 1362
1363 1363 @spin_first
1364 1364 def hub_history(self):
1365 1365 """Get the Hub's history
1366 1366
1367 1367 Just like the Client, the Hub has a history, which is a list of msg_ids.
1368 1368 This will contain the history of all clients, and, depending on configuration,
1369 1369 may contain history across multiple cluster sessions.
1370 1370
1371 1371 Any msg_id returned here is a valid argument to `get_result`.
1372 1372
1373 1373 Returns
1374 1374 -------
1375 1375
1376 1376 msg_ids : list of strs
1377 1377 list of all msg_ids, ordered by task submission time.
1378 1378 """
1379 1379
1380 1380 self.session.send(self._query_socket, "history_request", content={})
1381 1381 idents, msg = self.session.recv(self._query_socket, 0)
1382 1382
1383 1383 if self.debug:
1384 1384 pprint(msg)
1385 1385 content = msg['content']
1386 1386 if content['status'] != 'ok':
1387 1387 raise self._unwrap_exception(content)
1388 1388 else:
1389 1389 return content['history']
1390 1390
1391 1391 @spin_first
1392 1392 def db_query(self, query, keys=None):
1393 1393 """Query the Hub's TaskRecord database
1394 1394
1395 1395 This will return a list of task record dicts that match `query`
1396 1396
1397 1397 Parameters
1398 1398 ----------
1399 1399
1400 1400 query : mongodb query dict
1401 1401 The search dict. See mongodb query docs for details.
1402 1402 keys : list of strs [optional]
1403 1403 The subset of keys to be returned. The default is to fetch everything but buffers.
1404 1404 'msg_id' will *always* be included.
1405 1405 """
1406 1406 if isinstance(keys, basestring):
1407 1407 keys = [keys]
1408 1408 content = dict(query=query, keys=keys)
1409 1409 self.session.send(self._query_socket, "db_request", content=content)
1410 1410 idents, msg = self.session.recv(self._query_socket, 0)
1411 1411 if self.debug:
1412 1412 pprint(msg)
1413 1413 content = msg['content']
1414 1414 if content['status'] != 'ok':
1415 1415 raise self._unwrap_exception(content)
1416 1416
1417 1417 records = content['records']
1418 1418
1419 1419 buffer_lens = content['buffer_lens']
1420 1420 result_buffer_lens = content['result_buffer_lens']
1421 1421 buffers = msg['buffers']
1422 1422 has_bufs = buffer_lens is not None
1423 1423 has_rbufs = result_buffer_lens is not None
1424 1424 for i,rec in enumerate(records):
1425 1425 # relink buffers
1426 1426 if has_bufs:
1427 1427 blen = buffer_lens[i]
1428 1428 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1429 1429 if has_rbufs:
1430 1430 blen = result_buffer_lens[i]
1431 1431 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1432 1432
1433 1433 return records
1434 1434
1435 1435 __all__ = [ 'Client' ]
@@ -1,1290 +1,1290 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.traitlets import (
32 32 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 33 )
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import RegistrationFactory
37 37
38 38 from IPython.zmq.session import SessionFactory
39 39
40 40 from .heartmonitor import HeartMonitor
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Code
44 44 #-----------------------------------------------------------------------------
45 45
46 46 def _passer(*args, **kwargs):
47 47 return
48 48
49 49 def _printer(*args, **kwargs):
50 50 print (args)
51 51 print (kwargs)
52 52
53 53 def empty_record():
54 54 """Return an empty dict with all record keys."""
55 55 return {
56 56 'msg_id' : None,
57 57 'header' : None,
58 58 'content': None,
59 59 'buffers': None,
60 60 'submitted': None,
61 61 'client_uuid' : None,
62 62 'engine_uuid' : None,
63 63 'started': None,
64 64 'completed': None,
65 65 'resubmitted': None,
66 66 'result_header' : None,
67 67 'result_content' : None,
68 68 'result_buffers' : None,
69 69 'queue' : None,
70 70 'pyin' : None,
71 71 'pyout': None,
72 72 'pyerr': None,
73 73 'stdout': '',
74 74 'stderr': '',
75 75 }
76 76
77 77 def init_record(msg):
78 78 """Initialize a TaskRecord based on a request."""
79 79 header = msg['header']
80 80 return {
81 81 'msg_id' : header['msg_id'],
82 82 'header' : header,
83 83 'content': msg['content'],
84 84 'buffers': msg['buffers'],
85 85 'submitted': header['date'],
86 86 'client_uuid' : None,
87 87 'engine_uuid' : None,
88 88 'started': None,
89 89 'completed': None,
90 90 'resubmitted': None,
91 91 'result_header' : None,
92 92 'result_content' : None,
93 93 'result_buffers' : None,
94 94 'queue' : None,
95 95 'pyin' : None,
96 96 'pyout': None,
97 97 'pyerr': None,
98 98 'stdout': '',
99 99 'stderr': '',
100 100 }
101 101
102 102
103 103 class EngineConnector(HasTraits):
104 104 """A simple object for accessing the various zmq connections of an object.
105 105 Attributes are:
106 106 id (int): engine ID
107 107 uuid (str): uuid (unused?)
108 108 queue (str): identity of queue's XREQ socket
109 109 registration (str): identity of registration XREQ socket
110 110 heartbeat (str): identity of heartbeat XREQ socket
111 111 """
112 112 id=Int(0)
113 113 queue=CBytes()
114 114 control=CBytes()
115 115 registration=CBytes()
116 116 heartbeat=CBytes()
117 117 pending=Set()
118 118
119 119 class HubFactory(RegistrationFactory):
120 120 """The Configurable for setting up a Hub."""
121 121
122 122 # port-pairs for monitoredqueues:
123 123 hb = Tuple(Int,Int,config=True,
124 124 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 125 def _hb_default(self):
126 126 return tuple(util.select_random_ports(2))
127 127
128 128 mux = Tuple(Int,Int,config=True,
129 129 help="""Engine/Client Port pair for MUX queue""")
130 130
131 131 def _mux_default(self):
132 132 return tuple(util.select_random_ports(2))
133 133
134 134 task = Tuple(Int,Int,config=True,
135 135 help="""Engine/Client Port pair for Task queue""")
136 136 def _task_default(self):
137 137 return tuple(util.select_random_ports(2))
138 138
139 139 control = Tuple(Int,Int,config=True,
140 140 help="""Engine/Client Port pair for Control queue""")
141 141
142 142 def _control_default(self):
143 143 return tuple(util.select_random_ports(2))
144 144
145 145 iopub = Tuple(Int,Int,config=True,
146 146 help="""Engine/Client Port pair for IOPub relay""")
147 147
148 148 def _iopub_default(self):
149 149 return tuple(util.select_random_ports(2))
150 150
151 151 # single ports:
152 152 mon_port = Int(config=True,
153 153 help="""Monitor (SUB) port for queue traffic""")
154 154
155 155 def _mon_port_default(self):
156 156 return util.select_random_ports(1)[0]
157 157
158 158 notifier_port = Int(config=True,
159 159 help="""PUB port for sending engine status notifications""")
160 160
161 161 def _notifier_port_default(self):
162 162 return util.select_random_ports(1)[0]
163 163
164 164 engine_ip = Unicode('127.0.0.1', config=True,
165 165 help="IP on which to listen for engine connections. [default: loopback]")
166 166 engine_transport = Unicode('tcp', config=True,
167 167 help="0MQ transport for engine connections. [default: tcp]")
168 168
169 169 client_ip = Unicode('127.0.0.1', config=True,
170 170 help="IP on which to listen for client connections. [default: loopback]")
171 171 client_transport = Unicode('tcp', config=True,
172 172 help="0MQ transport for client connections. [default : tcp]")
173 173
174 174 monitor_ip = Unicode('127.0.0.1', config=True,
175 175 help="IP on which to listen for monitor messages. [default: loopback]")
176 176 monitor_transport = Unicode('tcp', config=True,
177 177 help="0MQ transport for monitor messages. [default : tcp]")
178 178
179 179 monitor_url = Unicode('')
180 180
181 181 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
182 182 config=True, help="""The class to use for the DB backend""")
183 183
184 184 # not configurable
185 185 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 186 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 187
188 188 def _ip_changed(self, name, old, new):
189 189 self.engine_ip = new
190 190 self.client_ip = new
191 191 self.monitor_ip = new
192 192 self._update_monitor_url()
193 193
194 194 def _update_monitor_url(self):
195 195 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
196 196
197 197 def _transport_changed(self, name, old, new):
198 198 self.engine_transport = new
199 199 self.client_transport = new
200 200 self.monitor_transport = new
201 201 self._update_monitor_url()
202 202
203 203 def __init__(self, **kwargs):
204 204 super(HubFactory, self).__init__(**kwargs)
205 205 self._update_monitor_url()
206 206
207 207
208 208 def construct(self):
209 209 self.init_hub()
210 210
211 211 def start(self):
212 212 self.heartmonitor.start()
213 213 self.log.info("Heartmonitor started")
214 214
215 215 def init_hub(self):
216 216 """construct"""
217 217 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
218 218 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
219 219
220 220 ctx = self.context
221 221 loop = self.loop
222 222
223 223 # Registrar socket
224 224 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
225 225 q.bind(client_iface % self.regport)
226 226 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
227 227 if self.client_ip != self.engine_ip:
228 228 q.bind(engine_iface % self.regport)
229 229 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
230 230
231 231 ### Engine connections ###
232 232
233 233 # heartbeat
234 234 hpub = ctx.socket(zmq.PUB)
235 235 hpub.bind(engine_iface % self.hb[0])
236 236 hrep = ctx.socket(zmq.ROUTER)
237 237 hrep.bind(engine_iface % self.hb[1])
238 238 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 239 pingstream=ZMQStream(hpub,loop),
240 240 pongstream=ZMQStream(hrep,loop)
241 241 )
242 242
243 243 ### Client connections ###
244 244 # Notifier socket
245 245 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 246 n.bind(client_iface%self.notifier_port)
247 247
248 248 ### build and launch the queues ###
249 249
250 250 # monitor socket
251 251 sub = ctx.socket(zmq.SUB)
252 252 sub.setsockopt(zmq.SUBSCRIBE, b"")
253 253 sub.bind(self.monitor_url)
254 254 sub.bind('inproc://monitor')
255 255 sub = ZMQStream(sub, loop)
256 256
257 257 # connect the db
258 258 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 259 # cdir = self.config.Global.cluster_dir
260 260 self.db = import_item(str(self.db_class))(session=self.session.session,
261 261 config=self.config, log=self.log)
262 262 time.sleep(.25)
263 263 try:
264 264 scheme = self.config.TaskScheduler.scheme_name
265 265 except AttributeError:
266 266 from .scheduler import TaskScheduler
267 267 scheme = TaskScheduler.scheme_name.get_default_value()
268 268 # build connection dicts
269 269 self.engine_info = {
270 270 'control' : engine_iface%self.control[1],
271 271 'mux': engine_iface%self.mux[1],
272 272 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 273 'task' : engine_iface%self.task[1],
274 274 'iopub' : engine_iface%self.iopub[1],
275 275 # 'monitor' : engine_iface%self.mon_port,
276 276 }
277 277
278 278 self.client_info = {
279 279 'control' : client_iface%self.control[0],
280 280 'mux': client_iface%self.mux[0],
281 281 'task' : (scheme, client_iface%self.task[0]),
282 282 'iopub' : client_iface%self.iopub[0],
283 283 'notification': client_iface%self.notifier_port
284 284 }
285 285 self.log.debug("Hub engine addrs: %s"%self.engine_info)
286 286 self.log.debug("Hub client addrs: %s"%self.client_info)
287 287
288 288 # resubmit stream
289 289 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
290 290 url = util.disambiguate_url(self.client_info['task'][-1])
291 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
291 r.setsockopt(zmq.IDENTITY, self.session.bsession)
292 292 r.connect(url)
293 293
294 294 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 295 query=q, notifier=n, resubmit=r, db=self.db,
296 296 engine_info=self.engine_info, client_info=self.client_info,
297 297 log=self.log)
298 298
299 299
300 300 class Hub(SessionFactory):
301 301 """The IPython Controller Hub with 0MQ connections
302 302
303 303 Parameters
304 304 ==========
305 305 loop: zmq IOLoop instance
306 306 session: Session object
307 307 <removed> context: zmq context for creating new connections (?)
308 308 queue: ZMQStream for monitoring the command queue (SUB)
309 309 query: ZMQStream for engine registration and client queries requests (XREP)
310 310 heartbeat: HeartMonitor object checking the pulse of the engines
311 311 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 312 db: connection to db for out of memory logging of commands
313 313 NotImplemented
314 314 engine_info: dict of zmq connection information for engines to connect
315 315 to the queues.
316 316 client_info: dict of zmq connection information for engines to connect
317 317 to the queues.
318 318 """
319 319 # internal data structures:
320 320 ids=Set() # engine IDs
321 321 keytable=Dict()
322 322 by_ident=Dict()
323 323 engines=Dict()
324 324 clients=Dict()
325 325 hearts=Dict()
326 326 pending=Set()
327 327 queues=Dict() # pending msg_ids keyed by engine_id
328 328 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 329 completed=Dict() # completed msg_ids keyed by engine_id
330 330 all_completed=Set() # completed msg_ids keyed by engine_id
331 331 dead_engines=Set() # completed msg_ids keyed by engine_id
332 332 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 333 incoming_registrations=Dict()
334 334 registration_timeout=Int()
335 335 _idcounter=Int(0)
336 336
337 337 # objects from constructor:
338 338 query=Instance(ZMQStream)
339 339 monitor=Instance(ZMQStream)
340 340 notifier=Instance(ZMQStream)
341 341 resubmit=Instance(ZMQStream)
342 342 heartmonitor=Instance(HeartMonitor)
343 343 db=Instance(object)
344 344 client_info=Dict()
345 345 engine_info=Dict()
346 346
347 347
348 348 def __init__(self, **kwargs):
349 349 """
350 350 # universal:
351 351 loop: IOLoop for creating future connections
352 352 session: streamsession for sending serialized data
353 353 # engine:
354 354 queue: ZMQStream for monitoring queue messages
355 355 query: ZMQStream for engine+client registration and client requests
356 356 heartbeat: HeartMonitor object for tracking engines
357 357 # extra:
358 358 db: ZMQStream for db connection (NotImplemented)
359 359 engine_info: zmq address/protocol dict for engine connections
360 360 client_info: zmq address/protocol dict for client connections
361 361 """
362 362
363 363 super(Hub, self).__init__(**kwargs)
364 364 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 365
366 366 # validate connection dicts:
367 367 for k,v in self.client_info.iteritems():
368 368 if k == 'task':
369 369 util.validate_url_container(v[1])
370 370 else:
371 371 util.validate_url_container(v)
372 372 # util.validate_url_container(self.client_info)
373 373 util.validate_url_container(self.engine_info)
374 374
375 375 # register our callbacks
376 376 self.query.on_recv(self.dispatch_query)
377 377 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 378
379 379 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 380 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 381
382 382 self.monitor_handlers = {b'in' : self.save_queue_request,
383 383 b'out': self.save_queue_result,
384 384 b'intask': self.save_task_request,
385 385 b'outtask': self.save_task_result,
386 386 b'tracktask': self.save_task_destination,
387 387 b'incontrol': _passer,
388 388 b'outcontrol': _passer,
389 389 b'iopub': self.save_iopub_message,
390 390 }
391 391
392 392 self.query_handlers = {'queue_request': self.queue_status,
393 393 'result_request': self.get_results,
394 394 'history_request': self.get_history,
395 395 'db_request': self.db_query,
396 396 'purge_request': self.purge_results,
397 397 'load_request': self.check_load,
398 398 'resubmit_request': self.resubmit_task,
399 399 'shutdown_request': self.shutdown_request,
400 400 'registration_request' : self.register_engine,
401 401 'unregistration_request' : self.unregister_engine,
402 402 'connection_request': self.connection_request,
403 403 }
404 404
405 405 # ignore resubmit replies
406 406 self.resubmit.on_recv(lambda msg: None, copy=False)
407 407
408 408 self.log.info("hub::created hub")
409 409
410 410 @property
411 411 def _next_id(self):
412 412 """gemerate a new ID.
413 413
414 414 No longer reuse old ids, just count from 0."""
415 415 newid = self._idcounter
416 416 self._idcounter += 1
417 417 return newid
418 418 # newid = 0
419 419 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 420 # # print newid, self.ids, self.incoming_registrations
421 421 # while newid in self.ids or newid in incoming:
422 422 # newid += 1
423 423 # return newid
424 424
425 425 #-----------------------------------------------------------------------------
426 426 # message validation
427 427 #-----------------------------------------------------------------------------
428 428
429 429 def _validate_targets(self, targets):
430 430 """turn any valid targets argument into a list of integer ids"""
431 431 if targets is None:
432 432 # default to all
433 433 targets = self.ids
434 434
435 435 if isinstance(targets, (int,str,unicode)):
436 436 # only one target specified
437 437 targets = [targets]
438 438 _targets = []
439 439 for t in targets:
440 440 # map raw identities to ids
441 441 if isinstance(t, (str,unicode)):
442 442 t = self.by_ident.get(t, t)
443 443 _targets.append(t)
444 444 targets = _targets
445 445 bad_targets = [ t for t in targets if t not in self.ids ]
446 446 if bad_targets:
447 447 raise IndexError("No Such Engine: %r"%bad_targets)
448 448 if not targets:
449 449 raise IndexError("No Engines Registered")
450 450 return targets
451 451
452 452 #-----------------------------------------------------------------------------
453 453 # dispatch methods (1 per stream)
454 454 #-----------------------------------------------------------------------------
455 455
456 456
457 457 def dispatch_monitor_traffic(self, msg):
458 458 """all ME and Task queue messages come through here, as well as
459 459 IOPub traffic."""
460 460 self.log.debug("monitor traffic: %r"%msg[:2])
461 461 switch = msg[0]
462 462 try:
463 463 idents, msg = self.session.feed_identities(msg[1:])
464 464 except ValueError:
465 465 idents=[]
466 466 if not idents:
467 467 self.log.error("Bad Monitor Message: %r"%msg)
468 468 return
469 469 handler = self.monitor_handlers.get(switch, None)
470 470 if handler is not None:
471 471 handler(idents, msg)
472 472 else:
473 473 self.log.error("Invalid monitor topic: %r"%switch)
474 474
475 475
476 476 def dispatch_query(self, msg):
477 477 """Route registration requests and queries from clients."""
478 478 try:
479 479 idents, msg = self.session.feed_identities(msg)
480 480 except ValueError:
481 481 idents = []
482 482 if not idents:
483 483 self.log.error("Bad Query Message: %r"%msg)
484 484 return
485 485 client_id = idents[0]
486 486 try:
487 487 msg = self.session.unserialize(msg, content=True)
488 488 except Exception:
489 489 content = error.wrap_exception()
490 490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 491 self.session.send(self.query, "hub_error", ident=client_id,
492 492 content=content)
493 493 return
494 494 # print client_id, header, parent, content
495 495 #switch on message type:
496 496 msg_type = msg['header']['msg_type']
497 497 self.log.info("client::client %r requested %r"%(client_id, msg_type))
498 498 handler = self.query_handlers.get(msg_type, None)
499 499 try:
500 500 assert handler is not None, "Bad Message Type: %r"%msg_type
501 501 except:
502 502 content = error.wrap_exception()
503 503 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
504 504 self.session.send(self.query, "hub_error", ident=client_id,
505 505 content=content)
506 506 return
507 507
508 508 else:
509 509 handler(idents, msg)
510 510
511 511 def dispatch_db(self, msg):
512 512 """"""
513 513 raise NotImplementedError
514 514
515 515 #---------------------------------------------------------------------------
516 516 # handler methods (1 per event)
517 517 #---------------------------------------------------------------------------
518 518
519 519 #----------------------- Heartbeat --------------------------------------
520 520
521 521 def handle_new_heart(self, heart):
522 522 """handler to attach to heartbeater.
523 523 Called when a new heart starts to beat.
524 524 Triggers completion of registration."""
525 525 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 526 if heart not in self.incoming_registrations:
527 527 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 528 else:
529 529 self.finish_registration(heart)
530 530
531 531
532 532 def handle_heart_failure(self, heart):
533 533 """handler to attach to heartbeater.
534 534 called when a previously registered heart fails to respond to beat request.
535 535 triggers unregistration"""
536 536 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 537 eid = self.hearts.get(heart, None)
538 538 queue = self.engines[eid].queue
539 539 if eid is None:
540 540 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 541 else:
542 542 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 543
544 544 #----------------------- MUX Queue Traffic ------------------------------
545 545
546 546 def save_queue_request(self, idents, msg):
547 547 if len(idents) < 2:
548 548 self.log.error("invalid identity prefix: %r"%idents)
549 549 return
550 550 queue_id, client_id = idents[:2]
551 551 try:
552 552 msg = self.session.unserialize(msg)
553 553 except Exception:
554 554 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
555 555 return
556 556
557 557 eid = self.by_ident.get(queue_id, None)
558 558 if eid is None:
559 559 self.log.error("queue::target %r not registered"%queue_id)
560 560 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
561 561 return
562 562 record = init_record(msg)
563 563 msg_id = record['msg_id']
564 564 # Unicode in records
565 565 record['engine_uuid'] = queue_id.decode('ascii')
566 566 record['client_uuid'] = client_id.decode('ascii')
567 567 record['queue'] = 'mux'
568 568
569 569 try:
570 570 # it's posible iopub arrived first:
571 571 existing = self.db.get_record(msg_id)
572 572 for key,evalue in existing.iteritems():
573 573 rvalue = record.get(key, None)
574 574 if evalue and rvalue and evalue != rvalue:
575 575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
576 576 elif evalue and not rvalue:
577 577 record[key] = evalue
578 578 try:
579 579 self.db.update_record(msg_id, record)
580 580 except Exception:
581 581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
582 582 except KeyError:
583 583 try:
584 584 self.db.add_record(msg_id, record)
585 585 except Exception:
586 586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
587 587
588 588
589 589 self.pending.add(msg_id)
590 590 self.queues[eid].append(msg_id)
591 591
592 592 def save_queue_result(self, idents, msg):
593 593 if len(idents) < 2:
594 594 self.log.error("invalid identity prefix: %r"%idents)
595 595 return
596 596
597 597 client_id, queue_id = idents[:2]
598 598 try:
599 599 msg = self.session.unserialize(msg)
600 600 except Exception:
601 601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
602 602 queue_id,client_id, msg), exc_info=True)
603 603 return
604 604
605 605 eid = self.by_ident.get(queue_id, None)
606 606 if eid is None:
607 607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
608 608 return
609 609
610 610 parent = msg['parent_header']
611 611 if not parent:
612 612 return
613 613 msg_id = parent['msg_id']
614 614 if msg_id in self.pending:
615 615 self.pending.remove(msg_id)
616 616 self.all_completed.add(msg_id)
617 617 self.queues[eid].remove(msg_id)
618 618 self.completed[eid].append(msg_id)
619 619 elif msg_id not in self.all_completed:
620 620 # it could be a result from a dead engine that died before delivering the
621 621 # result
622 622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
623 623 return
624 624 # update record anyway, because the unregistration could have been premature
625 625 rheader = msg['header']
626 626 completed = rheader['date']
627 627 started = rheader.get('started', None)
628 628 result = {
629 629 'result_header' : rheader,
630 630 'result_content': msg['content'],
631 631 'started' : started,
632 632 'completed' : completed
633 633 }
634 634
635 635 result['result_buffers'] = msg['buffers']
636 636 try:
637 637 self.db.update_record(msg_id, result)
638 638 except Exception:
639 639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
640 640
641 641
642 642 #--------------------- Task Queue Traffic ------------------------------
643 643
644 644 def save_task_request(self, idents, msg):
645 645 """Save the submission of a task."""
646 646 client_id = idents[0]
647 647
648 648 try:
649 649 msg = self.session.unserialize(msg)
650 650 except Exception:
651 651 self.log.error("task::client %r sent invalid task message: %r"%(
652 652 client_id, msg), exc_info=True)
653 653 return
654 654 record = init_record(msg)
655 655
656 656 record['client_uuid'] = client_id
657 657 record['queue'] = 'task'
658 658 header = msg['header']
659 659 msg_id = header['msg_id']
660 660 self.pending.add(msg_id)
661 661 self.unassigned.add(msg_id)
662 662 try:
663 663 # it's posible iopub arrived first:
664 664 existing = self.db.get_record(msg_id)
665 665 if existing['resubmitted']:
666 666 for key in ('submitted', 'client_uuid', 'buffers'):
667 667 # don't clobber these keys on resubmit
668 668 # submitted and client_uuid should be different
669 669 # and buffers might be big, and shouldn't have changed
670 670 record.pop(key)
671 671 # still check content,header which should not change
672 672 # but are not expensive to compare as buffers
673 673
674 674 for key,evalue in existing.iteritems():
675 675 if key.endswith('buffers'):
676 676 # don't compare buffers
677 677 continue
678 678 rvalue = record.get(key, None)
679 679 if evalue and rvalue and evalue != rvalue:
680 680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
681 681 elif evalue and not rvalue:
682 682 record[key] = evalue
683 683 try:
684 684 self.db.update_record(msg_id, record)
685 685 except Exception:
686 686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
687 687 except KeyError:
688 688 try:
689 689 self.db.add_record(msg_id, record)
690 690 except Exception:
691 691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
692 692 except Exception:
693 693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
694 694
695 695 def save_task_result(self, idents, msg):
696 696 """save the result of a completed task."""
697 697 client_id = idents[0]
698 698 try:
699 699 msg = self.session.unserialize(msg)
700 700 except Exception:
701 701 self.log.error("task::invalid task result message send to %r: %r"%(
702 702 client_id, msg), exc_info=True)
703 703 return
704 704
705 705 parent = msg['parent_header']
706 706 if not parent:
707 707 # print msg
708 708 self.log.warn("Task %r had no parent!"%msg)
709 709 return
710 710 msg_id = parent['msg_id']
711 711 if msg_id in self.unassigned:
712 712 self.unassigned.remove(msg_id)
713 713
714 714 header = msg['header']
715 715 engine_uuid = header.get('engine', None)
716 716 eid = self.by_ident.get(engine_uuid, None)
717 717
718 718 if msg_id in self.pending:
719 719 self.pending.remove(msg_id)
720 720 self.all_completed.add(msg_id)
721 721 if eid is not None:
722 722 self.completed[eid].append(msg_id)
723 723 if msg_id in self.tasks[eid]:
724 724 self.tasks[eid].remove(msg_id)
725 725 completed = header['date']
726 726 started = header.get('started', None)
727 727 result = {
728 728 'result_header' : header,
729 729 'result_content': msg['content'],
730 730 'started' : started,
731 731 'completed' : completed,
732 732 'engine_uuid': engine_uuid
733 733 }
734 734
735 735 result['result_buffers'] = msg['buffers']
736 736 try:
737 737 self.db.update_record(msg_id, result)
738 738 except Exception:
739 739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
740 740
741 741 else:
742 742 self.log.debug("task::unknown task %r finished"%msg_id)
743 743
744 744 def save_task_destination(self, idents, msg):
745 745 try:
746 746 msg = self.session.unserialize(msg, content=True)
747 747 except Exception:
748 748 self.log.error("task::invalid task tracking message", exc_info=True)
749 749 return
750 750 content = msg['content']
751 751 # print (content)
752 752 msg_id = content['msg_id']
753 753 engine_uuid = content['engine_id']
754 754 eid = self.by_ident[util.asbytes(engine_uuid)]
755 755
756 756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
757 757 if msg_id in self.unassigned:
758 758 self.unassigned.remove(msg_id)
759 759 # else:
760 760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
761 761
762 762 self.tasks[eid].append(msg_id)
763 763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
764 764 try:
765 765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
766 766 except Exception:
767 767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
768 768
769 769
770 770 def mia_task_request(self, idents, msg):
771 771 raise NotImplementedError
772 772 client_id = idents[0]
773 773 # content = dict(mia=self.mia,status='ok')
774 774 # self.session.send('mia_reply', content=content, idents=client_id)
775 775
776 776
777 777 #--------------------- IOPub Traffic ------------------------------
778 778
779 779 def save_iopub_message(self, topics, msg):
780 780 """save an iopub message into the db"""
781 781 # print (topics)
782 782 try:
783 783 msg = self.session.unserialize(msg, content=True)
784 784 except Exception:
785 785 self.log.error("iopub::invalid IOPub message", exc_info=True)
786 786 return
787 787
788 788 parent = msg['parent_header']
789 789 if not parent:
790 790 self.log.error("iopub::invalid IOPub message: %r"%msg)
791 791 return
792 792 msg_id = parent['msg_id']
793 793 msg_type = msg['header']['msg_type']
794 794 content = msg['content']
795 795
796 796 # ensure msg_id is in db
797 797 try:
798 798 rec = self.db.get_record(msg_id)
799 799 except KeyError:
800 800 rec = empty_record()
801 801 rec['msg_id'] = msg_id
802 802 self.db.add_record(msg_id, rec)
803 803 # stream
804 804 d = {}
805 805 if msg_type == 'stream':
806 806 name = content['name']
807 807 s = rec[name] or ''
808 808 d[name] = s + content['data']
809 809
810 810 elif msg_type == 'pyerr':
811 811 d['pyerr'] = content
812 812 elif msg_type == 'pyin':
813 813 d['pyin'] = content['code']
814 814 else:
815 815 d[msg_type] = content.get('data', '')
816 816
817 817 try:
818 818 self.db.update_record(msg_id, d)
819 819 except Exception:
820 820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
821 821
822 822
823 823
824 824 #-------------------------------------------------------------------------
825 825 # Registration requests
826 826 #-------------------------------------------------------------------------
827 827
828 828 def connection_request(self, client_id, msg):
829 829 """Reply with connection addresses for clients."""
830 830 self.log.info("client::client %r connected"%client_id)
831 831 content = dict(status='ok')
832 832 content.update(self.client_info)
833 833 jsonable = {}
834 834 for k,v in self.keytable.iteritems():
835 835 if v not in self.dead_engines:
836 836 jsonable[str(k)] = v.decode('ascii')
837 837 content['engines'] = jsonable
838 838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
839 839
840 840 def register_engine(self, reg, msg):
841 841 """Register a new engine."""
842 842 content = msg['content']
843 843 try:
844 844 queue = util.asbytes(content['queue'])
845 845 except KeyError:
846 846 self.log.error("registration::queue not specified", exc_info=True)
847 847 return
848 848 heart = content.get('heartbeat', None)
849 849 if heart:
850 850 heart = util.asbytes(heart)
851 851 """register a new engine, and create the socket(s) necessary"""
852 852 eid = self._next_id
853 853 # print (eid, queue, reg, heart)
854 854
855 855 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
856 856
857 857 content = dict(id=eid,status='ok')
858 858 content.update(self.engine_info)
859 859 # check if requesting available IDs:
860 860 if queue in self.by_ident:
861 861 try:
862 862 raise KeyError("queue_id %r in use"%queue)
863 863 except:
864 864 content = error.wrap_exception()
865 865 self.log.error("queue_id %r in use"%queue, exc_info=True)
866 866 elif heart in self.hearts: # need to check unique hearts?
867 867 try:
868 868 raise KeyError("heart_id %r in use"%heart)
869 869 except:
870 870 self.log.error("heart_id %r in use"%heart, exc_info=True)
871 871 content = error.wrap_exception()
872 872 else:
873 873 for h, pack in self.incoming_registrations.iteritems():
874 874 if heart == h:
875 875 try:
876 876 raise KeyError("heart_id %r in use"%heart)
877 877 except:
878 878 self.log.error("heart_id %r in use"%heart, exc_info=True)
879 879 content = error.wrap_exception()
880 880 break
881 881 elif queue == pack[1]:
882 882 try:
883 883 raise KeyError("queue_id %r in use"%queue)
884 884 except:
885 885 self.log.error("queue_id %r in use"%queue, exc_info=True)
886 886 content = error.wrap_exception()
887 887 break
888 888
889 889 msg = self.session.send(self.query, "registration_reply",
890 890 content=content,
891 891 ident=reg)
892 892
893 893 if content['status'] == 'ok':
894 894 if heart in self.heartmonitor.hearts:
895 895 # already beating
896 896 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
897 897 self.finish_registration(heart)
898 898 else:
899 899 purge = lambda : self._purge_stalled_registration(heart)
900 900 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
901 901 dc.start()
902 902 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
903 903 else:
904 904 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
905 905 return eid
906 906
907 907 def unregister_engine(self, ident, msg):
908 908 """Unregister an engine that explicitly requested to leave."""
909 909 try:
910 910 eid = msg['content']['id']
911 911 except:
912 912 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
913 913 return
914 914 self.log.info("registration::unregister_engine(%r)"%eid)
915 915 # print (eid)
916 916 uuid = self.keytable[eid]
917 917 content=dict(id=eid, queue=uuid.decode('ascii'))
918 918 self.dead_engines.add(uuid)
919 919 # self.ids.remove(eid)
920 920 # uuid = self.keytable.pop(eid)
921 921 #
922 922 # ec = self.engines.pop(eid)
923 923 # self.hearts.pop(ec.heartbeat)
924 924 # self.by_ident.pop(ec.queue)
925 925 # self.completed.pop(eid)
926 926 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
927 927 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
928 928 dc.start()
929 929 ############## TODO: HANDLE IT ################
930 930
931 931 if self.notifier:
932 932 self.session.send(self.notifier, "unregistration_notification", content=content)
933 933
934 934 def _handle_stranded_msgs(self, eid, uuid):
935 935 """Handle messages known to be on an engine when the engine unregisters.
936 936
937 937 It is possible that this will fire prematurely - that is, an engine will
938 938 go down after completing a result, and the client will be notified
939 939 that the result failed and later receive the actual result.
940 940 """
941 941
942 942 outstanding = self.queues[eid]
943 943
944 944 for msg_id in outstanding:
945 945 self.pending.remove(msg_id)
946 946 self.all_completed.add(msg_id)
947 947 try:
948 948 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
949 949 except:
950 950 content = error.wrap_exception()
951 951 # build a fake header:
952 952 header = {}
953 953 header['engine'] = uuid
954 954 header['date'] = datetime.now()
955 955 rec = dict(result_content=content, result_header=header, result_buffers=[])
956 956 rec['completed'] = header['date']
957 957 rec['engine_uuid'] = uuid
958 958 try:
959 959 self.db.update_record(msg_id, rec)
960 960 except Exception:
961 961 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
962 962
963 963
964 964 def finish_registration(self, heart):
965 965 """Second half of engine registration, called after our HeartMonitor
966 966 has received a beat from the Engine's Heart."""
967 967 try:
968 968 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
969 969 except KeyError:
970 970 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
971 971 return
972 972 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
973 973 if purge is not None:
974 974 purge.stop()
975 975 control = queue
976 976 self.ids.add(eid)
977 977 self.keytable[eid] = queue
978 978 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
979 979 control=control, heartbeat=heart)
980 980 self.by_ident[queue] = eid
981 981 self.queues[eid] = list()
982 982 self.tasks[eid] = list()
983 983 self.completed[eid] = list()
984 984 self.hearts[heart] = eid
985 985 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
986 986 if self.notifier:
987 987 self.session.send(self.notifier, "registration_notification", content=content)
988 988 self.log.info("engine::Engine Connected: %i"%eid)
989 989
990 990 def _purge_stalled_registration(self, heart):
991 991 if heart in self.incoming_registrations:
992 992 eid = self.incoming_registrations.pop(heart)[0]
993 993 self.log.info("registration::purging stalled registration: %i"%eid)
994 994 else:
995 995 pass
996 996
997 997 #-------------------------------------------------------------------------
998 998 # Client Requests
999 999 #-------------------------------------------------------------------------
1000 1000
1001 1001 def shutdown_request(self, client_id, msg):
1002 1002 """handle shutdown request."""
1003 1003 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1004 1004 # also notify other clients of shutdown
1005 1005 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1006 1006 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1007 1007 dc.start()
1008 1008
1009 1009 def _shutdown(self):
1010 1010 self.log.info("hub::hub shutting down.")
1011 1011 time.sleep(0.1)
1012 1012 sys.exit(0)
1013 1013
1014 1014
1015 1015 def check_load(self, client_id, msg):
1016 1016 content = msg['content']
1017 1017 try:
1018 1018 targets = content['targets']
1019 1019 targets = self._validate_targets(targets)
1020 1020 except:
1021 1021 content = error.wrap_exception()
1022 1022 self.session.send(self.query, "hub_error",
1023 1023 content=content, ident=client_id)
1024 1024 return
1025 1025
1026 1026 content = dict(status='ok')
1027 1027 # loads = {}
1028 1028 for t in targets:
1029 1029 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1030 1030 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1031 1031
1032 1032
1033 1033 def queue_status(self, client_id, msg):
1034 1034 """Return the Queue status of one or more targets.
1035 1035 if verbose: return the msg_ids
1036 1036 else: return len of each type.
1037 1037 keys: queue (pending MUX jobs)
1038 1038 tasks (pending Task jobs)
1039 1039 completed (finished jobs from both queues)"""
1040 1040 content = msg['content']
1041 1041 targets = content['targets']
1042 1042 try:
1043 1043 targets = self._validate_targets(targets)
1044 1044 except:
1045 1045 content = error.wrap_exception()
1046 1046 self.session.send(self.query, "hub_error",
1047 1047 content=content, ident=client_id)
1048 1048 return
1049 1049 verbose = content.get('verbose', False)
1050 1050 content = dict(status='ok')
1051 1051 for t in targets:
1052 1052 queue = self.queues[t]
1053 1053 completed = self.completed[t]
1054 1054 tasks = self.tasks[t]
1055 1055 if not verbose:
1056 1056 queue = len(queue)
1057 1057 completed = len(completed)
1058 1058 tasks = len(tasks)
1059 1059 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1060 1060 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1061 1061 # print (content)
1062 1062 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1063 1063
1064 1064 def purge_results(self, client_id, msg):
1065 1065 """Purge results from memory. This method is more valuable before we move
1066 1066 to a DB based message storage mechanism."""
1067 1067 content = msg['content']
1068 1068 self.log.info("Dropping records with %s", content)
1069 1069 msg_ids = content.get('msg_ids', [])
1070 1070 reply = dict(status='ok')
1071 1071 if msg_ids == 'all':
1072 1072 try:
1073 1073 self.db.drop_matching_records(dict(completed={'$ne':None}))
1074 1074 except Exception:
1075 1075 reply = error.wrap_exception()
1076 1076 else:
1077 1077 pending = filter(lambda m: m in self.pending, msg_ids)
1078 1078 if pending:
1079 1079 try:
1080 1080 raise IndexError("msg pending: %r"%pending[0])
1081 1081 except:
1082 1082 reply = error.wrap_exception()
1083 1083 else:
1084 1084 try:
1085 1085 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1086 1086 except Exception:
1087 1087 reply = error.wrap_exception()
1088 1088
1089 1089 if reply['status'] == 'ok':
1090 1090 eids = content.get('engine_ids', [])
1091 1091 for eid in eids:
1092 1092 if eid not in self.engines:
1093 1093 try:
1094 1094 raise IndexError("No such engine: %i"%eid)
1095 1095 except:
1096 1096 reply = error.wrap_exception()
1097 1097 break
1098 1098 uid = self.engines[eid].queue
1099 1099 try:
1100 1100 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1101 1101 except Exception:
1102 1102 reply = error.wrap_exception()
1103 1103 break
1104 1104
1105 1105 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1106 1106
1107 1107 def resubmit_task(self, client_id, msg):
1108 1108 """Resubmit one or more tasks."""
1109 1109 def finish(reply):
1110 1110 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1111 1111
1112 1112 content = msg['content']
1113 1113 msg_ids = content['msg_ids']
1114 1114 reply = dict(status='ok')
1115 1115 try:
1116 1116 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1117 1117 'header', 'content', 'buffers'])
1118 1118 except Exception:
1119 1119 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1120 1120 return finish(error.wrap_exception())
1121 1121
1122 1122 # validate msg_ids
1123 1123 found_ids = [ rec['msg_id'] for rec in records ]
1124 1124 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1125 1125 if len(records) > len(msg_ids):
1126 1126 try:
1127 1127 raise RuntimeError("DB appears to be in an inconsistent state."
1128 1128 "More matching records were found than should exist")
1129 1129 except Exception:
1130 1130 return finish(error.wrap_exception())
1131 1131 elif len(records) < len(msg_ids):
1132 1132 missing = [ m for m in msg_ids if m not in found_ids ]
1133 1133 try:
1134 1134 raise KeyError("No such msg(s): %r"%missing)
1135 1135 except KeyError:
1136 1136 return finish(error.wrap_exception())
1137 1137 elif invalid_ids:
1138 1138 msg_id = invalid_ids[0]
1139 1139 try:
1140 1140 raise ValueError("Task %r appears to be inflight"%(msg_id))
1141 1141 except Exception:
1142 1142 return finish(error.wrap_exception())
1143 1143
1144 1144 # clear the existing records
1145 1145 now = datetime.now()
1146 1146 rec = empty_record()
1147 1147 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1148 1148 rec['resubmitted'] = now
1149 1149 rec['queue'] = 'task'
1150 1150 rec['client_uuid'] = client_id[0]
1151 1151 try:
1152 1152 for msg_id in msg_ids:
1153 1153 self.all_completed.discard(msg_id)
1154 1154 self.db.update_record(msg_id, rec)
1155 1155 except Exception:
1156 1156 self.log.error('db::db error upating record', exc_info=True)
1157 1157 reply = error.wrap_exception()
1158 1158 else:
1159 1159 # send the messages
1160 1160 for rec in records:
1161 1161 header = rec['header']
1162 1162 # include resubmitted in header to prevent digest collision
1163 1163 header['resubmitted'] = now
1164 1164 msg = self.session.msg(header['msg_type'])
1165 1165 msg['content'] = rec['content']
1166 1166 msg['header'] = header
1167 1167 msg['header']['msg_id'] = rec['msg_id']
1168 1168 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1169 1169
1170 1170 finish(dict(status='ok'))
1171 1171
1172 1172
1173 1173 def _extract_record(self, rec):
1174 1174 """decompose a TaskRecord dict into subsection of reply for get_result"""
1175 1175 io_dict = {}
1176 1176 for key in 'pyin pyout pyerr stdout stderr'.split():
1177 1177 io_dict[key] = rec[key]
1178 1178 content = { 'result_content': rec['result_content'],
1179 1179 'header': rec['header'],
1180 1180 'result_header' : rec['result_header'],
1181 1181 'io' : io_dict,
1182 1182 }
1183 1183 if rec['result_buffers']:
1184 1184 buffers = map(bytes, rec['result_buffers'])
1185 1185 else:
1186 1186 buffers = []
1187 1187
1188 1188 return content, buffers
1189 1189
1190 1190 def get_results(self, client_id, msg):
1191 1191 """Get the result of 1 or more messages."""
1192 1192 content = msg['content']
1193 1193 msg_ids = sorted(set(content['msg_ids']))
1194 1194 statusonly = content.get('status_only', False)
1195 1195 pending = []
1196 1196 completed = []
1197 1197 content = dict(status='ok')
1198 1198 content['pending'] = pending
1199 1199 content['completed'] = completed
1200 1200 buffers = []
1201 1201 if not statusonly:
1202 1202 try:
1203 1203 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1204 1204 # turn match list into dict, for faster lookup
1205 1205 records = {}
1206 1206 for rec in matches:
1207 1207 records[rec['msg_id']] = rec
1208 1208 except Exception:
1209 1209 content = error.wrap_exception()
1210 1210 self.session.send(self.query, "result_reply", content=content,
1211 1211 parent=msg, ident=client_id)
1212 1212 return
1213 1213 else:
1214 1214 records = {}
1215 1215 for msg_id in msg_ids:
1216 1216 if msg_id in self.pending:
1217 1217 pending.append(msg_id)
1218 1218 elif msg_id in self.all_completed:
1219 1219 completed.append(msg_id)
1220 1220 if not statusonly:
1221 1221 c,bufs = self._extract_record(records[msg_id])
1222 1222 content[msg_id] = c
1223 1223 buffers.extend(bufs)
1224 1224 elif msg_id in records:
1225 1225 if rec['completed']:
1226 1226 completed.append(msg_id)
1227 1227 c,bufs = self._extract_record(records[msg_id])
1228 1228 content[msg_id] = c
1229 1229 buffers.extend(bufs)
1230 1230 else:
1231 1231 pending.append(msg_id)
1232 1232 else:
1233 1233 try:
1234 1234 raise KeyError('No such message: '+msg_id)
1235 1235 except:
1236 1236 content = error.wrap_exception()
1237 1237 break
1238 1238 self.session.send(self.query, "result_reply", content=content,
1239 1239 parent=msg, ident=client_id,
1240 1240 buffers=buffers)
1241 1241
1242 1242 def get_history(self, client_id, msg):
1243 1243 """Get a list of all msg_ids in our DB records"""
1244 1244 try:
1245 1245 msg_ids = self.db.get_history()
1246 1246 except Exception as e:
1247 1247 content = error.wrap_exception()
1248 1248 else:
1249 1249 content = dict(status='ok', history=msg_ids)
1250 1250
1251 1251 self.session.send(self.query, "history_reply", content=content,
1252 1252 parent=msg, ident=client_id)
1253 1253
1254 1254 def db_query(self, client_id, msg):
1255 1255 """Perform a raw query on the task record database."""
1256 1256 content = msg['content']
1257 1257 query = content.get('query', {})
1258 1258 keys = content.get('keys', None)
1259 1259 buffers = []
1260 1260 empty = list()
1261 1261 try:
1262 1262 records = self.db.find_records(query, keys)
1263 1263 except Exception as e:
1264 1264 content = error.wrap_exception()
1265 1265 else:
1266 1266 # extract buffers from reply content:
1267 1267 if keys is not None:
1268 1268 buffer_lens = [] if 'buffers' in keys else None
1269 1269 result_buffer_lens = [] if 'result_buffers' in keys else None
1270 1270 else:
1271 1271 buffer_lens = []
1272 1272 result_buffer_lens = []
1273 1273
1274 1274 for rec in records:
1275 1275 # buffers may be None, so double check
1276 1276 if buffer_lens is not None:
1277 1277 b = rec.pop('buffers', empty) or empty
1278 1278 buffer_lens.append(len(b))
1279 1279 buffers.extend(b)
1280 1280 if result_buffer_lens is not None:
1281 1281 rb = rec.pop('result_buffers', empty) or empty
1282 1282 result_buffer_lens.append(len(rb))
1283 1283 buffers.extend(rb)
1284 1284 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1285 1285 result_buffer_lens=result_buffer_lens)
1286 1286 # self.log.debug (content)
1287 1287 self.session.send(self.query, "db_reply", content=content,
1288 1288 parent=msg, ident=client_id,
1289 1289 buffers=buffers)
1290 1290
@@ -1,714 +1,714 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum, CBytes
44 44
45 45 from IPython.parallel import error
46 46 from IPython.parallel.factory import SessionFactory
47 47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48 48
49 49 from .dependency import Dependency
50 50
51 51 @decorator
52 52 def logged(f,self,*args,**kwargs):
53 53 # print ("#--------------------")
54 54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 55 # print ("#--")
56 56 return f(self,*args, **kwargs)
57 57
58 58 #----------------------------------------------------------------------
59 59 # Chooser functions
60 60 #----------------------------------------------------------------------
61 61
62 62 def plainrandom(loads):
63 63 """Plain random pick."""
64 64 n = len(loads)
65 65 return randint(0,n-1)
66 66
67 67 def lru(loads):
68 68 """Always pick the front of the line.
69 69
70 70 The content of `loads` is ignored.
71 71
72 72 Assumes LRU ordering of loads, with oldest first.
73 73 """
74 74 return 0
75 75
76 76 def twobin(loads):
77 77 """Pick two at random, use the LRU of the two.
78 78
79 79 The content of loads is ignored.
80 80
81 81 Assumes LRU ordering of loads, with oldest first.
82 82 """
83 83 n = len(loads)
84 84 a = randint(0,n-1)
85 85 b = randint(0,n-1)
86 86 return min(a,b)
87 87
88 88 def weighted(loads):
89 89 """Pick two at random using inverse load as weight.
90 90
91 91 Return the less loaded of the two.
92 92 """
93 93 # weight 0 a million times more than 1:
94 94 weights = 1./(1e-6+numpy.array(loads))
95 95 sums = weights.cumsum()
96 96 t = sums[-1]
97 97 x = random()*t
98 98 y = random()*t
99 99 idx = 0
100 100 idy = 0
101 101 while sums[idx] < x:
102 102 idx += 1
103 103 while sums[idy] < y:
104 104 idy += 1
105 105 if weights[idy] > weights[idx]:
106 106 return idy
107 107 else:
108 108 return idx
109 109
110 110 def leastload(loads):
111 111 """Always choose the lowest load.
112 112
113 113 If the lowest load occurs more than once, the first
114 114 occurance will be used. If loads has LRU ordering, this means
115 115 the LRU of those with the lowest load is chosen.
116 116 """
117 117 return loads.index(min(loads))
118 118
119 119 #---------------------------------------------------------------------
120 120 # Classes
121 121 #---------------------------------------------------------------------
122 122 # store empty default dependency:
123 123 MET = Dependency([])
124 124
125 125 class TaskScheduler(SessionFactory):
126 126 """Python TaskScheduler object.
127 127
128 128 This is the simplest object that supports msg_id based
129 129 DAG dependencies. *Only* task msg_ids are checked, not
130 130 msg_ids of jobs submitted via the MUX queue.
131 131
132 132 """
133 133
134 134 hwm = Int(0, config=True, shortname='hwm',
135 135 help="""specify the High Water Mark (HWM) for the downstream
136 136 socket in the Task scheduler. This is the maximum number
137 137 of allowed outstanding tasks on each engine."""
138 138 )
139 139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 140 'leastload', config=True, shortname='scheme', allow_none=False,
141 141 help="""select the task scheduler scheme [default: Python LRU]
142 142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 143 )
144 144 def _scheme_name_changed(self, old, new):
145 145 self.log.debug("Using scheme %r"%new)
146 146 self.scheme = globals()[new]
147 147
148 148 # input arguments:
149 149 scheme = Instance(FunctionType) # function for determining the destination
150 150 def _scheme_default(self):
151 151 return leastload
152 152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156 156
157 157 # internals:
158 158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 162 pending = Dict() # dict by engine_uuid of submitted tasks
163 163 completed = Dict() # dict by engine_uuid of completed tasks
164 164 failed = Dict() # dict by engine_uuid of failed tasks
165 165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 166 clients = Dict() # dict by msg_id for who submitted the task
167 167 targets = List() # list of target IDENTs
168 168 loads = List() # list of engine loads
169 169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 170 all_completed = Set() # set of all completed tasks
171 171 all_failed = Set() # set of all failed tasks
172 172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 173 all_ids = Set() # set of all submitted task IDs
174 174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176 176
177 177 ident = CBytes() # ZMQ identity. This should just be self.session.session
178 178 # but ensure Bytes
179 179 def _ident_default(self):
180 return asbytes(self.session.session)
180 return self.session.bsession
181 181
182 182 def start(self):
183 183 self.engine_stream.on_recv(self.dispatch_result, copy=False)
184 184 self._notification_handlers = dict(
185 185 registration_notification = self._register_engine,
186 186 unregistration_notification = self._unregister_engine
187 187 )
188 188 self.notifier_stream.on_recv(self.dispatch_notification)
189 189 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
190 190 self.auditor.start()
191 191 self.log.info("Scheduler started [%s]"%self.scheme_name)
192 192
193 193 def resume_receiving(self):
194 194 """Resume accepting jobs."""
195 195 self.client_stream.on_recv(self.dispatch_submission, copy=False)
196 196
197 197 def stop_receiving(self):
198 198 """Stop accepting jobs while there are no engines.
199 199 Leave them in the ZMQ queue."""
200 200 self.client_stream.on_recv(None)
201 201
202 202 #-----------------------------------------------------------------------
203 203 # [Un]Registration Handling
204 204 #-----------------------------------------------------------------------
205 205
206 206 def dispatch_notification(self, msg):
207 207 """dispatch register/unregister events."""
208 208 try:
209 209 idents,msg = self.session.feed_identities(msg)
210 210 except ValueError:
211 211 self.log.warn("task::Invalid Message: %r",msg)
212 212 return
213 213 try:
214 214 msg = self.session.unserialize(msg)
215 215 except ValueError:
216 216 self.log.warn("task::Unauthorized message from: %r"%idents)
217 217 return
218 218
219 219 msg_type = msg['header']['msg_type']
220 220
221 221 handler = self._notification_handlers.get(msg_type, None)
222 222 if handler is None:
223 223 self.log.error("Unhandled message type: %r"%msg_type)
224 224 else:
225 225 try:
226 226 handler(asbytes(msg['content']['queue']))
227 227 except Exception:
228 228 self.log.error("task::Invalid notification msg: %r",msg)
229 229
230 230 def _register_engine(self, uid):
231 231 """New engine with ident `uid` became available."""
232 232 # head of the line:
233 233 self.targets.insert(0,uid)
234 234 self.loads.insert(0,0)
235 235
236 236 # initialize sets
237 237 self.completed[uid] = set()
238 238 self.failed[uid] = set()
239 239 self.pending[uid] = {}
240 240 if len(self.targets) == 1:
241 241 self.resume_receiving()
242 242 # rescan the graph:
243 243 self.update_graph(None)
244 244
245 245 def _unregister_engine(self, uid):
246 246 """Existing engine with ident `uid` became unavailable."""
247 247 if len(self.targets) == 1:
248 248 # this was our only engine
249 249 self.stop_receiving()
250 250
251 251 # handle any potentially finished tasks:
252 252 self.engine_stream.flush()
253 253
254 254 # don't pop destinations, because they might be used later
255 255 # map(self.destinations.pop, self.completed.pop(uid))
256 256 # map(self.destinations.pop, self.failed.pop(uid))
257 257
258 258 # prevent this engine from receiving work
259 259 idx = self.targets.index(uid)
260 260 self.targets.pop(idx)
261 261 self.loads.pop(idx)
262 262
263 263 # wait 5 seconds before cleaning up pending jobs, since the results might
264 264 # still be incoming
265 265 if self.pending[uid]:
266 266 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
267 267 dc.start()
268 268 else:
269 269 self.completed.pop(uid)
270 270 self.failed.pop(uid)
271 271
272 272
273 273 def handle_stranded_tasks(self, engine):
274 274 """Deal with jobs resident in an engine that died."""
275 275 lost = self.pending[engine]
276 276 for msg_id in lost.keys():
277 277 if msg_id not in self.pending[engine]:
278 278 # prevent double-handling of messages
279 279 continue
280 280
281 281 raw_msg = lost[msg_id][0]
282 282 idents,msg = self.session.feed_identities(raw_msg, copy=False)
283 283 parent = self.session.unpack(msg[1].bytes)
284 284 idents = [engine, idents[0]]
285 285
286 286 # build fake error reply
287 287 try:
288 288 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
289 289 except:
290 290 content = error.wrap_exception()
291 291 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
292 292 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
293 293 # and dispatch it
294 294 self.dispatch_result(raw_reply)
295 295
296 296 # finally scrub completed/failed lists
297 297 self.completed.pop(engine)
298 298 self.failed.pop(engine)
299 299
300 300
301 301 #-----------------------------------------------------------------------
302 302 # Job Submission
303 303 #-----------------------------------------------------------------------
304 304 def dispatch_submission(self, raw_msg):
305 305 """Dispatch job submission to appropriate handlers."""
306 306 # ensure targets up to date:
307 307 self.notifier_stream.flush()
308 308 try:
309 309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
310 310 msg = self.session.unserialize(msg, content=False, copy=False)
311 311 except Exception:
312 312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
313 313 return
314 314
315 315
316 316 # send to monitor
317 317 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
318 318
319 319 header = msg['header']
320 320 msg_id = header['msg_id']
321 321 self.all_ids.add(msg_id)
322 322
323 323 # get targets as a set of bytes objects
324 324 # from a list of unicode objects
325 325 targets = header.get('targets', [])
326 326 targets = map(asbytes, targets)
327 327 targets = set(targets)
328 328
329 329 retries = header.get('retries', 0)
330 330 self.retries[msg_id] = retries
331 331
332 332 # time dependencies
333 333 after = header.get('after', None)
334 334 if after:
335 335 after = Dependency(after)
336 336 if after.all:
337 337 if after.success:
338 338 after = Dependency(after.difference(self.all_completed),
339 339 success=after.success,
340 340 failure=after.failure,
341 341 all=after.all,
342 342 )
343 343 if after.failure:
344 344 after = Dependency(after.difference(self.all_failed),
345 345 success=after.success,
346 346 failure=after.failure,
347 347 all=after.all,
348 348 )
349 349 if after.check(self.all_completed, self.all_failed):
350 350 # recast as empty set, if `after` already met,
351 351 # to prevent unnecessary set comparisons
352 352 after = MET
353 353 else:
354 354 after = MET
355 355
356 356 # location dependencies
357 357 follow = Dependency(header.get('follow', []))
358 358
359 359 # turn timeouts into datetime objects:
360 360 timeout = header.get('timeout', None)
361 361 if timeout:
362 362 timeout = datetime.now() + timedelta(0,timeout,0)
363 363
364 364 args = [raw_msg, targets, after, follow, timeout]
365 365
366 366 # validate and reduce dependencies:
367 367 for dep in after,follow:
368 368 if not dep: # empty dependency
369 369 continue
370 370 # check valid:
371 371 if msg_id in dep or dep.difference(self.all_ids):
372 372 self.depending[msg_id] = args
373 373 return self.fail_unreachable(msg_id, error.InvalidDependency)
374 374 # check if unreachable:
375 375 if dep.unreachable(self.all_completed, self.all_failed):
376 376 self.depending[msg_id] = args
377 377 return self.fail_unreachable(msg_id)
378 378
379 379 if after.check(self.all_completed, self.all_failed):
380 380 # time deps already met, try to run
381 381 if not self.maybe_run(msg_id, *args):
382 382 # can't run yet
383 383 if msg_id not in self.all_failed:
384 384 # could have failed as unreachable
385 385 self.save_unmet(msg_id, *args)
386 386 else:
387 387 self.save_unmet(msg_id, *args)
388 388
389 389 def audit_timeouts(self):
390 390 """Audit all waiting tasks for expired timeouts."""
391 391 now = datetime.now()
392 392 for msg_id in self.depending.keys():
393 393 # must recheck, in case one failure cascaded to another:
394 394 if msg_id in self.depending:
395 395 raw,after,targets,follow,timeout = self.depending[msg_id]
396 396 if timeout and timeout < now:
397 397 self.fail_unreachable(msg_id, error.TaskTimeout)
398 398
399 399 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
400 400 """a task has become unreachable, send a reply with an ImpossibleDependency
401 401 error."""
402 402 if msg_id not in self.depending:
403 403 self.log.error("msg %r already failed!", msg_id)
404 404 return
405 405 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
406 406 for mid in follow.union(after):
407 407 if mid in self.graph:
408 408 self.graph[mid].remove(msg_id)
409 409
410 410 # FIXME: unpacking a message I've already unpacked, but didn't save:
411 411 idents,msg = self.session.feed_identities(raw_msg, copy=False)
412 412 header = self.session.unpack(msg[1].bytes)
413 413
414 414 try:
415 415 raise why()
416 416 except:
417 417 content = error.wrap_exception()
418 418
419 419 self.all_done.add(msg_id)
420 420 self.all_failed.add(msg_id)
421 421
422 422 msg = self.session.send(self.client_stream, 'apply_reply', content,
423 423 parent=header, ident=idents)
424 424 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
425 425
426 426 self.update_graph(msg_id, success=False)
427 427
428 428 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
429 429 """check location dependencies, and run if they are met."""
430 430 blacklist = self.blacklist.setdefault(msg_id, set())
431 431 if follow or targets or blacklist or self.hwm:
432 432 # we need a can_run filter
433 433 def can_run(idx):
434 434 # check hwm
435 435 if self.hwm and self.loads[idx] == self.hwm:
436 436 return False
437 437 target = self.targets[idx]
438 438 # check blacklist
439 439 if target in blacklist:
440 440 return False
441 441 # check targets
442 442 if targets and target not in targets:
443 443 return False
444 444 # check follow
445 445 return follow.check(self.completed[target], self.failed[target])
446 446
447 447 indices = filter(can_run, range(len(self.targets)))
448 448
449 449 if not indices:
450 450 # couldn't run
451 451 if follow.all:
452 452 # check follow for impossibility
453 453 dests = set()
454 454 relevant = set()
455 455 if follow.success:
456 456 relevant = self.all_completed
457 457 if follow.failure:
458 458 relevant = relevant.union(self.all_failed)
459 459 for m in follow.intersection(relevant):
460 460 dests.add(self.destinations[m])
461 461 if len(dests) > 1:
462 462 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
463 463 self.fail_unreachable(msg_id)
464 464 return False
465 465 if targets:
466 466 # check blacklist+targets for impossibility
467 467 targets.difference_update(blacklist)
468 468 if not targets or not targets.intersection(self.targets):
469 469 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
470 470 self.fail_unreachable(msg_id)
471 471 return False
472 472 return False
473 473 else:
474 474 indices = None
475 475
476 476 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
477 477 return True
478 478
479 479 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
480 480 """Save a message for later submission when its dependencies are met."""
481 481 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
482 482 # track the ids in follow or after, but not those already finished
483 483 for dep_id in after.union(follow).difference(self.all_done):
484 484 if dep_id not in self.graph:
485 485 self.graph[dep_id] = set()
486 486 self.graph[dep_id].add(msg_id)
487 487
488 488 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
489 489 """Submit a task to any of a subset of our targets."""
490 490 if indices:
491 491 loads = [self.loads[i] for i in indices]
492 492 else:
493 493 loads = self.loads
494 494 idx = self.scheme(loads)
495 495 if indices:
496 496 idx = indices[idx]
497 497 target = self.targets[idx]
498 498 # print (target, map(str, msg[:3]))
499 499 # send job to the engine
500 500 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
501 501 self.engine_stream.send_multipart(raw_msg, copy=False)
502 502 # update load
503 503 self.add_job(idx)
504 504 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
505 505 # notify Hub
506 506 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
507 507 self.session.send(self.mon_stream, 'task_destination', content=content,
508 508 ident=[b'tracktask',self.ident])
509 509
510 510
511 511 #-----------------------------------------------------------------------
512 512 # Result Handling
513 513 #-----------------------------------------------------------------------
514 514 def dispatch_result(self, raw_msg):
515 515 """dispatch method for result replies"""
516 516 try:
517 517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
518 518 msg = self.session.unserialize(msg, content=False, copy=False)
519 519 engine = idents[0]
520 520 try:
521 521 idx = self.targets.index(engine)
522 522 except ValueError:
523 523 pass # skip load-update for dead engines
524 524 else:
525 525 self.finish_job(idx)
526 526 except Exception:
527 527 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
528 528 return
529 529
530 530 header = msg['header']
531 531 parent = msg['parent_header']
532 532 if header.get('dependencies_met', True):
533 533 success = (header['status'] == 'ok')
534 534 msg_id = parent['msg_id']
535 535 retries = self.retries[msg_id]
536 536 if not success and retries > 0:
537 537 # failed
538 538 self.retries[msg_id] = retries - 1
539 539 self.handle_unmet_dependency(idents, parent)
540 540 else:
541 541 del self.retries[msg_id]
542 542 # relay to client and update graph
543 543 self.handle_result(idents, parent, raw_msg, success)
544 544 # send to Hub monitor
545 545 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
546 546 else:
547 547 self.handle_unmet_dependency(idents, parent)
548 548
549 549 def handle_result(self, idents, parent, raw_msg, success=True):
550 550 """handle a real task result, either success or failure"""
551 551 # first, relay result to client
552 552 engine = idents[0]
553 553 client = idents[1]
554 554 # swap_ids for XREP-XREP mirror
555 555 raw_msg[:2] = [client,engine]
556 556 # print (map(str, raw_msg[:4]))
557 557 self.client_stream.send_multipart(raw_msg, copy=False)
558 558 # now, update our data structures
559 559 msg_id = parent['msg_id']
560 560 self.blacklist.pop(msg_id, None)
561 561 self.pending[engine].pop(msg_id)
562 562 if success:
563 563 self.completed[engine].add(msg_id)
564 564 self.all_completed.add(msg_id)
565 565 else:
566 566 self.failed[engine].add(msg_id)
567 567 self.all_failed.add(msg_id)
568 568 self.all_done.add(msg_id)
569 569 self.destinations[msg_id] = engine
570 570
571 571 self.update_graph(msg_id, success)
572 572
573 573 def handle_unmet_dependency(self, idents, parent):
574 574 """handle an unmet dependency"""
575 575 engine = idents[0]
576 576 msg_id = parent['msg_id']
577 577
578 578 if msg_id not in self.blacklist:
579 579 self.blacklist[msg_id] = set()
580 580 self.blacklist[msg_id].add(engine)
581 581
582 582 args = self.pending[engine].pop(msg_id)
583 583 raw,targets,after,follow,timeout = args
584 584
585 585 if self.blacklist[msg_id] == targets:
586 586 self.depending[msg_id] = args
587 587 self.fail_unreachable(msg_id)
588 588 elif not self.maybe_run(msg_id, *args):
589 589 # resubmit failed
590 590 if msg_id not in self.all_failed:
591 591 # put it back in our dependency tree
592 592 self.save_unmet(msg_id, *args)
593 593
594 594 if self.hwm:
595 595 try:
596 596 idx = self.targets.index(engine)
597 597 except ValueError:
598 598 pass # skip load-update for dead engines
599 599 else:
600 600 if self.loads[idx] == self.hwm-1:
601 601 self.update_graph(None)
602 602
603 603
604 604
605 605 def update_graph(self, dep_id=None, success=True):
606 606 """dep_id just finished. Update our dependency
607 607 graph and submit any jobs that just became runable.
608 608
609 609 Called with dep_id=None to update entire graph for hwm, but without finishing
610 610 a task.
611 611 """
612 612 # print ("\n\n***********")
613 613 # pprint (dep_id)
614 614 # pprint (self.graph)
615 615 # pprint (self.depending)
616 616 # pprint (self.all_completed)
617 617 # pprint (self.all_failed)
618 618 # print ("\n\n***********\n\n")
619 619 # update any jobs that depended on the dependency
620 620 jobs = self.graph.pop(dep_id, [])
621 621
622 622 # recheck *all* jobs if
623 623 # a) we have HWM and an engine just become no longer full
624 624 # or b) dep_id was given as None
625 625 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
626 626 jobs = self.depending.keys()
627 627
628 628 for msg_id in jobs:
629 629 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
630 630
631 631 if after.unreachable(self.all_completed, self.all_failed)\
632 632 or follow.unreachable(self.all_completed, self.all_failed):
633 633 self.fail_unreachable(msg_id)
634 634
635 635 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
636 636 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
637 637
638 638 self.depending.pop(msg_id)
639 639 for mid in follow.union(after):
640 640 if mid in self.graph:
641 641 self.graph[mid].remove(msg_id)
642 642
643 643 #----------------------------------------------------------------------
644 644 # methods to be overridden by subclasses
645 645 #----------------------------------------------------------------------
646 646
647 647 def add_job(self, idx):
648 648 """Called after self.targets[idx] just got the job with header.
649 649 Override with subclasses. The default ordering is simple LRU.
650 650 The default loads are the number of outstanding jobs."""
651 651 self.loads[idx] += 1
652 652 for lis in (self.targets, self.loads):
653 653 lis.append(lis.pop(idx))
654 654
655 655
656 656 def finish_job(self, idx):
657 657 """Called after self.targets[idx] just finished a job.
658 658 Override with subclasses."""
659 659 self.loads[idx] -= 1
660 660
661 661
662 662
663 663 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
664 664 logname='root', log_url=None, loglevel=logging.DEBUG,
665 665 identity=b'task', in_thread=False):
666 666
667 667 ZMQStream = zmqstream.ZMQStream
668 668
669 669 if config:
670 670 # unwrap dict back into Config
671 671 config = Config(config)
672 672
673 673 if in_thread:
674 674 # use instance() to get the same Context/Loop as our parent
675 675 ctx = zmq.Context.instance()
676 676 loop = ioloop.IOLoop.instance()
677 677 else:
678 678 # in a process, don't use instance()
679 679 # for safety with multiprocessing
680 680 ctx = zmq.Context()
681 681 loop = ioloop.IOLoop()
682 682 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
683 683 ins.setsockopt(zmq.IDENTITY, identity)
684 684 ins.bind(in_addr)
685 685
686 686 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
687 687 outs.setsockopt(zmq.IDENTITY, identity)
688 688 outs.bind(out_addr)
689 689 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
690 690 mons.connect(mon_addr)
691 691 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
692 692 nots.setsockopt(zmq.SUBSCRIBE, b'')
693 693 nots.connect(not_addr)
694 694
695 695 # setup logging.
696 696 if in_thread:
697 697 log = Application.instance().log
698 698 else:
699 699 if log_url:
700 700 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
701 701 else:
702 702 log = local_logger(logname, loglevel)
703 703
704 704 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
705 705 mon_stream=mons, notifier_stream=nots,
706 706 loop=loop, log=log,
707 707 config=config)
708 708 scheduler.start()
709 709 if not in_thread:
710 710 try:
711 711 loop.start()
712 712 except KeyboardInterrupt:
713 713 print ("interrupted, exiting...", file=sys.__stderr__)
714 714
@@ -1,983 +1,983 b''
1 1 """Base classes to manage the interaction with a running kernel.
2 2
3 3 TODO
4 4 * Create logger to handle debugging and console messages.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2010 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import atexit
20 20 import errno
21 21 from Queue import Queue, Empty
22 22 from subprocess import Popen
23 23 import signal
24 24 import sys
25 25 from threading import Thread
26 26 import time
27 27 import logging
28 28
29 29 # System library imports.
30 30 import zmq
31 31 from zmq import POLLIN, POLLOUT, POLLERR
32 32 from zmq.eventloop import ioloop
33 33
34 34 # Local imports.
35 35 from IPython.config.loader import Config
36 36 from IPython.utils import io
37 37 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
38 38 from IPython.utils.traitlets import HasTraits, Any, Instance, Type, TCPAddress
39 39 from session import Session, Message
40 40
41 41 #-----------------------------------------------------------------------------
42 42 # Constants and exceptions
43 43 #-----------------------------------------------------------------------------
44 44
45 45 class InvalidPortNumber(Exception):
46 46 pass
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Utility functions
50 50 #-----------------------------------------------------------------------------
51 51
52 52 # some utilities to validate message structure, these might get moved elsewhere
53 53 # if they prove to have more generic utility
54 54
55 55 def validate_string_list(lst):
56 56 """Validate that the input is a list of strings.
57 57
58 58 Raises ValueError if not."""
59 59 if not isinstance(lst, list):
60 60 raise ValueError('input %r must be a list' % lst)
61 61 for x in lst:
62 62 if not isinstance(x, basestring):
63 63 raise ValueError('element %r in list must be a string' % x)
64 64
65 65
66 66 def validate_string_dict(dct):
67 67 """Validate that the input is a dict with string keys and values.
68 68
69 69 Raises ValueError if not."""
70 70 for k,v in dct.iteritems():
71 71 if not isinstance(k, basestring):
72 72 raise ValueError('key %r in dict must be a string' % k)
73 73 if not isinstance(v, basestring):
74 74 raise ValueError('value %r in dict must be a string' % v)
75 75
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # ZMQ Socket Channel classes
79 79 #-----------------------------------------------------------------------------
80 80
81 81 class ZMQSocketChannel(Thread):
82 82 """The base class for the channels that use ZMQ sockets.
83 83 """
84 84 context = None
85 85 session = None
86 86 socket = None
87 87 ioloop = None
88 88 iostate = None
89 89 _address = None
90 90
91 91 def __init__(self, context, session, address):
92 92 """Create a channel
93 93
94 94 Parameters
95 95 ----------
96 96 context : :class:`zmq.Context`
97 97 The ZMQ context to use.
98 98 session : :class:`session.Session`
99 99 The session to use.
100 100 address : tuple
101 101 Standard (ip, port) tuple that the kernel is listening on.
102 102 """
103 103 super(ZMQSocketChannel, self).__init__()
104 104 self.daemon = True
105 105
106 106 self.context = context
107 107 self.session = session
108 108 if address[1] == 0:
109 109 message = 'The port number for a channel cannot be 0.'
110 110 raise InvalidPortNumber(message)
111 111 self._address = address
112 112
113 113 def _run_loop(self):
114 114 """Run my loop, ignoring EINTR events in the poller"""
115 115 while True:
116 116 try:
117 117 self.ioloop.start()
118 118 except zmq.ZMQError as e:
119 119 if e.errno == errno.EINTR:
120 120 continue
121 121 else:
122 122 raise
123 123 else:
124 124 break
125 125
126 126 def stop(self):
127 127 """Stop the channel's activity.
128 128
129 129 This calls :method:`Thread.join` and returns when the thread
130 130 terminates. :class:`RuntimeError` will be raised if
131 131 :method:`self.start` is called again.
132 132 """
133 133 self.join()
134 134
135 135 @property
136 136 def address(self):
137 137 """Get the channel's address as an (ip, port) tuple.
138 138
139 139 By the default, the address is (localhost, 0), where 0 means a random
140 140 port.
141 141 """
142 142 return self._address
143 143
144 144 def add_io_state(self, state):
145 145 """Add IO state to the eventloop.
146 146
147 147 Parameters
148 148 ----------
149 149 state : zmq.POLLIN|zmq.POLLOUT|zmq.POLLERR
150 150 The IO state flag to set.
151 151
152 152 This is thread safe as it uses the thread safe IOLoop.add_callback.
153 153 """
154 154 def add_io_state_callback():
155 155 if not self.iostate & state:
156 156 self.iostate = self.iostate | state
157 157 self.ioloop.update_handler(self.socket, self.iostate)
158 158 self.ioloop.add_callback(add_io_state_callback)
159 159
160 160 def drop_io_state(self, state):
161 161 """Drop IO state from the eventloop.
162 162
163 163 Parameters
164 164 ----------
165 165 state : zmq.POLLIN|zmq.POLLOUT|zmq.POLLERR
166 166 The IO state flag to set.
167 167
168 168 This is thread safe as it uses the thread safe IOLoop.add_callback.
169 169 """
170 170 def drop_io_state_callback():
171 171 if self.iostate & state:
172 172 self.iostate = self.iostate & (~state)
173 173 self.ioloop.update_handler(self.socket, self.iostate)
174 174 self.ioloop.add_callback(drop_io_state_callback)
175 175
176 176
177 177 class ShellSocketChannel(ZMQSocketChannel):
178 178 """The XREQ channel for issues request/replies to the kernel.
179 179 """
180 180
181 181 command_queue = None
182 182
183 183 def __init__(self, context, session, address):
184 184 super(ShellSocketChannel, self).__init__(context, session, address)
185 185 self.command_queue = Queue()
186 186 self.ioloop = ioloop.IOLoop()
187 187
188 188 def run(self):
189 189 """The thread's main activity. Call start() instead."""
190 190 self.socket = self.context.socket(zmq.DEALER)
191 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
191 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
192 192 self.socket.connect('tcp://%s:%i' % self.address)
193 193 self.iostate = POLLERR|POLLIN
194 194 self.ioloop.add_handler(self.socket, self._handle_events,
195 195 self.iostate)
196 196 self._run_loop()
197 197
198 198 def stop(self):
199 199 self.ioloop.stop()
200 200 super(ShellSocketChannel, self).stop()
201 201
202 202 def call_handlers(self, msg):
203 203 """This method is called in the ioloop thread when a message arrives.
204 204
205 205 Subclasses should override this method to handle incoming messages.
206 206 It is important to remember that this method is called in the thread
207 207 so that some logic must be done to ensure that the application leve
208 208 handlers are called in the application thread.
209 209 """
210 210 raise NotImplementedError('call_handlers must be defined in a subclass.')
211 211
212 212 def execute(self, code, silent=False,
213 213 user_variables=None, user_expressions=None):
214 214 """Execute code in the kernel.
215 215
216 216 Parameters
217 217 ----------
218 218 code : str
219 219 A string of Python code.
220 220
221 221 silent : bool, optional (default False)
222 222 If set, the kernel will execute the code as quietly possible.
223 223
224 224 user_variables : list, optional
225 225 A list of variable names to pull from the user's namespace. They
226 226 will come back as a dict with these names as keys and their
227 227 :func:`repr` as values.
228 228
229 229 user_expressions : dict, optional
230 230 A dict with string keys and to pull from the user's
231 231 namespace. They will come back as a dict with these names as keys
232 232 and their :func:`repr` as values.
233 233
234 234 Returns
235 235 -------
236 236 The msg_id of the message sent.
237 237 """
238 238 if user_variables is None:
239 239 user_variables = []
240 240 if user_expressions is None:
241 241 user_expressions = {}
242 242
243 243 # Don't waste network traffic if inputs are invalid
244 244 if not isinstance(code, basestring):
245 245 raise ValueError('code %r must be a string' % code)
246 246 validate_string_list(user_variables)
247 247 validate_string_dict(user_expressions)
248 248
249 249 # Create class for content/msg creation. Related to, but possibly
250 250 # not in Session.
251 251 content = dict(code=code, silent=silent,
252 252 user_variables=user_variables,
253 253 user_expressions=user_expressions)
254 254 msg = self.session.msg('execute_request', content)
255 255 self._queue_request(msg)
256 256 return msg['header']['msg_id']
257 257
258 258 def complete(self, text, line, cursor_pos, block=None):
259 259 """Tab complete text in the kernel's namespace.
260 260
261 261 Parameters
262 262 ----------
263 263 text : str
264 264 The text to complete.
265 265 line : str
266 266 The full line of text that is the surrounding context for the
267 267 text to complete.
268 268 cursor_pos : int
269 269 The position of the cursor in the line where the completion was
270 270 requested.
271 271 block : str, optional
272 272 The full block of code in which the completion is being requested.
273 273
274 274 Returns
275 275 -------
276 276 The msg_id of the message sent.
277 277 """
278 278 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
279 279 msg = self.session.msg('complete_request', content)
280 280 self._queue_request(msg)
281 281 return msg['header']['msg_id']
282 282
283 283 def object_info(self, oname):
284 284 """Get metadata information about an object.
285 285
286 286 Parameters
287 287 ----------
288 288 oname : str
289 289 A string specifying the object name.
290 290
291 291 Returns
292 292 -------
293 293 The msg_id of the message sent.
294 294 """
295 295 content = dict(oname=oname)
296 296 msg = self.session.msg('object_info_request', content)
297 297 self._queue_request(msg)
298 298 return msg['header']['msg_id']
299 299
300 300 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
301 301 """Get entries from the history list.
302 302
303 303 Parameters
304 304 ----------
305 305 raw : bool
306 306 If True, return the raw input.
307 307 output : bool
308 308 If True, then return the output as well.
309 309 hist_access_type : str
310 310 'range' (fill in session, start and stop params), 'tail' (fill in n)
311 311 or 'search' (fill in pattern param).
312 312
313 313 session : int
314 314 For a range request, the session from which to get lines. Session
315 315 numbers are positive integers; negative ones count back from the
316 316 current session.
317 317 start : int
318 318 The first line number of a history range.
319 319 stop : int
320 320 The final (excluded) line number of a history range.
321 321
322 322 n : int
323 323 The number of lines of history to get for a tail request.
324 324
325 325 pattern : str
326 326 The glob-syntax pattern for a search request.
327 327
328 328 Returns
329 329 -------
330 330 The msg_id of the message sent.
331 331 """
332 332 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
333 333 **kwargs)
334 334 msg = self.session.msg('history_request', content)
335 335 self._queue_request(msg)
336 336 return msg['header']['msg_id']
337 337
338 338 def shutdown(self, restart=False):
339 339 """Request an immediate kernel shutdown.
340 340
341 341 Upon receipt of the (empty) reply, client code can safely assume that
342 342 the kernel has shut down and it's safe to forcefully terminate it if
343 343 it's still alive.
344 344
345 345 The kernel will send the reply via a function registered with Python's
346 346 atexit module, ensuring it's truly done as the kernel is done with all
347 347 normal operation.
348 348 """
349 349 # Send quit message to kernel. Once we implement kernel-side setattr,
350 350 # this should probably be done that way, but for now this will do.
351 351 msg = self.session.msg('shutdown_request', {'restart':restart})
352 352 self._queue_request(msg)
353 353 return msg['header']['msg_id']
354 354
355 355 def _handle_events(self, socket, events):
356 356 if events & POLLERR:
357 357 self._handle_err()
358 358 if events & POLLOUT:
359 359 self._handle_send()
360 360 if events & POLLIN:
361 361 self._handle_recv()
362 362
363 363 def _handle_recv(self):
364 364 ident,msg = self.session.recv(self.socket, 0)
365 365 self.call_handlers(msg)
366 366
367 367 def _handle_send(self):
368 368 try:
369 369 msg = self.command_queue.get(False)
370 370 except Empty:
371 371 pass
372 372 else:
373 373 self.session.send(self.socket,msg)
374 374 if self.command_queue.empty():
375 375 self.drop_io_state(POLLOUT)
376 376
377 377 def _handle_err(self):
378 378 # We don't want to let this go silently, so eventually we should log.
379 379 raise zmq.ZMQError()
380 380
381 381 def _queue_request(self, msg):
382 382 self.command_queue.put(msg)
383 383 self.add_io_state(POLLOUT)
384 384
385 385
386 386 class SubSocketChannel(ZMQSocketChannel):
387 387 """The SUB channel which listens for messages that the kernel publishes.
388 388 """
389 389
390 390 def __init__(self, context, session, address):
391 391 super(SubSocketChannel, self).__init__(context, session, address)
392 392 self.ioloop = ioloop.IOLoop()
393 393
394 394 def run(self):
395 395 """The thread's main activity. Call start() instead."""
396 396 self.socket = self.context.socket(zmq.SUB)
397 397 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
398 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
398 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
399 399 self.socket.connect('tcp://%s:%i' % self.address)
400 400 self.iostate = POLLIN|POLLERR
401 401 self.ioloop.add_handler(self.socket, self._handle_events,
402 402 self.iostate)
403 403 self._run_loop()
404 404
405 405 def stop(self):
406 406 self.ioloop.stop()
407 407 super(SubSocketChannel, self).stop()
408 408
409 409 def call_handlers(self, msg):
410 410 """This method is called in the ioloop thread when a message arrives.
411 411
412 412 Subclasses should override this method to handle incoming messages.
413 413 It is important to remember that this method is called in the thread
414 414 so that some logic must be done to ensure that the application leve
415 415 handlers are called in the application thread.
416 416 """
417 417 raise NotImplementedError('call_handlers must be defined in a subclass.')
418 418
419 419 def flush(self, timeout=1.0):
420 420 """Immediately processes all pending messages on the SUB channel.
421 421
422 422 Callers should use this method to ensure that :method:`call_handlers`
423 423 has been called for all messages that have been received on the
424 424 0MQ SUB socket of this channel.
425 425
426 426 This method is thread safe.
427 427
428 428 Parameters
429 429 ----------
430 430 timeout : float, optional
431 431 The maximum amount of time to spend flushing, in seconds. The
432 432 default is one second.
433 433 """
434 434 # We do the IOLoop callback process twice to ensure that the IOLoop
435 435 # gets to perform at least one full poll.
436 436 stop_time = time.time() + timeout
437 437 for i in xrange(2):
438 438 self._flushed = False
439 439 self.ioloop.add_callback(self._flush)
440 440 while not self._flushed and time.time() < stop_time:
441 441 time.sleep(0.01)
442 442
443 443 def _handle_events(self, socket, events):
444 444 # Turn on and off POLLOUT depending on if we have made a request
445 445 if events & POLLERR:
446 446 self._handle_err()
447 447 if events & POLLIN:
448 448 self._handle_recv()
449 449
450 450 def _handle_err(self):
451 451 # We don't want to let this go silently, so eventually we should log.
452 452 raise zmq.ZMQError()
453 453
454 454 def _handle_recv(self):
455 455 # Get all of the messages we can
456 456 while True:
457 457 try:
458 458 ident,msg = self.session.recv(self.socket)
459 459 except zmq.ZMQError:
460 460 # Check the errno?
461 461 # Will this trigger POLLERR?
462 462 break
463 463 else:
464 464 if msg is None:
465 465 break
466 466 self.call_handlers(msg)
467 467
468 468 def _flush(self):
469 469 """Callback for :method:`self.flush`."""
470 470 self._flushed = True
471 471
472 472
473 473 class StdInSocketChannel(ZMQSocketChannel):
474 474 """A reply channel to handle raw_input requests that the kernel makes."""
475 475
476 476 msg_queue = None
477 477
478 478 def __init__(self, context, session, address):
479 479 super(StdInSocketChannel, self).__init__(context, session, address)
480 480 self.ioloop = ioloop.IOLoop()
481 481 self.msg_queue = Queue()
482 482
483 483 def run(self):
484 484 """The thread's main activity. Call start() instead."""
485 485 self.socket = self.context.socket(zmq.DEALER)
486 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
486 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
487 487 self.socket.connect('tcp://%s:%i' % self.address)
488 488 self.iostate = POLLERR|POLLIN
489 489 self.ioloop.add_handler(self.socket, self._handle_events,
490 490 self.iostate)
491 491 self._run_loop()
492 492
493 493 def stop(self):
494 494 self.ioloop.stop()
495 495 super(StdInSocketChannel, self).stop()
496 496
497 497 def call_handlers(self, msg):
498 498 """This method is called in the ioloop thread when a message arrives.
499 499
500 500 Subclasses should override this method to handle incoming messages.
501 501 It is important to remember that this method is called in the thread
502 502 so that some logic must be done to ensure that the application leve
503 503 handlers are called in the application thread.
504 504 """
505 505 raise NotImplementedError('call_handlers must be defined in a subclass.')
506 506
507 507 def input(self, string):
508 508 """Send a string of raw input to the kernel."""
509 509 content = dict(value=string)
510 510 msg = self.session.msg('input_reply', content)
511 511 self._queue_reply(msg)
512 512
513 513 def _handle_events(self, socket, events):
514 514 if events & POLLERR:
515 515 self._handle_err()
516 516 if events & POLLOUT:
517 517 self._handle_send()
518 518 if events & POLLIN:
519 519 self._handle_recv()
520 520
521 521 def _handle_recv(self):
522 522 ident,msg = self.session.recv(self.socket, 0)
523 523 self.call_handlers(msg)
524 524
525 525 def _handle_send(self):
526 526 try:
527 527 msg = self.msg_queue.get(False)
528 528 except Empty:
529 529 pass
530 530 else:
531 531 self.session.send(self.socket,msg)
532 532 if self.msg_queue.empty():
533 533 self.drop_io_state(POLLOUT)
534 534
535 535 def _handle_err(self):
536 536 # We don't want to let this go silently, so eventually we should log.
537 537 raise zmq.ZMQError()
538 538
539 539 def _queue_reply(self, msg):
540 540 self.msg_queue.put(msg)
541 541 self.add_io_state(POLLOUT)
542 542
543 543
544 544 class HBSocketChannel(ZMQSocketChannel):
545 545 """The heartbeat channel which monitors the kernel heartbeat.
546 546
547 547 Note that the heartbeat channel is paused by default. As long as you start
548 548 this channel, the kernel manager will ensure that it is paused and un-paused
549 549 as appropriate.
550 550 """
551 551
552 552 time_to_dead = 3.0
553 553 socket = None
554 554 poller = None
555 555 _running = None
556 556 _pause = None
557 557
558 558 def __init__(self, context, session, address):
559 559 super(HBSocketChannel, self).__init__(context, session, address)
560 560 self._running = False
561 561 self._pause = True
562 562
563 563 def _create_socket(self):
564 564 self.socket = self.context.socket(zmq.REQ)
565 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
565 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
566 566 self.socket.connect('tcp://%s:%i' % self.address)
567 567 self.poller = zmq.Poller()
568 568 self.poller.register(self.socket, zmq.POLLIN)
569 569
570 570 def run(self):
571 571 """The thread's main activity. Call start() instead."""
572 572 self._create_socket()
573 573 self._running = True
574 574 while self._running:
575 575 if self._pause:
576 576 time.sleep(self.time_to_dead)
577 577 else:
578 578 since_last_heartbeat = 0.0
579 579 request_time = time.time()
580 580 try:
581 581 #io.rprint('Ping from HB channel') # dbg
582 582 self.socket.send(b'ping')
583 583 except zmq.ZMQError, e:
584 584 #io.rprint('*** HB Error:', e) # dbg
585 585 if e.errno == zmq.EFSM:
586 586 #io.rprint('sleep...', self.time_to_dead) # dbg
587 587 time.sleep(self.time_to_dead)
588 588 self._create_socket()
589 589 else:
590 590 raise
591 591 else:
592 592 while True:
593 593 try:
594 594 self.socket.recv(zmq.NOBLOCK)
595 595 except zmq.ZMQError, e:
596 596 #io.rprint('*** HB Error 2:', e) # dbg
597 597 if e.errno == zmq.EAGAIN:
598 598 before_poll = time.time()
599 599 until_dead = self.time_to_dead - (before_poll -
600 600 request_time)
601 601
602 602 # When the return value of poll() is an empty
603 603 # list, that is when things have gone wrong
604 604 # (zeromq bug). As long as it is not an empty
605 605 # list, poll is working correctly even if it
606 606 # returns quickly. Note: poll timeout is in
607 607 # milliseconds.
608 608 if until_dead > 0.0:
609 609 while True:
610 610 try:
611 611 self.poller.poll(1000 * until_dead)
612 612 except zmq.ZMQError as e:
613 613 if e.errno == errno.EINTR:
614 614 continue
615 615 else:
616 616 raise
617 617 else:
618 618 break
619 619
620 620 since_last_heartbeat = time.time()-request_time
621 621 if since_last_heartbeat > self.time_to_dead:
622 622 self.call_handlers(since_last_heartbeat)
623 623 break
624 624 else:
625 625 # FIXME: We should probably log this instead.
626 626 raise
627 627 else:
628 628 until_dead = self.time_to_dead - (time.time() -
629 629 request_time)
630 630 if until_dead > 0.0:
631 631 #io.rprint('sleep...', self.time_to_dead) # dbg
632 632 time.sleep(until_dead)
633 633 break
634 634
635 635 def pause(self):
636 636 """Pause the heartbeat."""
637 637 self._pause = True
638 638
639 639 def unpause(self):
640 640 """Unpause the heartbeat."""
641 641 self._pause = False
642 642
643 643 def is_beating(self):
644 644 """Is the heartbeat running and not paused."""
645 645 if self.is_alive() and not self._pause:
646 646 return True
647 647 else:
648 648 return False
649 649
650 650 def stop(self):
651 651 self._running = False
652 652 super(HBSocketChannel, self).stop()
653 653
654 654 def call_handlers(self, since_last_heartbeat):
655 655 """This method is called in the ioloop thread when a message arrives.
656 656
657 657 Subclasses should override this method to handle incoming messages.
658 658 It is important to remember that this method is called in the thread
659 659 so that some logic must be done to ensure that the application leve
660 660 handlers are called in the application thread.
661 661 """
662 662 raise NotImplementedError('call_handlers must be defined in a subclass.')
663 663
664 664
665 665 #-----------------------------------------------------------------------------
666 666 # Main kernel manager class
667 667 #-----------------------------------------------------------------------------
668 668
669 669 class KernelManager(HasTraits):
670 670 """ Manages a kernel for a frontend.
671 671
672 672 The SUB channel is for the frontend to receive messages published by the
673 673 kernel.
674 674
675 675 The REQ channel is for the frontend to make requests of the kernel.
676 676
677 677 The REP channel is for the kernel to request stdin (raw_input) from the
678 678 frontend.
679 679 """
680 680 # config object for passing to child configurables
681 681 config = Instance(Config)
682 682
683 683 # The PyZMQ Context to use for communication with the kernel.
684 684 context = Instance(zmq.Context)
685 685 def _context_default(self):
686 686 return zmq.Context.instance()
687 687
688 688 # The Session to use for communication with the kernel.
689 689 session = Instance(Session)
690 690
691 691 # The kernel process with which the KernelManager is communicating.
692 692 kernel = Instance(Popen)
693 693
694 694 # The addresses for the communication channels.
695 695 shell_address = TCPAddress((LOCALHOST, 0))
696 696 sub_address = TCPAddress((LOCALHOST, 0))
697 697 stdin_address = TCPAddress((LOCALHOST, 0))
698 698 hb_address = TCPAddress((LOCALHOST, 0))
699 699
700 700 # The classes to use for the various channels.
701 701 shell_channel_class = Type(ShellSocketChannel)
702 702 sub_channel_class = Type(SubSocketChannel)
703 703 stdin_channel_class = Type(StdInSocketChannel)
704 704 hb_channel_class = Type(HBSocketChannel)
705 705
706 706 # Protected traits.
707 707 _launch_args = Any
708 708 _shell_channel = Any
709 709 _sub_channel = Any
710 710 _stdin_channel = Any
711 711 _hb_channel = Any
712 712
713 713 def __init__(self, **kwargs):
714 714 super(KernelManager, self).__init__(**kwargs)
715 715 if self.session is None:
716 716 self.session = Session(config=self.config)
717 717 # Uncomment this to try closing the context.
718 718 # atexit.register(self.context.term)
719 719
720 720 #--------------------------------------------------------------------------
721 721 # Channel management methods:
722 722 #--------------------------------------------------------------------------
723 723
724 724 def start_channels(self, shell=True, sub=True, stdin=True, hb=True):
725 725 """Starts the channels for this kernel.
726 726
727 727 This will create the channels if they do not exist and then start
728 728 them. If port numbers of 0 are being used (random ports) then you
729 729 must first call :method:`start_kernel`. If the channels have been
730 730 stopped and you call this, :class:`RuntimeError` will be raised.
731 731 """
732 732 if shell:
733 733 self.shell_channel.start()
734 734 if sub:
735 735 self.sub_channel.start()
736 736 if stdin:
737 737 self.stdin_channel.start()
738 738 if hb:
739 739 self.hb_channel.start()
740 740
741 741 def stop_channels(self):
742 742 """Stops all the running channels for this kernel.
743 743 """
744 744 if self.shell_channel.is_alive():
745 745 self.shell_channel.stop()
746 746 if self.sub_channel.is_alive():
747 747 self.sub_channel.stop()
748 748 if self.stdin_channel.is_alive():
749 749 self.stdin_channel.stop()
750 750 if self.hb_channel.is_alive():
751 751 self.hb_channel.stop()
752 752
753 753 @property
754 754 def channels_running(self):
755 755 """Are any of the channels created and running?"""
756 756 return (self.shell_channel.is_alive() or self.sub_channel.is_alive() or
757 757 self.stdin_channel.is_alive() or self.hb_channel.is_alive())
758 758
759 759 #--------------------------------------------------------------------------
760 760 # Kernel process management methods:
761 761 #--------------------------------------------------------------------------
762 762
763 763 def start_kernel(self, **kw):
764 764 """Starts a kernel process and configures the manager to use it.
765 765
766 766 If random ports (port=0) are being used, this method must be called
767 767 before the channels are created.
768 768
769 769 Parameters:
770 770 -----------
771 771 ipython : bool, optional (default True)
772 772 Whether to use an IPython kernel instead of a plain Python kernel.
773 773
774 774 launcher : callable, optional (default None)
775 775 A custom function for launching the kernel process (generally a
776 776 wrapper around ``entry_point.base_launch_kernel``). In most cases,
777 777 it should not be necessary to use this parameter.
778 778
779 779 **kw : optional
780 780 See respective options for IPython and Python kernels.
781 781 """
782 782 shell, sub, stdin, hb = self.shell_address, self.sub_address, \
783 783 self.stdin_address, self.hb_address
784 784 if shell[0] not in LOCAL_IPS or sub[0] not in LOCAL_IPS or \
785 785 stdin[0] not in LOCAL_IPS or hb[0] not in LOCAL_IPS:
786 786 raise RuntimeError("Can only launch a kernel on a local interface. "
787 787 "Make sure that the '*_address' attributes are "
788 788 "configured properly. "
789 789 "Currently valid addresses are: %s"%LOCAL_IPS
790 790 )
791 791
792 792 self._launch_args = kw.copy()
793 793 launch_kernel = kw.pop('launcher', None)
794 794 if launch_kernel is None:
795 795 if kw.pop('ipython', True):
796 796 from ipkernel import launch_kernel
797 797 else:
798 798 from pykernel import launch_kernel
799 799 self.kernel, xrep, pub, req, _hb = launch_kernel(
800 800 shell_port=shell[1], iopub_port=sub[1],
801 801 stdin_port=stdin[1], hb_port=hb[1], **kw)
802 802 self.shell_address = (shell[0], xrep)
803 803 self.sub_address = (sub[0], pub)
804 804 self.stdin_address = (stdin[0], req)
805 805 self.hb_address = (hb[0], _hb)
806 806
807 807 def shutdown_kernel(self, restart=False):
808 808 """ Attempts to the stop the kernel process cleanly. If the kernel
809 809 cannot be stopped, it is killed, if possible.
810 810 """
811 811 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
812 812 if sys.platform == 'win32':
813 813 self.kill_kernel()
814 814 return
815 815
816 816 # Pause the heart beat channel if it exists.
817 817 if self._hb_channel is not None:
818 818 self._hb_channel.pause()
819 819
820 820 # Don't send any additional kernel kill messages immediately, to give
821 821 # the kernel a chance to properly execute shutdown actions. Wait for at
822 822 # most 1s, checking every 0.1s.
823 823 self.shell_channel.shutdown(restart=restart)
824 824 for i in range(10):
825 825 if self.is_alive:
826 826 time.sleep(0.1)
827 827 else:
828 828 break
829 829 else:
830 830 # OK, we've waited long enough.
831 831 if self.has_kernel:
832 832 self.kill_kernel()
833 833
834 834 def restart_kernel(self, now=False, **kw):
835 835 """Restarts a kernel with the arguments that were used to launch it.
836 836
837 837 If the old kernel was launched with random ports, the same ports will be
838 838 used for the new kernel.
839 839
840 840 Parameters
841 841 ----------
842 842 now : bool, optional
843 843 If True, the kernel is forcefully restarted *immediately*, without
844 844 having a chance to do any cleanup action. Otherwise the kernel is
845 845 given 1s to clean up before a forceful restart is issued.
846 846
847 847 In all cases the kernel is restarted, the only difference is whether
848 848 it is given a chance to perform a clean shutdown or not.
849 849
850 850 **kw : optional
851 851 Any options specified here will replace those used to launch the
852 852 kernel.
853 853 """
854 854 if self._launch_args is None:
855 855 raise RuntimeError("Cannot restart the kernel. "
856 856 "No previous call to 'start_kernel'.")
857 857 else:
858 858 # Stop currently running kernel.
859 859 if self.has_kernel:
860 860 if now:
861 861 self.kill_kernel()
862 862 else:
863 863 self.shutdown_kernel(restart=True)
864 864
865 865 # Start new kernel.
866 866 self._launch_args.update(kw)
867 867 self.start_kernel(**self._launch_args)
868 868
869 869 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
870 870 # unless there is some delay here.
871 871 if sys.platform == 'win32':
872 872 time.sleep(0.2)
873 873
874 874 @property
875 875 def has_kernel(self):
876 876 """Returns whether a kernel process has been specified for the kernel
877 877 manager.
878 878 """
879 879 return self.kernel is not None
880 880
881 881 def kill_kernel(self):
882 882 """ Kill the running kernel. """
883 883 if self.has_kernel:
884 884 # Pause the heart beat channel if it exists.
885 885 if self._hb_channel is not None:
886 886 self._hb_channel.pause()
887 887
888 888 # Attempt to kill the kernel.
889 889 try:
890 890 self.kernel.kill()
891 891 except OSError, e:
892 892 # In Windows, we will get an Access Denied error if the process
893 893 # has already terminated. Ignore it.
894 894 if sys.platform == 'win32':
895 895 if e.winerror != 5:
896 896 raise
897 897 # On Unix, we may get an ESRCH error if the process has already
898 898 # terminated. Ignore it.
899 899 else:
900 900 from errno import ESRCH
901 901 if e.errno != ESRCH:
902 902 raise
903 903 self.kernel = None
904 904 else:
905 905 raise RuntimeError("Cannot kill kernel. No kernel is running!")
906 906
907 907 def interrupt_kernel(self):
908 908 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
909 909 well supported on all platforms.
910 910 """
911 911 if self.has_kernel:
912 912 if sys.platform == 'win32':
913 913 from parentpoller import ParentPollerWindows as Poller
914 914 Poller.send_interrupt(self.kernel.win32_interrupt_event)
915 915 else:
916 916 self.kernel.send_signal(signal.SIGINT)
917 917 else:
918 918 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
919 919
920 920 def signal_kernel(self, signum):
921 921 """ Sends a signal to the kernel. Note that since only SIGTERM is
922 922 supported on Windows, this function is only useful on Unix systems.
923 923 """
924 924 if self.has_kernel:
925 925 self.kernel.send_signal(signum)
926 926 else:
927 927 raise RuntimeError("Cannot signal kernel. No kernel is running!")
928 928
929 929 @property
930 930 def is_alive(self):
931 931 """Is the kernel process still running?"""
932 932 # FIXME: not using a heartbeat means this method is broken for any
933 933 # remote kernel, it's only capable of handling local kernels.
934 934 if self.has_kernel:
935 935 if self.kernel.poll() is None:
936 936 return True
937 937 else:
938 938 return False
939 939 else:
940 940 # We didn't start the kernel with this KernelManager so we don't
941 941 # know if it is running. We should use a heartbeat for this case.
942 942 return True
943 943
944 944 #--------------------------------------------------------------------------
945 945 # Channels used for communication with the kernel:
946 946 #--------------------------------------------------------------------------
947 947
948 948 @property
949 949 def shell_channel(self):
950 950 """Get the REQ socket channel object to make requests of the kernel."""
951 951 if self._shell_channel is None:
952 952 self._shell_channel = self.shell_channel_class(self.context,
953 953 self.session,
954 954 self.shell_address)
955 955 return self._shell_channel
956 956
957 957 @property
958 958 def sub_channel(self):
959 959 """Get the SUB socket channel object."""
960 960 if self._sub_channel is None:
961 961 self._sub_channel = self.sub_channel_class(self.context,
962 962 self.session,
963 963 self.sub_address)
964 964 return self._sub_channel
965 965
966 966 @property
967 967 def stdin_channel(self):
968 968 """Get the REP socket channel object to handle stdin (raw_input)."""
969 969 if self._stdin_channel is None:
970 970 self._stdin_channel = self.stdin_channel_class(self.context,
971 971 self.session,
972 972 self.stdin_address)
973 973 return self._stdin_channel
974 974
975 975 @property
976 976 def hb_channel(self):
977 977 """Get the heartbeat socket channel object to check that the
978 978 kernel is alive."""
979 979 if self._hb_channel is None:
980 980 self._hb_channel = self.hb_channel_class(self.context,
981 981 self.session,
982 982 self.hb_address)
983 983 return self._hb_channel
@@ -1,704 +1,716 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hmac
28 28 import logging
29 29 import os
30 30 import pprint
31 31 import uuid
32 32 from datetime import datetime
33 33
34 34 try:
35 35 import cPickle
36 36 pickle = cPickle
37 37 except:
38 38 cPickle = None
39 39 import pickle
40 40
41 41 import zmq
42 42 from zmq.utils import jsonapi
43 43 from zmq.eventloop.ioloop import IOLoop
44 44 from zmq.eventloop.zmqstream import ZMQStream
45 45
46 46 from IPython.config.configurable import Configurable, LoggingConfigurable
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 49 from IPython.utils.py3compat import str_to_bytes
50 50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 51 DottedObjectName, CUnicode)
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # utility functions
55 55 #-----------------------------------------------------------------------------
56 56
57 57 def squash_unicode(obj):
58 58 """coerce unicode back to bytestrings."""
59 59 if isinstance(obj,dict):
60 60 for key in obj.keys():
61 61 obj[key] = squash_unicode(obj[key])
62 62 if isinstance(key, unicode):
63 63 obj[squash_unicode(key)] = obj.pop(key)
64 64 elif isinstance(obj, list):
65 65 for i,v in enumerate(obj):
66 66 obj[i] = squash_unicode(v)
67 67 elif isinstance(obj, unicode):
68 68 obj = obj.encode('utf8')
69 69 return obj
70 70
71 71 #-----------------------------------------------------------------------------
72 72 # globals and defaults
73 73 #-----------------------------------------------------------------------------
74 74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77 77
78 78 pickle_packer = lambda o: pickle.dumps(o,-1)
79 79 pickle_unpacker = pickle.loads
80 80
81 81 default_packer = json_packer
82 82 default_unpacker = json_unpacker
83 83
84 84
85 85 DELIM=b"<IDS|MSG>"
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # Classes
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SessionFactory(LoggingConfigurable):
92 92 """The Base class for configurables that have a Session, Context, logger,
93 93 and IOLoop.
94 94 """
95 95
96 96 logname = Unicode('')
97 97 def _logname_changed(self, name, old, new):
98 98 self.log = logging.getLogger(new)
99 99
100 100 # not configurable:
101 101 context = Instance('zmq.Context')
102 102 def _context_default(self):
103 103 return zmq.Context.instance()
104 104
105 105 session = Instance('IPython.zmq.session.Session')
106 106
107 107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 108 def _loop_default(self):
109 109 return IOLoop.instance()
110 110
111 111 def __init__(self, **kwargs):
112 112 super(SessionFactory, self).__init__(**kwargs)
113 113
114 114 if self.session is None:
115 115 # construct the session
116 116 self.session = Session(**kwargs)
117 117
118 118
119 119 class Message(object):
120 120 """A simple message object that maps dict keys to attributes.
121 121
122 122 A Message can be created from a dict and a dict from a Message instance
123 123 simply by calling dict(msg_obj)."""
124 124
125 125 def __init__(self, msg_dict):
126 126 dct = self.__dict__
127 127 for k, v in dict(msg_dict).iteritems():
128 128 if isinstance(v, dict):
129 129 v = Message(v)
130 130 dct[k] = v
131 131
132 132 # Having this iterator lets dict(msg_obj) work out of the box.
133 133 def __iter__(self):
134 134 return iter(self.__dict__.iteritems())
135 135
136 136 def __repr__(self):
137 137 return repr(self.__dict__)
138 138
139 139 def __str__(self):
140 140 return pprint.pformat(self.__dict__)
141 141
142 142 def __contains__(self, k):
143 143 return k in self.__dict__
144 144
145 145 def __getitem__(self, k):
146 146 return self.__dict__[k]
147 147
148 148
149 149 def msg_header(msg_id, msg_type, username, session):
150 150 date = datetime.now()
151 151 return locals()
152 152
153 153 def extract_header(msg_or_header):
154 154 """Given a message or header, return the header."""
155 155 if not msg_or_header:
156 156 return {}
157 157 try:
158 158 # See if msg_or_header is the entire message.
159 159 h = msg_or_header['header']
160 160 except KeyError:
161 161 try:
162 162 # See if msg_or_header is just the header
163 163 h = msg_or_header['msg_id']
164 164 except KeyError:
165 165 raise
166 166 else:
167 167 h = msg_or_header
168 168 if not isinstance(h, dict):
169 169 h = dict(h)
170 170 return h
171 171
172 172 class Session(Configurable):
173 173 """Object for handling serialization and sending of messages.
174 174
175 175 The Session object handles building messages and sending them
176 176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 177 other over the network via Session objects, and only need to work with the
178 178 dict-based IPython message spec. The Session will handle
179 179 serialization/deserialization, security, and metadata.
180 180
181 181 Sessions support configurable serialiization via packer/unpacker traits,
182 182 and signing with HMAC digests via the key/keyfile traits.
183 183
184 184 Parameters
185 185 ----------
186 186
187 187 debug : bool
188 188 whether to trigger extra debugging statements
189 189 packer/unpacker : str : 'json', 'pickle' or import_string
190 190 importstrings for methods to serialize message parts. If just
191 191 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 192 Otherwise, the entire importstring must be used.
193 193
194 194 The functions must accept at least valid JSON input, and output *bytes*.
195 195
196 196 For example, to use msgpack:
197 197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 198 pack/unpack : callables
199 199 You can also set the pack/unpack callables for serialization directly.
200 200 session : bytes
201 201 the ID of this Session object. The default is to generate a new UUID.
202 202 username : unicode
203 203 username added to message headers. The default is to ask the OS.
204 204 key : bytes
205 205 The key used to initialize an HMAC signature. If unset, messages
206 206 will not be signed or checked.
207 207 keyfile : filepath
208 208 The file containing a key. If this is set, `key` will be initialized
209 209 to the contents of the file.
210 210
211 211 """
212 212
213 213 debug=Bool(False, config=True, help="""Debug output in the Session""")
214 214
215 215 packer = DottedObjectName('json',config=True,
216 216 help="""The name of the packer for serializing messages.
217 217 Should be one of 'json', 'pickle', or an import name
218 218 for a custom callable serializer.""")
219 219 def _packer_changed(self, name, old, new):
220 220 if new.lower() == 'json':
221 221 self.pack = json_packer
222 222 self.unpack = json_unpacker
223 223 elif new.lower() == 'pickle':
224 224 self.pack = pickle_packer
225 225 self.unpack = pickle_unpacker
226 226 else:
227 227 self.pack = import_item(str(new))
228 228
229 229 unpacker = DottedObjectName('json', config=True,
230 230 help="""The name of the unpacker for unserializing messages.
231 231 Only used with custom functions for `packer`.""")
232 232 def _unpacker_changed(self, name, old, new):
233 233 if new.lower() == 'json':
234 234 self.pack = json_packer
235 235 self.unpack = json_unpacker
236 236 elif new.lower() == 'pickle':
237 237 self.pack = pickle_packer
238 238 self.unpack = pickle_unpacker
239 239 else:
240 240 self.unpack = import_item(str(new))
241 241
242 242 session = CUnicode(u'', config=True,
243 243 help="""The UUID identifying this session.""")
244 244 def _session_default(self):
245 return unicode(uuid.uuid4())
245 u = unicode(uuid.uuid4())
246 self.bsession = u.encode('ascii')
247 return u
248
249 def _session_changed(self, name, old, new):
250 self.bsession = self.session.encode('ascii')
251
252 # bsession is the session as bytes
253 bsession = CBytes(b'')
246 254
247 255 username = Unicode(os.environ.get('USER',u'username'), config=True,
248 256 help="""Username for the Session. Default is your system username.""")
249 257
250 258 # message signature related traits:
251 259 key = CBytes(b'', config=True,
252 260 help="""execution key, for extra authentication.""")
253 261 def _key_changed(self, name, old, new):
254 262 if new:
255 263 self.auth = hmac.HMAC(new)
256 264 else:
257 265 self.auth = None
258 266 auth = Instance(hmac.HMAC)
259 267 digest_history = Set()
260 268
261 269 keyfile = Unicode('', config=True,
262 270 help="""path to file containing execution key.""")
263 271 def _keyfile_changed(self, name, old, new):
264 272 with open(new, 'rb') as f:
265 273 self.key = f.read().strip()
266 274
267 275 pack = Any(default_packer) # the actual packer function
268 276 def _pack_changed(self, name, old, new):
269 277 if not callable(new):
270 278 raise TypeError("packer must be callable, not %s"%type(new))
271 279
272 280 unpack = Any(default_unpacker) # the actual packer function
273 281 def _unpack_changed(self, name, old, new):
274 282 # unpacker is not checked - it is assumed to be
275 283 if not callable(new):
276 284 raise TypeError("unpacker must be callable, not %s"%type(new))
277 285
278 286 def __init__(self, **kwargs):
279 287 """create a Session object
280 288
281 289 Parameters
282 290 ----------
283 291
284 292 debug : bool
285 293 whether to trigger extra debugging statements
286 294 packer/unpacker : str : 'json', 'pickle' or import_string
287 295 importstrings for methods to serialize message parts. If just
288 296 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 297 Otherwise, the entire importstring must be used.
290 298
291 299 The functions must accept at least valid JSON input, and output
292 300 *bytes*.
293 301
294 302 For example, to use msgpack:
295 303 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 304 pack/unpack : callables
297 305 You can also set the pack/unpack callables for serialization
298 306 directly.
299 session : bytes
307 session : unicode (must be ascii)
300 308 the ID of this Session object. The default is to generate a new
301 309 UUID.
310 bsession : bytes
311 The session as bytes
302 312 username : unicode
303 313 username added to message headers. The default is to ask the OS.
304 314 key : bytes
305 315 The key used to initialize an HMAC signature. If unset, messages
306 316 will not be signed or checked.
307 317 keyfile : filepath
308 318 The file containing a key. If this is set, `key` will be
309 319 initialized to the contents of the file.
310 320 """
311 321 super(Session, self).__init__(**kwargs)
312 322 self._check_packers()
313 323 self.none = self.pack({})
324 # ensure self._session_default() if necessary, so bsession is defined:
325 self.session
314 326
315 327 @property
316 328 def msg_id(self):
317 329 """always return new uuid"""
318 330 return str(uuid.uuid4())
319 331
320 332 def _check_packers(self):
321 333 """check packers for binary data and datetime support."""
322 334 pack = self.pack
323 335 unpack = self.unpack
324 336
325 337 # check simple serialization
326 338 msg = dict(a=[1,'hi'])
327 339 try:
328 340 packed = pack(msg)
329 341 except Exception:
330 342 raise ValueError("packer could not serialize a simple message")
331 343
332 344 # ensure packed message is bytes
333 345 if not isinstance(packed, bytes):
334 346 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335 347
336 348 # check that unpack is pack's inverse
337 349 try:
338 350 unpacked = unpack(packed)
339 351 except Exception:
340 352 raise ValueError("unpacker could not handle the packer's output")
341 353
342 354 # check datetime support
343 355 msg = dict(t=datetime.now())
344 356 try:
345 357 unpacked = unpack(pack(msg))
346 358 except Exception:
347 359 self.pack = lambda o: pack(squash_dates(o))
348 360 self.unpack = lambda s: extract_dates(unpack(s))
349 361
350 362 def msg_header(self, msg_type):
351 363 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 364
353 365 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
354 366 """Return the nested message dict.
355 367
356 368 This format is different from what is sent over the wire. The
357 369 serialize/unserialize methods converts this nested message dict to the wire
358 370 format, which is a list of message parts.
359 371 """
360 372 msg = {}
361 373 header = self.msg_header(msg_type) if header is None else header
362 374 msg['header'] = header
363 375 msg['msg_id'] = header['msg_id']
364 376 msg['msg_type'] = header['msg_type']
365 377 msg['parent_header'] = {} if parent is None else extract_header(parent)
366 378 msg['content'] = {} if content is None else content
367 379 sub = {} if subheader is None else subheader
368 380 msg['header'].update(sub)
369 381 return msg
370 382
371 383 def sign(self, msg_list):
372 384 """Sign a message with HMAC digest. If no auth, return b''.
373 385
374 386 Parameters
375 387 ----------
376 388 msg_list : list
377 389 The [p_header,p_parent,p_content] part of the message list.
378 390 """
379 391 if self.auth is None:
380 392 return b''
381 393 h = self.auth.copy()
382 394 for m in msg_list:
383 395 h.update(m)
384 396 return str_to_bytes(h.hexdigest())
385 397
386 398 def serialize(self, msg, ident=None):
387 399 """Serialize the message components to bytes.
388 400
389 401 This is roughly the inverse of unserialize. The serialize/unserialize
390 402 methods work with full message lists, whereas pack/unpack work with
391 403 the individual message parts in the message list.
392 404
393 405 Parameters
394 406 ----------
395 407 msg : dict or Message
396 408 The nexted message dict as returned by the self.msg method.
397 409
398 410 Returns
399 411 -------
400 412 msg_list : list
401 413 The list of bytes objects to be sent with the format:
402 414 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
403 415 buffer1,buffer2,...]. In this list, the p_* entities are
404 416 the packed or serialized versions, so if JSON is used, these
405 417 are utf8 encoded JSON strings.
406 418 """
407 419 content = msg.get('content', {})
408 420 if content is None:
409 421 content = self.none
410 422 elif isinstance(content, dict):
411 423 content = self.pack(content)
412 424 elif isinstance(content, bytes):
413 425 # content is already packed, as in a relayed message
414 426 pass
415 427 elif isinstance(content, unicode):
416 428 # should be bytes, but JSON often spits out unicode
417 429 content = content.encode('utf8')
418 430 else:
419 431 raise TypeError("Content incorrect type: %s"%type(content))
420 432
421 433 real_message = [self.pack(msg['header']),
422 434 self.pack(msg['parent_header']),
423 435 content
424 436 ]
425 437
426 438 to_send = []
427 439
428 440 if isinstance(ident, list):
429 441 # accept list of idents
430 442 to_send.extend(ident)
431 443 elif ident is not None:
432 444 to_send.append(ident)
433 445 to_send.append(DELIM)
434 446
435 447 signature = self.sign(real_message)
436 448 to_send.append(signature)
437 449
438 450 to_send.extend(real_message)
439 451
440 452 return to_send
441 453
442 454 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
443 455 buffers=None, subheader=None, track=False, header=None):
444 456 """Build and send a message via stream or socket.
445 457
446 458 The message format used by this function internally is as follows:
447 459
448 460 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
449 461 buffer1,buffer2,...]
450 462
451 463 The serialize/unserialize methods convert the nested message dict into this
452 464 format.
453 465
454 466 Parameters
455 467 ----------
456 468
457 469 stream : zmq.Socket or ZMQStream
458 470 The socket-like object used to send the data.
459 471 msg_or_type : str or Message/dict
460 472 Normally, msg_or_type will be a msg_type unless a message is being
461 473 sent more than once. If a header is supplied, this can be set to
462 474 None and the msg_type will be pulled from the header.
463 475
464 476 content : dict or None
465 477 The content of the message (ignored if msg_or_type is a message).
466 478 header : dict or None
467 479 The header dict for the message (ignores if msg_to_type is a message).
468 480 parent : Message or dict or None
469 481 The parent or parent header describing the parent of this message
470 482 (ignored if msg_or_type is a message).
471 483 ident : bytes or list of bytes
472 484 The zmq.IDENTITY routing path.
473 485 subheader : dict or None
474 486 Extra header keys for this message's header (ignored if msg_or_type
475 487 is a message).
476 488 buffers : list or None
477 489 The already-serialized buffers to be appended to the message.
478 490 track : bool
479 491 Whether to track. Only for use with Sockets, because ZMQStream
480 492 objects cannot track messages.
481 493
482 494 Returns
483 495 -------
484 496 msg : dict
485 497 The constructed message.
486 498 (msg,tracker) : (dict, MessageTracker)
487 499 if track=True, then a 2-tuple will be returned,
488 500 the first element being the constructed
489 501 message, and the second being the MessageTracker
490 502
491 503 """
492 504
493 505 if not isinstance(stream, (zmq.Socket, ZMQStream)):
494 506 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
495 507 elif track and isinstance(stream, ZMQStream):
496 508 raise TypeError("ZMQStream cannot track messages")
497 509
498 510 if isinstance(msg_or_type, (Message, dict)):
499 511 # We got a Message or message dict, not a msg_type so don't
500 512 # build a new Message.
501 513 msg = msg_or_type
502 514 else:
503 515 msg = self.msg(msg_or_type, content=content, parent=parent,
504 516 subheader=subheader, header=header)
505 517
506 518 buffers = [] if buffers is None else buffers
507 519 to_send = self.serialize(msg, ident)
508 520 flag = 0
509 521 if buffers:
510 522 flag = zmq.SNDMORE
511 523 _track = False
512 524 else:
513 525 _track=track
514 526 if track:
515 527 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
516 528 else:
517 529 tracker = stream.send_multipart(to_send, flag, copy=False)
518 530 for b in buffers[:-1]:
519 531 stream.send(b, flag, copy=False)
520 532 if buffers:
521 533 if track:
522 534 tracker = stream.send(buffers[-1], copy=False, track=track)
523 535 else:
524 536 tracker = stream.send(buffers[-1], copy=False)
525 537
526 538 # omsg = Message(msg)
527 539 if self.debug:
528 540 pprint.pprint(msg)
529 541 pprint.pprint(to_send)
530 542 pprint.pprint(buffers)
531 543
532 544 msg['tracker'] = tracker
533 545
534 546 return msg
535 547
536 548 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
537 549 """Send a raw message via ident path.
538 550
539 551 This method is used to send a already serialized message.
540 552
541 553 Parameters
542 554 ----------
543 555 stream : ZMQStream or Socket
544 556 The ZMQ stream or socket to use for sending the message.
545 557 msg_list : list
546 558 The serialized list of messages to send. This only includes the
547 559 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
548 560 the message.
549 561 ident : ident or list
550 562 A single ident or a list of idents to use in sending.
551 563 """
552 564 to_send = []
553 565 if isinstance(ident, bytes):
554 566 ident = [ident]
555 567 if ident is not None:
556 568 to_send.extend(ident)
557 569
558 570 to_send.append(DELIM)
559 571 to_send.append(self.sign(msg_list))
560 572 to_send.extend(msg_list)
561 573 stream.send_multipart(msg_list, flags, copy=copy)
562 574
563 575 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
564 576 """Receive and unpack a message.
565 577
566 578 Parameters
567 579 ----------
568 580 socket : ZMQStream or Socket
569 581 The socket or stream to use in receiving.
570 582
571 583 Returns
572 584 -------
573 585 [idents], msg
574 586 [idents] is a list of idents and msg is a nested message dict of
575 587 same format as self.msg returns.
576 588 """
577 589 if isinstance(socket, ZMQStream):
578 590 socket = socket.socket
579 591 try:
580 592 msg_list = socket.recv_multipart(mode)
581 593 except zmq.ZMQError as e:
582 594 if e.errno == zmq.EAGAIN:
583 595 # We can convert EAGAIN to None as we know in this case
584 596 # recv_multipart won't return None.
585 597 return None,None
586 598 else:
587 599 raise
588 600 # split multipart message into identity list and message dict
589 601 # invalid large messages can cause very expensive string comparisons
590 602 idents, msg_list = self.feed_identities(msg_list, copy)
591 603 try:
592 604 return idents, self.unserialize(msg_list, content=content, copy=copy)
593 605 except Exception as e:
594 606 # TODO: handle it
595 607 raise e
596 608
597 609 def feed_identities(self, msg_list, copy=True):
598 610 """Split the identities from the rest of the message.
599 611
600 612 Feed until DELIM is reached, then return the prefix as idents and
601 613 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
602 614 but that would be silly.
603 615
604 616 Parameters
605 617 ----------
606 618 msg_list : a list of Message or bytes objects
607 619 The message to be split.
608 620 copy : bool
609 621 flag determining whether the arguments are bytes or Messages
610 622
611 623 Returns
612 624 -------
613 625 (idents, msg_list) : two lists
614 626 idents will always be a list of bytes, each of which is a ZMQ
615 627 identity. msg_list will be a list of bytes or zmq.Messages of the
616 628 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
617 629 should be unpackable/unserializable via self.unserialize at this
618 630 point.
619 631 """
620 632 if copy:
621 633 idx = msg_list.index(DELIM)
622 634 return msg_list[:idx], msg_list[idx+1:]
623 635 else:
624 636 failed = True
625 637 for idx,m in enumerate(msg_list):
626 638 if m.bytes == DELIM:
627 639 failed = False
628 640 break
629 641 if failed:
630 642 raise ValueError("DELIM not in msg_list")
631 643 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
632 644 return [m.bytes for m in idents], msg_list
633 645
634 646 def unserialize(self, msg_list, content=True, copy=True):
635 647 """Unserialize a msg_list to a nested message dict.
636 648
637 649 This is roughly the inverse of serialize. The serialize/unserialize
638 650 methods work with full message lists, whereas pack/unpack work with
639 651 the individual message parts in the message list.
640 652
641 653 Parameters:
642 654 -----------
643 655 msg_list : list of bytes or Message objects
644 656 The list of message parts of the form [HMAC,p_header,p_parent,
645 657 p_content,buffer1,buffer2,...].
646 658 content : bool (True)
647 659 Whether to unpack the content dict (True), or leave it packed
648 660 (False).
649 661 copy : bool (True)
650 662 Whether to return the bytes (True), or the non-copying Message
651 663 object in each place (False).
652 664
653 665 Returns
654 666 -------
655 667 msg : dict
656 668 The nested message dict with top-level keys [header, parent_header,
657 669 content, buffers].
658 670 """
659 671 minlen = 4
660 672 message = {}
661 673 if not copy:
662 674 for i in range(minlen):
663 675 msg_list[i] = msg_list[i].bytes
664 676 if self.auth is not None:
665 677 signature = msg_list[0]
666 678 if not signature:
667 679 raise ValueError("Unsigned Message")
668 680 if signature in self.digest_history:
669 681 raise ValueError("Duplicate Signature: %r"%signature)
670 682 self.digest_history.add(signature)
671 683 check = self.sign(msg_list[1:4])
672 684 if not signature == check:
673 685 raise ValueError("Invalid Signature: %r"%signature)
674 686 if not len(msg_list) >= minlen:
675 687 raise TypeError("malformed message, must have at least %i elements"%minlen)
676 688 header = self.unpack(msg_list[1])
677 689 message['header'] = header
678 690 message['msg_id'] = header['msg_id']
679 691 message['msg_type'] = header['msg_type']
680 692 message['parent_header'] = self.unpack(msg_list[2])
681 693 if content:
682 694 message['content'] = self.unpack(msg_list[3])
683 695 else:
684 696 message['content'] = msg_list[3]
685 697
686 698 message['buffers'] = msg_list[4:]
687 699 return message
688 700
689 701 def test_msg2obj():
690 702 am = dict(x=1)
691 703 ao = Message(am)
692 704 assert ao.x == am['x']
693 705
694 706 am['y'] = dict(z=1)
695 707 ao = Message(am)
696 708 assert ao.y.z == am['y']['z']
697 709
698 710 k1, k2 = 'y', 'z'
699 711 assert ao[k1][k2] == am[k1][k2]
700 712
701 713 am2 = dict(ao)
702 714 assert am['x'] == am2['x']
703 715 assert am['y']['z'] == am2['y']['z']
704 716
@@ -1,188 +1,210 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 from IPython.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29 29
30 30 class MockSocket(zmq.Socket):
31 31
32 32 def __init__(self, *args, **kwargs):
33 33 super(MockSocket,self).__init__(*args,**kwargs)
34 34 self.data = []
35 35
36 36 def send_multipart(self, msgparts, *args, **kwargs):
37 37 self.data.extend(msgparts)
38 38
39 39 def send(self, part, *args, **kwargs):
40 40 self.data.append(part)
41 41
42 42 def recv_multipart(self, *args, **kwargs):
43 43 return self.data
44 44
45 45 class TestSession(SessionTestCase):
46 46
47 47 def test_msg(self):
48 48 """message format"""
49 49 msg = self.session.msg('execute')
50 50 thekeys = set('header parent_header content msg_type msg_id'.split())
51 51 s = set(msg.keys())
52 52 self.assertEquals(s, thekeys)
53 53 self.assertTrue(isinstance(msg['content'],dict))
54 54 self.assertTrue(isinstance(msg['header'],dict))
55 55 self.assertTrue(isinstance(msg['parent_header'],dict))
56 56 self.assertTrue(isinstance(msg['msg_id'],str))
57 57 self.assertTrue(isinstance(msg['msg_type'],str))
58 58 self.assertEquals(msg['header']['msg_type'], 'execute')
59 59 self.assertEquals(msg['msg_type'], 'execute')
60 60
61 61 def test_serialize(self):
62 62 msg = self.session.msg('execute',content=dict(a=10))
63 63 msg_list = self.session.serialize(msg, ident=b'foo')
64 64 ident, msg_list = self.session.feed_identities(msg_list)
65 65 new_msg = self.session.unserialize(msg_list)
66 66 self.assertEquals(ident[0], b'foo')
67 67 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
68 68 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
69 69 self.assertEquals(new_msg['header'],msg['header'])
70 70 self.assertEquals(new_msg['content'],msg['content'])
71 71 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
72 72
73 73 def test_send(self):
74 74 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
75 75
76 76 msg = self.session.msg('execute', content=dict(a=10))
77 77 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
78 78 ident, msg_list = self.session.feed_identities(socket.data)
79 79 new_msg = self.session.unserialize(msg_list)
80 80 self.assertEquals(ident[0], b'foo')
81 81 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
82 82 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
83 83 self.assertEquals(new_msg['header'],msg['header'])
84 84 self.assertEquals(new_msg['content'],msg['content'])
85 85 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
86 86 self.assertEquals(new_msg['buffers'],[b'bar'])
87 87
88 88 socket.data = []
89 89
90 90 content = msg['content']
91 91 header = msg['header']
92 92 parent = msg['parent_header']
93 93 msg_type = header['msg_type']
94 94 self.session.send(socket, None, content=content, parent=parent,
95 95 header=header, ident=b'foo', buffers=[b'bar'])
96 96 ident, msg_list = self.session.feed_identities(socket.data)
97 97 new_msg = self.session.unserialize(msg_list)
98 98 self.assertEquals(ident[0], b'foo')
99 99 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
100 100 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
101 101 self.assertEquals(new_msg['header'],msg['header'])
102 102 self.assertEquals(new_msg['content'],msg['content'])
103 103 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
104 104 self.assertEquals(new_msg['buffers'],[b'bar'])
105 105
106 106 socket.data = []
107 107
108 108 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
109 109 ident, new_msg = self.session.recv(socket)
110 110 self.assertEquals(ident[0], b'foo')
111 111 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
112 112 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
113 113 self.assertEquals(new_msg['header'],msg['header'])
114 114 self.assertEquals(new_msg['content'],msg['content'])
115 115 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
116 116 self.assertEquals(new_msg['buffers'],[b'bar'])
117 117
118 118 socket.close()
119 119
120 120 def test_args(self):
121 121 """initialization arguments for Session"""
122 122 s = self.session
123 123 self.assertTrue(s.pack is ss.default_packer)
124 124 self.assertTrue(s.unpack is ss.default_unpacker)
125 125 self.assertEquals(s.username, os.environ.get('USER', u'username'))
126 126
127 127 s = ss.Session()
128 128 self.assertEquals(s.username, os.environ.get('USER', u'username'))
129 129
130 130 self.assertRaises(TypeError, ss.Session, pack='hi')
131 131 self.assertRaises(TypeError, ss.Session, unpack='hi')
132 132 u = str(uuid.uuid4())
133 133 s = ss.Session(username=u'carrot', session=u)
134 134 self.assertEquals(s.session, u)
135 135 self.assertEquals(s.username, u'carrot')
136 136
137 137 def test_tracking(self):
138 138 """test tracking messages"""
139 139 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
140 140 s = self.session
141 141 stream = ZMQStream(a)
142 142 msg = s.send(a, 'hello', track=False)
143 143 self.assertTrue(msg['tracker'] is None)
144 144 msg = s.send(a, 'hello', track=True)
145 145 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
146 146 M = zmq.Message(b'hi there', track=True)
147 147 msg = s.send(a, 'hello', buffers=[M], track=True)
148 148 t = msg['tracker']
149 149 self.assertTrue(isinstance(t, zmq.MessageTracker))
150 150 self.assertRaises(zmq.NotDone, t.wait, .1)
151 151 del M
152 152 t.wait(1) # this will raise
153 153
154 154
155 155 # def test_rekey(self):
156 156 # """rekeying dict around json str keys"""
157 157 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
158 158 # self.assertRaises(KeyError, ss.rekey, d)
159 159 #
160 160 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
161 161 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
162 162 # rd = ss.rekey(d)
163 163 # self.assertEquals(d2,rd)
164 164 #
165 165 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
166 166 # d2 = {1.5:d['1.5'],1:d['1']}
167 167 # rd = ss.rekey(d)
168 168 # self.assertEquals(d2,rd)
169 169 #
170 170 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
171 171 # self.assertRaises(KeyError, ss.rekey, d)
172 172 #
173 173 def test_unique_msg_ids(self):
174 174 """test that messages receive unique ids"""
175 175 ids = set()
176 176 for i in range(2**12):
177 177 h = self.session.msg_header('test')
178 178 msg_id = h['msg_id']
179 179 self.assertTrue(msg_id not in ids)
180 180 ids.add(msg_id)
181 181
182 182 def test_feed_identities(self):
183 183 """scrub the front for zmq IDENTITIES"""
184 184 theids = "engine client other".split()
185 185 content = dict(code='whoda',stuff=object())
186 186 themsg = self.session.msg('execute',content=content)
187 187 pmsg = theids
188 188
189 def test_session_id(self):
190 session = ss.Session()
191 # get bs before us
192 bs = session.bsession
193 us = session.session
194 self.assertEquals(us.encode('ascii'), bs)
195 session = ss.Session()
196 # get us before bs
197 us = session.session
198 bs = session.bsession
199 self.assertEquals(us.encode('ascii'), bs)
200 # change propagates:
201 session.session = 'something else'
202 bs = session.bsession
203 us = session.session
204 self.assertEquals(us.encode('ascii'), bs)
205 session = ss.Session(session='stuff')
206 # get us before bs
207 self.assertEquals(session.bsession, session.session.encode('ascii'))
208 self.assertEquals(b'stuff', session.bsession)
209
210
General Comments 0
You need to be logged in to leave comments. Login now