##// END OF EJS Templates
Add check to skip work in versions past 3.4
Jason Grout -
Show More
@@ -1,280 +1,281 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import os
8 import json
8 import json
9 import struct
9 import struct
10 import warnings
10 import warnings
11 import sys
11
12
12 try:
13 try:
13 from urllib.parse import urlparse # Py 3
14 from urllib.parse import urlparse # Py 3
14 except ImportError:
15 except ImportError:
15 from urlparse import urlparse # Py 2
16 from urlparse import urlparse # Py 2
16
17
17 import tornado
18 import tornado
18 from tornado import gen, ioloop, web
19 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
20 from tornado.websocket import WebSocketHandler
20
21
21 from IPython.kernel.zmq.session import Session
22 from IPython.kernel.zmq.session import Session
22 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.jsonutil import date_default, extract_dates
23 from IPython.utils.py3compat import cast_unicode
24 from IPython.utils.py3compat import cast_unicode
24
25
25 from .handlers import IPythonHandler
26 from .handlers import IPythonHandler
26
27
27 def serialize_binary_message(msg):
28 def serialize_binary_message(msg):
28 """serialize a message as a binary blob
29 """serialize a message as a binary blob
29
30
30 Header:
31 Header:
31
32
32 4 bytes: number of msg parts (nbufs) as 32b int
33 4 bytes: number of msg parts (nbufs) as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
34 4 * nbufs bytes: offset for each buffer as integer as 32b int
34
35
35 Offsets are from the start of the buffer, including the header.
36 Offsets are from the start of the buffer, including the header.
36
37
37 Returns
38 Returns
38 -------
39 -------
39
40
40 The message serialized to bytes.
41 The message serialized to bytes.
41
42
42 """
43 """
43 # don't modify msg or buffer list in-place
44 # don't modify msg or buffer list in-place
44 msg = msg.copy()
45 msg = msg.copy()
45 buffers = list(msg.pop('buffers'))
46 buffers = list(msg.pop('buffers'))
46 # for python 2, copy the buffer memoryviews to byte strings
47 if sys.version_info < (3, 4):
47 buffers = [x.tobytes() for x in buffers]
48 buffers = [x.tobytes() for x in buffers]
48 bmsg = json.dumps(msg, default=date_default).encode('utf8')
49 bmsg = json.dumps(msg, default=date_default).encode('utf8')
49 buffers.insert(0, bmsg)
50 buffers.insert(0, bmsg)
50 nbufs = len(buffers)
51 nbufs = len(buffers)
51 offsets = [4 * (nbufs + 1)]
52 offsets = [4 * (nbufs + 1)]
52 for buf in buffers[:-1]:
53 for buf in buffers[:-1]:
53 offsets.append(offsets[-1] + len(buf))
54 offsets.append(offsets[-1] + len(buf))
54 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
55 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
55 buffers.insert(0, offsets_buf)
56 buffers.insert(0, offsets_buf)
56 return b''.join(buffers)
57 return b''.join(buffers)
57
58
58
59
59 def deserialize_binary_message(bmsg):
60 def deserialize_binary_message(bmsg):
60 """deserialize a message from a binary blog
61 """deserialize a message from a binary blog
61
62
62 Header:
63 Header:
63
64
64 4 bytes: number of msg parts (nbufs) as 32b int
65 4 bytes: number of msg parts (nbufs) as 32b int
65 4 * nbufs bytes: offset for each buffer as integer as 32b int
66 4 * nbufs bytes: offset for each buffer as integer as 32b int
66
67
67 Offsets are from the start of the buffer, including the header.
68 Offsets are from the start of the buffer, including the header.
68
69
69 Returns
70 Returns
70 -------
71 -------
71
72
72 message dictionary
73 message dictionary
73 """
74 """
74 nbufs = struct.unpack('!i', bmsg[:4])[0]
75 nbufs = struct.unpack('!i', bmsg[:4])[0]
75 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
76 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
76 offsets.append(None)
77 offsets.append(None)
77 bufs = []
78 bufs = []
78 for start, stop in zip(offsets[:-1], offsets[1:]):
79 for start, stop in zip(offsets[:-1], offsets[1:]):
79 bufs.append(bmsg[start:stop])
80 bufs.append(bmsg[start:stop])
80 msg = json.loads(bufs[0].decode('utf8'))
81 msg = json.loads(bufs[0].decode('utf8'))
81 msg['header'] = extract_dates(msg['header'])
82 msg['header'] = extract_dates(msg['header'])
82 msg['parent_header'] = extract_dates(msg['parent_header'])
83 msg['parent_header'] = extract_dates(msg['parent_header'])
83 msg['buffers'] = bufs[1:]
84 msg['buffers'] = bufs[1:]
84 return msg
85 return msg
85
86
86 # ping interval for keeping websockets alive (30 seconds)
87 # ping interval for keeping websockets alive (30 seconds)
87 WS_PING_INTERVAL = 30000
88 WS_PING_INTERVAL = 30000
88
89
89 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
90 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
90 warnings.warn("""Allowing draft76 websocket connections!
91 warnings.warn("""Allowing draft76 websocket connections!
91 This should only be done for testing with phantomjs!""")
92 This should only be done for testing with phantomjs!""")
92 from IPython.html import allow76
93 from IPython.html import allow76
93 WebSocketHandler = allow76.AllowDraftWebSocketHandler
94 WebSocketHandler = allow76.AllowDraftWebSocketHandler
94 # draft 76 doesn't support ping
95 # draft 76 doesn't support ping
95 WS_PING_INTERVAL = 0
96 WS_PING_INTERVAL = 0
96
97
97 class ZMQStreamHandler(WebSocketHandler):
98 class ZMQStreamHandler(WebSocketHandler):
98
99
99 if tornado.version_info < (4,1):
100 if tornado.version_info < (4,1):
100 """Backport send_error from tornado 4.1 to 4.0"""
101 """Backport send_error from tornado 4.1 to 4.0"""
101 def send_error(self, *args, **kwargs):
102 def send_error(self, *args, **kwargs):
102 if self.stream is None:
103 if self.stream is None:
103 super(WebSocketHandler, self).send_error(*args, **kwargs)
104 super(WebSocketHandler, self).send_error(*args, **kwargs)
104 else:
105 else:
105 # If we get an uncaught exception during the handshake,
106 # If we get an uncaught exception during the handshake,
106 # we have no choice but to abruptly close the connection.
107 # we have no choice but to abruptly close the connection.
107 # TODO: for uncaught exceptions after the handshake,
108 # TODO: for uncaught exceptions after the handshake,
108 # we can close the connection more gracefully.
109 # we can close the connection more gracefully.
109 self.stream.close()
110 self.stream.close()
110
111
111
112
112 def check_origin(self, origin):
113 def check_origin(self, origin):
113 """Check Origin == Host or Access-Control-Allow-Origin.
114 """Check Origin == Host or Access-Control-Allow-Origin.
114
115
115 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
116 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
116 """
117 """
117 if self.allow_origin == '*':
118 if self.allow_origin == '*':
118 return True
119 return True
119
120
120 host = self.request.headers.get("Host")
121 host = self.request.headers.get("Host")
121
122
122 # If no header is provided, assume we can't verify origin
123 # If no header is provided, assume we can't verify origin
123 if origin is None:
124 if origin is None:
124 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
125 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
125 return False
126 return False
126 if host is None:
127 if host is None:
127 self.log.warn("Missing Host header, rejecting WebSocket connection.")
128 self.log.warn("Missing Host header, rejecting WebSocket connection.")
128 return False
129 return False
129
130
130 origin = origin.lower()
131 origin = origin.lower()
131 origin_host = urlparse(origin).netloc
132 origin_host = urlparse(origin).netloc
132
133
133 # OK if origin matches host
134 # OK if origin matches host
134 if origin_host == host:
135 if origin_host == host:
135 return True
136 return True
136
137
137 # Check CORS headers
138 # Check CORS headers
138 if self.allow_origin:
139 if self.allow_origin:
139 allow = self.allow_origin == origin
140 allow = self.allow_origin == origin
140 elif self.allow_origin_pat:
141 elif self.allow_origin_pat:
141 allow = bool(self.allow_origin_pat.match(origin))
142 allow = bool(self.allow_origin_pat.match(origin))
142 else:
143 else:
143 # No CORS headers deny the request
144 # No CORS headers deny the request
144 allow = False
145 allow = False
145 if not allow:
146 if not allow:
146 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
147 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
147 origin, host,
148 origin, host,
148 )
149 )
149 return allow
150 return allow
150
151
151 def clear_cookie(self, *args, **kwargs):
152 def clear_cookie(self, *args, **kwargs):
152 """meaningless for websockets"""
153 """meaningless for websockets"""
153 pass
154 pass
154
155
155 def _reserialize_reply(self, msg_list, channel=None):
156 def _reserialize_reply(self, msg_list, channel=None):
156 """Reserialize a reply message using JSON.
157 """Reserialize a reply message using JSON.
157
158
158 This takes the msg list from the ZMQ socket, deserializes it using
159 This takes the msg list from the ZMQ socket, deserializes it using
159 self.session and then serializes the result using JSON. This method
160 self.session and then serializes the result using JSON. This method
160 should be used by self._on_zmq_reply to build messages that can
161 should be used by self._on_zmq_reply to build messages that can
161 be sent back to the browser.
162 be sent back to the browser.
162 """
163 """
163 idents, msg_list = self.session.feed_identities(msg_list)
164 idents, msg_list = self.session.feed_identities(msg_list)
164 msg = self.session.deserialize(msg_list)
165 msg = self.session.deserialize(msg_list)
165 if channel:
166 if channel:
166 msg['channel'] = channel
167 msg['channel'] = channel
167 if msg['buffers']:
168 if msg['buffers']:
168 buf = serialize_binary_message(msg)
169 buf = serialize_binary_message(msg)
169 return buf
170 return buf
170 else:
171 else:
171 smsg = json.dumps(msg, default=date_default)
172 smsg = json.dumps(msg, default=date_default)
172 return cast_unicode(smsg)
173 return cast_unicode(smsg)
173
174
174 def _on_zmq_reply(self, stream, msg_list):
175 def _on_zmq_reply(self, stream, msg_list):
175 # Sometimes this gets triggered when the on_close method is scheduled in the
176 # Sometimes this gets triggered when the on_close method is scheduled in the
176 # eventloop but hasn't been called.
177 # eventloop but hasn't been called.
177 if self.stream.closed() or stream.closed():
178 if self.stream.closed() or stream.closed():
178 self.log.warn("zmq message arrived on closed channel")
179 self.log.warn("zmq message arrived on closed channel")
179 self.close()
180 self.close()
180 return
181 return
181 channel = getattr(stream, 'channel', None)
182 channel = getattr(stream, 'channel', None)
182 try:
183 try:
183 msg = self._reserialize_reply(msg_list, channel=channel)
184 msg = self._reserialize_reply(msg_list, channel=channel)
184 except Exception:
185 except Exception:
185 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
186 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
186 else:
187 else:
187 self.write_message(msg, binary=isinstance(msg, bytes))
188 self.write_message(msg, binary=isinstance(msg, bytes))
188
189
189 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
190 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
190 ping_callback = None
191 ping_callback = None
191 last_ping = 0
192 last_ping = 0
192 last_pong = 0
193 last_pong = 0
193
194
194 @property
195 @property
195 def ping_interval(self):
196 def ping_interval(self):
196 """The interval for websocket keep-alive pings.
197 """The interval for websocket keep-alive pings.
197
198
198 Set ws_ping_interval = 0 to disable pings.
199 Set ws_ping_interval = 0 to disable pings.
199 """
200 """
200 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
201 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
201
202
202 @property
203 @property
203 def ping_timeout(self):
204 def ping_timeout(self):
204 """If no ping is received in this many milliseconds,
205 """If no ping is received in this many milliseconds,
205 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
206 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
206 Default is max of 3 pings or 30 seconds.
207 Default is max of 3 pings or 30 seconds.
207 """
208 """
208 return self.settings.get('ws_ping_timeout',
209 return self.settings.get('ws_ping_timeout',
209 max(3 * self.ping_interval, WS_PING_INTERVAL)
210 max(3 * self.ping_interval, WS_PING_INTERVAL)
210 )
211 )
211
212
212 def set_default_headers(self):
213 def set_default_headers(self):
213 """Undo the set_default_headers in IPythonHandler
214 """Undo the set_default_headers in IPythonHandler
214
215
215 which doesn't make sense for websockets
216 which doesn't make sense for websockets
216 """
217 """
217 pass
218 pass
218
219
219 def pre_get(self):
220 def pre_get(self):
220 """Run before finishing the GET request
221 """Run before finishing the GET request
221
222
222 Extend this method to add logic that should fire before
223 Extend this method to add logic that should fire before
223 the websocket finishes completing.
224 the websocket finishes completing.
224 """
225 """
225 # authenticate the request before opening the websocket
226 # authenticate the request before opening the websocket
226 if self.get_current_user() is None:
227 if self.get_current_user() is None:
227 self.log.warn("Couldn't authenticate WebSocket connection")
228 self.log.warn("Couldn't authenticate WebSocket connection")
228 raise web.HTTPError(403)
229 raise web.HTTPError(403)
229
230
230 if self.get_argument('session_id', False):
231 if self.get_argument('session_id', False):
231 self.session.session = cast_unicode(self.get_argument('session_id'))
232 self.session.session = cast_unicode(self.get_argument('session_id'))
232 else:
233 else:
233 self.log.warn("No session ID specified")
234 self.log.warn("No session ID specified")
234
235
235 @gen.coroutine
236 @gen.coroutine
236 def get(self, *args, **kwargs):
237 def get(self, *args, **kwargs):
237 # pre_get can be a coroutine in subclasses
238 # pre_get can be a coroutine in subclasses
238 # assign and yield in two step to avoid tornado 3 issues
239 # assign and yield in two step to avoid tornado 3 issues
239 res = self.pre_get()
240 res = self.pre_get()
240 yield gen.maybe_future(res)
241 yield gen.maybe_future(res)
241 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
242 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
242
243
243 def initialize(self):
244 def initialize(self):
244 self.log.debug("Initializing websocket connection %s", self.request.path)
245 self.log.debug("Initializing websocket connection %s", self.request.path)
245 self.session = Session(config=self.config)
246 self.session = Session(config=self.config)
246
247
247 def open(self, *args, **kwargs):
248 def open(self, *args, **kwargs):
248 self.log.debug("Opening websocket %s", self.request.path)
249 self.log.debug("Opening websocket %s", self.request.path)
249
250
250 # start the pinging
251 # start the pinging
251 if self.ping_interval > 0:
252 if self.ping_interval > 0:
252 loop = ioloop.IOLoop.current()
253 loop = ioloop.IOLoop.current()
253 self.last_ping = loop.time() # Remember time of last ping
254 self.last_ping = loop.time() # Remember time of last ping
254 self.last_pong = self.last_ping
255 self.last_pong = self.last_ping
255 self.ping_callback = ioloop.PeriodicCallback(
256 self.ping_callback = ioloop.PeriodicCallback(
256 self.send_ping, self.ping_interval, io_loop=loop,
257 self.send_ping, self.ping_interval, io_loop=loop,
257 )
258 )
258 self.ping_callback.start()
259 self.ping_callback.start()
259
260
260 def send_ping(self):
261 def send_ping(self):
261 """send a ping to keep the websocket alive"""
262 """send a ping to keep the websocket alive"""
262 if self.stream.closed() and self.ping_callback is not None:
263 if self.stream.closed() and self.ping_callback is not None:
263 self.ping_callback.stop()
264 self.ping_callback.stop()
264 return
265 return
265
266
266 # check for timeout on pong. Make sure that we really have sent a recent ping in
267 # check for timeout on pong. Make sure that we really have sent a recent ping in
267 # case the machine with both server and client has been suspended since the last ping.
268 # case the machine with both server and client has been suspended since the last ping.
268 now = ioloop.IOLoop.current().time()
269 now = ioloop.IOLoop.current().time()
269 since_last_pong = 1e3 * (now - self.last_pong)
270 since_last_pong = 1e3 * (now - self.last_pong)
270 since_last_ping = 1e3 * (now - self.last_ping)
271 since_last_ping = 1e3 * (now - self.last_ping)
271 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
272 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
272 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
273 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
273 self.close()
274 self.close()
274 return
275 return
275
276
276 self.ping(b'')
277 self.ping(b'')
277 self.last_ping = now
278 self.last_ping = now
278
279
279 def on_pong(self, data):
280 def on_pong(self, data):
280 self.last_pong = ioloop.IOLoop.current().time()
281 self.last_pong = ioloop.IOLoop.current().time()
General Comments 0
You need to be logged in to leave comments. Login now