##// END OF EJS Templates
Adding tested ipc support to MultiKernelManager.
Brian E. Granger -
Show More
@@ -85,7 +85,9 b' class MultiKernelManager(LoggingConfigurable):'
85 kernel_id = kwargs.pop('kernel_id', unicode(uuid.uuid4()))
85 kernel_id = kwargs.pop('kernel_id', unicode(uuid.uuid4()))
86 if kernel_id in self:
86 if kernel_id in self:
87 raise DuplicateKernelError('Kernel already exists: %s' % kernel_id)
87 raise DuplicateKernelError('Kernel already exists: %s' % kernel_id)
88 # use base KernelManager for each Kernel
88 # kernel_manager_factory is the constructor for the KernelManager
89 # subclass we are using. It can be configured as any Configurable,
90 # including things like its transport and ip.
89 km = self.kernel_manager_factory(connection_file=os.path.join(
91 km = self.kernel_manager_factory(connection_file=os.path.join(
90 self.connection_dir, "kernel-%s.json" % kernel_id),
92 self.connection_dir, "kernel-%s.json" % kernel_id),
91 config=self.config,
93 config=self.config,
@@ -175,7 +177,7 b' class MultiKernelManager(LoggingConfigurable):'
175 else:
177 else:
176 raise KeyError("Kernel with id not found: %s" % kernel_id)
178 raise KeyError("Kernel with id not found: %s" % kernel_id)
177
179
178 def get_connection_data(self, kernel_id):
180 def get_connection_info(self, kernel_id):
179 """Return a dictionary of connection data for a kernel.
181 """Return a dictionary of connection data for a kernel.
180
182
181 Parameters
183 Parameters
@@ -192,18 +194,28 b' class MultiKernelManager(LoggingConfigurable):'
192 shell_port, hb_port).
194 shell_port, hb_port).
193 """
195 """
194 km = self.get_kernel(kernel_id)
196 km = self.get_kernel(kernel_id)
195 return dict(ip=km.ip,
197 return dict(transport=km.transport,
198 ip=km.ip,
196 shell_port=km.shell_port,
199 shell_port=km.shell_port,
197 iopub_port=km.iopub_port,
200 iopub_port=km.iopub_port,
198 stdin_port=km.stdin_port,
201 stdin_port=km.stdin_port,
199 hb_port=km.hb_port,
202 hb_port=km.hb_port,
200 )
203 )
201
204
202 def create_connected_stream(self, ip, port, socket_type):
205 def _make_url(self, transport, ip, port):
206 """Make a ZeroMQ URL for a given transport, ip and port."""
207 if transport == 'tcp':
208 return "tcp://%s:%i" % (ip, port)
209 else:
210 return "%s://%s-%s" % (transport, ip, port)
211
212 def _create_connected_stream(self, kernel_id, socket_type):
213 """Create a connected ZMQStream for a kernel."""
214 cinfo = self.get_connection_info(kernel_id)
215 url = self._make_url(cinfo['transport'], cinfo['ip'], cinfo['port'])
203 sock = self.context.socket(socket_type)
216 sock = self.context.socket(socket_type)
204 addr = "tcp://%s:%i" % (ip, port)
217 self.log.info("Connecting to: %s" % url)
205 self.log.info("Connecting to: %s" % addr)
218 sock.connect(url)
206 sock.connect(addr)
207 return ZMQStream(sock)
219 return ZMQStream(sock)
208
220
209 def create_iopub_stream(self, kernel_id):
221 def create_iopub_stream(self, kernel_id):
@@ -218,10 +230,7 b' class MultiKernelManager(LoggingConfigurable):'
218 =======
230 =======
219 stream : ZMQStream
231 stream : ZMQStream
220 """
232 """
221 kdata = self.get_connection_data(kernel_id)
233 iopub_stream = self._create_connected_stream(kernel_id, zmq.SUB)
222 iopub_stream = self.create_connected_stream(
223 kdata['ip'], kdata['iopub_port'], zmq.SUB
224 )
225 iopub_stream.socket.setsockopt(zmq.SUBSCRIBE, b'')
234 iopub_stream.socket.setsockopt(zmq.SUBSCRIBE, b'')
226 return iopub_stream
235 return iopub_stream
227
236
@@ -237,10 +246,7 b' class MultiKernelManager(LoggingConfigurable):'
237 =======
246 =======
238 stream : ZMQStream
247 stream : ZMQStream
239 """
248 """
240 kdata = self.get_connection_data(kernel_id)
249 shell_stream = self._create_connected_stream(kernel_id, zmq.DEALER)
241 shell_stream = self.create_connected_stream(
242 kdata['ip'], kdata['shell_port'], zmq.DEALER
243 )
244 return shell_stream
250 return shell_stream
245
251
246 def create_hb_stream(self, kernel_id):
252 def create_hb_stream(self, kernel_id):
@@ -255,10 +261,7 b' class MultiKernelManager(LoggingConfigurable):'
255 =======
261 =======
256 stream : ZMQStream
262 stream : ZMQStream
257 """
263 """
258 kdata = self.get_connection_data(kernel_id)
264 hb_stream = self._create_connected_stream(kernel_id, zmq.REQ)
259 hb_stream = self.create_connected_stream(
260 kdata['ip'], kdata['hb_port'], zmq.REQ
261 )
262 return hb_stream
265 return hb_stream
263
266
264
267
@@ -2,12 +2,23 b''
2
2
3 from unittest import TestCase
3 from unittest import TestCase
4
4
5 from IPython.config.loader import Config
5 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
6 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
7 from IPython.zmq.kernelmanager import KernelManager
6
8
7 class TestKernelManager(TestCase):
9 class TestKernelManager(TestCase):
8
10
9 def test_km_lifecycle(self):
11 def _get_tcp_km(self):
10 km = MultiKernelManager()
12 return MultiKernelManager()
13
14 def _get_ipc_km(self):
15 c = Config()
16 c.KernelManager.transport = 'ipc'
17 c.KernelManager.ip = 'test'
18 km = MultiKernelManager(config=c)
19 return km
20
21 def _run_lifecycle(self, km):
11 kid = km.start_kernel()
22 kid = km.start_kernel()
12 self.assertTrue(kid in km)
23 self.assertTrue(kid in km)
13 self.assertTrue(kid in km.list_kernel_ids())
24 self.assertTrue(kid in km.list_kernel_ids())
@@ -15,17 +26,42 b' class TestKernelManager(TestCase):'
15 km.restart_kernel(kid)
26 km.restart_kernel(kid)
16 self.assertTrue(kid in km.list_kernel_ids())
27 self.assertTrue(kid in km.list_kernel_ids())
17 km.interrupt_kernel(kid)
28 km.interrupt_kernel(kid)
18 km.kill_kernel(kid)
29 k = km.get_kernel(kid)
30 self.assertTrue(isinstance(k, KernelManager))
31 km.shutdown_kernel(kid)
19 self.assertTrue(not kid in km)
32 self.assertTrue(not kid in km)
20
33
34 def test_km_tcp(self):
35 km = self._get_tcp_km()
36 self._run_lifecycle(km)
37
38 def test_km_ipc(self):
39 km = self._get_ipc_km()
40 self._run_lifecycle(km)
41
42 def test_tcp_cinfo(self):
43 km = self._get_tcp_km()
21 kid = km.start_kernel()
44 kid = km.start_kernel()
22 cdata = km.get_connection_data(kid)
45 k = km.get_kernel(kid)
23 self.assertEqual('127.0.0.1', cdata['ip'])
46 cinfo = km.get_connection_info(kid)
24 self.assertTrue('stdin_port' in cdata)
47 self.assertEqual('tcp', cinfo['transport'])
25 self.assertTrue('iopub_port' in cdata)
48 self.assertEqual('127.0.0.1', cinfo['ip'])
26 self.assertTrue('shell_port' in cdata)
49 self.assertTrue('stdin_port' in cinfo)
27 self.assertTrue('hb_port' in cdata)
50 self.assertTrue('iopub_port' in cinfo)
28 km.get_kernel(kid)
51 self.assertTrue('shell_port' in cinfo)
29 km.kill_kernel(kid)
52 self.assertTrue('hb_port' in cinfo)
53 km.shutdown_kernel(kid)
30
54
55 def test_ipc_cinfo(self):
56 km = self._get_ipc_km()
57 kid = km.start_kernel()
58 k = km.get_kernel(kid)
59 cinfo = km.get_connection_info(kid)
60 self.assertEqual('ipc', cinfo['transport'])
61 self.assertEqual('test', cinfo['ip'])
62 self.assertTrue('stdin_port' in cinfo)
63 self.assertTrue('iopub_port' in cinfo)
64 self.assertTrue('shell_port' in cinfo)
65 self.assertTrue('hb_port' in cinfo)
66 km.shutdown_kernel(kid)
31
67
@@ -35,6 +35,7 b' from zmq.eventloop import ioloop, zmqstream'
35
35
36 # Local imports.
36 # Local imports.
37 from IPython.config.loader import Config
37 from IPython.config.loader import Config
38 from IPython.config.configurable import Configurable
38 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
39 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
39 from IPython.utils.traitlets import (
40 from IPython.utils.traitlets import (
40 HasTraits, Any, Instance, Type, Unicode, Integer, Bool, CaselessStrEnum
41 HasTraits, Any, Instance, Type, Unicode, Integer, Bool, CaselessStrEnum
@@ -638,7 +639,7 b' class HBSocketChannel(ZMQSocketChannel):'
638 # Main kernel manager class
639 # Main kernel manager class
639 #-----------------------------------------------------------------------------
640 #-----------------------------------------------------------------------------
640
641
641 class KernelManager(HasTraits):
642 class KernelManager(Configurable):
642 """ Manages a kernel for a frontend.
643 """ Manages a kernel for a frontend.
643
644
644 The SUB channel is for the frontend to receive messages published by the
645 The SUB channel is for the frontend to receive messages published by the
@@ -649,9 +650,6 b' class KernelManager(HasTraits):'
649 The REP channel is for the kernel to request stdin (raw_input) from the
650 The REP channel is for the kernel to request stdin (raw_input) from the
650 frontend.
651 frontend.
651 """
652 """
652 # config object for passing to child configurables
653 config = Instance(Config)
654
655 # The PyZMQ Context to use for communication with the kernel.
653 # The PyZMQ Context to use for communication with the kernel.
656 context = Instance(zmq.Context)
654 context = Instance(zmq.Context)
657 def _context_default(self):
655 def _context_default(self):
@@ -668,10 +666,9 b' class KernelManager(HasTraits):'
668 # The addresses for the communication channels.
666 # The addresses for the communication channels.
669 connection_file = Unicode('')
667 connection_file = Unicode('')
670
668
671 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp')
669 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
672
670
673
671 ip = Unicode(LOCALHOST, config=True)
674 ip = Unicode(LOCALHOST)
675 def _ip_changed(self, name, old, new):
672 def _ip_changed(self, name, old, new):
676 if new == '*':
673 if new == '*':
677 self.ip = '0.0.0.0'
674 self.ip = '0.0.0.0'
@@ -768,12 +765,12 b' class KernelManager(HasTraits):'
768 os.remove(ipcfile)
765 os.remove(ipcfile)
769 except (IOError, OSError):
766 except (IOError, OSError):
770 pass
767 pass
771
768
772 def load_connection_file(self):
769 def load_connection_file(self):
773 """load connection info from JSON dict in self.connection_file"""
770 """load connection info from JSON dict in self.connection_file"""
774 with open(self.connection_file) as f:
771 with open(self.connection_file) as f:
775 cfg = json.loads(f.read())
772 cfg = json.loads(f.read())
776
773
777 from pprint import pprint
774 from pprint import pprint
778 pprint(cfg)
775 pprint(cfg)
779 self.transport = cfg.get('transport', 'tcp')
776 self.transport = cfg.get('transport', 'tcp')
General Comments 0
You need to be logged in to leave comments. Login now