##// END OF EJS Templates
Fix websocket/zmq serialization to expect memoryviews
Jason Grout -
Show More
@@ -1,278 +1,280
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import os
8 import json
8 import json
9 import struct
9 import struct
10 import warnings
10 import warnings
11
11
12 try:
12 try:
13 from urllib.parse import urlparse # Py 3
13 from urllib.parse import urlparse # Py 3
14 except ImportError:
14 except ImportError:
15 from urlparse import urlparse # Py 2
15 from urlparse import urlparse # Py 2
16
16
17 import tornado
17 import tornado
18 from tornado import gen, ioloop, web
18 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
19 from tornado.websocket import WebSocketHandler
20
20
21 from IPython.kernel.zmq.session import Session
21 from IPython.kernel.zmq.session import Session
22 from IPython.utils.jsonutil import date_default, extract_dates
22 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.py3compat import cast_unicode
23 from IPython.utils.py3compat import cast_unicode
24
24
25 from .handlers import IPythonHandler
25 from .handlers import IPythonHandler
26
26
27 def serialize_binary_message(msg):
27 def serialize_binary_message(msg):
28 """serialize a message as a binary blob
28 """serialize a message as a binary blob
29
29
30 Header:
30 Header:
31
31
32 4 bytes: number of msg parts (nbufs) as 32b int
32 4 bytes: number of msg parts (nbufs) as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34
34
35 Offsets are from the start of the buffer, including the header.
35 Offsets are from the start of the buffer, including the header.
36
36
37 Returns
37 Returns
38 -------
38 -------
39
39
40 The message serialized to bytes.
40 The message serialized to bytes.
41
41
42 """
42 """
43 # don't modify msg or buffer list in-place
43 # don't modify msg or buffer list in-place
44 msg = msg.copy()
44 msg = msg.copy()
45 buffers = list(msg.pop('buffers'))
45 buffers = list(msg.pop('buffers'))
46 # for python 2, copy the buffer memoryviews to byte strings
47 buffers = [x.tobytes() for x in buffers]
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
48 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 buffers.insert(0, bmsg)
49 buffers.insert(0, bmsg)
48 nbufs = len(buffers)
50 nbufs = len(buffers)
49 offsets = [4 * (nbufs + 1)]
51 offsets = [4 * (nbufs + 1)]
50 for buf in buffers[:-1]:
52 for buf in buffers[:-1]:
51 offsets.append(offsets[-1] + len(buf))
53 offsets.append(offsets[-1] + len(buf))
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
54 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 buffers.insert(0, offsets_buf)
55 buffers.insert(0, offsets_buf)
54 return b''.join(buffers)
56 return b''.join(buffers)
55
57
56
58
57 def deserialize_binary_message(bmsg):
59 def deserialize_binary_message(bmsg):
58 """deserialize a message from a binary blog
60 """deserialize a message from a binary blog
59
61
60 Header:
62 Header:
61
63
62 4 bytes: number of msg parts (nbufs) as 32b int
64 4 bytes: number of msg parts (nbufs) as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
65 4 * nbufs bytes: offset for each buffer as integer as 32b int
64
66
65 Offsets are from the start of the buffer, including the header.
67 Offsets are from the start of the buffer, including the header.
66
68
67 Returns
69 Returns
68 -------
70 -------
69
71
70 message dictionary
72 message dictionary
71 """
73 """
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
74 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
75 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 offsets.append(None)
76 offsets.append(None)
75 bufs = []
77 bufs = []
76 for start, stop in zip(offsets[:-1], offsets[1:]):
78 for start, stop in zip(offsets[:-1], offsets[1:]):
77 bufs.append(bmsg[start:stop])
79 bufs.append(bmsg[start:stop])
78 msg = json.loads(bufs[0].decode('utf8'))
80 msg = json.loads(bufs[0].decode('utf8'))
79 msg['header'] = extract_dates(msg['header'])
81 msg['header'] = extract_dates(msg['header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
82 msg['parent_header'] = extract_dates(msg['parent_header'])
81 msg['buffers'] = bufs[1:]
83 msg['buffers'] = bufs[1:]
82 return msg
84 return msg
83
85
84 # ping interval for keeping websockets alive (30 seconds)
86 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
87 WS_PING_INTERVAL = 30000
86
88
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
89 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
90 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
91 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
92 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
93 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
94 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
95 WS_PING_INTERVAL = 0
94
96
95 class ZMQStreamHandler(WebSocketHandler):
97 class ZMQStreamHandler(WebSocketHandler):
96
98
97 if tornado.version_info < (4,1):
99 if tornado.version_info < (4,1):
98 """Backport send_error from tornado 4.1 to 4.0"""
100 """Backport send_error from tornado 4.1 to 4.0"""
99 def send_error(self, *args, **kwargs):
101 def send_error(self, *args, **kwargs):
100 if self.stream is None:
102 if self.stream is None:
101 super(WebSocketHandler, self).send_error(*args, **kwargs)
103 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 else:
104 else:
103 # If we get an uncaught exception during the handshake,
105 # If we get an uncaught exception during the handshake,
104 # we have no choice but to abruptly close the connection.
106 # we have no choice but to abruptly close the connection.
105 # TODO: for uncaught exceptions after the handshake,
107 # TODO: for uncaught exceptions after the handshake,
106 # we can close the connection more gracefully.
108 # we can close the connection more gracefully.
107 self.stream.close()
109 self.stream.close()
108
110
109
111
110 def check_origin(self, origin):
112 def check_origin(self, origin):
111 """Check Origin == Host or Access-Control-Allow-Origin.
113 """Check Origin == Host or Access-Control-Allow-Origin.
112
114
113 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
115 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
114 """
116 """
115 if self.allow_origin == '*':
117 if self.allow_origin == '*':
116 return True
118 return True
117
119
118 host = self.request.headers.get("Host")
120 host = self.request.headers.get("Host")
119
121
120 # If no header is provided, assume we can't verify origin
122 # If no header is provided, assume we can't verify origin
121 if origin is None:
123 if origin is None:
122 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
124 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
123 return False
125 return False
124 if host is None:
126 if host is None:
125 self.log.warn("Missing Host header, rejecting WebSocket connection.")
127 self.log.warn("Missing Host header, rejecting WebSocket connection.")
126 return False
128 return False
127
129
128 origin = origin.lower()
130 origin = origin.lower()
129 origin_host = urlparse(origin).netloc
131 origin_host = urlparse(origin).netloc
130
132
131 # OK if origin matches host
133 # OK if origin matches host
132 if origin_host == host:
134 if origin_host == host:
133 return True
135 return True
134
136
135 # Check CORS headers
137 # Check CORS headers
136 if self.allow_origin:
138 if self.allow_origin:
137 allow = self.allow_origin == origin
139 allow = self.allow_origin == origin
138 elif self.allow_origin_pat:
140 elif self.allow_origin_pat:
139 allow = bool(self.allow_origin_pat.match(origin))
141 allow = bool(self.allow_origin_pat.match(origin))
140 else:
142 else:
141 # No CORS headers deny the request
143 # No CORS headers deny the request
142 allow = False
144 allow = False
143 if not allow:
145 if not allow:
144 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
146 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
145 origin, host,
147 origin, host,
146 )
148 )
147 return allow
149 return allow
148
150
149 def clear_cookie(self, *args, **kwargs):
151 def clear_cookie(self, *args, **kwargs):
150 """meaningless for websockets"""
152 """meaningless for websockets"""
151 pass
153 pass
152
154
153 def _reserialize_reply(self, msg_list, channel=None):
155 def _reserialize_reply(self, msg_list, channel=None):
154 """Reserialize a reply message using JSON.
156 """Reserialize a reply message using JSON.
155
157
156 This takes the msg list from the ZMQ socket, deserializes it using
158 This takes the msg list from the ZMQ socket, deserializes it using
157 self.session and then serializes the result using JSON. This method
159 self.session and then serializes the result using JSON. This method
158 should be used by self._on_zmq_reply to build messages that can
160 should be used by self._on_zmq_reply to build messages that can
159 be sent back to the browser.
161 be sent back to the browser.
160 """
162 """
161 idents, msg_list = self.session.feed_identities(msg_list)
163 idents, msg_list = self.session.feed_identities(msg_list)
162 msg = self.session.deserialize(msg_list)
164 msg = self.session.deserialize(msg_list)
163 if channel:
165 if channel:
164 msg['channel'] = channel
166 msg['channel'] = channel
165 if msg['buffers']:
167 if msg['buffers']:
166 buf = serialize_binary_message(msg)
168 buf = serialize_binary_message(msg)
167 return buf
169 return buf
168 else:
170 else:
169 smsg = json.dumps(msg, default=date_default)
171 smsg = json.dumps(msg, default=date_default)
170 return cast_unicode(smsg)
172 return cast_unicode(smsg)
171
173
172 def _on_zmq_reply(self, stream, msg_list):
174 def _on_zmq_reply(self, stream, msg_list):
173 # Sometimes this gets triggered when the on_close method is scheduled in the
175 # Sometimes this gets triggered when the on_close method is scheduled in the
174 # eventloop but hasn't been called.
176 # eventloop but hasn't been called.
175 if self.stream.closed() or stream.closed():
177 if self.stream.closed() or stream.closed():
176 self.log.warn("zmq message arrived on closed channel")
178 self.log.warn("zmq message arrived on closed channel")
177 self.close()
179 self.close()
178 return
180 return
179 channel = getattr(stream, 'channel', None)
181 channel = getattr(stream, 'channel', None)
180 try:
182 try:
181 msg = self._reserialize_reply(msg_list, channel=channel)
183 msg = self._reserialize_reply(msg_list, channel=channel)
182 except Exception:
184 except Exception:
183 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
185 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
184 else:
186 else:
185 self.write_message(msg, binary=isinstance(msg, bytes))
187 self.write_message(msg, binary=isinstance(msg, bytes))
186
188
187 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
189 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
188 ping_callback = None
190 ping_callback = None
189 last_ping = 0
191 last_ping = 0
190 last_pong = 0
192 last_pong = 0
191
193
192 @property
194 @property
193 def ping_interval(self):
195 def ping_interval(self):
194 """The interval for websocket keep-alive pings.
196 """The interval for websocket keep-alive pings.
195
197
196 Set ws_ping_interval = 0 to disable pings.
198 Set ws_ping_interval = 0 to disable pings.
197 """
199 """
198 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
200 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
199
201
200 @property
202 @property
201 def ping_timeout(self):
203 def ping_timeout(self):
202 """If no ping is received in this many milliseconds,
204 """If no ping is received in this many milliseconds,
203 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
205 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
204 Default is max of 3 pings or 30 seconds.
206 Default is max of 3 pings or 30 seconds.
205 """
207 """
206 return self.settings.get('ws_ping_timeout',
208 return self.settings.get('ws_ping_timeout',
207 max(3 * self.ping_interval, WS_PING_INTERVAL)
209 max(3 * self.ping_interval, WS_PING_INTERVAL)
208 )
210 )
209
211
210 def set_default_headers(self):
212 def set_default_headers(self):
211 """Undo the set_default_headers in IPythonHandler
213 """Undo the set_default_headers in IPythonHandler
212
214
213 which doesn't make sense for websockets
215 which doesn't make sense for websockets
214 """
216 """
215 pass
217 pass
216
218
217 def pre_get(self):
219 def pre_get(self):
218 """Run before finishing the GET request
220 """Run before finishing the GET request
219
221
220 Extend this method to add logic that should fire before
222 Extend this method to add logic that should fire before
221 the websocket finishes completing.
223 the websocket finishes completing.
222 """
224 """
223 # authenticate the request before opening the websocket
225 # authenticate the request before opening the websocket
224 if self.get_current_user() is None:
226 if self.get_current_user() is None:
225 self.log.warn("Couldn't authenticate WebSocket connection")
227 self.log.warn("Couldn't authenticate WebSocket connection")
226 raise web.HTTPError(403)
228 raise web.HTTPError(403)
227
229
228 if self.get_argument('session_id', False):
230 if self.get_argument('session_id', False):
229 self.session.session = cast_unicode(self.get_argument('session_id'))
231 self.session.session = cast_unicode(self.get_argument('session_id'))
230 else:
232 else:
231 self.log.warn("No session ID specified")
233 self.log.warn("No session ID specified")
232
234
233 @gen.coroutine
235 @gen.coroutine
234 def get(self, *args, **kwargs):
236 def get(self, *args, **kwargs):
235 # pre_get can be a coroutine in subclasses
237 # pre_get can be a coroutine in subclasses
236 # assign and yield in two step to avoid tornado 3 issues
238 # assign and yield in two step to avoid tornado 3 issues
237 res = self.pre_get()
239 res = self.pre_get()
238 yield gen.maybe_future(res)
240 yield gen.maybe_future(res)
239 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
241 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
240
242
241 def initialize(self):
243 def initialize(self):
242 self.log.debug("Initializing websocket connection %s", self.request.path)
244 self.log.debug("Initializing websocket connection %s", self.request.path)
243 self.session = Session(config=self.config)
245 self.session = Session(config=self.config)
244
246
245 def open(self, *args, **kwargs):
247 def open(self, *args, **kwargs):
246 self.log.debug("Opening websocket %s", self.request.path)
248 self.log.debug("Opening websocket %s", self.request.path)
247
249
248 # start the pinging
250 # start the pinging
249 if self.ping_interval > 0:
251 if self.ping_interval > 0:
250 loop = ioloop.IOLoop.current()
252 loop = ioloop.IOLoop.current()
251 self.last_ping = loop.time() # Remember time of last ping
253 self.last_ping = loop.time() # Remember time of last ping
252 self.last_pong = self.last_ping
254 self.last_pong = self.last_ping
253 self.ping_callback = ioloop.PeriodicCallback(
255 self.ping_callback = ioloop.PeriodicCallback(
254 self.send_ping, self.ping_interval, io_loop=loop,
256 self.send_ping, self.ping_interval, io_loop=loop,
255 )
257 )
256 self.ping_callback.start()
258 self.ping_callback.start()
257
259
258 def send_ping(self):
260 def send_ping(self):
259 """send a ping to keep the websocket alive"""
261 """send a ping to keep the websocket alive"""
260 if self.stream.closed() and self.ping_callback is not None:
262 if self.stream.closed() and self.ping_callback is not None:
261 self.ping_callback.stop()
263 self.ping_callback.stop()
262 return
264 return
263
265
264 # check for timeout on pong. Make sure that we really have sent a recent ping in
266 # check for timeout on pong. Make sure that we really have sent a recent ping in
265 # case the machine with both server and client has been suspended since the last ping.
267 # case the machine with both server and client has been suspended since the last ping.
266 now = ioloop.IOLoop.current().time()
268 now = ioloop.IOLoop.current().time()
267 since_last_pong = 1e3 * (now - self.last_pong)
269 since_last_pong = 1e3 * (now - self.last_pong)
268 since_last_ping = 1e3 * (now - self.last_ping)
270 since_last_ping = 1e3 * (now - self.last_ping)
269 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
271 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
270 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
272 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
271 self.close()
273 self.close()
272 return
274 return
273
275
274 self.ping(b'')
276 self.ping(b'')
275 self.last_ping = now
277 self.last_ping = now
276
278
277 def on_pong(self, data):
279 def on_pong(self, data):
278 self.last_pong = ioloop.IOLoop.current().time()
280 self.last_pong = ioloop.IOLoop.current().time()
General Comments 0
You need to be logged in to leave comments. Login now