diff --git a/IPython/html/base/zmqhandlers.py b/IPython/html/base/zmqhandlers.py index 362c586..1b741ae 100644 --- a/IPython/html/base/zmqhandlers.py +++ b/IPython/html/base/zmqhandlers.py @@ -43,29 +43,17 @@ from .handlers import IPythonHandler class ZMQStreamHandler(websocket.WebSocketHandler): - def check_origin(self): - """Check origin from headers.""" - origin_header = self.request.headers["Origin"] - host = self.request.headers["Host"] + def is_cross_origin(self): + """Check to see that origin and host match in the headers.""" + origin_header = self.request.headers.get("Origin") + host = self.request.headers.get("Host") parsed_origin = urlparse(origin_header) origin = parsed_origin.netloc # Check to see that origin matches host directly, including ports - if origin != host: - self.log.warn("Cross Origin WebSocket Attempt.") - raise web.HTTPError(404) - - - def _execute(self, *args, **kwargs): - """Wrap all calls to make sure origin gets checked.""" - - # Check to see that origin matches host directly, including ports - self.check_origin() + return origin != host - # Pass on the rest of the handling by the WebSocketHandler - super(ZMQStreamHandler, self)._execute(*args, **kwargs) - def clear_cookie(self, *args, **kwargs): """meaningless for websockets""" pass @@ -114,6 +102,11 @@ class ZMQStreamHandler(websocket.WebSocketHandler): class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler): def open(self, kernel_id): + # Check to see that origin matches host directly, including ports + if self.is_cross_origin(): + self.log.warn("Cross Origin WebSocket Attempt.") + raise web.HTTPError(404) + self.kernel_id = cast_unicode(kernel_id, 'ascii') self.session = Session(config=self.config) self.save_on_message = self.on_message @@ -142,4 +135,4 @@ class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler): if self.get_current_user() is None: self.log.warn("Couldn't authenticate WebSocket connection") raise web.HTTPError(403) - self.on_message = self.save_on_message \ No newline at end of file + self.on_message = self.save_on_message