##// END OF EJS Templates
update Session object per review...
MinRK -
Show More
@@ -1,538 +1,621 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """Session object for building, serializing, sending, and receiving messages in
3 IPython. The Session object supports serialization, HMAC signatures, and
4 metadata on messages.
5
6 Also defined here are utilities for working with Sessions:
7 * A SessionFactory to be used as a base class for configurables that work with
8 Sessions.
9 * A Message object for convenience that allows attribute-access to the msg dict.
3 """
10 """
4 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
6 #
13 #
7 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
10
17
11 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
12 # Imports
19 # Imports
13 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
14
21
15 import hmac
22 import hmac
16 import logging
23 import logging
17 import os
24 import os
18 import pprint
25 import pprint
19 import uuid
26 import uuid
20 from datetime import datetime
27 from datetime import datetime
21
28
22 try:
29 try:
23 import cPickle
30 import cPickle
24 pickle = cPickle
31 pickle = cPickle
25 except:
32 except:
26 cPickle = None
33 cPickle = None
27 import pickle
34 import pickle
28
35
29 import zmq
36 import zmq
30 from zmq.utils import jsonapi
37 from zmq.utils import jsonapi
31 from zmq.eventloop.ioloop import IOLoop
38 from zmq.eventloop.ioloop import IOLoop
32 from zmq.eventloop.zmqstream import ZMQStream
39 from zmq.eventloop.zmqstream import ZMQStream
33
40
34 from IPython.config.application import Application
41 from IPython.config.configurable import Configurable, LoggingConfigurable
35 from IPython.config.configurable import Configurable
36 from IPython.utils.importstring import import_item
42 from IPython.utils.importstring import import_item
37 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
43 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
38 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
44 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
39
45
40 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
41 # utility functions
47 # utility functions
42 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
43
49
44 def squash_unicode(obj):
50 def squash_unicode(obj):
45 """coerce unicode back to bytestrings."""
51 """coerce unicode back to bytestrings."""
46 if isinstance(obj,dict):
52 if isinstance(obj,dict):
47 for key in obj.keys():
53 for key in obj.keys():
48 obj[key] = squash_unicode(obj[key])
54 obj[key] = squash_unicode(obj[key])
49 if isinstance(key, unicode):
55 if isinstance(key, unicode):
50 obj[squash_unicode(key)] = obj.pop(key)
56 obj[squash_unicode(key)] = obj.pop(key)
51 elif isinstance(obj, list):
57 elif isinstance(obj, list):
52 for i,v in enumerate(obj):
58 for i,v in enumerate(obj):
53 obj[i] = squash_unicode(v)
59 obj[i] = squash_unicode(v)
54 elif isinstance(obj, unicode):
60 elif isinstance(obj, unicode):
55 obj = obj.encode('utf8')
61 obj = obj.encode('utf8')
56 return obj
62 return obj
57
63
58 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
59 # globals and defaults
65 # globals and defaults
60 #-----------------------------------------------------------------------------
66 #-----------------------------------------------------------------------------
61 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
67 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
62 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
68 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
63 json_unpacker = lambda s: squash_unicode(extract_dates(jsonapi.loads(s)))
69 json_unpacker = lambda s: squash_unicode(extract_dates(jsonapi.loads(s)))
64
70
65 pickle_packer = lambda o: pickle.dumps(o,-1)
71 pickle_packer = lambda o: pickle.dumps(o,-1)
66 pickle_unpacker = pickle.loads
72 pickle_unpacker = pickle.loads
67
73
68 default_packer = json_packer
74 default_packer = json_packer
69 default_unpacker = json_unpacker
75 default_unpacker = json_unpacker
70
76
71
77
72 DELIM="<IDS|MSG>"
78 DELIM="<IDS|MSG>"
73
79
74 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
75 # Classes
81 # Classes
76 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
77
83
78 class SessionFactory(Configurable):
84 class SessionFactory(LoggingConfigurable):
79 """The Base class for configurables that have a Session, Context, logger,
85 """The Base class for configurables that have a Session, Context, logger,
80 and IOLoop.
86 and IOLoop.
81 """
87 """
82
88
83 log = Instance('logging.Logger')
84 def _log_default(self):
85 return Application.instance().log
86
87 logname = Unicode('')
89 logname = Unicode('')
88 def _logname_changed(self, name, old, new):
90 def _logname_changed(self, name, old, new):
89 self.log = logging.getLogger(new)
91 self.log = logging.getLogger(new)
90
92
91 # not configurable:
93 # not configurable:
92 context = Instance('zmq.Context')
94 context = Instance('zmq.Context')
93 def _context_default(self):
95 def _context_default(self):
94 return zmq.Context.instance()
96 return zmq.Context.instance()
95
97
96 session = Instance('IPython.zmq.session.Session')
98 session = Instance('IPython.zmq.session.Session')
97
99
98 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
100 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
99 def _loop_default(self):
101 def _loop_default(self):
100 return IOLoop.instance()
102 return IOLoop.instance()
101
103
102 def __init__(self, **kwargs):
104 def __init__(self, **kwargs):
103 super(SessionFactory, self).__init__(**kwargs)
105 super(SessionFactory, self).__init__(**kwargs)
104
106
105 if self.session is None:
107 if self.session is None:
106 # construct the session
108 # construct the session
107 self.session = Session(**kwargs)
109 self.session = Session(**kwargs)
108
110
109
111
110 class Message(object):
112 class Message(object):
111 """A simple message object that maps dict keys to attributes.
113 """A simple message object that maps dict keys to attributes.
112
114
113 A Message can be created from a dict and a dict from a Message instance
115 A Message can be created from a dict and a dict from a Message instance
114 simply by calling dict(msg_obj)."""
116 simply by calling dict(msg_obj)."""
115
117
116 def __init__(self, msg_dict):
118 def __init__(self, msg_dict):
117 dct = self.__dict__
119 dct = self.__dict__
118 for k, v in dict(msg_dict).iteritems():
120 for k, v in dict(msg_dict).iteritems():
119 if isinstance(v, dict):
121 if isinstance(v, dict):
120 v = Message(v)
122 v = Message(v)
121 dct[k] = v
123 dct[k] = v
122
124
123 # Having this iterator lets dict(msg_obj) work out of the box.
125 # Having this iterator lets dict(msg_obj) work out of the box.
124 def __iter__(self):
126 def __iter__(self):
125 return iter(self.__dict__.iteritems())
127 return iter(self.__dict__.iteritems())
126
128
127 def __repr__(self):
129 def __repr__(self):
128 return repr(self.__dict__)
130 return repr(self.__dict__)
129
131
130 def __str__(self):
132 def __str__(self):
131 return pprint.pformat(self.__dict__)
133 return pprint.pformat(self.__dict__)
132
134
133 def __contains__(self, k):
135 def __contains__(self, k):
134 return k in self.__dict__
136 return k in self.__dict__
135
137
136 def __getitem__(self, k):
138 def __getitem__(self, k):
137 return self.__dict__[k]
139 return self.__dict__[k]
138
140
139
141
140 def msg_header(msg_id, msg_type, username, session):
142 def msg_header(msg_id, msg_type, username, session):
141 date = datetime.now()
143 date = datetime.now()
142 return locals()
144 return locals()
143
145
144 def extract_header(msg_or_header):
146 def extract_header(msg_or_header):
145 """Given a message or header, return the header."""
147 """Given a message or header, return the header."""
146 if not msg_or_header:
148 if not msg_or_header:
147 return {}
149 return {}
148 try:
150 try:
149 # See if msg_or_header is the entire message.
151 # See if msg_or_header is the entire message.
150 h = msg_or_header['header']
152 h = msg_or_header['header']
151 except KeyError:
153 except KeyError:
152 try:
154 try:
153 # See if msg_or_header is just the header
155 # See if msg_or_header is just the header
154 h = msg_or_header['msg_id']
156 h = msg_or_header['msg_id']
155 except KeyError:
157 except KeyError:
156 raise
158 raise
157 else:
159 else:
158 h = msg_or_header
160 h = msg_or_header
159 if not isinstance(h, dict):
161 if not isinstance(h, dict):
160 h = dict(h)
162 h = dict(h)
161 return h
163 return h
162
164
163 class Session(Configurable):
165 class Session(Configurable):
164 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
166 """Object for handling serialization and sending of messages.
167
168 The Session object handles building messages and sending them
169 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
170 other over the network via Session objects, and only need to work with the
171 dict-based IPython message spec. The Session will handle
172 serialization/deserialization, security, and metadata.
173
174 Sessions support configurable serialiization via packer/unpacker traits,
175 and signing with HMAC digests via the key/keyfile traits.
176
177 Parameters
178 ----------
179
180 debug : bool
181 whether to trigger extra debugging statements
182 packer/unpacker : str : 'json', 'pickle' or import_string
183 importstrings for methods to serialize message parts. If just
184 'json' or 'pickle', predefined JSON and pickle packers will be used.
185 Otherwise, the entire importstring must be used.
186
187 The functions must accept at least valid JSON input, and output *bytes*.
188
189 For example, to use msgpack:
190 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
191 pack/unpack : callables
192 You can also set the pack/unpack callables for serialization directly.
193 session : bytes
194 the ID of this Session object. The default is to generate a new UUID.
195 username : unicode
196 username added to message headers. The default is to ask the OS.
197 key : bytes
198 The key used to initialize an HMAC signature. If unset, messages
199 will not be signed or checked.
200 keyfile : filepath
201 The file containing a key. If this is set, `key` will be initialized
202 to the contents of the file.
203
204 """
205
165 debug=Bool(False, config=True, help="""Debug output in the Session""")
206 debug=Bool(False, config=True, help="""Debug output in the Session""")
207
166 packer = Unicode('json',config=True,
208 packer = Unicode('json',config=True,
167 help="""The name of the packer for serializing messages.
209 help="""The name of the packer for serializing messages.
168 Should be one of 'json', 'pickle', or an import name
210 Should be one of 'json', 'pickle', or an import name
169 for a custom callable serializer.""")
211 for a custom callable serializer.""")
170 def _packer_changed(self, name, old, new):
212 def _packer_changed(self, name, old, new):
171 if new.lower() == 'json':
213 if new.lower() == 'json':
172 self.pack = json_packer
214 self.pack = json_packer
173 self.unpack = json_unpacker
215 self.unpack = json_unpacker
174 elif new.lower() == 'pickle':
216 elif new.lower() == 'pickle':
175 self.pack = pickle_packer
217 self.pack = pickle_packer
176 self.unpack = pickle_unpacker
218 self.unpack = pickle_unpacker
177 else:
219 else:
178 self.pack = import_item(str(new))
220 self.pack = import_item(str(new))
179
221
180 unpacker = Unicode('json', config=True,
222 unpacker = Unicode('json', config=True,
181 help="""The name of the unpacker for unserializing messages.
223 help="""The name of the unpacker for unserializing messages.
182 Only used with custom functions for `packer`.""")
224 Only used with custom functions for `packer`.""")
183 def _unpacker_changed(self, name, old, new):
225 def _unpacker_changed(self, name, old, new):
184 if new.lower() == 'json':
226 if new.lower() == 'json':
185 self.pack = json_packer
227 self.pack = json_packer
186 self.unpack = json_unpacker
228 self.unpack = json_unpacker
187 elif new.lower() == 'pickle':
229 elif new.lower() == 'pickle':
188 self.pack = pickle_packer
230 self.pack = pickle_packer
189 self.unpack = pickle_unpacker
231 self.unpack = pickle_unpacker
190 else:
232 else:
191 self.unpack = import_item(str(new))
233 self.unpack = import_item(str(new))
192
234
193 session = CStr('', config=True,
235 session = CStr('', config=True,
194 help="""The UUID identifying this session.""")
236 help="""The UUID identifying this session.""")
195 def _session_default(self):
237 def _session_default(self):
196 return bytes(uuid.uuid4())
238 return bytes(uuid.uuid4())
239
197 username = Unicode(os.environ.get('USER','username'), config=True,
240 username = Unicode(os.environ.get('USER','username'), config=True,
198 help="""Username for the Session. Default is your system username.""")
241 help="""Username for the Session. Default is your system username.""")
199
242
200 # message signature related traits:
243 # message signature related traits:
201 key = CStr('', config=True,
244 key = CStr('', config=True,
202 help="""execution key, for extra authentication.""")
245 help="""execution key, for extra authentication.""")
203 def _key_changed(self, name, old, new):
246 def _key_changed(self, name, old, new):
204 if new:
247 if new:
205 self.auth = hmac.HMAC(new)
248 self.auth = hmac.HMAC(new)
206 else:
249 else:
207 self.auth = None
250 self.auth = None
208 auth = Instance(hmac.HMAC)
251 auth = Instance(hmac.HMAC)
209 counters = Instance('collections.defaultdict', (int,))
210 digest_history = Set()
252 digest_history = Set()
211
253
212 keyfile = Unicode('', config=True,
254 keyfile = Unicode('', config=True,
213 help="""path to file containing execution key.""")
255 help="""path to file containing execution key.""")
214 def _keyfile_changed(self, name, old, new):
256 def _keyfile_changed(self, name, old, new):
215 with open(new, 'rb') as f:
257 with open(new, 'rb') as f:
216 self.key = f.read().strip()
258 self.key = f.read().strip()
217
259
218 pack = Any(default_packer) # the actual packer function
260 pack = Any(default_packer) # the actual packer function
219 def _pack_changed(self, name, old, new):
261 def _pack_changed(self, name, old, new):
220 if not callable(new):
262 if not callable(new):
221 raise TypeError("packer must be callable, not %s"%type(new))
263 raise TypeError("packer must be callable, not %s"%type(new))
222
264
223 unpack = Any(default_unpacker) # the actual packer function
265 unpack = Any(default_unpacker) # the actual packer function
224 def _unpack_changed(self, name, old, new):
266 def _unpack_changed(self, name, old, new):
225 # unpacker is not checked - it is assumed to be
267 # unpacker is not checked - it is assumed to be
226 if not callable(new):
268 if not callable(new):
227 raise TypeError("unpacker must be callable, not %s"%type(new))
269 raise TypeError("unpacker must be callable, not %s"%type(new))
228
270
229 def __init__(self, **kwargs):
271 def __init__(self, **kwargs):
272 """create a Session object
273
274 Parameters
275 ----------
276
277 debug : bool
278 whether to trigger extra debugging statements
279 packer/unpacker : str : 'json', 'pickle' or import_string
280 importstrings for methods to serialize message parts. If just
281 'json' or 'pickle', predefined JSON and pickle packers will be used.
282 Otherwise, the entire importstring must be used.
283
284 The functions must accept at least valid JSON input, and output
285 *bytes*.
286
287 For example, to use msgpack:
288 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
289 pack/unpack : callables
290 You can also set the pack/unpack callables for serialization
291 directly.
292 session : bytes
293 the ID of this Session object. The default is to generate a new
294 UUID.
295 username : unicode
296 username added to message headers. The default is to ask the OS.
297 key : bytes
298 The key used to initialize an HMAC signature. If unset, messages
299 will not be signed or checked.
300 keyfile : filepath
301 The file containing a key. If this is set, `key` will be
302 initialized to the contents of the file.
303 """
230 super(Session, self).__init__(**kwargs)
304 super(Session, self).__init__(**kwargs)
231 self._check_packers()
305 self._check_packers()
232 self.none = self.pack({})
306 self.none = self.pack({})
233
307
234 @property
308 @property
235 def msg_id(self):
309 def msg_id(self):
236 """always return new uuid"""
310 """always return new uuid"""
237 return str(uuid.uuid4())
311 return str(uuid.uuid4())
238
312
239 def _check_packers(self):
313 def _check_packers(self):
240 """check packers for binary data and datetime support."""
314 """check packers for binary data and datetime support."""
241 pack = self.pack
315 pack = self.pack
242 unpack = self.unpack
316 unpack = self.unpack
243
317
244 # check simple serialization
318 # check simple serialization
245 msg = dict(a=[1,'hi'])
319 msg = dict(a=[1,'hi'])
246 try:
320 try:
247 packed = pack(msg)
321 packed = pack(msg)
248 except Exception:
322 except Exception:
249 raise ValueError("packer could not serialize a simple message")
323 raise ValueError("packer could not serialize a simple message")
250
324
251 # ensure packed message is bytes
325 # ensure packed message is bytes
252 if not isinstance(packed, bytes):
326 if not isinstance(packed, bytes):
253 raise ValueError("message packed to %r, but bytes are required"%type(packed))
327 raise ValueError("message packed to %r, but bytes are required"%type(packed))
254
328
255 # check that unpack is pack's inverse
329 # check that unpack is pack's inverse
256 try:
330 try:
257 unpacked = unpack(packed)
331 unpacked = unpack(packed)
258 except Exception:
332 except Exception:
259 raise ValueError("unpacker could not handle the packer's output")
333 raise ValueError("unpacker could not handle the packer's output")
260
334
261 # check datetime support
335 # check datetime support
262 msg = dict(t=datetime.now())
336 msg = dict(t=datetime.now())
263 try:
337 try:
264 unpacked = unpack(pack(msg))
338 unpacked = unpack(pack(msg))
265 except Exception:
339 except Exception:
266 self.pack = lambda o: pack(squash_dates(o))
340 self.pack = lambda o: pack(squash_dates(o))
267 self.unpack = lambda s: extract_dates(unpack(s))
341 self.unpack = lambda s: extract_dates(unpack(s))
268
342
269 def msg_header(self, msg_type):
343 def msg_header(self, msg_type):
270 return msg_header(self.msg_id, msg_type, self.username, self.session)
344 return msg_header(self.msg_id, msg_type, self.username, self.session)
271
345
272 def msg(self, msg_type, content=None, parent=None, subheader=None):
346 def msg(self, msg_type, content=None, parent=None, subheader=None):
273 msg = {}
347 msg = {}
274 msg['header'] = self.msg_header(msg_type)
348 msg['header'] = self.msg_header(msg_type)
275 msg['msg_id'] = msg['header']['msg_id']
349 msg['msg_id'] = msg['header']['msg_id']
276 msg['parent_header'] = {} if parent is None else extract_header(parent)
350 msg['parent_header'] = {} if parent is None else extract_header(parent)
277 msg['msg_type'] = msg_type
351 msg['msg_type'] = msg_type
278 msg['content'] = {} if content is None else content
352 msg['content'] = {} if content is None else content
279 sub = {} if subheader is None else subheader
353 sub = {} if subheader is None else subheader
280 msg['header'].update(sub)
354 msg['header'].update(sub)
281 return msg
355 return msg
282
356
283 def sign(self, msg):
357 def sign(self, msg):
284 """Sign a message with HMAC digest. If no auth, return b''."""
358 """Sign a message with HMAC digest. If no auth, return b''."""
285 if self.auth is None:
359 if self.auth is None:
286 return b''
360 return b''
287 h = self.auth.copy()
361 h = self.auth.copy()
288 for m in msg:
362 for m in msg:
289 h.update(m)
363 h.update(m)
290 return h.hexdigest()
364 return h.hexdigest()
291
365
292 def serialize(self, msg, ident=None):
366 def serialize(self, msg, ident=None):
367 """Serialize the message components to bytes.
368
369 Returns
370 -------
371
372 list of bytes objects
373
374 """
293 content = msg.get('content', {})
375 content = msg.get('content', {})
294 if content is None:
376 if content is None:
295 content = self.none
377 content = self.none
296 elif isinstance(content, dict):
378 elif isinstance(content, dict):
297 content = self.pack(content)
379 content = self.pack(content)
298 elif isinstance(content, bytes):
380 elif isinstance(content, bytes):
299 # content is already packed, as in a relayed message
381 # content is already packed, as in a relayed message
300 pass
382 pass
301 elif isinstance(content, unicode):
383 elif isinstance(content, unicode):
302 # should be bytes, but JSON often spits out unicode
384 # should be bytes, but JSON often spits out unicode
303 content = content.encode('utf8')
385 content = content.encode('utf8')
304 else:
386 else:
305 raise TypeError("Content incorrect type: %s"%type(content))
387 raise TypeError("Content incorrect type: %s"%type(content))
306
388
307 real_message = [self.pack(msg['header']),
389 real_message = [self.pack(msg['header']),
308 self.pack(msg['parent_header']),
390 self.pack(msg['parent_header']),
309 content
391 content
310 ]
392 ]
311
393
312 to_send = []
394 to_send = []
313
395
314 if isinstance(ident, list):
396 if isinstance(ident, list):
315 # accept list of idents
397 # accept list of idents
316 to_send.extend(ident)
398 to_send.extend(ident)
317 elif ident is not None:
399 elif ident is not None:
318 to_send.append(ident)
400 to_send.append(ident)
319 to_send.append(DELIM)
401 to_send.append(DELIM)
320
402
321 signature = self.sign(real_message)
403 signature = self.sign(real_message)
322 to_send.append(signature)
404 to_send.append(signature)
323
405
324 to_send.extend(real_message)
406 to_send.extend(real_message)
325
407
326 return to_send
408 return to_send
327
409
328 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
410 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
329 buffers=None, subheader=None, track=False):
411 buffers=None, subheader=None, track=False):
330 """Build and send a message via stream or socket.
412 """Build and send a message via stream or socket.
331
413
332 Parameters
414 Parameters
333 ----------
415 ----------
334
416
335 stream : zmq.Socket or ZMQStream
417 stream : zmq.Socket or ZMQStream
336 the socket-like object used to send the data
418 the socket-like object used to send the data
337 msg_or_type : str or Message/dict
419 msg_or_type : str or Message/dict
338 Normally, msg_or_type will be a msg_type unless a message is being sent more
420 Normally, msg_or_type will be a msg_type unless a message is being
339 than once.
421 sent more than once.
340
422
341 content : dict or None
423 content : dict or None
342 the content of the message (ignored if msg_or_type is a message)
424 the content of the message (ignored if msg_or_type is a message)
343 parent : Message or dict or None
425 parent : Message or dict or None
344 the parent or parent header describing the parent of this message
426 the parent or parent header describing the parent of this message
345 ident : bytes or list of bytes
427 ident : bytes or list of bytes
346 the zmq.IDENTITY routing path
428 the zmq.IDENTITY routing path
347 subheader : dict or None
429 subheader : dict or None
348 extra header keys for this message's header
430 extra header keys for this message's header
349 buffers : list or None
431 buffers : list or None
350 the already-serialized buffers to be appended to the message
432 the already-serialized buffers to be appended to the message
351 track : bool
433 track : bool
352 whether to track. Only for use with Sockets,
434 whether to track. Only for use with Sockets,
353 because ZMQStream objects cannot track messages.
435 because ZMQStream objects cannot track messages.
354
436
355 Returns
437 Returns
356 -------
438 -------
357 msg : message dict
439 msg : message dict
358 the constructed message
440 the constructed message
359 (msg,tracker) : (message dict, MessageTracker)
441 (msg,tracker) : (message dict, MessageTracker)
360 if track=True, then a 2-tuple will be returned,
442 if track=True, then a 2-tuple will be returned,
361 the first element being the constructed
443 the first element being the constructed
362 message, and the second being the MessageTracker
444 message, and the second being the MessageTracker
363
445
364 """
446 """
365
447
366 if not isinstance(stream, (zmq.Socket, ZMQStream)):
448 if not isinstance(stream, (zmq.Socket, ZMQStream)):
367 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
449 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
368 elif track and isinstance(stream, ZMQStream):
450 elif track and isinstance(stream, ZMQStream):
369 raise TypeError("ZMQStream cannot track messages")
451 raise TypeError("ZMQStream cannot track messages")
370
452
371 if isinstance(msg_or_type, (Message, dict)):
453 if isinstance(msg_or_type, (Message, dict)):
372 # we got a Message, not a msg_type
454 # we got a Message, not a msg_type
373 # don't build a new Message
455 # don't build a new Message
374 msg = msg_or_type
456 msg = msg_or_type
375 else:
457 else:
376 msg = self.msg(msg_or_type, content, parent, subheader)
458 msg = self.msg(msg_or_type, content, parent, subheader)
377
459
378 buffers = [] if buffers is None else buffers
460 buffers = [] if buffers is None else buffers
379 to_send = self.serialize(msg, ident)
461 to_send = self.serialize(msg, ident)
380 flag = 0
462 flag = 0
381 if buffers:
463 if buffers:
382 flag = zmq.SNDMORE
464 flag = zmq.SNDMORE
383 _track = False
465 _track = False
384 else:
466 else:
385 _track=track
467 _track=track
386 if track:
468 if track:
387 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
469 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
388 else:
470 else:
389 tracker = stream.send_multipart(to_send, flag, copy=False)
471 tracker = stream.send_multipart(to_send, flag, copy=False)
390 for b in buffers[:-1]:
472 for b in buffers[:-1]:
391 stream.send(b, flag, copy=False)
473 stream.send(b, flag, copy=False)
392 if buffers:
474 if buffers:
393 if track:
475 if track:
394 tracker = stream.send(buffers[-1], copy=False, track=track)
476 tracker = stream.send(buffers[-1], copy=False, track=track)
395 else:
477 else:
396 tracker = stream.send(buffers[-1], copy=False)
478 tracker = stream.send(buffers[-1], copy=False)
397
479
398 # omsg = Message(msg)
480 # omsg = Message(msg)
399 if self.debug:
481 if self.debug:
400 pprint.pprint(msg)
482 pprint.pprint(msg)
401 pprint.pprint(to_send)
483 pprint.pprint(to_send)
402 pprint.pprint(buffers)
484 pprint.pprint(buffers)
403
485
404 msg['tracker'] = tracker
486 msg['tracker'] = tracker
405
487
406 return msg
488 return msg
407
489
408 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
490 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
409 """Send a raw message via ident path.
491 """Send a raw message via ident path.
410
492
411 Parameters
493 Parameters
412 ----------
494 ----------
413 msg : list of sendable buffers"""
495 msg : list of sendable buffers"""
414 to_send = []
496 to_send = []
415 if isinstance(ident, bytes):
497 if isinstance(ident, bytes):
416 ident = [ident]
498 ident = [ident]
417 if ident is not None:
499 if ident is not None:
418 to_send.extend(ident)
500 to_send.extend(ident)
419
501
420 to_send.append(DELIM)
502 to_send.append(DELIM)
421 to_send.append(self.sign(msg))
503 to_send.append(self.sign(msg))
422 to_send.extend(msg)
504 to_send.extend(msg)
423 stream.send_multipart(msg, flags, copy=copy)
505 stream.send_multipart(msg, flags, copy=copy)
424
506
425 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
507 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
426 """receives and unpacks a message
508 """receives and unpacks a message
427 returns [idents], msg"""
509 returns [idents], msg"""
428 if isinstance(socket, ZMQStream):
510 if isinstance(socket, ZMQStream):
429 socket = socket.socket
511 socket = socket.socket
430 try:
512 try:
431 msg = socket.recv_multipart(mode)
513 msg = socket.recv_multipart(mode)
432 except zmq.ZMQError as e:
514 except zmq.ZMQError as e:
433 if e.errno == zmq.EAGAIN:
515 if e.errno == zmq.EAGAIN:
434 # We can convert EAGAIN to None as we know in this case
516 # We can convert EAGAIN to None as we know in this case
435 # recv_multipart won't return None.
517 # recv_multipart won't return None.
436 return None,None
518 return None,None
437 else:
519 else:
438 raise
520 raise
439 # return an actual Message object
521 # split multipart message into identity list and message dict
440 # determine the number of idents by trying to unpack them.
522 # invalid large messages can cause very expensive string comparisons
441 # this is terrible:
442 idents, msg = self.feed_identities(msg, copy)
523 idents, msg = self.feed_identities(msg, copy)
443 try:
524 try:
444 return idents, self.unpack_message(msg, content=content, copy=copy)
525 return idents, self.unpack_message(msg, content=content, copy=copy)
445 except Exception as e:
526 except Exception as e:
446 print (idents, msg)
527 print (idents, msg)
447 # TODO: handle it
528 # TODO: handle it
448 raise e
529 raise e
449
530
450 def feed_identities(self, msg, copy=True):
531 def feed_identities(self, msg, copy=True):
451 """feed until DELIM is reached, then return the prefix as idents and remainder as
532 """feed until DELIM is reached, then return the prefix as idents and
452 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
533 remainder as msg. This is easily broken by setting an IDENT to DELIM,
534 but that would be silly.
453
535
454 Parameters
536 Parameters
455 ----------
537 ----------
456 msg : a list of Message or bytes objects
538 msg : a list of Message or bytes objects
457 the message to be split
539 the message to be split
458 copy : bool
540 copy : bool
459 flag determining whether the arguments are bytes or Messages
541 flag determining whether the arguments are bytes or Messages
460
542
461 Returns
543 Returns
462 -------
544 -------
463 (idents,msg) : two lists
545 (idents,msg) : two lists
464 idents will always be a list of bytes - the indentity prefix
546 idents will always be a list of bytes - the indentity prefix
465 msg will be a list of bytes or Messages, unchanged from input
547 msg will be a list of bytes or Messages, unchanged from input
466 msg should be unpackable via self.unpack_message at this point.
548 msg should be unpackable via self.unpack_message at this point.
467 """
549 """
468 if copy:
550 if copy:
469 idx = msg.index(DELIM)
551 idx = msg.index(DELIM)
470 return msg[:idx], msg[idx+1:]
552 return msg[:idx], msg[idx+1:]
471 else:
553 else:
472 failed = True
554 failed = True
473 for idx,m in enumerate(msg):
555 for idx,m in enumerate(msg):
474 if m.bytes == DELIM:
556 if m.bytes == DELIM:
475 failed = False
557 failed = False
476 break
558 break
477 if failed:
559 if failed:
478 raise ValueError("DELIM not in msg")
560 raise ValueError("DELIM not in msg")
479 idents, msg = msg[:idx], msg[idx+1:]
561 idents, msg = msg[:idx], msg[idx+1:]
480 return [m.bytes for m in idents], msg
562 return [m.bytes for m in idents], msg
481
563
482 def unpack_message(self, msg, content=True, copy=True):
564 def unpack_message(self, msg, content=True, copy=True):
483 """Return a message object from the format
565 """Return a message object from the format
484 sent by self.send.
566 sent by self.send.
485
567
486 Parameters:
568 Parameters:
487 -----------
569 -----------
488
570
489 content : bool (True)
571 content : bool (True)
490 whether to unpack the content dict (True),
572 whether to unpack the content dict (True),
491 or leave it serialized (False)
573 or leave it serialized (False)
492
574
493 copy : bool (True)
575 copy : bool (True)
494 whether to return the bytes (True),
576 whether to return the bytes (True),
495 or the non-copying Message object in each place (False)
577 or the non-copying Message object in each place (False)
496
578
497 """
579 """
498 minlen = 4
580 minlen = 4
499 message = {}
581 message = {}
500 if not copy:
582 if not copy:
501 for i in range(minlen):
583 for i in range(minlen):
502 msg[i] = msg[i].bytes
584 msg[i] = msg[i].bytes
503 if self.auth is not None:
585 if self.auth is not None:
504 signature = msg[0]
586 signature = msg[0]
505 if signature in self.digest_history:
587 if signature in self.digest_history:
506 raise ValueError("Duplicate Signature: %r"%signature)
588 raise ValueError("Duplicate Signature: %r"%signature)
507 self.digest_history.add(signature)
589 self.digest_history.add(signature)
508 check = self.sign(msg[1:4])
590 check = self.sign(msg[1:4])
509 if not signature == check:
591 if not signature == check:
510 raise ValueError("Invalid Signature: %r"%signature)
592 raise ValueError("Invalid Signature: %r"%signature)
511 if not len(msg) >= minlen:
593 if not len(msg) >= minlen:
512 raise TypeError("malformed message, must have at least %i elements"%minlen)
594 raise TypeError("malformed message, must have at least %i elements"%minlen)
513 message['header'] = self.unpack(msg[1])
595 message['header'] = self.unpack(msg[1])
514 message['msg_type'] = message['header']['msg_type']
596 message['msg_type'] = message['header']['msg_type']
515 message['parent_header'] = self.unpack(msg[2])
597 message['parent_header'] = self.unpack(msg[2])
516 if content:
598 if content:
517 message['content'] = self.unpack(msg[3])
599 message['content'] = self.unpack(msg[3])
518 else:
600 else:
519 message['content'] = msg[3]
601 message['content'] = msg[3]
520
602
521 message['buffers'] = msg[4:]
603 message['buffers'] = msg[4:]
522 return message
604 return message
523
605
524 def test_msg2obj():
606 def test_msg2obj():
525 am = dict(x=1)
607 am = dict(x=1)
526 ao = Message(am)
608 ao = Message(am)
527 assert ao.x == am['x']
609 assert ao.x == am['x']
528
610
529 am['y'] = dict(z=1)
611 am['y'] = dict(z=1)
530 ao = Message(am)
612 ao = Message(am)
531 assert ao.y.z == am['y']['z']
613 assert ao.y.z == am['y']['z']
532
614
533 k1, k2 = 'y', 'z'
615 k1, k2 = 'y', 'z'
534 assert ao[k1][k2] == am[k1][k2]
616 assert ao[k1][k2] == am[k1][k2]
535
617
536 am2 = dict(ao)
618 am2 = dict(ao)
537 assert am['x'] == am2['x']
619 assert am['x'] == am2['x']
538 assert am['y']['z'] == am2['y']['z']
620 assert am['y']['z'] == am2['y']['z']
621
General Comments 0
You need to be logged in to leave comments. Login now