From 600762a669331a4cb5fcdcedfe6da1df655aa986 2014-06-28 00:47:56
From: MinRK <benjaminrk@gmail.com>
Date: 2014-06-28 00:47:56
Subject: [PATCH] make CORS configurable

allows setting CORS headers.

- cors_origin sets Access-Control-Allow-Origin directly
- cors_origin_pat allows setting Access-Control-Allow-Origin via regular expression, since the header spec itself doesn’t support complex access[1]
- cors_credentials sets Access-Control-Allow-Credentials: true

To allow CORS from everywhere:

    ipython notebook —NotebookApp.cors_origin='*'


---

diff --git a/IPython/html/base/handlers.py b/IPython/html/base/handlers.py
index d8d107c..604e0f0 100644
--- a/IPython/html/base/handlers.py
+++ b/IPython/html/base/handlers.py
@@ -153,6 +153,48 @@ class IPythonHandler(AuthenticatedHandler):
         return self.notebook_manager.notebook_dir
     
     #---------------------------------------------------------------
+    # CORS
+    #---------------------------------------------------------------
+    
+    @property
+    def cors_origin(self):
+        """Normal Access-Control-Allow-Origin"""
+        return self.settings.get('cors_origin', '')
+    
+    @property
+    def cors_origin_pat(self):
+        """Regular expression version of cors_origin"""
+        return self.settings.get('cors_origin_pat', None)
+    
+    @property
+    def cors_credentials(self):
+        """Whether to set Access-Control-Allow-Credentials"""
+        return self.settings.get('cors_credentials', False)
+    
+    def set_default_headers(self):
+        """Add CORS headers, if defined"""
+        super(IPythonHandler, self).set_default_headers()
+        if self.cors_origin:
+            self.set_header("Access-Control-Allow-Origin", self.cors_origin)
+        elif self.cors_origin_pat:
+            origin = self.get_origin()
+            if origin and self.cors_origin_pat.match(origin):
+                self.set_header("Access-Control-Allow-Origin", origin)
+        if self.cors_credentials:
+            self.set_header("Access-Control-Allow-Credentials", 'true')
+    
+    def get_origin(self):
+        # Handle WebSocket Origin naming convention differences
+        # The difference between version 8 and 13 is that in 8 the
+        # client sends a "Sec-Websocket-Origin" header and in 13 it's
+        # simply "Origin".
+        if "Origin" in self.request.headers:
+            origin = self.request.headers.get("Origin")
+        else:
+            origin = self.request.headers.get("Sec-Websocket-Origin", None)
+        return origin
+    
+    #---------------------------------------------------------------
     # template rendering
     #---------------------------------------------------------------
     
diff --git a/IPython/html/base/zmqhandlers.py b/IPython/html/base/zmqhandlers.py
index 8999b26..dc18bb8 100644
--- a/IPython/html/base/zmqhandlers.py
+++ b/IPython/html/base/zmqhandlers.py
@@ -15,6 +15,8 @@ try:
 except ImportError:
     from Cookie import SimpleCookie  # Py 2
 import logging
+
+import tornado
 from tornado import web
 from tornado import websocket
 
@@ -26,29 +28,35 @@ from .handlers import IPythonHandler
 
 
 class ZMQStreamHandler(websocket.WebSocketHandler):
-
-    def same_origin(self):
-        """Check to see that origin and host match in the headers."""
-
-        # The difference between version 8 and 13 is that in 8 the
-        # client sends a "Sec-Websocket-Origin" header and in 13 it's
-        # simply "Origin".
-        if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8"):
-            origin_header = self.request.headers.get("Sec-Websocket-Origin")
-        else:
-            origin_header = self.request.headers.get("Origin")
+    
+    def check_origin(self, origin):
+        """Check Origin == Host or CORS origins."""
+        if self.cors_origin == '*':
+            return True
 
         host = self.request.headers.get("Host")
 
         # If no header is provided, assume we can't verify origin
-        if(origin_header is None or host is None):
+        if(origin is None or host is None):
+            return False
+        
+        host_origin = "{0}://{1}".format(self.request.protocol, host)
+        
+        # OK if origin matches host
+        if origin == host_origin:
+            return True
+        
+        # Check CORS headers
+        if self.cors_origin:
+            if self.cors_origin == '*':
+                return True
+            else:
+                return self.cors_origin == origin
+        elif self.cors_origin_pat:
+            return bool(self.cors_origin_pat.match(origin))
+        else:
+            # No CORS headers, deny the request
             return False
-
-        parsed_origin = urlparse(origin_header)
-        origin = parsed_origin.netloc
-
-        # Check to see that origin matches host directly, including ports
-        return origin == host
 
     def clear_cookie(self, *args, **kwargs):
         """meaningless for websockets"""
@@ -96,13 +104,21 @@ class ZMQStreamHandler(websocket.WebSocketHandler):
 
 
 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
+    def set_default_headers(self):
+        """Undo the set_default_headers in IPythonHandler
+        
+        which doesn't make sense for websockets
+        """
+        pass
 
     def open(self, kernel_id):
         self.kernel_id = cast_unicode(kernel_id, 'ascii')
         # Check to see that origin matches host directly, including ports
-        if not self.same_origin():
-            self.log.warn("Cross Origin WebSocket Attempt.")
-            raise web.HTTPError(404)
+        # Tornado 4 already does CORS checking
+        if tornado.version_info[0] < 4:
+            if not self.check_origin(self.get_origin()):
+                self.log.warn("Cross Origin WebSocket Attempt.")
+                raise web.HTTPError(404)
 
         self.session = Session(config=self.config)
         self.save_on_message = self.on_message
diff --git a/IPython/html/notebookapp.py b/IPython/html/notebookapp.py
index 0533f8a..cf5288b 100644
--- a/IPython/html/notebookapp.py
+++ b/IPython/html/notebookapp.py
@@ -12,6 +12,7 @@ import json
 import logging
 import os
 import random
+import re
 import select
 import signal
 import socket
@@ -333,8 +334,34 @@ class NotebookApp(BaseIPythonApplication):
             self.file_to_run = base
             self.notebook_dir = path
 
-    # Network related information.
-
+    # Network related information
+    
+    cors_origin = Unicode('', config=True,
+        help="""Set the Access-Control-Allow-Origin header
+        
+        Use '*' to allow any origin to access your server.
+        
+        Mutually exclusive with cors_origin_pat.
+        """
+    )
+    
+    cors_origin_pat = Unicode('', config=True,
+        help="""Use a regular expression for the Access-Control-Allow-Origin header
+        
+        Requests from an origin matching the expression will get replies with:
+        
+            Access-Control-Allow-Origin: origin
+        
+        where `origin` is the origin of the request.
+        
+        Mutually exclusive with cors_origin.
+        """
+    )
+    
+    cors_credentials = Bool(False, config=True,
+        help="Set the Access-Control-Allow-Credentials: true header"
+    )
+    
     ip = Unicode('localhost', config=True,
         help="The IP address the notebook server will listen on."
     )
@@ -622,6 +649,10 @@ class NotebookApp(BaseIPythonApplication):
     
     def init_webapp(self):
         """initialize tornado webapp and httpserver"""
+        self.webapp_settings['cors_origin'] = self.cors_origin
+        self.webapp_settings['cors_origin_pat'] = re.compile(self.cors_origin_pat)
+        self.webapp_settings['cors_credentials'] = self.cors_credentials
+        
         self.web_app = NotebookWebApplication(
             self, self.kernel_manager, self.notebook_manager, 
             self.cluster_manager, self.session_manager, self.kernel_spec_manager,