diff --git a/IPython/frontend/consoleapp.py b/IPython/frontend/consoleapp.py index c9fa362..8b544d8 100644 --- a/IPython/frontend/consoleapp.py +++ b/IPython/frontend/consoleapp.py @@ -24,6 +24,7 @@ Authors: import atexit import json import os +import shutil import signal import sys import uuid @@ -38,7 +39,7 @@ from IPython.zmq.blockingkernelmanager import BlockingKernelManager from IPython.utils.path import filefind from IPython.utils.py3compat import str_to_bytes from IPython.utils.traitlets import ( - Dict, List, Unicode, CUnicode, Int, CBool, Any + Dict, List, Unicode, CUnicode, Int, CBool, Any, CaselessStrEnum ) from IPython.zmq.ipkernel import ( flags as ipkernel_flags, @@ -151,12 +152,27 @@ class IPythonConsoleApp(Configurable): # create requested profiles by default, if they don't exist: auto_create = CBool(True) # connection info: - ip = Unicode(LOCALHOST, config=True, + + transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True) + + ip = Unicode(config=True, help="""Set the kernel\'s IP address [default localhost]. If the IP address is something other than localhost, then Consoles on other machines will be able to connect to the Kernel, so be careful!""" ) + def _ip_default(self): + if self.transport == 'tcp': + return LOCALHOST + else: + # this can fire early if ip is given, + # in which case our return value is meaningless + if not hasattr(self, 'profile_dir'): + return '' + ipcdir = os.path.join(self.profile_dir.security_dir, 'kernel-%s' % os.getpid()) + os.makedirs(ipcdir) + atexit.register(lambda : shutil.rmtree(ipcdir)) + return os.path.join(ipcdir, 'ipc') sshserver = Unicode('', config=True, help="""The SSH server to use to connect to the kernel.""") @@ -166,11 +182,11 @@ class IPythonConsoleApp(Configurable): hb_port = Int(0, config=True, help="set the heartbeat port [default: random]") shell_port = Int(0, config=True, - help="set the shell (XREP) port [default: random]") + help="set the shell (ROUTER) port [default: random]") iopub_port = Int(0, config=True, help="set the iopub (PUB) port [default: random]") stdin_port = Int(0, config=True, - help="set the stdin (XREQ) port [default: random]") + help="set the stdin (DEALER) port [default: random]") connection_file = Unicode('', config=True, help="""JSON file in which to store connection info [default: kernel-.json] @@ -256,10 +272,10 @@ class IPythonConsoleApp(Configurable): return self.log.debug(u"Loading connection file %s", fname) with open(fname) as f: - s = f.read() - cfg = json.loads(s) - if self.ip == LOCALHOST and 'ip' in cfg: - # not overridden by config or cl_args + cfg = json.load(f) + + self.transport = cfg.get('transport', 'tcp') + if 'ip' in cfg: self.ip = cfg['ip'] for channel in ('hb', 'shell', 'iopub', 'stdin'): name = channel + '_port' @@ -268,12 +284,17 @@ class IPythonConsoleApp(Configurable): setattr(self, name, cfg[name]) if 'key' in cfg: self.config.Session.key = str_to_bytes(cfg['key']) + def init_ssh(self): """set up ssh tunnels, if needed.""" if not self.sshserver and not self.sshkey: return + if self.transport != 'tcp': + self.log.error("Can only use ssh tunnels with TCP sockets, not %s", self.transport) + return + if self.sshkey and not self.sshserver: # specifying just the key implies that we are connecting directly self.sshserver = self.ip @@ -326,6 +347,7 @@ class IPythonConsoleApp(Configurable): # Create a KernelManager and start a kernel. self.kernel_manager = self.kernel_manager_class( + transport=self.transport, ip=self.ip, shell_port=self.shell_port, iopub_port=self.iopub_port, diff --git a/IPython/zmq/entry_point.py b/IPython/zmq/entry_point.py index bff59b3..86f2391 100644 --- a/IPython/zmq/entry_point.py +++ b/IPython/zmq/entry_point.py @@ -21,7 +21,7 @@ from IPython.utils.py3compat import bytes_to_str from parentpoller import ParentPollerWindows def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0, - ip=LOCALHOST, key=b''): + ip=LOCALHOST, key=b'', transport='tcp'): """Generates a JSON config file, including the selection of random ports. Parameters @@ -54,17 +54,26 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, fname = tempfile.mktemp('.json') # Find open ports as necessary. + ports = [] ports_needed = int(shell_port <= 0) + int(iopub_port <= 0) + \ int(stdin_port <= 0) + int(hb_port <= 0) - for i in xrange(ports_needed): - sock = socket.socket() - sock.bind(('', 0)) - ports.append(sock) - for i, sock in enumerate(ports): - port = sock.getsockname()[1] - sock.close() - ports[i] = port + if transport == 'tcp': + for i in range(ports_needed): + sock = socket.socket() + sock.bind(('', 0)) + ports.append(sock) + for i, sock in enumerate(ports): + port = sock.getsockname()[1] + sock.close() + ports[i] = port + else: + N = 1 + for i in range(ports_needed): + while os.path.exists("%s-%s" % (ip, str(N))): + N += 1 + ports.append(N) + N += 1 if shell_port <= 0: shell_port = ports.pop(0) if iopub_port <= 0: @@ -81,6 +90,7 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, ) cfg['ip'] = ip cfg['key'] = bytes_to_str(key) + cfg['transport'] = transport with open(fname, 'w') as f: f.write(json.dumps(cfg, indent=2)) diff --git a/IPython/zmq/heartbeat.py b/IPython/zmq/heartbeat.py index 20509ec..049483e 100644 --- a/IPython/zmq/heartbeat.py +++ b/IPython/zmq/heartbeat.py @@ -12,6 +12,7 @@ # Imports #----------------------------------------------------------------------------- +import os import socket import sys from threading import Thread @@ -28,21 +29,28 @@ from IPython.utils.localinterfaces import LOCALHOST class Heartbeat(Thread): "A simple ping-pong style heartbeat that runs in a thread." - def __init__(self, context, addr=(LOCALHOST, 0)): + def __init__(self, context, addr=('tcp', LOCALHOST, 0)): Thread.__init__(self) self.context = context - self.ip, self.port = addr + self.transport, self.ip, self.port = addr if self.port == 0: - s = socket.socket() - # '*' means all interfaces to 0MQ, which is '' to socket.socket - s.bind(('' if self.ip == '*' else self.ip, 0)) - self.port = s.getsockname()[1] - s.close() + if addr[0] == 'tcp': + s = socket.socket() + # '*' means all interfaces to 0MQ, which is '' to socket.socket + s.bind(('' if self.ip == '*' else self.ip, 0)) + self.port = s.getsockname()[1] + s.close() + elif addr[0] == 'ipc': + while os.path.exists(self.ip + '-' + self.port): + self.port = self.port + 1 + else: + raise ValueError("Unrecognized zmq transport: %s" % addr[0]) self.addr = (self.ip, self.port) self.daemon = True def run(self): self.socket = self.context.socket(zmq.REP) - self.socket.bind('tcp://%s:%i' % self.addr) + c = ':' if self.transport == 'tcp' else '-' + self.socket.bind('%s://%s' % (self.transport, self.ip) + c + str(self.port)) zmq.device(zmq.FORWARDER, self.socket, self.socket) diff --git a/IPython/zmq/kernelapp.py b/IPython/zmq/kernelapp.py index feec06c..672646a 100644 --- a/IPython/zmq/kernelapp.py +++ b/IPython/zmq/kernelapp.py @@ -35,8 +35,10 @@ from IPython.utils import io from IPython.utils.localinterfaces import LOCALHOST from IPython.utils.path import filefind from IPython.utils.py3compat import str_to_bytes -from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Integer, Bool, - DottedObjectName) +from IPython.utils.traitlets import ( + Any, Instance, Dict, Unicode, Integer, Bool, CaselessStrEnum, + DottedObjectName, +) from IPython.utils.importstring import import_item # local imports from IPython.zmq.entry_point import write_connection_file @@ -109,6 +111,7 @@ class KernelApp(BaseIPythonApplication): self.config_file_specified = False # connection info: + transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True) ip = Unicode(LOCALHOST, config=True, help="Set the IP or interface on which the kernel will listen.") hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]") @@ -154,11 +157,12 @@ class KernelApp(BaseIPythonApplication): self.poller = ParentPollerUnix() def _bind_socket(self, s, port): - iface = 'tcp://%s' % self.ip - if port <= 0: + iface = '%s://%s' % (self.transport, self.ip) + if port <= 0 and self.transport == 'tcp': port = s.bind_to_random_port(iface) else: - s.bind(iface + ':%i'%port) + c = ':' if self.transport == 'tcp' else '-' + s.bind(iface + c + str(port)) return port def load_connection_file(self): @@ -174,6 +178,7 @@ class KernelApp(BaseIPythonApplication): with open(fname) as f: s = f.read() cfg = json.loads(s) + self.transport = cfg.get('transport', self.transport) if self.ip == LOCALHOST and 'ip' in cfg: # not overridden by config or cl_args self.ip = cfg['ip'] @@ -191,7 +196,7 @@ class KernelApp(BaseIPythonApplication): cf = os.path.join(self.profile_dir.security_dir, self.connection_file) else: cf = self.connection_file - write_connection_file(cf, ip=self.ip, key=self.session.key, + write_connection_file(cf, ip=self.ip, key=self.session.key, transport=self.transport, shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port, iopub_port=self.iopub_port) @@ -204,6 +209,19 @@ class KernelApp(BaseIPythonApplication): os.remove(cf) except (IOError, OSError): pass + + self._cleanup_ipc_files() + + def _cleanup_ipc_files(self): + """cleanup ipc files if we wrote them""" + if self.transport != 'ipc': + return + for port in (self.shell_port, self.iopub_port, self.stdin_port, self.hb_port): + ipcfile = "%s-%i" % (self.ip, port) + try: + os.remove(ipcfile) + except (IOError, OSError): + pass def init_connection_file(self): if not self.connection_file: @@ -238,7 +256,7 @@ class KernelApp(BaseIPythonApplication): # heartbeat doesn't share context, because it mustn't be blocked # by the GIL, which is accessed by libzmq when freeing zero-copy messages hb_ctx = zmq.Context() - self.heartbeat = Heartbeat(hb_ctx, (self.ip, self.hb_port)) + self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port)) self.hb_port = self.heartbeat.port self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port) self.heartbeat.start() diff --git a/IPython/zmq/kernelmanager.py b/IPython/zmq/kernelmanager.py index b175af5..200dc85 100644 --- a/IPython/zmq/kernelmanager.py +++ b/IPython/zmq/kernelmanager.py @@ -37,7 +37,7 @@ from zmq.eventloop import ioloop, zmqstream from IPython.config.loader import Config from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS from IPython.utils.traitlets import ( - HasTraits, Any, Instance, Type, Unicode, Integer, Bool + HasTraits, Any, Instance, Type, Unicode, Integer, Bool, CaselessStrEnum ) from IPython.utils.py3compat import str_to_bytes from IPython.zmq.entry_point import write_connection_file @@ -103,7 +103,7 @@ class ZMQSocketChannel(Thread): The ZMQ context to use. session : :class:`session.Session` The session to use. - address : tuple + address : zmq url Standard (ip, port) tuple that the kernel is listening on. """ super(ZMQSocketChannel, self).__init__() @@ -111,9 +111,11 @@ class ZMQSocketChannel(Thread): self.context = context self.session = session - if address[1] == 0: - message = 'The port number for a channel cannot be 0.' - raise InvalidPortNumber(message) + if isinstance(address, tuple): + if address[1] == 0: + message = 'The port number for a channel cannot be 0.' + raise InvalidPortNumber(message) + address = "tcp://%s:%i" % address self._address = address atexit.register(self._notice_exit) @@ -149,10 +151,7 @@ class ZMQSocketChannel(Thread): @property def address(self): - """Get the channel's address as an (ip, port) tuple. - - By the default, the address is (localhost, 0), where 0 means a random - port. + """Get the channel's address as a zmq url string ('tcp://127.0.0.1:5555'). """ return self._address @@ -196,7 +195,7 @@ class ShellSocketChannel(ZMQSocketChannel): """The thread's main activity. Call start() instead.""" self.socket = self.context.socket(zmq.DEALER) self.socket.setsockopt(zmq.IDENTITY, self.session.bsession) - self.socket.connect('tcp://%s:%i' % self.address) + self.socket.connect(self.address) self.stream = zmqstream.ZMQStream(self.socket, self.ioloop) self.stream.on_recv(self._handle_recv) self._run_loop() @@ -390,7 +389,7 @@ class SubSocketChannel(ZMQSocketChannel): self.socket = self.context.socket(zmq.SUB) self.socket.setsockopt(zmq.SUBSCRIBE,b'') self.socket.setsockopt(zmq.IDENTITY, self.session.bsession) - self.socket.connect('tcp://%s:%i' % self.address) + self.socket.connect(self.address) self.stream = zmqstream.ZMQStream(self.socket, self.ioloop) self.stream.on_recv(self._handle_recv) self._run_loop() @@ -456,7 +455,7 @@ class StdInSocketChannel(ZMQSocketChannel): """The thread's main activity. Call start() instead.""" self.socket = self.context.socket(zmq.DEALER) self.socket.setsockopt(zmq.IDENTITY, self.session.bsession) - self.socket.connect('tcp://%s:%i' % self.address) + self.socket.connect(self.address) self.stream = zmqstream.ZMQStream(self.socket, self.ioloop) self.stream.on_recv(self._handle_recv) self._run_loop() @@ -515,7 +514,7 @@ class HBSocketChannel(ZMQSocketChannel): self.socket.close() self.socket = self.context.socket(zmq.REQ) self.socket.setsockopt(zmq.LINGER, 0) - self.socket.connect('tcp://%s:%i' % self.address) + self.socket.connect(self.address) self.poller.register(self.socket, zmq.POLLIN) @@ -654,6 +653,10 @@ class KernelManager(HasTraits): # The addresses for the communication channels. connection_file = Unicode('') + + transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp') + + ip = Unicode(LOCALHOST) def _ip_changed(self, name, old, new): if new == '*': @@ -742,7 +745,20 @@ class KernelManager(HasTraits): self._connection_file_written = False try: os.remove(self.connection_file) - except OSError: + except (IOError, OSError): + pass + + self._cleanup_ipc_files() + + def _cleanup_ipc_files(self): + """cleanup ipc files if we wrote them""" + if self.transport != 'ipc': + return + for port in (self.shell_port, self.iopub_port, self.stdin_port, self.hb_port): + ipcfile = "%s-%i" % (self.ip, port) + try: + os.remove(ipcfile) + except (IOError, OSError): pass def load_connection_file(self): @@ -750,6 +766,9 @@ class KernelManager(HasTraits): with open(self.connection_file) as f: cfg = json.loads(f.read()) + from pprint import pprint + pprint(cfg) + self.transport = cfg.get('transport', 'tcp') self.ip = cfg['ip'] self.shell_port = cfg['shell_port'] self.stdin_port = cfg['stdin_port'] @@ -762,7 +781,7 @@ class KernelManager(HasTraits): if self._connection_file_written: return self.connection_file,cfg = write_connection_file(self.connection_file, - ip=self.ip, key=self.session.key, + transport=self.transport, ip=self.ip, key=self.session.key, stdin_port=self.stdin_port, iopub_port=self.iopub_port, shell_port=self.shell_port, hb_port=self.hb_port) # write_connection_file also sets default ports: @@ -789,7 +808,7 @@ class KernelManager(HasTraits): **kw : optional See respective options for IPython and Python kernels. """ - if self.ip not in LOCAL_IPS: + if self.transport == 'tcp' and self.ip not in LOCAL_IPS: raise RuntimeError("Can only launch a kernel on a local interface. " "Make sure that the '*_address' attributes are " "configured properly. " @@ -956,13 +975,21 @@ class KernelManager(HasTraits): # Channels used for communication with the kernel: #-------------------------------------------------------------------------- + def _make_url(self, port): + """make a zmq url with a port""" + if self.transport == 'tcp': + return "tcp://%s:%i" % (self.ip, port) + else: + return "%s://%s-%s" % (self.transport, self.ip, port) + @property def shell_channel(self): """Get the REQ socket channel object to make requests of the kernel.""" if self._shell_channel is None: self._shell_channel = self.shell_channel_class(self.context, - self.session, - (self.ip, self.shell_port)) + self.session, + self._make_url(self.shell_port), + ) return self._shell_channel @property @@ -971,7 +998,8 @@ class KernelManager(HasTraits): if self._sub_channel is None: self._sub_channel = self.sub_channel_class(self.context, self.session, - (self.ip, self.iopub_port)) + self._make_url(self.iopub_port), + ) return self._sub_channel @property @@ -979,8 +1007,9 @@ class KernelManager(HasTraits): """Get the REP socket channel object to handle stdin (raw_input).""" if self._stdin_channel is None: self._stdin_channel = self.stdin_channel_class(self.context, - self.session, - (self.ip, self.stdin_port)) + self.session, + self._make_url(self.stdin_port), + ) return self._stdin_channel @property @@ -989,6 +1018,7 @@ class KernelManager(HasTraits): kernel is alive.""" if self._hb_channel is None: self._hb_channel = self.hb_channel_class(self.context, - self.session, - (self.ip, self.hb_port)) + self.session, + self._make_url(self.hb_port), + ) return self._hb_channel