##// END OF EJS Templates
Merge pull request #6951 from ccordoba12/fix-draft76...
Min RK -
r18923:a70bd067 merge
parent child Browse files
Show More
@@ -1,312 +1,311 b''
1 """WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
1 """WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
2
2
3 The contents of this file are Copyright (c) Tornado
3 The contents of this file are Copyright (c) Tornado
4 Used under the Apache 2.0 license
4 Used under the Apache 2.0 license
5 """
5 """
6
6
7
7
8 from __future__ import absolute_import, division, print_function, with_statement
8 from __future__ import absolute_import, division, print_function, with_statement
9 # Author: Jacob Kristhammar, 2010
9 # Author: Jacob Kristhammar, 2010
10
10
11 import functools
11 import functools
12 import hashlib
12 import hashlib
13 import struct
13 import struct
14 import time
14 import time
15 import tornado.escape
15 import tornado.escape
16 import tornado.web
16 import tornado.web
17
17
18 from tornado.log import gen_log, app_log
18 from tornado.log import gen_log, app_log
19 from tornado.util import bytes_type, unicode_type
19 from tornado.util import bytes_type, unicode_type
20
20
21 from tornado.websocket import WebSocketHandler, WebSocketProtocol13
21 from tornado.websocket import WebSocketHandler, WebSocketProtocol13
22
22
23 class AllowDraftWebSocketHandler(WebSocketHandler):
23 class AllowDraftWebSocketHandler(WebSocketHandler):
24 """Restore Draft76 support for tornado 4
24 """Restore Draft76 support for tornado 4
25
25
26 Remove when we can run tests without phantomjs + qt4
26 Remove when we can run tests without phantomjs + qt4
27 """
27 """
28
28
29 # get is unmodified except between the BEGIN/END PATCH lines
29 # get is unmodified except between the BEGIN/END PATCH lines
30 @tornado.web.asynchronous
30 @tornado.web.asynchronous
31 def get(self, *args, **kwargs):
31 def get(self, *args, **kwargs):
32 self.open_args = args
32 self.open_args = args
33 self.open_kwargs = kwargs
33 self.open_kwargs = kwargs
34
34
35 # Upgrade header should be present and should be equal to WebSocket
35 # Upgrade header should be present and should be equal to WebSocket
36 if self.request.headers.get("Upgrade", "").lower() != 'websocket':
36 if self.request.headers.get("Upgrade", "").lower() != 'websocket':
37 self.set_status(400)
37 self.set_status(400)
38 self.finish("Can \"Upgrade\" only to \"WebSocket\".")
38 self.finish("Can \"Upgrade\" only to \"WebSocket\".")
39 return
39 return
40
40
41 # Connection header should be upgrade. Some proxy servers/load balancers
41 # Connection header should be upgrade. Some proxy servers/load balancers
42 # might mess with it.
42 # might mess with it.
43 headers = self.request.headers
43 headers = self.request.headers
44 connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
44 connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
45 if 'upgrade' not in connection:
45 if 'upgrade' not in connection:
46 self.set_status(400)
46 self.set_status(400)
47 self.finish("\"Connection\" must be \"Upgrade\".")
47 self.finish("\"Connection\" must be \"Upgrade\".")
48 return
48 return
49
49
50 # Handle WebSocket Origin naming convention differences
50 # Handle WebSocket Origin naming convention differences
51 # The difference between version 8 and 13 is that in 8 the
51 # The difference between version 8 and 13 is that in 8 the
52 # client sends a "Sec-Websocket-Origin" header and in 13 it's
52 # client sends a "Sec-Websocket-Origin" header and in 13 it's
53 # simply "Origin".
53 # simply "Origin".
54 if "Origin" in self.request.headers:
54 if "Origin" in self.request.headers:
55 origin = self.request.headers.get("Origin")
55 origin = self.request.headers.get("Origin")
56 else:
56 else:
57 origin = self.request.headers.get("Sec-Websocket-Origin", None)
57 origin = self.request.headers.get("Sec-Websocket-Origin", None)
58
58
59
59
60 # If there was an origin header, check to make sure it matches
60 # If there was an origin header, check to make sure it matches
61 # according to check_origin. When the origin is None, we assume it
61 # according to check_origin. When the origin is None, we assume it
62 # did not come from a browser and that it can be passed on.
62 # did not come from a browser and that it can be passed on.
63 if origin is not None and not self.check_origin(origin):
63 if origin is not None and not self.check_origin(origin):
64 self.set_status(403)
64 self.set_status(403)
65 self.finish("Cross origin websockets not allowed")
65 self.finish("Cross origin websockets not allowed")
66 return
66 return
67
67
68 self.stream = self.request.connection.detach()
68 self.stream = self.request.connection.detach()
69 self.stream.set_close_callback(self.on_connection_close)
69 self.stream.set_close_callback(self.on_connection_close)
70
70
71 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
71 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
72 self.ws_connection = WebSocketProtocol13(
72 self.ws_connection = WebSocketProtocol13(self)
73 self, compression_options=self.get_compression_options())
74 self.ws_connection.accept_connection()
73 self.ws_connection.accept_connection()
75 #--------------- BEGIN PATCH ----------------
74 #--------------- BEGIN PATCH ----------------
76 elif (self.allow_draft76() and
75 elif (self.allow_draft76() and
77 "Sec-WebSocket-Version" not in self.request.headers):
76 "Sec-WebSocket-Version" not in self.request.headers):
78 self.ws_connection = WebSocketProtocol76(self)
77 self.ws_connection = WebSocketProtocol76(self)
79 self.ws_connection.accept_connection()
78 self.ws_connection.accept_connection()
80 #--------------- END PATCH ----------------
79 #--------------- END PATCH ----------------
81 else:
80 else:
82 if not self.stream.closed():
81 if not self.stream.closed():
83 self.stream.write(tornado.escape.utf8(
82 self.stream.write(tornado.escape.utf8(
84 "HTTP/1.1 426 Upgrade Required\r\n"
83 "HTTP/1.1 426 Upgrade Required\r\n"
85 "Sec-WebSocket-Version: 8\r\n\r\n"))
84 "Sec-WebSocket-Version: 8\r\n\r\n"))
86 self.stream.close()
85 self.stream.close()
87
86
88 # 3.2 methods removed in 4.0:
87 # 3.2 methods removed in 4.0:
89 def allow_draft76(self):
88 def allow_draft76(self):
90 """Using this class allows draft76 connections by default"""
89 """Using this class allows draft76 connections by default"""
91 return True
90 return True
92
91
93 def get_websocket_scheme(self):
92 def get_websocket_scheme(self):
94 """Return the url scheme used for this request, either "ws" or "wss".
93 """Return the url scheme used for this request, either "ws" or "wss".
95 This is normally decided by HTTPServer, but applications
94 This is normally decided by HTTPServer, but applications
96 may wish to override this if they are using an SSL proxy
95 may wish to override this if they are using an SSL proxy
97 that does not provide the X-Scheme header as understood
96 that does not provide the X-Scheme header as understood
98 by HTTPServer.
97 by HTTPServer.
99 Note that this is only used by the draft76 protocol.
98 Note that this is only used by the draft76 protocol.
100 """
99 """
101 return "wss" if self.request.protocol == "https" else "ws"
100 return "wss" if self.request.protocol == "https" else "ws"
102
101
103
102
104
103
105 # No modifications from tornado-3.2.2 below this line
104 # No modifications from tornado-3.2.2 below this line
106
105
107 class WebSocketProtocol(object):
106 class WebSocketProtocol(object):
108 """Base class for WebSocket protocol versions.
107 """Base class for WebSocket protocol versions.
109 """
108 """
110 def __init__(self, handler):
109 def __init__(self, handler):
111 self.handler = handler
110 self.handler = handler
112 self.request = handler.request
111 self.request = handler.request
113 self.stream = handler.stream
112 self.stream = handler.stream
114 self.client_terminated = False
113 self.client_terminated = False
115 self.server_terminated = False
114 self.server_terminated = False
116
115
117 def async_callback(self, callback, *args, **kwargs):
116 def async_callback(self, callback, *args, **kwargs):
118 """Wrap callbacks with this if they are used on asynchronous requests.
117 """Wrap callbacks with this if they are used on asynchronous requests.
119
118
120 Catches exceptions properly and closes this WebSocket if an exception
119 Catches exceptions properly and closes this WebSocket if an exception
121 is uncaught.
120 is uncaught.
122 """
121 """
123 if args or kwargs:
122 if args or kwargs:
124 callback = functools.partial(callback, *args, **kwargs)
123 callback = functools.partial(callback, *args, **kwargs)
125
124
126 def wrapper(*args, **kwargs):
125 def wrapper(*args, **kwargs):
127 try:
126 try:
128 return callback(*args, **kwargs)
127 return callback(*args, **kwargs)
129 except Exception:
128 except Exception:
130 app_log.error("Uncaught exception in %s",
129 app_log.error("Uncaught exception in %s",
131 self.request.path, exc_info=True)
130 self.request.path, exc_info=True)
132 self._abort()
131 self._abort()
133 return wrapper
132 return wrapper
134
133
135 def on_connection_close(self):
134 def on_connection_close(self):
136 self._abort()
135 self._abort()
137
136
138 def _abort(self):
137 def _abort(self):
139 """Instantly aborts the WebSocket connection by closing the socket"""
138 """Instantly aborts the WebSocket connection by closing the socket"""
140 self.client_terminated = True
139 self.client_terminated = True
141 self.server_terminated = True
140 self.server_terminated = True
142 self.stream.close() # forcibly tear down the connection
141 self.stream.close() # forcibly tear down the connection
143 self.close() # let the subclass cleanup
142 self.close() # let the subclass cleanup
144
143
145
144
146 class WebSocketProtocol76(WebSocketProtocol):
145 class WebSocketProtocol76(WebSocketProtocol):
147 """Implementation of the WebSockets protocol, version hixie-76.
146 """Implementation of the WebSockets protocol, version hixie-76.
148
147
149 This class provides basic functionality to process WebSockets requests as
148 This class provides basic functionality to process WebSockets requests as
150 specified in
149 specified in
151 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
150 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
152 """
151 """
153 def __init__(self, handler):
152 def __init__(self, handler):
154 WebSocketProtocol.__init__(self, handler)
153 WebSocketProtocol.__init__(self, handler)
155 self.challenge = None
154 self.challenge = None
156 self._waiting = None
155 self._waiting = None
157
156
158 def accept_connection(self):
157 def accept_connection(self):
159 try:
158 try:
160 self._handle_websocket_headers()
159 self._handle_websocket_headers()
161 except ValueError:
160 except ValueError:
162 gen_log.debug("Malformed WebSocket request received")
161 gen_log.debug("Malformed WebSocket request received")
163 self._abort()
162 self._abort()
164 return
163 return
165
164
166 scheme = self.handler.get_websocket_scheme()
165 scheme = self.handler.get_websocket_scheme()
167
166
168 # draft76 only allows a single subprotocol
167 # draft76 only allows a single subprotocol
169 subprotocol_header = ''
168 subprotocol_header = ''
170 subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
169 subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
171 if subprotocol:
170 if subprotocol:
172 selected = self.handler.select_subprotocol([subprotocol])
171 selected = self.handler.select_subprotocol([subprotocol])
173 if selected:
172 if selected:
174 assert selected == subprotocol
173 assert selected == subprotocol
175 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
174 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
176
175
177 # Write the initial headers before attempting to read the challenge.
176 # Write the initial headers before attempting to read the challenge.
178 # This is necessary when using proxies (such as HAProxy), which
177 # This is necessary when using proxies (such as HAProxy), which
179 # need to see the Upgrade headers before passing through the
178 # need to see the Upgrade headers before passing through the
180 # non-HTTP traffic that follows.
179 # non-HTTP traffic that follows.
181 self.stream.write(tornado.escape.utf8(
180 self.stream.write(tornado.escape.utf8(
182 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
181 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
183 "Upgrade: WebSocket\r\n"
182 "Upgrade: WebSocket\r\n"
184 "Connection: Upgrade\r\n"
183 "Connection: Upgrade\r\n"
185 "Server: TornadoServer/%(version)s\r\n"
184 "Server: TornadoServer/%(version)s\r\n"
186 "Sec-WebSocket-Origin: %(origin)s\r\n"
185 "Sec-WebSocket-Origin: %(origin)s\r\n"
187 "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
186 "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
188 "%(subprotocol)s"
187 "%(subprotocol)s"
189 "\r\n" % (dict(
188 "\r\n" % (dict(
190 version=tornado.version,
189 version=tornado.version,
191 origin=self.request.headers["Origin"],
190 origin=self.request.headers["Origin"],
192 scheme=scheme,
191 scheme=scheme,
193 host=self.request.host,
192 host=self.request.host,
194 uri=self.request.uri,
193 uri=self.request.uri,
195 subprotocol=subprotocol_header))))
194 subprotocol=subprotocol_header))))
196 self.stream.read_bytes(8, self._handle_challenge)
195 self.stream.read_bytes(8, self._handle_challenge)
197
196
198 def challenge_response(self, challenge):
197 def challenge_response(self, challenge):
199 """Generates the challenge response that's needed in the handshake
198 """Generates the challenge response that's needed in the handshake
200
199
201 The challenge parameter should be the raw bytes as sent from the
200 The challenge parameter should be the raw bytes as sent from the
202 client.
201 client.
203 """
202 """
204 key_1 = self.request.headers.get("Sec-Websocket-Key1")
203 key_1 = self.request.headers.get("Sec-Websocket-Key1")
205 key_2 = self.request.headers.get("Sec-Websocket-Key2")
204 key_2 = self.request.headers.get("Sec-Websocket-Key2")
206 try:
205 try:
207 part_1 = self._calculate_part(key_1)
206 part_1 = self._calculate_part(key_1)
208 part_2 = self._calculate_part(key_2)
207 part_2 = self._calculate_part(key_2)
209 except ValueError:
208 except ValueError:
210 raise ValueError("Invalid Keys/Challenge")
209 raise ValueError("Invalid Keys/Challenge")
211 return self._generate_challenge_response(part_1, part_2, challenge)
210 return self._generate_challenge_response(part_1, part_2, challenge)
212
211
213 def _handle_challenge(self, challenge):
212 def _handle_challenge(self, challenge):
214 try:
213 try:
215 challenge_response = self.challenge_response(challenge)
214 challenge_response = self.challenge_response(challenge)
216 except ValueError:
215 except ValueError:
217 gen_log.debug("Malformed key data in WebSocket request")
216 gen_log.debug("Malformed key data in WebSocket request")
218 self._abort()
217 self._abort()
219 return
218 return
220 self._write_response(challenge_response)
219 self._write_response(challenge_response)
221
220
222 def _write_response(self, challenge):
221 def _write_response(self, challenge):
223 self.stream.write(challenge)
222 self.stream.write(challenge)
224 self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
223 self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
225 self._receive_message()
224 self._receive_message()
226
225
227 def _handle_websocket_headers(self):
226 def _handle_websocket_headers(self):
228 """Verifies all invariant- and required headers
227 """Verifies all invariant- and required headers
229
228
230 If a header is missing or have an incorrect value ValueError will be
229 If a header is missing or have an incorrect value ValueError will be
231 raised
230 raised
232 """
231 """
233 fields = ("Origin", "Host", "Sec-Websocket-Key1",
232 fields = ("Origin", "Host", "Sec-Websocket-Key1",
234 "Sec-Websocket-Key2")
233 "Sec-Websocket-Key2")
235 if not all(map(lambda f: self.request.headers.get(f), fields)):
234 if not all(map(lambda f: self.request.headers.get(f), fields)):
236 raise ValueError("Missing/Invalid WebSocket headers")
235 raise ValueError("Missing/Invalid WebSocket headers")
237
236
238 def _calculate_part(self, key):
237 def _calculate_part(self, key):
239 """Processes the key headers and calculates their key value.
238 """Processes the key headers and calculates their key value.
240
239
241 Raises ValueError when feed invalid key."""
240 Raises ValueError when feed invalid key."""
242 # pyflakes complains about variable reuse if both of these lines use 'c'
241 # pyflakes complains about variable reuse if both of these lines use 'c'
243 number = int(''.join(c for c in key if c.isdigit()))
242 number = int(''.join(c for c in key if c.isdigit()))
244 spaces = len([c2 for c2 in key if c2.isspace()])
243 spaces = len([c2 for c2 in key if c2.isspace()])
245 try:
244 try:
246 key_number = number // spaces
245 key_number = number // spaces
247 except (ValueError, ZeroDivisionError):
246 except (ValueError, ZeroDivisionError):
248 raise ValueError
247 raise ValueError
249 return struct.pack(">I", key_number)
248 return struct.pack(">I", key_number)
250
249
251 def _generate_challenge_response(self, part_1, part_2, part_3):
250 def _generate_challenge_response(self, part_1, part_2, part_3):
252 m = hashlib.md5()
251 m = hashlib.md5()
253 m.update(part_1)
252 m.update(part_1)
254 m.update(part_2)
253 m.update(part_2)
255 m.update(part_3)
254 m.update(part_3)
256 return m.digest()
255 return m.digest()
257
256
258 def _receive_message(self):
257 def _receive_message(self):
259 self.stream.read_bytes(1, self._on_frame_type)
258 self.stream.read_bytes(1, self._on_frame_type)
260
259
261 def _on_frame_type(self, byte):
260 def _on_frame_type(self, byte):
262 frame_type = ord(byte)
261 frame_type = ord(byte)
263 if frame_type == 0x00:
262 if frame_type == 0x00:
264 self.stream.read_until(b"\xff", self._on_end_delimiter)
263 self.stream.read_until(b"\xff", self._on_end_delimiter)
265 elif frame_type == 0xff:
264 elif frame_type == 0xff:
266 self.stream.read_bytes(1, self._on_length_indicator)
265 self.stream.read_bytes(1, self._on_length_indicator)
267 else:
266 else:
268 self._abort()
267 self._abort()
269
268
270 def _on_end_delimiter(self, frame):
269 def _on_end_delimiter(self, frame):
271 if not self.client_terminated:
270 if not self.client_terminated:
272 self.async_callback(self.handler.on_message)(
271 self.async_callback(self.handler.on_message)(
273 frame[:-1].decode("utf-8", "replace"))
272 frame[:-1].decode("utf-8", "replace"))
274 if not self.client_terminated:
273 if not self.client_terminated:
275 self._receive_message()
274 self._receive_message()
276
275
277 def _on_length_indicator(self, byte):
276 def _on_length_indicator(self, byte):
278 if ord(byte) != 0x00:
277 if ord(byte) != 0x00:
279 self._abort()
278 self._abort()
280 return
279 return
281 self.client_terminated = True
280 self.client_terminated = True
282 self.close()
281 self.close()
283
282
284 def write_message(self, message, binary=False):
283 def write_message(self, message, binary=False):
285 """Sends the given message to the client of this Web Socket."""
284 """Sends the given message to the client of this Web Socket."""
286 if binary:
285 if binary:
287 raise ValueError(
286 raise ValueError(
288 "Binary messages not supported by this version of websockets")
287 "Binary messages not supported by this version of websockets")
289 if isinstance(message, unicode_type):
288 if isinstance(message, unicode_type):
290 message = message.encode("utf-8")
289 message = message.encode("utf-8")
291 assert isinstance(message, bytes_type)
290 assert isinstance(message, bytes_type)
292 self.stream.write(b"\x00" + message + b"\xff")
291 self.stream.write(b"\x00" + message + b"\xff")
293
292
294 def write_ping(self, data):
293 def write_ping(self, data):
295 """Send ping frame."""
294 """Send ping frame."""
296 raise ValueError("Ping messages not supported by this version of websockets")
295 raise ValueError("Ping messages not supported by this version of websockets")
297
296
298 def close(self):
297 def close(self):
299 """Closes the WebSocket connection."""
298 """Closes the WebSocket connection."""
300 if not self.server_terminated:
299 if not self.server_terminated:
301 if not self.stream.closed():
300 if not self.stream.closed():
302 self.stream.write("\xff\x00")
301 self.stream.write("\xff\x00")
303 self.server_terminated = True
302 self.server_terminated = True
304 if self.client_terminated:
303 if self.client_terminated:
305 if self._waiting is not None:
304 if self._waiting is not None:
306 self.stream.io_loop.remove_timeout(self._waiting)
305 self.stream.io_loop.remove_timeout(self._waiting)
307 self._waiting = None
306 self._waiting = None
308 self.stream.close()
307 self.stream.close()
309 elif self._waiting is None:
308 elif self._waiting is None:
310 self._waiting = self.stream.io_loop.add_timeout(
309 self._waiting = self.stream.io_loop.add_timeout(
311 time.time() + 5, self._abort)
310 time.time() + 5, self._abort)
312
311
General Comments 0
You need to be logged in to leave comments. Login now