##// END OF EJS Templates
add ssh tunneling to Engine...
MinRK -
Show More
@@ -116,6 +116,7 b" flags.update(boolean_flag('secure', 'IPControllerApp.secure',"
116 116 aliases = dict(
117 117 secure = 'IPControllerApp.secure',
118 118 ssh = 'IPControllerApp.ssh_server',
119 enginessh = 'IPControllerApp.engine_ssh_server',
119 120 location = 'IPControllerApp.location',
120 121
121 122 ident = 'Session.session',
@@ -158,6 +159,11 b' class IPControllerApp(BaseParallelApplication):'
158 159 processes. It should be of the form: [user@]server[:port]. The
159 160 Controller's listening addresses must be accessible from the ssh server""",
160 161 )
162 engine_ssh_server = Unicode(u'', config=True,
163 help="""ssh url for engines to use when connecting to the Controller
164 processes. It should be of the form: [user@]server[:port]. The
165 Controller's listening addresses must be accessible from the ssh server""",
166 )
161 167 location = Unicode(u'', config=True,
162 168 help="""The external IP or domain name of the Controller, used for disambiguating
163 169 engine and client connections.""",
@@ -218,6 +224,8 b' class IPControllerApp(BaseParallelApplication):'
218 224 c.HubFactory.engine_ip = ip
219 225 c.HubFactory.regport = int(ports)
220 226 self.location = cfg['location']
227 if not self.engine_ssh_server:
228 self.engine_ssh_server = cfg['ssh']
221 229 # load client config
222 230 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
223 231 cfg = json.loads(f.read())
@@ -226,7 +234,8 b' class IPControllerApp(BaseParallelApplication):'
226 234 c.HubFactory.client_transport = xport
227 235 ip,ports = addr.split(':')
228 236 c.HubFactory.client_ip = ip
229 self.ssh_server = cfg['ssh']
237 if not self.ssh_server:
238 self.ssh_server = cfg['ssh']
230 239 assert int(ports) == c.HubFactory.regport, "regport mismatch"
231 240
232 241 def init_hub(self):
@@ -271,6 +280,7 b' class IPControllerApp(BaseParallelApplication):'
271 280 self.save_connection_dict('ipcontroller-client.json', cdict)
272 281 edict = cdict
273 282 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
283 edict['ssh'] = self.engine_ssh_server
274 284 self.save_connection_dict('ipcontroller-engine.json', edict)
275 285
276 286 #
@@ -118,6 +118,8 b' aliases = dict('
118 118 keyfile = 'Session.keyfile',
119 119
120 120 url = 'EngineFactory.url',
121 ssh = 'EngineFactory.sshserver',
122 sshkey = 'EngineFactory.sshkey',
121 123 ip = 'EngineFactory.ip',
122 124 transport = 'EngineFactory.transport',
123 125 port = 'EngineFactory.regport',
@@ -192,6 +194,40 b' class IPEngineApp(BaseParallelApplication):'
192 194 self.profile_dir.security_dir,
193 195 self.url_file_name
194 196 )
197
198 def load_connector_file(self):
199 """load config from a JSON connector file,
200 at a *lower* priority than command-line/config files.
201 """
202
203 self.log.info("Loading url_file %r"%self.url_file)
204 config = self.config
205
206 with open(self.url_file) as f:
207 d = json.loads(f.read())
208
209 try:
210 config.Session.key
211 except AttributeError:
212 if d['exec_key']:
213 config.Session.key = asbytes(d['exec_key'])
214
215 try:
216 config.EngineFactory.location
217 except AttributeError:
218 config.EngineFactory.location = d['location']
219
220 d['url'] = disambiguate_url(d['url'], config.EngineFactory.location)
221 try:
222 config.EngineFactory.url
223 except AttributeError:
224 config.EngineFactory.url = d['url']
225
226 try:
227 config.EngineFactory.sshserver
228 except AttributeError:
229 config.EngineFactory.sshserver = d['ssh']
230
195 231 def init_engine(self):
196 232 # This is the working dir by now.
197 233 sys.path.insert(0, '')
@@ -219,14 +255,7 b' class IPEngineApp(BaseParallelApplication):'
219 255 time.sleep(0.1)
220 256
221 257 if os.path.exists(self.url_file):
222 self.log.info("Loading url_file %r"%self.url_file)
223 with open(self.url_file) as f:
224 d = json.loads(f.read())
225 if d['exec_key']:
226 config.Session.key = asbytes(d['exec_key'])
227 d['url'] = disambiguate_url(d['url'], d['location'])
228 config.EngineFactory.url = d['url']
229 config.EngineFactory.location = d['location']
258 self.load_connector_file()
230 259 elif not url_specified:
231 260 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
232 261 self.exit(1)
@@ -253,7 +282,7 b' class IPEngineApp(BaseParallelApplication):'
253 282 except:
254 283 self.log.error("Couldn't start the Engine", exc_info=True)
255 284 self.exit(1)
256
285
257 286 def forward_logging(self):
258 287 if self.log_url:
259 288 self.log.info("Forwarding logging to %s"%self.log_url)
@@ -265,7 +294,7 b' class IPEngineApp(BaseParallelApplication):'
265 294 handler.setLevel(self.log_level)
266 295 self.log.addHandler(handler)
267 296 self._log_handler = handler
268 #
297
269 298 def init_mpi(self):
270 299 global mpi
271 300 self.mpi = MPI(config=self.config)
@@ -17,12 +17,16 b' from __future__ import print_function'
17 17
18 18 import sys
19 19 import time
20 from getpass import getpass
20 21
21 22 import zmq
22 23 from zmq.eventloop import ioloop, zmqstream
23 24
25 from IPython.external.ssh import tunnel
24 26 # internal
25 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 from IPython.utils.traitlets import (
28 Instance, Dict, Int, Type, CFloat, Unicode, CBytes, Bool
29 )
26 30 # from IPython.utils.localinterfaces import LOCALHOST
27 31
28 32 from IPython.parallel.controller.heartmonitor import Heart
@@ -50,6 +54,12 b' class EngineFactory(RegistrationFactory):'
50 54 timeout=CFloat(2,config=True,
51 55 help="""The time (in seconds) to wait for the Controller to respond
52 56 to registration requests before giving up.""")
57 sshserver=Unicode(config=True,
58 help="""The SSH server to use for tunneling connections to the Controller.""")
59 sshkey=Unicode(config=True,
60 help="""The SSH keyfile to use when tunneling connections to the Controller.""")
61 paramiko=Bool(sys.platform == 'win32', config=True,
62 help="""Whether to use paramiko instead of openssh for tunnels.""")
53 63
54 64 # not configurable:
55 65 user_ns=Dict()
@@ -61,28 +71,70 b' class EngineFactory(RegistrationFactory):'
61 71 ident = Unicode()
62 72 def _ident_changed(self, name, old, new):
63 73 self.bident = asbytes(new)
74 using_ssh=Bool(False)
64 75
65 76
66 77 def __init__(self, **kwargs):
67 78 super(EngineFactory, self).__init__(**kwargs)
68 79 self.ident = self.session.session
69 ctx = self.context
80
81 def init_connector(self):
82 """construct connection function, which handles tunnels."""
83 self.using_ssh = bool(self.sshkey or self.sshserver)
70 84
71 reg = ctx.socket(zmq.XREQ)
72 reg.setsockopt(zmq.IDENTITY, self.bident)
73 reg.connect(self.url)
74 self.registrar = zmqstream.ZMQStream(reg, self.loop)
85 if self.sshkey and not self.sshserver:
86 # We are using ssh directly to the controller, tunneling localhost to localhost
87 self.sshserver = self.url.split('://')[1].split(':')[0]
88
89 if self.using_ssh:
90 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
91 password=False
92 else:
93 password = getpass("SSH Password for %s: "%self.sshserver)
94 else:
95 password = False
96
97 def connect(s, url):
98 url = disambiguate_url(url, self.location)
99 if self.using_ssh:
100 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
101 return tunnel.tunnel_connection(s, url, self.sshserver,
102 keyfile=self.sshkey, paramiko=self.paramiko,
103 password=password,
104 )
105 else:
106 return s.connect(url)
107
108 def maybe_tunnel(url):
109 """like connect, but don't complete the connection (for use by heartbeat)"""
110 url = disambiguate_url(url, self.location)
111 if self.using_ssh:
112 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
113 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
114 keyfile=self.sshkey, paramiko=self.paramiko,
115 password=password,
116 )
117 return url
118 return connect, maybe_tunnel
75 119
76 120 def register(self):
77 121 """send the registration_request"""
78 122
79 123 self.log.info("Registering with controller at %s"%self.url)
124 ctx = self.context
125 connect,maybe_tunnel = self.init_connector()
126 reg = ctx.socket(zmq.XREQ)
127 reg.setsockopt(zmq.IDENTITY, self.bident)
128 connect(reg, self.url)
129 self.registrar = zmqstream.ZMQStream(reg, self.loop)
130
131
80 132 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
81 self.registrar.on_recv(self.complete_registration)
133 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
82 134 # print (self.session.key)
83 135 self.session.send(self.registrar, "registration_request",content=content)
84 136
85 def complete_registration(self, msg):
137 def complete_registration(self, msg, connect, maybe_tunnel):
86 138 # print msg
87 139 self._abort_dc.stop()
88 140 ctx = self.context
@@ -94,6 +146,14 b' class EngineFactory(RegistrationFactory):'
94 146 if msg.content.status == 'ok':
95 147 self.id = int(msg.content.id)
96 148
149 # launch heartbeat
150 hb_addrs = msg.content.heartbeat
151
152 # possibly forward hb ports with tunnels
153 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
154 heart = Heart(*map(str, hb_addrs), heart_id=identity)
155 heart.start()
156
97 157 # create Shell Streams (MUX, Task, etc.):
98 158 queue_addr = msg.content.mux
99 159 shell_addrs = [ str(queue_addr) ]
@@ -114,24 +174,20 b' class EngineFactory(RegistrationFactory):'
114 174 stream.setsockopt(zmq.IDENTITY, identity)
115 175 shell_streams = [stream]
116 176 for addr in shell_addrs:
117 stream.connect(disambiguate_url(addr, self.location))
177 connect(stream, addr)
118 178 # end single stream-socket
119 179
120 180 # control stream:
121 181 control_addr = str(msg.content.control)
122 182 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
123 183 control_stream.setsockopt(zmq.IDENTITY, identity)
124 control_stream.connect(disambiguate_url(control_addr, self.location))
184 connect(control_stream, control_addr)
125 185
126 186 # create iopub stream:
127 187 iopub_addr = msg.content.iopub
128 188 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
129 189 iopub_stream.setsockopt(zmq.IDENTITY, identity)
130 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
131
132 # launch heartbeat
133 hb_addrs = msg.content.heartbeat
134 # print (hb_addrs)
190 connect(iopub_stream, iopub_addr)
135 191
136 192 # # Redirect input streams and set a display hook.
137 193 if self.out_stream_factory:
@@ -147,9 +203,6 b' class EngineFactory(RegistrationFactory):'
147 203 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
148 204 loop=loop, user_ns = self.user_ns, log=self.log)
149 205 self.kernel.start()
150 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
151 heart = Heart(*map(str, hb_addrs), heart_id=identity)
152 heart.start()
153 206
154 207
155 208 else:
General Comments 0
You need to be logged in to leave comments. Login now