##// END OF EJS Templates
handle message arriving when sockets are closed...
Min RK -
Show More
@@ -1,276 +1,278 b''
1 1 # coding: utf-8
2 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import os
8 8 import json
9 9 import struct
10 10 import warnings
11 11
12 12 try:
13 13 from urllib.parse import urlparse # Py 3
14 14 except ImportError:
15 15 from urlparse import urlparse # Py 2
16 16
17 17 import tornado
18 18 from tornado import gen, ioloop, web
19 19 from tornado.websocket import WebSocketHandler
20 20
21 21 from IPython.kernel.zmq.session import Session
22 22 from IPython.utils.jsonutil import date_default, extract_dates
23 23 from IPython.utils.py3compat import cast_unicode
24 24
25 25 from .handlers import IPythonHandler
26 26
27 27 def serialize_binary_message(msg):
28 28 """serialize a message as a binary blob
29 29
30 30 Header:
31 31
32 32 4 bytes: number of msg parts (nbufs) as 32b int
33 33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34 34
35 35 Offsets are from the start of the buffer, including the header.
36 36
37 37 Returns
38 38 -------
39 39
40 40 The message serialized to bytes.
41 41
42 42 """
43 43 # don't modify msg or buffer list in-place
44 44 msg = msg.copy()
45 45 buffers = list(msg.pop('buffers'))
46 46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 47 buffers.insert(0, bmsg)
48 48 nbufs = len(buffers)
49 49 offsets = [4 * (nbufs + 1)]
50 50 for buf in buffers[:-1]:
51 51 offsets.append(offsets[-1] + len(buf))
52 52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 53 buffers.insert(0, offsets_buf)
54 54 return b''.join(buffers)
55 55
56 56
57 57 def deserialize_binary_message(bmsg):
58 58 """deserialize a message from a binary blog
59 59
60 60 Header:
61 61
62 62 4 bytes: number of msg parts (nbufs) as 32b int
63 63 4 * nbufs bytes: offset for each buffer as integer as 32b int
64 64
65 65 Offsets are from the start of the buffer, including the header.
66 66
67 67 Returns
68 68 -------
69 69
70 70 message dictionary
71 71 """
72 72 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 74 offsets.append(None)
75 75 bufs = []
76 76 for start, stop in zip(offsets[:-1], offsets[1:]):
77 77 bufs.append(bmsg[start:stop])
78 78 msg = json.loads(bufs[0].decode('utf8'))
79 79 msg['header'] = extract_dates(msg['header'])
80 80 msg['parent_header'] = extract_dates(msg['parent_header'])
81 81 msg['buffers'] = bufs[1:]
82 82 return msg
83 83
84 84 # ping interval for keeping websockets alive (30 seconds)
85 85 WS_PING_INTERVAL = 30000
86 86
87 87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 88 warnings.warn("""Allowing draft76 websocket connections!
89 89 This should only be done for testing with phantomjs!""")
90 90 from IPython.html import allow76
91 91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 92 # draft 76 doesn't support ping
93 93 WS_PING_INTERVAL = 0
94 94
95 95 class ZMQStreamHandler(WebSocketHandler):
96 96
97 97 if tornado.version_info < (4,1):
98 98 """Backport send_error from tornado 4.1 to 4.0"""
99 99 def send_error(self, *args, **kwargs):
100 100 if self.stream is None:
101 101 super(WebSocketHandler, self).send_error(*args, **kwargs)
102 102 else:
103 103 # If we get an uncaught exception during the handshake,
104 104 # we have no choice but to abruptly close the connection.
105 105 # TODO: for uncaught exceptions after the handshake,
106 106 # we can close the connection more gracefully.
107 107 self.stream.close()
108 108
109
109
110 110 def check_origin(self, origin):
111 111 """Check Origin == Host or Access-Control-Allow-Origin.
112 112
113 113 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
114 We call it explicitly in `open` on Tornado < 4.
115 114 """
116 115 if self.allow_origin == '*':
117 116 return True
118 117
119 118 host = self.request.headers.get("Host")
120 119
121 120 # If no header is provided, assume we can't verify origin
122 121 if origin is None:
123 122 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
124 123 return False
125 124 if host is None:
126 125 self.log.warn("Missing Host header, rejecting WebSocket connection.")
127 126 return False
128 127
129 128 origin = origin.lower()
130 129 origin_host = urlparse(origin).netloc
131 130
132 131 # OK if origin matches host
133 132 if origin_host == host:
134 133 return True
135 134
136 135 # Check CORS headers
137 136 if self.allow_origin:
138 137 allow = self.allow_origin == origin
139 138 elif self.allow_origin_pat:
140 139 allow = bool(self.allow_origin_pat.match(origin))
141 140 else:
142 141 # No CORS headers deny the request
143 142 allow = False
144 143 if not allow:
145 144 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
146 145 origin, host,
147 146 )
148 147 return allow
149 148
150 149 def clear_cookie(self, *args, **kwargs):
151 150 """meaningless for websockets"""
152 151 pass
153 152
154 153 def _reserialize_reply(self, msg_list, channel=None):
155 154 """Reserialize a reply message using JSON.
156 155
157 156 This takes the msg list from the ZMQ socket, deserializes it using
158 157 self.session and then serializes the result using JSON. This method
159 158 should be used by self._on_zmq_reply to build messages that can
160 159 be sent back to the browser.
161 160 """
162 161 idents, msg_list = self.session.feed_identities(msg_list)
163 162 msg = self.session.deserialize(msg_list)
164 163 if channel:
165 164 msg['channel'] = channel
166 165 if msg['buffers']:
167 166 buf = serialize_binary_message(msg)
168 167 return buf
169 168 else:
170 169 smsg = json.dumps(msg, default=date_default)
171 170 return cast_unicode(smsg)
172 171
173 172 def _on_zmq_reply(self, stream, msg_list):
174 173 # Sometimes this gets triggered when the on_close method is scheduled in the
175 174 # eventloop but hasn't been called.
176 if stream.closed(): return
175 if self.stream.closed() or stream.closed():
176 self.log.warn("zmq message arrived on closed channel")
177 self.close()
178 return
177 179 channel = getattr(stream, 'channel', None)
178 180 try:
179 181 msg = self._reserialize_reply(msg_list, channel=channel)
180 182 except Exception:
181 183 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
182 184 else:
183 185 self.write_message(msg, binary=isinstance(msg, bytes))
184 186
185 187 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
186 188 ping_callback = None
187 189 last_ping = 0
188 190 last_pong = 0
189 191
190 192 @property
191 193 def ping_interval(self):
192 194 """The interval for websocket keep-alive pings.
193 195
194 196 Set ws_ping_interval = 0 to disable pings.
195 197 """
196 198 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
197 199
198 200 @property
199 201 def ping_timeout(self):
200 202 """If no ping is received in this many milliseconds,
201 203 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
202 204 Default is max of 3 pings or 30 seconds.
203 205 """
204 206 return self.settings.get('ws_ping_timeout',
205 207 max(3 * self.ping_interval, WS_PING_INTERVAL)
206 208 )
207 209
208 210 def set_default_headers(self):
209 211 """Undo the set_default_headers in IPythonHandler
210 212
211 213 which doesn't make sense for websockets
212 214 """
213 215 pass
214 216
215 217 def pre_get(self):
216 218 """Run before finishing the GET request
217 219
218 220 Extend this method to add logic that should fire before
219 221 the websocket finishes completing.
220 222 """
221 223 # authenticate the request before opening the websocket
222 224 if self.get_current_user() is None:
223 225 self.log.warn("Couldn't authenticate WebSocket connection")
224 226 raise web.HTTPError(403)
225 227
226 228 if self.get_argument('session_id', False):
227 229 self.session.session = cast_unicode(self.get_argument('session_id'))
228 230 else:
229 231 self.log.warn("No session ID specified")
230 232
231 233 @gen.coroutine
232 234 def get(self, *args, **kwargs):
233 235 # pre_get can be a coroutine in subclasses
234 236 # assign and yield in two step to avoid tornado 3 issues
235 237 res = self.pre_get()
236 238 yield gen.maybe_future(res)
237 239 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
238 240
239 241 def initialize(self):
240 242 self.log.debug("Initializing websocket connection %s", self.request.path)
241 243 self.session = Session(config=self.config)
242 244
243 245 def open(self, *args, **kwargs):
244 246 self.log.debug("Opening websocket %s", self.request.path)
245 247
246 248 # start the pinging
247 249 if self.ping_interval > 0:
248 250 loop = ioloop.IOLoop.current()
249 251 self.last_ping = loop.time() # Remember time of last ping
250 252 self.last_pong = self.last_ping
251 253 self.ping_callback = ioloop.PeriodicCallback(
252 254 self.send_ping, self.ping_interval, io_loop=loop,
253 255 )
254 256 self.ping_callback.start()
255 257
256 258 def send_ping(self):
257 259 """send a ping to keep the websocket alive"""
258 260 if self.stream.closed() and self.ping_callback is not None:
259 261 self.ping_callback.stop()
260 262 return
261 263
262 264 # check for timeout on pong. Make sure that we really have sent a recent ping in
263 265 # case the machine with both server and client has been suspended since the last ping.
264 266 now = ioloop.IOLoop.current().time()
265 267 since_last_pong = 1e3 * (now - self.last_pong)
266 268 since_last_ping = 1e3 * (now - self.last_ping)
267 269 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
268 270 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
269 271 self.close()
270 272 return
271 273
272 274 self.ping(b'')
273 275 self.last_ping = now
274 276
275 277 def on_pong(self, data):
276 278 self.last_pong = ioloop.IOLoop.current().time()
General Comments 0
You need to be logged in to leave comments. Login now