##// END OF EJS Templates
Merge pull request #6831 from Carreau/fix-yield...
Min RK -
r18562:241ef9c1 merge
parent child Browse files
Show More
@@ -1,269 +1,271 b''
1 1 # coding: utf-8
2 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import json
8 8 import struct
9 9
10 10 try:
11 11 from urllib.parse import urlparse # Py 3
12 12 except ImportError:
13 13 from urlparse import urlparse # Py 2
14 14
15 15 import tornado
16 16 from tornado import gen, ioloop, web, websocket
17 17
18 18 from IPython.kernel.zmq.session import Session
19 19 from IPython.utils.jsonutil import date_default, extract_dates
20 20 from IPython.utils.py3compat import cast_unicode
21 21
22 22 from .handlers import IPythonHandler
23 23
24 24
25 25 def serialize_binary_message(msg):
26 26 """serialize a message as a binary blob
27 27
28 28 Header:
29 29
30 30 4 bytes: number of msg parts (nbufs) as 32b int
31 31 4 * nbufs bytes: offset for each buffer as integer as 32b int
32 32
33 33 Offsets are from the start of the buffer, including the header.
34 34
35 35 Returns
36 36 -------
37 37
38 38 The message serialized to bytes.
39 39
40 40 """
41 41 # don't modify msg or buffer list in-place
42 42 msg = msg.copy()
43 43 buffers = list(msg.pop('buffers'))
44 44 bmsg = json.dumps(msg, default=date_default).encode('utf8')
45 45 buffers.insert(0, bmsg)
46 46 nbufs = len(buffers)
47 47 offsets = [4 * (nbufs + 1)]
48 48 for buf in buffers[:-1]:
49 49 offsets.append(offsets[-1] + len(buf))
50 50 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
51 51 buffers.insert(0, offsets_buf)
52 52 return b''.join(buffers)
53 53
54 54
55 55 def deserialize_binary_message(bmsg):
56 56 """deserialize a message from a binary blog
57 57
58 58 Header:
59 59
60 60 4 bytes: number of msg parts (nbufs) as 32b int
61 61 4 * nbufs bytes: offset for each buffer as integer as 32b int
62 62
63 63 Offsets are from the start of the buffer, including the header.
64 64
65 65 Returns
66 66 -------
67 67
68 68 message dictionary
69 69 """
70 70 nbufs = struct.unpack('!i', bmsg[:4])[0]
71 71 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
72 72 offsets.append(None)
73 73 bufs = []
74 74 for start, stop in zip(offsets[:-1], offsets[1:]):
75 75 bufs.append(bmsg[start:stop])
76 76 msg = json.loads(bufs[0].decode('utf8'))
77 77 msg['header'] = extract_dates(msg['header'])
78 78 msg['parent_header'] = extract_dates(msg['parent_header'])
79 79 msg['buffers'] = bufs[1:]
80 80 return msg
81 81
82 82
83 83 class ZMQStreamHandler(websocket.WebSocketHandler):
84 84
85 85 def check_origin(self, origin):
86 86 """Check Origin == Host or Access-Control-Allow-Origin.
87 87
88 88 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
89 89 We call it explicitly in `open` on Tornado < 4.
90 90 """
91 91 if self.allow_origin == '*':
92 92 return True
93 93
94 94 host = self.request.headers.get("Host")
95 95
96 96 # If no header is provided, assume we can't verify origin
97 97 if origin is None:
98 98 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
99 99 return False
100 100 if host is None:
101 101 self.log.warn("Missing Host header, rejecting WebSocket connection.")
102 102 return False
103 103
104 104 origin = origin.lower()
105 105 origin_host = urlparse(origin).netloc
106 106
107 107 # OK if origin matches host
108 108 if origin_host == host:
109 109 return True
110 110
111 111 # Check CORS headers
112 112 if self.allow_origin:
113 113 allow = self.allow_origin == origin
114 114 elif self.allow_origin_pat:
115 115 allow = bool(self.allow_origin_pat.match(origin))
116 116 else:
117 117 # No CORS headers deny the request
118 118 allow = False
119 119 if not allow:
120 120 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
121 121 origin, host,
122 122 )
123 123 return allow
124 124
125 125 def clear_cookie(self, *args, **kwargs):
126 126 """meaningless for websockets"""
127 127 pass
128 128
129 129 def _reserialize_reply(self, msg_list):
130 130 """Reserialize a reply message using JSON.
131 131
132 132 This takes the msg list from the ZMQ socket, deserializes it using
133 133 self.session and then serializes the result using JSON. This method
134 134 should be used by self._on_zmq_reply to build messages that can
135 135 be sent back to the browser.
136 136 """
137 137 idents, msg_list = self.session.feed_identities(msg_list)
138 138 msg = self.session.deserialize(msg_list)
139 139 if msg['buffers']:
140 140 buf = serialize_binary_message(msg)
141 141 return buf
142 142 else:
143 143 smsg = json.dumps(msg, default=date_default)
144 144 return cast_unicode(smsg)
145 145
146 146 def _on_zmq_reply(self, msg_list):
147 147 # Sometimes this gets triggered when the on_close method is scheduled in the
148 148 # eventloop but hasn't been called.
149 149 if self.stream.closed(): return
150 150 try:
151 151 msg = self._reserialize_reply(msg_list)
152 152 except Exception:
153 153 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
154 154 else:
155 155 self.write_message(msg, binary=isinstance(msg, bytes))
156 156
157 157 def allow_draft76(self):
158 158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
159 159
160 160 This has been disabled by default in tornado in release 2.2.0, and
161 161 support will be removed in later versions.
162 162 """
163 163 return True
164 164
165 165 # ping interval for keeping websockets alive (30 seconds)
166 166 WS_PING_INTERVAL = 30000
167 167
168 168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 169 ping_callback = None
170 170 last_ping = 0
171 171 last_pong = 0
172 172
173 173 @property
174 174 def ping_interval(self):
175 175 """The interval for websocket keep-alive pings.
176 176
177 177 Set ws_ping_interval = 0 to disable pings.
178 178 """
179 179 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
180 180
181 181 @property
182 182 def ping_timeout(self):
183 183 """If no ping is received in this many milliseconds,
184 184 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
185 185 Default is max of 3 pings or 30 seconds.
186 186 """
187 187 return self.settings.get('ws_ping_timeout',
188 188 max(3 * self.ping_interval, WS_PING_INTERVAL)
189 189 )
190 190
191 191 def set_default_headers(self):
192 192 """Undo the set_default_headers in IPythonHandler
193 193
194 194 which doesn't make sense for websockets
195 195 """
196 196 pass
197 197
198 198 def pre_get(self):
199 199 """Run before finishing the GET request
200 200
201 201 Extend this method to add logic that should fire before
202 202 the websocket finishes completing.
203 203 """
204 204 # Check to see that origin matches host directly, including ports
205 205 # Tornado 4 already does CORS checking
206 206 if tornado.version_info[0] < 4:
207 207 if not self.check_origin(self.get_origin()):
208 208 raise web.HTTPError(403)
209 209
210 210 # authenticate the request before opening the websocket
211 211 if self.get_current_user() is None:
212 212 self.log.warn("Couldn't authenticate WebSocket connection")
213 213 raise web.HTTPError(403)
214 214
215 215 if self.get_argument('session_id', False):
216 216 self.session.session = cast_unicode(self.get_argument('session_id'))
217 217 else:
218 218 self.log.warn("No session ID specified")
219 219
220 220 @gen.coroutine
221 221 def get(self, *args, **kwargs):
222 222 # pre_get can be a coroutine in subclasses
223 yield gen.maybe_future(self.pre_get())
223 # assign and yield in two step to avoid tornado 3 issues
224 res = self.pre_get()
225 yield gen.maybe_future(res)
224 226 # FIXME: only do super get on tornado β‰₯ 4
225 227 # tornado 3 has no get, will raise 405
226 228 if tornado.version_info >= (4,):
227 229 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
228 230
229 231 def initialize(self):
230 232 self.log.debug("Initializing websocket connection %s", self.request.path)
231 233 self.session = Session(config=self.config)
232 234
233 235 def open(self, *args, **kwargs):
234 236 self.log.debug("Opening websocket %s", self.request.path)
235 237 if tornado.version_info < (4,):
236 238 try:
237 239 self.get(*self.open_args, **self.open_kwargs)
238 240 except web.HTTPError:
239 241 self.close()
240 242 raise
241 243
242 244 # start the pinging
243 245 if self.ping_interval > 0:
244 246 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
245 247 self.last_pong = self.last_ping
246 248 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
247 249 self.ping_callback.start()
248 250
249 251 def send_ping(self):
250 252 """send a ping to keep the websocket alive"""
251 253 if self.stream.closed() and self.ping_callback is not None:
252 254 self.ping_callback.stop()
253 255 return
254 256
255 257 # check for timeout on pong. Make sure that we really have sent a recent ping in
256 258 # case the machine with both server and client has been suspended since the last ping.
257 259 now = ioloop.IOLoop.instance().time()
258 260 since_last_pong = 1e3 * (now - self.last_pong)
259 261 since_last_ping = 1e3 * (now - self.last_ping)
260 262 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
261 263 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
262 264 self.close()
263 265 return
264 266
265 267 self.ping(b'')
266 268 self.last_ping = now
267 269
268 270 def on_pong(self, data):
269 271 self.last_pong = ioloop.IOLoop.instance().time()
General Comments 0
You need to be logged in to leave comments. Login now