##// END OF EJS Templates
cull Session digest history...
MinRK -
Show More
@@ -1,773 +1,806 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hmac
27 import hmac
28 import logging
28 import logging
29 import os
29 import os
30 import pprint
30 import pprint
31 import random
31 import uuid
32 import uuid
32 from datetime import datetime
33 from datetime import datetime
33
34
34 try:
35 try:
35 import cPickle
36 import cPickle
36 pickle = cPickle
37 pickle = cPickle
37 except:
38 except:
38 cPickle = None
39 cPickle = None
39 import pickle
40 import pickle
40
41
41 import zmq
42 import zmq
42 from zmq.utils import jsonapi
43 from zmq.utils import jsonapi
43 from zmq.eventloop.ioloop import IOLoop
44 from zmq.eventloop.ioloop import IOLoop
44 from zmq.eventloop.zmqstream import ZMQStream
45 from zmq.eventloop.zmqstream import ZMQStream
45
46
46 from IPython.config.application import Application, boolean_flag
47 from IPython.config.application import Application, boolean_flag
47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.utils import io
49 from IPython.utils import io
49 from IPython.utils.importstring import import_item
50 from IPython.utils.importstring import import_item
50 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.py3compat import str_to_bytes
52 from IPython.utils.py3compat import str_to_bytes
52 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
53 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
53 DottedObjectName, CUnicode, Dict, Integer)
54 DottedObjectName, CUnicode, Dict, Integer)
54 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
55 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
55
56
56 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
57 # utility functions
58 # utility functions
58 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
59
60
60 def squash_unicode(obj):
61 def squash_unicode(obj):
61 """coerce unicode back to bytestrings."""
62 """coerce unicode back to bytestrings."""
62 if isinstance(obj,dict):
63 if isinstance(obj,dict):
63 for key in obj.keys():
64 for key in obj.keys():
64 obj[key] = squash_unicode(obj[key])
65 obj[key] = squash_unicode(obj[key])
65 if isinstance(key, unicode):
66 if isinstance(key, unicode):
66 obj[squash_unicode(key)] = obj.pop(key)
67 obj[squash_unicode(key)] = obj.pop(key)
67 elif isinstance(obj, list):
68 elif isinstance(obj, list):
68 for i,v in enumerate(obj):
69 for i,v in enumerate(obj):
69 obj[i] = squash_unicode(v)
70 obj[i] = squash_unicode(v)
70 elif isinstance(obj, unicode):
71 elif isinstance(obj, unicode):
71 obj = obj.encode('utf8')
72 obj = obj.encode('utf8')
72 return obj
73 return obj
73
74
74 #-----------------------------------------------------------------------------
75 #-----------------------------------------------------------------------------
75 # globals and defaults
76 # globals and defaults
76 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
77
78
78 # ISO8601-ify datetime objects
79 # ISO8601-ify datetime objects
79 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
80 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
80 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
81 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
81
82
82 pickle_packer = lambda o: pickle.dumps(o,-1)
83 pickle_packer = lambda o: pickle.dumps(o,-1)
83 pickle_unpacker = pickle.loads
84 pickle_unpacker = pickle.loads
84
85
85 default_packer = json_packer
86 default_packer = json_packer
86 default_unpacker = json_unpacker
87 default_unpacker = json_unpacker
87
88
88 DELIM = b"<IDS|MSG>"
89 DELIM = b"<IDS|MSG>"
89 # singleton dummy tracker, which will always report as done
90 # singleton dummy tracker, which will always report as done
90 DONE = zmq.MessageTracker()
91 DONE = zmq.MessageTracker()
91
92
92 #-----------------------------------------------------------------------------
93 #-----------------------------------------------------------------------------
93 # Mixin tools for apps that use Sessions
94 # Mixin tools for apps that use Sessions
94 #-----------------------------------------------------------------------------
95 #-----------------------------------------------------------------------------
95
96
96 session_aliases = dict(
97 session_aliases = dict(
97 ident = 'Session.session',
98 ident = 'Session.session',
98 user = 'Session.username',
99 user = 'Session.username',
99 keyfile = 'Session.keyfile',
100 keyfile = 'Session.keyfile',
100 )
101 )
101
102
102 session_flags = {
103 session_flags = {
103 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
104 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
104 'keyfile' : '' }},
105 'keyfile' : '' }},
105 """Use HMAC digests for authentication of messages.
106 """Use HMAC digests for authentication of messages.
106 Setting this flag will generate a new UUID to use as the HMAC key.
107 Setting this flag will generate a new UUID to use as the HMAC key.
107 """),
108 """),
108 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
109 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
109 """Don't authenticate messages."""),
110 """Don't authenticate messages."""),
110 }
111 }
111
112
112 def default_secure(cfg):
113 def default_secure(cfg):
113 """Set the default behavior for a config environment to be secure.
114 """Set the default behavior for a config environment to be secure.
114
115
115 If Session.key/keyfile have not been set, set Session.key to
116 If Session.key/keyfile have not been set, set Session.key to
116 a new random UUID.
117 a new random UUID.
117 """
118 """
118
119
119 if 'Session' in cfg:
120 if 'Session' in cfg:
120 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
121 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
121 return
122 return
122 # key/keyfile not specified, generate new UUID:
123 # key/keyfile not specified, generate new UUID:
123 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
124 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
124
125
125
126
126 #-----------------------------------------------------------------------------
127 #-----------------------------------------------------------------------------
127 # Classes
128 # Classes
128 #-----------------------------------------------------------------------------
129 #-----------------------------------------------------------------------------
129
130
130 class SessionFactory(LoggingConfigurable):
131 class SessionFactory(LoggingConfigurable):
131 """The Base class for configurables that have a Session, Context, logger,
132 """The Base class for configurables that have a Session, Context, logger,
132 and IOLoop.
133 and IOLoop.
133 """
134 """
134
135
135 logname = Unicode('')
136 logname = Unicode('')
136 def _logname_changed(self, name, old, new):
137 def _logname_changed(self, name, old, new):
137 self.log = logging.getLogger(new)
138 self.log = logging.getLogger(new)
138
139
139 # not configurable:
140 # not configurable:
140 context = Instance('zmq.Context')
141 context = Instance('zmq.Context')
141 def _context_default(self):
142 def _context_default(self):
142 return zmq.Context.instance()
143 return zmq.Context.instance()
143
144
144 session = Instance('IPython.kernel.zmq.session.Session')
145 session = Instance('IPython.kernel.zmq.session.Session')
145
146
146 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
147 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
147 def _loop_default(self):
148 def _loop_default(self):
148 return IOLoop.instance()
149 return IOLoop.instance()
149
150
150 def __init__(self, **kwargs):
151 def __init__(self, **kwargs):
151 super(SessionFactory, self).__init__(**kwargs)
152 super(SessionFactory, self).__init__(**kwargs)
152
153
153 if self.session is None:
154 if self.session is None:
154 # construct the session
155 # construct the session
155 self.session = Session(**kwargs)
156 self.session = Session(**kwargs)
156
157
157
158
158 class Message(object):
159 class Message(object):
159 """A simple message object that maps dict keys to attributes.
160 """A simple message object that maps dict keys to attributes.
160
161
161 A Message can be created from a dict and a dict from a Message instance
162 A Message can be created from a dict and a dict from a Message instance
162 simply by calling dict(msg_obj)."""
163 simply by calling dict(msg_obj)."""
163
164
164 def __init__(self, msg_dict):
165 def __init__(self, msg_dict):
165 dct = self.__dict__
166 dct = self.__dict__
166 for k, v in dict(msg_dict).iteritems():
167 for k, v in dict(msg_dict).iteritems():
167 if isinstance(v, dict):
168 if isinstance(v, dict):
168 v = Message(v)
169 v = Message(v)
169 dct[k] = v
170 dct[k] = v
170
171
171 # Having this iterator lets dict(msg_obj) work out of the box.
172 # Having this iterator lets dict(msg_obj) work out of the box.
172 def __iter__(self):
173 def __iter__(self):
173 return iter(self.__dict__.iteritems())
174 return iter(self.__dict__.iteritems())
174
175
175 def __repr__(self):
176 def __repr__(self):
176 return repr(self.__dict__)
177 return repr(self.__dict__)
177
178
178 def __str__(self):
179 def __str__(self):
179 return pprint.pformat(self.__dict__)
180 return pprint.pformat(self.__dict__)
180
181
181 def __contains__(self, k):
182 def __contains__(self, k):
182 return k in self.__dict__
183 return k in self.__dict__
183
184
184 def __getitem__(self, k):
185 def __getitem__(self, k):
185 return self.__dict__[k]
186 return self.__dict__[k]
186
187
187
188
188 def msg_header(msg_id, msg_type, username, session):
189 def msg_header(msg_id, msg_type, username, session):
189 date = datetime.now()
190 date = datetime.now()
190 return locals()
191 return locals()
191
192
192 def extract_header(msg_or_header):
193 def extract_header(msg_or_header):
193 """Given a message or header, return the header."""
194 """Given a message or header, return the header."""
194 if not msg_or_header:
195 if not msg_or_header:
195 return {}
196 return {}
196 try:
197 try:
197 # See if msg_or_header is the entire message.
198 # See if msg_or_header is the entire message.
198 h = msg_or_header['header']
199 h = msg_or_header['header']
199 except KeyError:
200 except KeyError:
200 try:
201 try:
201 # See if msg_or_header is just the header
202 # See if msg_or_header is just the header
202 h = msg_or_header['msg_id']
203 h = msg_or_header['msg_id']
203 except KeyError:
204 except KeyError:
204 raise
205 raise
205 else:
206 else:
206 h = msg_or_header
207 h = msg_or_header
207 if not isinstance(h, dict):
208 if not isinstance(h, dict):
208 h = dict(h)
209 h = dict(h)
209 return h
210 return h
210
211
211 class Session(Configurable):
212 class Session(Configurable):
212 """Object for handling serialization and sending of messages.
213 """Object for handling serialization and sending of messages.
213
214
214 The Session object handles building messages and sending them
215 The Session object handles building messages and sending them
215 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
216 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
216 other over the network via Session objects, and only need to work with the
217 other over the network via Session objects, and only need to work with the
217 dict-based IPython message spec. The Session will handle
218 dict-based IPython message spec. The Session will handle
218 serialization/deserialization, security, and metadata.
219 serialization/deserialization, security, and metadata.
219
220
220 Sessions support configurable serialiization via packer/unpacker traits,
221 Sessions support configurable serialiization via packer/unpacker traits,
221 and signing with HMAC digests via the key/keyfile traits.
222 and signing with HMAC digests via the key/keyfile traits.
222
223
223 Parameters
224 Parameters
224 ----------
225 ----------
225
226
226 debug : bool
227 debug : bool
227 whether to trigger extra debugging statements
228 whether to trigger extra debugging statements
228 packer/unpacker : str : 'json', 'pickle' or import_string
229 packer/unpacker : str : 'json', 'pickle' or import_string
229 importstrings for methods to serialize message parts. If just
230 importstrings for methods to serialize message parts. If just
230 'json' or 'pickle', predefined JSON and pickle packers will be used.
231 'json' or 'pickle', predefined JSON and pickle packers will be used.
231 Otherwise, the entire importstring must be used.
232 Otherwise, the entire importstring must be used.
232
233
233 The functions must accept at least valid JSON input, and output *bytes*.
234 The functions must accept at least valid JSON input, and output *bytes*.
234
235
235 For example, to use msgpack:
236 For example, to use msgpack:
236 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
237 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
237 pack/unpack : callables
238 pack/unpack : callables
238 You can also set the pack/unpack callables for serialization directly.
239 You can also set the pack/unpack callables for serialization directly.
239 session : bytes
240 session : bytes
240 the ID of this Session object. The default is to generate a new UUID.
241 the ID of this Session object. The default is to generate a new UUID.
241 username : unicode
242 username : unicode
242 username added to message headers. The default is to ask the OS.
243 username added to message headers. The default is to ask the OS.
243 key : bytes
244 key : bytes
244 The key used to initialize an HMAC signature. If unset, messages
245 The key used to initialize an HMAC signature. If unset, messages
245 will not be signed or checked.
246 will not be signed or checked.
246 keyfile : filepath
247 keyfile : filepath
247 The file containing a key. If this is set, `key` will be initialized
248 The file containing a key. If this is set, `key` will be initialized
248 to the contents of the file.
249 to the contents of the file.
249
250
250 """
251 """
251
252
252 debug=Bool(False, config=True, help="""Debug output in the Session""")
253 debug=Bool(False, config=True, help="""Debug output in the Session""")
253
254
254 packer = DottedObjectName('json',config=True,
255 packer = DottedObjectName('json',config=True,
255 help="""The name of the packer for serializing messages.
256 help="""The name of the packer for serializing messages.
256 Should be one of 'json', 'pickle', or an import name
257 Should be one of 'json', 'pickle', or an import name
257 for a custom callable serializer.""")
258 for a custom callable serializer.""")
258 def _packer_changed(self, name, old, new):
259 def _packer_changed(self, name, old, new):
259 if new.lower() == 'json':
260 if new.lower() == 'json':
260 self.pack = json_packer
261 self.pack = json_packer
261 self.unpack = json_unpacker
262 self.unpack = json_unpacker
262 self.unpacker = new
263 self.unpacker = new
263 elif new.lower() == 'pickle':
264 elif new.lower() == 'pickle':
264 self.pack = pickle_packer
265 self.pack = pickle_packer
265 self.unpack = pickle_unpacker
266 self.unpack = pickle_unpacker
266 self.unpacker = new
267 self.unpacker = new
267 else:
268 else:
268 self.pack = import_item(str(new))
269 self.pack = import_item(str(new))
269
270
270 unpacker = DottedObjectName('json', config=True,
271 unpacker = DottedObjectName('json', config=True,
271 help="""The name of the unpacker for unserializing messages.
272 help="""The name of the unpacker for unserializing messages.
272 Only used with custom functions for `packer`.""")
273 Only used with custom functions for `packer`.""")
273 def _unpacker_changed(self, name, old, new):
274 def _unpacker_changed(self, name, old, new):
274 if new.lower() == 'json':
275 if new.lower() == 'json':
275 self.pack = json_packer
276 self.pack = json_packer
276 self.unpack = json_unpacker
277 self.unpack = json_unpacker
277 self.packer = new
278 self.packer = new
278 elif new.lower() == 'pickle':
279 elif new.lower() == 'pickle':
279 self.pack = pickle_packer
280 self.pack = pickle_packer
280 self.unpack = pickle_unpacker
281 self.unpack = pickle_unpacker
281 self.packer = new
282 self.packer = new
282 else:
283 else:
283 self.unpack = import_item(str(new))
284 self.unpack = import_item(str(new))
284
285
285 session = CUnicode(u'', config=True,
286 session = CUnicode(u'', config=True,
286 help="""The UUID identifying this session.""")
287 help="""The UUID identifying this session.""")
287 def _session_default(self):
288 def _session_default(self):
288 u = unicode(uuid.uuid4())
289 u = unicode(uuid.uuid4())
289 self.bsession = u.encode('ascii')
290 self.bsession = u.encode('ascii')
290 return u
291 return u
291
292
292 def _session_changed(self, name, old, new):
293 def _session_changed(self, name, old, new):
293 self.bsession = self.session.encode('ascii')
294 self.bsession = self.session.encode('ascii')
294
295
295 # bsession is the session as bytes
296 # bsession is the session as bytes
296 bsession = CBytes(b'')
297 bsession = CBytes(b'')
297
298
298 username = Unicode(os.environ.get('USER',u'username'), config=True,
299 username = Unicode(os.environ.get('USER',u'username'), config=True,
299 help="""Username for the Session. Default is your system username.""")
300 help="""Username for the Session. Default is your system username.""")
300
301
301 metadata = Dict({}, config=True,
302 metadata = Dict({}, config=True,
302 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
303 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
303
304
304 # message signature related traits:
305 # message signature related traits:
305
306
306 key = CBytes(b'', config=True,
307 key = CBytes(b'', config=True,
307 help="""execution key, for extra authentication.""")
308 help="""execution key, for extra authentication.""")
308 def _key_changed(self, name, old, new):
309 def _key_changed(self, name, old, new):
309 if new:
310 if new:
310 self.auth = hmac.HMAC(new)
311 self.auth = hmac.HMAC(new)
311 else:
312 else:
312 self.auth = None
313 self.auth = None
314
313 auth = Instance(hmac.HMAC)
315 auth = Instance(hmac.HMAC)
316
314 digest_history = Set()
317 digest_history = Set()
318 digest_history_size = Integer(2**16, config=True,
319 help="""The maximum number of digests to remember.
320
321 The digest history will be culled when it exceeds this value.
322 """
323 )
315
324
316 keyfile = Unicode('', config=True,
325 keyfile = Unicode('', config=True,
317 help="""path to file containing execution key.""")
326 help="""path to file containing execution key.""")
318 def _keyfile_changed(self, name, old, new):
327 def _keyfile_changed(self, name, old, new):
319 with open(new, 'rb') as f:
328 with open(new, 'rb') as f:
320 self.key = f.read().strip()
329 self.key = f.read().strip()
321
330
322 # for protecting against sends from forks
331 # for protecting against sends from forks
323 pid = Integer()
332 pid = Integer()
324
333
325 # serialization traits:
334 # serialization traits:
326
335
327 pack = Any(default_packer) # the actual packer function
336 pack = Any(default_packer) # the actual packer function
328 def _pack_changed(self, name, old, new):
337 def _pack_changed(self, name, old, new):
329 if not callable(new):
338 if not callable(new):
330 raise TypeError("packer must be callable, not %s"%type(new))
339 raise TypeError("packer must be callable, not %s"%type(new))
331
340
332 unpack = Any(default_unpacker) # the actual packer function
341 unpack = Any(default_unpacker) # the actual packer function
333 def _unpack_changed(self, name, old, new):
342 def _unpack_changed(self, name, old, new):
334 # unpacker is not checked - it is assumed to be
343 # unpacker is not checked - it is assumed to be
335 if not callable(new):
344 if not callable(new):
336 raise TypeError("unpacker must be callable, not %s"%type(new))
345 raise TypeError("unpacker must be callable, not %s"%type(new))
337
346
338 # thresholds:
347 # thresholds:
339 copy_threshold = Integer(2**16, config=True,
348 copy_threshold = Integer(2**16, config=True,
340 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
349 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
341 buffer_threshold = Integer(MAX_BYTES, config=True,
350 buffer_threshold = Integer(MAX_BYTES, config=True,
342 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
351 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
343 item_threshold = Integer(MAX_ITEMS, config=True,
352 item_threshold = Integer(MAX_ITEMS, config=True,
344 help="""The maximum number of items for a container to be introspected for custom serialization.
353 help="""The maximum number of items for a container to be introspected for custom serialization.
345 Containers larger than this are pickled outright.
354 Containers larger than this are pickled outright.
346 """
355 """
347 )
356 )
348
357
349
358
350 def __init__(self, **kwargs):
359 def __init__(self, **kwargs):
351 """create a Session object
360 """create a Session object
352
361
353 Parameters
362 Parameters
354 ----------
363 ----------
355
364
356 debug : bool
365 debug : bool
357 whether to trigger extra debugging statements
366 whether to trigger extra debugging statements
358 packer/unpacker : str : 'json', 'pickle' or import_string
367 packer/unpacker : str : 'json', 'pickle' or import_string
359 importstrings for methods to serialize message parts. If just
368 importstrings for methods to serialize message parts. If just
360 'json' or 'pickle', predefined JSON and pickle packers will be used.
369 'json' or 'pickle', predefined JSON and pickle packers will be used.
361 Otherwise, the entire importstring must be used.
370 Otherwise, the entire importstring must be used.
362
371
363 The functions must accept at least valid JSON input, and output
372 The functions must accept at least valid JSON input, and output
364 *bytes*.
373 *bytes*.
365
374
366 For example, to use msgpack:
375 For example, to use msgpack:
367 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
376 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
368 pack/unpack : callables
377 pack/unpack : callables
369 You can also set the pack/unpack callables for serialization
378 You can also set the pack/unpack callables for serialization
370 directly.
379 directly.
371 session : unicode (must be ascii)
380 session : unicode (must be ascii)
372 the ID of this Session object. The default is to generate a new
381 the ID of this Session object. The default is to generate a new
373 UUID.
382 UUID.
374 bsession : bytes
383 bsession : bytes
375 The session as bytes
384 The session as bytes
376 username : unicode
385 username : unicode
377 username added to message headers. The default is to ask the OS.
386 username added to message headers. The default is to ask the OS.
378 key : bytes
387 key : bytes
379 The key used to initialize an HMAC signature. If unset, messages
388 The key used to initialize an HMAC signature. If unset, messages
380 will not be signed or checked.
389 will not be signed or checked.
381 keyfile : filepath
390 keyfile : filepath
382 The file containing a key. If this is set, `key` will be
391 The file containing a key. If this is set, `key` will be
383 initialized to the contents of the file.
392 initialized to the contents of the file.
384 """
393 """
385 super(Session, self).__init__(**kwargs)
394 super(Session, self).__init__(**kwargs)
386 self._check_packers()
395 self._check_packers()
387 self.none = self.pack({})
396 self.none = self.pack({})
388 # ensure self._session_default() if necessary, so bsession is defined:
397 # ensure self._session_default() if necessary, so bsession is defined:
389 self.session
398 self.session
390 self.pid = os.getpid()
399 self.pid = os.getpid()
391
400
392 @property
401 @property
393 def msg_id(self):
402 def msg_id(self):
394 """always return new uuid"""
403 """always return new uuid"""
395 return str(uuid.uuid4())
404 return str(uuid.uuid4())
396
405
397 def _check_packers(self):
406 def _check_packers(self):
398 """check packers for binary data and datetime support."""
407 """check packers for binary data and datetime support."""
399 pack = self.pack
408 pack = self.pack
400 unpack = self.unpack
409 unpack = self.unpack
401
410
402 # check simple serialization
411 # check simple serialization
403 msg = dict(a=[1,'hi'])
412 msg = dict(a=[1,'hi'])
404 try:
413 try:
405 packed = pack(msg)
414 packed = pack(msg)
406 except Exception:
415 except Exception:
407 raise ValueError("packer could not serialize a simple message")
416 raise ValueError("packer could not serialize a simple message")
408
417
409 # ensure packed message is bytes
418 # ensure packed message is bytes
410 if not isinstance(packed, bytes):
419 if not isinstance(packed, bytes):
411 raise ValueError("message packed to %r, but bytes are required"%type(packed))
420 raise ValueError("message packed to %r, but bytes are required"%type(packed))
412
421
413 # check that unpack is pack's inverse
422 # check that unpack is pack's inverse
414 try:
423 try:
415 unpacked = unpack(packed)
424 unpacked = unpack(packed)
416 except Exception:
425 except Exception:
417 raise ValueError("unpacker could not handle the packer's output")
426 raise ValueError("unpacker could not handle the packer's output")
418
427
419 # check datetime support
428 # check datetime support
420 msg = dict(t=datetime.now())
429 msg = dict(t=datetime.now())
421 try:
430 try:
422 unpacked = unpack(pack(msg))
431 unpacked = unpack(pack(msg))
423 except Exception:
432 except Exception:
424 self.pack = lambda o: pack(squash_dates(o))
433 self.pack = lambda o: pack(squash_dates(o))
425 self.unpack = lambda s: extract_dates(unpack(s))
434 self.unpack = lambda s: extract_dates(unpack(s))
426
435
427 def msg_header(self, msg_type):
436 def msg_header(self, msg_type):
428 return msg_header(self.msg_id, msg_type, self.username, self.session)
437 return msg_header(self.msg_id, msg_type, self.username, self.session)
429
438
430 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
439 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
431 """Return the nested message dict.
440 """Return the nested message dict.
432
441
433 This format is different from what is sent over the wire. The
442 This format is different from what is sent over the wire. The
434 serialize/unserialize methods converts this nested message dict to the wire
443 serialize/unserialize methods converts this nested message dict to the wire
435 format, which is a list of message parts.
444 format, which is a list of message parts.
436 """
445 """
437 msg = {}
446 msg = {}
438 header = self.msg_header(msg_type) if header is None else header
447 header = self.msg_header(msg_type) if header is None else header
439 msg['header'] = header
448 msg['header'] = header
440 msg['msg_id'] = header['msg_id']
449 msg['msg_id'] = header['msg_id']
441 msg['msg_type'] = header['msg_type']
450 msg['msg_type'] = header['msg_type']
442 msg['parent_header'] = {} if parent is None else extract_header(parent)
451 msg['parent_header'] = {} if parent is None else extract_header(parent)
443 msg['content'] = {} if content is None else content
452 msg['content'] = {} if content is None else content
444 msg['metadata'] = self.metadata.copy()
453 msg['metadata'] = self.metadata.copy()
445 if metadata is not None:
454 if metadata is not None:
446 msg['metadata'].update(metadata)
455 msg['metadata'].update(metadata)
447 return msg
456 return msg
448
457
449 def sign(self, msg_list):
458 def sign(self, msg_list):
450 """Sign a message with HMAC digest. If no auth, return b''.
459 """Sign a message with HMAC digest. If no auth, return b''.
451
460
452 Parameters
461 Parameters
453 ----------
462 ----------
454 msg_list : list
463 msg_list : list
455 The [p_header,p_parent,p_content] part of the message list.
464 The [p_header,p_parent,p_content] part of the message list.
456 """
465 """
457 if self.auth is None:
466 if self.auth is None:
458 return b''
467 return b''
459 h = self.auth.copy()
468 h = self.auth.copy()
460 for m in msg_list:
469 for m in msg_list:
461 h.update(m)
470 h.update(m)
462 return str_to_bytes(h.hexdigest())
471 return str_to_bytes(h.hexdigest())
463
472
464 def serialize(self, msg, ident=None):
473 def serialize(self, msg, ident=None):
465 """Serialize the message components to bytes.
474 """Serialize the message components to bytes.
466
475
467 This is roughly the inverse of unserialize. The serialize/unserialize
476 This is roughly the inverse of unserialize. The serialize/unserialize
468 methods work with full message lists, whereas pack/unpack work with
477 methods work with full message lists, whereas pack/unpack work with
469 the individual message parts in the message list.
478 the individual message parts in the message list.
470
479
471 Parameters
480 Parameters
472 ----------
481 ----------
473 msg : dict or Message
482 msg : dict or Message
474 The nexted message dict as returned by the self.msg method.
483 The nexted message dict as returned by the self.msg method.
475
484
476 Returns
485 Returns
477 -------
486 -------
478 msg_list : list
487 msg_list : list
479 The list of bytes objects to be sent with the format:
488 The list of bytes objects to be sent with the format:
480 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
489 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
481 buffer1,buffer2,...]. In this list, the p_* entities are
490 buffer1,buffer2,...]. In this list, the p_* entities are
482 the packed or serialized versions, so if JSON is used, these
491 the packed or serialized versions, so if JSON is used, these
483 are utf8 encoded JSON strings.
492 are utf8 encoded JSON strings.
484 """
493 """
485 content = msg.get('content', {})
494 content = msg.get('content', {})
486 if content is None:
495 if content is None:
487 content = self.none
496 content = self.none
488 elif isinstance(content, dict):
497 elif isinstance(content, dict):
489 content = self.pack(content)
498 content = self.pack(content)
490 elif isinstance(content, bytes):
499 elif isinstance(content, bytes):
491 # content is already packed, as in a relayed message
500 # content is already packed, as in a relayed message
492 pass
501 pass
493 elif isinstance(content, unicode):
502 elif isinstance(content, unicode):
494 # should be bytes, but JSON often spits out unicode
503 # should be bytes, but JSON often spits out unicode
495 content = content.encode('utf8')
504 content = content.encode('utf8')
496 else:
505 else:
497 raise TypeError("Content incorrect type: %s"%type(content))
506 raise TypeError("Content incorrect type: %s"%type(content))
498
507
499 real_message = [self.pack(msg['header']),
508 real_message = [self.pack(msg['header']),
500 self.pack(msg['parent_header']),
509 self.pack(msg['parent_header']),
501 self.pack(msg['metadata']),
510 self.pack(msg['metadata']),
502 content,
511 content,
503 ]
512 ]
504
513
505 to_send = []
514 to_send = []
506
515
507 if isinstance(ident, list):
516 if isinstance(ident, list):
508 # accept list of idents
517 # accept list of idents
509 to_send.extend(ident)
518 to_send.extend(ident)
510 elif ident is not None:
519 elif ident is not None:
511 to_send.append(ident)
520 to_send.append(ident)
512 to_send.append(DELIM)
521 to_send.append(DELIM)
513
522
514 signature = self.sign(real_message)
523 signature = self.sign(real_message)
515 to_send.append(signature)
524 to_send.append(signature)
516
525
517 to_send.extend(real_message)
526 to_send.extend(real_message)
518
527
519 return to_send
528 return to_send
520
529
521 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
530 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
522 buffers=None, track=False, header=None, metadata=None):
531 buffers=None, track=False, header=None, metadata=None):
523 """Build and send a message via stream or socket.
532 """Build and send a message via stream or socket.
524
533
525 The message format used by this function internally is as follows:
534 The message format used by this function internally is as follows:
526
535
527 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
536 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
528 buffer1,buffer2,...]
537 buffer1,buffer2,...]
529
538
530 The serialize/unserialize methods convert the nested message dict into this
539 The serialize/unserialize methods convert the nested message dict into this
531 format.
540 format.
532
541
533 Parameters
542 Parameters
534 ----------
543 ----------
535
544
536 stream : zmq.Socket or ZMQStream
545 stream : zmq.Socket or ZMQStream
537 The socket-like object used to send the data.
546 The socket-like object used to send the data.
538 msg_or_type : str or Message/dict
547 msg_or_type : str or Message/dict
539 Normally, msg_or_type will be a msg_type unless a message is being
548 Normally, msg_or_type will be a msg_type unless a message is being
540 sent more than once. If a header is supplied, this can be set to
549 sent more than once. If a header is supplied, this can be set to
541 None and the msg_type will be pulled from the header.
550 None and the msg_type will be pulled from the header.
542
551
543 content : dict or None
552 content : dict or None
544 The content of the message (ignored if msg_or_type is a message).
553 The content of the message (ignored if msg_or_type is a message).
545 header : dict or None
554 header : dict or None
546 The header dict for the message (ignored if msg_to_type is a message).
555 The header dict for the message (ignored if msg_to_type is a message).
547 parent : Message or dict or None
556 parent : Message or dict or None
548 The parent or parent header describing the parent of this message
557 The parent or parent header describing the parent of this message
549 (ignored if msg_or_type is a message).
558 (ignored if msg_or_type is a message).
550 ident : bytes or list of bytes
559 ident : bytes or list of bytes
551 The zmq.IDENTITY routing path.
560 The zmq.IDENTITY routing path.
552 metadata : dict or None
561 metadata : dict or None
553 The metadata describing the message
562 The metadata describing the message
554 buffers : list or None
563 buffers : list or None
555 The already-serialized buffers to be appended to the message.
564 The already-serialized buffers to be appended to the message.
556 track : bool
565 track : bool
557 Whether to track. Only for use with Sockets, because ZMQStream
566 Whether to track. Only for use with Sockets, because ZMQStream
558 objects cannot track messages.
567 objects cannot track messages.
559
568
560
569
561 Returns
570 Returns
562 -------
571 -------
563 msg : dict
572 msg : dict
564 The constructed message.
573 The constructed message.
565 """
574 """
566 if not isinstance(stream, zmq.Socket):
575 if not isinstance(stream, zmq.Socket):
567 # ZMQStreams and dummy sockets do not support tracking.
576 # ZMQStreams and dummy sockets do not support tracking.
568 track = False
577 track = False
569
578
570 if isinstance(msg_or_type, (Message, dict)):
579 if isinstance(msg_or_type, (Message, dict)):
571 # We got a Message or message dict, not a msg_type so don't
580 # We got a Message or message dict, not a msg_type so don't
572 # build a new Message.
581 # build a new Message.
573 msg = msg_or_type
582 msg = msg_or_type
574 else:
583 else:
575 msg = self.msg(msg_or_type, content=content, parent=parent,
584 msg = self.msg(msg_or_type, content=content, parent=parent,
576 header=header, metadata=metadata)
585 header=header, metadata=metadata)
577 if not os.getpid() == self.pid:
586 if not os.getpid() == self.pid:
578 io.rprint("WARNING: attempted to send message from fork")
587 io.rprint("WARNING: attempted to send message from fork")
579 io.rprint(msg)
588 io.rprint(msg)
580 return
589 return
581 buffers = [] if buffers is None else buffers
590 buffers = [] if buffers is None else buffers
582 to_send = self.serialize(msg, ident)
591 to_send = self.serialize(msg, ident)
583 to_send.extend(buffers)
592 to_send.extend(buffers)
584 longest = max([ len(s) for s in to_send ])
593 longest = max([ len(s) for s in to_send ])
585 copy = (longest < self.copy_threshold)
594 copy = (longest < self.copy_threshold)
586
595
587 if buffers and track and not copy:
596 if buffers and track and not copy:
588 # only really track when we are doing zero-copy buffers
597 # only really track when we are doing zero-copy buffers
589 tracker = stream.send_multipart(to_send, copy=False, track=True)
598 tracker = stream.send_multipart(to_send, copy=False, track=True)
590 else:
599 else:
591 # use dummy tracker, which will be done immediately
600 # use dummy tracker, which will be done immediately
592 tracker = DONE
601 tracker = DONE
593 stream.send_multipart(to_send, copy=copy)
602 stream.send_multipart(to_send, copy=copy)
594
603
595 if self.debug:
604 if self.debug:
596 pprint.pprint(msg)
605 pprint.pprint(msg)
597 pprint.pprint(to_send)
606 pprint.pprint(to_send)
598 pprint.pprint(buffers)
607 pprint.pprint(buffers)
599
608
600 msg['tracker'] = tracker
609 msg['tracker'] = tracker
601
610
602 return msg
611 return msg
603
612
604 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
613 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
605 """Send a raw message via ident path.
614 """Send a raw message via ident path.
606
615
607 This method is used to send a already serialized message.
616 This method is used to send a already serialized message.
608
617
609 Parameters
618 Parameters
610 ----------
619 ----------
611 stream : ZMQStream or Socket
620 stream : ZMQStream or Socket
612 The ZMQ stream or socket to use for sending the message.
621 The ZMQ stream or socket to use for sending the message.
613 msg_list : list
622 msg_list : list
614 The serialized list of messages to send. This only includes the
623 The serialized list of messages to send. This only includes the
615 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
624 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
616 the message.
625 the message.
617 ident : ident or list
626 ident : ident or list
618 A single ident or a list of idents to use in sending.
627 A single ident or a list of idents to use in sending.
619 """
628 """
620 to_send = []
629 to_send = []
621 if isinstance(ident, bytes):
630 if isinstance(ident, bytes):
622 ident = [ident]
631 ident = [ident]
623 if ident is not None:
632 if ident is not None:
624 to_send.extend(ident)
633 to_send.extend(ident)
625
634
626 to_send.append(DELIM)
635 to_send.append(DELIM)
627 to_send.append(self.sign(msg_list))
636 to_send.append(self.sign(msg_list))
628 to_send.extend(msg_list)
637 to_send.extend(msg_list)
629 stream.send_multipart(msg_list, flags, copy=copy)
638 stream.send_multipart(msg_list, flags, copy=copy)
630
639
631 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
640 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
632 """Receive and unpack a message.
641 """Receive and unpack a message.
633
642
634 Parameters
643 Parameters
635 ----------
644 ----------
636 socket : ZMQStream or Socket
645 socket : ZMQStream or Socket
637 The socket or stream to use in receiving.
646 The socket or stream to use in receiving.
638
647
639 Returns
648 Returns
640 -------
649 -------
641 [idents], msg
650 [idents], msg
642 [idents] is a list of idents and msg is a nested message dict of
651 [idents] is a list of idents and msg is a nested message dict of
643 same format as self.msg returns.
652 same format as self.msg returns.
644 """
653 """
645 if isinstance(socket, ZMQStream):
654 if isinstance(socket, ZMQStream):
646 socket = socket.socket
655 socket = socket.socket
647 try:
656 try:
648 msg_list = socket.recv_multipart(mode, copy=copy)
657 msg_list = socket.recv_multipart(mode, copy=copy)
649 except zmq.ZMQError as e:
658 except zmq.ZMQError as e:
650 if e.errno == zmq.EAGAIN:
659 if e.errno == zmq.EAGAIN:
651 # We can convert EAGAIN to None as we know in this case
660 # We can convert EAGAIN to None as we know in this case
652 # recv_multipart won't return None.
661 # recv_multipart won't return None.
653 return None,None
662 return None,None
654 else:
663 else:
655 raise
664 raise
656 # split multipart message into identity list and message dict
665 # split multipart message into identity list and message dict
657 # invalid large messages can cause very expensive string comparisons
666 # invalid large messages can cause very expensive string comparisons
658 idents, msg_list = self.feed_identities(msg_list, copy)
667 idents, msg_list = self.feed_identities(msg_list, copy)
659 try:
668 try:
660 return idents, self.unserialize(msg_list, content=content, copy=copy)
669 return idents, self.unserialize(msg_list, content=content, copy=copy)
661 except Exception as e:
670 except Exception as e:
662 # TODO: handle it
671 # TODO: handle it
663 raise e
672 raise e
664
673
665 def feed_identities(self, msg_list, copy=True):
674 def feed_identities(self, msg_list, copy=True):
666 """Split the identities from the rest of the message.
675 """Split the identities from the rest of the message.
667
676
668 Feed until DELIM is reached, then return the prefix as idents and
677 Feed until DELIM is reached, then return the prefix as idents and
669 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
678 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
670 but that would be silly.
679 but that would be silly.
671
680
672 Parameters
681 Parameters
673 ----------
682 ----------
674 msg_list : a list of Message or bytes objects
683 msg_list : a list of Message or bytes objects
675 The message to be split.
684 The message to be split.
676 copy : bool
685 copy : bool
677 flag determining whether the arguments are bytes or Messages
686 flag determining whether the arguments are bytes or Messages
678
687
679 Returns
688 Returns
680 -------
689 -------
681 (idents, msg_list) : two lists
690 (idents, msg_list) : two lists
682 idents will always be a list of bytes, each of which is a ZMQ
691 idents will always be a list of bytes, each of which is a ZMQ
683 identity. msg_list will be a list of bytes or zmq.Messages of the
692 identity. msg_list will be a list of bytes or zmq.Messages of the
684 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
693 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
685 should be unpackable/unserializable via self.unserialize at this
694 should be unpackable/unserializable via self.unserialize at this
686 point.
695 point.
687 """
696 """
688 if copy:
697 if copy:
689 idx = msg_list.index(DELIM)
698 idx = msg_list.index(DELIM)
690 return msg_list[:idx], msg_list[idx+1:]
699 return msg_list[:idx], msg_list[idx+1:]
691 else:
700 else:
692 failed = True
701 failed = True
693 for idx,m in enumerate(msg_list):
702 for idx,m in enumerate(msg_list):
694 if m.bytes == DELIM:
703 if m.bytes == DELIM:
695 failed = False
704 failed = False
696 break
705 break
697 if failed:
706 if failed:
698 raise ValueError("DELIM not in msg_list")
707 raise ValueError("DELIM not in msg_list")
699 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
708 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
700 return [m.bytes for m in idents], msg_list
709 return [m.bytes for m in idents], msg_list
701
710
711 def _add_digest(self, signature):
712 """add a digest to history to protect against replay attacks"""
713 if self.digest_history_size == 0:
714 # no history, never add digests
715 return
716
717 self.digest_history.add(signature)
718 if len(self.digest_history) > self.digest_history_size:
719 # threshold reached, cull 10%
720 self._cull_digest_history()
721
722 def _cull_digest_history(self):
723 """cull the digest history
724
725 Removes a randomly selected 10% of the digest history
726 """
727 current = len(self.digest_history)
728 n_to_cull = max(int(current // 10), current - self.digest_history_size)
729 if n_to_cull >= current:
730 self.digest_history = set()
731 return
732 to_cull = random.sample(self.digest_history, n_to_cull)
733 self.digest_history.difference_update(to_cull)
734
702 def unserialize(self, msg_list, content=True, copy=True):
735 def unserialize(self, msg_list, content=True, copy=True):
703 """Unserialize a msg_list to a nested message dict.
736 """Unserialize a msg_list to a nested message dict.
704
737
705 This is roughly the inverse of serialize. The serialize/unserialize
738 This is roughly the inverse of serialize. The serialize/unserialize
706 methods work with full message lists, whereas pack/unpack work with
739 methods work with full message lists, whereas pack/unpack work with
707 the individual message parts in the message list.
740 the individual message parts in the message list.
708
741
709 Parameters:
742 Parameters:
710 -----------
743 -----------
711 msg_list : list of bytes or Message objects
744 msg_list : list of bytes or Message objects
712 The list of message parts of the form [HMAC,p_header,p_parent,
745 The list of message parts of the form [HMAC,p_header,p_parent,
713 p_metadata,p_content,buffer1,buffer2,...].
746 p_metadata,p_content,buffer1,buffer2,...].
714 content : bool (True)
747 content : bool (True)
715 Whether to unpack the content dict (True), or leave it packed
748 Whether to unpack the content dict (True), or leave it packed
716 (False).
749 (False).
717 copy : bool (True)
750 copy : bool (True)
718 Whether to return the bytes (True), or the non-copying Message
751 Whether to return the bytes (True), or the non-copying Message
719 object in each place (False).
752 object in each place (False).
720
753
721 Returns
754 Returns
722 -------
755 -------
723 msg : dict
756 msg : dict
724 The nested message dict with top-level keys [header, parent_header,
757 The nested message dict with top-level keys [header, parent_header,
725 content, buffers].
758 content, buffers].
726 """
759 """
727 minlen = 5
760 minlen = 5
728 message = {}
761 message = {}
729 if not copy:
762 if not copy:
730 for i in range(minlen):
763 for i in range(minlen):
731 msg_list[i] = msg_list[i].bytes
764 msg_list[i] = msg_list[i].bytes
732 if self.auth is not None:
765 if self.auth is not None:
733 signature = msg_list[0]
766 signature = msg_list[0]
734 if not signature:
767 if not signature:
735 raise ValueError("Unsigned Message")
768 raise ValueError("Unsigned Message")
736 if signature in self.digest_history:
769 if signature in self.digest_history:
737 raise ValueError("Duplicate Signature: %r"%signature)
770 raise ValueError("Duplicate Signature: %r" % signature)
738 self.digest_history.add(signature)
771 self._add_digest(signature)
739 check = self.sign(msg_list[1:5])
772 check = self.sign(msg_list[1:5])
740 if not signature == check:
773 if not signature == check:
741 raise ValueError("Invalid Signature: %r" % signature)
774 raise ValueError("Invalid Signature: %r" % signature)
742 if not len(msg_list) >= minlen:
775 if not len(msg_list) >= minlen:
743 raise TypeError("malformed message, must have at least %i elements"%minlen)
776 raise TypeError("malformed message, must have at least %i elements"%minlen)
744 header = self.unpack(msg_list[1])
777 header = self.unpack(msg_list[1])
745 message['header'] = header
778 message['header'] = header
746 message['msg_id'] = header['msg_id']
779 message['msg_id'] = header['msg_id']
747 message['msg_type'] = header['msg_type']
780 message['msg_type'] = header['msg_type']
748 message['parent_header'] = self.unpack(msg_list[2])
781 message['parent_header'] = self.unpack(msg_list[2])
749 message['metadata'] = self.unpack(msg_list[3])
782 message['metadata'] = self.unpack(msg_list[3])
750 if content:
783 if content:
751 message['content'] = self.unpack(msg_list[4])
784 message['content'] = self.unpack(msg_list[4])
752 else:
785 else:
753 message['content'] = msg_list[4]
786 message['content'] = msg_list[4]
754
787
755 message['buffers'] = msg_list[5:]
788 message['buffers'] = msg_list[5:]
756 return message
789 return message
757
790
758 def test_msg2obj():
791 def test_msg2obj():
759 am = dict(x=1)
792 am = dict(x=1)
760 ao = Message(am)
793 ao = Message(am)
761 assert ao.x == am['x']
794 assert ao.x == am['x']
762
795
763 am['y'] = dict(z=1)
796 am['y'] = dict(z=1)
764 ao = Message(am)
797 ao = Message(am)
765 assert ao.y.z == am['y']['z']
798 assert ao.y.z == am['y']['z']
766
799
767 k1, k2 = 'y', 'z'
800 k1, k2 = 'y', 'z'
768 assert ao[k1][k2] == am[k1][k2]
801 assert ao[k1][k2] == am[k1][k2]
769
802
770 am2 = dict(ao)
803 am2 = dict(ao)
771 assert am['x'] == am2['x']
804 assert am['x'] == am2['x']
772 assert am['y']['z'] == am2['y']['z']
805 assert am['y']['z'] == am2['y']['z']
773
806
@@ -1,207 +1,225 b''
1 """test building messages with streamsession"""
1 """test building messages with streamsession"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import uuid
15 import uuid
16 import zmq
16 import zmq
17
17
18 from zmq.tests import BaseZMQTestCase
18 from zmq.tests import BaseZMQTestCase
19 from zmq.eventloop.zmqstream import ZMQStream
19 from zmq.eventloop.zmqstream import ZMQStream
20
20
21 from IPython.kernel.zmq import session as ss
21 from IPython.kernel.zmq import session as ss
22
22
23 class SessionTestCase(BaseZMQTestCase):
23 class SessionTestCase(BaseZMQTestCase):
24
24
25 def setUp(self):
25 def setUp(self):
26 BaseZMQTestCase.setUp(self)
26 BaseZMQTestCase.setUp(self)
27 self.session = ss.Session()
27 self.session = ss.Session()
28
28
29
29
30 class TestSession(SessionTestCase):
30 class TestSession(SessionTestCase):
31
31
32 def test_msg(self):
32 def test_msg(self):
33 """message format"""
33 """message format"""
34 msg = self.session.msg('execute')
34 msg = self.session.msg('execute')
35 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
35 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
36 s = set(msg.keys())
36 s = set(msg.keys())
37 self.assertEqual(s, thekeys)
37 self.assertEqual(s, thekeys)
38 self.assertTrue(isinstance(msg['content'],dict))
38 self.assertTrue(isinstance(msg['content'],dict))
39 self.assertTrue(isinstance(msg['metadata'],dict))
39 self.assertTrue(isinstance(msg['metadata'],dict))
40 self.assertTrue(isinstance(msg['header'],dict))
40 self.assertTrue(isinstance(msg['header'],dict))
41 self.assertTrue(isinstance(msg['parent_header'],dict))
41 self.assertTrue(isinstance(msg['parent_header'],dict))
42 self.assertTrue(isinstance(msg['msg_id'],str))
42 self.assertTrue(isinstance(msg['msg_id'],str))
43 self.assertTrue(isinstance(msg['msg_type'],str))
43 self.assertTrue(isinstance(msg['msg_type'],str))
44 self.assertEqual(msg['header']['msg_type'], 'execute')
44 self.assertEqual(msg['header']['msg_type'], 'execute')
45 self.assertEqual(msg['msg_type'], 'execute')
45 self.assertEqual(msg['msg_type'], 'execute')
46
46
47 def test_serialize(self):
47 def test_serialize(self):
48 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
48 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
49 msg_list = self.session.serialize(msg, ident=b'foo')
49 msg_list = self.session.serialize(msg, ident=b'foo')
50 ident, msg_list = self.session.feed_identities(msg_list)
50 ident, msg_list = self.session.feed_identities(msg_list)
51 new_msg = self.session.unserialize(msg_list)
51 new_msg = self.session.unserialize(msg_list)
52 self.assertEqual(ident[0], b'foo')
52 self.assertEqual(ident[0], b'foo')
53 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
53 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
54 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
54 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
55 self.assertEqual(new_msg['header'],msg['header'])
55 self.assertEqual(new_msg['header'],msg['header'])
56 self.assertEqual(new_msg['content'],msg['content'])
56 self.assertEqual(new_msg['content'],msg['content'])
57 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
57 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
58 self.assertEqual(new_msg['metadata'],msg['metadata'])
58 self.assertEqual(new_msg['metadata'],msg['metadata'])
59 # ensure floats don't come out as Decimal:
59 # ensure floats don't come out as Decimal:
60 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
60 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
61
61
62 def test_send(self):
62 def test_send(self):
63 ctx = zmq.Context.instance()
63 ctx = zmq.Context.instance()
64 A = ctx.socket(zmq.PAIR)
64 A = ctx.socket(zmq.PAIR)
65 B = ctx.socket(zmq.PAIR)
65 B = ctx.socket(zmq.PAIR)
66 A.bind("inproc://test")
66 A.bind("inproc://test")
67 B.connect("inproc://test")
67 B.connect("inproc://test")
68
68
69 msg = self.session.msg('execute', content=dict(a=10))
69 msg = self.session.msg('execute', content=dict(a=10))
70 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
70 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
71
71
72 ident, msg_list = self.session.feed_identities(B.recv_multipart())
72 ident, msg_list = self.session.feed_identities(B.recv_multipart())
73 new_msg = self.session.unserialize(msg_list)
73 new_msg = self.session.unserialize(msg_list)
74 self.assertEqual(ident[0], b'foo')
74 self.assertEqual(ident[0], b'foo')
75 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
75 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
76 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
76 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
77 self.assertEqual(new_msg['header'],msg['header'])
77 self.assertEqual(new_msg['header'],msg['header'])
78 self.assertEqual(new_msg['content'],msg['content'])
78 self.assertEqual(new_msg['content'],msg['content'])
79 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
79 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
80 self.assertEqual(new_msg['metadata'],msg['metadata'])
80 self.assertEqual(new_msg['metadata'],msg['metadata'])
81 self.assertEqual(new_msg['buffers'],[b'bar'])
81 self.assertEqual(new_msg['buffers'],[b'bar'])
82
82
83 content = msg['content']
83 content = msg['content']
84 header = msg['header']
84 header = msg['header']
85 parent = msg['parent_header']
85 parent = msg['parent_header']
86 metadata = msg['metadata']
86 metadata = msg['metadata']
87 msg_type = header['msg_type']
87 msg_type = header['msg_type']
88 self.session.send(A, None, content=content, parent=parent,
88 self.session.send(A, None, content=content, parent=parent,
89 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
89 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
90 ident, msg_list = self.session.feed_identities(B.recv_multipart())
90 ident, msg_list = self.session.feed_identities(B.recv_multipart())
91 new_msg = self.session.unserialize(msg_list)
91 new_msg = self.session.unserialize(msg_list)
92 self.assertEqual(ident[0], b'foo')
92 self.assertEqual(ident[0], b'foo')
93 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
93 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
94 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
94 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
95 self.assertEqual(new_msg['header'],msg['header'])
95 self.assertEqual(new_msg['header'],msg['header'])
96 self.assertEqual(new_msg['content'],msg['content'])
96 self.assertEqual(new_msg['content'],msg['content'])
97 self.assertEqual(new_msg['metadata'],msg['metadata'])
97 self.assertEqual(new_msg['metadata'],msg['metadata'])
98 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
98 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
99 self.assertEqual(new_msg['buffers'],[b'bar'])
99 self.assertEqual(new_msg['buffers'],[b'bar'])
100
100
101 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
101 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
102 ident, new_msg = self.session.recv(B)
102 ident, new_msg = self.session.recv(B)
103 self.assertEqual(ident[0], b'foo')
103 self.assertEqual(ident[0], b'foo')
104 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
104 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
105 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
105 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
106 self.assertEqual(new_msg['header'],msg['header'])
106 self.assertEqual(new_msg['header'],msg['header'])
107 self.assertEqual(new_msg['content'],msg['content'])
107 self.assertEqual(new_msg['content'],msg['content'])
108 self.assertEqual(new_msg['metadata'],msg['metadata'])
108 self.assertEqual(new_msg['metadata'],msg['metadata'])
109 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
109 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
110 self.assertEqual(new_msg['buffers'],[b'bar'])
110 self.assertEqual(new_msg['buffers'],[b'bar'])
111
111
112 A.close()
112 A.close()
113 B.close()
113 B.close()
114 ctx.term()
114 ctx.term()
115
115
116 def test_args(self):
116 def test_args(self):
117 """initialization arguments for Session"""
117 """initialization arguments for Session"""
118 s = self.session
118 s = self.session
119 self.assertTrue(s.pack is ss.default_packer)
119 self.assertTrue(s.pack is ss.default_packer)
120 self.assertTrue(s.unpack is ss.default_unpacker)
120 self.assertTrue(s.unpack is ss.default_unpacker)
121 self.assertEqual(s.username, os.environ.get('USER', u'username'))
121 self.assertEqual(s.username, os.environ.get('USER', u'username'))
122
122
123 s = ss.Session()
123 s = ss.Session()
124 self.assertEqual(s.username, os.environ.get('USER', u'username'))
124 self.assertEqual(s.username, os.environ.get('USER', u'username'))
125
125
126 self.assertRaises(TypeError, ss.Session, pack='hi')
126 self.assertRaises(TypeError, ss.Session, pack='hi')
127 self.assertRaises(TypeError, ss.Session, unpack='hi')
127 self.assertRaises(TypeError, ss.Session, unpack='hi')
128 u = str(uuid.uuid4())
128 u = str(uuid.uuid4())
129 s = ss.Session(username=u'carrot', session=u)
129 s = ss.Session(username=u'carrot', session=u)
130 self.assertEqual(s.session, u)
130 self.assertEqual(s.session, u)
131 self.assertEqual(s.username, u'carrot')
131 self.assertEqual(s.username, u'carrot')
132
132
133 def test_tracking(self):
133 def test_tracking(self):
134 """test tracking messages"""
134 """test tracking messages"""
135 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
135 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
136 s = self.session
136 s = self.session
137 s.copy_threshold = 1
137 s.copy_threshold = 1
138 stream = ZMQStream(a)
138 stream = ZMQStream(a)
139 msg = s.send(a, 'hello', track=False)
139 msg = s.send(a, 'hello', track=False)
140 self.assertTrue(msg['tracker'] is ss.DONE)
140 self.assertTrue(msg['tracker'] is ss.DONE)
141 msg = s.send(a, 'hello', track=True)
141 msg = s.send(a, 'hello', track=True)
142 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
142 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
143 M = zmq.Message(b'hi there', track=True)
143 M = zmq.Message(b'hi there', track=True)
144 msg = s.send(a, 'hello', buffers=[M], track=True)
144 msg = s.send(a, 'hello', buffers=[M], track=True)
145 t = msg['tracker']
145 t = msg['tracker']
146 self.assertTrue(isinstance(t, zmq.MessageTracker))
146 self.assertTrue(isinstance(t, zmq.MessageTracker))
147 self.assertRaises(zmq.NotDone, t.wait, .1)
147 self.assertRaises(zmq.NotDone, t.wait, .1)
148 del M
148 del M
149 t.wait(1) # this will raise
149 t.wait(1) # this will raise
150
150
151
151
152 # def test_rekey(self):
152 # def test_rekey(self):
153 # """rekeying dict around json str keys"""
153 # """rekeying dict around json str keys"""
154 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
154 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
155 # self.assertRaises(KeyError, ss.rekey, d)
155 # self.assertRaises(KeyError, ss.rekey, d)
156 #
156 #
157 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
157 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
158 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
158 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
159 # rd = ss.rekey(d)
159 # rd = ss.rekey(d)
160 # self.assertEqual(d2,rd)
160 # self.assertEqual(d2,rd)
161 #
161 #
162 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
162 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
163 # d2 = {1.5:d['1.5'],1:d['1']}
163 # d2 = {1.5:d['1.5'],1:d['1']}
164 # rd = ss.rekey(d)
164 # rd = ss.rekey(d)
165 # self.assertEqual(d2,rd)
165 # self.assertEqual(d2,rd)
166 #
166 #
167 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
167 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
168 # self.assertRaises(KeyError, ss.rekey, d)
168 # self.assertRaises(KeyError, ss.rekey, d)
169 #
169 #
170 def test_unique_msg_ids(self):
170 def test_unique_msg_ids(self):
171 """test that messages receive unique ids"""
171 """test that messages receive unique ids"""
172 ids = set()
172 ids = set()
173 for i in range(2**12):
173 for i in range(2**12):
174 h = self.session.msg_header('test')
174 h = self.session.msg_header('test')
175 msg_id = h['msg_id']
175 msg_id = h['msg_id']
176 self.assertTrue(msg_id not in ids)
176 self.assertTrue(msg_id not in ids)
177 ids.add(msg_id)
177 ids.add(msg_id)
178
178
179 def test_feed_identities(self):
179 def test_feed_identities(self):
180 """scrub the front for zmq IDENTITIES"""
180 """scrub the front for zmq IDENTITIES"""
181 theids = "engine client other".split()
181 theids = "engine client other".split()
182 content = dict(code='whoda',stuff=object())
182 content = dict(code='whoda',stuff=object())
183 themsg = self.session.msg('execute',content=content)
183 themsg = self.session.msg('execute',content=content)
184 pmsg = theids
184 pmsg = theids
185
185
186 def test_session_id(self):
186 def test_session_id(self):
187 session = ss.Session()
187 session = ss.Session()
188 # get bs before us
188 # get bs before us
189 bs = session.bsession
189 bs = session.bsession
190 us = session.session
190 us = session.session
191 self.assertEqual(us.encode('ascii'), bs)
191 self.assertEqual(us.encode('ascii'), bs)
192 session = ss.Session()
192 session = ss.Session()
193 # get us before bs
193 # get us before bs
194 us = session.session
194 us = session.session
195 bs = session.bsession
195 bs = session.bsession
196 self.assertEqual(us.encode('ascii'), bs)
196 self.assertEqual(us.encode('ascii'), bs)
197 # change propagates:
197 # change propagates:
198 session.session = 'something else'
198 session.session = 'something else'
199 bs = session.bsession
199 bs = session.bsession
200 us = session.session
200 us = session.session
201 self.assertEqual(us.encode('ascii'), bs)
201 self.assertEqual(us.encode('ascii'), bs)
202 session = ss.Session(session='stuff')
202 session = ss.Session(session='stuff')
203 # get us before bs
203 # get us before bs
204 self.assertEqual(session.bsession, session.session.encode('ascii'))
204 self.assertEqual(session.bsession, session.session.encode('ascii'))
205 self.assertEqual(b'stuff', session.bsession)
205 self.assertEqual(b'stuff', session.bsession)
206
206
207 def test_zero_digest_history(self):
208 session = ss.Session(digest_history_size=0)
209 for i in range(11):
210 session._add_digest(uuid.uuid4().bytes)
211 self.assertEqual(len(session.digest_history), 0)
212
213 def test_cull_digest_history(self):
214 session = ss.Session(digest_history_size=100)
215 for i in range(100):
216 session._add_digest(uuid.uuid4().bytes)
217 self.assertTrue(len(session.digest_history) == 100)
218 session._add_digest(uuid.uuid4().bytes)
219 self.assertTrue(len(session.digest_history) == 91)
220 for i in range(9):
221 session._add_digest(uuid.uuid4().bytes)
222 self.assertTrue(len(session.digest_history) == 100)
223 session._add_digest(uuid.uuid4().bytes)
224 self.assertTrue(len(session.digest_history) == 91)
207
225
General Comments 0
You need to be logged in to leave comments. Login now