##// END OF EJS Templates
Merge pull request #6951 from ccordoba12/fix-draft76...
Min RK -
r18923:a70bd067 merge
parent child Browse files
Show More
@@ -1,312 +1,311 b''
1 1 """WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
2 2
3 3 The contents of this file are Copyright (c) Tornado
4 4 Used under the Apache 2.0 license
5 5 """
6 6
7 7
8 8 from __future__ import absolute_import, division, print_function, with_statement
9 9 # Author: Jacob Kristhammar, 2010
10 10
11 11 import functools
12 12 import hashlib
13 13 import struct
14 14 import time
15 15 import tornado.escape
16 16 import tornado.web
17 17
18 18 from tornado.log import gen_log, app_log
19 19 from tornado.util import bytes_type, unicode_type
20 20
21 21 from tornado.websocket import WebSocketHandler, WebSocketProtocol13
22 22
23 23 class AllowDraftWebSocketHandler(WebSocketHandler):
24 24 """Restore Draft76 support for tornado 4
25 25
26 26 Remove when we can run tests without phantomjs + qt4
27 27 """
28 28
29 29 # get is unmodified except between the BEGIN/END PATCH lines
30 30 @tornado.web.asynchronous
31 31 def get(self, *args, **kwargs):
32 32 self.open_args = args
33 33 self.open_kwargs = kwargs
34 34
35 35 # Upgrade header should be present and should be equal to WebSocket
36 36 if self.request.headers.get("Upgrade", "").lower() != 'websocket':
37 37 self.set_status(400)
38 38 self.finish("Can \"Upgrade\" only to \"WebSocket\".")
39 39 return
40 40
41 41 # Connection header should be upgrade. Some proxy servers/load balancers
42 42 # might mess with it.
43 43 headers = self.request.headers
44 44 connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
45 45 if 'upgrade' not in connection:
46 46 self.set_status(400)
47 47 self.finish("\"Connection\" must be \"Upgrade\".")
48 48 return
49 49
50 50 # Handle WebSocket Origin naming convention differences
51 51 # The difference between version 8 and 13 is that in 8 the
52 52 # client sends a "Sec-Websocket-Origin" header and in 13 it's
53 53 # simply "Origin".
54 54 if "Origin" in self.request.headers:
55 55 origin = self.request.headers.get("Origin")
56 56 else:
57 57 origin = self.request.headers.get("Sec-Websocket-Origin", None)
58 58
59 59
60 60 # If there was an origin header, check to make sure it matches
61 61 # according to check_origin. When the origin is None, we assume it
62 62 # did not come from a browser and that it can be passed on.
63 63 if origin is not None and not self.check_origin(origin):
64 64 self.set_status(403)
65 65 self.finish("Cross origin websockets not allowed")
66 66 return
67 67
68 68 self.stream = self.request.connection.detach()
69 69 self.stream.set_close_callback(self.on_connection_close)
70 70
71 71 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
72 self.ws_connection = WebSocketProtocol13(
73 self, compression_options=self.get_compression_options())
72 self.ws_connection = WebSocketProtocol13(self)
74 73 self.ws_connection.accept_connection()
75 74 #--------------- BEGIN PATCH ----------------
76 75 elif (self.allow_draft76() and
77 76 "Sec-WebSocket-Version" not in self.request.headers):
78 77 self.ws_connection = WebSocketProtocol76(self)
79 78 self.ws_connection.accept_connection()
80 79 #--------------- END PATCH ----------------
81 80 else:
82 81 if not self.stream.closed():
83 82 self.stream.write(tornado.escape.utf8(
84 83 "HTTP/1.1 426 Upgrade Required\r\n"
85 84 "Sec-WebSocket-Version: 8\r\n\r\n"))
86 85 self.stream.close()
87 86
88 87 # 3.2 methods removed in 4.0:
89 88 def allow_draft76(self):
90 89 """Using this class allows draft76 connections by default"""
91 90 return True
92 91
93 92 def get_websocket_scheme(self):
94 93 """Return the url scheme used for this request, either "ws" or "wss".
95 94 This is normally decided by HTTPServer, but applications
96 95 may wish to override this if they are using an SSL proxy
97 96 that does not provide the X-Scheme header as understood
98 97 by HTTPServer.
99 98 Note that this is only used by the draft76 protocol.
100 99 """
101 100 return "wss" if self.request.protocol == "https" else "ws"
102 101
103 102
104 103
105 104 # No modifications from tornado-3.2.2 below this line
106 105
107 106 class WebSocketProtocol(object):
108 107 """Base class for WebSocket protocol versions.
109 108 """
110 109 def __init__(self, handler):
111 110 self.handler = handler
112 111 self.request = handler.request
113 112 self.stream = handler.stream
114 113 self.client_terminated = False
115 114 self.server_terminated = False
116 115
117 116 def async_callback(self, callback, *args, **kwargs):
118 117 """Wrap callbacks with this if they are used on asynchronous requests.
119 118
120 119 Catches exceptions properly and closes this WebSocket if an exception
121 120 is uncaught.
122 121 """
123 122 if args or kwargs:
124 123 callback = functools.partial(callback, *args, **kwargs)
125 124
126 125 def wrapper(*args, **kwargs):
127 126 try:
128 127 return callback(*args, **kwargs)
129 128 except Exception:
130 129 app_log.error("Uncaught exception in %s",
131 130 self.request.path, exc_info=True)
132 131 self._abort()
133 132 return wrapper
134 133
135 134 def on_connection_close(self):
136 135 self._abort()
137 136
138 137 def _abort(self):
139 138 """Instantly aborts the WebSocket connection by closing the socket"""
140 139 self.client_terminated = True
141 140 self.server_terminated = True
142 141 self.stream.close() # forcibly tear down the connection
143 142 self.close() # let the subclass cleanup
144 143
145 144
146 145 class WebSocketProtocol76(WebSocketProtocol):
147 146 """Implementation of the WebSockets protocol, version hixie-76.
148 147
149 148 This class provides basic functionality to process WebSockets requests as
150 149 specified in
151 150 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
152 151 """
153 152 def __init__(self, handler):
154 153 WebSocketProtocol.__init__(self, handler)
155 154 self.challenge = None
156 155 self._waiting = None
157 156
158 157 def accept_connection(self):
159 158 try:
160 159 self._handle_websocket_headers()
161 160 except ValueError:
162 161 gen_log.debug("Malformed WebSocket request received")
163 162 self._abort()
164 163 return
165 164
166 165 scheme = self.handler.get_websocket_scheme()
167 166
168 167 # draft76 only allows a single subprotocol
169 168 subprotocol_header = ''
170 169 subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
171 170 if subprotocol:
172 171 selected = self.handler.select_subprotocol([subprotocol])
173 172 if selected:
174 173 assert selected == subprotocol
175 174 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
176 175
177 176 # Write the initial headers before attempting to read the challenge.
178 177 # This is necessary when using proxies (such as HAProxy), which
179 178 # need to see the Upgrade headers before passing through the
180 179 # non-HTTP traffic that follows.
181 180 self.stream.write(tornado.escape.utf8(
182 181 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
183 182 "Upgrade: WebSocket\r\n"
184 183 "Connection: Upgrade\r\n"
185 184 "Server: TornadoServer/%(version)s\r\n"
186 185 "Sec-WebSocket-Origin: %(origin)s\r\n"
187 186 "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
188 187 "%(subprotocol)s"
189 188 "\r\n" % (dict(
190 189 version=tornado.version,
191 190 origin=self.request.headers["Origin"],
192 191 scheme=scheme,
193 192 host=self.request.host,
194 193 uri=self.request.uri,
195 194 subprotocol=subprotocol_header))))
196 195 self.stream.read_bytes(8, self._handle_challenge)
197 196
198 197 def challenge_response(self, challenge):
199 198 """Generates the challenge response that's needed in the handshake
200 199
201 200 The challenge parameter should be the raw bytes as sent from the
202 201 client.
203 202 """
204 203 key_1 = self.request.headers.get("Sec-Websocket-Key1")
205 204 key_2 = self.request.headers.get("Sec-Websocket-Key2")
206 205 try:
207 206 part_1 = self._calculate_part(key_1)
208 207 part_2 = self._calculate_part(key_2)
209 208 except ValueError:
210 209 raise ValueError("Invalid Keys/Challenge")
211 210 return self._generate_challenge_response(part_1, part_2, challenge)
212 211
213 212 def _handle_challenge(self, challenge):
214 213 try:
215 214 challenge_response = self.challenge_response(challenge)
216 215 except ValueError:
217 216 gen_log.debug("Malformed key data in WebSocket request")
218 217 self._abort()
219 218 return
220 219 self._write_response(challenge_response)
221 220
222 221 def _write_response(self, challenge):
223 222 self.stream.write(challenge)
224 223 self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
225 224 self._receive_message()
226 225
227 226 def _handle_websocket_headers(self):
228 227 """Verifies all invariant- and required headers
229 228
230 229 If a header is missing or have an incorrect value ValueError will be
231 230 raised
232 231 """
233 232 fields = ("Origin", "Host", "Sec-Websocket-Key1",
234 233 "Sec-Websocket-Key2")
235 234 if not all(map(lambda f: self.request.headers.get(f), fields)):
236 235 raise ValueError("Missing/Invalid WebSocket headers")
237 236
238 237 def _calculate_part(self, key):
239 238 """Processes the key headers and calculates their key value.
240 239
241 240 Raises ValueError when feed invalid key."""
242 241 # pyflakes complains about variable reuse if both of these lines use 'c'
243 242 number = int(''.join(c for c in key if c.isdigit()))
244 243 spaces = len([c2 for c2 in key if c2.isspace()])
245 244 try:
246 245 key_number = number // spaces
247 246 except (ValueError, ZeroDivisionError):
248 247 raise ValueError
249 248 return struct.pack(">I", key_number)
250 249
251 250 def _generate_challenge_response(self, part_1, part_2, part_3):
252 251 m = hashlib.md5()
253 252 m.update(part_1)
254 253 m.update(part_2)
255 254 m.update(part_3)
256 255 return m.digest()
257 256
258 257 def _receive_message(self):
259 258 self.stream.read_bytes(1, self._on_frame_type)
260 259
261 260 def _on_frame_type(self, byte):
262 261 frame_type = ord(byte)
263 262 if frame_type == 0x00:
264 263 self.stream.read_until(b"\xff", self._on_end_delimiter)
265 264 elif frame_type == 0xff:
266 265 self.stream.read_bytes(1, self._on_length_indicator)
267 266 else:
268 267 self._abort()
269 268
270 269 def _on_end_delimiter(self, frame):
271 270 if not self.client_terminated:
272 271 self.async_callback(self.handler.on_message)(
273 272 frame[:-1].decode("utf-8", "replace"))
274 273 if not self.client_terminated:
275 274 self._receive_message()
276 275
277 276 def _on_length_indicator(self, byte):
278 277 if ord(byte) != 0x00:
279 278 self._abort()
280 279 return
281 280 self.client_terminated = True
282 281 self.close()
283 282
284 283 def write_message(self, message, binary=False):
285 284 """Sends the given message to the client of this Web Socket."""
286 285 if binary:
287 286 raise ValueError(
288 287 "Binary messages not supported by this version of websockets")
289 288 if isinstance(message, unicode_type):
290 289 message = message.encode("utf-8")
291 290 assert isinstance(message, bytes_type)
292 291 self.stream.write(b"\x00" + message + b"\xff")
293 292
294 293 def write_ping(self, data):
295 294 """Send ping frame."""
296 295 raise ValueError("Ping messages not supported by this version of websockets")
297 296
298 297 def close(self):
299 298 """Closes the WebSocket connection."""
300 299 if not self.server_terminated:
301 300 if not self.stream.closed():
302 301 self.stream.write("\xff\x00")
303 302 self.server_terminated = True
304 303 if self.client_terminated:
305 304 if self._waiting is not None:
306 305 self.stream.io_loop.remove_timeout(self._waiting)
307 306 self._waiting = None
308 307 self.stream.close()
309 308 elif self._waiting is None:
310 309 self._waiting = self.stream.io_loop.add_timeout(
311 310 time.time() + 5, self._abort)
312 311
General Comments 0
You need to be logged in to leave comments. Login now