##// END OF EJS Templates
More Python 3 compatibility fixes.
Thomas Kluyver -
Show More
@@ -1,201 +1,203 b''
1 1 """Dependency utilities
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 from types import ModuleType
15 15
16 16 from IPython.parallel.client.asyncresult import AsyncResult
17 17 from IPython.parallel.error import UnmetDependency
18 18 from IPython.parallel.util import interactive
19 from IPython.utils import py3compat
19 20
20 21 class depend(object):
21 22 """Dependency decorator, for use with tasks.
22 23
23 24 `@depend` lets you define a function for engine dependencies
24 25 just like you use `apply` for tasks.
25 26
26 27
27 28 Examples
28 29 --------
29 30 ::
30 31
31 32 @depend(df, a,b, c=5)
32 33 def f(m,n,p)
33 34
34 35 view.apply(f, 1,2,3)
35 36
36 37 will call df(a,b,c=5) on the engine, and if it returns False or
37 38 raises an UnmetDependency error, then the task will not be run
38 39 and another engine will be tried.
39 40 """
40 41 def __init__(self, f, *args, **kwargs):
41 42 self.f = f
42 43 self.args = args
43 44 self.kwargs = kwargs
44 45
45 46 def __call__(self, f):
46 47 return dependent(f, self.f, *self.args, **self.kwargs)
47 48
48 49 class dependent(object):
49 50 """A function that depends on another function.
50 51 This is an object to prevent the closure used
51 52 in traditional decorators, which are not picklable.
52 53 """
53 54
54 55 def __init__(self, f, df, *dargs, **dkwargs):
55 56 self.f = f
56 57 self.func_name = getattr(f, '__name__', 'f')
57 58 self.df = df
58 59 self.dargs = dargs
59 60 self.dkwargs = dkwargs
60 61
61 62 def __call__(self, *args, **kwargs):
62 63 # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
63 64 # self.df.func_globals = self.f.func_globals
64 65 if self.df(*self.dargs, **self.dkwargs) is False:
65 66 raise UnmetDependency()
66 67 return self.f(*args, **kwargs)
67 68
68 @property
69 def __name__(self):
70 return self.func_name
69 if not py3compat.PY3:
70 @property
71 def __name__(self):
72 return self.func_name
71 73
72 74 @interactive
73 75 def _require(*names):
74 76 """Helper for @require decorator."""
75 77 from IPython.parallel.error import UnmetDependency
76 78 user_ns = globals()
77 79 for name in names:
78 80 if name in user_ns:
79 81 continue
80 82 try:
81 83 exec 'import %s'%name in user_ns
82 84 except ImportError:
83 85 raise UnmetDependency(name)
84 86 return True
85 87
86 88 def require(*mods):
87 89 """Simple decorator for requiring names to be importable.
88 90
89 91 Examples
90 92 --------
91 93
92 94 In [1]: @require('numpy')
93 95 ...: def norm(a):
94 96 ...: import numpy
95 97 ...: return numpy.linalg.norm(a,2)
96 98 """
97 99 names = []
98 100 for mod in mods:
99 101 if isinstance(mod, ModuleType):
100 102 mod = mod.__name__
101 103
102 104 if isinstance(mod, basestring):
103 105 names.append(mod)
104 106 else:
105 107 raise TypeError("names must be modules or module names, not %s"%type(mod))
106 108
107 109 return depend(_require, *names)
108 110
109 111 class Dependency(set):
110 112 """An object for representing a set of msg_id dependencies.
111 113
112 114 Subclassed from set().
113 115
114 116 Parameters
115 117 ----------
116 118 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
117 119 The msg_ids to depend on
118 120 all : bool [default True]
119 121 Whether the dependency should be considered met when *all* depending tasks have completed
120 122 or only when *any* have been completed.
121 123 success : bool [default True]
122 124 Whether to consider successes as fulfilling dependencies.
123 125 failure : bool [default False]
124 126 Whether to consider failures as fulfilling dependencies.
125 127
126 128 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
127 129 as soon as the first depended-upon task fails.
128 130 """
129 131
130 132 all=True
131 133 success=True
132 134 failure=True
133 135
134 136 def __init__(self, dependencies=[], all=True, success=True, failure=False):
135 137 if isinstance(dependencies, dict):
136 138 # load from dict
137 139 all = dependencies.get('all', True)
138 140 success = dependencies.get('success', success)
139 141 failure = dependencies.get('failure', failure)
140 142 dependencies = dependencies.get('dependencies', [])
141 143 ids = []
142 144
143 145 # extract ids from various sources:
144 146 if isinstance(dependencies, (basestring, AsyncResult)):
145 147 dependencies = [dependencies]
146 148 for d in dependencies:
147 149 if isinstance(d, basestring):
148 150 ids.append(d)
149 151 elif isinstance(d, AsyncResult):
150 152 ids.extend(d.msg_ids)
151 153 else:
152 154 raise TypeError("invalid dependency type: %r"%type(d))
153 155
154 156 set.__init__(self, ids)
155 157 self.all = all
156 158 if not (success or failure):
157 159 raise ValueError("Must depend on at least one of successes or failures!")
158 160 self.success=success
159 161 self.failure = failure
160 162
161 163 def check(self, completed, failed=None):
162 164 """check whether our dependencies have been met."""
163 165 if len(self) == 0:
164 166 return True
165 167 against = set()
166 168 if self.success:
167 169 against = completed
168 170 if failed is not None and self.failure:
169 171 against = against.union(failed)
170 172 if self.all:
171 173 return self.issubset(against)
172 174 else:
173 175 return not self.isdisjoint(against)
174 176
175 177 def unreachable(self, completed, failed=None):
176 178 """return whether this dependency has become impossible."""
177 179 if len(self) == 0:
178 180 return False
179 181 against = set()
180 182 if not self.success:
181 183 against = completed
182 184 if failed is not None and not self.failure:
183 185 against = against.union(failed)
184 186 if self.all:
185 187 return not self.isdisjoint(against)
186 188 else:
187 189 return self.issubset(against)
188 190
189 191
190 192 def as_dict(self):
191 193 """Represent this dependency as a dict. For json compatibility."""
192 194 return dict(
193 195 dependencies=list(self),
194 196 all=self.all,
195 197 success=self.success,
196 198 failure=self.failure
197 199 )
198 200
199 201
200 202 __all__ = ['depend', 'require', 'dependent', 'Dependency']
201 203
@@ -1,983 +1,983 b''
1 1 """Base classes to manage the interaction with a running kernel.
2 2
3 3 TODO
4 4 * Create logger to handle debugging and console messages.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2010 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import atexit
20 20 import errno
21 21 from Queue import Queue, Empty
22 22 from subprocess import Popen
23 23 import signal
24 24 import sys
25 25 from threading import Thread
26 26 import time
27 27 import logging
28 28
29 29 # System library imports.
30 30 import zmq
31 31 from zmq import POLLIN, POLLOUT, POLLERR
32 32 from zmq.eventloop import ioloop
33 33
34 34 # Local imports.
35 35 from IPython.config.loader import Config
36 36 from IPython.utils import io
37 37 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
38 38 from IPython.utils.traitlets import HasTraits, Any, Instance, Type, TCPAddress
39 39 from session import Session, Message
40 40
41 41 #-----------------------------------------------------------------------------
42 42 # Constants and exceptions
43 43 #-----------------------------------------------------------------------------
44 44
45 45 class InvalidPortNumber(Exception):
46 46 pass
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Utility functions
50 50 #-----------------------------------------------------------------------------
51 51
52 52 # some utilities to validate message structure, these might get moved elsewhere
53 53 # if they prove to have more generic utility
54 54
55 55 def validate_string_list(lst):
56 56 """Validate that the input is a list of strings.
57 57
58 58 Raises ValueError if not."""
59 59 if not isinstance(lst, list):
60 60 raise ValueError('input %r must be a list' % lst)
61 61 for x in lst:
62 62 if not isinstance(x, basestring):
63 63 raise ValueError('element %r in list must be a string' % x)
64 64
65 65
66 66 def validate_string_dict(dct):
67 67 """Validate that the input is a dict with string keys and values.
68 68
69 69 Raises ValueError if not."""
70 70 for k,v in dct.iteritems():
71 71 if not isinstance(k, basestring):
72 72 raise ValueError('key %r in dict must be a string' % k)
73 73 if not isinstance(v, basestring):
74 74 raise ValueError('value %r in dict must be a string' % v)
75 75
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # ZMQ Socket Channel classes
79 79 #-----------------------------------------------------------------------------
80 80
81 81 class ZMQSocketChannel(Thread):
82 82 """The base class for the channels that use ZMQ sockets.
83 83 """
84 84 context = None
85 85 session = None
86 86 socket = None
87 87 ioloop = None
88 88 iostate = None
89 89 _address = None
90 90
91 91 def __init__(self, context, session, address):
92 92 """Create a channel
93 93
94 94 Parameters
95 95 ----------
96 96 context : :class:`zmq.Context`
97 97 The ZMQ context to use.
98 98 session : :class:`session.Session`
99 99 The session to use.
100 100 address : tuple
101 101 Standard (ip, port) tuple that the kernel is listening on.
102 102 """
103 103 super(ZMQSocketChannel, self).__init__()
104 104 self.daemon = True
105 105
106 106 self.context = context
107 107 self.session = session
108 108 if address[1] == 0:
109 109 message = 'The port number for a channel cannot be 0.'
110 110 raise InvalidPortNumber(message)
111 111 self._address = address
112 112
113 113 def _run_loop(self):
114 114 """Run my loop, ignoring EINTR events in the poller"""
115 115 while True:
116 116 try:
117 117 self.ioloop.start()
118 118 except zmq.ZMQError as e:
119 119 if e.errno == errno.EINTR:
120 120 continue
121 121 else:
122 122 raise
123 123 else:
124 124 break
125 125
126 126 def stop(self):
127 127 """Stop the channel's activity.
128 128
129 129 This calls :method:`Thread.join` and returns when the thread
130 130 terminates. :class:`RuntimeError` will be raised if
131 131 :method:`self.start` is called again.
132 132 """
133 133 self.join()
134 134
135 135 @property
136 136 def address(self):
137 137 """Get the channel's address as an (ip, port) tuple.
138 138
139 139 By the default, the address is (localhost, 0), where 0 means a random
140 140 port.
141 141 """
142 142 return self._address
143 143
144 144 def add_io_state(self, state):
145 145 """Add IO state to the eventloop.
146 146
147 147 Parameters
148 148 ----------
149 149 state : zmq.POLLIN|zmq.POLLOUT|zmq.POLLERR
150 150 The IO state flag to set.
151 151
152 152 This is thread safe as it uses the thread safe IOLoop.add_callback.
153 153 """
154 154 def add_io_state_callback():
155 155 if not self.iostate & state:
156 156 self.iostate = self.iostate | state
157 157 self.ioloop.update_handler(self.socket, self.iostate)
158 158 self.ioloop.add_callback(add_io_state_callback)
159 159
160 160 def drop_io_state(self, state):
161 161 """Drop IO state from the eventloop.
162 162
163 163 Parameters
164 164 ----------
165 165 state : zmq.POLLIN|zmq.POLLOUT|zmq.POLLERR
166 166 The IO state flag to set.
167 167
168 168 This is thread safe as it uses the thread safe IOLoop.add_callback.
169 169 """
170 170 def drop_io_state_callback():
171 171 if self.iostate & state:
172 172 self.iostate = self.iostate & (~state)
173 173 self.ioloop.update_handler(self.socket, self.iostate)
174 174 self.ioloop.add_callback(drop_io_state_callback)
175 175
176 176
177 177 class ShellSocketChannel(ZMQSocketChannel):
178 178 """The XREQ channel for issues request/replies to the kernel.
179 179 """
180 180
181 181 command_queue = None
182 182
183 183 def __init__(self, context, session, address):
184 184 super(ShellSocketChannel, self).__init__(context, session, address)
185 185 self.command_queue = Queue()
186 186 self.ioloop = ioloop.IOLoop()
187 187
188 188 def run(self):
189 189 """The thread's main activity. Call start() instead."""
190 190 self.socket = self.context.socket(zmq.DEALER)
191 self.socket.setsockopt(zmq.IDENTITY, self.session.session)
191 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
192 192 self.socket.connect('tcp://%s:%i' % self.address)
193 193 self.iostate = POLLERR|POLLIN
194 194 self.ioloop.add_handler(self.socket, self._handle_events,
195 195 self.iostate)
196 196 self._run_loop()
197 197
198 198 def stop(self):
199 199 self.ioloop.stop()
200 200 super(ShellSocketChannel, self).stop()
201 201
202 202 def call_handlers(self, msg):
203 203 """This method is called in the ioloop thread when a message arrives.
204 204
205 205 Subclasses should override this method to handle incoming messages.
206 206 It is important to remember that this method is called in the thread
207 207 so that some logic must be done to ensure that the application leve
208 208 handlers are called in the application thread.
209 209 """
210 210 raise NotImplementedError('call_handlers must be defined in a subclass.')
211 211
212 212 def execute(self, code, silent=False,
213 213 user_variables=None, user_expressions=None):
214 214 """Execute code in the kernel.
215 215
216 216 Parameters
217 217 ----------
218 218 code : str
219 219 A string of Python code.
220 220
221 221 silent : bool, optional (default False)
222 222 If set, the kernel will execute the code as quietly possible.
223 223
224 224 user_variables : list, optional
225 225 A list of variable names to pull from the user's namespace. They
226 226 will come back as a dict with these names as keys and their
227 227 :func:`repr` as values.
228 228
229 229 user_expressions : dict, optional
230 230 A dict with string keys and to pull from the user's
231 231 namespace. They will come back as a dict with these names as keys
232 232 and their :func:`repr` as values.
233 233
234 234 Returns
235 235 -------
236 236 The msg_id of the message sent.
237 237 """
238 238 if user_variables is None:
239 239 user_variables = []
240 240 if user_expressions is None:
241 241 user_expressions = {}
242 242
243 243 # Don't waste network traffic if inputs are invalid
244 244 if not isinstance(code, basestring):
245 245 raise ValueError('code %r must be a string' % code)
246 246 validate_string_list(user_variables)
247 247 validate_string_dict(user_expressions)
248 248
249 249 # Create class for content/msg creation. Related to, but possibly
250 250 # not in Session.
251 251 content = dict(code=code, silent=silent,
252 252 user_variables=user_variables,
253 253 user_expressions=user_expressions)
254 254 msg = self.session.msg('execute_request', content)
255 255 self._queue_request(msg)
256 256 return msg['header']['msg_id']
257 257
258 258 def complete(self, text, line, cursor_pos, block=None):
259 259 """Tab complete text in the kernel's namespace.
260 260
261 261 Parameters
262 262 ----------
263 263 text : str
264 264 The text to complete.
265 265 line : str
266 266 The full line of text that is the surrounding context for the
267 267 text to complete.
268 268 cursor_pos : int
269 269 The position of the cursor in the line where the completion was
270 270 requested.
271 271 block : str, optional
272 272 The full block of code in which the completion is being requested.
273 273
274 274 Returns
275 275 -------
276 276 The msg_id of the message sent.
277 277 """
278 278 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
279 279 msg = self.session.msg('complete_request', content)
280 280 self._queue_request(msg)
281 281 return msg['header']['msg_id']
282 282
283 283 def object_info(self, oname):
284 284 """Get metadata information about an object.
285 285
286 286 Parameters
287 287 ----------
288 288 oname : str
289 289 A string specifying the object name.
290 290
291 291 Returns
292 292 -------
293 293 The msg_id of the message sent.
294 294 """
295 295 content = dict(oname=oname)
296 296 msg = self.session.msg('object_info_request', content)
297 297 self._queue_request(msg)
298 298 return msg['header']['msg_id']
299 299
300 300 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
301 301 """Get entries from the history list.
302 302
303 303 Parameters
304 304 ----------
305 305 raw : bool
306 306 If True, return the raw input.
307 307 output : bool
308 308 If True, then return the output as well.
309 309 hist_access_type : str
310 310 'range' (fill in session, start and stop params), 'tail' (fill in n)
311 311 or 'search' (fill in pattern param).
312 312
313 313 session : int
314 314 For a range request, the session from which to get lines. Session
315 315 numbers are positive integers; negative ones count back from the
316 316 current session.
317 317 start : int
318 318 The first line number of a history range.
319 319 stop : int
320 320 The final (excluded) line number of a history range.
321 321
322 322 n : int
323 323 The number of lines of history to get for a tail request.
324 324
325 325 pattern : str
326 326 The glob-syntax pattern for a search request.
327 327
328 328 Returns
329 329 -------
330 330 The msg_id of the message sent.
331 331 """
332 332 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
333 333 **kwargs)
334 334 msg = self.session.msg('history_request', content)
335 335 self._queue_request(msg)
336 336 return msg['header']['msg_id']
337 337
338 338 def shutdown(self, restart=False):
339 339 """Request an immediate kernel shutdown.
340 340
341 341 Upon receipt of the (empty) reply, client code can safely assume that
342 342 the kernel has shut down and it's safe to forcefully terminate it if
343 343 it's still alive.
344 344
345 345 The kernel will send the reply via a function registered with Python's
346 346 atexit module, ensuring it's truly done as the kernel is done with all
347 347 normal operation.
348 348 """
349 349 # Send quit message to kernel. Once we implement kernel-side setattr,
350 350 # this should probably be done that way, but for now this will do.
351 351 msg = self.session.msg('shutdown_request', {'restart':restart})
352 352 self._queue_request(msg)
353 353 return msg['header']['msg_id']
354 354
355 355 def _handle_events(self, socket, events):
356 356 if events & POLLERR:
357 357 self._handle_err()
358 358 if events & POLLOUT:
359 359 self._handle_send()
360 360 if events & POLLIN:
361 361 self._handle_recv()
362 362
363 363 def _handle_recv(self):
364 364 ident,msg = self.session.recv(self.socket, 0)
365 365 self.call_handlers(msg)
366 366
367 367 def _handle_send(self):
368 368 try:
369 369 msg = self.command_queue.get(False)
370 370 except Empty:
371 371 pass
372 372 else:
373 373 self.session.send(self.socket,msg)
374 374 if self.command_queue.empty():
375 375 self.drop_io_state(POLLOUT)
376 376
377 377 def _handle_err(self):
378 378 # We don't want to let this go silently, so eventually we should log.
379 379 raise zmq.ZMQError()
380 380
381 381 def _queue_request(self, msg):
382 382 self.command_queue.put(msg)
383 383 self.add_io_state(POLLOUT)
384 384
385 385
386 386 class SubSocketChannel(ZMQSocketChannel):
387 387 """The SUB channel which listens for messages that the kernel publishes.
388 388 """
389 389
390 390 def __init__(self, context, session, address):
391 391 super(SubSocketChannel, self).__init__(context, session, address)
392 392 self.ioloop = ioloop.IOLoop()
393 393
394 394 def run(self):
395 395 """The thread's main activity. Call start() instead."""
396 396 self.socket = self.context.socket(zmq.SUB)
397 self.socket.setsockopt(zmq.SUBSCRIBE,'')
398 self.socket.setsockopt(zmq.IDENTITY, self.session.session)
397 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
398 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
399 399 self.socket.connect('tcp://%s:%i' % self.address)
400 400 self.iostate = POLLIN|POLLERR
401 401 self.ioloop.add_handler(self.socket, self._handle_events,
402 402 self.iostate)
403 403 self._run_loop()
404 404
405 405 def stop(self):
406 406 self.ioloop.stop()
407 407 super(SubSocketChannel, self).stop()
408 408
409 409 def call_handlers(self, msg):
410 410 """This method is called in the ioloop thread when a message arrives.
411 411
412 412 Subclasses should override this method to handle incoming messages.
413 413 It is important to remember that this method is called in the thread
414 414 so that some logic must be done to ensure that the application leve
415 415 handlers are called in the application thread.
416 416 """
417 417 raise NotImplementedError('call_handlers must be defined in a subclass.')
418 418
419 419 def flush(self, timeout=1.0):
420 420 """Immediately processes all pending messages on the SUB channel.
421 421
422 422 Callers should use this method to ensure that :method:`call_handlers`
423 423 has been called for all messages that have been received on the
424 424 0MQ SUB socket of this channel.
425 425
426 426 This method is thread safe.
427 427
428 428 Parameters
429 429 ----------
430 430 timeout : float, optional
431 431 The maximum amount of time to spend flushing, in seconds. The
432 432 default is one second.
433 433 """
434 434 # We do the IOLoop callback process twice to ensure that the IOLoop
435 435 # gets to perform at least one full poll.
436 436 stop_time = time.time() + timeout
437 437 for i in xrange(2):
438 438 self._flushed = False
439 439 self.ioloop.add_callback(self._flush)
440 440 while not self._flushed and time.time() < stop_time:
441 441 time.sleep(0.01)
442 442
443 443 def _handle_events(self, socket, events):
444 444 # Turn on and off POLLOUT depending on if we have made a request
445 445 if events & POLLERR:
446 446 self._handle_err()
447 447 if events & POLLIN:
448 448 self._handle_recv()
449 449
450 450 def _handle_err(self):
451 451 # We don't want to let this go silently, so eventually we should log.
452 452 raise zmq.ZMQError()
453 453
454 454 def _handle_recv(self):
455 455 # Get all of the messages we can
456 456 while True:
457 457 try:
458 458 ident,msg = self.session.recv(self.socket)
459 459 except zmq.ZMQError:
460 460 # Check the errno?
461 461 # Will this trigger POLLERR?
462 462 break
463 463 else:
464 464 if msg is None:
465 465 break
466 466 self.call_handlers(msg)
467 467
468 468 def _flush(self):
469 469 """Callback for :method:`self.flush`."""
470 470 self._flushed = True
471 471
472 472
473 473 class StdInSocketChannel(ZMQSocketChannel):
474 474 """A reply channel to handle raw_input requests that the kernel makes."""
475 475
476 476 msg_queue = None
477 477
478 478 def __init__(self, context, session, address):
479 479 super(StdInSocketChannel, self).__init__(context, session, address)
480 480 self.ioloop = ioloop.IOLoop()
481 481 self.msg_queue = Queue()
482 482
483 483 def run(self):
484 484 """The thread's main activity. Call start() instead."""
485 485 self.socket = self.context.socket(zmq.DEALER)
486 self.socket.setsockopt(zmq.IDENTITY, self.session.session)
486 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
487 487 self.socket.connect('tcp://%s:%i' % self.address)
488 488 self.iostate = POLLERR|POLLIN
489 489 self.ioloop.add_handler(self.socket, self._handle_events,
490 490 self.iostate)
491 491 self._run_loop()
492 492
493 493 def stop(self):
494 494 self.ioloop.stop()
495 495 super(StdInSocketChannel, self).stop()
496 496
497 497 def call_handlers(self, msg):
498 498 """This method is called in the ioloop thread when a message arrives.
499 499
500 500 Subclasses should override this method to handle incoming messages.
501 501 It is important to remember that this method is called in the thread
502 502 so that some logic must be done to ensure that the application leve
503 503 handlers are called in the application thread.
504 504 """
505 505 raise NotImplementedError('call_handlers must be defined in a subclass.')
506 506
507 507 def input(self, string):
508 508 """Send a string of raw input to the kernel."""
509 509 content = dict(value=string)
510 510 msg = self.session.msg('input_reply', content)
511 511 self._queue_reply(msg)
512 512
513 513 def _handle_events(self, socket, events):
514 514 if events & POLLERR:
515 515 self._handle_err()
516 516 if events & POLLOUT:
517 517 self._handle_send()
518 518 if events & POLLIN:
519 519 self._handle_recv()
520 520
521 521 def _handle_recv(self):
522 522 ident,msg = self.session.recv(self.socket, 0)
523 523 self.call_handlers(msg)
524 524
525 525 def _handle_send(self):
526 526 try:
527 527 msg = self.msg_queue.get(False)
528 528 except Empty:
529 529 pass
530 530 else:
531 531 self.session.send(self.socket,msg)
532 532 if self.msg_queue.empty():
533 533 self.drop_io_state(POLLOUT)
534 534
535 535 def _handle_err(self):
536 536 # We don't want to let this go silently, so eventually we should log.
537 537 raise zmq.ZMQError()
538 538
539 539 def _queue_reply(self, msg):
540 540 self.msg_queue.put(msg)
541 541 self.add_io_state(POLLOUT)
542 542
543 543
544 544 class HBSocketChannel(ZMQSocketChannel):
545 545 """The heartbeat channel which monitors the kernel heartbeat.
546 546
547 547 Note that the heartbeat channel is paused by default. As long as you start
548 548 this channel, the kernel manager will ensure that it is paused and un-paused
549 549 as appropriate.
550 550 """
551 551
552 552 time_to_dead = 3.0
553 553 socket = None
554 554 poller = None
555 555 _running = None
556 556 _pause = None
557 557
558 558 def __init__(self, context, session, address):
559 559 super(HBSocketChannel, self).__init__(context, session, address)
560 560 self._running = False
561 561 self._pause = True
562 562
563 563 def _create_socket(self):
564 564 self.socket = self.context.socket(zmq.REQ)
565 self.socket.setsockopt(zmq.IDENTITY, self.session.session)
565 self.socket.setsockopt(zmq.IDENTITY, self.session.session.encode("ascii"))
566 566 self.socket.connect('tcp://%s:%i' % self.address)
567 567 self.poller = zmq.Poller()
568 568 self.poller.register(self.socket, zmq.POLLIN)
569 569
570 570 def run(self):
571 571 """The thread's main activity. Call start() instead."""
572 572 self._create_socket()
573 573 self._running = True
574 574 while self._running:
575 575 if self._pause:
576 576 time.sleep(self.time_to_dead)
577 577 else:
578 578 since_last_heartbeat = 0.0
579 579 request_time = time.time()
580 580 try:
581 581 #io.rprint('Ping from HB channel') # dbg
582 582 self.socket.send(b'ping')
583 583 except zmq.ZMQError, e:
584 584 #io.rprint('*** HB Error:', e) # dbg
585 585 if e.errno == zmq.EFSM:
586 586 #io.rprint('sleep...', self.time_to_dead) # dbg
587 587 time.sleep(self.time_to_dead)
588 588 self._create_socket()
589 589 else:
590 590 raise
591 591 else:
592 592 while True:
593 593 try:
594 594 self.socket.recv(zmq.NOBLOCK)
595 595 except zmq.ZMQError, e:
596 596 #io.rprint('*** HB Error 2:', e) # dbg
597 597 if e.errno == zmq.EAGAIN:
598 598 before_poll = time.time()
599 599 until_dead = self.time_to_dead - (before_poll -
600 600 request_time)
601 601
602 602 # When the return value of poll() is an empty
603 603 # list, that is when things have gone wrong
604 604 # (zeromq bug). As long as it is not an empty
605 605 # list, poll is working correctly even if it
606 606 # returns quickly. Note: poll timeout is in
607 607 # milliseconds.
608 608 if until_dead > 0.0:
609 609 while True:
610 610 try:
611 611 self.poller.poll(1000 * until_dead)
612 612 except zmq.ZMQError as e:
613 613 if e.errno == errno.EINTR:
614 614 continue
615 615 else:
616 616 raise
617 617 else:
618 618 break
619 619
620 620 since_last_heartbeat = time.time()-request_time
621 621 if since_last_heartbeat > self.time_to_dead:
622 622 self.call_handlers(since_last_heartbeat)
623 623 break
624 624 else:
625 625 # FIXME: We should probably log this instead.
626 626 raise
627 627 else:
628 628 until_dead = self.time_to_dead - (time.time() -
629 629 request_time)
630 630 if until_dead > 0.0:
631 631 #io.rprint('sleep...', self.time_to_dead) # dbg
632 632 time.sleep(until_dead)
633 633 break
634 634
635 635 def pause(self):
636 636 """Pause the heartbeat."""
637 637 self._pause = True
638 638
639 639 def unpause(self):
640 640 """Unpause the heartbeat."""
641 641 self._pause = False
642 642
643 643 def is_beating(self):
644 644 """Is the heartbeat running and not paused."""
645 645 if self.is_alive() and not self._pause:
646 646 return True
647 647 else:
648 648 return False
649 649
650 650 def stop(self):
651 651 self._running = False
652 652 super(HBSocketChannel, self).stop()
653 653
654 654 def call_handlers(self, since_last_heartbeat):
655 655 """This method is called in the ioloop thread when a message arrives.
656 656
657 657 Subclasses should override this method to handle incoming messages.
658 658 It is important to remember that this method is called in the thread
659 659 so that some logic must be done to ensure that the application leve
660 660 handlers are called in the application thread.
661 661 """
662 662 raise NotImplementedError('call_handlers must be defined in a subclass.')
663 663
664 664
665 665 #-----------------------------------------------------------------------------
666 666 # Main kernel manager class
667 667 #-----------------------------------------------------------------------------
668 668
669 669 class KernelManager(HasTraits):
670 670 """ Manages a kernel for a frontend.
671 671
672 672 The SUB channel is for the frontend to receive messages published by the
673 673 kernel.
674 674
675 675 The REQ channel is for the frontend to make requests of the kernel.
676 676
677 677 The REP channel is for the kernel to request stdin (raw_input) from the
678 678 frontend.
679 679 """
680 680 # config object for passing to child configurables
681 681 config = Instance(Config)
682 682
683 683 # The PyZMQ Context to use for communication with the kernel.
684 684 context = Instance(zmq.Context)
685 685 def _context_default(self):
686 686 return zmq.Context.instance()
687 687
688 688 # The Session to use for communication with the kernel.
689 689 session = Instance(Session)
690 690
691 691 # The kernel process with which the KernelManager is communicating.
692 692 kernel = Instance(Popen)
693 693
694 694 # The addresses for the communication channels.
695 695 shell_address = TCPAddress((LOCALHOST, 0))
696 696 sub_address = TCPAddress((LOCALHOST, 0))
697 697 stdin_address = TCPAddress((LOCALHOST, 0))
698 698 hb_address = TCPAddress((LOCALHOST, 0))
699 699
700 700 # The classes to use for the various channels.
701 701 shell_channel_class = Type(ShellSocketChannel)
702 702 sub_channel_class = Type(SubSocketChannel)
703 703 stdin_channel_class = Type(StdInSocketChannel)
704 704 hb_channel_class = Type(HBSocketChannel)
705 705
706 706 # Protected traits.
707 707 _launch_args = Any
708 708 _shell_channel = Any
709 709 _sub_channel = Any
710 710 _stdin_channel = Any
711 711 _hb_channel = Any
712 712
713 713 def __init__(self, **kwargs):
714 714 super(KernelManager, self).__init__(**kwargs)
715 715 if self.session is None:
716 716 self.session = Session(config=self.config)
717 717 # Uncomment this to try closing the context.
718 718 # atexit.register(self.context.term)
719 719
720 720 #--------------------------------------------------------------------------
721 721 # Channel management methods:
722 722 #--------------------------------------------------------------------------
723 723
724 724 def start_channels(self, shell=True, sub=True, stdin=True, hb=True):
725 725 """Starts the channels for this kernel.
726 726
727 727 This will create the channels if they do not exist and then start
728 728 them. If port numbers of 0 are being used (random ports) then you
729 729 must first call :method:`start_kernel`. If the channels have been
730 730 stopped and you call this, :class:`RuntimeError` will be raised.
731 731 """
732 732 if shell:
733 733 self.shell_channel.start()
734 734 if sub:
735 735 self.sub_channel.start()
736 736 if stdin:
737 737 self.stdin_channel.start()
738 738 if hb:
739 739 self.hb_channel.start()
740 740
741 741 def stop_channels(self):
742 742 """Stops all the running channels for this kernel.
743 743 """
744 744 if self.shell_channel.is_alive():
745 745 self.shell_channel.stop()
746 746 if self.sub_channel.is_alive():
747 747 self.sub_channel.stop()
748 748 if self.stdin_channel.is_alive():
749 749 self.stdin_channel.stop()
750 750 if self.hb_channel.is_alive():
751 751 self.hb_channel.stop()
752 752
753 753 @property
754 754 def channels_running(self):
755 755 """Are any of the channels created and running?"""
756 756 return (self.shell_channel.is_alive() or self.sub_channel.is_alive() or
757 757 self.stdin_channel.is_alive() or self.hb_channel.is_alive())
758 758
759 759 #--------------------------------------------------------------------------
760 760 # Kernel process management methods:
761 761 #--------------------------------------------------------------------------
762 762
763 763 def start_kernel(self, **kw):
764 764 """Starts a kernel process and configures the manager to use it.
765 765
766 766 If random ports (port=0) are being used, this method must be called
767 767 before the channels are created.
768 768
769 769 Parameters:
770 770 -----------
771 771 ipython : bool, optional (default True)
772 772 Whether to use an IPython kernel instead of a plain Python kernel.
773 773
774 774 launcher : callable, optional (default None)
775 775 A custom function for launching the kernel process (generally a
776 776 wrapper around ``entry_point.base_launch_kernel``). In most cases,
777 777 it should not be necessary to use this parameter.
778 778
779 779 **kw : optional
780 780 See respective options for IPython and Python kernels.
781 781 """
782 782 shell, sub, stdin, hb = self.shell_address, self.sub_address, \
783 783 self.stdin_address, self.hb_address
784 784 if shell[0] not in LOCAL_IPS or sub[0] not in LOCAL_IPS or \
785 785 stdin[0] not in LOCAL_IPS or hb[0] not in LOCAL_IPS:
786 786 raise RuntimeError("Can only launch a kernel on a local interface. "
787 787 "Make sure that the '*_address' attributes are "
788 788 "configured properly. "
789 789 "Currently valid addresses are: %s"%LOCAL_IPS
790 790 )
791 791
792 792 self._launch_args = kw.copy()
793 793 launch_kernel = kw.pop('launcher', None)
794 794 if launch_kernel is None:
795 795 if kw.pop('ipython', True):
796 796 from ipkernel import launch_kernel
797 797 else:
798 798 from pykernel import launch_kernel
799 799 self.kernel, xrep, pub, req, _hb = launch_kernel(
800 800 shell_port=shell[1], iopub_port=sub[1],
801 801 stdin_port=stdin[1], hb_port=hb[1], **kw)
802 802 self.shell_address = (shell[0], xrep)
803 803 self.sub_address = (sub[0], pub)
804 804 self.stdin_address = (stdin[0], req)
805 805 self.hb_address = (hb[0], _hb)
806 806
807 807 def shutdown_kernel(self, restart=False):
808 808 """ Attempts to the stop the kernel process cleanly. If the kernel
809 809 cannot be stopped, it is killed, if possible.
810 810 """
811 811 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
812 812 if sys.platform == 'win32':
813 813 self.kill_kernel()
814 814 return
815 815
816 816 # Pause the heart beat channel if it exists.
817 817 if self._hb_channel is not None:
818 818 self._hb_channel.pause()
819 819
820 820 # Don't send any additional kernel kill messages immediately, to give
821 821 # the kernel a chance to properly execute shutdown actions. Wait for at
822 822 # most 1s, checking every 0.1s.
823 823 self.shell_channel.shutdown(restart=restart)
824 824 for i in range(10):
825 825 if self.is_alive:
826 826 time.sleep(0.1)
827 827 else:
828 828 break
829 829 else:
830 830 # OK, we've waited long enough.
831 831 if self.has_kernel:
832 832 self.kill_kernel()
833 833
834 834 def restart_kernel(self, now=False, **kw):
835 835 """Restarts a kernel with the arguments that were used to launch it.
836 836
837 837 If the old kernel was launched with random ports, the same ports will be
838 838 used for the new kernel.
839 839
840 840 Parameters
841 841 ----------
842 842 now : bool, optional
843 843 If True, the kernel is forcefully restarted *immediately*, without
844 844 having a chance to do any cleanup action. Otherwise the kernel is
845 845 given 1s to clean up before a forceful restart is issued.
846 846
847 847 In all cases the kernel is restarted, the only difference is whether
848 848 it is given a chance to perform a clean shutdown or not.
849 849
850 850 **kw : optional
851 851 Any options specified here will replace those used to launch the
852 852 kernel.
853 853 """
854 854 if self._launch_args is None:
855 855 raise RuntimeError("Cannot restart the kernel. "
856 856 "No previous call to 'start_kernel'.")
857 857 else:
858 858 # Stop currently running kernel.
859 859 if self.has_kernel:
860 860 if now:
861 861 self.kill_kernel()
862 862 else:
863 863 self.shutdown_kernel(restart=True)
864 864
865 865 # Start new kernel.
866 866 self._launch_args.update(kw)
867 867 self.start_kernel(**self._launch_args)
868 868
869 869 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
870 870 # unless there is some delay here.
871 871 if sys.platform == 'win32':
872 872 time.sleep(0.2)
873 873
874 874 @property
875 875 def has_kernel(self):
876 876 """Returns whether a kernel process has been specified for the kernel
877 877 manager.
878 878 """
879 879 return self.kernel is not None
880 880
881 881 def kill_kernel(self):
882 882 """ Kill the running kernel. """
883 883 if self.has_kernel:
884 884 # Pause the heart beat channel if it exists.
885 885 if self._hb_channel is not None:
886 886 self._hb_channel.pause()
887 887
888 888 # Attempt to kill the kernel.
889 889 try:
890 890 self.kernel.kill()
891 891 except OSError, e:
892 892 # In Windows, we will get an Access Denied error if the process
893 893 # has already terminated. Ignore it.
894 894 if sys.platform == 'win32':
895 895 if e.winerror != 5:
896 896 raise
897 897 # On Unix, we may get an ESRCH error if the process has already
898 898 # terminated. Ignore it.
899 899 else:
900 900 from errno import ESRCH
901 901 if e.errno != ESRCH:
902 902 raise
903 903 self.kernel = None
904 904 else:
905 905 raise RuntimeError("Cannot kill kernel. No kernel is running!")
906 906
907 907 def interrupt_kernel(self):
908 908 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
909 909 well supported on all platforms.
910 910 """
911 911 if self.has_kernel:
912 912 if sys.platform == 'win32':
913 913 from parentpoller import ParentPollerWindows as Poller
914 914 Poller.send_interrupt(self.kernel.win32_interrupt_event)
915 915 else:
916 916 self.kernel.send_signal(signal.SIGINT)
917 917 else:
918 918 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
919 919
920 920 def signal_kernel(self, signum):
921 921 """ Sends a signal to the kernel. Note that since only SIGTERM is
922 922 supported on Windows, this function is only useful on Unix systems.
923 923 """
924 924 if self.has_kernel:
925 925 self.kernel.send_signal(signum)
926 926 else:
927 927 raise RuntimeError("Cannot signal kernel. No kernel is running!")
928 928
929 929 @property
930 930 def is_alive(self):
931 931 """Is the kernel process still running?"""
932 932 # FIXME: not using a heartbeat means this method is broken for any
933 933 # remote kernel, it's only capable of handling local kernels.
934 934 if self.has_kernel:
935 935 if self.kernel.poll() is None:
936 936 return True
937 937 else:
938 938 return False
939 939 else:
940 940 # We didn't start the kernel with this KernelManager so we don't
941 941 # know if it is running. We should use a heartbeat for this case.
942 942 return True
943 943
944 944 #--------------------------------------------------------------------------
945 945 # Channels used for communication with the kernel:
946 946 #--------------------------------------------------------------------------
947 947
948 948 @property
949 949 def shell_channel(self):
950 950 """Get the REQ socket channel object to make requests of the kernel."""
951 951 if self._shell_channel is None:
952 952 self._shell_channel = self.shell_channel_class(self.context,
953 953 self.session,
954 954 self.shell_address)
955 955 return self._shell_channel
956 956
957 957 @property
958 958 def sub_channel(self):
959 959 """Get the SUB socket channel object."""
960 960 if self._sub_channel is None:
961 961 self._sub_channel = self.sub_channel_class(self.context,
962 962 self.session,
963 963 self.sub_address)
964 964 return self._sub_channel
965 965
966 966 @property
967 967 def stdin_channel(self):
968 968 """Get the REP socket channel object to handle stdin (raw_input)."""
969 969 if self._stdin_channel is None:
970 970 self._stdin_channel = self.stdin_channel_class(self.context,
971 971 self.session,
972 972 self.stdin_address)
973 973 return self._stdin_channel
974 974
975 975 @property
976 976 def hb_channel(self):
977 977 """Get the heartbeat socket channel object to check that the
978 978 kernel is alive."""
979 979 if self._hb_channel is None:
980 980 self._hb_channel = self.hb_channel_class(self.context,
981 981 self.session,
982 982 self.hb_address)
983 983 return self._hb_channel
@@ -1,704 +1,704 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hmac
28 28 import logging
29 29 import os
30 30 import pprint
31 31 import uuid
32 32 from datetime import datetime
33 33
34 34 try:
35 35 import cPickle
36 36 pickle = cPickle
37 37 except:
38 38 cPickle = None
39 39 import pickle
40 40
41 41 import zmq
42 42 from zmq.utils import jsonapi
43 43 from zmq.eventloop.ioloop import IOLoop
44 44 from zmq.eventloop.zmqstream import ZMQStream
45 45
46 46 from IPython.config.configurable import Configurable, LoggingConfigurable
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 49 from IPython.utils.py3compat import str_to_bytes
50 50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 DottedObjectName)
51 DottedObjectName, CUnicode)
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # utility functions
55 55 #-----------------------------------------------------------------------------
56 56
57 57 def squash_unicode(obj):
58 58 """coerce unicode back to bytestrings."""
59 59 if isinstance(obj,dict):
60 60 for key in obj.keys():
61 61 obj[key] = squash_unicode(obj[key])
62 62 if isinstance(key, unicode):
63 63 obj[squash_unicode(key)] = obj.pop(key)
64 64 elif isinstance(obj, list):
65 65 for i,v in enumerate(obj):
66 66 obj[i] = squash_unicode(v)
67 67 elif isinstance(obj, unicode):
68 68 obj = obj.encode('utf8')
69 69 return obj
70 70
71 71 #-----------------------------------------------------------------------------
72 72 # globals and defaults
73 73 #-----------------------------------------------------------------------------
74 74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77 77
78 78 pickle_packer = lambda o: pickle.dumps(o,-1)
79 79 pickle_unpacker = pickle.loads
80 80
81 81 default_packer = json_packer
82 82 default_unpacker = json_unpacker
83 83
84 84
85 85 DELIM=b"<IDS|MSG>"
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # Classes
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SessionFactory(LoggingConfigurable):
92 92 """The Base class for configurables that have a Session, Context, logger,
93 93 and IOLoop.
94 94 """
95 95
96 96 logname = Unicode('')
97 97 def _logname_changed(self, name, old, new):
98 98 self.log = logging.getLogger(new)
99 99
100 100 # not configurable:
101 101 context = Instance('zmq.Context')
102 102 def _context_default(self):
103 103 return zmq.Context.instance()
104 104
105 105 session = Instance('IPython.zmq.session.Session')
106 106
107 107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 108 def _loop_default(self):
109 109 return IOLoop.instance()
110 110
111 111 def __init__(self, **kwargs):
112 112 super(SessionFactory, self).__init__(**kwargs)
113 113
114 114 if self.session is None:
115 115 # construct the session
116 116 self.session = Session(**kwargs)
117 117
118 118
119 119 class Message(object):
120 120 """A simple message object that maps dict keys to attributes.
121 121
122 122 A Message can be created from a dict and a dict from a Message instance
123 123 simply by calling dict(msg_obj)."""
124 124
125 125 def __init__(self, msg_dict):
126 126 dct = self.__dict__
127 127 for k, v in dict(msg_dict).iteritems():
128 128 if isinstance(v, dict):
129 129 v = Message(v)
130 130 dct[k] = v
131 131
132 132 # Having this iterator lets dict(msg_obj) work out of the box.
133 133 def __iter__(self):
134 134 return iter(self.__dict__.iteritems())
135 135
136 136 def __repr__(self):
137 137 return repr(self.__dict__)
138 138
139 139 def __str__(self):
140 140 return pprint.pformat(self.__dict__)
141 141
142 142 def __contains__(self, k):
143 143 return k in self.__dict__
144 144
145 145 def __getitem__(self, k):
146 146 return self.__dict__[k]
147 147
148 148
149 149 def msg_header(msg_id, msg_type, username, session):
150 150 date = datetime.now()
151 151 return locals()
152 152
153 153 def extract_header(msg_or_header):
154 154 """Given a message or header, return the header."""
155 155 if not msg_or_header:
156 156 return {}
157 157 try:
158 158 # See if msg_or_header is the entire message.
159 159 h = msg_or_header['header']
160 160 except KeyError:
161 161 try:
162 162 # See if msg_or_header is just the header
163 163 h = msg_or_header['msg_id']
164 164 except KeyError:
165 165 raise
166 166 else:
167 167 h = msg_or_header
168 168 if not isinstance(h, dict):
169 169 h = dict(h)
170 170 return h
171 171
172 172 class Session(Configurable):
173 173 """Object for handling serialization and sending of messages.
174 174
175 175 The Session object handles building messages and sending them
176 176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 177 other over the network via Session objects, and only need to work with the
178 178 dict-based IPython message spec. The Session will handle
179 179 serialization/deserialization, security, and metadata.
180 180
181 181 Sessions support configurable serialiization via packer/unpacker traits,
182 182 and signing with HMAC digests via the key/keyfile traits.
183 183
184 184 Parameters
185 185 ----------
186 186
187 187 debug : bool
188 188 whether to trigger extra debugging statements
189 189 packer/unpacker : str : 'json', 'pickle' or import_string
190 190 importstrings for methods to serialize message parts. If just
191 191 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 192 Otherwise, the entire importstring must be used.
193 193
194 194 The functions must accept at least valid JSON input, and output *bytes*.
195 195
196 196 For example, to use msgpack:
197 197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 198 pack/unpack : callables
199 199 You can also set the pack/unpack callables for serialization directly.
200 200 session : bytes
201 201 the ID of this Session object. The default is to generate a new UUID.
202 202 username : unicode
203 203 username added to message headers. The default is to ask the OS.
204 204 key : bytes
205 205 The key used to initialize an HMAC signature. If unset, messages
206 206 will not be signed or checked.
207 207 keyfile : filepath
208 208 The file containing a key. If this is set, `key` will be initialized
209 209 to the contents of the file.
210 210
211 211 """
212 212
213 213 debug=Bool(False, config=True, help="""Debug output in the Session""")
214 214
215 215 packer = DottedObjectName('json',config=True,
216 216 help="""The name of the packer for serializing messages.
217 217 Should be one of 'json', 'pickle', or an import name
218 218 for a custom callable serializer.""")
219 219 def _packer_changed(self, name, old, new):
220 220 if new.lower() == 'json':
221 221 self.pack = json_packer
222 222 self.unpack = json_unpacker
223 223 elif new.lower() == 'pickle':
224 224 self.pack = pickle_packer
225 225 self.unpack = pickle_unpacker
226 226 else:
227 227 self.pack = import_item(str(new))
228 228
229 229 unpacker = DottedObjectName('json', config=True,
230 230 help="""The name of the unpacker for unserializing messages.
231 231 Only used with custom functions for `packer`.""")
232 232 def _unpacker_changed(self, name, old, new):
233 233 if new.lower() == 'json':
234 234 self.pack = json_packer
235 235 self.unpack = json_unpacker
236 236 elif new.lower() == 'pickle':
237 237 self.pack = pickle_packer
238 238 self.unpack = pickle_unpacker
239 239 else:
240 240 self.unpack = import_item(str(new))
241 241
242 session = CBytes(b'', config=True,
242 session = CUnicode(u'', config=True,
243 243 help="""The UUID identifying this session.""")
244 244 def _session_default(self):
245 return bytes(uuid.uuid4())
245 return unicode(uuid.uuid4())
246 246
247 247 username = Unicode(os.environ.get('USER',u'username'), config=True,
248 248 help="""Username for the Session. Default is your system username.""")
249 249
250 250 # message signature related traits:
251 251 key = CBytes(b'', config=True,
252 252 help="""execution key, for extra authentication.""")
253 253 def _key_changed(self, name, old, new):
254 254 if new:
255 255 self.auth = hmac.HMAC(new)
256 256 else:
257 257 self.auth = None
258 258 auth = Instance(hmac.HMAC)
259 259 digest_history = Set()
260 260
261 261 keyfile = Unicode('', config=True,
262 262 help="""path to file containing execution key.""")
263 263 def _keyfile_changed(self, name, old, new):
264 264 with open(new, 'rb') as f:
265 265 self.key = f.read().strip()
266 266
267 267 pack = Any(default_packer) # the actual packer function
268 268 def _pack_changed(self, name, old, new):
269 269 if not callable(new):
270 270 raise TypeError("packer must be callable, not %s"%type(new))
271 271
272 272 unpack = Any(default_unpacker) # the actual packer function
273 273 def _unpack_changed(self, name, old, new):
274 274 # unpacker is not checked - it is assumed to be
275 275 if not callable(new):
276 276 raise TypeError("unpacker must be callable, not %s"%type(new))
277 277
278 278 def __init__(self, **kwargs):
279 279 """create a Session object
280 280
281 281 Parameters
282 282 ----------
283 283
284 284 debug : bool
285 285 whether to trigger extra debugging statements
286 286 packer/unpacker : str : 'json', 'pickle' or import_string
287 287 importstrings for methods to serialize message parts. If just
288 288 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 289 Otherwise, the entire importstring must be used.
290 290
291 291 The functions must accept at least valid JSON input, and output
292 292 *bytes*.
293 293
294 294 For example, to use msgpack:
295 295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 296 pack/unpack : callables
297 297 You can also set the pack/unpack callables for serialization
298 298 directly.
299 299 session : bytes
300 300 the ID of this Session object. The default is to generate a new
301 301 UUID.
302 302 username : unicode
303 303 username added to message headers. The default is to ask the OS.
304 304 key : bytes
305 305 The key used to initialize an HMAC signature. If unset, messages
306 306 will not be signed or checked.
307 307 keyfile : filepath
308 308 The file containing a key. If this is set, `key` will be
309 309 initialized to the contents of the file.
310 310 """
311 311 super(Session, self).__init__(**kwargs)
312 312 self._check_packers()
313 313 self.none = self.pack({})
314 314
315 315 @property
316 316 def msg_id(self):
317 317 """always return new uuid"""
318 318 return str(uuid.uuid4())
319 319
320 320 def _check_packers(self):
321 321 """check packers for binary data and datetime support."""
322 322 pack = self.pack
323 323 unpack = self.unpack
324 324
325 325 # check simple serialization
326 326 msg = dict(a=[1,'hi'])
327 327 try:
328 328 packed = pack(msg)
329 329 except Exception:
330 330 raise ValueError("packer could not serialize a simple message")
331 331
332 332 # ensure packed message is bytes
333 333 if not isinstance(packed, bytes):
334 334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335 335
336 336 # check that unpack is pack's inverse
337 337 try:
338 338 unpacked = unpack(packed)
339 339 except Exception:
340 340 raise ValueError("unpacker could not handle the packer's output")
341 341
342 342 # check datetime support
343 343 msg = dict(t=datetime.now())
344 344 try:
345 345 unpacked = unpack(pack(msg))
346 346 except Exception:
347 347 self.pack = lambda o: pack(squash_dates(o))
348 348 self.unpack = lambda s: extract_dates(unpack(s))
349 349
350 350 def msg_header(self, msg_type):
351 351 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 352
353 353 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
354 354 """Return the nested message dict.
355 355
356 356 This format is different from what is sent over the wire. The
357 357 serialize/unserialize methods converts this nested message dict to the wire
358 358 format, which is a list of message parts.
359 359 """
360 360 msg = {}
361 361 header = self.msg_header(msg_type) if header is None else header
362 362 msg['header'] = header
363 363 msg['msg_id'] = header['msg_id']
364 364 msg['msg_type'] = header['msg_type']
365 365 msg['parent_header'] = {} if parent is None else extract_header(parent)
366 366 msg['content'] = {} if content is None else content
367 367 sub = {} if subheader is None else subheader
368 368 msg['header'].update(sub)
369 369 return msg
370 370
371 371 def sign(self, msg_list):
372 372 """Sign a message with HMAC digest. If no auth, return b''.
373 373
374 374 Parameters
375 375 ----------
376 376 msg_list : list
377 377 The [p_header,p_parent,p_content] part of the message list.
378 378 """
379 379 if self.auth is None:
380 380 return b''
381 381 h = self.auth.copy()
382 382 for m in msg_list:
383 383 h.update(m)
384 384 return str_to_bytes(h.hexdigest())
385 385
386 386 def serialize(self, msg, ident=None):
387 387 """Serialize the message components to bytes.
388 388
389 389 This is roughly the inverse of unserialize. The serialize/unserialize
390 390 methods work with full message lists, whereas pack/unpack work with
391 391 the individual message parts in the message list.
392 392
393 393 Parameters
394 394 ----------
395 395 msg : dict or Message
396 396 The nexted message dict as returned by the self.msg method.
397 397
398 398 Returns
399 399 -------
400 400 msg_list : list
401 401 The list of bytes objects to be sent with the format:
402 402 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
403 403 buffer1,buffer2,...]. In this list, the p_* entities are
404 404 the packed or serialized versions, so if JSON is used, these
405 405 are utf8 encoded JSON strings.
406 406 """
407 407 content = msg.get('content', {})
408 408 if content is None:
409 409 content = self.none
410 410 elif isinstance(content, dict):
411 411 content = self.pack(content)
412 412 elif isinstance(content, bytes):
413 413 # content is already packed, as in a relayed message
414 414 pass
415 415 elif isinstance(content, unicode):
416 416 # should be bytes, but JSON often spits out unicode
417 417 content = content.encode('utf8')
418 418 else:
419 419 raise TypeError("Content incorrect type: %s"%type(content))
420 420
421 421 real_message = [self.pack(msg['header']),
422 422 self.pack(msg['parent_header']),
423 423 content
424 424 ]
425 425
426 426 to_send = []
427 427
428 428 if isinstance(ident, list):
429 429 # accept list of idents
430 430 to_send.extend(ident)
431 431 elif ident is not None:
432 432 to_send.append(ident)
433 433 to_send.append(DELIM)
434 434
435 435 signature = self.sign(real_message)
436 436 to_send.append(signature)
437 437
438 438 to_send.extend(real_message)
439 439
440 440 return to_send
441 441
442 442 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
443 443 buffers=None, subheader=None, track=False, header=None):
444 444 """Build and send a message via stream or socket.
445 445
446 446 The message format used by this function internally is as follows:
447 447
448 448 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
449 449 buffer1,buffer2,...]
450 450
451 451 The serialize/unserialize methods convert the nested message dict into this
452 452 format.
453 453
454 454 Parameters
455 455 ----------
456 456
457 457 stream : zmq.Socket or ZMQStream
458 458 The socket-like object used to send the data.
459 459 msg_or_type : str or Message/dict
460 460 Normally, msg_or_type will be a msg_type unless a message is being
461 461 sent more than once. If a header is supplied, this can be set to
462 462 None and the msg_type will be pulled from the header.
463 463
464 464 content : dict or None
465 465 The content of the message (ignored if msg_or_type is a message).
466 466 header : dict or None
467 467 The header dict for the message (ignores if msg_to_type is a message).
468 468 parent : Message or dict or None
469 469 The parent or parent header describing the parent of this message
470 470 (ignored if msg_or_type is a message).
471 471 ident : bytes or list of bytes
472 472 The zmq.IDENTITY routing path.
473 473 subheader : dict or None
474 474 Extra header keys for this message's header (ignored if msg_or_type
475 475 is a message).
476 476 buffers : list or None
477 477 The already-serialized buffers to be appended to the message.
478 478 track : bool
479 479 Whether to track. Only for use with Sockets, because ZMQStream
480 480 objects cannot track messages.
481 481
482 482 Returns
483 483 -------
484 484 msg : dict
485 485 The constructed message.
486 486 (msg,tracker) : (dict, MessageTracker)
487 487 if track=True, then a 2-tuple will be returned,
488 488 the first element being the constructed
489 489 message, and the second being the MessageTracker
490 490
491 491 """
492 492
493 493 if not isinstance(stream, (zmq.Socket, ZMQStream)):
494 494 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
495 495 elif track and isinstance(stream, ZMQStream):
496 496 raise TypeError("ZMQStream cannot track messages")
497 497
498 498 if isinstance(msg_or_type, (Message, dict)):
499 499 # We got a Message or message dict, not a msg_type so don't
500 500 # build a new Message.
501 501 msg = msg_or_type
502 502 else:
503 503 msg = self.msg(msg_or_type, content=content, parent=parent,
504 504 subheader=subheader, header=header)
505 505
506 506 buffers = [] if buffers is None else buffers
507 507 to_send = self.serialize(msg, ident)
508 508 flag = 0
509 509 if buffers:
510 510 flag = zmq.SNDMORE
511 511 _track = False
512 512 else:
513 513 _track=track
514 514 if track:
515 515 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
516 516 else:
517 517 tracker = stream.send_multipart(to_send, flag, copy=False)
518 518 for b in buffers[:-1]:
519 519 stream.send(b, flag, copy=False)
520 520 if buffers:
521 521 if track:
522 522 tracker = stream.send(buffers[-1], copy=False, track=track)
523 523 else:
524 524 tracker = stream.send(buffers[-1], copy=False)
525 525
526 526 # omsg = Message(msg)
527 527 if self.debug:
528 528 pprint.pprint(msg)
529 529 pprint.pprint(to_send)
530 530 pprint.pprint(buffers)
531 531
532 532 msg['tracker'] = tracker
533 533
534 534 return msg
535 535
536 536 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
537 537 """Send a raw message via ident path.
538 538
539 539 This method is used to send a already serialized message.
540 540
541 541 Parameters
542 542 ----------
543 543 stream : ZMQStream or Socket
544 544 The ZMQ stream or socket to use for sending the message.
545 545 msg_list : list
546 546 The serialized list of messages to send. This only includes the
547 547 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
548 548 the message.
549 549 ident : ident or list
550 550 A single ident or a list of idents to use in sending.
551 551 """
552 552 to_send = []
553 553 if isinstance(ident, bytes):
554 554 ident = [ident]
555 555 if ident is not None:
556 556 to_send.extend(ident)
557 557
558 558 to_send.append(DELIM)
559 559 to_send.append(self.sign(msg_list))
560 560 to_send.extend(msg_list)
561 561 stream.send_multipart(msg_list, flags, copy=copy)
562 562
563 563 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
564 564 """Receive and unpack a message.
565 565
566 566 Parameters
567 567 ----------
568 568 socket : ZMQStream or Socket
569 569 The socket or stream to use in receiving.
570 570
571 571 Returns
572 572 -------
573 573 [idents], msg
574 574 [idents] is a list of idents and msg is a nested message dict of
575 575 same format as self.msg returns.
576 576 """
577 577 if isinstance(socket, ZMQStream):
578 578 socket = socket.socket
579 579 try:
580 580 msg_list = socket.recv_multipart(mode)
581 581 except zmq.ZMQError as e:
582 582 if e.errno == zmq.EAGAIN:
583 583 # We can convert EAGAIN to None as we know in this case
584 584 # recv_multipart won't return None.
585 585 return None,None
586 586 else:
587 587 raise
588 588 # split multipart message into identity list and message dict
589 589 # invalid large messages can cause very expensive string comparisons
590 590 idents, msg_list = self.feed_identities(msg_list, copy)
591 591 try:
592 592 return idents, self.unserialize(msg_list, content=content, copy=copy)
593 593 except Exception as e:
594 594 # TODO: handle it
595 595 raise e
596 596
597 597 def feed_identities(self, msg_list, copy=True):
598 598 """Split the identities from the rest of the message.
599 599
600 600 Feed until DELIM is reached, then return the prefix as idents and
601 601 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
602 602 but that would be silly.
603 603
604 604 Parameters
605 605 ----------
606 606 msg_list : a list of Message or bytes objects
607 607 The message to be split.
608 608 copy : bool
609 609 flag determining whether the arguments are bytes or Messages
610 610
611 611 Returns
612 612 -------
613 613 (idents, msg_list) : two lists
614 614 idents will always be a list of bytes, each of which is a ZMQ
615 615 identity. msg_list will be a list of bytes or zmq.Messages of the
616 616 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
617 617 should be unpackable/unserializable via self.unserialize at this
618 618 point.
619 619 """
620 620 if copy:
621 621 idx = msg_list.index(DELIM)
622 622 return msg_list[:idx], msg_list[idx+1:]
623 623 else:
624 624 failed = True
625 625 for idx,m in enumerate(msg_list):
626 626 if m.bytes == DELIM:
627 627 failed = False
628 628 break
629 629 if failed:
630 630 raise ValueError("DELIM not in msg_list")
631 631 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
632 632 return [m.bytes for m in idents], msg_list
633 633
634 634 def unserialize(self, msg_list, content=True, copy=True):
635 635 """Unserialize a msg_list to a nested message dict.
636 636
637 637 This is roughly the inverse of serialize. The serialize/unserialize
638 638 methods work with full message lists, whereas pack/unpack work with
639 639 the individual message parts in the message list.
640 640
641 641 Parameters:
642 642 -----------
643 643 msg_list : list of bytes or Message objects
644 644 The list of message parts of the form [HMAC,p_header,p_parent,
645 645 p_content,buffer1,buffer2,...].
646 646 content : bool (True)
647 647 Whether to unpack the content dict (True), or leave it packed
648 648 (False).
649 649 copy : bool (True)
650 650 Whether to return the bytes (True), or the non-copying Message
651 651 object in each place (False).
652 652
653 653 Returns
654 654 -------
655 655 msg : dict
656 656 The nested message dict with top-level keys [header, parent_header,
657 657 content, buffers].
658 658 """
659 659 minlen = 4
660 660 message = {}
661 661 if not copy:
662 662 for i in range(minlen):
663 663 msg_list[i] = msg_list[i].bytes
664 664 if self.auth is not None:
665 665 signature = msg_list[0]
666 666 if not signature:
667 667 raise ValueError("Unsigned Message")
668 668 if signature in self.digest_history:
669 669 raise ValueError("Duplicate Signature: %r"%signature)
670 670 self.digest_history.add(signature)
671 671 check = self.sign(msg_list[1:4])
672 672 if not signature == check:
673 673 raise ValueError("Invalid Signature: %r"%signature)
674 674 if not len(msg_list) >= minlen:
675 675 raise TypeError("malformed message, must have at least %i elements"%minlen)
676 676 header = self.unpack(msg_list[1])
677 677 message['header'] = header
678 678 message['msg_id'] = header['msg_id']
679 679 message['msg_type'] = header['msg_type']
680 680 message['parent_header'] = self.unpack(msg_list[2])
681 681 if content:
682 682 message['content'] = self.unpack(msg_list[3])
683 683 else:
684 684 message['content'] = msg_list[3]
685 685
686 686 message['buffers'] = msg_list[4:]
687 687 return message
688 688
689 689 def test_msg2obj():
690 690 am = dict(x=1)
691 691 ao = Message(am)
692 692 assert ao.x == am['x']
693 693
694 694 am['y'] = dict(z=1)
695 695 ao = Message(am)
696 696 assert ao.y.z == am['y']['z']
697 697
698 698 k1, k2 = 'y', 'z'
699 699 assert ao[k1][k2] == am[k1][k2]
700 700
701 701 am2 = dict(ao)
702 702 assert am['x'] == am2['x']
703 703 assert am['y']['z'] == am2['y']['z']
704 704
General Comments 0
You need to be logged in to leave comments. Login now