##// END OF EJS Templates
test websocket-friendly binary message roundtrip...
MinRK -
Show More
@@ -0,0 +1,26 b''
1 """Test serialize/deserialize messages with buffers"""
2
3 import os
4
5 import nose.tools as nt
6
7 from IPython.kernel.zmq.session import Session
8 from ..base.zmqhandlers import (
9 serialize_binary_message,
10 deserialize_binary_message,
11 )
12
13 def test_serialize_binary():
14 s = Session()
15 msg = s.msg('data_pub', content={'a': 'b'})
16 msg['buffers'] = [ os.urandom(3) for i in range(3) ]
17 bmsg = serialize_binary_message(msg)
18 nt.assert_is_instance(bmsg, bytes)
19
20 def test_deserialize_binary():
21 s = Session()
22 msg = s.msg('data_pub', content={'a': 'b'})
23 msg['buffers'] = [ os.urandom(2) for i in range(3) ]
24 bmsg = serialize_binary_message(msg)
25 msg2 = deserialize_binary_message(bmsg)
26 nt.assert_equal(msg2, msg)
@@ -1,261 +1,257 b''
1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import json
6 import json
7 import struct
7 import struct
8
8
9 try:
9 try:
10 from urllib.parse import urlparse # Py 3
10 from urllib.parse import urlparse # Py 3
11 except ImportError:
11 except ImportError:
12 from urlparse import urlparse # Py 2
12 from urlparse import urlparse # Py 2
13
13
14 try:
14 try:
15 from http.cookies import SimpleCookie # Py 3
15 from http.cookies import SimpleCookie # Py 3
16 except ImportError:
16 except ImportError:
17 from Cookie import SimpleCookie # Py 2
17 from Cookie import SimpleCookie # Py 2
18 import logging
18 import logging
19
19
20 import tornado
20 import tornado
21 from tornado import ioloop
21 from tornado import ioloop
22 from tornado import web
22 from tornado import web
23 from tornado import websocket
23 from tornado import websocket
24
24
25 from IPython.kernel.zmq.session import Session
25 from IPython.kernel.zmq.session import Session
26 from IPython.utils.jsonutil import date_default
26 from IPython.utils.jsonutil import date_default, extract_dates
27 from IPython.utils.py3compat import PY3, cast_unicode
27 from IPython.utils.py3compat import PY3, cast_unicode
28
28
29 from .handlers import IPythonHandler
29 from .handlers import IPythonHandler
30
30
31
31
32 def serialize_binary_message(msg):
32 def serialize_binary_message(msg):
33 """serialize a message as a binary blob
33 """serialize a message as a binary blob
34
34
35 Header:
35 Header:
36
36
37 4 bytes: number of msg parts (nbufs) as 32b int
37 4 bytes: number of msg parts (nbufs) as 32b int
38 4 * nbufs bytes: offset for each buffer as integer as 32b int
38 4 * nbufs bytes: offset for each buffer as integer as 32b int
39
39
40 Offsets are from the start of the buffer, including the header.
40 Offsets are from the start of the buffer, including the header.
41
41
42 Returns
42 Returns
43 -------
43 -------
44
44
45 The message serialized to bytes.
45 The message serialized to bytes.
46
46
47 """
47 """
48 buffers = msg.pop('buffers')
48 # don't modify msg or buffer list in-place
49 msg = msg.copy()
50 buffers = list(msg.pop('buffers'))
49 bmsg = json.dumps(msg, default=date_default).encode('utf8')
51 bmsg = json.dumps(msg, default=date_default).encode('utf8')
50 buffers.insert(0, bmsg)
52 buffers.insert(0, bmsg)
51 nbufs = len(buffers)
53 nbufs = len(buffers)
52 offsets = [4 * (nbufs + 1)]
54 offsets = [4 * (nbufs + 1)]
53 for buf in buffers[:-1]:
55 for buf in buffers[:-1]:
54 offsets.append(offsets[-1] + len(buf))
56 offsets.append(offsets[-1] + len(buf))
55 offsets_buf = struct.pack('!' + 'i' * (nbufs + 1), nbufs, *offsets)
57 offsets_buf = struct.pack('!' + 'i' * (nbufs + 1), nbufs, *offsets)
56 buffers.insert(0, offsets_buf)
58 buffers.insert(0, offsets_buf)
57 return b''.join(buffers)
59 return b''.join(buffers)
58
60
59
61
60 def deserialize_binary_message(bmsg):
62 def deserialize_binary_message(bmsg):
61 """deserialize a message from a binary blog
63 """deserialize a message from a binary blog
62
64
63 Header:
65 Header:
64
66
65 4 bytes: number of msg parts (nbufs) as 32b int
67 4 bytes: number of msg parts (nbufs) as 32b int
66 4 * nbufs bytes: offset for each buffer as integer as 32b int
68 4 * nbufs bytes: offset for each buffer as integer as 32b int
67
69
68 Offsets are from the start of the buffer, including the header.
70 Offsets are from the start of the buffer, including the header.
69
71
70 Returns
72 Returns
71 -------
73 -------
72
74
73 message dictionary
75 message dictionary
74 """
76 """
75 nbufs = struct.unpack('i', bmsg[:4])[0]
77 nbufs = struct.unpack('!i', bmsg[:4])[0]
76 offsets = list(struct.unpack('!' + 'i' * nbufs, bmsg[4:4*(nbufs+1)]))
78 offsets = list(struct.unpack('!' + 'i' * nbufs, bmsg[4:4*(nbufs+1)]))
77 offsets.append(None)
79 offsets.append(None)
78 bufs = []
80 bufs = []
79 for start, stop in zip(offsets[:-1], offsets[1:]):
81 for start, stop in zip(offsets[:-1], offsets[1:]):
80 bufs.append(bmsg[start:stop])
82 bufs.append(bmsg[start:stop])
81 msg = json.loads(bufs[0])
83 msg = json.loads(bufs[0].decode('utf8'))
84 msg['header'] = extract_dates(msg['header'])
85 msg['parent_header'] = extract_dates(msg['parent_header'])
82 msg['buffers'] = bufs[1:]
86 msg['buffers'] = bufs[1:]
83 return msg
87 return msg
84
88
85
89
86 class ZMQStreamHandler(websocket.WebSocketHandler):
90 class ZMQStreamHandler(websocket.WebSocketHandler):
87
91
88 def check_origin(self, origin):
92 def check_origin(self, origin):
89 """Check Origin == Host or Access-Control-Allow-Origin.
93 """Check Origin == Host or Access-Control-Allow-Origin.
90
94
91 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
95 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
92 We call it explicitly in `open` on Tornado < 4.
96 We call it explicitly in `open` on Tornado < 4.
93 """
97 """
94 if self.allow_origin == '*':
98 if self.allow_origin == '*':
95 return True
99 return True
96
100
97 host = self.request.headers.get("Host")
101 host = self.request.headers.get("Host")
98
102
99 # If no header is provided, assume we can't verify origin
103 # If no header is provided, assume we can't verify origin
100 if origin is None:
104 if origin is None:
101 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
105 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
102 return False
106 return False
103 if host is None:
107 if host is None:
104 self.log.warn("Missing Host header, rejecting WebSocket connection.")
108 self.log.warn("Missing Host header, rejecting WebSocket connection.")
105 return False
109 return False
106
110
107 origin = origin.lower()
111 origin = origin.lower()
108 origin_host = urlparse(origin).netloc
112 origin_host = urlparse(origin).netloc
109
113
110 # OK if origin matches host
114 # OK if origin matches host
111 if origin_host == host:
115 if origin_host == host:
112 return True
116 return True
113
117
114 # Check CORS headers
118 # Check CORS headers
115 if self.allow_origin:
119 if self.allow_origin:
116 allow = self.allow_origin == origin
120 allow = self.allow_origin == origin
117 elif self.allow_origin_pat:
121 elif self.allow_origin_pat:
118 allow = bool(self.allow_origin_pat.match(origin))
122 allow = bool(self.allow_origin_pat.match(origin))
119 else:
123 else:
120 # No CORS headers deny the request
124 # No CORS headers deny the request
121 allow = False
125 allow = False
122 if not allow:
126 if not allow:
123 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
127 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
124 origin, host,
128 origin, host,
125 )
129 )
126 return allow
130 return allow
127
131
128 def clear_cookie(self, *args, **kwargs):
132 def clear_cookie(self, *args, **kwargs):
129 """meaningless for websockets"""
133 """meaningless for websockets"""
130 pass
134 pass
131
135
132 def _reserialize_reply(self, msg_list):
136 def _reserialize_reply(self, msg_list):
133 """Reserialize a reply message using JSON.
137 """Reserialize a reply message using JSON.
134
138
135 This takes the msg list from the ZMQ socket, deserializes it using
139 This takes the msg list from the ZMQ socket, deserializes it using
136 self.session and then serializes the result using JSON. This method
140 self.session and then serializes the result using JSON. This method
137 should be used by self._on_zmq_reply to build messages that can
141 should be used by self._on_zmq_reply to build messages that can
138 be sent back to the browser.
142 be sent back to the browser.
139 """
143 """
140 idents, msg_list = self.session.feed_identities(msg_list)
144 idents, msg_list = self.session.feed_identities(msg_list)
141 msg = self.session.deserialize(msg_list)
145 msg = self.session.deserialize(msg_list)
142 try:
143 msg['header'].pop('date')
144 except KeyError:
145 pass
146 try:
147 msg['parent_header'].pop('date')
148 except KeyError:
149 pass
150 if msg['buffers']:
146 if msg['buffers']:
151 buf = serialize_binary_message(msg)
147 buf = serialize_binary_message(msg)
152 return buf
148 return buf
153 else:
149 else:
154 smsg = json.dumps(msg, default=date_default)
150 smsg = json.dumps(msg, default=date_default)
155 return cast_unicode(smsg)
151 return cast_unicode(smsg)
156
152
157 def _on_zmq_reply(self, msg_list):
153 def _on_zmq_reply(self, msg_list):
158 # Sometimes this gets triggered when the on_close method is scheduled in the
154 # Sometimes this gets triggered when the on_close method is scheduled in the
159 # eventloop but hasn't been called.
155 # eventloop but hasn't been called.
160 if self.stream.closed(): return
156 if self.stream.closed(): return
161 try:
157 try:
162 msg = self._reserialize_reply(msg_list)
158 msg = self._reserialize_reply(msg_list)
163 except Exception:
159 except Exception:
164 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
160 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
165 else:
161 else:
166 self.write_message(msg, binary=isinstance(msg, bytes))
162 self.write_message(msg, binary=isinstance(msg, bytes))
167
163
168 def allow_draft76(self):
164 def allow_draft76(self):
169 """Allow draft 76, until browsers such as Safari update to RFC 6455.
165 """Allow draft 76, until browsers such as Safari update to RFC 6455.
170
166
171 This has been disabled by default in tornado in release 2.2.0, and
167 This has been disabled by default in tornado in release 2.2.0, and
172 support will be removed in later versions.
168 support will be removed in later versions.
173 """
169 """
174 return True
170 return True
175
171
176 # ping interval for keeping websockets alive (30 seconds)
172 # ping interval for keeping websockets alive (30 seconds)
177 WS_PING_INTERVAL = 30000
173 WS_PING_INTERVAL = 30000
178
174
179 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
175 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
180 ping_callback = None
176 ping_callback = None
181 last_ping = 0
177 last_ping = 0
182 last_pong = 0
178 last_pong = 0
183
179
184 @property
180 @property
185 def ping_interval(self):
181 def ping_interval(self):
186 """The interval for websocket keep-alive pings.
182 """The interval for websocket keep-alive pings.
187
183
188 Set ws_ping_interval = 0 to disable pings.
184 Set ws_ping_interval = 0 to disable pings.
189 """
185 """
190 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
186 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
191
187
192 @property
188 @property
193 def ping_timeout(self):
189 def ping_timeout(self):
194 """If no ping is received in this many milliseconds,
190 """If no ping is received in this many milliseconds,
195 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
191 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
196 Default is max of 3 pings or 30 seconds.
192 Default is max of 3 pings or 30 seconds.
197 """
193 """
198 return self.settings.get('ws_ping_timeout',
194 return self.settings.get('ws_ping_timeout',
199 max(3 * self.ping_interval, WS_PING_INTERVAL)
195 max(3 * self.ping_interval, WS_PING_INTERVAL)
200 )
196 )
201
197
202 def set_default_headers(self):
198 def set_default_headers(self):
203 """Undo the set_default_headers in IPythonHandler
199 """Undo the set_default_headers in IPythonHandler
204
200
205 which doesn't make sense for websockets
201 which doesn't make sense for websockets
206 """
202 """
207 pass
203 pass
208
204
209 def get(self, *args, **kwargs):
205 def get(self, *args, **kwargs):
210 # Check to see that origin matches host directly, including ports
206 # Check to see that origin matches host directly, including ports
211 # Tornado 4 already does CORS checking
207 # Tornado 4 already does CORS checking
212 if tornado.version_info[0] < 4:
208 if tornado.version_info[0] < 4:
213 if not self.check_origin(self.get_origin()):
209 if not self.check_origin(self.get_origin()):
214 raise web.HTTPError(403)
210 raise web.HTTPError(403)
215
211
216 # authenticate the request before opening the websocket
212 # authenticate the request before opening the websocket
217 if self.get_current_user() is None:
213 if self.get_current_user() is None:
218 self.log.warn("Couldn't authenticate WebSocket connection")
214 self.log.warn("Couldn't authenticate WebSocket connection")
219 raise web.HTTPError(403)
215 raise web.HTTPError(403)
220
216
221 if self.get_argument('session_id', False):
217 if self.get_argument('session_id', False):
222 self.session.session = cast_unicode(self.get_argument('session_id'))
218 self.session.session = cast_unicode(self.get_argument('session_id'))
223 else:
219 else:
224 self.log.warn("No session ID specified")
220 self.log.warn("No session ID specified")
225
221
226 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
222 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
227
223
228 def initialize(self):
224 def initialize(self):
229 self.session = Session(config=self.config)
225 self.session = Session(config=self.config)
230
226
231 def open(self, kernel_id):
227 def open(self, kernel_id):
232 self.kernel_id = cast_unicode(kernel_id, 'ascii')
228 self.kernel_id = cast_unicode(kernel_id, 'ascii')
233
229
234 # start the pinging
230 # start the pinging
235 if self.ping_interval > 0:
231 if self.ping_interval > 0:
236 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
232 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
237 self.last_pong = self.last_ping
233 self.last_pong = self.last_ping
238 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
234 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
239 self.ping_callback.start()
235 self.ping_callback.start()
240
236
241 def send_ping(self):
237 def send_ping(self):
242 """send a ping to keep the websocket alive"""
238 """send a ping to keep the websocket alive"""
243 if self.stream.closed() and self.ping_callback is not None:
239 if self.stream.closed() and self.ping_callback is not None:
244 self.ping_callback.stop()
240 self.ping_callback.stop()
245 return
241 return
246
242
247 # check for timeout on pong. Make sure that we really have sent a recent ping in
243 # check for timeout on pong. Make sure that we really have sent a recent ping in
248 # case the machine with both server and client has been suspended since the last ping.
244 # case the machine with both server and client has been suspended since the last ping.
249 now = ioloop.IOLoop.instance().time()
245 now = ioloop.IOLoop.instance().time()
250 since_last_pong = 1e3 * (now - self.last_pong)
246 since_last_pong = 1e3 * (now - self.last_pong)
251 since_last_ping = 1e3 * (now - self.last_ping)
247 since_last_ping = 1e3 * (now - self.last_ping)
252 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
248 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
253 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
249 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
254 self.close()
250 self.close()
255 return
251 return
256
252
257 self.ping(b'')
253 self.ping(b'')
258 self.last_ping = now
254 self.last_ping = now
259
255
260 def on_pong(self, data):
256 def on_pong(self, data):
261 self.last_pong = ioloop.IOLoop.instance().time()
257 self.last_pong = ioloop.IOLoop.instance().time()
General Comments 0
You need to be logged in to leave comments. Login now