##// END OF EJS Templates
Get pre_get to work and make session logs when adapter changes
Matthias Bussonnier -
Show More
@@ -1,269 +1,271 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import json
7 import json
8 import struct
8 import struct
9
9
10 try:
10 try:
11 from urllib.parse import urlparse # Py 3
11 from urllib.parse import urlparse # Py 3
12 except ImportError:
12 except ImportError:
13 from urlparse import urlparse # Py 2
13 from urlparse import urlparse # Py 2
14
14
15 import tornado
15 import tornado
16 from tornado import gen, ioloop, web, websocket
16 from tornado import gen, ioloop, web, websocket
17
17
18 from IPython.kernel.zmq.session import Session
18 from IPython.kernel.zmq.session import Session
19 from IPython.utils.jsonutil import date_default, extract_dates
19 from IPython.utils.jsonutil import date_default, extract_dates
20 from IPython.utils.py3compat import cast_unicode
20 from IPython.utils.py3compat import cast_unicode
21
21
22 from .handlers import IPythonHandler
22 from .handlers import IPythonHandler
23
23
24
24
25 def serialize_binary_message(msg):
25 def serialize_binary_message(msg):
26 """serialize a message as a binary blob
26 """serialize a message as a binary blob
27
27
28 Header:
28 Header:
29
29
30 4 bytes: number of msg parts (nbufs) as 32b int
30 4 bytes: number of msg parts (nbufs) as 32b int
31 4 * nbufs bytes: offset for each buffer as integer as 32b int
31 4 * nbufs bytes: offset for each buffer as integer as 32b int
32
32
33 Offsets are from the start of the buffer, including the header.
33 Offsets are from the start of the buffer, including the header.
34
34
35 Returns
35 Returns
36 -------
36 -------
37
37
38 The message serialized to bytes.
38 The message serialized to bytes.
39
39
40 """
40 """
41 # don't modify msg or buffer list in-place
41 # don't modify msg or buffer list in-place
42 msg = msg.copy()
42 msg = msg.copy()
43 buffers = list(msg.pop('buffers'))
43 buffers = list(msg.pop('buffers'))
44 bmsg = json.dumps(msg, default=date_default).encode('utf8')
44 bmsg = json.dumps(msg, default=date_default).encode('utf8')
45 buffers.insert(0, bmsg)
45 buffers.insert(0, bmsg)
46 nbufs = len(buffers)
46 nbufs = len(buffers)
47 offsets = [4 * (nbufs + 1)]
47 offsets = [4 * (nbufs + 1)]
48 for buf in buffers[:-1]:
48 for buf in buffers[:-1]:
49 offsets.append(offsets[-1] + len(buf))
49 offsets.append(offsets[-1] + len(buf))
50 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
50 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
51 buffers.insert(0, offsets_buf)
51 buffers.insert(0, offsets_buf)
52 return b''.join(buffers)
52 return b''.join(buffers)
53
53
54
54
55 def deserialize_binary_message(bmsg):
55 def deserialize_binary_message(bmsg):
56 """deserialize a message from a binary blog
56 """deserialize a message from a binary blog
57
57
58 Header:
58 Header:
59
59
60 4 bytes: number of msg parts (nbufs) as 32b int
60 4 bytes: number of msg parts (nbufs) as 32b int
61 4 * nbufs bytes: offset for each buffer as integer as 32b int
61 4 * nbufs bytes: offset for each buffer as integer as 32b int
62
62
63 Offsets are from the start of the buffer, including the header.
63 Offsets are from the start of the buffer, including the header.
64
64
65 Returns
65 Returns
66 -------
66 -------
67
67
68 message dictionary
68 message dictionary
69 """
69 """
70 nbufs = struct.unpack('!i', bmsg[:4])[0]
70 nbufs = struct.unpack('!i', bmsg[:4])[0]
71 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
71 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
72 offsets.append(None)
72 offsets.append(None)
73 bufs = []
73 bufs = []
74 for start, stop in zip(offsets[:-1], offsets[1:]):
74 for start, stop in zip(offsets[:-1], offsets[1:]):
75 bufs.append(bmsg[start:stop])
75 bufs.append(bmsg[start:stop])
76 msg = json.loads(bufs[0].decode('utf8'))
76 msg = json.loads(bufs[0].decode('utf8'))
77 msg['header'] = extract_dates(msg['header'])
77 msg['header'] = extract_dates(msg['header'])
78 msg['parent_header'] = extract_dates(msg['parent_header'])
78 msg['parent_header'] = extract_dates(msg['parent_header'])
79 msg['buffers'] = bufs[1:]
79 msg['buffers'] = bufs[1:]
80 return msg
80 return msg
81
81
82
82
83 class ZMQStreamHandler(websocket.WebSocketHandler):
83 class ZMQStreamHandler(websocket.WebSocketHandler):
84
84
85 def check_origin(self, origin):
85 def check_origin(self, origin):
86 """Check Origin == Host or Access-Control-Allow-Origin.
86 """Check Origin == Host or Access-Control-Allow-Origin.
87
87
88 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
88 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
89 We call it explicitly in `open` on Tornado < 4.
89 We call it explicitly in `open` on Tornado < 4.
90 """
90 """
91 if self.allow_origin == '*':
91 if self.allow_origin == '*':
92 return True
92 return True
93
93
94 host = self.request.headers.get("Host")
94 host = self.request.headers.get("Host")
95
95
96 # If no header is provided, assume we can't verify origin
96 # If no header is provided, assume we can't verify origin
97 if origin is None:
97 if origin is None:
98 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
98 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
99 return False
99 return False
100 if host is None:
100 if host is None:
101 self.log.warn("Missing Host header, rejecting WebSocket connection.")
101 self.log.warn("Missing Host header, rejecting WebSocket connection.")
102 return False
102 return False
103
103
104 origin = origin.lower()
104 origin = origin.lower()
105 origin_host = urlparse(origin).netloc
105 origin_host = urlparse(origin).netloc
106
106
107 # OK if origin matches host
107 # OK if origin matches host
108 if origin_host == host:
108 if origin_host == host:
109 return True
109 return True
110
110
111 # Check CORS headers
111 # Check CORS headers
112 if self.allow_origin:
112 if self.allow_origin:
113 allow = self.allow_origin == origin
113 allow = self.allow_origin == origin
114 elif self.allow_origin_pat:
114 elif self.allow_origin_pat:
115 allow = bool(self.allow_origin_pat.match(origin))
115 allow = bool(self.allow_origin_pat.match(origin))
116 else:
116 else:
117 # No CORS headers deny the request
117 # No CORS headers deny the request
118 allow = False
118 allow = False
119 if not allow:
119 if not allow:
120 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
120 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
121 origin, host,
121 origin, host,
122 )
122 )
123 return allow
123 return allow
124
124
125 def clear_cookie(self, *args, **kwargs):
125 def clear_cookie(self, *args, **kwargs):
126 """meaningless for websockets"""
126 """meaningless for websockets"""
127 pass
127 pass
128
128
129 def _reserialize_reply(self, msg_list):
129 def _reserialize_reply(self, msg_list):
130 """Reserialize a reply message using JSON.
130 """Reserialize a reply message using JSON.
131
131
132 This takes the msg list from the ZMQ socket, deserializes it using
132 This takes the msg list from the ZMQ socket, deserializes it using
133 self.session and then serializes the result using JSON. This method
133 self.session and then serializes the result using JSON. This method
134 should be used by self._on_zmq_reply to build messages that can
134 should be used by self._on_zmq_reply to build messages that can
135 be sent back to the browser.
135 be sent back to the browser.
136 """
136 """
137 idents, msg_list = self.session.feed_identities(msg_list)
137 idents, msg_list = self.session.feed_identities(msg_list)
138 msg = self.session.deserialize(msg_list)
138 msg = self.session.deserialize(msg_list)
139 if msg['buffers']:
139 if msg['buffers']:
140 buf = serialize_binary_message(msg)
140 buf = serialize_binary_message(msg)
141 return buf
141 return buf
142 else:
142 else:
143 smsg = json.dumps(msg, default=date_default)
143 smsg = json.dumps(msg, default=date_default)
144 return cast_unicode(smsg)
144 return cast_unicode(smsg)
145
145
146 def _on_zmq_reply(self, msg_list):
146 def _on_zmq_reply(self, msg_list):
147 # Sometimes this gets triggered when the on_close method is scheduled in the
147 # Sometimes this gets triggered when the on_close method is scheduled in the
148 # eventloop but hasn't been called.
148 # eventloop but hasn't been called.
149 if self.stream.closed(): return
149 if self.stream.closed(): return
150 try:
150 try:
151 msg = self._reserialize_reply(msg_list)
151 msg = self._reserialize_reply(msg_list)
152 except Exception:
152 except Exception:
153 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
153 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
154 else:
154 else:
155 self.write_message(msg, binary=isinstance(msg, bytes))
155 self.write_message(msg, binary=isinstance(msg, bytes))
156
156
157 def allow_draft76(self):
157 def allow_draft76(self):
158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
159
159
160 This has been disabled by default in tornado in release 2.2.0, and
160 This has been disabled by default in tornado in release 2.2.0, and
161 support will be removed in later versions.
161 support will be removed in later versions.
162 """
162 """
163 return True
163 return True
164
164
165 # ping interval for keeping websockets alive (30 seconds)
165 # ping interval for keeping websockets alive (30 seconds)
166 WS_PING_INTERVAL = 30000
166 WS_PING_INTERVAL = 30000
167
167
168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 ping_callback = None
169 ping_callback = None
170 last_ping = 0
170 last_ping = 0
171 last_pong = 0
171 last_pong = 0
172
172
173 @property
173 @property
174 def ping_interval(self):
174 def ping_interval(self):
175 """The interval for websocket keep-alive pings.
175 """The interval for websocket keep-alive pings.
176
176
177 Set ws_ping_interval = 0 to disable pings.
177 Set ws_ping_interval = 0 to disable pings.
178 """
178 """
179 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
179 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
180
180
181 @property
181 @property
182 def ping_timeout(self):
182 def ping_timeout(self):
183 """If no ping is received in this many milliseconds,
183 """If no ping is received in this many milliseconds,
184 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
184 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
185 Default is max of 3 pings or 30 seconds.
185 Default is max of 3 pings or 30 seconds.
186 """
186 """
187 return self.settings.get('ws_ping_timeout',
187 return self.settings.get('ws_ping_timeout',
188 max(3 * self.ping_interval, WS_PING_INTERVAL)
188 max(3 * self.ping_interval, WS_PING_INTERVAL)
189 )
189 )
190
190
191 def set_default_headers(self):
191 def set_default_headers(self):
192 """Undo the set_default_headers in IPythonHandler
192 """Undo the set_default_headers in IPythonHandler
193
193
194 which doesn't make sense for websockets
194 which doesn't make sense for websockets
195 """
195 """
196 pass
196 pass
197
197
198 def pre_get(self):
198 def pre_get(self):
199 """Run before finishing the GET request
199 """Run before finishing the GET request
200
200
201 Extend this method to add logic that should fire before
201 Extend this method to add logic that should fire before
202 the websocket finishes completing.
202 the websocket finishes completing.
203 """
203 """
204 # Check to see that origin matches host directly, including ports
204 # Check to see that origin matches host directly, including ports
205 # Tornado 4 already does CORS checking
205 # Tornado 4 already does CORS checking
206 if tornado.version_info[0] < 4:
206 if tornado.version_info[0] < 4:
207 if not self.check_origin(self.get_origin()):
207 if not self.check_origin(self.get_origin()):
208 raise web.HTTPError(403)
208 raise web.HTTPError(403)
209
209
210 # authenticate the request before opening the websocket
210 # authenticate the request before opening the websocket
211 if self.get_current_user() is None:
211 if self.get_current_user() is None:
212 self.log.warn("Couldn't authenticate WebSocket connection")
212 self.log.warn("Couldn't authenticate WebSocket connection")
213 raise web.HTTPError(403)
213 raise web.HTTPError(403)
214
214
215 if self.get_argument('session_id', False):
215 if self.get_argument('session_id', False):
216 self.session.session = cast_unicode(self.get_argument('session_id'))
216 self.session.session = cast_unicode(self.get_argument('session_id'))
217 else:
217 else:
218 self.log.warn("No session ID specified")
218 self.log.warn("No session ID specified")
219
219
220 @gen.coroutine
220 @gen.coroutine
221 def get(self, *args, **kwargs):
221 def get(self, *args, **kwargs):
222 # pre_get can be a coroutine in subclasses
222 # pre_get can be a coroutine in subclasses
223 yield gen.maybe_future(self.pre_get())
223 # assign and yield in two step to avoid tornado 3 issues
224 res = self.pre_get()
225 yield gen.maybe_future(res)
224 # FIXME: only do super get on tornado β‰₯ 4
226 # FIXME: only do super get on tornado β‰₯ 4
225 # tornado 3 has no get, will raise 405
227 # tornado 3 has no get, will raise 405
226 if tornado.version_info >= (4,):
228 if tornado.version_info >= (4,):
227 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
229 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
228
230
229 def initialize(self):
231 def initialize(self):
230 self.log.debug("Initializing websocket connection %s", self.request.path)
232 self.log.debug("Initializing websocket connection %s", self.request.path)
231 self.session = Session(config=self.config)
233 self.session = Session(config=self.config)
232
234
233 def open(self, *args, **kwargs):
235 def open(self, *args, **kwargs):
234 self.log.debug("Opening websocket %s", self.request.path)
236 self.log.debug("Opening websocket %s", self.request.path)
235 if tornado.version_info < (4,):
237 if tornado.version_info < (4,):
236 try:
238 try:
237 self.get(*self.open_args, **self.open_kwargs)
239 self.get(*self.open_args, **self.open_kwargs)
238 except web.HTTPError:
240 except web.HTTPError:
239 self.close()
241 self.close()
240 raise
242 raise
241
243
242 # start the pinging
244 # start the pinging
243 if self.ping_interval > 0:
245 if self.ping_interval > 0:
244 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
246 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
245 self.last_pong = self.last_ping
247 self.last_pong = self.last_ping
246 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
248 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
247 self.ping_callback.start()
249 self.ping_callback.start()
248
250
249 def send_ping(self):
251 def send_ping(self):
250 """send a ping to keep the websocket alive"""
252 """send a ping to keep the websocket alive"""
251 if self.stream.closed() and self.ping_callback is not None:
253 if self.stream.closed() and self.ping_callback is not None:
252 self.ping_callback.stop()
254 self.ping_callback.stop()
253 return
255 return
254
256
255 # check for timeout on pong. Make sure that we really have sent a recent ping in
257 # check for timeout on pong. Make sure that we really have sent a recent ping in
256 # case the machine with both server and client has been suspended since the last ping.
258 # case the machine with both server and client has been suspended since the last ping.
257 now = ioloop.IOLoop.instance().time()
259 now = ioloop.IOLoop.instance().time()
258 since_last_pong = 1e3 * (now - self.last_pong)
260 since_last_pong = 1e3 * (now - self.last_pong)
259 since_last_ping = 1e3 * (now - self.last_ping)
261 since_last_ping = 1e3 * (now - self.last_ping)
260 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
262 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
261 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
263 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
262 self.close()
264 self.close()
263 return
265 return
264
266
265 self.ping(b'')
267 self.ping(b'')
266 self.last_ping = now
268 self.last_ping = now
267
269
268 def on_pong(self, data):
270 def on_pong(self, data):
269 self.last_pong = ioloop.IOLoop.instance().time()
271 self.last_pong = ioloop.IOLoop.instance().time()
General Comments 0
You need to be logged in to leave comments. Login now