##// END OF EJS Templates
backport WebSocket.send_error from tornado 4.1...
Min RK -
Show More
@@ -1,263 +1,276 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import os
8 import json
8 import json
9 import struct
9 import struct
10 import warnings
10 import warnings
11
11
12 try:
12 try:
13 from urllib.parse import urlparse # Py 3
13 from urllib.parse import urlparse # Py 3
14 except ImportError:
14 except ImportError:
15 from urlparse import urlparse # Py 2
15 from urlparse import urlparse # Py 2
16
16
17 import tornado
17 import tornado
18 from tornado import gen, ioloop, web
18 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
19 from tornado.websocket import WebSocketHandler
20
20
21 from IPython.kernel.zmq.session import Session
21 from IPython.kernel.zmq.session import Session
22 from IPython.utils.jsonutil import date_default, extract_dates
22 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.py3compat import cast_unicode
23 from IPython.utils.py3compat import cast_unicode
24
24
25 from .handlers import IPythonHandler
25 from .handlers import IPythonHandler
26
26
27 def serialize_binary_message(msg):
27 def serialize_binary_message(msg):
28 """serialize a message as a binary blob
28 """serialize a message as a binary blob
29
29
30 Header:
30 Header:
31
31
32 4 bytes: number of msg parts (nbufs) as 32b int
32 4 bytes: number of msg parts (nbufs) as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34
34
35 Offsets are from the start of the buffer, including the header.
35 Offsets are from the start of the buffer, including the header.
36
36
37 Returns
37 Returns
38 -------
38 -------
39
39
40 The message serialized to bytes.
40 The message serialized to bytes.
41
41
42 """
42 """
43 # don't modify msg or buffer list in-place
43 # don't modify msg or buffer list in-place
44 msg = msg.copy()
44 msg = msg.copy()
45 buffers = list(msg.pop('buffers'))
45 buffers = list(msg.pop('buffers'))
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 buffers.insert(0, bmsg)
47 buffers.insert(0, bmsg)
48 nbufs = len(buffers)
48 nbufs = len(buffers)
49 offsets = [4 * (nbufs + 1)]
49 offsets = [4 * (nbufs + 1)]
50 for buf in buffers[:-1]:
50 for buf in buffers[:-1]:
51 offsets.append(offsets[-1] + len(buf))
51 offsets.append(offsets[-1] + len(buf))
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 buffers.insert(0, offsets_buf)
53 buffers.insert(0, offsets_buf)
54 return b''.join(buffers)
54 return b''.join(buffers)
55
55
56
56
57 def deserialize_binary_message(bmsg):
57 def deserialize_binary_message(bmsg):
58 """deserialize a message from a binary blog
58 """deserialize a message from a binary blog
59
59
60 Header:
60 Header:
61
61
62 4 bytes: number of msg parts (nbufs) as 32b int
62 4 bytes: number of msg parts (nbufs) as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
64
64
65 Offsets are from the start of the buffer, including the header.
65 Offsets are from the start of the buffer, including the header.
66
66
67 Returns
67 Returns
68 -------
68 -------
69
69
70 message dictionary
70 message dictionary
71 """
71 """
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 offsets.append(None)
74 offsets.append(None)
75 bufs = []
75 bufs = []
76 for start, stop in zip(offsets[:-1], offsets[1:]):
76 for start, stop in zip(offsets[:-1], offsets[1:]):
77 bufs.append(bmsg[start:stop])
77 bufs.append(bmsg[start:stop])
78 msg = json.loads(bufs[0].decode('utf8'))
78 msg = json.loads(bufs[0].decode('utf8'))
79 msg['header'] = extract_dates(msg['header'])
79 msg['header'] = extract_dates(msg['header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
81 msg['buffers'] = bufs[1:]
81 msg['buffers'] = bufs[1:]
82 return msg
82 return msg
83
83
84 # ping interval for keeping websockets alive (30 seconds)
84 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
85 WS_PING_INTERVAL = 30000
86
86
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
88 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
89 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
90 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
92 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
93 WS_PING_INTERVAL = 0
94
94
95 class ZMQStreamHandler(WebSocketHandler):
95 class ZMQStreamHandler(WebSocketHandler):
96
96
97 if tornado.version_info < (4,1):
98 """Backport send_error from tornado 4.1 to 4.0"""
99 def send_error(self, *args, **kwargs):
100 if self.stream is None:
101 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 else:
103 # If we get an uncaught exception during the handshake,
104 # we have no choice but to abruptly close the connection.
105 # TODO: for uncaught exceptions after the handshake,
106 # we can close the connection more gracefully.
107 self.stream.close()
108
109
97 def check_origin(self, origin):
110 def check_origin(self, origin):
98 """Check Origin == Host or Access-Control-Allow-Origin.
111 """Check Origin == Host or Access-Control-Allow-Origin.
99
112
100 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
113 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
101 We call it explicitly in `open` on Tornado < 4.
114 We call it explicitly in `open` on Tornado < 4.
102 """
115 """
103 if self.allow_origin == '*':
116 if self.allow_origin == '*':
104 return True
117 return True
105
118
106 host = self.request.headers.get("Host")
119 host = self.request.headers.get("Host")
107
120
108 # If no header is provided, assume we can't verify origin
121 # If no header is provided, assume we can't verify origin
109 if origin is None:
122 if origin is None:
110 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
123 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
111 return False
124 return False
112 if host is None:
125 if host is None:
113 self.log.warn("Missing Host header, rejecting WebSocket connection.")
126 self.log.warn("Missing Host header, rejecting WebSocket connection.")
114 return False
127 return False
115
128
116 origin = origin.lower()
129 origin = origin.lower()
117 origin_host = urlparse(origin).netloc
130 origin_host = urlparse(origin).netloc
118
131
119 # OK if origin matches host
132 # OK if origin matches host
120 if origin_host == host:
133 if origin_host == host:
121 return True
134 return True
122
135
123 # Check CORS headers
136 # Check CORS headers
124 if self.allow_origin:
137 if self.allow_origin:
125 allow = self.allow_origin == origin
138 allow = self.allow_origin == origin
126 elif self.allow_origin_pat:
139 elif self.allow_origin_pat:
127 allow = bool(self.allow_origin_pat.match(origin))
140 allow = bool(self.allow_origin_pat.match(origin))
128 else:
141 else:
129 # No CORS headers deny the request
142 # No CORS headers deny the request
130 allow = False
143 allow = False
131 if not allow:
144 if not allow:
132 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
145 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
133 origin, host,
146 origin, host,
134 )
147 )
135 return allow
148 return allow
136
149
137 def clear_cookie(self, *args, **kwargs):
150 def clear_cookie(self, *args, **kwargs):
138 """meaningless for websockets"""
151 """meaningless for websockets"""
139 pass
152 pass
140
153
141 def _reserialize_reply(self, msg_list, channel=None):
154 def _reserialize_reply(self, msg_list, channel=None):
142 """Reserialize a reply message using JSON.
155 """Reserialize a reply message using JSON.
143
156
144 This takes the msg list from the ZMQ socket, deserializes it using
157 This takes the msg list from the ZMQ socket, deserializes it using
145 self.session and then serializes the result using JSON. This method
158 self.session and then serializes the result using JSON. This method
146 should be used by self._on_zmq_reply to build messages that can
159 should be used by self._on_zmq_reply to build messages that can
147 be sent back to the browser.
160 be sent back to the browser.
148 """
161 """
149 idents, msg_list = self.session.feed_identities(msg_list)
162 idents, msg_list = self.session.feed_identities(msg_list)
150 msg = self.session.deserialize(msg_list)
163 msg = self.session.deserialize(msg_list)
151 if channel:
164 if channel:
152 msg['channel'] = channel
165 msg['channel'] = channel
153 if msg['buffers']:
166 if msg['buffers']:
154 buf = serialize_binary_message(msg)
167 buf = serialize_binary_message(msg)
155 return buf
168 return buf
156 else:
169 else:
157 smsg = json.dumps(msg, default=date_default)
170 smsg = json.dumps(msg, default=date_default)
158 return cast_unicode(smsg)
171 return cast_unicode(smsg)
159
172
160 def _on_zmq_reply(self, stream, msg_list):
173 def _on_zmq_reply(self, stream, msg_list):
161 # Sometimes this gets triggered when the on_close method is scheduled in the
174 # Sometimes this gets triggered when the on_close method is scheduled in the
162 # eventloop but hasn't been called.
175 # eventloop but hasn't been called.
163 if stream.closed(): return
176 if stream.closed(): return
164 channel = getattr(stream, 'channel', None)
177 channel = getattr(stream, 'channel', None)
165 try:
178 try:
166 msg = self._reserialize_reply(msg_list, channel=channel)
179 msg = self._reserialize_reply(msg_list, channel=channel)
167 except Exception:
180 except Exception:
168 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
181 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
169 else:
182 else:
170 self.write_message(msg, binary=isinstance(msg, bytes))
183 self.write_message(msg, binary=isinstance(msg, bytes))
171
184
172 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
185 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
173 ping_callback = None
186 ping_callback = None
174 last_ping = 0
187 last_ping = 0
175 last_pong = 0
188 last_pong = 0
176
189
177 @property
190 @property
178 def ping_interval(self):
191 def ping_interval(self):
179 """The interval for websocket keep-alive pings.
192 """The interval for websocket keep-alive pings.
180
193
181 Set ws_ping_interval = 0 to disable pings.
194 Set ws_ping_interval = 0 to disable pings.
182 """
195 """
183 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
196 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
184
197
185 @property
198 @property
186 def ping_timeout(self):
199 def ping_timeout(self):
187 """If no ping is received in this many milliseconds,
200 """If no ping is received in this many milliseconds,
188 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
201 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
189 Default is max of 3 pings or 30 seconds.
202 Default is max of 3 pings or 30 seconds.
190 """
203 """
191 return self.settings.get('ws_ping_timeout',
204 return self.settings.get('ws_ping_timeout',
192 max(3 * self.ping_interval, WS_PING_INTERVAL)
205 max(3 * self.ping_interval, WS_PING_INTERVAL)
193 )
206 )
194
207
195 def set_default_headers(self):
208 def set_default_headers(self):
196 """Undo the set_default_headers in IPythonHandler
209 """Undo the set_default_headers in IPythonHandler
197
210
198 which doesn't make sense for websockets
211 which doesn't make sense for websockets
199 """
212 """
200 pass
213 pass
201
214
202 def pre_get(self):
215 def pre_get(self):
203 """Run before finishing the GET request
216 """Run before finishing the GET request
204
217
205 Extend this method to add logic that should fire before
218 Extend this method to add logic that should fire before
206 the websocket finishes completing.
219 the websocket finishes completing.
207 """
220 """
208 # authenticate the request before opening the websocket
221 # authenticate the request before opening the websocket
209 if self.get_current_user() is None:
222 if self.get_current_user() is None:
210 self.log.warn("Couldn't authenticate WebSocket connection")
223 self.log.warn("Couldn't authenticate WebSocket connection")
211 raise web.HTTPError(403)
224 raise web.HTTPError(403)
212
225
213 if self.get_argument('session_id', False):
226 if self.get_argument('session_id', False):
214 self.session.session = cast_unicode(self.get_argument('session_id'))
227 self.session.session = cast_unicode(self.get_argument('session_id'))
215 else:
228 else:
216 self.log.warn("No session ID specified")
229 self.log.warn("No session ID specified")
217
230
218 @gen.coroutine
231 @gen.coroutine
219 def get(self, *args, **kwargs):
232 def get(self, *args, **kwargs):
220 # pre_get can be a coroutine in subclasses
233 # pre_get can be a coroutine in subclasses
221 # assign and yield in two step to avoid tornado 3 issues
234 # assign and yield in two step to avoid tornado 3 issues
222 res = self.pre_get()
235 res = self.pre_get()
223 yield gen.maybe_future(res)
236 yield gen.maybe_future(res)
224 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
237 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
225
238
226 def initialize(self):
239 def initialize(self):
227 self.log.debug("Initializing websocket connection %s", self.request.path)
240 self.log.debug("Initializing websocket connection %s", self.request.path)
228 self.session = Session(config=self.config)
241 self.session = Session(config=self.config)
229
242
230 def open(self, *args, **kwargs):
243 def open(self, *args, **kwargs):
231 self.log.debug("Opening websocket %s", self.request.path)
244 self.log.debug("Opening websocket %s", self.request.path)
232
245
233 # start the pinging
246 # start the pinging
234 if self.ping_interval > 0:
247 if self.ping_interval > 0:
235 loop = ioloop.IOLoop.current()
248 loop = ioloop.IOLoop.current()
236 self.last_ping = loop.time() # Remember time of last ping
249 self.last_ping = loop.time() # Remember time of last ping
237 self.last_pong = self.last_ping
250 self.last_pong = self.last_ping
238 self.ping_callback = ioloop.PeriodicCallback(
251 self.ping_callback = ioloop.PeriodicCallback(
239 self.send_ping, self.ping_interval, io_loop=loop,
252 self.send_ping, self.ping_interval, io_loop=loop,
240 )
253 )
241 self.ping_callback.start()
254 self.ping_callback.start()
242
255
243 def send_ping(self):
256 def send_ping(self):
244 """send a ping to keep the websocket alive"""
257 """send a ping to keep the websocket alive"""
245 if self.stream.closed() and self.ping_callback is not None:
258 if self.stream.closed() and self.ping_callback is not None:
246 self.ping_callback.stop()
259 self.ping_callback.stop()
247 return
260 return
248
261
249 # check for timeout on pong. Make sure that we really have sent a recent ping in
262 # check for timeout on pong. Make sure that we really have sent a recent ping in
250 # case the machine with both server and client has been suspended since the last ping.
263 # case the machine with both server and client has been suspended since the last ping.
251 now = ioloop.IOLoop.current().time()
264 now = ioloop.IOLoop.current().time()
252 since_last_pong = 1e3 * (now - self.last_pong)
265 since_last_pong = 1e3 * (now - self.last_pong)
253 since_last_ping = 1e3 * (now - self.last_ping)
266 since_last_ping = 1e3 * (now - self.last_ping)
254 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
267 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
255 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
268 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
256 self.close()
269 self.close()
257 return
270 return
258
271
259 self.ping(b'')
272 self.ping(b'')
260 self.last_ping = now
273 self.last_ping = now
261
274
262 def on_pong(self, data):
275 def on_pong(self, data):
263 self.last_pong = ioloop.IOLoop.current().time()
276 self.last_pong = ioloop.IOLoop.current().time()
General Comments 0
You need to be logged in to leave comments. Login now