##// END OF EJS Templates
forward-port draft76 websockets...
Min RK -
Show More
@@ -0,0 +1,312 b''
1 """WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
2
3 The contents of this file are Copyright (c) Tornado
4 Used under the Apache 2.0 license
5 """
6
7
8 from __future__ import absolute_import, division, print_function, with_statement
9 # Author: Jacob Kristhammar, 2010
10
11 import functools
12 import hashlib
13 import struct
14 import time
15 import tornado.escape
16 import tornado.web
17
18 from tornado.log import gen_log, app_log
19 from tornado.util import bytes_type, unicode_type
20
21 from tornado.websocket import WebSocketHandler, WebSocketProtocol13
22
23 class AllowDraftWebSocketHandler(WebSocketHandler):
24 """Restore Draft76 support for tornado 4
25
26 Remove when we can run tests without phantomjs + qt4
27 """
28
29 # get is unmodified except between the BEGIN/END PATCH lines
30 @tornado.web.asynchronous
31 def get(self, *args, **kwargs):
32 self.open_args = args
33 self.open_kwargs = kwargs
34
35 # Upgrade header should be present and should be equal to WebSocket
36 if self.request.headers.get("Upgrade", "").lower() != 'websocket':
37 self.set_status(400)
38 self.finish("Can \"Upgrade\" only to \"WebSocket\".")
39 return
40
41 # Connection header should be upgrade. Some proxy servers/load balancers
42 # might mess with it.
43 headers = self.request.headers
44 connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
45 if 'upgrade' not in connection:
46 self.set_status(400)
47 self.finish("\"Connection\" must be \"Upgrade\".")
48 return
49
50 # Handle WebSocket Origin naming convention differences
51 # The difference between version 8 and 13 is that in 8 the
52 # client sends a "Sec-Websocket-Origin" header and in 13 it's
53 # simply "Origin".
54 if "Origin" in self.request.headers:
55 origin = self.request.headers.get("Origin")
56 else:
57 origin = self.request.headers.get("Sec-Websocket-Origin", None)
58
59
60 # If there was an origin header, check to make sure it matches
61 # according to check_origin. When the origin is None, we assume it
62 # did not come from a browser and that it can be passed on.
63 if origin is not None and not self.check_origin(origin):
64 self.set_status(403)
65 self.finish("Cross origin websockets not allowed")
66 return
67
68 self.stream = self.request.connection.detach()
69 self.stream.set_close_callback(self.on_connection_close)
70
71 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
72 self.ws_connection = WebSocketProtocol13(
73 self, compression_options=self.get_compression_options())
74 self.ws_connection.accept_connection()
75 #--------------- BEGIN PATCH ----------------
76 elif (self.allow_draft76() and
77 "Sec-WebSocket-Version" not in self.request.headers):
78 self.ws_connection = WebSocketProtocol76(self)
79 self.ws_connection.accept_connection()
80 #--------------- END PATCH ----------------
81 else:
82 if not self.stream.closed():
83 self.stream.write(tornado.escape.utf8(
84 "HTTP/1.1 426 Upgrade Required\r\n"
85 "Sec-WebSocket-Version: 8\r\n\r\n"))
86 self.stream.close()
87
88 # 3.2 methods removed in 4.0:
89 def allow_draft76(self):
90 """Using this class allows draft76 connections by default"""
91 return True
92
93 def get_websocket_scheme(self):
94 """Return the url scheme used for this request, either "ws" or "wss".
95 This is normally decided by HTTPServer, but applications
96 may wish to override this if they are using an SSL proxy
97 that does not provide the X-Scheme header as understood
98 by HTTPServer.
99 Note that this is only used by the draft76 protocol.
100 """
101 return "wss" if self.request.protocol == "https" else "ws"
102
103
104
105 # No modifications from tornado-3.2.2 below this line
106
107 class WebSocketProtocol(object):
108 """Base class for WebSocket protocol versions.
109 """
110 def __init__(self, handler):
111 self.handler = handler
112 self.request = handler.request
113 self.stream = handler.stream
114 self.client_terminated = False
115 self.server_terminated = False
116
117 def async_callback(self, callback, *args, **kwargs):
118 """Wrap callbacks with this if they are used on asynchronous requests.
119
120 Catches exceptions properly and closes this WebSocket if an exception
121 is uncaught.
122 """
123 if args or kwargs:
124 callback = functools.partial(callback, *args, **kwargs)
125
126 def wrapper(*args, **kwargs):
127 try:
128 return callback(*args, **kwargs)
129 except Exception:
130 app_log.error("Uncaught exception in %s",
131 self.request.path, exc_info=True)
132 self._abort()
133 return wrapper
134
135 def on_connection_close(self):
136 self._abort()
137
138 def _abort(self):
139 """Instantly aborts the WebSocket connection by closing the socket"""
140 self.client_terminated = True
141 self.server_terminated = True
142 self.stream.close() # forcibly tear down the connection
143 self.close() # let the subclass cleanup
144
145
146 class WebSocketProtocol76(WebSocketProtocol):
147 """Implementation of the WebSockets protocol, version hixie-76.
148
149 This class provides basic functionality to process WebSockets requests as
150 specified in
151 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
152 """
153 def __init__(self, handler):
154 WebSocketProtocol.__init__(self, handler)
155 self.challenge = None
156 self._waiting = None
157
158 def accept_connection(self):
159 try:
160 self._handle_websocket_headers()
161 except ValueError:
162 gen_log.debug("Malformed WebSocket request received")
163 self._abort()
164 return
165
166 scheme = self.handler.get_websocket_scheme()
167
168 # draft76 only allows a single subprotocol
169 subprotocol_header = ''
170 subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
171 if subprotocol:
172 selected = self.handler.select_subprotocol([subprotocol])
173 if selected:
174 assert selected == subprotocol
175 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
176
177 # Write the initial headers before attempting to read the challenge.
178 # This is necessary when using proxies (such as HAProxy), which
179 # need to see the Upgrade headers before passing through the
180 # non-HTTP traffic that follows.
181 self.stream.write(tornado.escape.utf8(
182 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
183 "Upgrade: WebSocket\r\n"
184 "Connection: Upgrade\r\n"
185 "Server: TornadoServer/%(version)s\r\n"
186 "Sec-WebSocket-Origin: %(origin)s\r\n"
187 "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
188 "%(subprotocol)s"
189 "\r\n" % (dict(
190 version=tornado.version,
191 origin=self.request.headers["Origin"],
192 scheme=scheme,
193 host=self.request.host,
194 uri=self.request.uri,
195 subprotocol=subprotocol_header))))
196 self.stream.read_bytes(8, self._handle_challenge)
197
198 def challenge_response(self, challenge):
199 """Generates the challenge response that's needed in the handshake
200
201 The challenge parameter should be the raw bytes as sent from the
202 client.
203 """
204 key_1 = self.request.headers.get("Sec-Websocket-Key1")
205 key_2 = self.request.headers.get("Sec-Websocket-Key2")
206 try:
207 part_1 = self._calculate_part(key_1)
208 part_2 = self._calculate_part(key_2)
209 except ValueError:
210 raise ValueError("Invalid Keys/Challenge")
211 return self._generate_challenge_response(part_1, part_2, challenge)
212
213 def _handle_challenge(self, challenge):
214 try:
215 challenge_response = self.challenge_response(challenge)
216 except ValueError:
217 gen_log.debug("Malformed key data in WebSocket request")
218 self._abort()
219 return
220 self._write_response(challenge_response)
221
222 def _write_response(self, challenge):
223 self.stream.write(challenge)
224 self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
225 self._receive_message()
226
227 def _handle_websocket_headers(self):
228 """Verifies all invariant- and required headers
229
230 If a header is missing or have an incorrect value ValueError will be
231 raised
232 """
233 fields = ("Origin", "Host", "Sec-Websocket-Key1",
234 "Sec-Websocket-Key2")
235 if not all(map(lambda f: self.request.headers.get(f), fields)):
236 raise ValueError("Missing/Invalid WebSocket headers")
237
238 def _calculate_part(self, key):
239 """Processes the key headers and calculates their key value.
240
241 Raises ValueError when feed invalid key."""
242 # pyflakes complains about variable reuse if both of these lines use 'c'
243 number = int(''.join(c for c in key if c.isdigit()))
244 spaces = len([c2 for c2 in key if c2.isspace()])
245 try:
246 key_number = number // spaces
247 except (ValueError, ZeroDivisionError):
248 raise ValueError
249 return struct.pack(">I", key_number)
250
251 def _generate_challenge_response(self, part_1, part_2, part_3):
252 m = hashlib.md5()
253 m.update(part_1)
254 m.update(part_2)
255 m.update(part_3)
256 return m.digest()
257
258 def _receive_message(self):
259 self.stream.read_bytes(1, self._on_frame_type)
260
261 def _on_frame_type(self, byte):
262 frame_type = ord(byte)
263 if frame_type == 0x00:
264 self.stream.read_until(b"\xff", self._on_end_delimiter)
265 elif frame_type == 0xff:
266 self.stream.read_bytes(1, self._on_length_indicator)
267 else:
268 self._abort()
269
270 def _on_end_delimiter(self, frame):
271 if not self.client_terminated:
272 self.async_callback(self.handler.on_message)(
273 frame[:-1].decode("utf-8", "replace"))
274 if not self.client_terminated:
275 self._receive_message()
276
277 def _on_length_indicator(self, byte):
278 if ord(byte) != 0x00:
279 self._abort()
280 return
281 self.client_terminated = True
282 self.close()
283
284 def write_message(self, message, binary=False):
285 """Sends the given message to the client of this Web Socket."""
286 if binary:
287 raise ValueError(
288 "Binary messages not supported by this version of websockets")
289 if isinstance(message, unicode_type):
290 message = message.encode("utf-8")
291 assert isinstance(message, bytes_type)
292 self.stream.write(b"\x00" + message + b"\xff")
293
294 def write_ping(self, data):
295 """Send ping frame."""
296 raise ValueError("Ping messages not supported by this version of websockets")
297
298 def close(self):
299 """Closes the WebSocket connection."""
300 if not self.server_terminated:
301 if not self.stream.closed():
302 self.stream.write("\xff\x00")
303 self.server_terminated = True
304 if self.client_terminated:
305 if self._waiting is not None:
306 self.stream.io_loop.remove_timeout(self._waiting)
307 self._waiting = None
308 self.stream.close()
309 elif self._waiting is None:
310 self._waiting = self.stream.io_loop.add_timeout(
311 time.time() + 5, self._abort)
312
@@ -4,8 +4,10 b''
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import json
8 import json
8 import struct
9 import struct
10 import warnings
9
11
10 try:
12 try:
11 from urllib.parse import urlparse # Py 3
13 from urllib.parse import urlparse # Py 3
@@ -13,7 +15,8 b' except ImportError:'
13 from urlparse import urlparse # Py 2
15 from urlparse import urlparse # Py 2
14
16
15 import tornado
17 import tornado
16 from tornado import gen, ioloop, web, websocket
18 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
17
20
18 from IPython.kernel.zmq.session import Session
21 from IPython.kernel.zmq.session import Session
19 from IPython.utils.jsonutil import date_default, extract_dates
22 from IPython.utils.jsonutil import date_default, extract_dates
@@ -21,7 +24,6 b' from IPython.utils.py3compat import cast_unicode'
21
24
22 from .handlers import IPythonHandler
25 from .handlers import IPythonHandler
23
26
24
25 def serialize_binary_message(msg):
27 def serialize_binary_message(msg):
26 """serialize a message as a binary blob
28 """serialize a message as a binary blob
27
29
@@ -79,8 +81,18 b' def deserialize_binary_message(bmsg):'
79 msg['buffers'] = bufs[1:]
81 msg['buffers'] = bufs[1:]
80 return msg
82 return msg
81
83
84 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
82
86
83 class ZMQStreamHandler(websocket.WebSocketHandler):
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
94
95 class ZMQStreamHandler(WebSocketHandler):
84
96
85 def check_origin(self, origin):
97 def check_origin(self, origin):
86 """Check Origin == Host or Access-Control-Allow-Origin.
98 """Check Origin == Host or Access-Control-Allow-Origin.
@@ -154,17 +166,6 b' class ZMQStreamHandler(websocket.WebSocketHandler):'
154 else:
166 else:
155 self.write_message(msg, binary=isinstance(msg, bytes))
167 self.write_message(msg, binary=isinstance(msg, bytes))
156
168
157 def allow_draft76(self):
158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
159
160 This has been disabled by default in tornado in release 2.2.0, and
161 support will be removed in later versions.
162 """
163 return True
164
165 # ping interval for keeping websockets alive (30 seconds)
166 WS_PING_INTERVAL = 30000
167
168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 ping_callback = None
170 ping_callback = None
170 last_ping = 0
171 last_ping = 0
@@ -325,7 +325,15 b' class JSController(TestController):'
325 command.append('--KernelManager.transport=ipc')
325 command.append('--KernelManager.transport=ipc')
326 self.stream_capturer = c = StreamCapturer()
326 self.stream_capturer = c = StreamCapturer()
327 c.start()
327 c.start()
328 self.server = subprocess.Popen(command, stdout=c.writefd, stderr=subprocess.STDOUT, cwd=self.nbdir.name)
328 env = os.environ.copy()
329 if self.engine == 'phantomjs':
330 env['IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS'] = '1'
331 self.server = subprocess.Popen(command,
332 stdout=c.writefd,
333 stderr=subprocess.STDOUT,
334 cwd=self.nbdir.name,
335 env=env,
336 )
329 self.server_info_file = os.path.join(self.ipydir.name,
337 self.server_info_file = os.path.join(self.ipydir.name,
330 'profile_default', 'security', 'nbserver-%i.json' % self.server.pid
338 'profile_default', 'security', 'nbserver-%i.json' % self.server.pid
331 )
339 )
General Comments 0
You need to be logged in to leave comments. Login now