##// END OF EJS Templates
debugging websocket connections...
Min RK -
Show More
@@ -1,259 +1,269
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import json
7 import json
8 import struct
8 import struct
9
9
10 try:
10 try:
11 from urllib.parse import urlparse # Py 3
11 from urllib.parse import urlparse # Py 3
12 except ImportError:
12 except ImportError:
13 from urlparse import urlparse # Py 2
13 from urlparse import urlparse # Py 2
14
14
15 import tornado
15 import tornado
16 from tornado import ioloop
16 from tornado import gen, ioloop, web, websocket
17 from tornado import web
18 from tornado import websocket
19
17
20 from IPython.kernel.zmq.session import Session
18 from IPython.kernel.zmq.session import Session
21 from IPython.utils.jsonutil import date_default, extract_dates
19 from IPython.utils.jsonutil import date_default, extract_dates
22 from IPython.utils.py3compat import cast_unicode
20 from IPython.utils.py3compat import cast_unicode
23
21
24 from .handlers import IPythonHandler
22 from .handlers import IPythonHandler
25
23
26
24
27 def serialize_binary_message(msg):
25 def serialize_binary_message(msg):
28 """serialize a message as a binary blob
26 """serialize a message as a binary blob
29
27
30 Header:
28 Header:
31
29
32 4 bytes: number of msg parts (nbufs) as 32b int
30 4 bytes: number of msg parts (nbufs) as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
31 4 * nbufs bytes: offset for each buffer as integer as 32b int
34
32
35 Offsets are from the start of the buffer, including the header.
33 Offsets are from the start of the buffer, including the header.
36
34
37 Returns
35 Returns
38 -------
36 -------
39
37
40 The message serialized to bytes.
38 The message serialized to bytes.
41
39
42 """
40 """
43 # don't modify msg or buffer list in-place
41 # don't modify msg or buffer list in-place
44 msg = msg.copy()
42 msg = msg.copy()
45 buffers = list(msg.pop('buffers'))
43 buffers = list(msg.pop('buffers'))
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
44 bmsg = json.dumps(msg, default=date_default).encode('utf8')
47 buffers.insert(0, bmsg)
45 buffers.insert(0, bmsg)
48 nbufs = len(buffers)
46 nbufs = len(buffers)
49 offsets = [4 * (nbufs + 1)]
47 offsets = [4 * (nbufs + 1)]
50 for buf in buffers[:-1]:
48 for buf in buffers[:-1]:
51 offsets.append(offsets[-1] + len(buf))
49 offsets.append(offsets[-1] + len(buf))
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
50 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
53 buffers.insert(0, offsets_buf)
51 buffers.insert(0, offsets_buf)
54 return b''.join(buffers)
52 return b''.join(buffers)
55
53
56
54
57 def deserialize_binary_message(bmsg):
55 def deserialize_binary_message(bmsg):
58 """deserialize a message from a binary blog
56 """deserialize a message from a binary blog
59
57
60 Header:
58 Header:
61
59
62 4 bytes: number of msg parts (nbufs) as 32b int
60 4 bytes: number of msg parts (nbufs) as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
61 4 * nbufs bytes: offset for each buffer as integer as 32b int
64
62
65 Offsets are from the start of the buffer, including the header.
63 Offsets are from the start of the buffer, including the header.
66
64
67 Returns
65 Returns
68 -------
66 -------
69
67
70 message dictionary
68 message dictionary
71 """
69 """
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
70 nbufs = struct.unpack('!i', bmsg[:4])[0]
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
71 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
74 offsets.append(None)
72 offsets.append(None)
75 bufs = []
73 bufs = []
76 for start, stop in zip(offsets[:-1], offsets[1:]):
74 for start, stop in zip(offsets[:-1], offsets[1:]):
77 bufs.append(bmsg[start:stop])
75 bufs.append(bmsg[start:stop])
78 msg = json.loads(bufs[0].decode('utf8'))
76 msg = json.loads(bufs[0].decode('utf8'))
79 msg['header'] = extract_dates(msg['header'])
77 msg['header'] = extract_dates(msg['header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
78 msg['parent_header'] = extract_dates(msg['parent_header'])
81 msg['buffers'] = bufs[1:]
79 msg['buffers'] = bufs[1:]
82 return msg
80 return msg
83
81
84
82
85 class ZMQStreamHandler(websocket.WebSocketHandler):
83 class ZMQStreamHandler(websocket.WebSocketHandler):
86
84
87 def check_origin(self, origin):
85 def check_origin(self, origin):
88 """Check Origin == Host or Access-Control-Allow-Origin.
86 """Check Origin == Host or Access-Control-Allow-Origin.
89
87
90 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
88 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
91 We call it explicitly in `open` on Tornado < 4.
89 We call it explicitly in `open` on Tornado < 4.
92 """
90 """
93 if self.allow_origin == '*':
91 if self.allow_origin == '*':
94 return True
92 return True
95
93
96 host = self.request.headers.get("Host")
94 host = self.request.headers.get("Host")
97
95
98 # If no header is provided, assume we can't verify origin
96 # If no header is provided, assume we can't verify origin
99 if origin is None:
97 if origin is None:
100 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
98 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
101 return False
99 return False
102 if host is None:
100 if host is None:
103 self.log.warn("Missing Host header, rejecting WebSocket connection.")
101 self.log.warn("Missing Host header, rejecting WebSocket connection.")
104 return False
102 return False
105
103
106 origin = origin.lower()
104 origin = origin.lower()
107 origin_host = urlparse(origin).netloc
105 origin_host = urlparse(origin).netloc
108
106
109 # OK if origin matches host
107 # OK if origin matches host
110 if origin_host == host:
108 if origin_host == host:
111 return True
109 return True
112
110
113 # Check CORS headers
111 # Check CORS headers
114 if self.allow_origin:
112 if self.allow_origin:
115 allow = self.allow_origin == origin
113 allow = self.allow_origin == origin
116 elif self.allow_origin_pat:
114 elif self.allow_origin_pat:
117 allow = bool(self.allow_origin_pat.match(origin))
115 allow = bool(self.allow_origin_pat.match(origin))
118 else:
116 else:
119 # No CORS headers deny the request
117 # No CORS headers deny the request
120 allow = False
118 allow = False
121 if not allow:
119 if not allow:
122 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
120 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
123 origin, host,
121 origin, host,
124 )
122 )
125 return allow
123 return allow
126
124
127 def clear_cookie(self, *args, **kwargs):
125 def clear_cookie(self, *args, **kwargs):
128 """meaningless for websockets"""
126 """meaningless for websockets"""
129 pass
127 pass
130
128
131 def _reserialize_reply(self, msg_list):
129 def _reserialize_reply(self, msg_list):
132 """Reserialize a reply message using JSON.
130 """Reserialize a reply message using JSON.
133
131
134 This takes the msg list from the ZMQ socket, deserializes it using
132 This takes the msg list from the ZMQ socket, deserializes it using
135 self.session and then serializes the result using JSON. This method
133 self.session and then serializes the result using JSON. This method
136 should be used by self._on_zmq_reply to build messages that can
134 should be used by self._on_zmq_reply to build messages that can
137 be sent back to the browser.
135 be sent back to the browser.
138 """
136 """
139 idents, msg_list = self.session.feed_identities(msg_list)
137 idents, msg_list = self.session.feed_identities(msg_list)
140 msg = self.session.deserialize(msg_list)
138 msg = self.session.deserialize(msg_list)
141 if msg['buffers']:
139 if msg['buffers']:
142 buf = serialize_binary_message(msg)
140 buf = serialize_binary_message(msg)
143 return buf
141 return buf
144 else:
142 else:
145 smsg = json.dumps(msg, default=date_default)
143 smsg = json.dumps(msg, default=date_default)
146 return cast_unicode(smsg)
144 return cast_unicode(smsg)
147
145
148 def _on_zmq_reply(self, msg_list):
146 def _on_zmq_reply(self, msg_list):
149 # Sometimes this gets triggered when the on_close method is scheduled in the
147 # Sometimes this gets triggered when the on_close method is scheduled in the
150 # eventloop but hasn't been called.
148 # eventloop but hasn't been called.
151 if self.stream.closed(): return
149 if self.stream.closed(): return
152 try:
150 try:
153 msg = self._reserialize_reply(msg_list)
151 msg = self._reserialize_reply(msg_list)
154 except Exception:
152 except Exception:
155 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
153 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
156 else:
154 else:
157 self.write_message(msg, binary=isinstance(msg, bytes))
155 self.write_message(msg, binary=isinstance(msg, bytes))
158
156
159 def allow_draft76(self):
157 def allow_draft76(self):
160 """Allow draft 76, until browsers such as Safari update to RFC 6455.
158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
161
159
162 This has been disabled by default in tornado in release 2.2.0, and
160 This has been disabled by default in tornado in release 2.2.0, and
163 support will be removed in later versions.
161 support will be removed in later versions.
164 """
162 """
165 return True
163 return True
166
164
167 # ping interval for keeping websockets alive (30 seconds)
165 # ping interval for keeping websockets alive (30 seconds)
168 WS_PING_INTERVAL = 30000
166 WS_PING_INTERVAL = 30000
169
167
170 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
171 ping_callback = None
169 ping_callback = None
172 last_ping = 0
170 last_ping = 0
173 last_pong = 0
171 last_pong = 0
174
172
175 @property
173 @property
176 def ping_interval(self):
174 def ping_interval(self):
177 """The interval for websocket keep-alive pings.
175 """The interval for websocket keep-alive pings.
178
176
179 Set ws_ping_interval = 0 to disable pings.
177 Set ws_ping_interval = 0 to disable pings.
180 """
178 """
181 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
179 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
182
180
183 @property
181 @property
184 def ping_timeout(self):
182 def ping_timeout(self):
185 """If no ping is received in this many milliseconds,
183 """If no ping is received in this many milliseconds,
186 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
184 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
187 Default is max of 3 pings or 30 seconds.
185 Default is max of 3 pings or 30 seconds.
188 """
186 """
189 return self.settings.get('ws_ping_timeout',
187 return self.settings.get('ws_ping_timeout',
190 max(3 * self.ping_interval, WS_PING_INTERVAL)
188 max(3 * self.ping_interval, WS_PING_INTERVAL)
191 )
189 )
192
190
193 def set_default_headers(self):
191 def set_default_headers(self):
194 """Undo the set_default_headers in IPythonHandler
192 """Undo the set_default_headers in IPythonHandler
195
193
196 which doesn't make sense for websockets
194 which doesn't make sense for websockets
197 """
195 """
198 pass
196 pass
199
197
200 def get(self, *args, **kwargs):
198 def pre_get(self):
199 """Run before finishing the GET request
200
201 Extend this method to add logic that should fire before
202 the websocket finishes completing.
203 """
201 # Check to see that origin matches host directly, including ports
204 # Check to see that origin matches host directly, including ports
202 # Tornado 4 already does CORS checking
205 # Tornado 4 already does CORS checking
203 if tornado.version_info[0] < 4:
206 if tornado.version_info[0] < 4:
204 if not self.check_origin(self.get_origin()):
207 if not self.check_origin(self.get_origin()):
205 raise web.HTTPError(403)
208 raise web.HTTPError(403)
206
209
207 # authenticate the request before opening the websocket
210 # authenticate the request before opening the websocket
208 if self.get_current_user() is None:
211 if self.get_current_user() is None:
209 self.log.warn("Couldn't authenticate WebSocket connection")
212 self.log.warn("Couldn't authenticate WebSocket connection")
210 raise web.HTTPError(403)
213 raise web.HTTPError(403)
211
214
212 if self.get_argument('session_id', False):
215 if self.get_argument('session_id', False):
213 self.session.session = cast_unicode(self.get_argument('session_id'))
216 self.session.session = cast_unicode(self.get_argument('session_id'))
214 else:
217 else:
215 self.log.warn("No session ID specified")
218 self.log.warn("No session ID specified")
219
220 @gen.coroutine
221 def get(self, *args, **kwargs):
222 # pre_get can be a coroutine in subclasses
223 yield gen.maybe_future(self.pre_get())
216 # FIXME: only do super get on tornado β‰₯ 4
224 # FIXME: only do super get on tornado β‰₯ 4
217 # tornado 3 has no get, will raise 405
225 # tornado 3 has no get, will raise 405
218 if tornado.version_info >= (4,):
226 if tornado.version_info >= (4,):
219 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
227 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
220
228
221 def initialize(self):
229 def initialize(self):
230 self.log.debug("Initializing websocket connection %s", self.request.path)
222 self.session = Session(config=self.config)
231 self.session = Session(config=self.config)
223
232
224 def open(self, *args, **kwargs):
233 def open(self, *args, **kwargs):
234 self.log.debug("Opening websocket %s", self.request.path)
225 if tornado.version_info < (4,):
235 if tornado.version_info < (4,):
226 try:
236 try:
227 self.get(*self.open_args, **self.open_kwargs)
237 self.get(*self.open_args, **self.open_kwargs)
228 except web.HTTPError:
238 except web.HTTPError:
229 self.close()
239 self.close()
230 raise
240 raise
231
241
232 # start the pinging
242 # start the pinging
233 if self.ping_interval > 0:
243 if self.ping_interval > 0:
234 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
244 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
235 self.last_pong = self.last_ping
245 self.last_pong = self.last_ping
236 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
246 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
237 self.ping_callback.start()
247 self.ping_callback.start()
238
248
239 def send_ping(self):
249 def send_ping(self):
240 """send a ping to keep the websocket alive"""
250 """send a ping to keep the websocket alive"""
241 if self.stream.closed() and self.ping_callback is not None:
251 if self.stream.closed() and self.ping_callback is not None:
242 self.ping_callback.stop()
252 self.ping_callback.stop()
243 return
253 return
244
254
245 # check for timeout on pong. Make sure that we really have sent a recent ping in
255 # check for timeout on pong. Make sure that we really have sent a recent ping in
246 # case the machine with both server and client has been suspended since the last ping.
256 # case the machine with both server and client has been suspended since the last ping.
247 now = ioloop.IOLoop.instance().time()
257 now = ioloop.IOLoop.instance().time()
248 since_last_pong = 1e3 * (now - self.last_pong)
258 since_last_pong = 1e3 * (now - self.last_pong)
249 since_last_ping = 1e3 * (now - self.last_ping)
259 since_last_ping = 1e3 * (now - self.last_ping)
250 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
260 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
251 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
261 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
252 self.close()
262 self.close()
253 return
263 return
254
264
255 self.ping(b'')
265 self.ping(b'')
256 self.last_ping = now
266 self.last_ping = now
257
267
258 def on_pong(self, data):
268 def on_pong(self, data):
259 self.last_pong = ioloop.IOLoop.instance().time()
269 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,269 +1,294
1 """Tornado handlers for kernels."""
1 """Tornado handlers for kernels."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import json
6 import json
7 import logging
7 import logging
8 from tornado import gen, web
8 from tornado import gen, web
9 from tornado.concurrent import Future
9 from tornado.concurrent import Future
10 from tornado.ioloop import IOLoop
10
11
11 from IPython.utils.jsonutil import date_default
12 from IPython.utils.jsonutil import date_default
12 from IPython.utils.py3compat import cast_unicode
13 from IPython.utils.py3compat import cast_unicode
13 from IPython.html.utils import url_path_join, url_escape
14 from IPython.html.utils import url_path_join, url_escape
14
15
15 from ...base.handlers import IPythonHandler, json_errors
16 from ...base.handlers import IPythonHandler, json_errors
16 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
17 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
17
18
18 from IPython.core.release import kernel_protocol_version
19 from IPython.core.release import kernel_protocol_version
19
20
20 class MainKernelHandler(IPythonHandler):
21 class MainKernelHandler(IPythonHandler):
21
22
22 @web.authenticated
23 @web.authenticated
23 @json_errors
24 @json_errors
24 def get(self):
25 def get(self):
25 km = self.kernel_manager
26 km = self.kernel_manager
26 self.finish(json.dumps(km.list_kernels()))
27 self.finish(json.dumps(km.list_kernels()))
27
28
28 @web.authenticated
29 @web.authenticated
29 @json_errors
30 @json_errors
30 def post(self):
31 def post(self):
31 km = self.kernel_manager
32 km = self.kernel_manager
32 model = self.get_json_body()
33 model = self.get_json_body()
33 if model is None:
34 if model is None:
34 model = {
35 model = {
35 'name': km.default_kernel_name
36 'name': km.default_kernel_name
36 }
37 }
37 else:
38 else:
38 model.setdefault('name', km.default_kernel_name)
39 model.setdefault('name', km.default_kernel_name)
39
40
40 kernel_id = km.start_kernel(kernel_name=model['name'])
41 kernel_id = km.start_kernel(kernel_name=model['name'])
41 model = km.kernel_model(kernel_id)
42 model = km.kernel_model(kernel_id)
42 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
43 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
43 self.set_header('Location', url_escape(location))
44 self.set_header('Location', url_escape(location))
44 self.set_status(201)
45 self.set_status(201)
45 self.finish(json.dumps(model))
46 self.finish(json.dumps(model))
46
47
47
48
48 class KernelHandler(IPythonHandler):
49 class KernelHandler(IPythonHandler):
49
50
50 SUPPORTED_METHODS = ('DELETE', 'GET')
51 SUPPORTED_METHODS = ('DELETE', 'GET')
51
52
52 @web.authenticated
53 @web.authenticated
53 @json_errors
54 @json_errors
54 def get(self, kernel_id):
55 def get(self, kernel_id):
55 km = self.kernel_manager
56 km = self.kernel_manager
56 km._check_kernel_id(kernel_id)
57 km._check_kernel_id(kernel_id)
57 model = km.kernel_model(kernel_id)
58 model = km.kernel_model(kernel_id)
58 self.finish(json.dumps(model))
59 self.finish(json.dumps(model))
59
60
60 @web.authenticated
61 @web.authenticated
61 @json_errors
62 @json_errors
62 def delete(self, kernel_id):
63 def delete(self, kernel_id):
63 km = self.kernel_manager
64 km = self.kernel_manager
64 km.shutdown_kernel(kernel_id)
65 km.shutdown_kernel(kernel_id)
65 self.set_status(204)
66 self.set_status(204)
66 self.finish()
67 self.finish()
67
68
68
69
69 class KernelActionHandler(IPythonHandler):
70 class KernelActionHandler(IPythonHandler):
70
71
71 @web.authenticated
72 @web.authenticated
72 @json_errors
73 @json_errors
73 def post(self, kernel_id, action):
74 def post(self, kernel_id, action):
74 km = self.kernel_manager
75 km = self.kernel_manager
75 if action == 'interrupt':
76 if action == 'interrupt':
76 km.interrupt_kernel(kernel_id)
77 km.interrupt_kernel(kernel_id)
77 self.set_status(204)
78 self.set_status(204)
78 if action == 'restart':
79 if action == 'restart':
79 km.restart_kernel(kernel_id)
80 km.restart_kernel(kernel_id)
80 model = km.kernel_model(kernel_id)
81 model = km.kernel_model(kernel_id)
81 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
82 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
82 self.write(json.dumps(model))
83 self.write(json.dumps(model))
83 self.finish()
84 self.finish()
84
85
85
86
86 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
87 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
87
88
89 @property
90 def kernel_info_timeout(self):
91 return self.settings.get('kernel_info_timeout', 10)
92
88 def __repr__(self):
93 def __repr__(self):
89 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
94 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
90
95
91 def create_stream(self):
96 def create_stream(self):
92 km = self.kernel_manager
97 km = self.kernel_manager
93 meth = getattr(km, 'connect_%s' % self.channel)
98 meth = getattr(km, 'connect_%s' % self.channel)
94 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
99 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
95
100
96 def request_kernel_info(self):
101 def request_kernel_info(self):
97 """send a request for kernel_info"""
102 """send a request for kernel_info"""
98 km = self.kernel_manager
103 km = self.kernel_manager
99 kernel = km.get_kernel(self.kernel_id)
104 kernel = km.get_kernel(self.kernel_id)
100 try:
105 try:
101 # check for cached value
106 # check for cached value
102 kernel_info = kernel._kernel_info
107 kernel_info = kernel._kernel_info
103 except AttributeError:
108 except AttributeError:
104 self.log.debug("Requesting kernel info from %s", self.kernel_id)
109 self.log.debug("Requesting kernel info from %s", self.kernel_id)
105 # Create a kernel_info channel to query the kernel protocol version.
110 # Create a kernel_info channel to query the kernel protocol version.
106 # This channel will be closed after the kernel_info reply is received.
111 # This channel will be closed after the kernel_info reply is received.
107 if self.kernel_info_channel is None:
112 if self.kernel_info_channel is None:
108 self.kernel_info_channel = km.connect_shell(self.kernel_id)
113 self.kernel_info_channel = km.connect_shell(self.kernel_id)
109 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
114 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
110 self.session.send(self.kernel_info_channel, "kernel_info_request")
115 self.session.send(self.kernel_info_channel, "kernel_info_request")
111 else:
116 else:
112 # use cached value, don't resend request
117 # use cached value, don't resend request
113 self._finish_kernel_info(kernel_info)
118 self._finish_kernel_info(kernel_info)
114 return self._kernel_info_future
119 return self._kernel_info_future
115
120
116 def _handle_kernel_info_reply(self, msg):
121 def _handle_kernel_info_reply(self, msg):
117 """process the kernel_info_reply
122 """process the kernel_info_reply
118
123
119 enabling msg spec adaptation, if necessary
124 enabling msg spec adaptation, if necessary
120 """
125 """
121 idents,msg = self.session.feed_identities(msg)
126 idents,msg = self.session.feed_identities(msg)
122 try:
127 try:
123 msg = self.session.deserialize(msg)
128 msg = self.session.deserialize(msg)
124 except:
129 except:
125 self.log.error("Bad kernel_info reply", exc_info=True)
130 self.log.error("Bad kernel_info reply", exc_info=True)
126 self._kernel_info_future.set_result(None)
131 self._kernel_info_future.set_result(None)
127 return
132 return
128 else:
133 else:
129 info = msg['content']
134 info = msg['content']
130 self.log.debug("Received kernel info: %s", info)
135 self.log.debug("Received kernel info: %s", info)
131 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
136 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in info:
132 self.log.error("Kernel info request failed, assuming current %s", info)
137 self.log.error("Kernel info request failed, assuming current %s", info)
133 else:
138 else:
134 kernel = self.kernel_manager.get_kernel(self.kernel_id)
139 kernel = self.kernel_manager.get_kernel(self.kernel_id)
135 kernel._kernel_info = info
140 kernel._kernel_info = info
136 self._finish_kernel_info(info)
141 self._finish_kernel_info(info)
137
142
138 # close the kernel_info channel, we don't need it anymore
143 # close the kernel_info channel, we don't need it anymore
139 if self.kernel_info_channel:
144 if self.kernel_info_channel:
140 self.kernel_info_channel.close()
145 self.kernel_info_channel.close()
141 self.kernel_info_channel = None
146 self.kernel_info_channel = None
142
147
143 def _finish_kernel_info(self, info):
148 def _finish_kernel_info(self, info):
144 """Finish handling kernel_info reply
149 """Finish handling kernel_info reply
145
150
146 Set up protocol adaptation, if needed,
151 Set up protocol adaptation, if needed,
147 and signal that connection can continue.
152 and signal that connection can continue.
148 """
153 """
149 protocol_version = info.get('protocol_version', kernel_protocol_version)
154 protocol_version = info.get('protocol_version', kernel_protocol_version)
150 if protocol_version != kernel_protocol_version:
155 if protocol_version != kernel_protocol_version:
151 self.session.adapt_version = int(protocol_version.split('.')[0])
156 self.session.adapt_version = int(protocol_version.split('.')[0])
152 self.log.info("Kernel %s speaks protocol %s", self.kernel_id, protocol_version)
157 self.log.info("Kernel %s speaks protocol %s", self.kernel_id, protocol_version)
158 if not self._kernel_info_future.done():
153 self._kernel_info_future.set_result(info)
159 self._kernel_info_future.set_result(info)
154
160
155 def initialize(self):
161 def initialize(self):
156 super(ZMQChannelHandler, self).initialize()
162 super(ZMQChannelHandler, self).initialize()
157 self.zmq_stream = None
163 self.zmq_stream = None
158 self.kernel_id = None
164 self.kernel_id = None
159 self.kernel_info_channel = None
165 self.kernel_info_channel = None
160 self._kernel_info_future = Future()
166 self._kernel_info_future = Future()
161
167
162 @gen.coroutine
168 @gen.coroutine
169 def pre_get(self):
170 # authenticate first
171 super(ZMQChannelHandler, self).pre_get()
172 # then request kernel info, waiting up to a certain time before giving up.
173 # We don't want to wait forever, because browsers don't take it well when
174 # servers never respond to websocket connection requests.
175 future = self.request_kernel_info()
176
177 def give_up():
178 """Don't wait forever for the kernel to reply"""
179 if future.done():
180 return
181 self.log.warn("Timeout waiting for kernel_info reply from %s", self.kernel_id)
182 future.set_result(None)
183 loop = IOLoop.current()
184 loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up)
185 # actually wait for it
186 yield future
187
188 @gen.coroutine
163 def get(self, kernel_id):
189 def get(self, kernel_id):
164 self.kernel_id = cast_unicode(kernel_id, 'ascii')
190 self.kernel_id = cast_unicode(kernel_id, 'ascii')
165 yield self.request_kernel_info()
191 yield super(ZMQChannelHandler, self).get(kernel_id=kernel_id)
166 super(ZMQChannelHandler, self).get(kernel_id)
167
192
168 def open(self, kernel_id):
193 def open(self, kernel_id):
169 super(ZMQChannelHandler, self).open()
194 super(ZMQChannelHandler, self).open()
170 try:
195 try:
171 self.create_stream()
196 self.create_stream()
172 except web.HTTPError as e:
197 except web.HTTPError as e:
173 self.log.error("Error opening stream: %s", e)
198 self.log.error("Error opening stream: %s", e)
174 # WebSockets don't response to traditional error codes so we
199 # WebSockets don't response to traditional error codes so we
175 # close the connection.
200 # close the connection.
176 if not self.stream.closed():
201 if not self.stream.closed():
177 self.stream.close()
202 self.stream.close()
178 self.close()
203 self.close()
179 else:
204 else:
180 self.zmq_stream.on_recv(self._on_zmq_reply)
205 self.zmq_stream.on_recv(self._on_zmq_reply)
181
206
182 def on_message(self, msg):
207 def on_message(self, msg):
183 if self.zmq_stream is None:
208 if self.zmq_stream is None:
184 return
209 return
185 elif self.zmq_stream.closed():
210 elif self.zmq_stream.closed():
186 self.log.info("%s closed, closing websocket.", self)
211 self.log.info("%s closed, closing websocket.", self)
187 self.close()
212 self.close()
188 return
213 return
189 if isinstance(msg, bytes):
214 if isinstance(msg, bytes):
190 msg = deserialize_binary_message(msg)
215 msg = deserialize_binary_message(msg)
191 else:
216 else:
192 msg = json.loads(msg)
217 msg = json.loads(msg)
193 self.session.send(self.zmq_stream, msg)
218 self.session.send(self.zmq_stream, msg)
194
219
195 def on_close(self):
220 def on_close(self):
196 # This method can be called twice, once by self.kernel_died and once
221 # This method can be called twice, once by self.kernel_died and once
197 # from the WebSocket close event. If the WebSocket connection is
222 # from the WebSocket close event. If the WebSocket connection is
198 # closed before the ZMQ streams are setup, they could be None.
223 # closed before the ZMQ streams are setup, they could be None.
199 if self.zmq_stream is not None and not self.zmq_stream.closed():
224 if self.zmq_stream is not None and not self.zmq_stream.closed():
200 self.zmq_stream.on_recv(None)
225 self.zmq_stream.on_recv(None)
201 # close the socket directly, don't wait for the stream
226 # close the socket directly, don't wait for the stream
202 socket = self.zmq_stream.socket
227 socket = self.zmq_stream.socket
203 self.zmq_stream.close()
228 self.zmq_stream.close()
204 socket.close()
229 socket.close()
205
230
206
231
207 class IOPubHandler(ZMQChannelHandler):
232 class IOPubHandler(ZMQChannelHandler):
208 channel = 'iopub'
233 channel = 'iopub'
209
234
210 def create_stream(self):
235 def create_stream(self):
211 super(IOPubHandler, self).create_stream()
236 super(IOPubHandler, self).create_stream()
212 km = self.kernel_manager
237 km = self.kernel_manager
213 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
238 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
214 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
239 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
215
240
216 def on_close(self):
241 def on_close(self):
217 km = self.kernel_manager
242 km = self.kernel_manager
218 if self.kernel_id in km:
243 if self.kernel_id in km:
219 km.remove_restart_callback(
244 km.remove_restart_callback(
220 self.kernel_id, self.on_kernel_restarted,
245 self.kernel_id, self.on_kernel_restarted,
221 )
246 )
222 km.remove_restart_callback(
247 km.remove_restart_callback(
223 self.kernel_id, self.on_restart_failed, 'dead',
248 self.kernel_id, self.on_restart_failed, 'dead',
224 )
249 )
225 super(IOPubHandler, self).on_close()
250 super(IOPubHandler, self).on_close()
226
251
227 def _send_status_message(self, status):
252 def _send_status_message(self, status):
228 msg = self.session.msg("status",
253 msg = self.session.msg("status",
229 {'execution_state': status}
254 {'execution_state': status}
230 )
255 )
231 self.write_message(json.dumps(msg, default=date_default))
256 self.write_message(json.dumps(msg, default=date_default))
232
257
233 def on_kernel_restarted(self):
258 def on_kernel_restarted(self):
234 logging.warn("kernel %s restarted", self.kernel_id)
259 logging.warn("kernel %s restarted", self.kernel_id)
235 self._send_status_message('restarting')
260 self._send_status_message('restarting')
236
261
237 def on_restart_failed(self):
262 def on_restart_failed(self):
238 logging.error("kernel %s restarted failed!", self.kernel_id)
263 logging.error("kernel %s restarted failed!", self.kernel_id)
239 self._send_status_message('dead')
264 self._send_status_message('dead')
240
265
241 def on_message(self, msg):
266 def on_message(self, msg):
242 """IOPub messages make no sense"""
267 """IOPub messages make no sense"""
243 pass
268 pass
244
269
245
270
246 class ShellHandler(ZMQChannelHandler):
271 class ShellHandler(ZMQChannelHandler):
247 channel = 'shell'
272 channel = 'shell'
248
273
249
274
250 class StdinHandler(ZMQChannelHandler):
275 class StdinHandler(ZMQChannelHandler):
251 channel = 'stdin'
276 channel = 'stdin'
252
277
253
278
254 #-----------------------------------------------------------------------------
279 #-----------------------------------------------------------------------------
255 # URL to handler mappings
280 # URL to handler mappings
256 #-----------------------------------------------------------------------------
281 #-----------------------------------------------------------------------------
257
282
258
283
259 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
284 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
260 _kernel_action_regex = r"(?P<action>restart|interrupt)"
285 _kernel_action_regex = r"(?P<action>restart|interrupt)"
261
286
262 default_handlers = [
287 default_handlers = [
263 (r"/api/kernels", MainKernelHandler),
288 (r"/api/kernels", MainKernelHandler),
264 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
289 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
265 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
290 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
266 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
291 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
267 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
292 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
268 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
293 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
269 ]
294 ]
General Comments 0
You need to be logged in to leave comments. Login now