##// END OF EJS Templates
test websocket-friendly binary message roundtrip...
MinRK -
Show More
@@ -0,0 +1,26 b''
1 """Test serialize/deserialize messages with buffers"""
2
3 import os
4
5 import nose.tools as nt
6
7 from IPython.kernel.zmq.session import Session
8 from ..base.zmqhandlers import (
9 serialize_binary_message,
10 deserialize_binary_message,
11 )
12
13 def test_serialize_binary():
14 s = Session()
15 msg = s.msg('data_pub', content={'a': 'b'})
16 msg['buffers'] = [ os.urandom(3) for i in range(3) ]
17 bmsg = serialize_binary_message(msg)
18 nt.assert_is_instance(bmsg, bytes)
19
20 def test_deserialize_binary():
21 s = Session()
22 msg = s.msg('data_pub', content={'a': 'b'})
23 msg['buffers'] = [ os.urandom(2) for i in range(3) ]
24 bmsg = serialize_binary_message(msg)
25 msg2 = deserialize_binary_message(bmsg)
26 nt.assert_equal(msg2, msg)
@@ -1,261 +1,257 b''
1 1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 7 import struct
8 8
9 9 try:
10 10 from urllib.parse import urlparse # Py 3
11 11 except ImportError:
12 12 from urlparse import urlparse # Py 2
13 13
14 14 try:
15 15 from http.cookies import SimpleCookie # Py 3
16 16 except ImportError:
17 17 from Cookie import SimpleCookie # Py 2
18 18 import logging
19 19
20 20 import tornado
21 21 from tornado import ioloop
22 22 from tornado import web
23 23 from tornado import websocket
24 24
25 25 from IPython.kernel.zmq.session import Session
26 from IPython.utils.jsonutil import date_default
26 from IPython.utils.jsonutil import date_default, extract_dates
27 27 from IPython.utils.py3compat import PY3, cast_unicode
28 28
29 29 from .handlers import IPythonHandler
30 30
31 31
32 32 def serialize_binary_message(msg):
33 33 """serialize a message as a binary blob
34 34
35 35 Header:
36 36
37 37 4 bytes: number of msg parts (nbufs) as 32b int
38 38 4 * nbufs bytes: offset for each buffer as integer as 32b int
39 39
40 40 Offsets are from the start of the buffer, including the header.
41 41
42 42 Returns
43 43 -------
44 44
45 45 The message serialized to bytes.
46 46
47 47 """
48 buffers = msg.pop('buffers')
48 # don't modify msg or buffer list in-place
49 msg = msg.copy()
50 buffers = list(msg.pop('buffers'))
49 51 bmsg = json.dumps(msg, default=date_default).encode('utf8')
50 52 buffers.insert(0, bmsg)
51 53 nbufs = len(buffers)
52 54 offsets = [4 * (nbufs + 1)]
53 55 for buf in buffers[:-1]:
54 56 offsets.append(offsets[-1] + len(buf))
55 57 offsets_buf = struct.pack('!' + 'i' * (nbufs + 1), nbufs, *offsets)
56 58 buffers.insert(0, offsets_buf)
57 59 return b''.join(buffers)
58 60
59 61
60 62 def deserialize_binary_message(bmsg):
61 63 """deserialize a message from a binary blog
62 64
63 65 Header:
64 66
65 67 4 bytes: number of msg parts (nbufs) as 32b int
66 68 4 * nbufs bytes: offset for each buffer as integer as 32b int
67 69
68 70 Offsets are from the start of the buffer, including the header.
69 71
70 72 Returns
71 73 -------
72 74
73 75 message dictionary
74 76 """
75 nbufs = struct.unpack('i', bmsg[:4])[0]
77 nbufs = struct.unpack('!i', bmsg[:4])[0]
76 78 offsets = list(struct.unpack('!' + 'i' * nbufs, bmsg[4:4*(nbufs+1)]))
77 79 offsets.append(None)
78 80 bufs = []
79 81 for start, stop in zip(offsets[:-1], offsets[1:]):
80 82 bufs.append(bmsg[start:stop])
81 msg = json.loads(bufs[0])
83 msg = json.loads(bufs[0].decode('utf8'))
84 msg['header'] = extract_dates(msg['header'])
85 msg['parent_header'] = extract_dates(msg['parent_header'])
82 86 msg['buffers'] = bufs[1:]
83 87 return msg
84 88
85 89
86 90 class ZMQStreamHandler(websocket.WebSocketHandler):
87 91
88 92 def check_origin(self, origin):
89 93 """Check Origin == Host or Access-Control-Allow-Origin.
90 94
91 95 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
92 96 We call it explicitly in `open` on Tornado < 4.
93 97 """
94 98 if self.allow_origin == '*':
95 99 return True
96 100
97 101 host = self.request.headers.get("Host")
98 102
99 103 # If no header is provided, assume we can't verify origin
100 104 if origin is None:
101 105 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
102 106 return False
103 107 if host is None:
104 108 self.log.warn("Missing Host header, rejecting WebSocket connection.")
105 109 return False
106 110
107 111 origin = origin.lower()
108 112 origin_host = urlparse(origin).netloc
109 113
110 114 # OK if origin matches host
111 115 if origin_host == host:
112 116 return True
113 117
114 118 # Check CORS headers
115 119 if self.allow_origin:
116 120 allow = self.allow_origin == origin
117 121 elif self.allow_origin_pat:
118 122 allow = bool(self.allow_origin_pat.match(origin))
119 123 else:
120 124 # No CORS headers deny the request
121 125 allow = False
122 126 if not allow:
123 127 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
124 128 origin, host,
125 129 )
126 130 return allow
127 131
128 132 def clear_cookie(self, *args, **kwargs):
129 133 """meaningless for websockets"""
130 134 pass
131 135
132 136 def _reserialize_reply(self, msg_list):
133 137 """Reserialize a reply message using JSON.
134 138
135 139 This takes the msg list from the ZMQ socket, deserializes it using
136 140 self.session and then serializes the result using JSON. This method
137 141 should be used by self._on_zmq_reply to build messages that can
138 142 be sent back to the browser.
139 143 """
140 144 idents, msg_list = self.session.feed_identities(msg_list)
141 145 msg = self.session.deserialize(msg_list)
142 try:
143 msg['header'].pop('date')
144 except KeyError:
145 pass
146 try:
147 msg['parent_header'].pop('date')
148 except KeyError:
149 pass
150 146 if msg['buffers']:
151 147 buf = serialize_binary_message(msg)
152 148 return buf
153 149 else:
154 150 smsg = json.dumps(msg, default=date_default)
155 151 return cast_unicode(smsg)
156 152
157 153 def _on_zmq_reply(self, msg_list):
158 154 # Sometimes this gets triggered when the on_close method is scheduled in the
159 155 # eventloop but hasn't been called.
160 156 if self.stream.closed(): return
161 157 try:
162 158 msg = self._reserialize_reply(msg_list)
163 159 except Exception:
164 160 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
165 161 else:
166 162 self.write_message(msg, binary=isinstance(msg, bytes))
167 163
168 164 def allow_draft76(self):
169 165 """Allow draft 76, until browsers such as Safari update to RFC 6455.
170 166
171 167 This has been disabled by default in tornado in release 2.2.0, and
172 168 support will be removed in later versions.
173 169 """
174 170 return True
175 171
176 172 # ping interval for keeping websockets alive (30 seconds)
177 173 WS_PING_INTERVAL = 30000
178 174
179 175 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
180 176 ping_callback = None
181 177 last_ping = 0
182 178 last_pong = 0
183 179
184 180 @property
185 181 def ping_interval(self):
186 182 """The interval for websocket keep-alive pings.
187 183
188 184 Set ws_ping_interval = 0 to disable pings.
189 185 """
190 186 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
191 187
192 188 @property
193 189 def ping_timeout(self):
194 190 """If no ping is received in this many milliseconds,
195 191 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
196 192 Default is max of 3 pings or 30 seconds.
197 193 """
198 194 return self.settings.get('ws_ping_timeout',
199 195 max(3 * self.ping_interval, WS_PING_INTERVAL)
200 196 )
201 197
202 198 def set_default_headers(self):
203 199 """Undo the set_default_headers in IPythonHandler
204 200
205 201 which doesn't make sense for websockets
206 202 """
207 203 pass
208 204
209 205 def get(self, *args, **kwargs):
210 206 # Check to see that origin matches host directly, including ports
211 207 # Tornado 4 already does CORS checking
212 208 if tornado.version_info[0] < 4:
213 209 if not self.check_origin(self.get_origin()):
214 210 raise web.HTTPError(403)
215 211
216 212 # authenticate the request before opening the websocket
217 213 if self.get_current_user() is None:
218 214 self.log.warn("Couldn't authenticate WebSocket connection")
219 215 raise web.HTTPError(403)
220 216
221 217 if self.get_argument('session_id', False):
222 218 self.session.session = cast_unicode(self.get_argument('session_id'))
223 219 else:
224 220 self.log.warn("No session ID specified")
225 221
226 222 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
227 223
228 224 def initialize(self):
229 225 self.session = Session(config=self.config)
230 226
231 227 def open(self, kernel_id):
232 228 self.kernel_id = cast_unicode(kernel_id, 'ascii')
233 229
234 230 # start the pinging
235 231 if self.ping_interval > 0:
236 232 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
237 233 self.last_pong = self.last_ping
238 234 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
239 235 self.ping_callback.start()
240 236
241 237 def send_ping(self):
242 238 """send a ping to keep the websocket alive"""
243 239 if self.stream.closed() and self.ping_callback is not None:
244 240 self.ping_callback.stop()
245 241 return
246 242
247 243 # check for timeout on pong. Make sure that we really have sent a recent ping in
248 244 # case the machine with both server and client has been suspended since the last ping.
249 245 now = ioloop.IOLoop.instance().time()
250 246 since_last_pong = 1e3 * (now - self.last_pong)
251 247 since_last_ping = 1e3 * (now - self.last_ping)
252 248 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
253 249 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
254 250 self.close()
255 251 return
256 252
257 253 self.ping(b'')
258 254 self.last_ping = now
259 255
260 256 def on_pong(self, data):
261 257 self.last_pong = ioloop.IOLoop.instance().time()
General Comments 0
You need to be logged in to leave comments. Login now