From af13023afb7fc05fa40e4c9bc45a480d0167d50b 2012-09-29 08:14:53
From: Bussonnier Matthias <bussonniermatthias@gmail.com>
Date: 2012-09-29 08:14:53
Subject: [PATCH] Merge pull request #1868 from minrk/ipc

enable IPC transport for kernels

works with the qtconsole

Config is a bit clumsy, because the interpretation of 'ip' is actually a path when transport is IPC.

Notebook does not yet expose the option, because it's still not well integrated into the rest of the config universe.
---

diff --git a/IPython/frontend/consoleapp.py b/IPython/frontend/consoleapp.py
index 4d229ce..8b544d8 100644
--- a/IPython/frontend/consoleapp.py
+++ b/IPython/frontend/consoleapp.py
@@ -24,6 +24,7 @@ Authors:
 import atexit
 import json
 import os
+import shutil
 import signal
 import sys
 import uuid
@@ -38,7 +39,7 @@ from IPython.zmq.blockingkernelmanager import BlockingKernelManager
 from IPython.utils.path import filefind
 from IPython.utils.py3compat import str_to_bytes
 from IPython.utils.traitlets import (
-    Dict, List, Unicode, CUnicode, Int, CBool, Any
+    Dict, List, Unicode, CUnicode, Int, CBool, Any, CaselessStrEnum
 )
 from IPython.zmq.ipkernel import (
     flags as ipkernel_flags,
@@ -151,12 +152,27 @@ class IPythonConsoleApp(Configurable):
     # create requested profiles by default, if they don't exist:
     auto_create = CBool(True)
     # connection info:
-    ip = Unicode(LOCALHOST, config=True,
+    
+    transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
+    
+    ip = Unicode(config=True,
         help="""Set the kernel\'s IP address [default localhost].
         If the IP address is something other than localhost, then
         Consoles on other machines will be able to connect
         to the Kernel, so be careful!"""
     )
+    def _ip_default(self):
+        if self.transport == 'tcp':
+            return LOCALHOST
+        else:
+            # this can fire early if ip is given,
+            # in which case our return value is meaningless
+            if not hasattr(self, 'profile_dir'):
+                return ''
+            ipcdir = os.path.join(self.profile_dir.security_dir, 'kernel-%s' % os.getpid())
+            os.makedirs(ipcdir)
+            atexit.register(lambda : shutil.rmtree(ipcdir))
+            return os.path.join(ipcdir, 'ipc')
     
     sshserver = Unicode('', config=True,
         help="""The SSH server to use to connect to the kernel.""")
@@ -256,10 +272,10 @@ class IPythonConsoleApp(Configurable):
             return
         self.log.debug(u"Loading connection file %s", fname)
         with open(fname) as f:
-            s = f.read()
-        cfg = json.loads(s)
-        if self.ip == LOCALHOST and 'ip' in cfg:
-            # not overridden by config or cl_args
+            cfg = json.load(f)
+        
+        self.transport = cfg.get('transport', 'tcp')
+        if 'ip' in cfg:
             self.ip = cfg['ip']
         for channel in ('hb', 'shell', 'iopub', 'stdin'):
             name = channel + '_port'
@@ -268,12 +284,17 @@ class IPythonConsoleApp(Configurable):
                 setattr(self, name, cfg[name])
         if 'key' in cfg:
             self.config.Session.key = str_to_bytes(cfg['key'])
+        
     
     def init_ssh(self):
         """set up ssh tunnels, if needed."""
         if not self.sshserver and not self.sshkey:
             return
         
+        if self.transport != 'tcp':
+            self.log.error("Can only use ssh tunnels with TCP sockets, not %s", self.transport)
+            return
+        
         if self.sshkey and not self.sshserver:
             # specifying just the key implies that we are connecting directly
             self.sshserver = self.ip
@@ -326,6 +347,7 @@ class IPythonConsoleApp(Configurable):
 
         # Create a KernelManager and start a kernel.
         self.kernel_manager = self.kernel_manager_class(
+                                transport=self.transport,
                                 ip=self.ip,
                                 shell_port=self.shell_port,
                                 iopub_port=self.iopub_port,
diff --git a/IPython/zmq/entry_point.py b/IPython/zmq/entry_point.py
index ca85a44..dff9286 100644
--- a/IPython/zmq/entry_point.py
+++ b/IPython/zmq/entry_point.py
@@ -21,7 +21,7 @@ from IPython.utils.py3compat import bytes_to_str
 from parentpoller import ParentPollerWindows
 
 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
-                         ip=LOCALHOST, key=b''):
+                         ip=LOCALHOST, key=b'', transport='tcp'):
     """Generates a JSON config file, including the selection of random ports.
     
     Parameters
@@ -54,17 +54,26 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, 
         fname = tempfile.mktemp('.json')
     
     # Find open ports as necessary.
+    
     ports = []
     ports_needed = int(shell_port <= 0) + int(iopub_port <= 0) + \
                    int(stdin_port <= 0) + int(hb_port <= 0)
-    for i in xrange(ports_needed):
-        sock = socket.socket()
-        sock.bind(('', 0))
-        ports.append(sock)
-    for i, sock in enumerate(ports):
-        port = sock.getsockname()[1]
-        sock.close()
-        ports[i] = port
+    if transport == 'tcp':
+        for i in range(ports_needed):
+            sock = socket.socket()
+            sock.bind(('', 0))
+            ports.append(sock)
+        for i, sock in enumerate(ports):
+            port = sock.getsockname()[1]
+            sock.close()
+            ports[i] = port
+    else:
+        N = 1
+        for i in range(ports_needed):
+            while os.path.exists("%s-%s" % (ip, str(N))):
+                N += 1
+            ports.append(N)
+            N += 1
     if shell_port <= 0:
         shell_port = ports.pop(0)
     if iopub_port <= 0:
@@ -81,6 +90,7 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, 
               )
     cfg['ip'] = ip
     cfg['key'] = bytes_to_str(key)
+    cfg['transport'] = transport
     
     with open(fname, 'w') as f:
         f.write(json.dumps(cfg, indent=2))
diff --git a/IPython/zmq/heartbeat.py b/IPython/zmq/heartbeat.py
index 20509ec..049483e 100644
--- a/IPython/zmq/heartbeat.py
+++ b/IPython/zmq/heartbeat.py
@@ -12,6 +12,7 @@
 # Imports
 #-----------------------------------------------------------------------------
 
+import os
 import socket
 import sys
 from threading import Thread
@@ -28,21 +29,28 @@ from IPython.utils.localinterfaces import LOCALHOST
 class Heartbeat(Thread):
     "A simple ping-pong style heartbeat that runs in a thread."
 
-    def __init__(self, context, addr=(LOCALHOST, 0)):
+    def __init__(self, context, addr=('tcp', LOCALHOST, 0)):
         Thread.__init__(self)
         self.context = context
-        self.ip, self.port = addr
+        self.transport, self.ip, self.port = addr
         if self.port == 0:
-            s = socket.socket()
-            # '*' means all interfaces to 0MQ, which is '' to socket.socket
-            s.bind(('' if self.ip == '*' else self.ip, 0))
-            self.port = s.getsockname()[1]
-            s.close()
+            if addr[0] == 'tcp':
+                s = socket.socket()
+                # '*' means all interfaces to 0MQ, which is '' to socket.socket
+                s.bind(('' if self.ip == '*' else self.ip, 0))
+                self.port = s.getsockname()[1]
+                s.close()
+            elif addr[0] == 'ipc':
+                while os.path.exists(self.ip + '-' + self.port):
+                    self.port = self.port + 1
+            else:
+                raise ValueError("Unrecognized zmq transport: %s" % addr[0])
         self.addr = (self.ip, self.port)
         self.daemon = True
 
     def run(self):
         self.socket = self.context.socket(zmq.REP)
-        self.socket.bind('tcp://%s:%i' % self.addr)
+        c = ':' if self.transport == 'tcp' else '-'
+        self.socket.bind('%s://%s' % (self.transport, self.ip) + c + str(self.port))
         zmq.device(zmq.FORWARDER, self.socket, self.socket)
 
diff --git a/IPython/zmq/kernelapp.py b/IPython/zmq/kernelapp.py
index 80e151f..869f1dd 100644
--- a/IPython/zmq/kernelapp.py
+++ b/IPython/zmq/kernelapp.py
@@ -35,8 +35,10 @@ from IPython.utils import io
 from IPython.utils.localinterfaces import LOCALHOST
 from IPython.utils.path import filefind
 from IPython.utils.py3compat import str_to_bytes
-from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Integer, Bool,
-                                        DottedObjectName)
+from IPython.utils.traitlets import (
+    Any, Instance, Dict, Unicode, Integer, Bool, CaselessStrEnum,
+    DottedObjectName,
+)
 from IPython.utils.importstring import import_item
 # local imports
 from IPython.zmq.entry_point import write_connection_file
@@ -109,6 +111,7 @@ class KernelApp(BaseIPythonApplication):
         self.config_file_specified = False
         
     # connection info:
+    transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
     ip = Unicode(LOCALHOST, config=True,
         help="Set the IP or interface on which the kernel will listen.")
     hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
@@ -154,11 +157,12 @@ class KernelApp(BaseIPythonApplication):
             self.poller = ParentPollerUnix()
 
     def _bind_socket(self, s, port):
-        iface = 'tcp://%s' % self.ip
-        if port <= 0:
+        iface = '%s://%s' % (self.transport, self.ip)
+        if port <= 0 and self.transport == 'tcp':
             port = s.bind_to_random_port(iface)
         else:
-            s.bind(iface + ':%i'%port)
+            c = ':' if self.transport == 'tcp' else '-'
+            s.bind(iface + c + str(port))
         return port
 
     def load_connection_file(self):
@@ -174,6 +178,7 @@ class KernelApp(BaseIPythonApplication):
         with open(fname) as f:
             s = f.read()
         cfg = json.loads(s)
+        self.transport = cfg.get('transport', self.transport)
         if self.ip == LOCALHOST and 'ip' in cfg:
             # not overridden by config or cl_args
             self.ip = cfg['ip']
@@ -191,7 +196,7 @@ class KernelApp(BaseIPythonApplication):
             cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
         else:
             cf = self.connection_file
-        write_connection_file(cf, ip=self.ip, key=self.session.key,
+        write_connection_file(cf, ip=self.ip, key=self.session.key, transport=self.transport,
         shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
         iopub_port=self.iopub_port)
         
@@ -204,6 +209,19 @@ class KernelApp(BaseIPythonApplication):
             os.remove(cf)
         except (IOError, OSError):
             pass
+        
+        self._cleanup_ipc_files()
+    
+    def _cleanup_ipc_files(self):
+        """cleanup ipc files if we wrote them"""
+        if self.transport != 'ipc':
+            return
+        for port in (self.shell_port, self.iopub_port, self.stdin_port, self.hb_port):
+            ipcfile = "%s-%i" % (self.ip, port)
+            try:
+                os.remove(ipcfile)
+            except (IOError, OSError):
+                pass
     
     def init_connection_file(self):
         if not self.connection_file:
@@ -238,7 +256,7 @@ class KernelApp(BaseIPythonApplication):
         # heartbeat doesn't share context, because it mustn't be blocked
         # by the GIL, which is accessed by libzmq when freeing zero-copy messages
         hb_ctx = zmq.Context()
-        self.heartbeat = Heartbeat(hb_ctx, (self.ip, self.hb_port))
+        self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
         self.hb_port = self.heartbeat.port
         self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
         self.heartbeat.start()
diff --git a/IPython/zmq/kernelmanager.py b/IPython/zmq/kernelmanager.py
index ede90a6..8eb5efd 100644
--- a/IPython/zmq/kernelmanager.py
+++ b/IPython/zmq/kernelmanager.py
@@ -37,7 +37,7 @@ from zmq.eventloop import ioloop, zmqstream
 from IPython.config.loader import Config
 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
 from IPython.utils.traitlets import (
-    HasTraits, Any, Instance, Type, Unicode, Integer, Bool
+    HasTraits, Any, Instance, Type, Unicode, Integer, Bool, CaselessStrEnum
 )
 from IPython.utils.py3compat import str_to_bytes
 from IPython.zmq.entry_point import write_connection_file
@@ -103,7 +103,7 @@ class ZMQSocketChannel(Thread):
             The ZMQ context to use.
         session : :class:`session.Session`
             The session to use.
-        address : tuple
+        address : zmq url
             Standard (ip, port) tuple that the kernel is listening on.
         """
         super(ZMQSocketChannel, self).__init__()
@@ -111,9 +111,11 @@ class ZMQSocketChannel(Thread):
 
         self.context = context
         self.session = session
-        if address[1] == 0:
-            message = 'The port number for a channel cannot be 0.'
-            raise InvalidPortNumber(message)
+        if isinstance(address, tuple):
+            if address[1] == 0:
+                message = 'The port number for a channel cannot be 0.'
+                raise InvalidPortNumber(message)
+            address = "tcp://%s:%i" % address
         self._address = address
         atexit.register(self._notice_exit)
     
@@ -149,10 +151,7 @@ class ZMQSocketChannel(Thread):
 
     @property
     def address(self):
-        """Get the channel's address as an (ip, port) tuple.
-
-        By the default, the address is (localhost, 0), where 0 means a random
-        port.
+        """Get the channel's address as a zmq url string ('tcp://127.0.0.1:5555').
         """
         return self._address
 
@@ -196,7 +195,7 @@ class ShellSocketChannel(ZMQSocketChannel):
         """The thread's main activity.  Call start() instead."""
         self.socket = self.context.socket(zmq.DEALER)
         self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
-        self.socket.connect('tcp://%s:%i' % self.address)
+        self.socket.connect(self.address)
         self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
         self.stream.on_recv(self._handle_recv)
         self._run_loop()
@@ -396,7 +395,7 @@ class SubSocketChannel(ZMQSocketChannel):
         self.socket = self.context.socket(zmq.SUB)
         self.socket.setsockopt(zmq.SUBSCRIBE,b'')
         self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
-        self.socket.connect('tcp://%s:%i' % self.address)
+        self.socket.connect(self.address)
         self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
         self.stream.on_recv(self._handle_recv)
         self._run_loop()
@@ -462,7 +461,7 @@ class StdInSocketChannel(ZMQSocketChannel):
         """The thread's main activity.  Call start() instead."""
         self.socket = self.context.socket(zmq.DEALER)
         self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
-        self.socket.connect('tcp://%s:%i' % self.address)
+        self.socket.connect(self.address)
         self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
         self.stream.on_recv(self._handle_recv)
         self._run_loop()
@@ -521,7 +520,7 @@ class HBSocketChannel(ZMQSocketChannel):
             self.socket.close()
         self.socket = self.context.socket(zmq.REQ)
         self.socket.setsockopt(zmq.LINGER, 0)
-        self.socket.connect('tcp://%s:%i' % self.address)
+        self.socket.connect(self.address)
         
         self.poller.register(self.socket, zmq.POLLIN)
     
@@ -660,6 +659,10 @@ class KernelManager(HasTraits):
 
     # The addresses for the communication channels.
     connection_file = Unicode('')
+    
+    transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp')
+    
+    
     ip = Unicode(LOCALHOST)
     def _ip_changed(self, name, old, new):
         if new == '*':
@@ -748,7 +751,20 @@ class KernelManager(HasTraits):
             self._connection_file_written = False
             try:
                 os.remove(self.connection_file)
-            except OSError:
+            except (IOError, OSError):
+                pass
+            
+            self._cleanup_ipc_files()
+    
+    def _cleanup_ipc_files(self):
+        """cleanup ipc files if we wrote them"""
+        if self.transport != 'ipc':
+            return
+        for port in (self.shell_port, self.iopub_port, self.stdin_port, self.hb_port):
+            ipcfile = "%s-%i" % (self.ip, port)
+            try:
+                os.remove(ipcfile)
+            except (IOError, OSError):
                 pass
     
     def load_connection_file(self):
@@ -756,6 +772,9 @@ class KernelManager(HasTraits):
         with open(self.connection_file) as f:
             cfg = json.loads(f.read())
         
+        from pprint import pprint
+        pprint(cfg)
+        self.transport = cfg.get('transport', 'tcp')
         self.ip = cfg['ip']
         self.shell_port = cfg['shell_port']
         self.stdin_port = cfg['stdin_port']
@@ -768,7 +787,7 @@ class KernelManager(HasTraits):
         if self._connection_file_written:
             return
         self.connection_file,cfg = write_connection_file(self.connection_file,
-            ip=self.ip, key=self.session.key,
+            transport=self.transport, ip=self.ip, key=self.session.key,
             stdin_port=self.stdin_port, iopub_port=self.iopub_port,
             shell_port=self.shell_port, hb_port=self.hb_port)
         # write_connection_file also sets default ports:
@@ -795,7 +814,7 @@ class KernelManager(HasTraits):
         **kw : optional
              See respective options for IPython and Python kernels.
         """
-        if self.ip not in LOCAL_IPS:
+        if self.transport == 'tcp' and self.ip not in LOCAL_IPS:
             raise RuntimeError("Can only launch a kernel on a local interface. "
                                "Make sure that the '*_address' attributes are "
                                "configured properly. "
@@ -974,13 +993,21 @@ class KernelManager(HasTraits):
     # Channels used for communication with the kernel:
     #--------------------------------------------------------------------------
 
+    def _make_url(self, port):
+        """make a zmq url with a port"""
+        if self.transport == 'tcp':
+            return "tcp://%s:%i" % (self.ip, port)
+        else:
+            return "%s://%s-%s" % (self.transport, self.ip, port)
+
     @property
     def shell_channel(self):
         """Get the REQ socket channel object to make requests of the kernel."""
         if self._shell_channel is None:
             self._shell_channel = self.shell_channel_class(self.context,
-                                                         self.session,
-                                                         (self.ip, self.shell_port))
+                                                           self.session,
+                                                           self._make_url(self.shell_port),
+            )
         return self._shell_channel
 
     @property
@@ -989,7 +1016,8 @@ class KernelManager(HasTraits):
         if self._sub_channel is None:
             self._sub_channel = self.sub_channel_class(self.context,
                                                        self.session,
-                                                       (self.ip, self.iopub_port))
+                                                       self._make_url(self.iopub_port),
+            )
         return self._sub_channel
 
     @property
@@ -997,8 +1025,9 @@ class KernelManager(HasTraits):
         """Get the REP socket channel object to handle stdin (raw_input)."""
         if self._stdin_channel is None:
             self._stdin_channel = self.stdin_channel_class(self.context,
-                                                       self.session,
-                                                       (self.ip, self.stdin_port))
+                                                           self.session,
+                                                           self._make_url(self.stdin_port),
+            )
         return self._stdin_channel
 
     @property
@@ -1007,6 +1036,7 @@ class KernelManager(HasTraits):
         kernel is alive."""
         if self._hb_channel is None:
             self._hb_channel = self.hb_channel_class(self.context,
-                                                       self.session,
-                                                       (self.ip, self.hb_port))
+                                                     self.session,
+                                                     self._make_url(self.hb_port),
+            )
         return self._hb_channel