From 1160d9f8398ae7dc37273a3f4d3a1d5a3e4c05d6 2011-04-08 00:38:11 From: MinRK Date: 2011-04-08 00:38:11 Subject: [PATCH] added exec_key and fixed client.shutdown --- diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 5eb2b77..bde8d5b 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -12,6 +12,7 @@ from __future__ import print_function +import os import time from pprint import pprint @@ -139,19 +140,30 @@ class Client(object): A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port' If keyfile or password is specified, and this is not, it will default to the ip given in addr. - keyfile : str; path to public key file + sshkey : str; path to public ssh key file This specifies a key to be used in ssh login, default None. Regular default ssh keys will be used without specifying this argument. password : str; Your ssh password to sshserver. Note that if this is left None, you will be prompted for it if passwordless key based login is unavailable. + #------- exec authentication args ------- + # If even localhost is untrusted, you can have some protection against + # unauthorized execution by using a key. Messages are still sent + # as cleartext, so if someone can snoop your loopback traffic this will + # not help anything. + + exec_key : str + an authentication key or file containing a key + default: None + + Attributes ---------- ids : set of int engine IDs requesting the ids attribute always synchronizes the registration state. To request ids without synchronization, - use semi-private _ids. + use semi-private _ids attributes. history : list of msg_ids a list of msg_ids, keeping track of all the execution @@ -175,7 +187,7 @@ class Client(object): barrier : wait on one or more msg_ids - execution methods: apply/apply_bound/apply_to/applu_bount + execution methods: apply/apply_bound/apply_to/apply_bound legacy: execute, run query methods: queue_status, get_result, purge @@ -202,26 +214,32 @@ class Client(object): debug = False def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False, - sshserver=None, keyfile=None, password=None, paramiko=None): + sshserver=None, sshkey=None, password=None, paramiko=None, + exec_key=None,): if context is None: context = zmq.Context() self.context = context self._addr = addr - self._ssh = bool(sshserver or keyfile or password) + self._ssh = bool(sshserver or sshkey or password) if self._ssh and sshserver is None: # default to the same sshserver = addr.split('://')[1].split(':')[0] if self._ssh and password is None: - if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko): + if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko): password=False else: password = getpass("SSH Password for %s: "%sshserver) - ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko) - + ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko) + + if os.path.isfile(exec_key): + arg = 'keyfile' + else: + arg = 'key' + key_arg = {arg:exec_key} if username is None: - self.session = ss.StreamSession() + self.session = ss.StreamSession(**key_arg) else: - self.session = ss.StreamSession(username) + self.session = ss.StreamSession(username, **key_arg) self._registration_socket = self.context.socket(zmq.XREQ) self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) if self._ssh: @@ -536,11 +554,12 @@ class Client(object): @spinfirst @defaultblock - def kill(self, targets=None, block=None): + def shutdown(self, targets=None, restart=False, block=None): """Terminates one or more engine processes.""" targets = self._build_targets(targets)[0] for t in targets: - self.session.send(self._control_socket, 'kill_request', content={},ident=t) + self.session.send(self._control_socket, 'shutdown_request', + content={'restart':restart},ident=t) error = False if self.block: for i in range(len(targets)): diff --git a/IPython/zmq/parallel/controller.py b/IPython/zmq/parallel/controller.py index 7251009..dde3aca 100644 --- a/IPython/zmq/parallel/controller.py +++ b/IPython/zmq/parallel/controller.py @@ -15,6 +15,7 @@ and monitors traffic through the various queues. #----------------------------------------------------------------------------- from __future__ import print_function +import os from datetime import datetime import logging @@ -28,7 +29,7 @@ from IPython.zmq.entry_point import bind_port from streamsession import Message, wrap_exception from entry_point import (make_base_argument_parser, select_random_ports, split_ports, - connect_logger, parse_url, signal_children) + connect_logger, parse_url, signal_children, generate_exec_key) #----------------------------------------------------------------------------- # Code @@ -283,13 +284,12 @@ class Controller(object): logger.debug("registration::dispatch_register_request(%s)"%msg) idents,msg = self.session.feed_identities(msg) if not idents: - logger.error("Bad Queue Message: %s"%msg) + logger.error("Bad Queue Message: %s"%msg, exc_info=True) return try: msg = self.session.unpack_message(msg,content=True) - except Exception as e: - logger.error("registration::got bad registration message: %s"%msg) - raise e + except: + logger.error("registration::got bad registration message: %s"%msg, exc_info=True) return msg_type = msg['msg_type'] @@ -326,7 +326,7 @@ class Controller(object): msg = self.session.unpack_message(msg, content=True) except: content = wrap_exception() - logger.error("Bad Client Message: %s"%msg) + logger.error("Bad Client Message: %s"%msg, exc_info=True) self.session.send(self.clientele, "controller_error", ident=client_id, content=content) return @@ -340,7 +340,7 @@ class Controller(object): assert handler is not None, "Bad Message Type: %s"%msg_type except: content = wrap_exception() - logger.error("Bad Message Type: %s"%msg_type) + logger.error("Bad Message Type: %s"%msg_type, exc_info=True) self.session.send(self.clientele, "controller_error", ident=client_id, content=content) return @@ -390,7 +390,7 @@ class Controller(object): try: msg = self.session.unpack_message(msg, content=False) except: - logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg)) + logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True) return eid = self.by_ident.get(queue_id, None) @@ -417,7 +417,7 @@ class Controller(object): msg = self.session.unpack_message(msg, content=False) except: logger.error("queue::engine %r sent invalid message to %r: %s"%( - queue_id,client_id, msg)) + queue_id,client_id, msg), exc_info=True) return eid = self.by_ident.get(queue_id, None) @@ -448,7 +448,7 @@ class Controller(object): msg = self.session.unpack_message(msg, content=False) except: logger.error("task::client %r sent invalid task message: %s"%( - client_id, msg)) + client_id, msg), exc_info=True) return header = msg['header'] @@ -871,7 +871,11 @@ def main(): n = ZMQStream(ctx.socket(zmq.PUB), loop) nport = bind_port(n, args.ip, args.notice) - thesession = session.StreamSession(username=args.ident or "controller") + ### Key File ### + if args.execkey and not os.path.isfile(args.execkey): + generate_exec_key(args.execkey) + + thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey) ### build and launch the queues ### diff --git a/IPython/zmq/parallel/engine.py b/IPython/zmq/parallel/engine.py index 4eead85..1284874 100644 --- a/IPython/zmq/parallel/engine.py +++ b/IPython/zmq/parallel/engine.py @@ -40,7 +40,7 @@ class Engine(object): heart=None kernel=None - def __init__(self, context, loop, session, registrar, client, ident=None, heart_id=None): + def __init__(self, context, loop, session, registrar, client=None, ident=None): self.context = context self.loop = loop self.session = session @@ -53,6 +53,7 @@ class Engine(object): content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) self.registrar.on_recv(self.complete_registration) + # print (self.session.key) self.session.send(self.registrar, "registration_request",content=content) def complete_registration(self, msg): @@ -77,9 +78,8 @@ class Engine(object): sub.on_recv(lambda *a: None) port = sub.bind_to_random_port("tcp://%s"%LOCALHOST) iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345) - make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs, - client_addr=None, loop=self.loop, context=self.context) + client_addr=None, loop=self.loop, context=self.context, key=self.session.key) else: # logger.error("Registration Failed: %s"%msg) @@ -111,7 +111,8 @@ def main(): iface="%s://%s"%(args.transport,args.ip)+':%i' loop = ioloop.IOLoop.instance() - session = StreamSession() + session = StreamSession(keyfile=args.execkey) + # print (session.key) ctx = zmq.Context() # setup logging @@ -124,7 +125,7 @@ def main(): reg = ctx.socket(zmq.PAIR) reg.connect(reg_conn) reg = zmqstream.ZMQStream(reg, loop) - client = Client(reg_conn) + client = None e = Engine(ctx, loop, session, reg, client, args.ident) dc = ioloop.DelayedCallback(e.start, 100, loop) diff --git a/IPython/zmq/parallel/entry_point.py b/IPython/zmq/parallel/entry_point.py index 1042987..2036186 100644 --- a/IPython/zmq/parallel/entry_point.py +++ b/IPython/zmq/parallel/entry_point.py @@ -7,6 +7,7 @@ import logging import atexit import sys import os +import stat import socket from subprocess import Popen, PIPE from signal import signal, SIGINT, SIGABRT, SIGTERM @@ -33,7 +34,7 @@ def split_ports(s, n): return ports def select_random_ports(n): - """Selects and return n random ports that are open.""" + """Selects and return n random ports that are available.""" ports = [] for i in xrange(n): sock = socket.socket() @@ -46,6 +47,7 @@ def select_random_ports(n): return ports def parse_url(args): + """Ensure args.url contains full transport://interface:port""" if args.url: iface = args.url.split('://',1) if len(args) == 2: @@ -57,6 +59,7 @@ def parse_url(args): args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport) def signal_children(children): + """Relay interupt/term signals to children, for more solid process cleanup.""" def terminate_children(sig, frame): for child in children: child.terminate() @@ -64,6 +67,17 @@ def signal_children(children): for sig in (SIGINT, SIGABRT, SIGTERM): signal(sig, terminate_children) +def generate_exec_key(keyfile): + import uuid + newkey = str(uuid.uuid4()) + with open(keyfile, 'w') as f: + # f.write('ipython-key ') + f.write(newkey) + # set user-only RW permissions (0600) + # this will have no effect on Windows + os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR) + + def make_base_argument_parser(): """ Creates an ArgumentParser for the generic arguments supported by all ipcluster entry points. @@ -86,6 +100,8 @@ def make_base_argument_parser(): help='set the message format method [default: json]') parser.add_argument('--url', type=str, help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101') + parser.add_argument('--execkey', type=str, + help="File containing key for authenticating requests.") return parser diff --git a/IPython/zmq/parallel/ipcluster.py b/IPython/zmq/parallel/ipcluster.py index b458240..c413e8c 100644 --- a/IPython/zmq/parallel/ipcluster.py +++ b/IPython/zmq/parallel/ipcluster.py @@ -65,7 +65,7 @@ def main(): controller_args = strip_args([('--n','-n')]) engine_args = filter_args(['--url', '--regport', '--logport', '--ip', - '--transport','--loglevel','--packer'])+['--ident'] + '--transport','--loglevel','--packer', '--execkey'])+['--ident'] controller = launch_process('controller', controller_args) for i in range(10): diff --git a/IPython/zmq/parallel/streamkernel.py b/IPython/zmq/parallel/streamkernel.py index f3139f1..a361bc5 100755 --- a/IPython/zmq/parallel/streamkernel.py +++ b/IPython/zmq/parallel/streamkernel.py @@ -127,17 +127,21 @@ class Kernel(HasTraits): """kill ourself. This should really be handled in an external process""" self.abort_queues() content = dict(parent['content']) - msg = self.session.send(self.reply_socket, 'shutdown_reply', - content, parent, ident) - msg = self.session.send(self.pub_socket, 'shutdown_reply', - content, parent, ident) + msg = self.session.send(stream, 'shutdown_reply', + content=content, parent=parent, ident=ident) + # msg = self.session.send(self.pub_socket, 'shutdown_reply', + # content, parent, ident) # print >> sys.__stdout__, msg time.sleep(0.1) sys.exit(0) def dispatch_control(self, msg): idents,msg = self.session.feed_identities(msg, copy=False) - msg = self.session.unpack_message(msg, content=True, copy=False) + try: + msg = self.session.unpack_message(msg, content=True, copy=False) + except: + logger.error("Invalid Message", exc_info=True) + return header = msg['header'] msg_id = header['msg_id'] @@ -313,7 +317,12 @@ class Kernel(HasTraits): def dispatch_queue(self, stream, msg): self.control_stream.flush() idents,msg = self.session.feed_identities(msg, copy=False) - msg = self.session.unpack_message(msg, content=True, copy=False) + try: + msg = self.session.unpack_message(msg, content=True, copy=False) + except: + logger.error("Invalid Message", exc_info=True) + return + header = msg['header'] msg_id = header['msg_id'] @@ -367,14 +376,15 @@ class Kernel(HasTraits): # time.sleep(1e-3) def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs, - client_addr=None, loop=None, context=None): + client_addr=None, loop=None, context=None, key=None): # create loop, context, and session: if loop is None: loop = ioloop.IOLoop.instance() if context is None: context = zmq.Context() c = context - session = StreamSession() + session = StreamSession(key=key) + # print (session.key) print (control_addr, shell_addrs, iopub_addr, hb_addrs) # create Control Stream diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py index d8356f4..078b3fa 100644 --- a/IPython/zmq/parallel/streamsession.py +++ b/IPython/zmq/parallel/streamsession.py @@ -277,7 +277,9 @@ def unpack_apply_message(bufs, g=None, copy=True): class StreamSession(object): """tweaked version of IPython.zmq.session.Session, for development in Parallel""" debug=False - def __init__(self, username=None, session=None, packer=None, unpacker=None): + key=None + + def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None): if username is None: username = os.environ.get('USER','username') self.username = username @@ -300,6 +302,14 @@ class StreamSession(object): raise TypeError("unpacker must be callable, not %s"%type(unpacker)) self.unpack = unpacker + if key is not None and keyfile is not None: + raise TypeError("Must specify key OR keyfile, not both") + if keyfile is not None: + with open(keyfile) as f: + self.key = f.read().strip() + else: + self.key = key + # print key, keyfile, self.key self.none = self.pack({}) def msg_header(self, msg_type): @@ -318,6 +328,14 @@ class StreamSession(object): msg['header'].update(sub) return msg + def check_key(self, msg_or_header): + """Check that a message's header has the right key""" + if self.key is None: + return True + header = extract_header(msg_or_header) + return header.get('key', None) == self.key + + def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None): """Build and send a message via stream or socket. @@ -353,6 +371,8 @@ class StreamSession(object): elif ident is not None: to_send.append(ident) to_send.append(DELIM) + if self.key is not None: + to_send.append(self.key) to_send.append(self.pack(msg['header'])) to_send.append(self.pack(msg['parent_header'])) @@ -393,6 +413,8 @@ class StreamSession(object): if ident is not None: to_send.extend(ident) to_send.append(DELIM) + if self.key is not None: + to_send.append(self.key) to_send.extend(msg) stream.send_multipart(msg, flags, copy=copy) @@ -457,19 +479,24 @@ class StreamSession(object): or the non-copying Message object in each place (False) """ - if not len(msg) >= 3: - raise TypeError("malformed message, must have at least 3 elements") + ikey = int(self.key is not None) + minlen = 3 + ikey + if not len(msg) >= minlen: + raise TypeError("malformed message, must have at least %i elements"%minlen) message = {} if not copy: - for i in range(3): + for i in range(minlen): msg[i] = msg[i].bytes - message['header'] = self.unpack(msg[0]) + if ikey: + if not self.key == msg[0]: + raise KeyError("Invalid Session Key: %s"%msg[0]) + message['header'] = self.unpack(msg[ikey+0]) message['msg_type'] = message['header']['msg_type'] - message['parent_header'] = self.unpack(msg[1]) + message['parent_header'] = self.unpack(msg[ikey+1]) if content: - message['content'] = self.unpack(msg[2]) + message['content'] = self.unpack(msg[ikey+2]) else: - message['content'] = msg[2] + message['content'] = msg[ikey+2] # message['buffers'] = msg[3:] # else: @@ -481,7 +508,7 @@ class StreamSession(object): # else: # message['content'] = msg[2].bytes - message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ] + message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ] return message