##// END OF EJS Templates
added exec_key and fixed client.shutdown
MinRK -
Show More
@@ -12,6 +12,7 b''
12 12
13 13 from __future__ import print_function
14 14
15 import os
15 16 import time
16 17 from pprint import pprint
17 18
@@ -139,19 +140,30 b' class Client(object):'
139 140 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
140 141 If keyfile or password is specified, and this is not, it will default to
141 142 the ip given in addr.
142 keyfile : str; path to public key file
143 sshkey : str; path to public ssh key file
143 144 This specifies a key to be used in ssh login, default None.
144 145 Regular default ssh keys will be used without specifying this argument.
145 146 password : str;
146 147 Your ssh password to sshserver. Note that if this is left None,
147 148 you will be prompted for it if passwordless key based login is unavailable.
148 149
150 #------- exec authentication args -------
151 # If even localhost is untrusted, you can have some protection against
152 # unauthorized execution by using a key. Messages are still sent
153 # as cleartext, so if someone can snoop your loopback traffic this will
154 # not help anything.
155
156 exec_key : str
157 an authentication key or file containing a key
158 default: None
159
160
149 161 Attributes
150 162 ----------
151 163 ids : set of int engine IDs
152 164 requesting the ids attribute always synchronizes
153 165 the registration state. To request ids without synchronization,
154 use semi-private _ids.
166 use semi-private _ids attributes.
155 167
156 168 history : list of msg_ids
157 169 a list of msg_ids, keeping track of all the execution
@@ -175,7 +187,7 b' class Client(object):'
175 187
176 188 barrier : wait on one or more msg_ids
177 189
178 execution methods: apply/apply_bound/apply_to/applu_bount
190 execution methods: apply/apply_bound/apply_to/apply_bound
179 191 legacy: execute, run
180 192
181 193 query methods: queue_status, get_result, purge
@@ -202,26 +214,32 b' class Client(object):'
202 214 debug = False
203 215
204 216 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
205 sshserver=None, keyfile=None, password=None, paramiko=None):
217 sshserver=None, sshkey=None, password=None, paramiko=None,
218 exec_key=None,):
206 219 if context is None:
207 220 context = zmq.Context()
208 221 self.context = context
209 222 self._addr = addr
210 self._ssh = bool(sshserver or keyfile or password)
223 self._ssh = bool(sshserver or sshkey or password)
211 224 if self._ssh and sshserver is None:
212 225 # default to the same
213 226 sshserver = addr.split('://')[1].split(':')[0]
214 227 if self._ssh and password is None:
215 if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko):
228 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
216 229 password=False
217 230 else:
218 231 password = getpass("SSH Password for %s: "%sshserver)
219 ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko)
232 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
220 233
234 if os.path.isfile(exec_key):
235 arg = 'keyfile'
236 else:
237 arg = 'key'
238 key_arg = {arg:exec_key}
221 239 if username is None:
222 self.session = ss.StreamSession()
240 self.session = ss.StreamSession(**key_arg)
223 241 else:
224 self.session = ss.StreamSession(username)
242 self.session = ss.StreamSession(username, **key_arg)
225 243 self._registration_socket = self.context.socket(zmq.XREQ)
226 244 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
227 245 if self._ssh:
@@ -536,11 +554,12 b' class Client(object):'
536 554
537 555 @spinfirst
538 556 @defaultblock
539 def kill(self, targets=None, block=None):
557 def shutdown(self, targets=None, restart=False, block=None):
540 558 """Terminates one or more engine processes."""
541 559 targets = self._build_targets(targets)[0]
542 560 for t in targets:
543 self.session.send(self._control_socket, 'kill_request', content={},ident=t)
561 self.session.send(self._control_socket, 'shutdown_request',
562 content={'restart':restart},ident=t)
544 563 error = False
545 564 if self.block:
546 565 for i in range(len(targets)):
@@ -15,6 +15,7 b' and monitors traffic through the various queues.'
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 import os
18 19 from datetime import datetime
19 20 import logging
20 21
@@ -28,7 +29,7 b' from IPython.zmq.entry_point import bind_port'
28 29
29 30 from streamsession import Message, wrap_exception
30 31 from entry_point import (make_base_argument_parser, select_random_ports, split_ports,
31 connect_logger, parse_url, signal_children)
32 connect_logger, parse_url, signal_children, generate_exec_key)
32 33
33 34 #-----------------------------------------------------------------------------
34 35 # Code
@@ -283,13 +284,12 b' class Controller(object):'
283 284 logger.debug("registration::dispatch_register_request(%s)"%msg)
284 285 idents,msg = self.session.feed_identities(msg)
285 286 if not idents:
286 logger.error("Bad Queue Message: %s"%msg)
287 logger.error("Bad Queue Message: %s"%msg, exc_info=True)
287 288 return
288 289 try:
289 290 msg = self.session.unpack_message(msg,content=True)
290 except Exception as e:
291 logger.error("registration::got bad registration message: %s"%msg)
292 raise e
291 except:
292 logger.error("registration::got bad registration message: %s"%msg, exc_info=True)
293 293 return
294 294
295 295 msg_type = msg['msg_type']
@@ -326,7 +326,7 b' class Controller(object):'
326 326 msg = self.session.unpack_message(msg, content=True)
327 327 except:
328 328 content = wrap_exception()
329 logger.error("Bad Client Message: %s"%msg)
329 logger.error("Bad Client Message: %s"%msg, exc_info=True)
330 330 self.session.send(self.clientele, "controller_error", ident=client_id,
331 331 content=content)
332 332 return
@@ -340,7 +340,7 b' class Controller(object):'
340 340 assert handler is not None, "Bad Message Type: %s"%msg_type
341 341 except:
342 342 content = wrap_exception()
343 logger.error("Bad Message Type: %s"%msg_type)
343 logger.error("Bad Message Type: %s"%msg_type, exc_info=True)
344 344 self.session.send(self.clientele, "controller_error", ident=client_id,
345 345 content=content)
346 346 return
@@ -390,7 +390,7 b' class Controller(object):'
390 390 try:
391 391 msg = self.session.unpack_message(msg, content=False)
392 392 except:
393 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg))
393 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
394 394 return
395 395
396 396 eid = self.by_ident.get(queue_id, None)
@@ -417,7 +417,7 b' class Controller(object):'
417 417 msg = self.session.unpack_message(msg, content=False)
418 418 except:
419 419 logger.error("queue::engine %r sent invalid message to %r: %s"%(
420 queue_id,client_id, msg))
420 queue_id,client_id, msg), exc_info=True)
421 421 return
422 422
423 423 eid = self.by_ident.get(queue_id, None)
@@ -448,7 +448,7 b' class Controller(object):'
448 448 msg = self.session.unpack_message(msg, content=False)
449 449 except:
450 450 logger.error("task::client %r sent invalid task message: %s"%(
451 client_id, msg))
451 client_id, msg), exc_info=True)
452 452 return
453 453
454 454 header = msg['header']
@@ -871,7 +871,11 b' def main():'
871 871 n = ZMQStream(ctx.socket(zmq.PUB), loop)
872 872 nport = bind_port(n, args.ip, args.notice)
873 873
874 thesession = session.StreamSession(username=args.ident or "controller")
874 ### Key File ###
875 if args.execkey and not os.path.isfile(args.execkey):
876 generate_exec_key(args.execkey)
877
878 thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey)
875 879
876 880 ### build and launch the queues ###
877 881
@@ -40,7 +40,7 b' class Engine(object):'
40 40 heart=None
41 41 kernel=None
42 42
43 def __init__(self, context, loop, session, registrar, client, ident=None, heart_id=None):
43 def __init__(self, context, loop, session, registrar, client=None, ident=None):
44 44 self.context = context
45 45 self.loop = loop
46 46 self.session = session
@@ -53,6 +53,7 b' class Engine(object):'
53 53
54 54 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
55 55 self.registrar.on_recv(self.complete_registration)
56 # print (self.session.key)
56 57 self.session.send(self.registrar, "registration_request",content=content)
57 58
58 59 def complete_registration(self, msg):
@@ -77,9 +78,8 b' class Engine(object):'
77 78 sub.on_recv(lambda *a: None)
78 79 port = sub.bind_to_random_port("tcp://%s"%LOCALHOST)
79 80 iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345)
80
81 81 make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs,
82 client_addr=None, loop=self.loop, context=self.context)
82 client_addr=None, loop=self.loop, context=self.context, key=self.session.key)
83 83
84 84 else:
85 85 # logger.error("Registration Failed: %s"%msg)
@@ -111,7 +111,8 b' def main():'
111 111 iface="%s://%s"%(args.transport,args.ip)+':%i'
112 112
113 113 loop = ioloop.IOLoop.instance()
114 session = StreamSession()
114 session = StreamSession(keyfile=args.execkey)
115 # print (session.key)
115 116 ctx = zmq.Context()
116 117
117 118 # setup logging
@@ -124,7 +125,7 b' def main():'
124 125 reg = ctx.socket(zmq.PAIR)
125 126 reg.connect(reg_conn)
126 127 reg = zmqstream.ZMQStream(reg, loop)
127 client = Client(reg_conn)
128 client = None
128 129
129 130 e = Engine(ctx, loop, session, reg, client, args.ident)
130 131 dc = ioloop.DelayedCallback(e.start, 100, loop)
@@ -7,6 +7,7 b' import logging'
7 7 import atexit
8 8 import sys
9 9 import os
10 import stat
10 11 import socket
11 12 from subprocess import Popen, PIPE
12 13 from signal import signal, SIGINT, SIGABRT, SIGTERM
@@ -33,7 +34,7 b' def split_ports(s, n):'
33 34 return ports
34 35
35 36 def select_random_ports(n):
36 """Selects and return n random ports that are open."""
37 """Selects and return n random ports that are available."""
37 38 ports = []
38 39 for i in xrange(n):
39 40 sock = socket.socket()
@@ -46,6 +47,7 b' def select_random_ports(n):'
46 47 return ports
47 48
48 49 def parse_url(args):
50 """Ensure args.url contains full transport://interface:port"""
49 51 if args.url:
50 52 iface = args.url.split('://',1)
51 53 if len(args) == 2:
@@ -57,6 +59,7 b' def parse_url(args):'
57 59 args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport)
58 60
59 61 def signal_children(children):
62 """Relay interupt/term signals to children, for more solid process cleanup."""
60 63 def terminate_children(sig, frame):
61 64 for child in children:
62 65 child.terminate()
@@ -64,6 +67,17 b' def signal_children(children):'
64 67 for sig in (SIGINT, SIGABRT, SIGTERM):
65 68 signal(sig, terminate_children)
66 69
70 def generate_exec_key(keyfile):
71 import uuid
72 newkey = str(uuid.uuid4())
73 with open(keyfile, 'w') as f:
74 # f.write('ipython-key ')
75 f.write(newkey)
76 # set user-only RW permissions (0600)
77 # this will have no effect on Windows
78 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
79
80
67 81 def make_base_argument_parser():
68 82 """ Creates an ArgumentParser for the generic arguments supported by all
69 83 ipcluster entry points.
@@ -86,6 +100,8 b' def make_base_argument_parser():'
86 100 help='set the message format method [default: json]')
87 101 parser.add_argument('--url', type=str,
88 102 help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101')
103 parser.add_argument('--execkey', type=str,
104 help="File containing key for authenticating requests.")
89 105
90 106 return parser
91 107
@@ -65,7 +65,7 b' def main():'
65 65
66 66 controller_args = strip_args([('--n','-n')])
67 67 engine_args = filter_args(['--url', '--regport', '--logport', '--ip',
68 '--transport','--loglevel','--packer'])+['--ident']
68 '--transport','--loglevel','--packer', '--execkey'])+['--ident']
69 69
70 70 controller = launch_process('controller', controller_args)
71 71 for i in range(10):
@@ -127,17 +127,21 b' class Kernel(HasTraits):'
127 127 """kill ourself. This should really be handled in an external process"""
128 128 self.abort_queues()
129 129 content = dict(parent['content'])
130 msg = self.session.send(self.reply_socket, 'shutdown_reply',
131 content, parent, ident)
132 msg = self.session.send(self.pub_socket, 'shutdown_reply',
133 content, parent, ident)
130 msg = self.session.send(stream, 'shutdown_reply',
131 content=content, parent=parent, ident=ident)
132 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
133 # content, parent, ident)
134 134 # print >> sys.__stdout__, msg
135 135 time.sleep(0.1)
136 136 sys.exit(0)
137 137
138 138 def dispatch_control(self, msg):
139 139 idents,msg = self.session.feed_identities(msg, copy=False)
140 try:
140 141 msg = self.session.unpack_message(msg, content=True, copy=False)
142 except:
143 logger.error("Invalid Message", exc_info=True)
144 return
141 145
142 146 header = msg['header']
143 147 msg_id = header['msg_id']
@@ -313,7 +317,12 b' class Kernel(HasTraits):'
313 317 def dispatch_queue(self, stream, msg):
314 318 self.control_stream.flush()
315 319 idents,msg = self.session.feed_identities(msg, copy=False)
320 try:
316 321 msg = self.session.unpack_message(msg, content=True, copy=False)
322 except:
323 logger.error("Invalid Message", exc_info=True)
324 return
325
317 326
318 327 header = msg['header']
319 328 msg_id = header['msg_id']
@@ -367,14 +376,15 b' class Kernel(HasTraits):'
367 376 # time.sleep(1e-3)
368 377
369 378 def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
370 client_addr=None, loop=None, context=None):
379 client_addr=None, loop=None, context=None, key=None):
371 380 # create loop, context, and session:
372 381 if loop is None:
373 382 loop = ioloop.IOLoop.instance()
374 383 if context is None:
375 384 context = zmq.Context()
376 385 c = context
377 session = StreamSession()
386 session = StreamSession(key=key)
387 # print (session.key)
378 388 print (control_addr, shell_addrs, iopub_addr, hb_addrs)
379 389
380 390 # create Control Stream
@@ -277,7 +277,9 b' def unpack_apply_message(bufs, g=None, copy=True):'
277 277 class StreamSession(object):
278 278 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
279 279 debug=False
280 def __init__(self, username=None, session=None, packer=None, unpacker=None):
280 key=None
281
282 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
281 283 if username is None:
282 284 username = os.environ.get('USER','username')
283 285 self.username = username
@@ -300,6 +302,14 b' class StreamSession(object):'
300 302 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
301 303 self.unpack = unpacker
302 304
305 if key is not None and keyfile is not None:
306 raise TypeError("Must specify key OR keyfile, not both")
307 if keyfile is not None:
308 with open(keyfile) as f:
309 self.key = f.read().strip()
310 else:
311 self.key = key
312 # print key, keyfile, self.key
303 313 self.none = self.pack({})
304 314
305 315 def msg_header(self, msg_type):
@@ -318,6 +328,14 b' class StreamSession(object):'
318 328 msg['header'].update(sub)
319 329 return msg
320 330
331 def check_key(self, msg_or_header):
332 """Check that a message's header has the right key"""
333 if self.key is None:
334 return True
335 header = extract_header(msg_or_header)
336 return header.get('key', None) == self.key
337
338
321 339 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
322 340 """Build and send a message via stream or socket.
323 341
@@ -353,6 +371,8 b' class StreamSession(object):'
353 371 elif ident is not None:
354 372 to_send.append(ident)
355 373 to_send.append(DELIM)
374 if self.key is not None:
375 to_send.append(self.key)
356 376 to_send.append(self.pack(msg['header']))
357 377 to_send.append(self.pack(msg['parent_header']))
358 378
@@ -393,6 +413,8 b' class StreamSession(object):'
393 413 if ident is not None:
394 414 to_send.extend(ident)
395 415 to_send.append(DELIM)
416 if self.key is not None:
417 to_send.append(self.key)
396 418 to_send.extend(msg)
397 419 stream.send_multipart(msg, flags, copy=copy)
398 420
@@ -457,19 +479,24 b' class StreamSession(object):'
457 479 or the non-copying Message object in each place (False)
458 480
459 481 """
460 if not len(msg) >= 3:
461 raise TypeError("malformed message, must have at least 3 elements")
482 ikey = int(self.key is not None)
483 minlen = 3 + ikey
484 if not len(msg) >= minlen:
485 raise TypeError("malformed message, must have at least %i elements"%minlen)
462 486 message = {}
463 487 if not copy:
464 for i in range(3):
488 for i in range(minlen):
465 489 msg[i] = msg[i].bytes
466 message['header'] = self.unpack(msg[0])
490 if ikey:
491 if not self.key == msg[0]:
492 raise KeyError("Invalid Session Key: %s"%msg[0])
493 message['header'] = self.unpack(msg[ikey+0])
467 494 message['msg_type'] = message['header']['msg_type']
468 message['parent_header'] = self.unpack(msg[1])
495 message['parent_header'] = self.unpack(msg[ikey+1])
469 496 if content:
470 message['content'] = self.unpack(msg[2])
497 message['content'] = self.unpack(msg[ikey+2])
471 498 else:
472 message['content'] = msg[2]
499 message['content'] = msg[ikey+2]
473 500
474 501 # message['buffers'] = msg[3:]
475 502 # else:
@@ -481,7 +508,7 b' class StreamSession(object):'
481 508 # else:
482 509 # message['content'] = msg[2].bytes
483 510
484 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
511 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
485 512 return message
486 513
487 514
General Comments 0
You need to be logged in to leave comments. Login now