##// END OF EJS Templates
handle message arriving when sockets are closed...
Min RK -
Show More
@@ -1,276 +1,278 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import os
8 import json
8 import json
9 import struct
9 import struct
10 import warnings
10 import warnings
11
11
12 try:
12 try:
13 from urllib.parse import urlparse # Py 3
13 from urllib.parse import urlparse # Py 3
14 except ImportError:
14 except ImportError:
15 from urlparse import urlparse # Py 2
15 from urlparse import urlparse # Py 2
16
16
17 import tornado
17 import tornado
18 from tornado import gen, ioloop, web
18 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
19 from tornado.websocket import WebSocketHandler
20
20
21 from IPython.kernel.zmq.session import Session
21 from IPython.kernel.zmq.session import Session
22 from IPython.utils.jsonutil import date_default, extract_dates
22 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.py3compat import cast_unicode
23 from IPython.utils.py3compat import cast_unicode
24
24
25 from .handlers import IPythonHandler
25 from .handlers import IPythonHandler
26
26
27 def serialize_binary_message(msg):
27 def serialize_binary_message(msg):
28 """serialize a message as a binary blob
28 """serialize a message as a binary blob
29
29
30 Header:
30 Header:
31
31
32 4 bytes: number of msg parts (nbufs) as 32b int
32 4 bytes: number of msg parts (nbufs) as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34
34
35 Offsets are from the start of the buffer, including the header.
35 Offsets are from the start of the buffer, including the header.
36
36
37 Returns
37 Returns
38 -------
38 -------
39
39
40 The message serialized to bytes.
40 The message serialized to bytes.
41
41
42 """
42 """
43 # don't modify msg or buffer list in-place
43 # don't modify msg or buffer list in-place
44 msg = msg.copy()
44 msg = msg.copy()
45 buffers = list(msg.pop('buffers'))
45 buffers = list(msg.pop('buffers'))
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 buffers.insert(0, bmsg)
47 buffers.insert(0, bmsg)
48 nbufs = len(buffers)
48 nbufs = len(buffers)
49 offsets = [4 * (nbufs + 1)]
49 offsets = [4 * (nbufs + 1)]
50 for buf in buffers[:-1]:
50 for buf in buffers[:-1]:
51 offsets.append(offsets[-1] + len(buf))
51 offsets.append(offsets[-1] + len(buf))
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 buffers.insert(0, offsets_buf)
53 buffers.insert(0, offsets_buf)
54 return b''.join(buffers)
54 return b''.join(buffers)
55
55
56
56
57 def deserialize_binary_message(bmsg):
57 def deserialize_binary_message(bmsg):
58 """deserialize a message from a binary blog
58 """deserialize a message from a binary blog
59
59
60 Header:
60 Header:
61
61
62 4 bytes: number of msg parts (nbufs) as 32b int
62 4 bytes: number of msg parts (nbufs) as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
64
64
65 Offsets are from the start of the buffer, including the header.
65 Offsets are from the start of the buffer, including the header.
66
66
67 Returns
67 Returns
68 -------
68 -------
69
69
70 message dictionary
70 message dictionary
71 """
71 """
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 offsets.append(None)
74 offsets.append(None)
75 bufs = []
75 bufs = []
76 for start, stop in zip(offsets[:-1], offsets[1:]):
76 for start, stop in zip(offsets[:-1], offsets[1:]):
77 bufs.append(bmsg[start:stop])
77 bufs.append(bmsg[start:stop])
78 msg = json.loads(bufs[0].decode('utf8'))
78 msg = json.loads(bufs[0].decode('utf8'))
79 msg['header'] = extract_dates(msg['header'])
79 msg['header'] = extract_dates(msg['header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
81 msg['buffers'] = bufs[1:]
81 msg['buffers'] = bufs[1:]
82 return msg
82 return msg
83
83
84 # ping interval for keeping websockets alive (30 seconds)
84 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
85 WS_PING_INTERVAL = 30000
86
86
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
88 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
89 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
90 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
92 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
93 WS_PING_INTERVAL = 0
94
94
95 class ZMQStreamHandler(WebSocketHandler):
95 class ZMQStreamHandler(WebSocketHandler):
96
96
97 if tornado.version_info < (4,1):
97 if tornado.version_info < (4,1):
98 """Backport send_error from tornado 4.1 to 4.0"""
98 """Backport send_error from tornado 4.1 to 4.0"""
99 def send_error(self, *args, **kwargs):
99 def send_error(self, *args, **kwargs):
100 if self.stream is None:
100 if self.stream is None:
101 super(WebSocketHandler, self).send_error(*args, **kwargs)
101 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 else:
102 else:
103 # If we get an uncaught exception during the handshake,
103 # If we get an uncaught exception during the handshake,
104 # we have no choice but to abruptly close the connection.
104 # we have no choice but to abruptly close the connection.
105 # TODO: for uncaught exceptions after the handshake,
105 # TODO: for uncaught exceptions after the handshake,
106 # we can close the connection more gracefully.
106 # we can close the connection more gracefully.
107 self.stream.close()
107 self.stream.close()
108
108
109
109
110 def check_origin(self, origin):
110 def check_origin(self, origin):
111 """Check Origin == Host or Access-Control-Allow-Origin.
111 """Check Origin == Host or Access-Control-Allow-Origin.
112
112
113 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
113 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
114 We call it explicitly in `open` on Tornado < 4.
115 """
114 """
116 if self.allow_origin == '*':
115 if self.allow_origin == '*':
117 return True
116 return True
118
117
119 host = self.request.headers.get("Host")
118 host = self.request.headers.get("Host")
120
119
121 # If no header is provided, assume we can't verify origin
120 # If no header is provided, assume we can't verify origin
122 if origin is None:
121 if origin is None:
123 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
122 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
124 return False
123 return False
125 if host is None:
124 if host is None:
126 self.log.warn("Missing Host header, rejecting WebSocket connection.")
125 self.log.warn("Missing Host header, rejecting WebSocket connection.")
127 return False
126 return False
128
127
129 origin = origin.lower()
128 origin = origin.lower()
130 origin_host = urlparse(origin).netloc
129 origin_host = urlparse(origin).netloc
131
130
132 # OK if origin matches host
131 # OK if origin matches host
133 if origin_host == host:
132 if origin_host == host:
134 return True
133 return True
135
134
136 # Check CORS headers
135 # Check CORS headers
137 if self.allow_origin:
136 if self.allow_origin:
138 allow = self.allow_origin == origin
137 allow = self.allow_origin == origin
139 elif self.allow_origin_pat:
138 elif self.allow_origin_pat:
140 allow = bool(self.allow_origin_pat.match(origin))
139 allow = bool(self.allow_origin_pat.match(origin))
141 else:
140 else:
142 # No CORS headers deny the request
141 # No CORS headers deny the request
143 allow = False
142 allow = False
144 if not allow:
143 if not allow:
145 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
144 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
146 origin, host,
145 origin, host,
147 )
146 )
148 return allow
147 return allow
149
148
150 def clear_cookie(self, *args, **kwargs):
149 def clear_cookie(self, *args, **kwargs):
151 """meaningless for websockets"""
150 """meaningless for websockets"""
152 pass
151 pass
153
152
154 def _reserialize_reply(self, msg_list, channel=None):
153 def _reserialize_reply(self, msg_list, channel=None):
155 """Reserialize a reply message using JSON.
154 """Reserialize a reply message using JSON.
156
155
157 This takes the msg list from the ZMQ socket, deserializes it using
156 This takes the msg list from the ZMQ socket, deserializes it using
158 self.session and then serializes the result using JSON. This method
157 self.session and then serializes the result using JSON. This method
159 should be used by self._on_zmq_reply to build messages that can
158 should be used by self._on_zmq_reply to build messages that can
160 be sent back to the browser.
159 be sent back to the browser.
161 """
160 """
162 idents, msg_list = self.session.feed_identities(msg_list)
161 idents, msg_list = self.session.feed_identities(msg_list)
163 msg = self.session.deserialize(msg_list)
162 msg = self.session.deserialize(msg_list)
164 if channel:
163 if channel:
165 msg['channel'] = channel
164 msg['channel'] = channel
166 if msg['buffers']:
165 if msg['buffers']:
167 buf = serialize_binary_message(msg)
166 buf = serialize_binary_message(msg)
168 return buf
167 return buf
169 else:
168 else:
170 smsg = json.dumps(msg, default=date_default)
169 smsg = json.dumps(msg, default=date_default)
171 return cast_unicode(smsg)
170 return cast_unicode(smsg)
172
171
173 def _on_zmq_reply(self, stream, msg_list):
172 def _on_zmq_reply(self, stream, msg_list):
174 # Sometimes this gets triggered when the on_close method is scheduled in the
173 # Sometimes this gets triggered when the on_close method is scheduled in the
175 # eventloop but hasn't been called.
174 # eventloop but hasn't been called.
176 if stream.closed(): return
175 if self.stream.closed() or stream.closed():
176 self.log.warn("zmq message arrived on closed channel")
177 self.close()
178 return
177 channel = getattr(stream, 'channel', None)
179 channel = getattr(stream, 'channel', None)
178 try:
180 try:
179 msg = self._reserialize_reply(msg_list, channel=channel)
181 msg = self._reserialize_reply(msg_list, channel=channel)
180 except Exception:
182 except Exception:
181 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
183 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
182 else:
184 else:
183 self.write_message(msg, binary=isinstance(msg, bytes))
185 self.write_message(msg, binary=isinstance(msg, bytes))
184
186
185 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
187 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
186 ping_callback = None
188 ping_callback = None
187 last_ping = 0
189 last_ping = 0
188 last_pong = 0
190 last_pong = 0
189
191
190 @property
192 @property
191 def ping_interval(self):
193 def ping_interval(self):
192 """The interval for websocket keep-alive pings.
194 """The interval for websocket keep-alive pings.
193
195
194 Set ws_ping_interval = 0 to disable pings.
196 Set ws_ping_interval = 0 to disable pings.
195 """
197 """
196 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
198 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
197
199
198 @property
200 @property
199 def ping_timeout(self):
201 def ping_timeout(self):
200 """If no ping is received in this many milliseconds,
202 """If no ping is received in this many milliseconds,
201 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
203 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
202 Default is max of 3 pings or 30 seconds.
204 Default is max of 3 pings or 30 seconds.
203 """
205 """
204 return self.settings.get('ws_ping_timeout',
206 return self.settings.get('ws_ping_timeout',
205 max(3 * self.ping_interval, WS_PING_INTERVAL)
207 max(3 * self.ping_interval, WS_PING_INTERVAL)
206 )
208 )
207
209
208 def set_default_headers(self):
210 def set_default_headers(self):
209 """Undo the set_default_headers in IPythonHandler
211 """Undo the set_default_headers in IPythonHandler
210
212
211 which doesn't make sense for websockets
213 which doesn't make sense for websockets
212 """
214 """
213 pass
215 pass
214
216
215 def pre_get(self):
217 def pre_get(self):
216 """Run before finishing the GET request
218 """Run before finishing the GET request
217
219
218 Extend this method to add logic that should fire before
220 Extend this method to add logic that should fire before
219 the websocket finishes completing.
221 the websocket finishes completing.
220 """
222 """
221 # authenticate the request before opening the websocket
223 # authenticate the request before opening the websocket
222 if self.get_current_user() is None:
224 if self.get_current_user() is None:
223 self.log.warn("Couldn't authenticate WebSocket connection")
225 self.log.warn("Couldn't authenticate WebSocket connection")
224 raise web.HTTPError(403)
226 raise web.HTTPError(403)
225
227
226 if self.get_argument('session_id', False):
228 if self.get_argument('session_id', False):
227 self.session.session = cast_unicode(self.get_argument('session_id'))
229 self.session.session = cast_unicode(self.get_argument('session_id'))
228 else:
230 else:
229 self.log.warn("No session ID specified")
231 self.log.warn("No session ID specified")
230
232
231 @gen.coroutine
233 @gen.coroutine
232 def get(self, *args, **kwargs):
234 def get(self, *args, **kwargs):
233 # pre_get can be a coroutine in subclasses
235 # pre_get can be a coroutine in subclasses
234 # assign and yield in two step to avoid tornado 3 issues
236 # assign and yield in two step to avoid tornado 3 issues
235 res = self.pre_get()
237 res = self.pre_get()
236 yield gen.maybe_future(res)
238 yield gen.maybe_future(res)
237 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
239 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
238
240
239 def initialize(self):
241 def initialize(self):
240 self.log.debug("Initializing websocket connection %s", self.request.path)
242 self.log.debug("Initializing websocket connection %s", self.request.path)
241 self.session = Session(config=self.config)
243 self.session = Session(config=self.config)
242
244
243 def open(self, *args, **kwargs):
245 def open(self, *args, **kwargs):
244 self.log.debug("Opening websocket %s", self.request.path)
246 self.log.debug("Opening websocket %s", self.request.path)
245
247
246 # start the pinging
248 # start the pinging
247 if self.ping_interval > 0:
249 if self.ping_interval > 0:
248 loop = ioloop.IOLoop.current()
250 loop = ioloop.IOLoop.current()
249 self.last_ping = loop.time() # Remember time of last ping
251 self.last_ping = loop.time() # Remember time of last ping
250 self.last_pong = self.last_ping
252 self.last_pong = self.last_ping
251 self.ping_callback = ioloop.PeriodicCallback(
253 self.ping_callback = ioloop.PeriodicCallback(
252 self.send_ping, self.ping_interval, io_loop=loop,
254 self.send_ping, self.ping_interval, io_loop=loop,
253 )
255 )
254 self.ping_callback.start()
256 self.ping_callback.start()
255
257
256 def send_ping(self):
258 def send_ping(self):
257 """send a ping to keep the websocket alive"""
259 """send a ping to keep the websocket alive"""
258 if self.stream.closed() and self.ping_callback is not None:
260 if self.stream.closed() and self.ping_callback is not None:
259 self.ping_callback.stop()
261 self.ping_callback.stop()
260 return
262 return
261
263
262 # check for timeout on pong. Make sure that we really have sent a recent ping in
264 # check for timeout on pong. Make sure that we really have sent a recent ping in
263 # case the machine with both server and client has been suspended since the last ping.
265 # case the machine with both server and client has been suspended since the last ping.
264 now = ioloop.IOLoop.current().time()
266 now = ioloop.IOLoop.current().time()
265 since_last_pong = 1e3 * (now - self.last_pong)
267 since_last_pong = 1e3 * (now - self.last_pong)
266 since_last_ping = 1e3 * (now - self.last_ping)
268 since_last_ping = 1e3 * (now - self.last_ping)
267 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
269 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
268 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
270 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
269 self.close()
271 self.close()
270 return
272 return
271
273
272 self.ping(b'')
274 self.ping(b'')
273 self.last_ping = now
275 self.last_ping = now
274
276
275 def on_pong(self, data):
277 def on_pong(self, data):
276 self.last_pong = ioloop.IOLoop.current().time()
278 self.last_pong = ioloop.IOLoop.current().time()
General Comments 0
You need to be logged in to leave comments. Login now