##// END OF EJS Templates
cache kernel_info reply for protocol adaptation...
MinRK -
Show More
@@ -1,257 +1,256 b''
1 1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 7 import struct
8 8
9 9 try:
10 10 from urllib.parse import urlparse # Py 3
11 11 except ImportError:
12 12 from urlparse import urlparse # Py 2
13 13
14 14 try:
15 15 from http.cookies import SimpleCookie # Py 3
16 16 except ImportError:
17 17 from Cookie import SimpleCookie # Py 2
18 18 import logging
19 19
20 20 import tornado
21 21 from tornado import ioloop
22 22 from tornado import web
23 23 from tornado import websocket
24 24
25 25 from IPython.kernel.zmq.session import Session
26 26 from IPython.utils.jsonutil import date_default, extract_dates
27 27 from IPython.utils.py3compat import PY3, cast_unicode
28 28
29 29 from .handlers import IPythonHandler
30 30
31 31
32 32 def serialize_binary_message(msg):
33 33 """serialize a message as a binary blob
34 34
35 35 Header:
36 36
37 37 4 bytes: number of msg parts (nbufs) as 32b int
38 38 4 * nbufs bytes: offset for each buffer as integer as 32b int
39 39
40 40 Offsets are from the start of the buffer, including the header.
41 41
42 42 Returns
43 43 -------
44 44
45 45 The message serialized to bytes.
46 46
47 47 """
48 48 # don't modify msg or buffer list in-place
49 49 msg = msg.copy()
50 50 buffers = list(msg.pop('buffers'))
51 51 bmsg = json.dumps(msg, default=date_default).encode('utf8')
52 52 buffers.insert(0, bmsg)
53 53 nbufs = len(buffers)
54 54 offsets = [4 * (nbufs + 1)]
55 55 for buf in buffers[:-1]:
56 56 offsets.append(offsets[-1] + len(buf))
57 57 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
58 58 buffers.insert(0, offsets_buf)
59 59 return b''.join(buffers)
60 60
61 61
62 62 def deserialize_binary_message(bmsg):
63 63 """deserialize a message from a binary blog
64 64
65 65 Header:
66 66
67 67 4 bytes: number of msg parts (nbufs) as 32b int
68 68 4 * nbufs bytes: offset for each buffer as integer as 32b int
69 69
70 70 Offsets are from the start of the buffer, including the header.
71 71
72 72 Returns
73 73 -------
74 74
75 75 message dictionary
76 76 """
77 77 nbufs = struct.unpack('!i', bmsg[:4])[0]
78 78 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
79 79 offsets.append(None)
80 80 bufs = []
81 81 for start, stop in zip(offsets[:-1], offsets[1:]):
82 82 bufs.append(bmsg[start:stop])
83 83 msg = json.loads(bufs[0].decode('utf8'))
84 84 msg['header'] = extract_dates(msg['header'])
85 85 msg['parent_header'] = extract_dates(msg['parent_header'])
86 86 msg['buffers'] = bufs[1:]
87 87 return msg
88 88
89 89
90 90 class ZMQStreamHandler(websocket.WebSocketHandler):
91 91
92 92 def check_origin(self, origin):
93 93 """Check Origin == Host or Access-Control-Allow-Origin.
94 94
95 95 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
96 96 We call it explicitly in `open` on Tornado < 4.
97 97 """
98 98 if self.allow_origin == '*':
99 99 return True
100 100
101 101 host = self.request.headers.get("Host")
102 102
103 103 # If no header is provided, assume we can't verify origin
104 104 if origin is None:
105 105 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
106 106 return False
107 107 if host is None:
108 108 self.log.warn("Missing Host header, rejecting WebSocket connection.")
109 109 return False
110 110
111 111 origin = origin.lower()
112 112 origin_host = urlparse(origin).netloc
113 113
114 114 # OK if origin matches host
115 115 if origin_host == host:
116 116 return True
117 117
118 118 # Check CORS headers
119 119 if self.allow_origin:
120 120 allow = self.allow_origin == origin
121 121 elif self.allow_origin_pat:
122 122 allow = bool(self.allow_origin_pat.match(origin))
123 123 else:
124 124 # No CORS headers deny the request
125 125 allow = False
126 126 if not allow:
127 127 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
128 128 origin, host,
129 129 )
130 130 return allow
131 131
132 132 def clear_cookie(self, *args, **kwargs):
133 133 """meaningless for websockets"""
134 134 pass
135 135
136 136 def _reserialize_reply(self, msg_list):
137 137 """Reserialize a reply message using JSON.
138 138
139 139 This takes the msg list from the ZMQ socket, deserializes it using
140 140 self.session and then serializes the result using JSON. This method
141 141 should be used by self._on_zmq_reply to build messages that can
142 142 be sent back to the browser.
143 143 """
144 144 idents, msg_list = self.session.feed_identities(msg_list)
145 145 msg = self.session.deserialize(msg_list)
146 146 if msg['buffers']:
147 147 buf = serialize_binary_message(msg)
148 148 return buf
149 149 else:
150 150 smsg = json.dumps(msg, default=date_default)
151 151 return cast_unicode(smsg)
152 152
153 153 def _on_zmq_reply(self, msg_list):
154 154 # Sometimes this gets triggered when the on_close method is scheduled in the
155 155 # eventloop but hasn't been called.
156 156 if self.stream.closed(): return
157 157 try:
158 158 msg = self._reserialize_reply(msg_list)
159 159 except Exception:
160 160 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
161 161 else:
162 162 self.write_message(msg, binary=isinstance(msg, bytes))
163 163
164 164 def allow_draft76(self):
165 165 """Allow draft 76, until browsers such as Safari update to RFC 6455.
166 166
167 167 This has been disabled by default in tornado in release 2.2.0, and
168 168 support will be removed in later versions.
169 169 """
170 170 return True
171 171
172 172 # ping interval for keeping websockets alive (30 seconds)
173 173 WS_PING_INTERVAL = 30000
174 174
175 175 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
176 176 ping_callback = None
177 177 last_ping = 0
178 178 last_pong = 0
179 179
180 180 @property
181 181 def ping_interval(self):
182 182 """The interval for websocket keep-alive pings.
183 183
184 184 Set ws_ping_interval = 0 to disable pings.
185 185 """
186 186 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
187 187
188 188 @property
189 189 def ping_timeout(self):
190 190 """If no ping is received in this many milliseconds,
191 191 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
192 192 Default is max of 3 pings or 30 seconds.
193 193 """
194 194 return self.settings.get('ws_ping_timeout',
195 195 max(3 * self.ping_interval, WS_PING_INTERVAL)
196 196 )
197 197
198 198 def set_default_headers(self):
199 199 """Undo the set_default_headers in IPythonHandler
200 200
201 201 which doesn't make sense for websockets
202 202 """
203 203 pass
204 204
205 205 def get(self, *args, **kwargs):
206 206 # Check to see that origin matches host directly, including ports
207 207 # Tornado 4 already does CORS checking
208 208 if tornado.version_info[0] < 4:
209 209 if not self.check_origin(self.get_origin()):
210 210 raise web.HTTPError(403)
211 211
212 212 # authenticate the request before opening the websocket
213 213 if self.get_current_user() is None:
214 214 self.log.warn("Couldn't authenticate WebSocket connection")
215 215 raise web.HTTPError(403)
216 216
217 217 if self.get_argument('session_id', False):
218 218 self.session.session = cast_unicode(self.get_argument('session_id'))
219 219 else:
220 220 self.log.warn("No session ID specified")
221 221
222 222 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
223 223
224 224 def initialize(self):
225 225 self.session = Session(config=self.config)
226 226
227 def open(self, kernel_id):
228 self.kernel_id = cast_unicode(kernel_id, 'ascii')
227 def open(self, *args, **kwargs):
229 228
230 229 # start the pinging
231 230 if self.ping_interval > 0:
232 231 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
233 232 self.last_pong = self.last_ping
234 233 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
235 234 self.ping_callback.start()
236 235
237 236 def send_ping(self):
238 237 """send a ping to keep the websocket alive"""
239 238 if self.stream.closed() and self.ping_callback is not None:
240 239 self.ping_callback.stop()
241 240 return
242 241
243 242 # check for timeout on pong. Make sure that we really have sent a recent ping in
244 243 # case the machine with both server and client has been suspended since the last ping.
245 244 now = ioloop.IOLoop.instance().time()
246 245 since_last_pong = 1e3 * (now - self.last_pong)
247 246 since_last_ping = 1e3 * (now - self.last_ping)
248 247 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
249 248 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
250 249 self.close()
251 250 return
252 251
253 252 self.ping(b'')
254 253 self.last_ping = now
255 254
256 255 def on_pong(self, data):
257 256 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,233 +1,267 b''
1 1 """Tornado handlers for kernels."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 7 import logging
8 from tornado import web
8 from tornado import gen, web
9 from tornado.concurrent import Future
9 10
10 11 from IPython.utils.jsonutil import date_default
11 from IPython.utils.py3compat import string_types
12 from IPython.utils.py3compat import cast_unicode
12 13 from IPython.html.utils import url_path_join, url_escape
13 14
14 15 from ...base.handlers import IPythonHandler, json_errors
15 16 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
16 17
17 18 from IPython.core.release import kernel_protocol_version
18 19
19 20 class MainKernelHandler(IPythonHandler):
20 21
21 22 @web.authenticated
22 23 @json_errors
23 24 def get(self):
24 25 km = self.kernel_manager
25 26 self.finish(json.dumps(km.list_kernels()))
26 27
27 28 @web.authenticated
28 29 @json_errors
29 30 def post(self):
30 31 km = self.kernel_manager
31 32 model = self.get_json_body()
32 33 if model is None:
33 34 model = {
34 35 'name': km.default_kernel_name
35 36 }
36 37 else:
37 38 model.setdefault('name', km.default_kernel_name)
38 39
39 40 kernel_id = km.start_kernel(kernel_name=model['name'])
40 41 model = km.kernel_model(kernel_id)
41 42 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
42 43 self.set_header('Location', url_escape(location))
43 44 self.set_status(201)
44 45 self.finish(json.dumps(model))
45 46
46 47
47 48 class KernelHandler(IPythonHandler):
48 49
49 50 SUPPORTED_METHODS = ('DELETE', 'GET')
50 51
51 52 @web.authenticated
52 53 @json_errors
53 54 def get(self, kernel_id):
54 55 km = self.kernel_manager
55 56 km._check_kernel_id(kernel_id)
56 57 model = km.kernel_model(kernel_id)
57 58 self.finish(json.dumps(model))
58 59
59 60 @web.authenticated
60 61 @json_errors
61 62 def delete(self, kernel_id):
62 63 km = self.kernel_manager
63 64 km.shutdown_kernel(kernel_id)
64 65 self.set_status(204)
65 66 self.finish()
66 67
67 68
68 69 class KernelActionHandler(IPythonHandler):
69 70
70 71 @web.authenticated
71 72 @json_errors
72 73 def post(self, kernel_id, action):
73 74 km = self.kernel_manager
74 75 if action == 'interrupt':
75 76 km.interrupt_kernel(kernel_id)
76 77 self.set_status(204)
77 78 if action == 'restart':
78 79 km.restart_kernel(kernel_id)
79 80 model = km.kernel_model(kernel_id)
80 81 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
81 82 self.write(json.dumps(model))
82 83 self.finish()
83 84
84 85
85 86 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
86 87
87 88 def __repr__(self):
88 89 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
89 90
90 91 def create_stream(self):
91 92 km = self.kernel_manager
92 93 meth = getattr(km, 'connect_%s' % self.channel)
93 94 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
94 # Create a kernel_info channel to query the kernel protocol version.
95 # This channel will be closed after the kernel_info reply is received.
96 self.kernel_info_channel = None
97 self.kernel_info_channel = km.connect_shell(self.kernel_id)
98 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
99 self._request_kernel_info()
100 95
101 def _request_kernel_info(self):
96 def request_kernel_info(self):
102 97 """send a request for kernel_info"""
103 self.log.debug("requesting kernel info")
104 self.session.send(self.kernel_info_channel, "kernel_info_request")
98 km = self.kernel_manager
99 kernel = km.get_kernel(self.kernel_id)
100 try:
101 # check for cached value
102 kernel_info = kernel._kernel_info
103 except AttributeError:
104 self.log.debug("Requesting kernel info from %s", self.kernel_id)
105 # Create a kernel_info channel to query the kernel protocol version.
106 # This channel will be closed after the kernel_info reply is received.
107 if self.kernel_info_channel is None:
108 self.kernel_info_channel = km.connect_shell(self.kernel_id)
109 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
110 self.session.send(self.kernel_info_channel, "kernel_info_request")
111 else:
112 # use cached value, don't resend request
113 self._finish_kernel_info(kernel_info)
114 return self._kernel_info_future
105 115
106 116 def _handle_kernel_info_reply(self, msg):
107 117 """process the kernel_info_reply
108 118
109 119 enabling msg spec adaptation, if necessary
110 120 """
111 121 idents,msg = self.session.feed_identities(msg)
112 122 try:
113 123 msg = self.session.deserialize(msg)
114 124 except:
115 125 self.log.error("Bad kernel_info reply", exc_info=True)
116 self._request_kernel_info()
126 self.request_kernel_info()
117 127 return
118 128 else:
119 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in msg['content']:
120 self.log.error("Kernel info request failed, assuming current %s", msg['content'])
129 info = msg['content']
130 self.log.debug("Received kernel info: %s", info)
131 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
132 self.log.error("Kernel info request failed, assuming current %s", info)
121 133 else:
122 protocol_version = msg['content']['protocol_version']
123 if protocol_version != kernel_protocol_version:
124 self.session.adapt_version = int(protocol_version.split('.')[0])
125 self.log.info("adapting kernel to %s" % protocol_version)
126 self.kernel_info_channel.close()
134 kernel = self.kernel_manager.get_kernel(self.kernel_id)
135 kernel._kernel_info = info
136 self._finish_kernel_info(info)
137
138 # close the kernel_info channel, we don't need it anymore
139 if self.kernel_info_channel:
140 self.kernel_info_channel.close()
127 141 self.kernel_info_channel = None
128 142
143 def _finish_kernel_info(self, info):
144 """Finish handling kernel_info reply
145
146 Set up protocol adaptation, if needed,
147 and signal that connection can continue.
148 """
149 protocol_version = info.get('protocol_version', kernel_protocol_version)
150 if protocol_version != kernel_protocol_version:
151 self.session.adapt_version = int(protocol_version.split('.')[0])
152 self.log.info("Kernel %s speaks protocol %s", self.kernel_id, protocol_version)
153 self._kernel_info_future.set_result(info)
154
129 155 def initialize(self):
130 156 super(ZMQChannelHandler, self).initialize()
131 157 self.zmq_stream = None
158 self.kernel_info_channel = None
159 self._kernel_info_future = Future()
160
161 @gen.coroutine
162 def get(self, kernel_id):
163 self.kernel_id = cast_unicode(kernel_id, 'ascii')
164 yield self.request_kernel_info()
165 super(ZMQChannelHandler, self).get(kernel_id)
132 166
133 167 def open(self, kernel_id):
134 super(ZMQChannelHandler, self).open(kernel_id)
168 super(ZMQChannelHandler, self).open()
135 169 try:
136 170 self.create_stream()
137 171 except web.HTTPError:
138 172 # WebSockets don't response to traditional error codes so we
139 173 # close the connection.
140 174 if not self.stream.closed():
141 175 self.stream.close()
142 176 self.close()
143 177 else:
144 178 self.zmq_stream.on_recv(self._on_zmq_reply)
145 179
146 180 def on_message(self, msg):
147 181 if self.zmq_stream is None:
148 182 return
149 183 elif self.zmq_stream.closed():
150 184 self.log.info("%s closed, closing websocket.", self)
151 185 self.close()
152 186 return
153 187 if isinstance(msg, bytes):
154 188 msg = deserialize_binary_message(msg)
155 189 else:
156 190 msg = json.loads(msg)
157 191 self.session.send(self.zmq_stream, msg)
158 192
159 193 def on_close(self):
160 194 # This method can be called twice, once by self.kernel_died and once
161 195 # from the WebSocket close event. If the WebSocket connection is
162 196 # closed before the ZMQ streams are setup, they could be None.
163 197 if self.zmq_stream is not None and not self.zmq_stream.closed():
164 198 self.zmq_stream.on_recv(None)
165 199 # close the socket directly, don't wait for the stream
166 200 socket = self.zmq_stream.socket
167 201 self.zmq_stream.close()
168 202 socket.close()
169 203
170 204
171 205 class IOPubHandler(ZMQChannelHandler):
172 206 channel = 'iopub'
173 207
174 208 def create_stream(self):
175 209 super(IOPubHandler, self).create_stream()
176 210 km = self.kernel_manager
177 211 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
178 212 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
179 213
180 214 def on_close(self):
181 215 km = self.kernel_manager
182 216 if self.kernel_id in km:
183 217 km.remove_restart_callback(
184 218 self.kernel_id, self.on_kernel_restarted,
185 219 )
186 220 km.remove_restart_callback(
187 221 self.kernel_id, self.on_restart_failed, 'dead',
188 222 )
189 223 super(IOPubHandler, self).on_close()
190 224
191 225 def _send_status_message(self, status):
192 226 msg = self.session.msg("status",
193 227 {'execution_state': status}
194 228 )
195 229 self.write_message(json.dumps(msg, default=date_default))
196 230
197 231 def on_kernel_restarted(self):
198 232 logging.warn("kernel %s restarted", self.kernel_id)
199 233 self._send_status_message('restarting')
200 234
201 235 def on_restart_failed(self):
202 236 logging.error("kernel %s restarted failed!", self.kernel_id)
203 237 self._send_status_message('dead')
204 238
205 239 def on_message(self, msg):
206 240 """IOPub messages make no sense"""
207 241 pass
208 242
209 243
210 244 class ShellHandler(ZMQChannelHandler):
211 245 channel = 'shell'
212 246
213 247
214 248 class StdinHandler(ZMQChannelHandler):
215 249 channel = 'stdin'
216 250
217 251
218 252 #-----------------------------------------------------------------------------
219 253 # URL to handler mappings
220 254 #-----------------------------------------------------------------------------
221 255
222 256
223 257 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
224 258 _kernel_action_regex = r"(?P<action>restart|interrupt)"
225 259
226 260 default_handlers = [
227 261 (r"/api/kernels", MainKernelHandler),
228 262 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
229 263 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
230 264 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
231 265 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
232 266 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
233 267 ]
General Comments 0
You need to be logged in to leave comments. Login now