diff --git a/IPython/html/base/zmqhandlers.py b/IPython/html/base/zmqhandlers.py index 0d4c95a..c651dbb 100644 --- a/IPython/html/base/zmqhandlers.py +++ b/IPython/html/base/zmqhandlers.py @@ -17,6 +17,11 @@ Authors: #----------------------------------------------------------------------------- try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: from http.cookies import SimpleCookie # Py 3 except ImportError: from Cookie import SimpleCookie # Py 2 @@ -37,6 +42,29 @@ from .handlers import IPythonHandler #----------------------------------------------------------------------------- class ZMQStreamHandler(websocket.WebSocketHandler): + + def check_origin(self): + """Check origin from headers.""" + origin_header = self.request.headers["Origin"] + host = self.request.headers["Host"] + + parsed_origin = urlparse(origin_header) + origin = parsed_origin.netloc + + # Check to see that origin matches host directly, including ports + if origin != host: + self.log.critical("Cross Origin WebSocket Attempt.", exc_info=True) + raise web.HTTPError(404) + + + def _execute(self, transforms, *args, **kwargs): + """Wrap all calls to make sure origin gets checked.""" + + # Check to see that origin matches host directly, including ports + self.check_origin() + + # Pass on the rest of the handling by the WebSocketHandler + super(ZMQStreamHandler, self)._execute(transforms, *args, **kwargs) def clear_cookie(self, *args, **kwargs): """meaningless for websockets"""