##// END OF EJS Templates
debugging websocket connections...
Min RK -
Show More
@@ -1,259 +1,269 b''
1 1 # coding: utf-8
2 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import json
8 8 import struct
9 9
10 10 try:
11 11 from urllib.parse import urlparse # Py 3
12 12 except ImportError:
13 13 from urlparse import urlparse # Py 2
14 14
15 15 import tornado
16 from tornado import ioloop
17 from tornado import web
18 from tornado import websocket
16 from tornado import gen, ioloop, web, websocket
19 17
20 18 from IPython.kernel.zmq.session import Session
21 19 from IPython.utils.jsonutil import date_default, extract_dates
22 20 from IPython.utils.py3compat import cast_unicode
23 21
24 22 from .handlers import IPythonHandler
25 23
26 24
27 25 def serialize_binary_message(msg):
28 26 """serialize a message as a binary blob
29 27
30 28 Header:
31 29
32 30 4 bytes: number of msg parts (nbufs) as 32b int
33 31 4 * nbufs bytes: offset for each buffer as integer as 32b int
34 32
35 33 Offsets are from the start of the buffer, including the header.
36 34
37 35 Returns
38 36 -------
39 37
40 38 The message serialized to bytes.
41 39
42 40 """
43 41 # don't modify msg or buffer list in-place
44 42 msg = msg.copy()
45 43 buffers = list(msg.pop('buffers'))
46 44 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 45 buffers.insert(0, bmsg)
48 46 nbufs = len(buffers)
49 47 offsets = [4 * (nbufs + 1)]
50 48 for buf in buffers[:-1]:
51 49 offsets.append(offsets[-1] + len(buf))
52 50 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 51 buffers.insert(0, offsets_buf)
54 52 return b''.join(buffers)
55 53
56 54
57 55 def deserialize_binary_message(bmsg):
58 56 """deserialize a message from a binary blog
59 57
60 58 Header:
61 59
62 60 4 bytes: number of msg parts (nbufs) as 32b int
63 61 4 * nbufs bytes: offset for each buffer as integer as 32b int
64 62
65 63 Offsets are from the start of the buffer, including the header.
66 64
67 65 Returns
68 66 -------
69 67
70 68 message dictionary
71 69 """
72 70 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 71 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 72 offsets.append(None)
75 73 bufs = []
76 74 for start, stop in zip(offsets[:-1], offsets[1:]):
77 75 bufs.append(bmsg[start:stop])
78 76 msg = json.loads(bufs[0].decode('utf8'))
79 77 msg['header'] = extract_dates(msg['header'])
80 78 msg['parent_header'] = extract_dates(msg['parent_header'])
81 79 msg['buffers'] = bufs[1:]
82 80 return msg
83 81
84 82
85 83 class ZMQStreamHandler(websocket.WebSocketHandler):
86 84
87 85 def check_origin(self, origin):
88 86 """Check Origin == Host or Access-Control-Allow-Origin.
89 87
90 88 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
91 89 We call it explicitly in `open` on Tornado < 4.
92 90 """
93 91 if self.allow_origin == '*':
94 92 return True
95 93
96 94 host = self.request.headers.get("Host")
97 95
98 96 # If no header is provided, assume we can't verify origin
99 97 if origin is None:
100 98 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
101 99 return False
102 100 if host is None:
103 101 self.log.warn("Missing Host header, rejecting WebSocket connection.")
104 102 return False
105 103
106 104 origin = origin.lower()
107 105 origin_host = urlparse(origin).netloc
108 106
109 107 # OK if origin matches host
110 108 if origin_host == host:
111 109 return True
112 110
113 111 # Check CORS headers
114 112 if self.allow_origin:
115 113 allow = self.allow_origin == origin
116 114 elif self.allow_origin_pat:
117 115 allow = bool(self.allow_origin_pat.match(origin))
118 116 else:
119 117 # No CORS headers deny the request
120 118 allow = False
121 119 if not allow:
122 120 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
123 121 origin, host,
124 122 )
125 123 return allow
126 124
127 125 def clear_cookie(self, *args, **kwargs):
128 126 """meaningless for websockets"""
129 127 pass
130 128
131 129 def _reserialize_reply(self, msg_list):
132 130 """Reserialize a reply message using JSON.
133 131
134 132 This takes the msg list from the ZMQ socket, deserializes it using
135 133 self.session and then serializes the result using JSON. This method
136 134 should be used by self._on_zmq_reply to build messages that can
137 135 be sent back to the browser.
138 136 """
139 137 idents, msg_list = self.session.feed_identities(msg_list)
140 138 msg = self.session.deserialize(msg_list)
141 139 if msg['buffers']:
142 140 buf = serialize_binary_message(msg)
143 141 return buf
144 142 else:
145 143 smsg = json.dumps(msg, default=date_default)
146 144 return cast_unicode(smsg)
147 145
148 146 def _on_zmq_reply(self, msg_list):
149 147 # Sometimes this gets triggered when the on_close method is scheduled in the
150 148 # eventloop but hasn't been called.
151 149 if self.stream.closed(): return
152 150 try:
153 151 msg = self._reserialize_reply(msg_list)
154 152 except Exception:
155 153 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
156 154 else:
157 155 self.write_message(msg, binary=isinstance(msg, bytes))
158 156
159 157 def allow_draft76(self):
160 158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
161 159
162 160 This has been disabled by default in tornado in release 2.2.0, and
163 161 support will be removed in later versions.
164 162 """
165 163 return True
166 164
167 165 # ping interval for keeping websockets alive (30 seconds)
168 166 WS_PING_INTERVAL = 30000
169 167
170 168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
171 169 ping_callback = None
172 170 last_ping = 0
173 171 last_pong = 0
174 172
175 173 @property
176 174 def ping_interval(self):
177 175 """The interval for websocket keep-alive pings.
178 176
179 177 Set ws_ping_interval = 0 to disable pings.
180 178 """
181 179 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
182 180
183 181 @property
184 182 def ping_timeout(self):
185 183 """If no ping is received in this many milliseconds,
186 184 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
187 185 Default is max of 3 pings or 30 seconds.
188 186 """
189 187 return self.settings.get('ws_ping_timeout',
190 188 max(3 * self.ping_interval, WS_PING_INTERVAL)
191 189 )
192 190
193 191 def set_default_headers(self):
194 192 """Undo the set_default_headers in IPythonHandler
195 193
196 194 which doesn't make sense for websockets
197 195 """
198 196 pass
199 197
200 def get(self, *args, **kwargs):
198 def pre_get(self):
199 """Run before finishing the GET request
200
201 Extend this method to add logic that should fire before
202 the websocket finishes completing.
203 """
201 204 # Check to see that origin matches host directly, including ports
202 205 # Tornado 4 already does CORS checking
203 206 if tornado.version_info[0] < 4:
204 207 if not self.check_origin(self.get_origin()):
205 208 raise web.HTTPError(403)
206 209
207 210 # authenticate the request before opening the websocket
208 211 if self.get_current_user() is None:
209 212 self.log.warn("Couldn't authenticate WebSocket connection")
210 213 raise web.HTTPError(403)
211 214
212 215 if self.get_argument('session_id', False):
213 216 self.session.session = cast_unicode(self.get_argument('session_id'))
214 217 else:
215 218 self.log.warn("No session ID specified")
219
220 @gen.coroutine
221 def get(self, *args, **kwargs):
222 # pre_get can be a coroutine in subclasses
223 yield gen.maybe_future(self.pre_get())
216 224 # FIXME: only do super get on tornado β‰₯ 4
217 225 # tornado 3 has no get, will raise 405
218 226 if tornado.version_info >= (4,):
219 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
227 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
220 228
221 229 def initialize(self):
230 self.log.debug("Initializing websocket connection %s", self.request.path)
222 231 self.session = Session(config=self.config)
223 232
224 233 def open(self, *args, **kwargs):
234 self.log.debug("Opening websocket %s", self.request.path)
225 235 if tornado.version_info < (4,):
226 236 try:
227 237 self.get(*self.open_args, **self.open_kwargs)
228 238 except web.HTTPError:
229 239 self.close()
230 240 raise
231 241
232 242 # start the pinging
233 243 if self.ping_interval > 0:
234 244 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
235 245 self.last_pong = self.last_ping
236 246 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
237 247 self.ping_callback.start()
238 248
239 249 def send_ping(self):
240 250 """send a ping to keep the websocket alive"""
241 251 if self.stream.closed() and self.ping_callback is not None:
242 252 self.ping_callback.stop()
243 253 return
244 254
245 255 # check for timeout on pong. Make sure that we really have sent a recent ping in
246 256 # case the machine with both server and client has been suspended since the last ping.
247 257 now = ioloop.IOLoop.instance().time()
248 258 since_last_pong = 1e3 * (now - self.last_pong)
249 259 since_last_ping = 1e3 * (now - self.last_ping)
250 260 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
251 261 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
252 262 self.close()
253 263 return
254 264
255 265 self.ping(b'')
256 266 self.last_ping = now
257 267
258 268 def on_pong(self, data):
259 269 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,269 +1,294 b''
1 1 """Tornado handlers for kernels."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 7 import logging
8 8 from tornado import gen, web
9 9 from tornado.concurrent import Future
10 from tornado.ioloop import IOLoop
10 11
11 12 from IPython.utils.jsonutil import date_default
12 13 from IPython.utils.py3compat import cast_unicode
13 14 from IPython.html.utils import url_path_join, url_escape
14 15
15 16 from ...base.handlers import IPythonHandler, json_errors
16 17 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
17 18
18 19 from IPython.core.release import kernel_protocol_version
19 20
20 21 class MainKernelHandler(IPythonHandler):
21 22
22 23 @web.authenticated
23 24 @json_errors
24 25 def get(self):
25 26 km = self.kernel_manager
26 27 self.finish(json.dumps(km.list_kernels()))
27 28
28 29 @web.authenticated
29 30 @json_errors
30 31 def post(self):
31 32 km = self.kernel_manager
32 33 model = self.get_json_body()
33 34 if model is None:
34 35 model = {
35 36 'name': km.default_kernel_name
36 37 }
37 38 else:
38 39 model.setdefault('name', km.default_kernel_name)
39 40
40 41 kernel_id = km.start_kernel(kernel_name=model['name'])
41 42 model = km.kernel_model(kernel_id)
42 43 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
43 44 self.set_header('Location', url_escape(location))
44 45 self.set_status(201)
45 46 self.finish(json.dumps(model))
46 47
47 48
48 49 class KernelHandler(IPythonHandler):
49 50
50 51 SUPPORTED_METHODS = ('DELETE', 'GET')
51 52
52 53 @web.authenticated
53 54 @json_errors
54 55 def get(self, kernel_id):
55 56 km = self.kernel_manager
56 57 km._check_kernel_id(kernel_id)
57 58 model = km.kernel_model(kernel_id)
58 59 self.finish(json.dumps(model))
59 60
60 61 @web.authenticated
61 62 @json_errors
62 63 def delete(self, kernel_id):
63 64 km = self.kernel_manager
64 65 km.shutdown_kernel(kernel_id)
65 66 self.set_status(204)
66 67 self.finish()
67 68
68 69
69 70 class KernelActionHandler(IPythonHandler):
70 71
71 72 @web.authenticated
72 73 @json_errors
73 74 def post(self, kernel_id, action):
74 75 km = self.kernel_manager
75 76 if action == 'interrupt':
76 77 km.interrupt_kernel(kernel_id)
77 78 self.set_status(204)
78 79 if action == 'restart':
79 80 km.restart_kernel(kernel_id)
80 81 model = km.kernel_model(kernel_id)
81 82 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
82 83 self.write(json.dumps(model))
83 84 self.finish()
84 85
85 86
86 87 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
87 88
89 @property
90 def kernel_info_timeout(self):
91 return self.settings.get('kernel_info_timeout', 10)
92
88 93 def __repr__(self):
89 94 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
90 95
91 96 def create_stream(self):
92 97 km = self.kernel_manager
93 98 meth = getattr(km, 'connect_%s' % self.channel)
94 99 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
95 100
96 101 def request_kernel_info(self):
97 102 """send a request for kernel_info"""
98 103 km = self.kernel_manager
99 104 kernel = km.get_kernel(self.kernel_id)
100 105 try:
101 106 # check for cached value
102 107 kernel_info = kernel._kernel_info
103 108 except AttributeError:
104 109 self.log.debug("Requesting kernel info from %s", self.kernel_id)
105 110 # Create a kernel_info channel to query the kernel protocol version.
106 111 # This channel will be closed after the kernel_info reply is received.
107 112 if self.kernel_info_channel is None:
108 113 self.kernel_info_channel = km.connect_shell(self.kernel_id)
109 114 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
110 115 self.session.send(self.kernel_info_channel, "kernel_info_request")
111 116 else:
112 117 # use cached value, don't resend request
113 118 self._finish_kernel_info(kernel_info)
114 119 return self._kernel_info_future
115 120
116 121 def _handle_kernel_info_reply(self, msg):
117 122 """process the kernel_info_reply
118 123
119 124 enabling msg spec adaptation, if necessary
120 125 """
121 126 idents,msg = self.session.feed_identities(msg)
122 127 try:
123 128 msg = self.session.deserialize(msg)
124 129 except:
125 130 self.log.error("Bad kernel_info reply", exc_info=True)
126 131 self._kernel_info_future.set_result(None)
127 132 return
128 133 else:
129 134 info = msg['content']
130 135 self.log.debug("Received kernel info: %s", info)
131 136 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
132 137 self.log.error("Kernel info request failed, assuming current %s", info)
133 138 else:
134 139 kernel = self.kernel_manager.get_kernel(self.kernel_id)
135 140 kernel._kernel_info = info
136 141 self._finish_kernel_info(info)
137 142
138 143 # close the kernel_info channel, we don't need it anymore
139 144 if self.kernel_info_channel:
140 145 self.kernel_info_channel.close()
141 146 self.kernel_info_channel = None
142 147
143 148 def _finish_kernel_info(self, info):
144 149 """Finish handling kernel_info reply
145 150
146 151 Set up protocol adaptation, if needed,
147 152 and signal that connection can continue.
148 153 """
149 154 protocol_version = info.get('protocol_version', kernel_protocol_version)
150 155 if protocol_version != kernel_protocol_version:
151 156 self.session.adapt_version = int(protocol_version.split('.')[0])
152 157 self.log.info("Kernel %s speaks protocol %s", self.kernel_id, protocol_version)
153 self._kernel_info_future.set_result(info)
158 if not self._kernel_info_future.done():
159 self._kernel_info_future.set_result(info)
154 160
155 161 def initialize(self):
156 162 super(ZMQChannelHandler, self).initialize()
157 163 self.zmq_stream = None
158 164 self.kernel_id = None
159 165 self.kernel_info_channel = None
160 166 self._kernel_info_future = Future()
161 167
162 168 @gen.coroutine
169 def pre_get(self):
170 # authenticate first
171 super(ZMQChannelHandler, self).pre_get()
172 # then request kernel info, waiting up to a certain time before giving up.
173 # We don't want to wait forever, because browsers don't take it well when
174 # servers never respond to websocket connection requests.
175 future = self.request_kernel_info()
176
177 def give_up():
178 """Don't wait forever for the kernel to reply"""
179 if future.done():
180 return
181 self.log.warn("Timeout waiting for kernel_info reply from %s", self.kernel_id)
182 future.set_result(None)
183 loop = IOLoop.current()
184 loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up)
185 # actually wait for it
186 yield future
187
188 @gen.coroutine
163 189 def get(self, kernel_id):
164 190 self.kernel_id = cast_unicode(kernel_id, 'ascii')
165 yield self.request_kernel_info()
166 super(ZMQChannelHandler, self).get(kernel_id)
191 yield super(ZMQChannelHandler, self).get(kernel_id=kernel_id)
167 192
168 193 def open(self, kernel_id):
169 194 super(ZMQChannelHandler, self).open()
170 195 try:
171 196 self.create_stream()
172 197 except web.HTTPError as e:
173 198 self.log.error("Error opening stream: %s", e)
174 199 # WebSockets don't response to traditional error codes so we
175 200 # close the connection.
176 201 if not self.stream.closed():
177 202 self.stream.close()
178 203 self.close()
179 204 else:
180 205 self.zmq_stream.on_recv(self._on_zmq_reply)
181 206
182 207 def on_message(self, msg):
183 208 if self.zmq_stream is None:
184 209 return
185 210 elif self.zmq_stream.closed():
186 211 self.log.info("%s closed, closing websocket.", self)
187 212 self.close()
188 213 return
189 214 if isinstance(msg, bytes):
190 215 msg = deserialize_binary_message(msg)
191 216 else:
192 217 msg = json.loads(msg)
193 218 self.session.send(self.zmq_stream, msg)
194 219
195 220 def on_close(self):
196 221 # This method can be called twice, once by self.kernel_died and once
197 222 # from the WebSocket close event. If the WebSocket connection is
198 223 # closed before the ZMQ streams are setup, they could be None.
199 224 if self.zmq_stream is not None and not self.zmq_stream.closed():
200 225 self.zmq_stream.on_recv(None)
201 226 # close the socket directly, don't wait for the stream
202 227 socket = self.zmq_stream.socket
203 228 self.zmq_stream.close()
204 229 socket.close()
205 230
206 231
207 232 class IOPubHandler(ZMQChannelHandler):
208 233 channel = 'iopub'
209 234
210 235 def create_stream(self):
211 236 super(IOPubHandler, self).create_stream()
212 237 km = self.kernel_manager
213 238 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
214 239 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
215 240
216 241 def on_close(self):
217 242 km = self.kernel_manager
218 243 if self.kernel_id in km:
219 244 km.remove_restart_callback(
220 245 self.kernel_id, self.on_kernel_restarted,
221 246 )
222 247 km.remove_restart_callback(
223 248 self.kernel_id, self.on_restart_failed, 'dead',
224 249 )
225 250 super(IOPubHandler, self).on_close()
226 251
227 252 def _send_status_message(self, status):
228 253 msg = self.session.msg("status",
229 254 {'execution_state': status}
230 255 )
231 256 self.write_message(json.dumps(msg, default=date_default))
232 257
233 258 def on_kernel_restarted(self):
234 259 logging.warn("kernel %s restarted", self.kernel_id)
235 260 self._send_status_message('restarting')
236 261
237 262 def on_restart_failed(self):
238 263 logging.error("kernel %s restarted failed!", self.kernel_id)
239 264 self._send_status_message('dead')
240 265
241 266 def on_message(self, msg):
242 267 """IOPub messages make no sense"""
243 268 pass
244 269
245 270
246 271 class ShellHandler(ZMQChannelHandler):
247 272 channel = 'shell'
248 273
249 274
250 275 class StdinHandler(ZMQChannelHandler):
251 276 channel = 'stdin'
252 277
253 278
254 279 #-----------------------------------------------------------------------------
255 280 # URL to handler mappings
256 281 #-----------------------------------------------------------------------------
257 282
258 283
259 284 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
260 285 _kernel_action_regex = r"(?P<action>restart|interrupt)"
261 286
262 287 default_handlers = [
263 288 (r"/api/kernels", MainKernelHandler),
264 289 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
265 290 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
266 291 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
267 292 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
268 293 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
269 294 ]
General Comments 0
You need to be logged in to leave comments. Login now