##// END OF EJS Templates
Merge pull request #7798 from jasongrout/buffer-memoryview...
Min RK -
r20764:84817ef0 merge
parent child Browse files
Show More
@@ -1,278 +1,281 b''
1 1 # coding: utf-8
2 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import os
8 8 import json
9 9 import struct
10 10 import warnings
11 import sys
11 12
12 13 try:
13 14 from urllib.parse import urlparse # Py 3
14 15 except ImportError:
15 16 from urlparse import urlparse # Py 2
16 17
17 18 import tornado
18 19 from tornado import gen, ioloop, web
19 20 from tornado.websocket import WebSocketHandler
20 21
21 22 from IPython.kernel.zmq.session import Session
22 23 from IPython.utils.jsonutil import date_default, extract_dates
23 24 from IPython.utils.py3compat import cast_unicode
24 25
25 26 from .handlers import IPythonHandler
26 27
27 28 def serialize_binary_message(msg):
28 29 """serialize a message as a binary blob
29 30
30 31 Header:
31 32
32 33 4 bytes: number of msg parts (nbufs) as 32b int
33 34 4 * nbufs bytes: offset for each buffer as integer as 32b int
34 35
35 36 Offsets are from the start of the buffer, including the header.
36 37
37 38 Returns
38 39 -------
39 40
40 41 The message serialized to bytes.
41 42
42 43 """
43 44 # don't modify msg or buffer list in-place
44 45 msg = msg.copy()
45 46 buffers = list(msg.pop('buffers'))
47 if sys.version_info < (3, 4):
48 buffers = [x.tobytes() for x in buffers]
46 49 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 50 buffers.insert(0, bmsg)
48 51 nbufs = len(buffers)
49 52 offsets = [4 * (nbufs + 1)]
50 53 for buf in buffers[:-1]:
51 54 offsets.append(offsets[-1] + len(buf))
52 55 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 56 buffers.insert(0, offsets_buf)
54 57 return b''.join(buffers)
55 58
56 59
57 60 def deserialize_binary_message(bmsg):
58 61 """deserialize a message from a binary blog
59 62
60 63 Header:
61 64
62 65 4 bytes: number of msg parts (nbufs) as 32b int
63 66 4 * nbufs bytes: offset for each buffer as integer as 32b int
64 67
65 68 Offsets are from the start of the buffer, including the header.
66 69
67 70 Returns
68 71 -------
69 72
70 73 message dictionary
71 74 """
72 75 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 76 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 77 offsets.append(None)
75 78 bufs = []
76 79 for start, stop in zip(offsets[:-1], offsets[1:]):
77 80 bufs.append(bmsg[start:stop])
78 81 msg = json.loads(bufs[0].decode('utf8'))
79 82 msg['header'] = extract_dates(msg['header'])
80 83 msg['parent_header'] = extract_dates(msg['parent_header'])
81 84 msg['buffers'] = bufs[1:]
82 85 return msg
83 86
84 87 # ping interval for keeping websockets alive (30 seconds)
85 88 WS_PING_INTERVAL = 30000
86 89
87 90 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 91 warnings.warn("""Allowing draft76 websocket connections!
89 92 This should only be done for testing with phantomjs!""")
90 93 from IPython.html import allow76
91 94 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 95 # draft 76 doesn't support ping
93 96 WS_PING_INTERVAL = 0
94 97
95 98 class ZMQStreamHandler(WebSocketHandler):
96 99
97 100 if tornado.version_info < (4,1):
98 101 """Backport send_error from tornado 4.1 to 4.0"""
99 102 def send_error(self, *args, **kwargs):
100 103 if self.stream is None:
101 104 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 105 else:
103 106 # If we get an uncaught exception during the handshake,
104 107 # we have no choice but to abruptly close the connection.
105 108 # TODO: for uncaught exceptions after the handshake,
106 109 # we can close the connection more gracefully.
107 110 self.stream.close()
108 111
109 112
110 113 def check_origin(self, origin):
111 114 """Check Origin == Host or Access-Control-Allow-Origin.
112 115
113 116 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
114 117 """
115 118 if self.allow_origin == '*':
116 119 return True
117 120
118 121 host = self.request.headers.get("Host")
119 122
120 123 # If no header is provided, assume we can't verify origin
121 124 if origin is None:
122 125 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
123 126 return False
124 127 if host is None:
125 128 self.log.warn("Missing Host header, rejecting WebSocket connection.")
126 129 return False
127 130
128 131 origin = origin.lower()
129 132 origin_host = urlparse(origin).netloc
130 133
131 134 # OK if origin matches host
132 135 if origin_host == host:
133 136 return True
134 137
135 138 # Check CORS headers
136 139 if self.allow_origin:
137 140 allow = self.allow_origin == origin
138 141 elif self.allow_origin_pat:
139 142 allow = bool(self.allow_origin_pat.match(origin))
140 143 else:
141 144 # No CORS headers deny the request
142 145 allow = False
143 146 if not allow:
144 147 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
145 148 origin, host,
146 149 )
147 150 return allow
148 151
149 152 def clear_cookie(self, *args, **kwargs):
150 153 """meaningless for websockets"""
151 154 pass
152 155
153 156 def _reserialize_reply(self, msg_list, channel=None):
154 157 """Reserialize a reply message using JSON.
155 158
156 159 This takes the msg list from the ZMQ socket, deserializes it using
157 160 self.session and then serializes the result using JSON. This method
158 161 should be used by self._on_zmq_reply to build messages that can
159 162 be sent back to the browser.
160 163 """
161 164 idents, msg_list = self.session.feed_identities(msg_list)
162 165 msg = self.session.deserialize(msg_list)
163 166 if channel:
164 167 msg['channel'] = channel
165 168 if msg['buffers']:
166 169 buf = serialize_binary_message(msg)
167 170 return buf
168 171 else:
169 172 smsg = json.dumps(msg, default=date_default)
170 173 return cast_unicode(smsg)
171 174
172 175 def _on_zmq_reply(self, stream, msg_list):
173 176 # Sometimes this gets triggered when the on_close method is scheduled in the
174 177 # eventloop but hasn't been called.
175 178 if self.stream.closed() or stream.closed():
176 179 self.log.warn("zmq message arrived on closed channel")
177 180 self.close()
178 181 return
179 182 channel = getattr(stream, 'channel', None)
180 183 try:
181 184 msg = self._reserialize_reply(msg_list, channel=channel)
182 185 except Exception:
183 186 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
184 187 else:
185 188 self.write_message(msg, binary=isinstance(msg, bytes))
186 189
187 190 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
188 191 ping_callback = None
189 192 last_ping = 0
190 193 last_pong = 0
191 194
192 195 @property
193 196 def ping_interval(self):
194 197 """The interval for websocket keep-alive pings.
195 198
196 199 Set ws_ping_interval = 0 to disable pings.
197 200 """
198 201 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
199 202
200 203 @property
201 204 def ping_timeout(self):
202 205 """If no ping is received in this many milliseconds,
203 206 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
204 207 Default is max of 3 pings or 30 seconds.
205 208 """
206 209 return self.settings.get('ws_ping_timeout',
207 210 max(3 * self.ping_interval, WS_PING_INTERVAL)
208 211 )
209 212
210 213 def set_default_headers(self):
211 214 """Undo the set_default_headers in IPythonHandler
212 215
213 216 which doesn't make sense for websockets
214 217 """
215 218 pass
216 219
217 220 def pre_get(self):
218 221 """Run before finishing the GET request
219 222
220 223 Extend this method to add logic that should fire before
221 224 the websocket finishes completing.
222 225 """
223 226 # authenticate the request before opening the websocket
224 227 if self.get_current_user() is None:
225 228 self.log.warn("Couldn't authenticate WebSocket connection")
226 229 raise web.HTTPError(403)
227 230
228 231 if self.get_argument('session_id', False):
229 232 self.session.session = cast_unicode(self.get_argument('session_id'))
230 233 else:
231 234 self.log.warn("No session ID specified")
232 235
233 236 @gen.coroutine
234 237 def get(self, *args, **kwargs):
235 238 # pre_get can be a coroutine in subclasses
236 239 # assign and yield in two step to avoid tornado 3 issues
237 240 res = self.pre_get()
238 241 yield gen.maybe_future(res)
239 242 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
240 243
241 244 def initialize(self):
242 245 self.log.debug("Initializing websocket connection %s", self.request.path)
243 246 self.session = Session(config=self.config)
244 247
245 248 def open(self, *args, **kwargs):
246 249 self.log.debug("Opening websocket %s", self.request.path)
247 250
248 251 # start the pinging
249 252 if self.ping_interval > 0:
250 253 loop = ioloop.IOLoop.current()
251 254 self.last_ping = loop.time() # Remember time of last ping
252 255 self.last_pong = self.last_ping
253 256 self.ping_callback = ioloop.PeriodicCallback(
254 257 self.send_ping, self.ping_interval, io_loop=loop,
255 258 )
256 259 self.ping_callback.start()
257 260
258 261 def send_ping(self):
259 262 """send a ping to keep the websocket alive"""
260 263 if self.stream.closed() and self.ping_callback is not None:
261 264 self.ping_callback.stop()
262 265 return
263 266
264 267 # check for timeout on pong. Make sure that we really have sent a recent ping in
265 268 # case the machine with both server and client has been suspended since the last ping.
266 269 now = ioloop.IOLoop.current().time()
267 270 since_last_pong = 1e3 * (now - self.last_pong)
268 271 since_last_ping = 1e3 * (now - self.last_ping)
269 272 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
270 273 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
271 274 self.close()
272 275 return
273 276
274 277 self.ping(b'')
275 278 self.last_ping = now
276 279
277 280 def on_pong(self, data):
278 281 self.last_pong = ioloop.IOLoop.current().time()
@@ -1,26 +1,26 b''
1 1 """Test serialize/deserialize messages with buffers"""
2 2
3 3 import os
4 4
5 5 import nose.tools as nt
6 6
7 7 from IPython.kernel.zmq.session import Session
8 8 from ..base.zmqhandlers import (
9 9 serialize_binary_message,
10 10 deserialize_binary_message,
11 11 )
12 12
13 13 def test_serialize_binary():
14 14 s = Session()
15 15 msg = s.msg('data_pub', content={'a': 'b'})
16 msg['buffers'] = [ os.urandom(3) for i in range(3) ]
16 msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ]
17 17 bmsg = serialize_binary_message(msg)
18 18 nt.assert_is_instance(bmsg, bytes)
19 19
20 20 def test_deserialize_binary():
21 21 s = Session()
22 22 msg = s.msg('data_pub', content={'a': 'b'})
23 msg['buffers'] = [ os.urandom(2) for i in range(3) ]
23 msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ]
24 24 bmsg = serialize_binary_message(msg)
25 25 msg2 = deserialize_binary_message(bmsg)
26 26 nt.assert_equal(msg2, msg)
@@ -1,877 +1,881 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 import warnings
22 22 from datetime import datetime
23 23
24 24 try:
25 25 import cPickle
26 26 pickle = cPickle
27 27 except:
28 28 cPickle = None
29 29 import pickle
30 30
31 31 try:
32 32 # We are using compare_digest to limit the surface of timing attacks
33 33 from hmac import compare_digest
34 34 except ImportError:
35 35 # Python < 2.7.7: When digests don't match no feedback is provided,
36 36 # limiting the surface of attack
37 37 def compare_digest(a,b): return a == b
38 38
39 39 import zmq
40 40 from zmq.utils import jsonapi
41 41 from zmq.eventloop.ioloop import IOLoop
42 42 from zmq.eventloop.zmqstream import ZMQStream
43 43
44 44 from IPython.core.release import kernel_protocol_version
45 45 from IPython.config.configurable import Configurable, LoggingConfigurable
46 46 from IPython.utils import io
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 50 iteritems)
51 51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 52 DottedObjectName, CUnicode, Dict, Integer,
53 53 TraitError,
54 54 )
55 55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 56 from IPython.kernel.adapter import adapt
57 57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58 58
59 59 #-----------------------------------------------------------------------------
60 60 # utility functions
61 61 #-----------------------------------------------------------------------------
62 62
63 63 def squash_unicode(obj):
64 64 """coerce unicode back to bytestrings."""
65 65 if isinstance(obj,dict):
66 66 for key in obj.keys():
67 67 obj[key] = squash_unicode(obj[key])
68 68 if isinstance(key, unicode_type):
69 69 obj[squash_unicode(key)] = obj.pop(key)
70 70 elif isinstance(obj, list):
71 71 for i,v in enumerate(obj):
72 72 obj[i] = squash_unicode(v)
73 73 elif isinstance(obj, unicode_type):
74 74 obj = obj.encode('utf8')
75 75 return obj
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # globals and defaults
79 79 #-----------------------------------------------------------------------------
80 80
81 81 # ISO8601-ify datetime objects
82 82 # allow unicode
83 83 # disallow nan, because it's not actually valid JSON
84 84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
85 85 ensure_ascii=False, allow_nan=False,
86 86 )
87 87 json_unpacker = lambda s: jsonapi.loads(s)
88 88
89 89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
90 90 pickle_unpacker = pickle.loads
91 91
92 92 default_packer = json_packer
93 93 default_unpacker = json_unpacker
94 94
95 95 DELIM = b"<IDS|MSG>"
96 96 # singleton dummy tracker, which will always report as done
97 97 DONE = zmq.MessageTracker()
98 98
99 99 #-----------------------------------------------------------------------------
100 100 # Mixin tools for apps that use Sessions
101 101 #-----------------------------------------------------------------------------
102 102
103 103 session_aliases = dict(
104 104 ident = 'Session.session',
105 105 user = 'Session.username',
106 106 keyfile = 'Session.keyfile',
107 107 )
108 108
109 109 session_flags = {
110 110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
111 111 'keyfile' : '' }},
112 112 """Use HMAC digests for authentication of messages.
113 113 Setting this flag will generate a new UUID to use as the HMAC key.
114 114 """),
115 115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
116 116 """Don't authenticate messages."""),
117 117 }
118 118
119 119 def default_secure(cfg):
120 120 """Set the default behavior for a config environment to be secure.
121 121
122 122 If Session.key/keyfile have not been set, set Session.key to
123 123 a new random UUID.
124 124 """
125 125 warnings.warn("default_secure is deprecated", DeprecationWarning)
126 126 if 'Session' in cfg:
127 127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
128 128 return
129 129 # key/keyfile not specified, generate new UUID:
130 130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
131 131
132 132
133 133 #-----------------------------------------------------------------------------
134 134 # Classes
135 135 #-----------------------------------------------------------------------------
136 136
137 137 class SessionFactory(LoggingConfigurable):
138 138 """The Base class for configurables that have a Session, Context, logger,
139 139 and IOLoop.
140 140 """
141 141
142 142 logname = Unicode('')
143 143 def _logname_changed(self, name, old, new):
144 144 self.log = logging.getLogger(new)
145 145
146 146 # not configurable:
147 147 context = Instance('zmq.Context')
148 148 def _context_default(self):
149 149 return zmq.Context.instance()
150 150
151 151 session = Instance('IPython.kernel.zmq.session.Session')
152 152
153 153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
154 154 def _loop_default(self):
155 155 return IOLoop.instance()
156 156
157 157 def __init__(self, **kwargs):
158 158 super(SessionFactory, self).__init__(**kwargs)
159 159
160 160 if self.session is None:
161 161 # construct the session
162 162 self.session = Session(**kwargs)
163 163
164 164
165 165 class Message(object):
166 166 """A simple message object that maps dict keys to attributes.
167 167
168 168 A Message can be created from a dict and a dict from a Message instance
169 169 simply by calling dict(msg_obj)."""
170 170
171 171 def __init__(self, msg_dict):
172 172 dct = self.__dict__
173 173 for k, v in iteritems(dict(msg_dict)):
174 174 if isinstance(v, dict):
175 175 v = Message(v)
176 176 dct[k] = v
177 177
178 178 # Having this iterator lets dict(msg_obj) work out of the box.
179 179 def __iter__(self):
180 180 return iter(iteritems(self.__dict__))
181 181
182 182 def __repr__(self):
183 183 return repr(self.__dict__)
184 184
185 185 def __str__(self):
186 186 return pprint.pformat(self.__dict__)
187 187
188 188 def __contains__(self, k):
189 189 return k in self.__dict__
190 190
191 191 def __getitem__(self, k):
192 192 return self.__dict__[k]
193 193
194 194
195 195 def msg_header(msg_id, msg_type, username, session):
196 196 date = datetime.now()
197 197 version = kernel_protocol_version
198 198 return locals()
199 199
200 200 def extract_header(msg_or_header):
201 201 """Given a message or header, return the header."""
202 202 if not msg_or_header:
203 203 return {}
204 204 try:
205 205 # See if msg_or_header is the entire message.
206 206 h = msg_or_header['header']
207 207 except KeyError:
208 208 try:
209 209 # See if msg_or_header is just the header
210 210 h = msg_or_header['msg_id']
211 211 except KeyError:
212 212 raise
213 213 else:
214 214 h = msg_or_header
215 215 if not isinstance(h, dict):
216 216 h = dict(h)
217 217 return h
218 218
219 219 class Session(Configurable):
220 220 """Object for handling serialization and sending of messages.
221 221
222 222 The Session object handles building messages and sending them
223 223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
224 224 other over the network via Session objects, and only need to work with the
225 225 dict-based IPython message spec. The Session will handle
226 226 serialization/deserialization, security, and metadata.
227 227
228 228 Sessions support configurable serialization via packer/unpacker traits,
229 229 and signing with HMAC digests via the key/keyfile traits.
230 230
231 231 Parameters
232 232 ----------
233 233
234 234 debug : bool
235 235 whether to trigger extra debugging statements
236 236 packer/unpacker : str : 'json', 'pickle' or import_string
237 237 importstrings for methods to serialize message parts. If just
238 238 'json' or 'pickle', predefined JSON and pickle packers will be used.
239 239 Otherwise, the entire importstring must be used.
240 240
241 241 The functions must accept at least valid JSON input, and output *bytes*.
242 242
243 243 For example, to use msgpack:
244 244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
245 245 pack/unpack : callables
246 246 You can also set the pack/unpack callables for serialization directly.
247 247 session : bytes
248 248 the ID of this Session object. The default is to generate a new UUID.
249 249 username : unicode
250 250 username added to message headers. The default is to ask the OS.
251 251 key : bytes
252 252 The key used to initialize an HMAC signature. If unset, messages
253 253 will not be signed or checked.
254 254 keyfile : filepath
255 255 The file containing a key. If this is set, `key` will be initialized
256 256 to the contents of the file.
257 257
258 258 """
259 259
260 260 debug=Bool(False, config=True, help="""Debug output in the Session""")
261 261
262 262 packer = DottedObjectName('json',config=True,
263 263 help="""The name of the packer for serializing messages.
264 264 Should be one of 'json', 'pickle', or an import name
265 265 for a custom callable serializer.""")
266 266 def _packer_changed(self, name, old, new):
267 267 if new.lower() == 'json':
268 268 self.pack = json_packer
269 269 self.unpack = json_unpacker
270 270 self.unpacker = new
271 271 elif new.lower() == 'pickle':
272 272 self.pack = pickle_packer
273 273 self.unpack = pickle_unpacker
274 274 self.unpacker = new
275 275 else:
276 276 self.pack = import_item(str(new))
277 277
278 278 unpacker = DottedObjectName('json', config=True,
279 279 help="""The name of the unpacker for unserializing messages.
280 280 Only used with custom functions for `packer`.""")
281 281 def _unpacker_changed(self, name, old, new):
282 282 if new.lower() == 'json':
283 283 self.pack = json_packer
284 284 self.unpack = json_unpacker
285 285 self.packer = new
286 286 elif new.lower() == 'pickle':
287 287 self.pack = pickle_packer
288 288 self.unpack = pickle_unpacker
289 289 self.packer = new
290 290 else:
291 291 self.unpack = import_item(str(new))
292 292
293 293 session = CUnicode(u'', config=True,
294 294 help="""The UUID identifying this session.""")
295 295 def _session_default(self):
296 296 u = unicode_type(uuid.uuid4())
297 297 self.bsession = u.encode('ascii')
298 298 return u
299 299
300 300 def _session_changed(self, name, old, new):
301 301 self.bsession = self.session.encode('ascii')
302 302
303 303 # bsession is the session as bytes
304 304 bsession = CBytes(b'')
305 305
306 306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
307 307 help="""Username for the Session. Default is your system username.""",
308 308 config=True)
309 309
310 310 metadata = Dict({}, config=True,
311 311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
312 312
313 313 # if 0, no adapting to do.
314 314 adapt_version = Integer(0)
315 315
316 316 # message signature related traits:
317 317
318 318 key = CBytes(config=True,
319 319 help="""execution key, for signing messages.""")
320 320 def _key_default(self):
321 321 return str_to_bytes(str(uuid.uuid4()))
322 322
323 323 def _key_changed(self):
324 324 self._new_auth()
325 325
326 326 signature_scheme = Unicode('hmac-sha256', config=True,
327 327 help="""The digest scheme used to construct the message signatures.
328 328 Must have the form 'hmac-HASH'.""")
329 329 def _signature_scheme_changed(self, name, old, new):
330 330 if not new.startswith('hmac-'):
331 331 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
332 332 hash_name = new.split('-', 1)[1]
333 333 try:
334 334 self.digest_mod = getattr(hashlib, hash_name)
335 335 except AttributeError:
336 336 raise TraitError("hashlib has no such attribute: %s" % hash_name)
337 337 self._new_auth()
338 338
339 339 digest_mod = Any()
340 340 def _digest_mod_default(self):
341 341 return hashlib.sha256
342 342
343 343 auth = Instance(hmac.HMAC)
344 344
345 345 def _new_auth(self):
346 346 if self.key:
347 347 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
348 348 else:
349 349 self.auth = None
350 350
351 351 digest_history = Set()
352 352 digest_history_size = Integer(2**16, config=True,
353 353 help="""The maximum number of digests to remember.
354 354
355 355 The digest history will be culled when it exceeds this value.
356 356 """
357 357 )
358 358
359 359 keyfile = Unicode('', config=True,
360 360 help="""path to file containing execution key.""")
361 361 def _keyfile_changed(self, name, old, new):
362 362 with open(new, 'rb') as f:
363 363 self.key = f.read().strip()
364 364
365 365 # for protecting against sends from forks
366 366 pid = Integer()
367 367
368 368 # serialization traits:
369 369
370 370 pack = Any(default_packer) # the actual packer function
371 371 def _pack_changed(self, name, old, new):
372 372 if not callable(new):
373 373 raise TypeError("packer must be callable, not %s"%type(new))
374 374
375 375 unpack = Any(default_unpacker) # the actual packer function
376 376 def _unpack_changed(self, name, old, new):
377 377 # unpacker is not checked - it is assumed to be
378 378 if not callable(new):
379 379 raise TypeError("unpacker must be callable, not %s"%type(new))
380 380
381 381 # thresholds:
382 382 copy_threshold = Integer(2**16, config=True,
383 383 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
384 384 buffer_threshold = Integer(MAX_BYTES, config=True,
385 385 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
386 386 item_threshold = Integer(MAX_ITEMS, config=True,
387 387 help="""The maximum number of items for a container to be introspected for custom serialization.
388 388 Containers larger than this are pickled outright.
389 389 """
390 390 )
391 391
392 392
393 393 def __init__(self, **kwargs):
394 394 """create a Session object
395 395
396 396 Parameters
397 397 ----------
398 398
399 399 debug : bool
400 400 whether to trigger extra debugging statements
401 401 packer/unpacker : str : 'json', 'pickle' or import_string
402 402 importstrings for methods to serialize message parts. If just
403 403 'json' or 'pickle', predefined JSON and pickle packers will be used.
404 404 Otherwise, the entire importstring must be used.
405 405
406 406 The functions must accept at least valid JSON input, and output
407 407 *bytes*.
408 408
409 409 For example, to use msgpack:
410 410 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
411 411 pack/unpack : callables
412 412 You can also set the pack/unpack callables for serialization
413 413 directly.
414 414 session : unicode (must be ascii)
415 415 the ID of this Session object. The default is to generate a new
416 416 UUID.
417 417 bsession : bytes
418 418 The session as bytes
419 419 username : unicode
420 420 username added to message headers. The default is to ask the OS.
421 421 key : bytes
422 422 The key used to initialize an HMAC signature. If unset, messages
423 423 will not be signed or checked.
424 424 signature_scheme : str
425 425 The message digest scheme. Currently must be of the form 'hmac-HASH',
426 426 where 'HASH' is a hashing function available in Python's hashlib.
427 427 The default is 'hmac-sha256'.
428 428 This is ignored if 'key' is empty.
429 429 keyfile : filepath
430 430 The file containing a key. If this is set, `key` will be
431 431 initialized to the contents of the file.
432 432 """
433 433 super(Session, self).__init__(**kwargs)
434 434 self._check_packers()
435 435 self.none = self.pack({})
436 436 # ensure self._session_default() if necessary, so bsession is defined:
437 437 self.session
438 438 self.pid = os.getpid()
439 439 self._new_auth()
440 440
441 441 @property
442 442 def msg_id(self):
443 443 """always return new uuid"""
444 444 return str(uuid.uuid4())
445 445
446 446 def _check_packers(self):
447 447 """check packers for datetime support."""
448 448 pack = self.pack
449 449 unpack = self.unpack
450 450
451 451 # check simple serialization
452 452 msg = dict(a=[1,'hi'])
453 453 try:
454 454 packed = pack(msg)
455 455 except Exception as e:
456 456 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
457 457 if self.packer == 'json':
458 458 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
459 459 else:
460 460 jsonmsg = ""
461 461 raise ValueError(
462 462 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
463 463 )
464 464
465 465 # ensure packed message is bytes
466 466 if not isinstance(packed, bytes):
467 467 raise ValueError("message packed to %r, but bytes are required"%type(packed))
468 468
469 469 # check that unpack is pack's inverse
470 470 try:
471 471 unpacked = unpack(packed)
472 472 assert unpacked == msg
473 473 except Exception as e:
474 474 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
475 475 if self.packer == 'json':
476 476 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
477 477 else:
478 478 jsonmsg = ""
479 479 raise ValueError(
480 480 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
481 481 )
482 482
483 483 # check datetime support
484 484 msg = dict(t=datetime.now())
485 485 try:
486 486 unpacked = unpack(pack(msg))
487 487 if isinstance(unpacked['t'], datetime):
488 488 raise ValueError("Shouldn't deserialize to datetime")
489 489 except Exception:
490 490 self.pack = lambda o: pack(squash_dates(o))
491 491 self.unpack = lambda s: unpack(s)
492 492
493 493 def msg_header(self, msg_type):
494 494 return msg_header(self.msg_id, msg_type, self.username, self.session)
495 495
496 496 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
497 497 """Return the nested message dict.
498 498
499 499 This format is different from what is sent over the wire. The
500 500 serialize/deserialize methods converts this nested message dict to the wire
501 501 format, which is a list of message parts.
502 502 """
503 503 msg = {}
504 504 header = self.msg_header(msg_type) if header is None else header
505 505 msg['header'] = header
506 506 msg['msg_id'] = header['msg_id']
507 507 msg['msg_type'] = header['msg_type']
508 508 msg['parent_header'] = {} if parent is None else extract_header(parent)
509 509 msg['content'] = {} if content is None else content
510 510 msg['metadata'] = self.metadata.copy()
511 511 if metadata is not None:
512 512 msg['metadata'].update(metadata)
513 513 return msg
514 514
515 515 def sign(self, msg_list):
516 516 """Sign a message with HMAC digest. If no auth, return b''.
517 517
518 518 Parameters
519 519 ----------
520 520 msg_list : list
521 521 The [p_header,p_parent,p_content] part of the message list.
522 522 """
523 523 if self.auth is None:
524 524 return b''
525 525 h = self.auth.copy()
526 526 for m in msg_list:
527 527 h.update(m)
528 528 return str_to_bytes(h.hexdigest())
529 529
530 530 def serialize(self, msg, ident=None):
531 531 """Serialize the message components to bytes.
532 532
533 533 This is roughly the inverse of deserialize. The serialize/deserialize
534 534 methods work with full message lists, whereas pack/unpack work with
535 535 the individual message parts in the message list.
536 536
537 537 Parameters
538 538 ----------
539 539 msg : dict or Message
540 540 The next message dict as returned by the self.msg method.
541 541
542 542 Returns
543 543 -------
544 544 msg_list : list
545 545 The list of bytes objects to be sent with the format::
546 546
547 547 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
548 548 p_metadata, p_content, buffer1, buffer2, ...]
549 549
550 550 In this list, the ``p_*`` entities are the packed or serialized
551 551 versions, so if JSON is used, these are utf8 encoded JSON strings.
552 552 """
553 553 content = msg.get('content', {})
554 554 if content is None:
555 555 content = self.none
556 556 elif isinstance(content, dict):
557 557 content = self.pack(content)
558 558 elif isinstance(content, bytes):
559 559 # content is already packed, as in a relayed message
560 560 pass
561 561 elif isinstance(content, unicode_type):
562 562 # should be bytes, but JSON often spits out unicode
563 563 content = content.encode('utf8')
564 564 else:
565 565 raise TypeError("Content incorrect type: %s"%type(content))
566 566
567 567 real_message = [self.pack(msg['header']),
568 568 self.pack(msg['parent_header']),
569 569 self.pack(msg['metadata']),
570 570 content,
571 571 ]
572 572
573 573 to_send = []
574 574
575 575 if isinstance(ident, list):
576 576 # accept list of idents
577 577 to_send.extend(ident)
578 578 elif ident is not None:
579 579 to_send.append(ident)
580 580 to_send.append(DELIM)
581 581
582 582 signature = self.sign(real_message)
583 583 to_send.append(signature)
584 584
585 585 to_send.extend(real_message)
586 586
587 587 return to_send
588 588
589 589 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
590 590 buffers=None, track=False, header=None, metadata=None):
591 591 """Build and send a message via stream or socket.
592 592
593 593 The message format used by this function internally is as follows:
594 594
595 595 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
596 596 buffer1,buffer2,...]
597 597
598 598 The serialize/deserialize methods convert the nested message dict into this
599 599 format.
600 600
601 601 Parameters
602 602 ----------
603 603
604 604 stream : zmq.Socket or ZMQStream
605 605 The socket-like object used to send the data.
606 606 msg_or_type : str or Message/dict
607 607 Normally, msg_or_type will be a msg_type unless a message is being
608 608 sent more than once. If a header is supplied, this can be set to
609 609 None and the msg_type will be pulled from the header.
610 610
611 611 content : dict or None
612 612 The content of the message (ignored if msg_or_type is a message).
613 613 header : dict or None
614 614 The header dict for the message (ignored if msg_to_type is a message).
615 615 parent : Message or dict or None
616 616 The parent or parent header describing the parent of this message
617 617 (ignored if msg_or_type is a message).
618 618 ident : bytes or list of bytes
619 619 The zmq.IDENTITY routing path.
620 620 metadata : dict or None
621 621 The metadata describing the message
622 622 buffers : list or None
623 623 The already-serialized buffers to be appended to the message.
624 624 track : bool
625 625 Whether to track. Only for use with Sockets, because ZMQStream
626 626 objects cannot track messages.
627 627
628 628
629 629 Returns
630 630 -------
631 631 msg : dict
632 632 The constructed message.
633 633 """
634 634 if not isinstance(stream, zmq.Socket):
635 635 # ZMQStreams and dummy sockets do not support tracking.
636 636 track = False
637 637
638 638 if isinstance(msg_or_type, (Message, dict)):
639 639 # We got a Message or message dict, not a msg_type so don't
640 640 # build a new Message.
641 641 msg = msg_or_type
642 642 buffers = buffers or msg.get('buffers', [])
643 643 else:
644 644 msg = self.msg(msg_or_type, content=content, parent=parent,
645 645 header=header, metadata=metadata)
646 646 if not os.getpid() == self.pid:
647 647 io.rprint("WARNING: attempted to send message from fork")
648 648 io.rprint(msg)
649 649 return
650 650 buffers = [] if buffers is None else buffers
651 651 if self.adapt_version:
652 652 msg = adapt(msg, self.adapt_version)
653 653 to_send = self.serialize(msg, ident)
654 654 to_send.extend(buffers)
655 655 longest = max([ len(s) for s in to_send ])
656 656 copy = (longest < self.copy_threshold)
657 657
658 658 if buffers and track and not copy:
659 659 # only really track when we are doing zero-copy buffers
660 660 tracker = stream.send_multipart(to_send, copy=False, track=True)
661 661 else:
662 662 # use dummy tracker, which will be done immediately
663 663 tracker = DONE
664 664 stream.send_multipart(to_send, copy=copy)
665 665
666 666 if self.debug:
667 667 pprint.pprint(msg)
668 668 pprint.pprint(to_send)
669 669 pprint.pprint(buffers)
670 670
671 671 msg['tracker'] = tracker
672 672
673 673 return msg
674 674
675 675 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
676 676 """Send a raw message via ident path.
677 677
678 678 This method is used to send a already serialized message.
679 679
680 680 Parameters
681 681 ----------
682 682 stream : ZMQStream or Socket
683 683 The ZMQ stream or socket to use for sending the message.
684 684 msg_list : list
685 685 The serialized list of messages to send. This only includes the
686 686 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
687 687 the message.
688 688 ident : ident or list
689 689 A single ident or a list of idents to use in sending.
690 690 """
691 691 to_send = []
692 692 if isinstance(ident, bytes):
693 693 ident = [ident]
694 694 if ident is not None:
695 695 to_send.extend(ident)
696 696
697 697 to_send.append(DELIM)
698 698 to_send.append(self.sign(msg_list))
699 699 to_send.extend(msg_list)
700 700 stream.send_multipart(to_send, flags, copy=copy)
701 701
702 702 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
703 703 """Receive and unpack a message.
704 704
705 705 Parameters
706 706 ----------
707 707 socket : ZMQStream or Socket
708 708 The socket or stream to use in receiving.
709 709
710 710 Returns
711 711 -------
712 712 [idents], msg
713 713 [idents] is a list of idents and msg is a nested message dict of
714 714 same format as self.msg returns.
715 715 """
716 716 if isinstance(socket, ZMQStream):
717 717 socket = socket.socket
718 718 try:
719 719 msg_list = socket.recv_multipart(mode, copy=copy)
720 720 except zmq.ZMQError as e:
721 721 if e.errno == zmq.EAGAIN:
722 722 # We can convert EAGAIN to None as we know in this case
723 723 # recv_multipart won't return None.
724 724 return None,None
725 725 else:
726 726 raise
727 727 # split multipart message into identity list and message dict
728 728 # invalid large messages can cause very expensive string comparisons
729 729 idents, msg_list = self.feed_identities(msg_list, copy)
730 730 try:
731 731 return idents, self.deserialize(msg_list, content=content, copy=copy)
732 732 except Exception as e:
733 733 # TODO: handle it
734 734 raise e
735 735
736 736 def feed_identities(self, msg_list, copy=True):
737 737 """Split the identities from the rest of the message.
738 738
739 739 Feed until DELIM is reached, then return the prefix as idents and
740 740 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
741 741 but that would be silly.
742 742
743 743 Parameters
744 744 ----------
745 745 msg_list : a list of Message or bytes objects
746 746 The message to be split.
747 747 copy : bool
748 748 flag determining whether the arguments are bytes or Messages
749 749
750 750 Returns
751 751 -------
752 752 (idents, msg_list) : two lists
753 753 idents will always be a list of bytes, each of which is a ZMQ
754 754 identity. msg_list will be a list of bytes or zmq.Messages of the
755 755 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
756 756 should be unpackable/unserializable via self.deserialize at this
757 757 point.
758 758 """
759 759 if copy:
760 760 idx = msg_list.index(DELIM)
761 761 return msg_list[:idx], msg_list[idx+1:]
762 762 else:
763 763 failed = True
764 764 for idx,m in enumerate(msg_list):
765 765 if m.bytes == DELIM:
766 766 failed = False
767 767 break
768 768 if failed:
769 769 raise ValueError("DELIM not in msg_list")
770 770 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
771 771 return [m.bytes for m in idents], msg_list
772 772
773 773 def _add_digest(self, signature):
774 774 """add a digest to history to protect against replay attacks"""
775 775 if self.digest_history_size == 0:
776 776 # no history, never add digests
777 777 return
778 778
779 779 self.digest_history.add(signature)
780 780 if len(self.digest_history) > self.digest_history_size:
781 781 # threshold reached, cull 10%
782 782 self._cull_digest_history()
783 783
784 784 def _cull_digest_history(self):
785 785 """cull the digest history
786 786
787 787 Removes a randomly selected 10% of the digest history
788 788 """
789 789 current = len(self.digest_history)
790 790 n_to_cull = max(int(current // 10), current - self.digest_history_size)
791 791 if n_to_cull >= current:
792 792 self.digest_history = set()
793 793 return
794 794 to_cull = random.sample(self.digest_history, n_to_cull)
795 795 self.digest_history.difference_update(to_cull)
796 796
797 797 def deserialize(self, msg_list, content=True, copy=True):
798 798 """Unserialize a msg_list to a nested message dict.
799 799
800 800 This is roughly the inverse of serialize. The serialize/deserialize
801 801 methods work with full message lists, whereas pack/unpack work with
802 802 the individual message parts in the message list.
803 803
804 804 Parameters
805 805 ----------
806 806 msg_list : list of bytes or Message objects
807 807 The list of message parts of the form [HMAC,p_header,p_parent,
808 808 p_metadata,p_content,buffer1,buffer2,...].
809 809 content : bool (True)
810 810 Whether to unpack the content dict (True), or leave it packed
811 811 (False).
812 812 copy : bool (True)
813 Whether to return the bytes (True), or the non-copying Message
814 object in each place (False).
813 Whether msg_list contains bytes (True) or the non-copying Message
814 objects in each place (False).
815 815
816 816 Returns
817 817 -------
818 818 msg : dict
819 819 The nested message dict with top-level keys [header, parent_header,
820 content, buffers].
820 content, buffers]. The buffers are returned as memoryviews.
821 821 """
822 822 minlen = 5
823 823 message = {}
824 824 if not copy:
825 # pyzmq didn't copy the first parts of the message, so we'll do it
825 826 for i in range(minlen):
826 827 msg_list[i] = msg_list[i].bytes
827 828 if self.auth is not None:
828 829 signature = msg_list[0]
829 830 if not signature:
830 831 raise ValueError("Unsigned Message")
831 832 if signature in self.digest_history:
832 833 raise ValueError("Duplicate Signature: %r" % signature)
833 834 self._add_digest(signature)
834 835 check = self.sign(msg_list[1:5])
835 836 if not compare_digest(signature, check):
836 837 raise ValueError("Invalid Signature: %r" % signature)
837 838 if not len(msg_list) >= minlen:
838 839 raise TypeError("malformed message, must have at least %i elements"%minlen)
839 840 header = self.unpack(msg_list[1])
840 841 message['header'] = extract_dates(header)
841 842 message['msg_id'] = header['msg_id']
842 843 message['msg_type'] = header['msg_type']
843 844 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
844 845 message['metadata'] = self.unpack(msg_list[3])
845 846 if content:
846 847 message['content'] = self.unpack(msg_list[4])
847 848 else:
848 849 message['content'] = msg_list[4]
849
850 message['buffers'] = msg_list[5:]
850 buffers = [memoryview(b) for b in msg_list[5:]]
851 if buffers and buffers[0].shape is None:
852 # force copy to workaround pyzmq #646
853 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
854 message['buffers'] = buffers
851 855 # adapt to the current version
852 856 return adapt(message)
853 857
854 858 def unserialize(self, *args, **kwargs):
855 859 warnings.warn(
856 860 "Session.unserialize is deprecated. Use Session.deserialize.",
857 861 DeprecationWarning,
858 862 )
859 863 return self.deserialize(*args, **kwargs)
860 864
861 865
862 866 def test_msg2obj():
863 867 am = dict(x=1)
864 868 ao = Message(am)
865 869 assert ao.x == am['x']
866 870
867 871 am['y'] = dict(z=1)
868 872 ao = Message(am)
869 873 assert ao.y.z == am['y']['z']
870 874
871 875 k1, k2 = 'y', 'z'
872 876 assert ao[k1][k2] == am[k1][k2]
873 877
874 878 am2 = dict(ao)
875 879 assert am['x'] == am2['x']
876 880 assert am['y']['z'] == am2['y']['z']
877 881
General Comments 0
You need to be logged in to leave comments. Login now