##// END OF EJS Templates
Fix websocket/zmq serialization to expect memoryviews
Jason Grout -
Show More
@@ -1,278 +1,280
1 1 # coding: utf-8
2 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import os
8 8 import json
9 9 import struct
10 10 import warnings
11 11
12 12 try:
13 13 from urllib.parse import urlparse # Py 3
14 14 except ImportError:
15 15 from urlparse import urlparse # Py 2
16 16
17 17 import tornado
18 18 from tornado import gen, ioloop, web
19 19 from tornado.websocket import WebSocketHandler
20 20
21 21 from IPython.kernel.zmq.session import Session
22 22 from IPython.utils.jsonutil import date_default, extract_dates
23 23 from IPython.utils.py3compat import cast_unicode
24 24
25 25 from .handlers import IPythonHandler
26 26
27 27 def serialize_binary_message(msg):
28 28 """serialize a message as a binary blob
29 29
30 30 Header:
31 31
32 32 4 bytes: number of msg parts (nbufs) as 32b int
33 33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34 34
35 35 Offsets are from the start of the buffer, including the header.
36 36
37 37 Returns
38 38 -------
39 39
40 40 The message serialized to bytes.
41 41
42 42 """
43 43 # don't modify msg or buffer list in-place
44 44 msg = msg.copy()
45 45 buffers = list(msg.pop('buffers'))
46 # for python 2, copy the buffer memoryviews to byte strings
47 buffers = [x.tobytes() for x in buffers]
46 48 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 49 buffers.insert(0, bmsg)
48 50 nbufs = len(buffers)
49 51 offsets = [4 * (nbufs + 1)]
50 52 for buf in buffers[:-1]:
51 53 offsets.append(offsets[-1] + len(buf))
52 54 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 55 buffers.insert(0, offsets_buf)
54 56 return b''.join(buffers)
55 57
56 58
57 59 def deserialize_binary_message(bmsg):
58 60 """deserialize a message from a binary blog
59 61
60 62 Header:
61 63
62 64 4 bytes: number of msg parts (nbufs) as 32b int
63 65 4 * nbufs bytes: offset for each buffer as integer as 32b int
64 66
65 67 Offsets are from the start of the buffer, including the header.
66 68
67 69 Returns
68 70 -------
69 71
70 72 message dictionary
71 73 """
72 74 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 75 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 76 offsets.append(None)
75 77 bufs = []
76 78 for start, stop in zip(offsets[:-1], offsets[1:]):
77 79 bufs.append(bmsg[start:stop])
78 80 msg = json.loads(bufs[0].decode('utf8'))
79 81 msg['header'] = extract_dates(msg['header'])
80 82 msg['parent_header'] = extract_dates(msg['parent_header'])
81 83 msg['buffers'] = bufs[1:]
82 84 return msg
83 85
84 86 # ping interval for keeping websockets alive (30 seconds)
85 87 WS_PING_INTERVAL = 30000
86 88
87 89 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 90 warnings.warn("""Allowing draft76 websocket connections!
89 91 This should only be done for testing with phantomjs!""")
90 92 from IPython.html import allow76
91 93 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 94 # draft 76 doesn't support ping
93 95 WS_PING_INTERVAL = 0
94 96
95 97 class ZMQStreamHandler(WebSocketHandler):
96 98
97 99 if tornado.version_info < (4,1):
98 100 """Backport send_error from tornado 4.1 to 4.0"""
99 101 def send_error(self, *args, **kwargs):
100 102 if self.stream is None:
101 103 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 104 else:
103 105 # If we get an uncaught exception during the handshake,
104 106 # we have no choice but to abruptly close the connection.
105 107 # TODO: for uncaught exceptions after the handshake,
106 108 # we can close the connection more gracefully.
107 109 self.stream.close()
108 110
109 111
110 112 def check_origin(self, origin):
111 113 """Check Origin == Host or Access-Control-Allow-Origin.
112 114
113 115 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
114 116 """
115 117 if self.allow_origin == '*':
116 118 return True
117 119
118 120 host = self.request.headers.get("Host")
119 121
120 122 # If no header is provided, assume we can't verify origin
121 123 if origin is None:
122 124 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
123 125 return False
124 126 if host is None:
125 127 self.log.warn("Missing Host header, rejecting WebSocket connection.")
126 128 return False
127 129
128 130 origin = origin.lower()
129 131 origin_host = urlparse(origin).netloc
130 132
131 133 # OK if origin matches host
132 134 if origin_host == host:
133 135 return True
134 136
135 137 # Check CORS headers
136 138 if self.allow_origin:
137 139 allow = self.allow_origin == origin
138 140 elif self.allow_origin_pat:
139 141 allow = bool(self.allow_origin_pat.match(origin))
140 142 else:
141 143 # No CORS headers deny the request
142 144 allow = False
143 145 if not allow:
144 146 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
145 147 origin, host,
146 148 )
147 149 return allow
148 150
149 151 def clear_cookie(self, *args, **kwargs):
150 152 """meaningless for websockets"""
151 153 pass
152 154
153 155 def _reserialize_reply(self, msg_list, channel=None):
154 156 """Reserialize a reply message using JSON.
155 157
156 158 This takes the msg list from the ZMQ socket, deserializes it using
157 159 self.session and then serializes the result using JSON. This method
158 160 should be used by self._on_zmq_reply to build messages that can
159 161 be sent back to the browser.
160 162 """
161 163 idents, msg_list = self.session.feed_identities(msg_list)
162 164 msg = self.session.deserialize(msg_list)
163 165 if channel:
164 166 msg['channel'] = channel
165 167 if msg['buffers']:
166 168 buf = serialize_binary_message(msg)
167 169 return buf
168 170 else:
169 171 smsg = json.dumps(msg, default=date_default)
170 172 return cast_unicode(smsg)
171 173
172 174 def _on_zmq_reply(self, stream, msg_list):
173 175 # Sometimes this gets triggered when the on_close method is scheduled in the
174 176 # eventloop but hasn't been called.
175 177 if self.stream.closed() or stream.closed():
176 178 self.log.warn("zmq message arrived on closed channel")
177 179 self.close()
178 180 return
179 181 channel = getattr(stream, 'channel', None)
180 182 try:
181 183 msg = self._reserialize_reply(msg_list, channel=channel)
182 184 except Exception:
183 185 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
184 186 else:
185 187 self.write_message(msg, binary=isinstance(msg, bytes))
186 188
187 189 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
188 190 ping_callback = None
189 191 last_ping = 0
190 192 last_pong = 0
191 193
192 194 @property
193 195 def ping_interval(self):
194 196 """The interval for websocket keep-alive pings.
195 197
196 198 Set ws_ping_interval = 0 to disable pings.
197 199 """
198 200 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
199 201
200 202 @property
201 203 def ping_timeout(self):
202 204 """If no ping is received in this many milliseconds,
203 205 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
204 206 Default is max of 3 pings or 30 seconds.
205 207 """
206 208 return self.settings.get('ws_ping_timeout',
207 209 max(3 * self.ping_interval, WS_PING_INTERVAL)
208 210 )
209 211
210 212 def set_default_headers(self):
211 213 """Undo the set_default_headers in IPythonHandler
212 214
213 215 which doesn't make sense for websockets
214 216 """
215 217 pass
216 218
217 219 def pre_get(self):
218 220 """Run before finishing the GET request
219 221
220 222 Extend this method to add logic that should fire before
221 223 the websocket finishes completing.
222 224 """
223 225 # authenticate the request before opening the websocket
224 226 if self.get_current_user() is None:
225 227 self.log.warn("Couldn't authenticate WebSocket connection")
226 228 raise web.HTTPError(403)
227 229
228 230 if self.get_argument('session_id', False):
229 231 self.session.session = cast_unicode(self.get_argument('session_id'))
230 232 else:
231 233 self.log.warn("No session ID specified")
232 234
233 235 @gen.coroutine
234 236 def get(self, *args, **kwargs):
235 237 # pre_get can be a coroutine in subclasses
236 238 # assign and yield in two step to avoid tornado 3 issues
237 239 res = self.pre_get()
238 240 yield gen.maybe_future(res)
239 241 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
240 242
241 243 def initialize(self):
242 244 self.log.debug("Initializing websocket connection %s", self.request.path)
243 245 self.session = Session(config=self.config)
244 246
245 247 def open(self, *args, **kwargs):
246 248 self.log.debug("Opening websocket %s", self.request.path)
247 249
248 250 # start the pinging
249 251 if self.ping_interval > 0:
250 252 loop = ioloop.IOLoop.current()
251 253 self.last_ping = loop.time() # Remember time of last ping
252 254 self.last_pong = self.last_ping
253 255 self.ping_callback = ioloop.PeriodicCallback(
254 256 self.send_ping, self.ping_interval, io_loop=loop,
255 257 )
256 258 self.ping_callback.start()
257 259
258 260 def send_ping(self):
259 261 """send a ping to keep the websocket alive"""
260 262 if self.stream.closed() and self.ping_callback is not None:
261 263 self.ping_callback.stop()
262 264 return
263 265
264 266 # check for timeout on pong. Make sure that we really have sent a recent ping in
265 267 # case the machine with both server and client has been suspended since the last ping.
266 268 now = ioloop.IOLoop.current().time()
267 269 since_last_pong = 1e3 * (now - self.last_pong)
268 270 since_last_ping = 1e3 * (now - self.last_ping)
269 271 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
270 272 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
271 273 self.close()
272 274 return
273 275
274 276 self.ping(b'')
275 277 self.last_ping = now
276 278
277 279 def on_pong(self, data):
278 280 self.last_pong = ioloop.IOLoop.current().time()
General Comments 0
You need to be logged in to leave comments. Login now