##// END OF EJS Templates
Merge pull request #7798 from jasongrout/buffer-memoryview...
Min RK -
r20764:84817ef0 merge
parent child Browse files
Show More
@@ -1,278 +1,281 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import os
8 import json
8 import json
9 import struct
9 import struct
10 import warnings
10 import warnings
11 import sys
11
12
12 try:
13 try:
13 from urllib.parse import urlparse # Py 3
14 from urllib.parse import urlparse # Py 3
14 except ImportError:
15 except ImportError:
15 from urlparse import urlparse # Py 2
16 from urlparse import urlparse # Py 2
16
17
17 import tornado
18 import tornado
18 from tornado import gen, ioloop, web
19 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
20 from tornado.websocket import WebSocketHandler
20
21
21 from IPython.kernel.zmq.session import Session
22 from IPython.kernel.zmq.session import Session
22 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.py3compat import cast_unicode
24 from IPython.utils.py3compat import cast_unicode
24
25
25 from .handlers import IPythonHandler
26 from .handlers import IPythonHandler
26
27
27 def serialize_binary_message(msg):
28 def serialize_binary_message(msg):
28 """serialize a message as a binary blob
29 """serialize a message as a binary blob
29
30
30 Header:
31 Header:
31
32
32 4 bytes: number of msg parts (nbufs) as 32b int
33 4 bytes: number of msg parts (nbufs) as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34 4 * nbufs bytes: offset for each buffer as integer as 32b int
34
35
35 Offsets are from the start of the buffer, including the header.
36 Offsets are from the start of the buffer, including the header.
36
37
37 Returns
38 Returns
38 -------
39 -------
39
40
40 The message serialized to bytes.
41 The message serialized to bytes.
41
42
42 """
43 """
43 # don't modify msg or buffer list in-place
44 # don't modify msg or buffer list in-place
44 msg = msg.copy()
45 msg = msg.copy()
45 buffers = list(msg.pop('buffers'))
46 buffers = list(msg.pop('buffers'))
47 if sys.version_info < (3, 4):
48 buffers = [x.tobytes() for x in buffers]
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
49 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 buffers.insert(0, bmsg)
50 buffers.insert(0, bmsg)
48 nbufs = len(buffers)
51 nbufs = len(buffers)
49 offsets = [4 * (nbufs + 1)]
52 offsets = [4 * (nbufs + 1)]
50 for buf in buffers[:-1]:
53 for buf in buffers[:-1]:
51 offsets.append(offsets[-1] + len(buf))
54 offsets.append(offsets[-1] + len(buf))
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
55 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 buffers.insert(0, offsets_buf)
56 buffers.insert(0, offsets_buf)
54 return b''.join(buffers)
57 return b''.join(buffers)
55
58
56
59
57 def deserialize_binary_message(bmsg):
60 def deserialize_binary_message(bmsg):
58 """deserialize a message from a binary blog
61 """deserialize a message from a binary blog
59
62
60 Header:
63 Header:
61
64
62 4 bytes: number of msg parts (nbufs) as 32b int
65 4 bytes: number of msg parts (nbufs) as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
66 4 * nbufs bytes: offset for each buffer as integer as 32b int
64
67
65 Offsets are from the start of the buffer, including the header.
68 Offsets are from the start of the buffer, including the header.
66
69
67 Returns
70 Returns
68 -------
71 -------
69
72
70 message dictionary
73 message dictionary
71 """
74 """
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
75 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
76 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 offsets.append(None)
77 offsets.append(None)
75 bufs = []
78 bufs = []
76 for start, stop in zip(offsets[:-1], offsets[1:]):
79 for start, stop in zip(offsets[:-1], offsets[1:]):
77 bufs.append(bmsg[start:stop])
80 bufs.append(bmsg[start:stop])
78 msg = json.loads(bufs[0].decode('utf8'))
81 msg = json.loads(bufs[0].decode('utf8'))
79 msg['header'] = extract_dates(msg['header'])
82 msg['header'] = extract_dates(msg['header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
83 msg['parent_header'] = extract_dates(msg['parent_header'])
81 msg['buffers'] = bufs[1:]
84 msg['buffers'] = bufs[1:]
82 return msg
85 return msg
83
86
84 # ping interval for keeping websockets alive (30 seconds)
87 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
88 WS_PING_INTERVAL = 30000
86
89
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
90 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
91 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
92 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
93 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
94 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
95 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
96 WS_PING_INTERVAL = 0
94
97
95 class ZMQStreamHandler(WebSocketHandler):
98 class ZMQStreamHandler(WebSocketHandler):
96
99
97 if tornado.version_info < (4,1):
100 if tornado.version_info < (4,1):
98 """Backport send_error from tornado 4.1 to 4.0"""
101 """Backport send_error from tornado 4.1 to 4.0"""
99 def send_error(self, *args, **kwargs):
102 def send_error(self, *args, **kwargs):
100 if self.stream is None:
103 if self.stream is None:
101 super(WebSocketHandler, self).send_error(*args, **kwargs)
104 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 else:
105 else:
103 # If we get an uncaught exception during the handshake,
106 # If we get an uncaught exception during the handshake,
104 # we have no choice but to abruptly close the connection.
107 # we have no choice but to abruptly close the connection.
105 # TODO: for uncaught exceptions after the handshake,
108 # TODO: for uncaught exceptions after the handshake,
106 # we can close the connection more gracefully.
109 # we can close the connection more gracefully.
107 self.stream.close()
110 self.stream.close()
108
111
109
112
110 def check_origin(self, origin):
113 def check_origin(self, origin):
111 """Check Origin == Host or Access-Control-Allow-Origin.
114 """Check Origin == Host or Access-Control-Allow-Origin.
112
115
113 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
116 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
114 """
117 """
115 if self.allow_origin == '*':
118 if self.allow_origin == '*':
116 return True
119 return True
117
120
118 host = self.request.headers.get("Host")
121 host = self.request.headers.get("Host")
119
122
120 # If no header is provided, assume we can't verify origin
123 # If no header is provided, assume we can't verify origin
121 if origin is None:
124 if origin is None:
122 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
125 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
123 return False
126 return False
124 if host is None:
127 if host is None:
125 self.log.warn("Missing Host header, rejecting WebSocket connection.")
128 self.log.warn("Missing Host header, rejecting WebSocket connection.")
126 return False
129 return False
127
130
128 origin = origin.lower()
131 origin = origin.lower()
129 origin_host = urlparse(origin).netloc
132 origin_host = urlparse(origin).netloc
130
133
131 # OK if origin matches host
134 # OK if origin matches host
132 if origin_host == host:
135 if origin_host == host:
133 return True
136 return True
134
137
135 # Check CORS headers
138 # Check CORS headers
136 if self.allow_origin:
139 if self.allow_origin:
137 allow = self.allow_origin == origin
140 allow = self.allow_origin == origin
138 elif self.allow_origin_pat:
141 elif self.allow_origin_pat:
139 allow = bool(self.allow_origin_pat.match(origin))
142 allow = bool(self.allow_origin_pat.match(origin))
140 else:
143 else:
141 # No CORS headers deny the request
144 # No CORS headers deny the request
142 allow = False
145 allow = False
143 if not allow:
146 if not allow:
144 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
147 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
145 origin, host,
148 origin, host,
146 )
149 )
147 return allow
150 return allow
148
151
149 def clear_cookie(self, *args, **kwargs):
152 def clear_cookie(self, *args, **kwargs):
150 """meaningless for websockets"""
153 """meaningless for websockets"""
151 pass
154 pass
152
155
153 def _reserialize_reply(self, msg_list, channel=None):
156 def _reserialize_reply(self, msg_list, channel=None):
154 """Reserialize a reply message using JSON.
157 """Reserialize a reply message using JSON.
155
158
156 This takes the msg list from the ZMQ socket, deserializes it using
159 This takes the msg list from the ZMQ socket, deserializes it using
157 self.session and then serializes the result using JSON. This method
160 self.session and then serializes the result using JSON. This method
158 should be used by self._on_zmq_reply to build messages that can
161 should be used by self._on_zmq_reply to build messages that can
159 be sent back to the browser.
162 be sent back to the browser.
160 """
163 """
161 idents, msg_list = self.session.feed_identities(msg_list)
164 idents, msg_list = self.session.feed_identities(msg_list)
162 msg = self.session.deserialize(msg_list)
165 msg = self.session.deserialize(msg_list)
163 if channel:
166 if channel:
164 msg['channel'] = channel
167 msg['channel'] = channel
165 if msg['buffers']:
168 if msg['buffers']:
166 buf = serialize_binary_message(msg)
169 buf = serialize_binary_message(msg)
167 return buf
170 return buf
168 else:
171 else:
169 smsg = json.dumps(msg, default=date_default)
172 smsg = json.dumps(msg, default=date_default)
170 return cast_unicode(smsg)
173 return cast_unicode(smsg)
171
174
172 def _on_zmq_reply(self, stream, msg_list):
175 def _on_zmq_reply(self, stream, msg_list):
173 # Sometimes this gets triggered when the on_close method is scheduled in the
176 # Sometimes this gets triggered when the on_close method is scheduled in the
174 # eventloop but hasn't been called.
177 # eventloop but hasn't been called.
175 if self.stream.closed() or stream.closed():
178 if self.stream.closed() or stream.closed():
176 self.log.warn("zmq message arrived on closed channel")
179 self.log.warn("zmq message arrived on closed channel")
177 self.close()
180 self.close()
178 return
181 return
179 channel = getattr(stream, 'channel', None)
182 channel = getattr(stream, 'channel', None)
180 try:
183 try:
181 msg = self._reserialize_reply(msg_list, channel=channel)
184 msg = self._reserialize_reply(msg_list, channel=channel)
182 except Exception:
185 except Exception:
183 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
186 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
184 else:
187 else:
185 self.write_message(msg, binary=isinstance(msg, bytes))
188 self.write_message(msg, binary=isinstance(msg, bytes))
186
189
187 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
190 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
188 ping_callback = None
191 ping_callback = None
189 last_ping = 0
192 last_ping = 0
190 last_pong = 0
193 last_pong = 0
191
194
192 @property
195 @property
193 def ping_interval(self):
196 def ping_interval(self):
194 """The interval for websocket keep-alive pings.
197 """The interval for websocket keep-alive pings.
195
198
196 Set ws_ping_interval = 0 to disable pings.
199 Set ws_ping_interval = 0 to disable pings.
197 """
200 """
198 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
201 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
199
202
200 @property
203 @property
201 def ping_timeout(self):
204 def ping_timeout(self):
202 """If no ping is received in this many milliseconds,
205 """If no ping is received in this many milliseconds,
203 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
206 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
204 Default is max of 3 pings or 30 seconds.
207 Default is max of 3 pings or 30 seconds.
205 """
208 """
206 return self.settings.get('ws_ping_timeout',
209 return self.settings.get('ws_ping_timeout',
207 max(3 * self.ping_interval, WS_PING_INTERVAL)
210 max(3 * self.ping_interval, WS_PING_INTERVAL)
208 )
211 )
209
212
210 def set_default_headers(self):
213 def set_default_headers(self):
211 """Undo the set_default_headers in IPythonHandler
214 """Undo the set_default_headers in IPythonHandler
212
215
213 which doesn't make sense for websockets
216 which doesn't make sense for websockets
214 """
217 """
215 pass
218 pass
216
219
217 def pre_get(self):
220 def pre_get(self):
218 """Run before finishing the GET request
221 """Run before finishing the GET request
219
222
220 Extend this method to add logic that should fire before
223 Extend this method to add logic that should fire before
221 the websocket finishes completing.
224 the websocket finishes completing.
222 """
225 """
223 # authenticate the request before opening the websocket
226 # authenticate the request before opening the websocket
224 if self.get_current_user() is None:
227 if self.get_current_user() is None:
225 self.log.warn("Couldn't authenticate WebSocket connection")
228 self.log.warn("Couldn't authenticate WebSocket connection")
226 raise web.HTTPError(403)
229 raise web.HTTPError(403)
227
230
228 if self.get_argument('session_id', False):
231 if self.get_argument('session_id', False):
229 self.session.session = cast_unicode(self.get_argument('session_id'))
232 self.session.session = cast_unicode(self.get_argument('session_id'))
230 else:
233 else:
231 self.log.warn("No session ID specified")
234 self.log.warn("No session ID specified")
232
235
233 @gen.coroutine
236 @gen.coroutine
234 def get(self, *args, **kwargs):
237 def get(self, *args, **kwargs):
235 # pre_get can be a coroutine in subclasses
238 # pre_get can be a coroutine in subclasses
236 # assign and yield in two step to avoid tornado 3 issues
239 # assign and yield in two step to avoid tornado 3 issues
237 res = self.pre_get()
240 res = self.pre_get()
238 yield gen.maybe_future(res)
241 yield gen.maybe_future(res)
239 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
242 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
240
243
241 def initialize(self):
244 def initialize(self):
242 self.log.debug("Initializing websocket connection %s", self.request.path)
245 self.log.debug("Initializing websocket connection %s", self.request.path)
243 self.session = Session(config=self.config)
246 self.session = Session(config=self.config)
244
247
245 def open(self, *args, **kwargs):
248 def open(self, *args, **kwargs):
246 self.log.debug("Opening websocket %s", self.request.path)
249 self.log.debug("Opening websocket %s", self.request.path)
247
250
248 # start the pinging
251 # start the pinging
249 if self.ping_interval > 0:
252 if self.ping_interval > 0:
250 loop = ioloop.IOLoop.current()
253 loop = ioloop.IOLoop.current()
251 self.last_ping = loop.time() # Remember time of last ping
254 self.last_ping = loop.time() # Remember time of last ping
252 self.last_pong = self.last_ping
255 self.last_pong = self.last_ping
253 self.ping_callback = ioloop.PeriodicCallback(
256 self.ping_callback = ioloop.PeriodicCallback(
254 self.send_ping, self.ping_interval, io_loop=loop,
257 self.send_ping, self.ping_interval, io_loop=loop,
255 )
258 )
256 self.ping_callback.start()
259 self.ping_callback.start()
257
260
258 def send_ping(self):
261 def send_ping(self):
259 """send a ping to keep the websocket alive"""
262 """send a ping to keep the websocket alive"""
260 if self.stream.closed() and self.ping_callback is not None:
263 if self.stream.closed() and self.ping_callback is not None:
261 self.ping_callback.stop()
264 self.ping_callback.stop()
262 return
265 return
263
266
264 # check for timeout on pong. Make sure that we really have sent a recent ping in
267 # check for timeout on pong. Make sure that we really have sent a recent ping in
265 # case the machine with both server and client has been suspended since the last ping.
268 # case the machine with both server and client has been suspended since the last ping.
266 now = ioloop.IOLoop.current().time()
269 now = ioloop.IOLoop.current().time()
267 since_last_pong = 1e3 * (now - self.last_pong)
270 since_last_pong = 1e3 * (now - self.last_pong)
268 since_last_ping = 1e3 * (now - self.last_ping)
271 since_last_ping = 1e3 * (now - self.last_ping)
269 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
272 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
270 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
273 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
271 self.close()
274 self.close()
272 return
275 return
273
276
274 self.ping(b'')
277 self.ping(b'')
275 self.last_ping = now
278 self.last_ping = now
276
279
277 def on_pong(self, data):
280 def on_pong(self, data):
278 self.last_pong = ioloop.IOLoop.current().time()
281 self.last_pong = ioloop.IOLoop.current().time()
@@ -1,26 +1,26 b''
1 """Test serialize/deserialize messages with buffers"""
1 """Test serialize/deserialize messages with buffers"""
2
2
3 import os
3 import os
4
4
5 import nose.tools as nt
5 import nose.tools as nt
6
6
7 from IPython.kernel.zmq.session import Session
7 from IPython.kernel.zmq.session import Session
8 from ..base.zmqhandlers import (
8 from ..base.zmqhandlers import (
9 serialize_binary_message,
9 serialize_binary_message,
10 deserialize_binary_message,
10 deserialize_binary_message,
11 )
11 )
12
12
13 def test_serialize_binary():
13 def test_serialize_binary():
14 s = Session()
14 s = Session()
15 msg = s.msg('data_pub', content={'a': 'b'})
15 msg = s.msg('data_pub', content={'a': 'b'})
16 msg['buffers'] = [ os.urandom(3) for i in range(3) ]
16 msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ]
17 bmsg = serialize_binary_message(msg)
17 bmsg = serialize_binary_message(msg)
18 nt.assert_is_instance(bmsg, bytes)
18 nt.assert_is_instance(bmsg, bytes)
19
19
20 def test_deserialize_binary():
20 def test_deserialize_binary():
21 s = Session()
21 s = Session()
22 msg = s.msg('data_pub', content={'a': 'b'})
22 msg = s.msg('data_pub', content={'a': 'b'})
23 msg['buffers'] = [ os.urandom(2) for i in range(3) ]
23 msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ]
24 bmsg = serialize_binary_message(msg)
24 bmsg = serialize_binary_message(msg)
25 msg2 = deserialize_binary_message(bmsg)
25 msg2 = deserialize_binary_message(bmsg)
26 nt.assert_equal(msg2, msg)
26 nt.assert_equal(msg2, msg)
@@ -1,877 +1,881 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23
23
24 try:
24 try:
25 import cPickle
25 import cPickle
26 pickle = cPickle
26 pickle = cPickle
27 except:
27 except:
28 cPickle = None
28 cPickle = None
29 import pickle
29 import pickle
30
30
31 try:
31 try:
32 # We are using compare_digest to limit the surface of timing attacks
32 # We are using compare_digest to limit the surface of timing attacks
33 from hmac import compare_digest
33 from hmac import compare_digest
34 except ImportError:
34 except ImportError:
35 # Python < 2.7.7: When digests don't match no feedback is provided,
35 # Python < 2.7.7: When digests don't match no feedback is provided,
36 # limiting the surface of attack
36 # limiting the surface of attack
37 def compare_digest(a,b): return a == b
37 def compare_digest(a,b): return a == b
38
38
39 import zmq
39 import zmq
40 from zmq.utils import jsonapi
40 from zmq.utils import jsonapi
41 from zmq.eventloop.ioloop import IOLoop
41 from zmq.eventloop.ioloop import IOLoop
42 from zmq.eventloop.zmqstream import ZMQStream
42 from zmq.eventloop.zmqstream import ZMQStream
43
43
44 from IPython.core.release import kernel_protocol_version
44 from IPython.core.release import kernel_protocol_version
45 from IPython.config.configurable import Configurable, LoggingConfigurable
45 from IPython.config.configurable import Configurable, LoggingConfigurable
46 from IPython.utils import io
46 from IPython.utils import io
47 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 iteritems)
50 iteritems)
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 DottedObjectName, CUnicode, Dict, Integer,
52 DottedObjectName, CUnicode, Dict, Integer,
53 TraitError,
53 TraitError,
54 )
54 )
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 from IPython.kernel.adapter import adapt
56 from IPython.kernel.adapter import adapt
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58
58
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60 # utility functions
60 # utility functions
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62
62
63 def squash_unicode(obj):
63 def squash_unicode(obj):
64 """coerce unicode back to bytestrings."""
64 """coerce unicode back to bytestrings."""
65 if isinstance(obj,dict):
65 if isinstance(obj,dict):
66 for key in obj.keys():
66 for key in obj.keys():
67 obj[key] = squash_unicode(obj[key])
67 obj[key] = squash_unicode(obj[key])
68 if isinstance(key, unicode_type):
68 if isinstance(key, unicode_type):
69 obj[squash_unicode(key)] = obj.pop(key)
69 obj[squash_unicode(key)] = obj.pop(key)
70 elif isinstance(obj, list):
70 elif isinstance(obj, list):
71 for i,v in enumerate(obj):
71 for i,v in enumerate(obj):
72 obj[i] = squash_unicode(v)
72 obj[i] = squash_unicode(v)
73 elif isinstance(obj, unicode_type):
73 elif isinstance(obj, unicode_type):
74 obj = obj.encode('utf8')
74 obj = obj.encode('utf8')
75 return obj
75 return obj
76
76
77 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
78 # globals and defaults
78 # globals and defaults
79 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
80
80
81 # ISO8601-ify datetime objects
81 # ISO8601-ify datetime objects
82 # allow unicode
82 # allow unicode
83 # disallow nan, because it's not actually valid JSON
83 # disallow nan, because it's not actually valid JSON
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
85 ensure_ascii=False, allow_nan=False,
85 ensure_ascii=False, allow_nan=False,
86 )
86 )
87 json_unpacker = lambda s: jsonapi.loads(s)
87 json_unpacker = lambda s: jsonapi.loads(s)
88
88
89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
90 pickle_unpacker = pickle.loads
90 pickle_unpacker = pickle.loads
91
91
92 default_packer = json_packer
92 default_packer = json_packer
93 default_unpacker = json_unpacker
93 default_unpacker = json_unpacker
94
94
95 DELIM = b"<IDS|MSG>"
95 DELIM = b"<IDS|MSG>"
96 # singleton dummy tracker, which will always report as done
96 # singleton dummy tracker, which will always report as done
97 DONE = zmq.MessageTracker()
97 DONE = zmq.MessageTracker()
98
98
99 #-----------------------------------------------------------------------------
99 #-----------------------------------------------------------------------------
100 # Mixin tools for apps that use Sessions
100 # Mixin tools for apps that use Sessions
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102
102
103 session_aliases = dict(
103 session_aliases = dict(
104 ident = 'Session.session',
104 ident = 'Session.session',
105 user = 'Session.username',
105 user = 'Session.username',
106 keyfile = 'Session.keyfile',
106 keyfile = 'Session.keyfile',
107 )
107 )
108
108
109 session_flags = {
109 session_flags = {
110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
111 'keyfile' : '' }},
111 'keyfile' : '' }},
112 """Use HMAC digests for authentication of messages.
112 """Use HMAC digests for authentication of messages.
113 Setting this flag will generate a new UUID to use as the HMAC key.
113 Setting this flag will generate a new UUID to use as the HMAC key.
114 """),
114 """),
115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
116 """Don't authenticate messages."""),
116 """Don't authenticate messages."""),
117 }
117 }
118
118
119 def default_secure(cfg):
119 def default_secure(cfg):
120 """Set the default behavior for a config environment to be secure.
120 """Set the default behavior for a config environment to be secure.
121
121
122 If Session.key/keyfile have not been set, set Session.key to
122 If Session.key/keyfile have not been set, set Session.key to
123 a new random UUID.
123 a new random UUID.
124 """
124 """
125 warnings.warn("default_secure is deprecated", DeprecationWarning)
125 warnings.warn("default_secure is deprecated", DeprecationWarning)
126 if 'Session' in cfg:
126 if 'Session' in cfg:
127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
128 return
128 return
129 # key/keyfile not specified, generate new UUID:
129 # key/keyfile not specified, generate new UUID:
130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
131
131
132
132
133 #-----------------------------------------------------------------------------
133 #-----------------------------------------------------------------------------
134 # Classes
134 # Classes
135 #-----------------------------------------------------------------------------
135 #-----------------------------------------------------------------------------
136
136
137 class SessionFactory(LoggingConfigurable):
137 class SessionFactory(LoggingConfigurable):
138 """The Base class for configurables that have a Session, Context, logger,
138 """The Base class for configurables that have a Session, Context, logger,
139 and IOLoop.
139 and IOLoop.
140 """
140 """
141
141
142 logname = Unicode('')
142 logname = Unicode('')
143 def _logname_changed(self, name, old, new):
143 def _logname_changed(self, name, old, new):
144 self.log = logging.getLogger(new)
144 self.log = logging.getLogger(new)
145
145
146 # not configurable:
146 # not configurable:
147 context = Instance('zmq.Context')
147 context = Instance('zmq.Context')
148 def _context_default(self):
148 def _context_default(self):
149 return zmq.Context.instance()
149 return zmq.Context.instance()
150
150
151 session = Instance('IPython.kernel.zmq.session.Session')
151 session = Instance('IPython.kernel.zmq.session.Session')
152
152
153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
154 def _loop_default(self):
154 def _loop_default(self):
155 return IOLoop.instance()
155 return IOLoop.instance()
156
156
157 def __init__(self, **kwargs):
157 def __init__(self, **kwargs):
158 super(SessionFactory, self).__init__(**kwargs)
158 super(SessionFactory, self).__init__(**kwargs)
159
159
160 if self.session is None:
160 if self.session is None:
161 # construct the session
161 # construct the session
162 self.session = Session(**kwargs)
162 self.session = Session(**kwargs)
163
163
164
164
165 class Message(object):
165 class Message(object):
166 """A simple message object that maps dict keys to attributes.
166 """A simple message object that maps dict keys to attributes.
167
167
168 A Message can be created from a dict and a dict from a Message instance
168 A Message can be created from a dict and a dict from a Message instance
169 simply by calling dict(msg_obj)."""
169 simply by calling dict(msg_obj)."""
170
170
171 def __init__(self, msg_dict):
171 def __init__(self, msg_dict):
172 dct = self.__dict__
172 dct = self.__dict__
173 for k, v in iteritems(dict(msg_dict)):
173 for k, v in iteritems(dict(msg_dict)):
174 if isinstance(v, dict):
174 if isinstance(v, dict):
175 v = Message(v)
175 v = Message(v)
176 dct[k] = v
176 dct[k] = v
177
177
178 # Having this iterator lets dict(msg_obj) work out of the box.
178 # Having this iterator lets dict(msg_obj) work out of the box.
179 def __iter__(self):
179 def __iter__(self):
180 return iter(iteritems(self.__dict__))
180 return iter(iteritems(self.__dict__))
181
181
182 def __repr__(self):
182 def __repr__(self):
183 return repr(self.__dict__)
183 return repr(self.__dict__)
184
184
185 def __str__(self):
185 def __str__(self):
186 return pprint.pformat(self.__dict__)
186 return pprint.pformat(self.__dict__)
187
187
188 def __contains__(self, k):
188 def __contains__(self, k):
189 return k in self.__dict__
189 return k in self.__dict__
190
190
191 def __getitem__(self, k):
191 def __getitem__(self, k):
192 return self.__dict__[k]
192 return self.__dict__[k]
193
193
194
194
195 def msg_header(msg_id, msg_type, username, session):
195 def msg_header(msg_id, msg_type, username, session):
196 date = datetime.now()
196 date = datetime.now()
197 version = kernel_protocol_version
197 version = kernel_protocol_version
198 return locals()
198 return locals()
199
199
200 def extract_header(msg_or_header):
200 def extract_header(msg_or_header):
201 """Given a message or header, return the header."""
201 """Given a message or header, return the header."""
202 if not msg_or_header:
202 if not msg_or_header:
203 return {}
203 return {}
204 try:
204 try:
205 # See if msg_or_header is the entire message.
205 # See if msg_or_header is the entire message.
206 h = msg_or_header['header']
206 h = msg_or_header['header']
207 except KeyError:
207 except KeyError:
208 try:
208 try:
209 # See if msg_or_header is just the header
209 # See if msg_or_header is just the header
210 h = msg_or_header['msg_id']
210 h = msg_or_header['msg_id']
211 except KeyError:
211 except KeyError:
212 raise
212 raise
213 else:
213 else:
214 h = msg_or_header
214 h = msg_or_header
215 if not isinstance(h, dict):
215 if not isinstance(h, dict):
216 h = dict(h)
216 h = dict(h)
217 return h
217 return h
218
218
219 class Session(Configurable):
219 class Session(Configurable):
220 """Object for handling serialization and sending of messages.
220 """Object for handling serialization and sending of messages.
221
221
222 The Session object handles building messages and sending them
222 The Session object handles building messages and sending them
223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
224 other over the network via Session objects, and only need to work with the
224 other over the network via Session objects, and only need to work with the
225 dict-based IPython message spec. The Session will handle
225 dict-based IPython message spec. The Session will handle
226 serialization/deserialization, security, and metadata.
226 serialization/deserialization, security, and metadata.
227
227
228 Sessions support configurable serialization via packer/unpacker traits,
228 Sessions support configurable serialization via packer/unpacker traits,
229 and signing with HMAC digests via the key/keyfile traits.
229 and signing with HMAC digests via the key/keyfile traits.
230
230
231 Parameters
231 Parameters
232 ----------
232 ----------
233
233
234 debug : bool
234 debug : bool
235 whether to trigger extra debugging statements
235 whether to trigger extra debugging statements
236 packer/unpacker : str : 'json', 'pickle' or import_string
236 packer/unpacker : str : 'json', 'pickle' or import_string
237 importstrings for methods to serialize message parts. If just
237 importstrings for methods to serialize message parts. If just
238 'json' or 'pickle', predefined JSON and pickle packers will be used.
238 'json' or 'pickle', predefined JSON and pickle packers will be used.
239 Otherwise, the entire importstring must be used.
239 Otherwise, the entire importstring must be used.
240
240
241 The functions must accept at least valid JSON input, and output *bytes*.
241 The functions must accept at least valid JSON input, and output *bytes*.
242
242
243 For example, to use msgpack:
243 For example, to use msgpack:
244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
245 pack/unpack : callables
245 pack/unpack : callables
246 You can also set the pack/unpack callables for serialization directly.
246 You can also set the pack/unpack callables for serialization directly.
247 session : bytes
247 session : bytes
248 the ID of this Session object. The default is to generate a new UUID.
248 the ID of this Session object. The default is to generate a new UUID.
249 username : unicode
249 username : unicode
250 username added to message headers. The default is to ask the OS.
250 username added to message headers. The default is to ask the OS.
251 key : bytes
251 key : bytes
252 The key used to initialize an HMAC signature. If unset, messages
252 The key used to initialize an HMAC signature. If unset, messages
253 will not be signed or checked.
253 will not be signed or checked.
254 keyfile : filepath
254 keyfile : filepath
255 The file containing a key. If this is set, `key` will be initialized
255 The file containing a key. If this is set, `key` will be initialized
256 to the contents of the file.
256 to the contents of the file.
257
257
258 """
258 """
259
259
260 debug=Bool(False, config=True, help="""Debug output in the Session""")
260 debug=Bool(False, config=True, help="""Debug output in the Session""")
261
261
262 packer = DottedObjectName('json',config=True,
262 packer = DottedObjectName('json',config=True,
263 help="""The name of the packer for serializing messages.
263 help="""The name of the packer for serializing messages.
264 Should be one of 'json', 'pickle', or an import name
264 Should be one of 'json', 'pickle', or an import name
265 for a custom callable serializer.""")
265 for a custom callable serializer.""")
266 def _packer_changed(self, name, old, new):
266 def _packer_changed(self, name, old, new):
267 if new.lower() == 'json':
267 if new.lower() == 'json':
268 self.pack = json_packer
268 self.pack = json_packer
269 self.unpack = json_unpacker
269 self.unpack = json_unpacker
270 self.unpacker = new
270 self.unpacker = new
271 elif new.lower() == 'pickle':
271 elif new.lower() == 'pickle':
272 self.pack = pickle_packer
272 self.pack = pickle_packer
273 self.unpack = pickle_unpacker
273 self.unpack = pickle_unpacker
274 self.unpacker = new
274 self.unpacker = new
275 else:
275 else:
276 self.pack = import_item(str(new))
276 self.pack = import_item(str(new))
277
277
278 unpacker = DottedObjectName('json', config=True,
278 unpacker = DottedObjectName('json', config=True,
279 help="""The name of the unpacker for unserializing messages.
279 help="""The name of the unpacker for unserializing messages.
280 Only used with custom functions for `packer`.""")
280 Only used with custom functions for `packer`.""")
281 def _unpacker_changed(self, name, old, new):
281 def _unpacker_changed(self, name, old, new):
282 if new.lower() == 'json':
282 if new.lower() == 'json':
283 self.pack = json_packer
283 self.pack = json_packer
284 self.unpack = json_unpacker
284 self.unpack = json_unpacker
285 self.packer = new
285 self.packer = new
286 elif new.lower() == 'pickle':
286 elif new.lower() == 'pickle':
287 self.pack = pickle_packer
287 self.pack = pickle_packer
288 self.unpack = pickle_unpacker
288 self.unpack = pickle_unpacker
289 self.packer = new
289 self.packer = new
290 else:
290 else:
291 self.unpack = import_item(str(new))
291 self.unpack = import_item(str(new))
292
292
293 session = CUnicode(u'', config=True,
293 session = CUnicode(u'', config=True,
294 help="""The UUID identifying this session.""")
294 help="""The UUID identifying this session.""")
295 def _session_default(self):
295 def _session_default(self):
296 u = unicode_type(uuid.uuid4())
296 u = unicode_type(uuid.uuid4())
297 self.bsession = u.encode('ascii')
297 self.bsession = u.encode('ascii')
298 return u
298 return u
299
299
300 def _session_changed(self, name, old, new):
300 def _session_changed(self, name, old, new):
301 self.bsession = self.session.encode('ascii')
301 self.bsession = self.session.encode('ascii')
302
302
303 # bsession is the session as bytes
303 # bsession is the session as bytes
304 bsession = CBytes(b'')
304 bsession = CBytes(b'')
305
305
306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
307 help="""Username for the Session. Default is your system username.""",
307 help="""Username for the Session. Default is your system username.""",
308 config=True)
308 config=True)
309
309
310 metadata = Dict({}, config=True,
310 metadata = Dict({}, config=True,
311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
312
312
313 # if 0, no adapting to do.
313 # if 0, no adapting to do.
314 adapt_version = Integer(0)
314 adapt_version = Integer(0)
315
315
316 # message signature related traits:
316 # message signature related traits:
317
317
318 key = CBytes(config=True,
318 key = CBytes(config=True,
319 help="""execution key, for signing messages.""")
319 help="""execution key, for signing messages.""")
320 def _key_default(self):
320 def _key_default(self):
321 return str_to_bytes(str(uuid.uuid4()))
321 return str_to_bytes(str(uuid.uuid4()))
322
322
323 def _key_changed(self):
323 def _key_changed(self):
324 self._new_auth()
324 self._new_auth()
325
325
326 signature_scheme = Unicode('hmac-sha256', config=True,
326 signature_scheme = Unicode('hmac-sha256', config=True,
327 help="""The digest scheme used to construct the message signatures.
327 help="""The digest scheme used to construct the message signatures.
328 Must have the form 'hmac-HASH'.""")
328 Must have the form 'hmac-HASH'.""")
329 def _signature_scheme_changed(self, name, old, new):
329 def _signature_scheme_changed(self, name, old, new):
330 if not new.startswith('hmac-'):
330 if not new.startswith('hmac-'):
331 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
331 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
332 hash_name = new.split('-', 1)[1]
332 hash_name = new.split('-', 1)[1]
333 try:
333 try:
334 self.digest_mod = getattr(hashlib, hash_name)
334 self.digest_mod = getattr(hashlib, hash_name)
335 except AttributeError:
335 except AttributeError:
336 raise TraitError("hashlib has no such attribute: %s" % hash_name)
336 raise TraitError("hashlib has no such attribute: %s" % hash_name)
337 self._new_auth()
337 self._new_auth()
338
338
339 digest_mod = Any()
339 digest_mod = Any()
340 def _digest_mod_default(self):
340 def _digest_mod_default(self):
341 return hashlib.sha256
341 return hashlib.sha256
342
342
343 auth = Instance(hmac.HMAC)
343 auth = Instance(hmac.HMAC)
344
344
345 def _new_auth(self):
345 def _new_auth(self):
346 if self.key:
346 if self.key:
347 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
347 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
348 else:
348 else:
349 self.auth = None
349 self.auth = None
350
350
351 digest_history = Set()
351 digest_history = Set()
352 digest_history_size = Integer(2**16, config=True,
352 digest_history_size = Integer(2**16, config=True,
353 help="""The maximum number of digests to remember.
353 help="""The maximum number of digests to remember.
354
354
355 The digest history will be culled when it exceeds this value.
355 The digest history will be culled when it exceeds this value.
356 """
356 """
357 )
357 )
358
358
359 keyfile = Unicode('', config=True,
359 keyfile = Unicode('', config=True,
360 help="""path to file containing execution key.""")
360 help="""path to file containing execution key.""")
361 def _keyfile_changed(self, name, old, new):
361 def _keyfile_changed(self, name, old, new):
362 with open(new, 'rb') as f:
362 with open(new, 'rb') as f:
363 self.key = f.read().strip()
363 self.key = f.read().strip()
364
364
365 # for protecting against sends from forks
365 # for protecting against sends from forks
366 pid = Integer()
366 pid = Integer()
367
367
368 # serialization traits:
368 # serialization traits:
369
369
370 pack = Any(default_packer) # the actual packer function
370 pack = Any(default_packer) # the actual packer function
371 def _pack_changed(self, name, old, new):
371 def _pack_changed(self, name, old, new):
372 if not callable(new):
372 if not callable(new):
373 raise TypeError("packer must be callable, not %s"%type(new))
373 raise TypeError("packer must be callable, not %s"%type(new))
374
374
375 unpack = Any(default_unpacker) # the actual packer function
375 unpack = Any(default_unpacker) # the actual packer function
376 def _unpack_changed(self, name, old, new):
376 def _unpack_changed(self, name, old, new):
377 # unpacker is not checked - it is assumed to be
377 # unpacker is not checked - it is assumed to be
378 if not callable(new):
378 if not callable(new):
379 raise TypeError("unpacker must be callable, not %s"%type(new))
379 raise TypeError("unpacker must be callable, not %s"%type(new))
380
380
381 # thresholds:
381 # thresholds:
382 copy_threshold = Integer(2**16, config=True,
382 copy_threshold = Integer(2**16, config=True,
383 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
383 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
384 buffer_threshold = Integer(MAX_BYTES, config=True,
384 buffer_threshold = Integer(MAX_BYTES, config=True,
385 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
385 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
386 item_threshold = Integer(MAX_ITEMS, config=True,
386 item_threshold = Integer(MAX_ITEMS, config=True,
387 help="""The maximum number of items for a container to be introspected for custom serialization.
387 help="""The maximum number of items for a container to be introspected for custom serialization.
388 Containers larger than this are pickled outright.
388 Containers larger than this are pickled outright.
389 """
389 """
390 )
390 )
391
391
392
392
393 def __init__(self, **kwargs):
393 def __init__(self, **kwargs):
394 """create a Session object
394 """create a Session object
395
395
396 Parameters
396 Parameters
397 ----------
397 ----------
398
398
399 debug : bool
399 debug : bool
400 whether to trigger extra debugging statements
400 whether to trigger extra debugging statements
401 packer/unpacker : str : 'json', 'pickle' or import_string
401 packer/unpacker : str : 'json', 'pickle' or import_string
402 importstrings for methods to serialize message parts. If just
402 importstrings for methods to serialize message parts. If just
403 'json' or 'pickle', predefined JSON and pickle packers will be used.
403 'json' or 'pickle', predefined JSON and pickle packers will be used.
404 Otherwise, the entire importstring must be used.
404 Otherwise, the entire importstring must be used.
405
405
406 The functions must accept at least valid JSON input, and output
406 The functions must accept at least valid JSON input, and output
407 *bytes*.
407 *bytes*.
408
408
409 For example, to use msgpack:
409 For example, to use msgpack:
410 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
410 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
411 pack/unpack : callables
411 pack/unpack : callables
412 You can also set the pack/unpack callables for serialization
412 You can also set the pack/unpack callables for serialization
413 directly.
413 directly.
414 session : unicode (must be ascii)
414 session : unicode (must be ascii)
415 the ID of this Session object. The default is to generate a new
415 the ID of this Session object. The default is to generate a new
416 UUID.
416 UUID.
417 bsession : bytes
417 bsession : bytes
418 The session as bytes
418 The session as bytes
419 username : unicode
419 username : unicode
420 username added to message headers. The default is to ask the OS.
420 username added to message headers. The default is to ask the OS.
421 key : bytes
421 key : bytes
422 The key used to initialize an HMAC signature. If unset, messages
422 The key used to initialize an HMAC signature. If unset, messages
423 will not be signed or checked.
423 will not be signed or checked.
424 signature_scheme : str
424 signature_scheme : str
425 The message digest scheme. Currently must be of the form 'hmac-HASH',
425 The message digest scheme. Currently must be of the form 'hmac-HASH',
426 where 'HASH' is a hashing function available in Python's hashlib.
426 where 'HASH' is a hashing function available in Python's hashlib.
427 The default is 'hmac-sha256'.
427 The default is 'hmac-sha256'.
428 This is ignored if 'key' is empty.
428 This is ignored if 'key' is empty.
429 keyfile : filepath
429 keyfile : filepath
430 The file containing a key. If this is set, `key` will be
430 The file containing a key. If this is set, `key` will be
431 initialized to the contents of the file.
431 initialized to the contents of the file.
432 """
432 """
433 super(Session, self).__init__(**kwargs)
433 super(Session, self).__init__(**kwargs)
434 self._check_packers()
434 self._check_packers()
435 self.none = self.pack({})
435 self.none = self.pack({})
436 # ensure self._session_default() if necessary, so bsession is defined:
436 # ensure self._session_default() if necessary, so bsession is defined:
437 self.session
437 self.session
438 self.pid = os.getpid()
438 self.pid = os.getpid()
439 self._new_auth()
439 self._new_auth()
440
440
441 @property
441 @property
442 def msg_id(self):
442 def msg_id(self):
443 """always return new uuid"""
443 """always return new uuid"""
444 return str(uuid.uuid4())
444 return str(uuid.uuid4())
445
445
446 def _check_packers(self):
446 def _check_packers(self):
447 """check packers for datetime support."""
447 """check packers for datetime support."""
448 pack = self.pack
448 pack = self.pack
449 unpack = self.unpack
449 unpack = self.unpack
450
450
451 # check simple serialization
451 # check simple serialization
452 msg = dict(a=[1,'hi'])
452 msg = dict(a=[1,'hi'])
453 try:
453 try:
454 packed = pack(msg)
454 packed = pack(msg)
455 except Exception as e:
455 except Exception as e:
456 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
456 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
457 if self.packer == 'json':
457 if self.packer == 'json':
458 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
458 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
459 else:
459 else:
460 jsonmsg = ""
460 jsonmsg = ""
461 raise ValueError(
461 raise ValueError(
462 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
462 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
463 )
463 )
464
464
465 # ensure packed message is bytes
465 # ensure packed message is bytes
466 if not isinstance(packed, bytes):
466 if not isinstance(packed, bytes):
467 raise ValueError("message packed to %r, but bytes are required"%type(packed))
467 raise ValueError("message packed to %r, but bytes are required"%type(packed))
468
468
469 # check that unpack is pack's inverse
469 # check that unpack is pack's inverse
470 try:
470 try:
471 unpacked = unpack(packed)
471 unpacked = unpack(packed)
472 assert unpacked == msg
472 assert unpacked == msg
473 except Exception as e:
473 except Exception as e:
474 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
474 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
475 if self.packer == 'json':
475 if self.packer == 'json':
476 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
476 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
477 else:
477 else:
478 jsonmsg = ""
478 jsonmsg = ""
479 raise ValueError(
479 raise ValueError(
480 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
480 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
481 )
481 )
482
482
483 # check datetime support
483 # check datetime support
484 msg = dict(t=datetime.now())
484 msg = dict(t=datetime.now())
485 try:
485 try:
486 unpacked = unpack(pack(msg))
486 unpacked = unpack(pack(msg))
487 if isinstance(unpacked['t'], datetime):
487 if isinstance(unpacked['t'], datetime):
488 raise ValueError("Shouldn't deserialize to datetime")
488 raise ValueError("Shouldn't deserialize to datetime")
489 except Exception:
489 except Exception:
490 self.pack = lambda o: pack(squash_dates(o))
490 self.pack = lambda o: pack(squash_dates(o))
491 self.unpack = lambda s: unpack(s)
491 self.unpack = lambda s: unpack(s)
492
492
493 def msg_header(self, msg_type):
493 def msg_header(self, msg_type):
494 return msg_header(self.msg_id, msg_type, self.username, self.session)
494 return msg_header(self.msg_id, msg_type, self.username, self.session)
495
495
496 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
496 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
497 """Return the nested message dict.
497 """Return the nested message dict.
498
498
499 This format is different from what is sent over the wire. The
499 This format is different from what is sent over the wire. The
500 serialize/deserialize methods converts this nested message dict to the wire
500 serialize/deserialize methods converts this nested message dict to the wire
501 format, which is a list of message parts.
501 format, which is a list of message parts.
502 """
502 """
503 msg = {}
503 msg = {}
504 header = self.msg_header(msg_type) if header is None else header
504 header = self.msg_header(msg_type) if header is None else header
505 msg['header'] = header
505 msg['header'] = header
506 msg['msg_id'] = header['msg_id']
506 msg['msg_id'] = header['msg_id']
507 msg['msg_type'] = header['msg_type']
507 msg['msg_type'] = header['msg_type']
508 msg['parent_header'] = {} if parent is None else extract_header(parent)
508 msg['parent_header'] = {} if parent is None else extract_header(parent)
509 msg['content'] = {} if content is None else content
509 msg['content'] = {} if content is None else content
510 msg['metadata'] = self.metadata.copy()
510 msg['metadata'] = self.metadata.copy()
511 if metadata is not None:
511 if metadata is not None:
512 msg['metadata'].update(metadata)
512 msg['metadata'].update(metadata)
513 return msg
513 return msg
514
514
515 def sign(self, msg_list):
515 def sign(self, msg_list):
516 """Sign a message with HMAC digest. If no auth, return b''.
516 """Sign a message with HMAC digest. If no auth, return b''.
517
517
518 Parameters
518 Parameters
519 ----------
519 ----------
520 msg_list : list
520 msg_list : list
521 The [p_header,p_parent,p_content] part of the message list.
521 The [p_header,p_parent,p_content] part of the message list.
522 """
522 """
523 if self.auth is None:
523 if self.auth is None:
524 return b''
524 return b''
525 h = self.auth.copy()
525 h = self.auth.copy()
526 for m in msg_list:
526 for m in msg_list:
527 h.update(m)
527 h.update(m)
528 return str_to_bytes(h.hexdigest())
528 return str_to_bytes(h.hexdigest())
529
529
530 def serialize(self, msg, ident=None):
530 def serialize(self, msg, ident=None):
531 """Serialize the message components to bytes.
531 """Serialize the message components to bytes.
532
532
533 This is roughly the inverse of deserialize. The serialize/deserialize
533 This is roughly the inverse of deserialize. The serialize/deserialize
534 methods work with full message lists, whereas pack/unpack work with
534 methods work with full message lists, whereas pack/unpack work with
535 the individual message parts in the message list.
535 the individual message parts in the message list.
536
536
537 Parameters
537 Parameters
538 ----------
538 ----------
539 msg : dict or Message
539 msg : dict or Message
540 The next message dict as returned by the self.msg method.
540 The next message dict as returned by the self.msg method.
541
541
542 Returns
542 Returns
543 -------
543 -------
544 msg_list : list
544 msg_list : list
545 The list of bytes objects to be sent with the format::
545 The list of bytes objects to be sent with the format::
546
546
547 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
547 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
548 p_metadata, p_content, buffer1, buffer2, ...]
548 p_metadata, p_content, buffer1, buffer2, ...]
549
549
550 In this list, the ``p_*`` entities are the packed or serialized
550 In this list, the ``p_*`` entities are the packed or serialized
551 versions, so if JSON is used, these are utf8 encoded JSON strings.
551 versions, so if JSON is used, these are utf8 encoded JSON strings.
552 """
552 """
553 content = msg.get('content', {})
553 content = msg.get('content', {})
554 if content is None:
554 if content is None:
555 content = self.none
555 content = self.none
556 elif isinstance(content, dict):
556 elif isinstance(content, dict):
557 content = self.pack(content)
557 content = self.pack(content)
558 elif isinstance(content, bytes):
558 elif isinstance(content, bytes):
559 # content is already packed, as in a relayed message
559 # content is already packed, as in a relayed message
560 pass
560 pass
561 elif isinstance(content, unicode_type):
561 elif isinstance(content, unicode_type):
562 # should be bytes, but JSON often spits out unicode
562 # should be bytes, but JSON often spits out unicode
563 content = content.encode('utf8')
563 content = content.encode('utf8')
564 else:
564 else:
565 raise TypeError("Content incorrect type: %s"%type(content))
565 raise TypeError("Content incorrect type: %s"%type(content))
566
566
567 real_message = [self.pack(msg['header']),
567 real_message = [self.pack(msg['header']),
568 self.pack(msg['parent_header']),
568 self.pack(msg['parent_header']),
569 self.pack(msg['metadata']),
569 self.pack(msg['metadata']),
570 content,
570 content,
571 ]
571 ]
572
572
573 to_send = []
573 to_send = []
574
574
575 if isinstance(ident, list):
575 if isinstance(ident, list):
576 # accept list of idents
576 # accept list of idents
577 to_send.extend(ident)
577 to_send.extend(ident)
578 elif ident is not None:
578 elif ident is not None:
579 to_send.append(ident)
579 to_send.append(ident)
580 to_send.append(DELIM)
580 to_send.append(DELIM)
581
581
582 signature = self.sign(real_message)
582 signature = self.sign(real_message)
583 to_send.append(signature)
583 to_send.append(signature)
584
584
585 to_send.extend(real_message)
585 to_send.extend(real_message)
586
586
587 return to_send
587 return to_send
588
588
589 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
589 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
590 buffers=None, track=False, header=None, metadata=None):
590 buffers=None, track=False, header=None, metadata=None):
591 """Build and send a message via stream or socket.
591 """Build and send a message via stream or socket.
592
592
593 The message format used by this function internally is as follows:
593 The message format used by this function internally is as follows:
594
594
595 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
595 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
596 buffer1,buffer2,...]
596 buffer1,buffer2,...]
597
597
598 The serialize/deserialize methods convert the nested message dict into this
598 The serialize/deserialize methods convert the nested message dict into this
599 format.
599 format.
600
600
601 Parameters
601 Parameters
602 ----------
602 ----------
603
603
604 stream : zmq.Socket or ZMQStream
604 stream : zmq.Socket or ZMQStream
605 The socket-like object used to send the data.
605 The socket-like object used to send the data.
606 msg_or_type : str or Message/dict
606 msg_or_type : str or Message/dict
607 Normally, msg_or_type will be a msg_type unless a message is being
607 Normally, msg_or_type will be a msg_type unless a message is being
608 sent more than once. If a header is supplied, this can be set to
608 sent more than once. If a header is supplied, this can be set to
609 None and the msg_type will be pulled from the header.
609 None and the msg_type will be pulled from the header.
610
610
611 content : dict or None
611 content : dict or None
612 The content of the message (ignored if msg_or_type is a message).
612 The content of the message (ignored if msg_or_type is a message).
613 header : dict or None
613 header : dict or None
614 The header dict for the message (ignored if msg_to_type is a message).
614 The header dict for the message (ignored if msg_to_type is a message).
615 parent : Message or dict or None
615 parent : Message or dict or None
616 The parent or parent header describing the parent of this message
616 The parent or parent header describing the parent of this message
617 (ignored if msg_or_type is a message).
617 (ignored if msg_or_type is a message).
618 ident : bytes or list of bytes
618 ident : bytes or list of bytes
619 The zmq.IDENTITY routing path.
619 The zmq.IDENTITY routing path.
620 metadata : dict or None
620 metadata : dict or None
621 The metadata describing the message
621 The metadata describing the message
622 buffers : list or None
622 buffers : list or None
623 The already-serialized buffers to be appended to the message.
623 The already-serialized buffers to be appended to the message.
624 track : bool
624 track : bool
625 Whether to track. Only for use with Sockets, because ZMQStream
625 Whether to track. Only for use with Sockets, because ZMQStream
626 objects cannot track messages.
626 objects cannot track messages.
627
627
628
628
629 Returns
629 Returns
630 -------
630 -------
631 msg : dict
631 msg : dict
632 The constructed message.
632 The constructed message.
633 """
633 """
634 if not isinstance(stream, zmq.Socket):
634 if not isinstance(stream, zmq.Socket):
635 # ZMQStreams and dummy sockets do not support tracking.
635 # ZMQStreams and dummy sockets do not support tracking.
636 track = False
636 track = False
637
637
638 if isinstance(msg_or_type, (Message, dict)):
638 if isinstance(msg_or_type, (Message, dict)):
639 # We got a Message or message dict, not a msg_type so don't
639 # We got a Message or message dict, not a msg_type so don't
640 # build a new Message.
640 # build a new Message.
641 msg = msg_or_type
641 msg = msg_or_type
642 buffers = buffers or msg.get('buffers', [])
642 buffers = buffers or msg.get('buffers', [])
643 else:
643 else:
644 msg = self.msg(msg_or_type, content=content, parent=parent,
644 msg = self.msg(msg_or_type, content=content, parent=parent,
645 header=header, metadata=metadata)
645 header=header, metadata=metadata)
646 if not os.getpid() == self.pid:
646 if not os.getpid() == self.pid:
647 io.rprint("WARNING: attempted to send message from fork")
647 io.rprint("WARNING: attempted to send message from fork")
648 io.rprint(msg)
648 io.rprint(msg)
649 return
649 return
650 buffers = [] if buffers is None else buffers
650 buffers = [] if buffers is None else buffers
651 if self.adapt_version:
651 if self.adapt_version:
652 msg = adapt(msg, self.adapt_version)
652 msg = adapt(msg, self.adapt_version)
653 to_send = self.serialize(msg, ident)
653 to_send = self.serialize(msg, ident)
654 to_send.extend(buffers)
654 to_send.extend(buffers)
655 longest = max([ len(s) for s in to_send ])
655 longest = max([ len(s) for s in to_send ])
656 copy = (longest < self.copy_threshold)
656 copy = (longest < self.copy_threshold)
657
657
658 if buffers and track and not copy:
658 if buffers and track and not copy:
659 # only really track when we are doing zero-copy buffers
659 # only really track when we are doing zero-copy buffers
660 tracker = stream.send_multipart(to_send, copy=False, track=True)
660 tracker = stream.send_multipart(to_send, copy=False, track=True)
661 else:
661 else:
662 # use dummy tracker, which will be done immediately
662 # use dummy tracker, which will be done immediately
663 tracker = DONE
663 tracker = DONE
664 stream.send_multipart(to_send, copy=copy)
664 stream.send_multipart(to_send, copy=copy)
665
665
666 if self.debug:
666 if self.debug:
667 pprint.pprint(msg)
667 pprint.pprint(msg)
668 pprint.pprint(to_send)
668 pprint.pprint(to_send)
669 pprint.pprint(buffers)
669 pprint.pprint(buffers)
670
670
671 msg['tracker'] = tracker
671 msg['tracker'] = tracker
672
672
673 return msg
673 return msg
674
674
675 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
675 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
676 """Send a raw message via ident path.
676 """Send a raw message via ident path.
677
677
678 This method is used to send a already serialized message.
678 This method is used to send a already serialized message.
679
679
680 Parameters
680 Parameters
681 ----------
681 ----------
682 stream : ZMQStream or Socket
682 stream : ZMQStream or Socket
683 The ZMQ stream or socket to use for sending the message.
683 The ZMQ stream or socket to use for sending the message.
684 msg_list : list
684 msg_list : list
685 The serialized list of messages to send. This only includes the
685 The serialized list of messages to send. This only includes the
686 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
686 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
687 the message.
687 the message.
688 ident : ident or list
688 ident : ident or list
689 A single ident or a list of idents to use in sending.
689 A single ident or a list of idents to use in sending.
690 """
690 """
691 to_send = []
691 to_send = []
692 if isinstance(ident, bytes):
692 if isinstance(ident, bytes):
693 ident = [ident]
693 ident = [ident]
694 if ident is not None:
694 if ident is not None:
695 to_send.extend(ident)
695 to_send.extend(ident)
696
696
697 to_send.append(DELIM)
697 to_send.append(DELIM)
698 to_send.append(self.sign(msg_list))
698 to_send.append(self.sign(msg_list))
699 to_send.extend(msg_list)
699 to_send.extend(msg_list)
700 stream.send_multipart(to_send, flags, copy=copy)
700 stream.send_multipart(to_send, flags, copy=copy)
701
701
702 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
702 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
703 """Receive and unpack a message.
703 """Receive and unpack a message.
704
704
705 Parameters
705 Parameters
706 ----------
706 ----------
707 socket : ZMQStream or Socket
707 socket : ZMQStream or Socket
708 The socket or stream to use in receiving.
708 The socket or stream to use in receiving.
709
709
710 Returns
710 Returns
711 -------
711 -------
712 [idents], msg
712 [idents], msg
713 [idents] is a list of idents and msg is a nested message dict of
713 [idents] is a list of idents and msg is a nested message dict of
714 same format as self.msg returns.
714 same format as self.msg returns.
715 """
715 """
716 if isinstance(socket, ZMQStream):
716 if isinstance(socket, ZMQStream):
717 socket = socket.socket
717 socket = socket.socket
718 try:
718 try:
719 msg_list = socket.recv_multipart(mode, copy=copy)
719 msg_list = socket.recv_multipart(mode, copy=copy)
720 except zmq.ZMQError as e:
720 except zmq.ZMQError as e:
721 if e.errno == zmq.EAGAIN:
721 if e.errno == zmq.EAGAIN:
722 # We can convert EAGAIN to None as we know in this case
722 # We can convert EAGAIN to None as we know in this case
723 # recv_multipart won't return None.
723 # recv_multipart won't return None.
724 return None,None
724 return None,None
725 else:
725 else:
726 raise
726 raise
727 # split multipart message into identity list and message dict
727 # split multipart message into identity list and message dict
728 # invalid large messages can cause very expensive string comparisons
728 # invalid large messages can cause very expensive string comparisons
729 idents, msg_list = self.feed_identities(msg_list, copy)
729 idents, msg_list = self.feed_identities(msg_list, copy)
730 try:
730 try:
731 return idents, self.deserialize(msg_list, content=content, copy=copy)
731 return idents, self.deserialize(msg_list, content=content, copy=copy)
732 except Exception as e:
732 except Exception as e:
733 # TODO: handle it
733 # TODO: handle it
734 raise e
734 raise e
735
735
736 def feed_identities(self, msg_list, copy=True):
736 def feed_identities(self, msg_list, copy=True):
737 """Split the identities from the rest of the message.
737 """Split the identities from the rest of the message.
738
738
739 Feed until DELIM is reached, then return the prefix as idents and
739 Feed until DELIM is reached, then return the prefix as idents and
740 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
740 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
741 but that would be silly.
741 but that would be silly.
742
742
743 Parameters
743 Parameters
744 ----------
744 ----------
745 msg_list : a list of Message or bytes objects
745 msg_list : a list of Message or bytes objects
746 The message to be split.
746 The message to be split.
747 copy : bool
747 copy : bool
748 flag determining whether the arguments are bytes or Messages
748 flag determining whether the arguments are bytes or Messages
749
749
750 Returns
750 Returns
751 -------
751 -------
752 (idents, msg_list) : two lists
752 (idents, msg_list) : two lists
753 idents will always be a list of bytes, each of which is a ZMQ
753 idents will always be a list of bytes, each of which is a ZMQ
754 identity. msg_list will be a list of bytes or zmq.Messages of the
754 identity. msg_list will be a list of bytes or zmq.Messages of the
755 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
755 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
756 should be unpackable/unserializable via self.deserialize at this
756 should be unpackable/unserializable via self.deserialize at this
757 point.
757 point.
758 """
758 """
759 if copy:
759 if copy:
760 idx = msg_list.index(DELIM)
760 idx = msg_list.index(DELIM)
761 return msg_list[:idx], msg_list[idx+1:]
761 return msg_list[:idx], msg_list[idx+1:]
762 else:
762 else:
763 failed = True
763 failed = True
764 for idx,m in enumerate(msg_list):
764 for idx,m in enumerate(msg_list):
765 if m.bytes == DELIM:
765 if m.bytes == DELIM:
766 failed = False
766 failed = False
767 break
767 break
768 if failed:
768 if failed:
769 raise ValueError("DELIM not in msg_list")
769 raise ValueError("DELIM not in msg_list")
770 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
770 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
771 return [m.bytes for m in idents], msg_list
771 return [m.bytes for m in idents], msg_list
772
772
773 def _add_digest(self, signature):
773 def _add_digest(self, signature):
774 """add a digest to history to protect against replay attacks"""
774 """add a digest to history to protect against replay attacks"""
775 if self.digest_history_size == 0:
775 if self.digest_history_size == 0:
776 # no history, never add digests
776 # no history, never add digests
777 return
777 return
778
778
779 self.digest_history.add(signature)
779 self.digest_history.add(signature)
780 if len(self.digest_history) > self.digest_history_size:
780 if len(self.digest_history) > self.digest_history_size:
781 # threshold reached, cull 10%
781 # threshold reached, cull 10%
782 self._cull_digest_history()
782 self._cull_digest_history()
783
783
784 def _cull_digest_history(self):
784 def _cull_digest_history(self):
785 """cull the digest history
785 """cull the digest history
786
786
787 Removes a randomly selected 10% of the digest history
787 Removes a randomly selected 10% of the digest history
788 """
788 """
789 current = len(self.digest_history)
789 current = len(self.digest_history)
790 n_to_cull = max(int(current // 10), current - self.digest_history_size)
790 n_to_cull = max(int(current // 10), current - self.digest_history_size)
791 if n_to_cull >= current:
791 if n_to_cull >= current:
792 self.digest_history = set()
792 self.digest_history = set()
793 return
793 return
794 to_cull = random.sample(self.digest_history, n_to_cull)
794 to_cull = random.sample(self.digest_history, n_to_cull)
795 self.digest_history.difference_update(to_cull)
795 self.digest_history.difference_update(to_cull)
796
796
797 def deserialize(self, msg_list, content=True, copy=True):
797 def deserialize(self, msg_list, content=True, copy=True):
798 """Unserialize a msg_list to a nested message dict.
798 """Unserialize a msg_list to a nested message dict.
799
799
800 This is roughly the inverse of serialize. The serialize/deserialize
800 This is roughly the inverse of serialize. The serialize/deserialize
801 methods work with full message lists, whereas pack/unpack work with
801 methods work with full message lists, whereas pack/unpack work with
802 the individual message parts in the message list.
802 the individual message parts in the message list.
803
803
804 Parameters
804 Parameters
805 ----------
805 ----------
806 msg_list : list of bytes or Message objects
806 msg_list : list of bytes or Message objects
807 The list of message parts of the form [HMAC,p_header,p_parent,
807 The list of message parts of the form [HMAC,p_header,p_parent,
808 p_metadata,p_content,buffer1,buffer2,...].
808 p_metadata,p_content,buffer1,buffer2,...].
809 content : bool (True)
809 content : bool (True)
810 Whether to unpack the content dict (True), or leave it packed
810 Whether to unpack the content dict (True), or leave it packed
811 (False).
811 (False).
812 copy : bool (True)
812 copy : bool (True)
813 Whether to return the bytes (True), or the non-copying Message
813 Whether msg_list contains bytes (True) or the non-copying Message
814 object in each place (False).
814 objects in each place (False).
815
815
816 Returns
816 Returns
817 -------
817 -------
818 msg : dict
818 msg : dict
819 The nested message dict with top-level keys [header, parent_header,
819 The nested message dict with top-level keys [header, parent_header,
820 content, buffers].
820 content, buffers]. The buffers are returned as memoryviews.
821 """
821 """
822 minlen = 5
822 minlen = 5
823 message = {}
823 message = {}
824 if not copy:
824 if not copy:
825 # pyzmq didn't copy the first parts of the message, so we'll do it
825 for i in range(minlen):
826 for i in range(minlen):
826 msg_list[i] = msg_list[i].bytes
827 msg_list[i] = msg_list[i].bytes
827 if self.auth is not None:
828 if self.auth is not None:
828 signature = msg_list[0]
829 signature = msg_list[0]
829 if not signature:
830 if not signature:
830 raise ValueError("Unsigned Message")
831 raise ValueError("Unsigned Message")
831 if signature in self.digest_history:
832 if signature in self.digest_history:
832 raise ValueError("Duplicate Signature: %r" % signature)
833 raise ValueError("Duplicate Signature: %r" % signature)
833 self._add_digest(signature)
834 self._add_digest(signature)
834 check = self.sign(msg_list[1:5])
835 check = self.sign(msg_list[1:5])
835 if not compare_digest(signature, check):
836 if not compare_digest(signature, check):
836 raise ValueError("Invalid Signature: %r" % signature)
837 raise ValueError("Invalid Signature: %r" % signature)
837 if not len(msg_list) >= minlen:
838 if not len(msg_list) >= minlen:
838 raise TypeError("malformed message, must have at least %i elements"%minlen)
839 raise TypeError("malformed message, must have at least %i elements"%minlen)
839 header = self.unpack(msg_list[1])
840 header = self.unpack(msg_list[1])
840 message['header'] = extract_dates(header)
841 message['header'] = extract_dates(header)
841 message['msg_id'] = header['msg_id']
842 message['msg_id'] = header['msg_id']
842 message['msg_type'] = header['msg_type']
843 message['msg_type'] = header['msg_type']
843 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
844 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
844 message['metadata'] = self.unpack(msg_list[3])
845 message['metadata'] = self.unpack(msg_list[3])
845 if content:
846 if content:
846 message['content'] = self.unpack(msg_list[4])
847 message['content'] = self.unpack(msg_list[4])
847 else:
848 else:
848 message['content'] = msg_list[4]
849 message['content'] = msg_list[4]
849
850 buffers = [memoryview(b) for b in msg_list[5:]]
850 message['buffers'] = msg_list[5:]
851 if buffers and buffers[0].shape is None:
852 # force copy to workaround pyzmq #646
853 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
854 message['buffers'] = buffers
851 # adapt to the current version
855 # adapt to the current version
852 return adapt(message)
856 return adapt(message)
853
857
854 def unserialize(self, *args, **kwargs):
858 def unserialize(self, *args, **kwargs):
855 warnings.warn(
859 warnings.warn(
856 "Session.unserialize is deprecated. Use Session.deserialize.",
860 "Session.unserialize is deprecated. Use Session.deserialize.",
857 DeprecationWarning,
861 DeprecationWarning,
858 )
862 )
859 return self.deserialize(*args, **kwargs)
863 return self.deserialize(*args, **kwargs)
860
864
861
865
862 def test_msg2obj():
866 def test_msg2obj():
863 am = dict(x=1)
867 am = dict(x=1)
864 ao = Message(am)
868 ao = Message(am)
865 assert ao.x == am['x']
869 assert ao.x == am['x']
866
870
867 am['y'] = dict(z=1)
871 am['y'] = dict(z=1)
868 ao = Message(am)
872 ao = Message(am)
869 assert ao.y.z == am['y']['z']
873 assert ao.y.z == am['y']['z']
870
874
871 k1, k2 = 'y', 'z'
875 k1, k2 = 'y', 'z'
872 assert ao[k1][k2] == am[k1][k2]
876 assert ao[k1][k2] == am[k1][k2]
873
877
874 am2 = dict(ao)
878 am2 = dict(ao)
875 assert am['x'] == am2['x']
879 assert am['x'] == am2['x']
876 assert am['y']['z'] == am2['y']['z']
880 assert am['y']['z'] == am2['y']['z']
877
881
General Comments 0
You need to be logged in to leave comments. Login now