##// END OF EJS Templates
cache kernel_info reply for protocol adaptation...
MinRK -
Show More
@@ -1,257 +1,256 b''
1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import json
6 import json
7 import struct
7 import struct
8
8
9 try:
9 try:
10 from urllib.parse import urlparse # Py 3
10 from urllib.parse import urlparse # Py 3
11 except ImportError:
11 except ImportError:
12 from urlparse import urlparse # Py 2
12 from urlparse import urlparse # Py 2
13
13
14 try:
14 try:
15 from http.cookies import SimpleCookie # Py 3
15 from http.cookies import SimpleCookie # Py 3
16 except ImportError:
16 except ImportError:
17 from Cookie import SimpleCookie # Py 2
17 from Cookie import SimpleCookie # Py 2
18 import logging
18 import logging
19
19
20 import tornado
20 import tornado
21 from tornado import ioloop
21 from tornado import ioloop
22 from tornado import web
22 from tornado import web
23 from tornado import websocket
23 from tornado import websocket
24
24
25 from IPython.kernel.zmq.session import Session
25 from IPython.kernel.zmq.session import Session
26 from IPython.utils.jsonutil import date_default, extract_dates
26 from IPython.utils.jsonutil import date_default, extract_dates
27 from IPython.utils.py3compat import PY3, cast_unicode
27 from IPython.utils.py3compat import PY3, cast_unicode
28
28
29 from .handlers import IPythonHandler
29 from .handlers import IPythonHandler
30
30
31
31
32 def serialize_binary_message(msg):
32 def serialize_binary_message(msg):
33 """serialize a message as a binary blob
33 """serialize a message as a binary blob
34
34
35 Header:
35 Header:
36
36
37 4 bytes: number of msg parts (nbufs) as 32b int
37 4 bytes: number of msg parts (nbufs) as 32b int
38 4 * nbufs bytes: offset for each buffer as integer as 32b int
38 4 * nbufs bytes: offset for each buffer as integer as 32b int
39
39
40 Offsets are from the start of the buffer, including the header.
40 Offsets are from the start of the buffer, including the header.
41
41
42 Returns
42 Returns
43 -------
43 -------
44
44
45 The message serialized to bytes.
45 The message serialized to bytes.
46
46
47 """
47 """
48 # don't modify msg or buffer list in-place
48 # don't modify msg or buffer list in-place
49 msg = msg.copy()
49 msg = msg.copy()
50 buffers = list(msg.pop('buffers'))
50 buffers = list(msg.pop('buffers'))
51 bmsg = json.dumps(msg, default=date_default).encode('utf8')
51 bmsg = json.dumps(msg, default=date_default).encode('utf8')
52 buffers.insert(0, bmsg)
52 buffers.insert(0, bmsg)
53 nbufs = len(buffers)
53 nbufs = len(buffers)
54 offsets = [4 * (nbufs + 1)]
54 offsets = [4 * (nbufs + 1)]
55 for buf in buffers[:-1]:
55 for buf in buffers[:-1]:
56 offsets.append(offsets[-1] + len(buf))
56 offsets.append(offsets[-1] + len(buf))
57 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
57 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
58 buffers.insert(0, offsets_buf)
58 buffers.insert(0, offsets_buf)
59 return b''.join(buffers)
59 return b''.join(buffers)
60
60
61
61
62 def deserialize_binary_message(bmsg):
62 def deserialize_binary_message(bmsg):
63 """deserialize a message from a binary blog
63 """deserialize a message from a binary blog
64
64
65 Header:
65 Header:
66
66
67 4 bytes: number of msg parts (nbufs) as 32b int
67 4 bytes: number of msg parts (nbufs) as 32b int
68 4 * nbufs bytes: offset for each buffer as integer as 32b int
68 4 * nbufs bytes: offset for each buffer as integer as 32b int
69
69
70 Offsets are from the start of the buffer, including the header.
70 Offsets are from the start of the buffer, including the header.
71
71
72 Returns
72 Returns
73 -------
73 -------
74
74
75 message dictionary
75 message dictionary
76 """
76 """
77 nbufs = struct.unpack('!i', bmsg[:4])[0]
77 nbufs = struct.unpack('!i', bmsg[:4])[0]
78 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
78 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
79 offsets.append(None)
79 offsets.append(None)
80 bufs = []
80 bufs = []
81 for start, stop in zip(offsets[:-1], offsets[1:]):
81 for start, stop in zip(offsets[:-1], offsets[1:]):
82 bufs.append(bmsg[start:stop])
82 bufs.append(bmsg[start:stop])
83 msg = json.loads(bufs[0].decode('utf8'))
83 msg = json.loads(bufs[0].decode('utf8'))
84 msg['header'] = extract_dates(msg['header'])
84 msg['header'] = extract_dates(msg['header'])
85 msg['parent_header'] = extract_dates(msg['parent_header'])
85 msg['parent_header'] = extract_dates(msg['parent_header'])
86 msg['buffers'] = bufs[1:]
86 msg['buffers'] = bufs[1:]
87 return msg
87 return msg
88
88
89
89
90 class ZMQStreamHandler(websocket.WebSocketHandler):
90 class ZMQStreamHandler(websocket.WebSocketHandler):
91
91
92 def check_origin(self, origin):
92 def check_origin(self, origin):
93 """Check Origin == Host or Access-Control-Allow-Origin.
93 """Check Origin == Host or Access-Control-Allow-Origin.
94
94
95 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
95 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
96 We call it explicitly in `open` on Tornado < 4.
96 We call it explicitly in `open` on Tornado < 4.
97 """
97 """
98 if self.allow_origin == '*':
98 if self.allow_origin == '*':
99 return True
99 return True
100
100
101 host = self.request.headers.get("Host")
101 host = self.request.headers.get("Host")
102
102
103 # If no header is provided, assume we can't verify origin
103 # If no header is provided, assume we can't verify origin
104 if origin is None:
104 if origin is None:
105 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
105 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
106 return False
106 return False
107 if host is None:
107 if host is None:
108 self.log.warn("Missing Host header, rejecting WebSocket connection.")
108 self.log.warn("Missing Host header, rejecting WebSocket connection.")
109 return False
109 return False
110
110
111 origin = origin.lower()
111 origin = origin.lower()
112 origin_host = urlparse(origin).netloc
112 origin_host = urlparse(origin).netloc
113
113
114 # OK if origin matches host
114 # OK if origin matches host
115 if origin_host == host:
115 if origin_host == host:
116 return True
116 return True
117
117
118 # Check CORS headers
118 # Check CORS headers
119 if self.allow_origin:
119 if self.allow_origin:
120 allow = self.allow_origin == origin
120 allow = self.allow_origin == origin
121 elif self.allow_origin_pat:
121 elif self.allow_origin_pat:
122 allow = bool(self.allow_origin_pat.match(origin))
122 allow = bool(self.allow_origin_pat.match(origin))
123 else:
123 else:
124 # No CORS headers deny the request
124 # No CORS headers deny the request
125 allow = False
125 allow = False
126 if not allow:
126 if not allow:
127 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
127 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
128 origin, host,
128 origin, host,
129 )
129 )
130 return allow
130 return allow
131
131
132 def clear_cookie(self, *args, **kwargs):
132 def clear_cookie(self, *args, **kwargs):
133 """meaningless for websockets"""
133 """meaningless for websockets"""
134 pass
134 pass
135
135
136 def _reserialize_reply(self, msg_list):
136 def _reserialize_reply(self, msg_list):
137 """Reserialize a reply message using JSON.
137 """Reserialize a reply message using JSON.
138
138
139 This takes the msg list from the ZMQ socket, deserializes it using
139 This takes the msg list from the ZMQ socket, deserializes it using
140 self.session and then serializes the result using JSON. This method
140 self.session and then serializes the result using JSON. This method
141 should be used by self._on_zmq_reply to build messages that can
141 should be used by self._on_zmq_reply to build messages that can
142 be sent back to the browser.
142 be sent back to the browser.
143 """
143 """
144 idents, msg_list = self.session.feed_identities(msg_list)
144 idents, msg_list = self.session.feed_identities(msg_list)
145 msg = self.session.deserialize(msg_list)
145 msg = self.session.deserialize(msg_list)
146 if msg['buffers']:
146 if msg['buffers']:
147 buf = serialize_binary_message(msg)
147 buf = serialize_binary_message(msg)
148 return buf
148 return buf
149 else:
149 else:
150 smsg = json.dumps(msg, default=date_default)
150 smsg = json.dumps(msg, default=date_default)
151 return cast_unicode(smsg)
151 return cast_unicode(smsg)
152
152
153 def _on_zmq_reply(self, msg_list):
153 def _on_zmq_reply(self, msg_list):
154 # Sometimes this gets triggered when the on_close method is scheduled in the
154 # Sometimes this gets triggered when the on_close method is scheduled in the
155 # eventloop but hasn't been called.
155 # eventloop but hasn't been called.
156 if self.stream.closed(): return
156 if self.stream.closed(): return
157 try:
157 try:
158 msg = self._reserialize_reply(msg_list)
158 msg = self._reserialize_reply(msg_list)
159 except Exception:
159 except Exception:
160 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
160 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
161 else:
161 else:
162 self.write_message(msg, binary=isinstance(msg, bytes))
162 self.write_message(msg, binary=isinstance(msg, bytes))
163
163
164 def allow_draft76(self):
164 def allow_draft76(self):
165 """Allow draft 76, until browsers such as Safari update to RFC 6455.
165 """Allow draft 76, until browsers such as Safari update to RFC 6455.
166
166
167 This has been disabled by default in tornado in release 2.2.0, and
167 This has been disabled by default in tornado in release 2.2.0, and
168 support will be removed in later versions.
168 support will be removed in later versions.
169 """
169 """
170 return True
170 return True
171
171
172 # ping interval for keeping websockets alive (30 seconds)
172 # ping interval for keeping websockets alive (30 seconds)
173 WS_PING_INTERVAL = 30000
173 WS_PING_INTERVAL = 30000
174
174
175 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
175 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
176 ping_callback = None
176 ping_callback = None
177 last_ping = 0
177 last_ping = 0
178 last_pong = 0
178 last_pong = 0
179
179
180 @property
180 @property
181 def ping_interval(self):
181 def ping_interval(self):
182 """The interval for websocket keep-alive pings.
182 """The interval for websocket keep-alive pings.
183
183
184 Set ws_ping_interval = 0 to disable pings.
184 Set ws_ping_interval = 0 to disable pings.
185 """
185 """
186 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
186 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
187
187
188 @property
188 @property
189 def ping_timeout(self):
189 def ping_timeout(self):
190 """If no ping is received in this many milliseconds,
190 """If no ping is received in this many milliseconds,
191 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
191 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
192 Default is max of 3 pings or 30 seconds.
192 Default is max of 3 pings or 30 seconds.
193 """
193 """
194 return self.settings.get('ws_ping_timeout',
194 return self.settings.get('ws_ping_timeout',
195 max(3 * self.ping_interval, WS_PING_INTERVAL)
195 max(3 * self.ping_interval, WS_PING_INTERVAL)
196 )
196 )
197
197
198 def set_default_headers(self):
198 def set_default_headers(self):
199 """Undo the set_default_headers in IPythonHandler
199 """Undo the set_default_headers in IPythonHandler
200
200
201 which doesn't make sense for websockets
201 which doesn't make sense for websockets
202 """
202 """
203 pass
203 pass
204
204
205 def get(self, *args, **kwargs):
205 def get(self, *args, **kwargs):
206 # Check to see that origin matches host directly, including ports
206 # Check to see that origin matches host directly, including ports
207 # Tornado 4 already does CORS checking
207 # Tornado 4 already does CORS checking
208 if tornado.version_info[0] < 4:
208 if tornado.version_info[0] < 4:
209 if not self.check_origin(self.get_origin()):
209 if not self.check_origin(self.get_origin()):
210 raise web.HTTPError(403)
210 raise web.HTTPError(403)
211
211
212 # authenticate the request before opening the websocket
212 # authenticate the request before opening the websocket
213 if self.get_current_user() is None:
213 if self.get_current_user() is None:
214 self.log.warn("Couldn't authenticate WebSocket connection")
214 self.log.warn("Couldn't authenticate WebSocket connection")
215 raise web.HTTPError(403)
215 raise web.HTTPError(403)
216
216
217 if self.get_argument('session_id', False):
217 if self.get_argument('session_id', False):
218 self.session.session = cast_unicode(self.get_argument('session_id'))
218 self.session.session = cast_unicode(self.get_argument('session_id'))
219 else:
219 else:
220 self.log.warn("No session ID specified")
220 self.log.warn("No session ID specified")
221
221
222 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
222 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
223
223
224 def initialize(self):
224 def initialize(self):
225 self.session = Session(config=self.config)
225 self.session = Session(config=self.config)
226
226
227 def open(self, kernel_id):
227 def open(self, *args, **kwargs):
228 self.kernel_id = cast_unicode(kernel_id, 'ascii')
229
228
230 # start the pinging
229 # start the pinging
231 if self.ping_interval > 0:
230 if self.ping_interval > 0:
232 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
231 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
233 self.last_pong = self.last_ping
232 self.last_pong = self.last_ping
234 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
233 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
235 self.ping_callback.start()
234 self.ping_callback.start()
236
235
237 def send_ping(self):
236 def send_ping(self):
238 """send a ping to keep the websocket alive"""
237 """send a ping to keep the websocket alive"""
239 if self.stream.closed() and self.ping_callback is not None:
238 if self.stream.closed() and self.ping_callback is not None:
240 self.ping_callback.stop()
239 self.ping_callback.stop()
241 return
240 return
242
241
243 # check for timeout on pong. Make sure that we really have sent a recent ping in
242 # check for timeout on pong. Make sure that we really have sent a recent ping in
244 # case the machine with both server and client has been suspended since the last ping.
243 # case the machine with both server and client has been suspended since the last ping.
245 now = ioloop.IOLoop.instance().time()
244 now = ioloop.IOLoop.instance().time()
246 since_last_pong = 1e3 * (now - self.last_pong)
245 since_last_pong = 1e3 * (now - self.last_pong)
247 since_last_ping = 1e3 * (now - self.last_ping)
246 since_last_ping = 1e3 * (now - self.last_ping)
248 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
247 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
249 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
248 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
250 self.close()
249 self.close()
251 return
250 return
252
251
253 self.ping(b'')
252 self.ping(b'')
254 self.last_ping = now
253 self.last_ping = now
255
254
256 def on_pong(self, data):
255 def on_pong(self, data):
257 self.last_pong = ioloop.IOLoop.instance().time()
256 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,233 +1,267 b''
1 """Tornado handlers for kernels."""
1 """Tornado handlers for kernels."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import json
6 import json
7 import logging
7 import logging
8 from tornado import web
8 from tornado import gen, web
9 from tornado.concurrent import Future
9
10
10 from IPython.utils.jsonutil import date_default
11 from IPython.utils.jsonutil import date_default
11 from IPython.utils.py3compat import string_types
12 from IPython.utils.py3compat import cast_unicode
12 from IPython.html.utils import url_path_join, url_escape
13 from IPython.html.utils import url_path_join, url_escape
13
14
14 from ...base.handlers import IPythonHandler, json_errors
15 from ...base.handlers import IPythonHandler, json_errors
15 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
16 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
16
17
17 from IPython.core.release import kernel_protocol_version
18 from IPython.core.release import kernel_protocol_version
18
19
19 class MainKernelHandler(IPythonHandler):
20 class MainKernelHandler(IPythonHandler):
20
21
21 @web.authenticated
22 @web.authenticated
22 @json_errors
23 @json_errors
23 def get(self):
24 def get(self):
24 km = self.kernel_manager
25 km = self.kernel_manager
25 self.finish(json.dumps(km.list_kernels()))
26 self.finish(json.dumps(km.list_kernels()))
26
27
27 @web.authenticated
28 @web.authenticated
28 @json_errors
29 @json_errors
29 def post(self):
30 def post(self):
30 km = self.kernel_manager
31 km = self.kernel_manager
31 model = self.get_json_body()
32 model = self.get_json_body()
32 if model is None:
33 if model is None:
33 model = {
34 model = {
34 'name': km.default_kernel_name
35 'name': km.default_kernel_name
35 }
36 }
36 else:
37 else:
37 model.setdefault('name', km.default_kernel_name)
38 model.setdefault('name', km.default_kernel_name)
38
39
39 kernel_id = km.start_kernel(kernel_name=model['name'])
40 kernel_id = km.start_kernel(kernel_name=model['name'])
40 model = km.kernel_model(kernel_id)
41 model = km.kernel_model(kernel_id)
41 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
42 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
42 self.set_header('Location', url_escape(location))
43 self.set_header('Location', url_escape(location))
43 self.set_status(201)
44 self.set_status(201)
44 self.finish(json.dumps(model))
45 self.finish(json.dumps(model))
45
46
46
47
47 class KernelHandler(IPythonHandler):
48 class KernelHandler(IPythonHandler):
48
49
49 SUPPORTED_METHODS = ('DELETE', 'GET')
50 SUPPORTED_METHODS = ('DELETE', 'GET')
50
51
51 @web.authenticated
52 @web.authenticated
52 @json_errors
53 @json_errors
53 def get(self, kernel_id):
54 def get(self, kernel_id):
54 km = self.kernel_manager
55 km = self.kernel_manager
55 km._check_kernel_id(kernel_id)
56 km._check_kernel_id(kernel_id)
56 model = km.kernel_model(kernel_id)
57 model = km.kernel_model(kernel_id)
57 self.finish(json.dumps(model))
58 self.finish(json.dumps(model))
58
59
59 @web.authenticated
60 @web.authenticated
60 @json_errors
61 @json_errors
61 def delete(self, kernel_id):
62 def delete(self, kernel_id):
62 km = self.kernel_manager
63 km = self.kernel_manager
63 km.shutdown_kernel(kernel_id)
64 km.shutdown_kernel(kernel_id)
64 self.set_status(204)
65 self.set_status(204)
65 self.finish()
66 self.finish()
66
67
67
68
68 class KernelActionHandler(IPythonHandler):
69 class KernelActionHandler(IPythonHandler):
69
70
70 @web.authenticated
71 @web.authenticated
71 @json_errors
72 @json_errors
72 def post(self, kernel_id, action):
73 def post(self, kernel_id, action):
73 km = self.kernel_manager
74 km = self.kernel_manager
74 if action == 'interrupt':
75 if action == 'interrupt':
75 km.interrupt_kernel(kernel_id)
76 km.interrupt_kernel(kernel_id)
76 self.set_status(204)
77 self.set_status(204)
77 if action == 'restart':
78 if action == 'restart':
78 km.restart_kernel(kernel_id)
79 km.restart_kernel(kernel_id)
79 model = km.kernel_model(kernel_id)
80 model = km.kernel_model(kernel_id)
80 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
81 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
81 self.write(json.dumps(model))
82 self.write(json.dumps(model))
82 self.finish()
83 self.finish()
83
84
84
85
85 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
86 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
86
87
87 def __repr__(self):
88 def __repr__(self):
88 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
89 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
89
90
90 def create_stream(self):
91 def create_stream(self):
91 km = self.kernel_manager
92 km = self.kernel_manager
92 meth = getattr(km, 'connect_%s' % self.channel)
93 meth = getattr(km, 'connect_%s' % self.channel)
93 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
94 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
94 # Create a kernel_info channel to query the kernel protocol version.
95 # This channel will be closed after the kernel_info reply is received.
96 self.kernel_info_channel = None
97 self.kernel_info_channel = km.connect_shell(self.kernel_id)
98 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
99 self._request_kernel_info()
100
95
101 def _request_kernel_info(self):
96 def request_kernel_info(self):
102 """send a request for kernel_info"""
97 """send a request for kernel_info"""
103 self.log.debug("requesting kernel info")
98 km = self.kernel_manager
104 self.session.send(self.kernel_info_channel, "kernel_info_request")
99 kernel = km.get_kernel(self.kernel_id)
100 try:
101 # check for cached value
102 kernel_info = kernel._kernel_info
103 except AttributeError:
104 self.log.debug("Requesting kernel info from %s", self.kernel_id)
105 # Create a kernel_info channel to query the kernel protocol version.
106 # This channel will be closed after the kernel_info reply is received.
107 if self.kernel_info_channel is None:
108 self.kernel_info_channel = km.connect_shell(self.kernel_id)
109 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
110 self.session.send(self.kernel_info_channel, "kernel_info_request")
111 else:
112 # use cached value, don't resend request
113 self._finish_kernel_info(kernel_info)
114 return self._kernel_info_future
105
115
106 def _handle_kernel_info_reply(self, msg):
116 def _handle_kernel_info_reply(self, msg):
107 """process the kernel_info_reply
117 """process the kernel_info_reply
108
118
109 enabling msg spec adaptation, if necessary
119 enabling msg spec adaptation, if necessary
110 """
120 """
111 idents,msg = self.session.feed_identities(msg)
121 idents,msg = self.session.feed_identities(msg)
112 try:
122 try:
113 msg = self.session.deserialize(msg)
123 msg = self.session.deserialize(msg)
114 except:
124 except:
115 self.log.error("Bad kernel_info reply", exc_info=True)
125 self.log.error("Bad kernel_info reply", exc_info=True)
116 self._request_kernel_info()
126 self.request_kernel_info()
117 return
127 return
118 else:
128 else:
119 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in msg['content']:
129 info = msg['content']
120 self.log.error("Kernel info request failed, assuming current %s", msg['content'])
130 self.log.debug("Received kernel info: %s", info)
131 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
132 self.log.error("Kernel info request failed, assuming current %s", info)
121 else:
133 else:
122 protocol_version = msg['content']['protocol_version']
134 kernel = self.kernel_manager.get_kernel(self.kernel_id)
123 if protocol_version != kernel_protocol_version:
135 kernel._kernel_info = info
124 self.session.adapt_version = int(protocol_version.split('.')[0])
136 self._finish_kernel_info(info)
125 self.log.info("adapting kernel to %s" % protocol_version)
137
126 self.kernel_info_channel.close()
138 # close the kernel_info channel, we don't need it anymore
139 if self.kernel_info_channel:
140 self.kernel_info_channel.close()
127 self.kernel_info_channel = None
141 self.kernel_info_channel = None
128
142
143 def _finish_kernel_info(self, info):
144 """Finish handling kernel_info reply
145
146 Set up protocol adaptation, if needed,
147 and signal that connection can continue.
148 """
149 protocol_version = info.get('protocol_version', kernel_protocol_version)
150 if protocol_version != kernel_protocol_version:
151 self.session.adapt_version = int(protocol_version.split('.')[0])
152 self.log.info("Kernel %s speaks protocol %s", self.kernel_id, protocol_version)
153 self._kernel_info_future.set_result(info)
154
129 def initialize(self):
155 def initialize(self):
130 super(ZMQChannelHandler, self).initialize()
156 super(ZMQChannelHandler, self).initialize()
131 self.zmq_stream = None
157 self.zmq_stream = None
158 self.kernel_info_channel = None
159 self._kernel_info_future = Future()
160
161 @gen.coroutine
162 def get(self, kernel_id):
163 self.kernel_id = cast_unicode(kernel_id, 'ascii')
164 yield self.request_kernel_info()
165 super(ZMQChannelHandler, self).get(kernel_id)
132
166
133 def open(self, kernel_id):
167 def open(self, kernel_id):
134 super(ZMQChannelHandler, self).open(kernel_id)
168 super(ZMQChannelHandler, self).open()
135 try:
169 try:
136 self.create_stream()
170 self.create_stream()
137 except web.HTTPError:
171 except web.HTTPError:
138 # WebSockets don't response to traditional error codes so we
172 # WebSockets don't response to traditional error codes so we
139 # close the connection.
173 # close the connection.
140 if not self.stream.closed():
174 if not self.stream.closed():
141 self.stream.close()
175 self.stream.close()
142 self.close()
176 self.close()
143 else:
177 else:
144 self.zmq_stream.on_recv(self._on_zmq_reply)
178 self.zmq_stream.on_recv(self._on_zmq_reply)
145
179
146 def on_message(self, msg):
180 def on_message(self, msg):
147 if self.zmq_stream is None:
181 if self.zmq_stream is None:
148 return
182 return
149 elif self.zmq_stream.closed():
183 elif self.zmq_stream.closed():
150 self.log.info("%s closed, closing websocket.", self)
184 self.log.info("%s closed, closing websocket.", self)
151 self.close()
185 self.close()
152 return
186 return
153 if isinstance(msg, bytes):
187 if isinstance(msg, bytes):
154 msg = deserialize_binary_message(msg)
188 msg = deserialize_binary_message(msg)
155 else:
189 else:
156 msg = json.loads(msg)
190 msg = json.loads(msg)
157 self.session.send(self.zmq_stream, msg)
191 self.session.send(self.zmq_stream, msg)
158
192
159 def on_close(self):
193 def on_close(self):
160 # This method can be called twice, once by self.kernel_died and once
194 # This method can be called twice, once by self.kernel_died and once
161 # from the WebSocket close event. If the WebSocket connection is
195 # from the WebSocket close event. If the WebSocket connection is
162 # closed before the ZMQ streams are setup, they could be None.
196 # closed before the ZMQ streams are setup, they could be None.
163 if self.zmq_stream is not None and not self.zmq_stream.closed():
197 if self.zmq_stream is not None and not self.zmq_stream.closed():
164 self.zmq_stream.on_recv(None)
198 self.zmq_stream.on_recv(None)
165 # close the socket directly, don't wait for the stream
199 # close the socket directly, don't wait for the stream
166 socket = self.zmq_stream.socket
200 socket = self.zmq_stream.socket
167 self.zmq_stream.close()
201 self.zmq_stream.close()
168 socket.close()
202 socket.close()
169
203
170
204
171 class IOPubHandler(ZMQChannelHandler):
205 class IOPubHandler(ZMQChannelHandler):
172 channel = 'iopub'
206 channel = 'iopub'
173
207
174 def create_stream(self):
208 def create_stream(self):
175 super(IOPubHandler, self).create_stream()
209 super(IOPubHandler, self).create_stream()
176 km = self.kernel_manager
210 km = self.kernel_manager
177 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
211 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
178 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
212 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
179
213
180 def on_close(self):
214 def on_close(self):
181 km = self.kernel_manager
215 km = self.kernel_manager
182 if self.kernel_id in km:
216 if self.kernel_id in km:
183 km.remove_restart_callback(
217 km.remove_restart_callback(
184 self.kernel_id, self.on_kernel_restarted,
218 self.kernel_id, self.on_kernel_restarted,
185 )
219 )
186 km.remove_restart_callback(
220 km.remove_restart_callback(
187 self.kernel_id, self.on_restart_failed, 'dead',
221 self.kernel_id, self.on_restart_failed, 'dead',
188 )
222 )
189 super(IOPubHandler, self).on_close()
223 super(IOPubHandler, self).on_close()
190
224
191 def _send_status_message(self, status):
225 def _send_status_message(self, status):
192 msg = self.session.msg("status",
226 msg = self.session.msg("status",
193 {'execution_state': status}
227 {'execution_state': status}
194 )
228 )
195 self.write_message(json.dumps(msg, default=date_default))
229 self.write_message(json.dumps(msg, default=date_default))
196
230
197 def on_kernel_restarted(self):
231 def on_kernel_restarted(self):
198 logging.warn("kernel %s restarted", self.kernel_id)
232 logging.warn("kernel %s restarted", self.kernel_id)
199 self._send_status_message('restarting')
233 self._send_status_message('restarting')
200
234
201 def on_restart_failed(self):
235 def on_restart_failed(self):
202 logging.error("kernel %s restarted failed!", self.kernel_id)
236 logging.error("kernel %s restarted failed!", self.kernel_id)
203 self._send_status_message('dead')
237 self._send_status_message('dead')
204
238
205 def on_message(self, msg):
239 def on_message(self, msg):
206 """IOPub messages make no sense"""
240 """IOPub messages make no sense"""
207 pass
241 pass
208
242
209
243
210 class ShellHandler(ZMQChannelHandler):
244 class ShellHandler(ZMQChannelHandler):
211 channel = 'shell'
245 channel = 'shell'
212
246
213
247
214 class StdinHandler(ZMQChannelHandler):
248 class StdinHandler(ZMQChannelHandler):
215 channel = 'stdin'
249 channel = 'stdin'
216
250
217
251
218 #-----------------------------------------------------------------------------
252 #-----------------------------------------------------------------------------
219 # URL to handler mappings
253 # URL to handler mappings
220 #-----------------------------------------------------------------------------
254 #-----------------------------------------------------------------------------
221
255
222
256
223 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
257 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
224 _kernel_action_regex = r"(?P<action>restart|interrupt)"
258 _kernel_action_regex = r"(?P<action>restart|interrupt)"
225
259
226 default_handlers = [
260 default_handlers = [
227 (r"/api/kernels", MainKernelHandler),
261 (r"/api/kernels", MainKernelHandler),
228 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
262 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
229 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
263 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
230 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
264 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
231 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
265 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
232 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
266 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
233 ]
267 ]
General Comments 0
You need to be logged in to leave comments. Login now