##// END OF EJS Templates
add ssh tunneling to Engine...
MinRK -
Show More
@@ -116,6 +116,7 b" flags.update(boolean_flag('secure', 'IPControllerApp.secure',"
116 aliases = dict(
116 aliases = dict(
117 secure = 'IPControllerApp.secure',
117 secure = 'IPControllerApp.secure',
118 ssh = 'IPControllerApp.ssh_server',
118 ssh = 'IPControllerApp.ssh_server',
119 enginessh = 'IPControllerApp.engine_ssh_server',
119 location = 'IPControllerApp.location',
120 location = 'IPControllerApp.location',
120
121
121 ident = 'Session.session',
122 ident = 'Session.session',
@@ -158,6 +159,11 b' class IPControllerApp(BaseParallelApplication):'
158 processes. It should be of the form: [user@]server[:port]. The
159 processes. It should be of the form: [user@]server[:port]. The
159 Controller's listening addresses must be accessible from the ssh server""",
160 Controller's listening addresses must be accessible from the ssh server""",
160 )
161 )
162 engine_ssh_server = Unicode(u'', config=True,
163 help="""ssh url for engines to use when connecting to the Controller
164 processes. It should be of the form: [user@]server[:port]. The
165 Controller's listening addresses must be accessible from the ssh server""",
166 )
161 location = Unicode(u'', config=True,
167 location = Unicode(u'', config=True,
162 help="""The external IP or domain name of the Controller, used for disambiguating
168 help="""The external IP or domain name of the Controller, used for disambiguating
163 engine and client connections.""",
169 engine and client connections.""",
@@ -218,6 +224,8 b' class IPControllerApp(BaseParallelApplication):'
218 c.HubFactory.engine_ip = ip
224 c.HubFactory.engine_ip = ip
219 c.HubFactory.regport = int(ports)
225 c.HubFactory.regport = int(ports)
220 self.location = cfg['location']
226 self.location = cfg['location']
227 if not self.engine_ssh_server:
228 self.engine_ssh_server = cfg['ssh']
221 # load client config
229 # load client config
222 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
230 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
223 cfg = json.loads(f.read())
231 cfg = json.loads(f.read())
@@ -226,7 +234,8 b' class IPControllerApp(BaseParallelApplication):'
226 c.HubFactory.client_transport = xport
234 c.HubFactory.client_transport = xport
227 ip,ports = addr.split(':')
235 ip,ports = addr.split(':')
228 c.HubFactory.client_ip = ip
236 c.HubFactory.client_ip = ip
229 self.ssh_server = cfg['ssh']
237 if not self.ssh_server:
238 self.ssh_server = cfg['ssh']
230 assert int(ports) == c.HubFactory.regport, "regport mismatch"
239 assert int(ports) == c.HubFactory.regport, "regport mismatch"
231
240
232 def init_hub(self):
241 def init_hub(self):
@@ -271,6 +280,7 b' class IPControllerApp(BaseParallelApplication):'
271 self.save_connection_dict('ipcontroller-client.json', cdict)
280 self.save_connection_dict('ipcontroller-client.json', cdict)
272 edict = cdict
281 edict = cdict
273 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
282 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
283 edict['ssh'] = self.engine_ssh_server
274 self.save_connection_dict('ipcontroller-engine.json', edict)
284 self.save_connection_dict('ipcontroller-engine.json', edict)
275
285
276 #
286 #
@@ -118,6 +118,8 b' aliases = dict('
118 keyfile = 'Session.keyfile',
118 keyfile = 'Session.keyfile',
119
119
120 url = 'EngineFactory.url',
120 url = 'EngineFactory.url',
121 ssh = 'EngineFactory.sshserver',
122 sshkey = 'EngineFactory.sshkey',
121 ip = 'EngineFactory.ip',
123 ip = 'EngineFactory.ip',
122 transport = 'EngineFactory.transport',
124 transport = 'EngineFactory.transport',
123 port = 'EngineFactory.regport',
125 port = 'EngineFactory.regport',
@@ -192,6 +194,40 b' class IPEngineApp(BaseParallelApplication):'
192 self.profile_dir.security_dir,
194 self.profile_dir.security_dir,
193 self.url_file_name
195 self.url_file_name
194 )
196 )
197
198 def load_connector_file(self):
199 """load config from a JSON connector file,
200 at a *lower* priority than command-line/config files.
201 """
202
203 self.log.info("Loading url_file %r"%self.url_file)
204 config = self.config
205
206 with open(self.url_file) as f:
207 d = json.loads(f.read())
208
209 try:
210 config.Session.key
211 except AttributeError:
212 if d['exec_key']:
213 config.Session.key = asbytes(d['exec_key'])
214
215 try:
216 config.EngineFactory.location
217 except AttributeError:
218 config.EngineFactory.location = d['location']
219
220 d['url'] = disambiguate_url(d['url'], config.EngineFactory.location)
221 try:
222 config.EngineFactory.url
223 except AttributeError:
224 config.EngineFactory.url = d['url']
225
226 try:
227 config.EngineFactory.sshserver
228 except AttributeError:
229 config.EngineFactory.sshserver = d['ssh']
230
195 def init_engine(self):
231 def init_engine(self):
196 # This is the working dir by now.
232 # This is the working dir by now.
197 sys.path.insert(0, '')
233 sys.path.insert(0, '')
@@ -219,14 +255,7 b' class IPEngineApp(BaseParallelApplication):'
219 time.sleep(0.1)
255 time.sleep(0.1)
220
256
221 if os.path.exists(self.url_file):
257 if os.path.exists(self.url_file):
222 self.log.info("Loading url_file %r"%self.url_file)
258 self.load_connector_file()
223 with open(self.url_file) as f:
224 d = json.loads(f.read())
225 if d['exec_key']:
226 config.Session.key = asbytes(d['exec_key'])
227 d['url'] = disambiguate_url(d['url'], d['location'])
228 config.EngineFactory.url = d['url']
229 config.EngineFactory.location = d['location']
230 elif not url_specified:
259 elif not url_specified:
231 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
260 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
232 self.exit(1)
261 self.exit(1)
@@ -253,7 +282,7 b' class IPEngineApp(BaseParallelApplication):'
253 except:
282 except:
254 self.log.error("Couldn't start the Engine", exc_info=True)
283 self.log.error("Couldn't start the Engine", exc_info=True)
255 self.exit(1)
284 self.exit(1)
256
285
257 def forward_logging(self):
286 def forward_logging(self):
258 if self.log_url:
287 if self.log_url:
259 self.log.info("Forwarding logging to %s"%self.log_url)
288 self.log.info("Forwarding logging to %s"%self.log_url)
@@ -265,7 +294,7 b' class IPEngineApp(BaseParallelApplication):'
265 handler.setLevel(self.log_level)
294 handler.setLevel(self.log_level)
266 self.log.addHandler(handler)
295 self.log.addHandler(handler)
267 self._log_handler = handler
296 self._log_handler = handler
268 #
297
269 def init_mpi(self):
298 def init_mpi(self):
270 global mpi
299 global mpi
271 self.mpi = MPI(config=self.config)
300 self.mpi = MPI(config=self.config)
@@ -17,12 +17,16 b' from __future__ import print_function'
17
17
18 import sys
18 import sys
19 import time
19 import time
20 from getpass import getpass
20
21
21 import zmq
22 import zmq
22 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
23
24
25 from IPython.external.ssh import tunnel
24 # internal
26 # internal
25 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 from IPython.utils.traitlets import (
28 Instance, Dict, Int, Type, CFloat, Unicode, CBytes, Bool
29 )
26 # from IPython.utils.localinterfaces import LOCALHOST
30 # from IPython.utils.localinterfaces import LOCALHOST
27
31
28 from IPython.parallel.controller.heartmonitor import Heart
32 from IPython.parallel.controller.heartmonitor import Heart
@@ -50,6 +54,12 b' class EngineFactory(RegistrationFactory):'
50 timeout=CFloat(2,config=True,
54 timeout=CFloat(2,config=True,
51 help="""The time (in seconds) to wait for the Controller to respond
55 help="""The time (in seconds) to wait for the Controller to respond
52 to registration requests before giving up.""")
56 to registration requests before giving up.""")
57 sshserver=Unicode(config=True,
58 help="""The SSH server to use for tunneling connections to the Controller.""")
59 sshkey=Unicode(config=True,
60 help="""The SSH keyfile to use when tunneling connections to the Controller.""")
61 paramiko=Bool(sys.platform == 'win32', config=True,
62 help="""Whether to use paramiko instead of openssh for tunnels.""")
53
63
54 # not configurable:
64 # not configurable:
55 user_ns=Dict()
65 user_ns=Dict()
@@ -61,28 +71,70 b' class EngineFactory(RegistrationFactory):'
61 ident = Unicode()
71 ident = Unicode()
62 def _ident_changed(self, name, old, new):
72 def _ident_changed(self, name, old, new):
63 self.bident = asbytes(new)
73 self.bident = asbytes(new)
74 using_ssh=Bool(False)
64
75
65
76
66 def __init__(self, **kwargs):
77 def __init__(self, **kwargs):
67 super(EngineFactory, self).__init__(**kwargs)
78 super(EngineFactory, self).__init__(**kwargs)
68 self.ident = self.session.session
79 self.ident = self.session.session
69 ctx = self.context
80
81 def init_connector(self):
82 """construct connection function, which handles tunnels."""
83 self.using_ssh = bool(self.sshkey or self.sshserver)
70
84
71 reg = ctx.socket(zmq.XREQ)
85 if self.sshkey and not self.sshserver:
72 reg.setsockopt(zmq.IDENTITY, self.bident)
86 # We are using ssh directly to the controller, tunneling localhost to localhost
73 reg.connect(self.url)
87 self.sshserver = self.url.split('://')[1].split(':')[0]
74 self.registrar = zmqstream.ZMQStream(reg, self.loop)
88
89 if self.using_ssh:
90 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
91 password=False
92 else:
93 password = getpass("SSH Password for %s: "%self.sshserver)
94 else:
95 password = False
96
97 def connect(s, url):
98 url = disambiguate_url(url, self.location)
99 if self.using_ssh:
100 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
101 return tunnel.tunnel_connection(s, url, self.sshserver,
102 keyfile=self.sshkey, paramiko=self.paramiko,
103 password=password,
104 )
105 else:
106 return s.connect(url)
107
108 def maybe_tunnel(url):
109 """like connect, but don't complete the connection (for use by heartbeat)"""
110 url = disambiguate_url(url, self.location)
111 if self.using_ssh:
112 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
113 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
114 keyfile=self.sshkey, paramiko=self.paramiko,
115 password=password,
116 )
117 return url
118 return connect, maybe_tunnel
75
119
76 def register(self):
120 def register(self):
77 """send the registration_request"""
121 """send the registration_request"""
78
122
79 self.log.info("Registering with controller at %s"%self.url)
123 self.log.info("Registering with controller at %s"%self.url)
124 ctx = self.context
125 connect,maybe_tunnel = self.init_connector()
126 reg = ctx.socket(zmq.XREQ)
127 reg.setsockopt(zmq.IDENTITY, self.bident)
128 connect(reg, self.url)
129 self.registrar = zmqstream.ZMQStream(reg, self.loop)
130
131
80 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
132 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
81 self.registrar.on_recv(self.complete_registration)
133 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
82 # print (self.session.key)
134 # print (self.session.key)
83 self.session.send(self.registrar, "registration_request",content=content)
135 self.session.send(self.registrar, "registration_request",content=content)
84
136
85 def complete_registration(self, msg):
137 def complete_registration(self, msg, connect, maybe_tunnel):
86 # print msg
138 # print msg
87 self._abort_dc.stop()
139 self._abort_dc.stop()
88 ctx = self.context
140 ctx = self.context
@@ -94,6 +146,14 b' class EngineFactory(RegistrationFactory):'
94 if msg.content.status == 'ok':
146 if msg.content.status == 'ok':
95 self.id = int(msg.content.id)
147 self.id = int(msg.content.id)
96
148
149 # launch heartbeat
150 hb_addrs = msg.content.heartbeat
151
152 # possibly forward hb ports with tunnels
153 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
154 heart = Heart(*map(str, hb_addrs), heart_id=identity)
155 heart.start()
156
97 # create Shell Streams (MUX, Task, etc.):
157 # create Shell Streams (MUX, Task, etc.):
98 queue_addr = msg.content.mux
158 queue_addr = msg.content.mux
99 shell_addrs = [ str(queue_addr) ]
159 shell_addrs = [ str(queue_addr) ]
@@ -114,24 +174,20 b' class EngineFactory(RegistrationFactory):'
114 stream.setsockopt(zmq.IDENTITY, identity)
174 stream.setsockopt(zmq.IDENTITY, identity)
115 shell_streams = [stream]
175 shell_streams = [stream]
116 for addr in shell_addrs:
176 for addr in shell_addrs:
117 stream.connect(disambiguate_url(addr, self.location))
177 connect(stream, addr)
118 # end single stream-socket
178 # end single stream-socket
119
179
120 # control stream:
180 # control stream:
121 control_addr = str(msg.content.control)
181 control_addr = str(msg.content.control)
122 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
182 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
123 control_stream.setsockopt(zmq.IDENTITY, identity)
183 control_stream.setsockopt(zmq.IDENTITY, identity)
124 control_stream.connect(disambiguate_url(control_addr, self.location))
184 connect(control_stream, control_addr)
125
185
126 # create iopub stream:
186 # create iopub stream:
127 iopub_addr = msg.content.iopub
187 iopub_addr = msg.content.iopub
128 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
188 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
129 iopub_stream.setsockopt(zmq.IDENTITY, identity)
189 iopub_stream.setsockopt(zmq.IDENTITY, identity)
130 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
190 connect(iopub_stream, iopub_addr)
131
132 # launch heartbeat
133 hb_addrs = msg.content.heartbeat
134 # print (hb_addrs)
135
191
136 # # Redirect input streams and set a display hook.
192 # # Redirect input streams and set a display hook.
137 if self.out_stream_factory:
193 if self.out_stream_factory:
@@ -147,9 +203,6 b' class EngineFactory(RegistrationFactory):'
147 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
203 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
148 loop=loop, user_ns = self.user_ns, log=self.log)
204 loop=loop, user_ns = self.user_ns, log=self.log)
149 self.kernel.start()
205 self.kernel.start()
150 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
151 heart = Heart(*map(str, hb_addrs), heart_id=identity)
152 heart.start()
153
206
154
207
155 else:
208 else:
General Comments 0
You need to be logged in to leave comments. Login now