diff --git a/IPython/html/base/zmqhandlers.py b/IPython/html/base/zmqhandlers.py
index 362c586..1b741ae 100644
--- a/IPython/html/base/zmqhandlers.py
+++ b/IPython/html/base/zmqhandlers.py
@@ -43,29 +43,17 @@ from .handlers import IPythonHandler
class ZMQStreamHandler(websocket.WebSocketHandler):
- def check_origin(self):
- """Check origin from headers."""
- origin_header = self.request.headers["Origin"]
- host = self.request.headers["Host"]
+ def is_cross_origin(self):
+ """Check to see that origin and host match in the headers."""
+ origin_header = self.request.headers.get("Origin")
+ host = self.request.headers.get("Host")
parsed_origin = urlparse(origin_header)
origin = parsed_origin.netloc
# Check to see that origin matches host directly, including ports
- if origin != host:
- self.log.warn("Cross Origin WebSocket Attempt.")
- raise web.HTTPError(404)
-
-
- def _execute(self, *args, **kwargs):
- """Wrap all calls to make sure origin gets checked."""
-
- # Check to see that origin matches host directly, including ports
- self.check_origin()
+ return origin != host
- # Pass on the rest of the handling by the WebSocketHandler
- super(ZMQStreamHandler, self)._execute(*args, **kwargs)
-
def clear_cookie(self, *args, **kwargs):
"""meaningless for websockets"""
pass
@@ -114,6 +102,11 @@ class ZMQStreamHandler(websocket.WebSocketHandler):
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
def open(self, kernel_id):
+ # Check to see that origin matches host directly, including ports
+ if self.is_cross_origin():
+ self.log.warn("Cross Origin WebSocket Attempt.")
+ raise web.HTTPError(404)
+
self.kernel_id = cast_unicode(kernel_id, 'ascii')
self.session = Session(config=self.config)
self.save_on_message = self.on_message
@@ -142,4 +135,4 @@ class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
if self.get_current_user() is None:
self.log.warn("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
- self.on_message = self.save_on_message
\ No newline at end of file
+ self.on_message = self.save_on_message