##// END OF EJS Templates
add websocket workarounds for tornado 3...
MinRK -
Show More
@@ -1,256 +1,259 b''
1 # coding: utf-8
1 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 3
3 4 # Copyright (c) IPython Development Team.
4 5 # Distributed under the terms of the Modified BSD License.
5 6
6 7 import json
7 8 import struct
8 9
9 10 try:
10 11 from urllib.parse import urlparse # Py 3
11 12 except ImportError:
12 13 from urlparse import urlparse # Py 2
13 14
14 try:
15 from http.cookies import SimpleCookie # Py 3
16 except ImportError:
17 from Cookie import SimpleCookie # Py 2
18 import logging
19
20 15 import tornado
21 16 from tornado import ioloop
22 17 from tornado import web
23 18 from tornado import websocket
24 19
25 20 from IPython.kernel.zmq.session import Session
26 21 from IPython.utils.jsonutil import date_default, extract_dates
27 from IPython.utils.py3compat import PY3, cast_unicode
22 from IPython.utils.py3compat import cast_unicode
28 23
29 24 from .handlers import IPythonHandler
30 25
31 26
32 27 def serialize_binary_message(msg):
33 28 """serialize a message as a binary blob
34 29
35 30 Header:
36 31
37 32 4 bytes: number of msg parts (nbufs) as 32b int
38 33 4 * nbufs bytes: offset for each buffer as integer as 32b int
39 34
40 35 Offsets are from the start of the buffer, including the header.
41 36
42 37 Returns
43 38 -------
44 39
45 40 The message serialized to bytes.
46 41
47 42 """
48 43 # don't modify msg or buffer list in-place
49 44 msg = msg.copy()
50 45 buffers = list(msg.pop('buffers'))
51 46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
52 47 buffers.insert(0, bmsg)
53 48 nbufs = len(buffers)
54 49 offsets = [4 * (nbufs + 1)]
55 50 for buf in buffers[:-1]:
56 51 offsets.append(offsets[-1] + len(buf))
57 52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
58 53 buffers.insert(0, offsets_buf)
59 54 return b''.join(buffers)
60 55
61 56
62 57 def deserialize_binary_message(bmsg):
63 58 """deserialize a message from a binary blog
64 59
65 60 Header:
66 61
67 62 4 bytes: number of msg parts (nbufs) as 32b int
68 63 4 * nbufs bytes: offset for each buffer as integer as 32b int
69 64
70 65 Offsets are from the start of the buffer, including the header.
71 66
72 67 Returns
73 68 -------
74 69
75 70 message dictionary
76 71 """
77 72 nbufs = struct.unpack('!i', bmsg[:4])[0]
78 73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
79 74 offsets.append(None)
80 75 bufs = []
81 76 for start, stop in zip(offsets[:-1], offsets[1:]):
82 77 bufs.append(bmsg[start:stop])
83 78 msg = json.loads(bufs[0].decode('utf8'))
84 79 msg['header'] = extract_dates(msg['header'])
85 80 msg['parent_header'] = extract_dates(msg['parent_header'])
86 81 msg['buffers'] = bufs[1:]
87 82 return msg
88 83
89 84
90 85 class ZMQStreamHandler(websocket.WebSocketHandler):
91 86
92 87 def check_origin(self, origin):
93 88 """Check Origin == Host or Access-Control-Allow-Origin.
94 89
95 90 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
96 91 We call it explicitly in `open` on Tornado < 4.
97 92 """
98 93 if self.allow_origin == '*':
99 94 return True
100 95
101 96 host = self.request.headers.get("Host")
102 97
103 98 # If no header is provided, assume we can't verify origin
104 99 if origin is None:
105 100 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
106 101 return False
107 102 if host is None:
108 103 self.log.warn("Missing Host header, rejecting WebSocket connection.")
109 104 return False
110 105
111 106 origin = origin.lower()
112 107 origin_host = urlparse(origin).netloc
113 108
114 109 # OK if origin matches host
115 110 if origin_host == host:
116 111 return True
117 112
118 113 # Check CORS headers
119 114 if self.allow_origin:
120 115 allow = self.allow_origin == origin
121 116 elif self.allow_origin_pat:
122 117 allow = bool(self.allow_origin_pat.match(origin))
123 118 else:
124 119 # No CORS headers deny the request
125 120 allow = False
126 121 if not allow:
127 122 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
128 123 origin, host,
129 124 )
130 125 return allow
131 126
132 127 def clear_cookie(self, *args, **kwargs):
133 128 """meaningless for websockets"""
134 129 pass
135 130
136 131 def _reserialize_reply(self, msg_list):
137 132 """Reserialize a reply message using JSON.
138 133
139 134 This takes the msg list from the ZMQ socket, deserializes it using
140 135 self.session and then serializes the result using JSON. This method
141 136 should be used by self._on_zmq_reply to build messages that can
142 137 be sent back to the browser.
143 138 """
144 139 idents, msg_list = self.session.feed_identities(msg_list)
145 140 msg = self.session.deserialize(msg_list)
146 141 if msg['buffers']:
147 142 buf = serialize_binary_message(msg)
148 143 return buf
149 144 else:
150 145 smsg = json.dumps(msg, default=date_default)
151 146 return cast_unicode(smsg)
152 147
153 148 def _on_zmq_reply(self, msg_list):
154 149 # Sometimes this gets triggered when the on_close method is scheduled in the
155 150 # eventloop but hasn't been called.
156 151 if self.stream.closed(): return
157 152 try:
158 153 msg = self._reserialize_reply(msg_list)
159 154 except Exception:
160 155 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
161 156 else:
162 157 self.write_message(msg, binary=isinstance(msg, bytes))
163 158
164 159 def allow_draft76(self):
165 160 """Allow draft 76, until browsers such as Safari update to RFC 6455.
166 161
167 162 This has been disabled by default in tornado in release 2.2.0, and
168 163 support will be removed in later versions.
169 164 """
170 165 return True
171 166
172 167 # ping interval for keeping websockets alive (30 seconds)
173 168 WS_PING_INTERVAL = 30000
174 169
175 170 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
176 171 ping_callback = None
177 172 last_ping = 0
178 173 last_pong = 0
179 174
180 175 @property
181 176 def ping_interval(self):
182 177 """The interval for websocket keep-alive pings.
183 178
184 179 Set ws_ping_interval = 0 to disable pings.
185 180 """
186 181 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
187 182
188 183 @property
189 184 def ping_timeout(self):
190 185 """If no ping is received in this many milliseconds,
191 186 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
192 187 Default is max of 3 pings or 30 seconds.
193 188 """
194 189 return self.settings.get('ws_ping_timeout',
195 190 max(3 * self.ping_interval, WS_PING_INTERVAL)
196 191 )
197 192
198 193 def set_default_headers(self):
199 194 """Undo the set_default_headers in IPythonHandler
200 195
201 196 which doesn't make sense for websockets
202 197 """
203 198 pass
204 199
205 200 def get(self, *args, **kwargs):
206 201 # Check to see that origin matches host directly, including ports
207 202 # Tornado 4 already does CORS checking
208 203 if tornado.version_info[0] < 4:
209 204 if not self.check_origin(self.get_origin()):
210 205 raise web.HTTPError(403)
211 206
212 207 # authenticate the request before opening the websocket
213 208 if self.get_current_user() is None:
214 209 self.log.warn("Couldn't authenticate WebSocket connection")
215 210 raise web.HTTPError(403)
216 211
217 212 if self.get_argument('session_id', False):
218 213 self.session.session = cast_unicode(self.get_argument('session_id'))
219 214 else:
220 215 self.log.warn("No session ID specified")
221
216 # FIXME: only do super get on tornado β‰₯ 4
217 # tornado 3 has no get, will raise 405
218 if tornado.version_info >= (4,):
222 219 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
223 220
224 221 def initialize(self):
225 222 self.session = Session(config=self.config)
226 223
227 224 def open(self, *args, **kwargs):
225 if tornado.version_info < (4,):
226 try:
227 self.get(*self.open_args, **self.open_kwargs)
228 except web.HTTPError:
229 self.close()
230 raise
228 231
229 232 # start the pinging
230 233 if self.ping_interval > 0:
231 234 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
232 235 self.last_pong = self.last_ping
233 236 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
234 237 self.ping_callback.start()
235 238
236 239 def send_ping(self):
237 240 """send a ping to keep the websocket alive"""
238 241 if self.stream.closed() and self.ping_callback is not None:
239 242 self.ping_callback.stop()
240 243 return
241 244
242 245 # check for timeout on pong. Make sure that we really have sent a recent ping in
243 246 # case the machine with both server and client has been suspended since the last ping.
244 247 now = ioloop.IOLoop.instance().time()
245 248 since_last_pong = 1e3 * (now - self.last_pong)
246 249 since_last_ping = 1e3 * (now - self.last_ping)
247 250 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
248 251 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
249 252 self.close()
250 253 return
251 254
252 255 self.ping(b'')
253 256 self.last_ping = now
254 257
255 258 def on_pong(self, data):
256 259 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,267 +1,269 b''
1 1 """Tornado handlers for kernels."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 7 import logging
8 8 from tornado import gen, web
9 9 from tornado.concurrent import Future
10 10
11 11 from IPython.utils.jsonutil import date_default
12 12 from IPython.utils.py3compat import cast_unicode
13 13 from IPython.html.utils import url_path_join, url_escape
14 14
15 15 from ...base.handlers import IPythonHandler, json_errors
16 16 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
17 17
18 18 from IPython.core.release import kernel_protocol_version
19 19
20 20 class MainKernelHandler(IPythonHandler):
21 21
22 22 @web.authenticated
23 23 @json_errors
24 24 def get(self):
25 25 km = self.kernel_manager
26 26 self.finish(json.dumps(km.list_kernels()))
27 27
28 28 @web.authenticated
29 29 @json_errors
30 30 def post(self):
31 31 km = self.kernel_manager
32 32 model = self.get_json_body()
33 33 if model is None:
34 34 model = {
35 35 'name': km.default_kernel_name
36 36 }
37 37 else:
38 38 model.setdefault('name', km.default_kernel_name)
39 39
40 40 kernel_id = km.start_kernel(kernel_name=model['name'])
41 41 model = km.kernel_model(kernel_id)
42 42 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
43 43 self.set_header('Location', url_escape(location))
44 44 self.set_status(201)
45 45 self.finish(json.dumps(model))
46 46
47 47
48 48 class KernelHandler(IPythonHandler):
49 49
50 50 SUPPORTED_METHODS = ('DELETE', 'GET')
51 51
52 52 @web.authenticated
53 53 @json_errors
54 54 def get(self, kernel_id):
55 55 km = self.kernel_manager
56 56 km._check_kernel_id(kernel_id)
57 57 model = km.kernel_model(kernel_id)
58 58 self.finish(json.dumps(model))
59 59
60 60 @web.authenticated
61 61 @json_errors
62 62 def delete(self, kernel_id):
63 63 km = self.kernel_manager
64 64 km.shutdown_kernel(kernel_id)
65 65 self.set_status(204)
66 66 self.finish()
67 67
68 68
69 69 class KernelActionHandler(IPythonHandler):
70 70
71 71 @web.authenticated
72 72 @json_errors
73 73 def post(self, kernel_id, action):
74 74 km = self.kernel_manager
75 75 if action == 'interrupt':
76 76 km.interrupt_kernel(kernel_id)
77 77 self.set_status(204)
78 78 if action == 'restart':
79 79 km.restart_kernel(kernel_id)
80 80 model = km.kernel_model(kernel_id)
81 81 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
82 82 self.write(json.dumps(model))
83 83 self.finish()
84 84
85 85
86 86 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
87 87
88 88 def __repr__(self):
89 89 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
90 90
91 91 def create_stream(self):
92 92 km = self.kernel_manager
93 93 meth = getattr(km, 'connect_%s' % self.channel)
94 94 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
95 95
96 96 def request_kernel_info(self):
97 97 """send a request for kernel_info"""
98 98 km = self.kernel_manager
99 99 kernel = km.get_kernel(self.kernel_id)
100 100 try:
101 101 # check for cached value
102 102 kernel_info = kernel._kernel_info
103 103 except AttributeError:
104 104 self.log.debug("Requesting kernel info from %s", self.kernel_id)
105 105 # Create a kernel_info channel to query the kernel protocol version.
106 106 # This channel will be closed after the kernel_info reply is received.
107 107 if self.kernel_info_channel is None:
108 108 self.kernel_info_channel = km.connect_shell(self.kernel_id)
109 109 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
110 110 self.session.send(self.kernel_info_channel, "kernel_info_request")
111 111 else:
112 112 # use cached value, don't resend request
113 113 self._finish_kernel_info(kernel_info)
114 114 return self._kernel_info_future
115 115
116 116 def _handle_kernel_info_reply(self, msg):
117 117 """process the kernel_info_reply
118 118
119 119 enabling msg spec adaptation, if necessary
120 120 """
121 121 idents,msg = self.session.feed_identities(msg)
122 122 try:
123 123 msg = self.session.deserialize(msg)
124 124 except:
125 125 self.log.error("Bad kernel_info reply", exc_info=True)
126 126 self.request_kernel_info()
127 127 return
128 128 else:
129 129 info = msg['content']
130 130 self.log.debug("Received kernel info: %s", info)
131 131 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
132 132 self.log.error("Kernel info request failed, assuming current %s", info)
133 133 else:
134 134 kernel = self.kernel_manager.get_kernel(self.kernel_id)
135 135 kernel._kernel_info = info
136 136 self._finish_kernel_info(info)
137 137
138 138 # close the kernel_info channel, we don't need it anymore
139 139 if self.kernel_info_channel:
140 140 self.kernel_info_channel.close()
141 141 self.kernel_info_channel = None
142 142
143 143 def _finish_kernel_info(self, info):
144 144 """Finish handling kernel_info reply
145 145
146 146 Set up protocol adaptation, if needed,
147 147 and signal that connection can continue.
148 148 """
149 149 protocol_version = info.get('protocol_version', kernel_protocol_version)
150 150 if protocol_version != kernel_protocol_version:
151 151 self.session.adapt_version = int(protocol_version.split('.')[0])
152 152 self.log.info("Kernel %s speaks protocol %s", self.kernel_id, protocol_version)
153 153 self._kernel_info_future.set_result(info)
154 154
155 155 def initialize(self):
156 156 super(ZMQChannelHandler, self).initialize()
157 157 self.zmq_stream = None
158 self.kernel_id = None
158 159 self.kernel_info_channel = None
159 160 self._kernel_info_future = Future()
160 161
161 162 @gen.coroutine
162 163 def get(self, kernel_id):
163 164 self.kernel_id = cast_unicode(kernel_id, 'ascii')
164 165 yield self.request_kernel_info()
165 166 super(ZMQChannelHandler, self).get(kernel_id)
166 167
167 168 def open(self, kernel_id):
168 169 super(ZMQChannelHandler, self).open()
169 170 try:
170 171 self.create_stream()
171 except web.HTTPError:
172 except web.HTTPError as e:
173 self.log.error("Error opening stream: %s", e)
172 174 # WebSockets don't response to traditional error codes so we
173 175 # close the connection.
174 176 if not self.stream.closed():
175 177 self.stream.close()
176 178 self.close()
177 179 else:
178 180 self.zmq_stream.on_recv(self._on_zmq_reply)
179 181
180 182 def on_message(self, msg):
181 183 if self.zmq_stream is None:
182 184 return
183 185 elif self.zmq_stream.closed():
184 186 self.log.info("%s closed, closing websocket.", self)
185 187 self.close()
186 188 return
187 189 if isinstance(msg, bytes):
188 190 msg = deserialize_binary_message(msg)
189 191 else:
190 192 msg = json.loads(msg)
191 193 self.session.send(self.zmq_stream, msg)
192 194
193 195 def on_close(self):
194 196 # This method can be called twice, once by self.kernel_died and once
195 197 # from the WebSocket close event. If the WebSocket connection is
196 198 # closed before the ZMQ streams are setup, they could be None.
197 199 if self.zmq_stream is not None and not self.zmq_stream.closed():
198 200 self.zmq_stream.on_recv(None)
199 201 # close the socket directly, don't wait for the stream
200 202 socket = self.zmq_stream.socket
201 203 self.zmq_stream.close()
202 204 socket.close()
203 205
204 206
205 207 class IOPubHandler(ZMQChannelHandler):
206 208 channel = 'iopub'
207 209
208 210 def create_stream(self):
209 211 super(IOPubHandler, self).create_stream()
210 212 km = self.kernel_manager
211 213 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
212 214 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
213 215
214 216 def on_close(self):
215 217 km = self.kernel_manager
216 218 if self.kernel_id in km:
217 219 km.remove_restart_callback(
218 220 self.kernel_id, self.on_kernel_restarted,
219 221 )
220 222 km.remove_restart_callback(
221 223 self.kernel_id, self.on_restart_failed, 'dead',
222 224 )
223 225 super(IOPubHandler, self).on_close()
224 226
225 227 def _send_status_message(self, status):
226 228 msg = self.session.msg("status",
227 229 {'execution_state': status}
228 230 )
229 231 self.write_message(json.dumps(msg, default=date_default))
230 232
231 233 def on_kernel_restarted(self):
232 234 logging.warn("kernel %s restarted", self.kernel_id)
233 235 self._send_status_message('restarting')
234 236
235 237 def on_restart_failed(self):
236 238 logging.error("kernel %s restarted failed!", self.kernel_id)
237 239 self._send_status_message('dead')
238 240
239 241 def on_message(self, msg):
240 242 """IOPub messages make no sense"""
241 243 pass
242 244
243 245
244 246 class ShellHandler(ZMQChannelHandler):
245 247 channel = 'shell'
246 248
247 249
248 250 class StdinHandler(ZMQChannelHandler):
249 251 channel = 'stdin'
250 252
251 253
252 254 #-----------------------------------------------------------------------------
253 255 # URL to handler mappings
254 256 #-----------------------------------------------------------------------------
255 257
256 258
257 259 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
258 260 _kernel_action_regex = r"(?P<action>restart|interrupt)"
259 261
260 262 default_handlers = [
261 263 (r"/api/kernels", MainKernelHandler),
262 264 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
263 265 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
264 266 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
265 267 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
266 268 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
267 269 ]
General Comments 0
You need to be logged in to leave comments. Login now