From a70d89c1e669f57235caa91079ff079617818c3e 2012-07-18 14:36:46 From: MinRK Date: 2012-07-18 14:36:46 Subject: [PATCH] simplify IPython.parallel connections Rolls back two-stage connection, putting complete connection info into the connection files. This makes it easier to use hand-crafted ssh tunnels, as all ports are read from the file, rather than from the reply to registration/connection requests. It is no longer possible to connect to the Controller without a connection file. Adding the serialization methods to the connection file also makes it harder for custom serialization to result in a mismatch in configuration between the various objects. --- diff --git a/IPython/parallel/apps/ipcontrollerapp.py b/IPython/parallel/apps/ipcontrollerapp.py index 1c36420..7262815 100755 --- a/IPython/parallel/apps/ipcontrollerapp.py +++ b/IPython/parallel/apps/ipcontrollerapp.py @@ -209,7 +209,7 @@ class IPControllerApp(BaseParallelApplication): def save_connection_dict(self, fname, cdict): """save a connection dict to json file.""" c = self.config - url = cdict['url'] + url = cdict['registration'] location = cdict['location'] if not location: try: @@ -314,15 +314,21 @@ class IPControllerApp(BaseParallelApplication): if self.write_connection_files: # save to new json config files f = self.factory - cdict = {'exec_key' : f.session.key.decode('ascii'), - 'ssh' : self.ssh_server, - 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport), - 'location' : self.location - } + base = { + 'exec_key' : f.session.key.decode('ascii'), + 'location' : self.location, + 'pack' : f.session.packer, + 'unpack' : f.session.unpacker, + } + + cdict = {'ssh' : self.ssh_server} + cdict.update(f.client_info) + cdict.update(base) self.save_connection_dict(self.client_json_file, cdict) - edict = cdict - edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport)) - edict['ssh'] = self.engine_ssh_server + + edict = {'ssh' : self.engine_ssh_server} + edict.update(f.engine_info) + edict.update(base) self.save_connection_dict(self.engine_json_file, edict) def init_schedulers(self): @@ -379,7 +385,7 @@ class IPControllerApp(BaseParallelApplication): else: self.log.info("task::using Python %s Task scheduler"%scheme) - sargs = (hub.client_info['task'][1], hub.engine_info['task'], + sargs = (hub.client_info['task'], hub.engine_info['task'], monitor_url, disambiguate_url(hub.client_info['notification'])) kwargs = dict(logname='scheduler', loglevel=self.log_level, log_url = self.log_url, config=dict(self.config)) diff --git a/IPython/parallel/apps/ipengineapp.py b/IPython/parallel/apps/ipengineapp.py index d24f574..04261b4 100755 --- a/IPython/parallel/apps/ipengineapp.py +++ b/IPython/parallel/apps/ipengineapp.py @@ -211,24 +211,35 @@ class IPEngineApp(BaseParallelApplication): with open(self.url_file) as f: d = json.loads(f.read()) - if 'exec_key' in d: - config.Session.key = cast_bytes(d['exec_key']) - + # allow hand-override of location for disambiguation + # and ssh-server try: config.EngineFactory.location except AttributeError: config.EngineFactory.location = d['location'] - d['url'] = disambiguate_url(d['url'], config.EngineFactory.location) - try: - config.EngineFactory.url - except AttributeError: - config.EngineFactory.url = d['url'] - try: config.EngineFactory.sshserver except AttributeError: - config.EngineFactory.sshserver = d['ssh'] + config.EngineFactory.sshserver = d.get('ssh') + + location = config.EngineFactory.location + + for key in ('registration', 'hb_ping', 'hb_pong', 'mux', 'task', 'control'): + d[key] = disambiguate_url(d[key], location) + + # DO NOT allow override of basic URLs, serialization, or exec_key + # JSON file takes top priority there + config.Session.key = asbytes(d['exec_key']) + + config.EngineFactory.url = d['registration'] + + config.Session.packer = d['pack'] + config.Session.unpacker = d['unpack'] + + self.log.debug("Config changed:") + self.log.debug("%r", config) + self.connection_info = d def bind_kernel(self, **kwargs): """Promote engine to listening kernel, accessible to frontends.""" @@ -320,7 +331,9 @@ class IPEngineApp(BaseParallelApplication): # shell_class = import_item(self.master_config.Global.shell_class) # print self.config try: - self.engine = EngineFactory(config=config, log=self.log) + self.engine = EngineFactory(config=config, log=self.log, + connection_info=self.connection_info, + ) except: self.log.error("Couldn't start the Engine", exc_info=True) self.exit(1) diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index 233226c..5ca6059 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -217,7 +217,9 @@ class Client(HasTraits): Parameters ---------- - url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json + url_file : str/unicode; path to ipcontroller-client.json + This JSON file should contain all the information needed to connect to a cluster, + and is likely the only argument needed. Connection information for the Hub's registration. If a json connector file is given, then likely no further configuration is necessary. [Default: use profile] @@ -239,14 +241,6 @@ class Client(HasTraits): If specified, this will be relayed to the Session for configuration username : str set username for the session object - packer : str (import_string) or callable - Can be either the simple keyword 'json' or 'pickle', or an import_string to a - function to serialize messages. Must support same input as - JSON, and output must be bytes. - You can pass a callable directly as `pack` - unpacker : str (import_string) or callable - The inverse of packer. Only necessary if packer is specified as *not* one - of 'json' or 'pickle'. #-------------- ssh related args ---------------- # These are args for configuring the ssh tunnel to be used @@ -271,17 +265,6 @@ class Client(HasTraits): flag for whether to use paramiko instead of shell ssh for tunneling. [default: True on win32, False else] - ------- exec authentication args ------- - If even localhost is untrusted, you can have some protection against - unauthorized execution by signing messages with HMAC digests. - Messages are still sent as cleartext, so if someone can snoop your - loopback traffic this will not protect your privacy, but will prevent - unauthorized execution. - - exec_key : str - an authentication key or file containing a key - default: None - Attributes ---------- @@ -378,8 +361,8 @@ class Client(HasTraits): # don't raise on positional args return HasTraits.__new__(self, **kw) - def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None, - context=None, debug=False, exec_key=None, + def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None, + context=None, debug=False, sshserver=None, sshkey=None, password=None, paramiko=None, timeout=10, **extra_args ): @@ -391,38 +374,38 @@ class Client(HasTraits): context = zmq.Context.instance() self._context = context self._stop_spinning = Event() + + if 'url_or_file' in extra_args: + url_file = extra_args['url_or_file'] + warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning) + + if url_file and util.is_url(url_file): + raise ValueError("single urls cannot be specified, url-files must be used.") self._setup_profile_dir(self.profile, profile_dir, ipython_dir) + if self._cd is not None: - if url_or_file is None: - url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json') - if url_or_file is None: + if url_file is None: + url_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json') + if url_file is None: raise ValueError( "I can't find enough information to connect to a hub!" - " Please specify at least one of url_or_file or profile." + " Please specify at least one of url_file or profile." ) - - if not util.is_url(url_or_file): - # it's not a url, try for a file - if not os.path.exists(url_or_file): - if self._cd: - url_or_file = os.path.join(self._cd.security_dir, url_or_file) - if not os.path.exists(url_or_file): - raise IOError("Connection file not found: %r" % url_or_file) - with open(url_or_file) as f: - cfg = json.loads(f.read()) - else: - cfg = {'url':url_or_file} + + with open(url_file) as f: + cfg = json.load(f) + + self._task_scheme = cfg['task_scheme'] # sync defaults from args, json: if sshserver: cfg['ssh'] = sshserver - if exec_key: - cfg['exec_key'] = exec_key - exec_key = cfg['exec_key'] + location = cfg.setdefault('location', None) - cfg['url'] = util.disambiguate_url(cfg['url'], location) - url = cfg['url'] + for key in ('control', 'task', 'mux', 'notification', 'registration'): + cfg[key] = util.disambiguate_url(cfg[key], location) + url = cfg['registration'] proto,addr,port = util.split_url(url) if location is not None and addr == '127.0.0.1': # location specified, and connection is expected to be local @@ -457,12 +440,10 @@ class Client(HasTraits): ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko) # configure and construct the session - if exec_key is not None: - if os.path.isfile(exec_key): - extra_args['keyfile'] = exec_key - else: - exec_key = cast_bytes(exec_key) - extra_args['key'] = exec_key + extra_args['packer'] = cfg['pack'] + extra_args['unpacker'] = cfg['unpack'] + extra_args['key'] = cfg['exec_key'] + self.session = Session(**extra_args) self._query_socket = self._context.socket(zmq.DEALER) @@ -583,7 +564,7 @@ class Client(HasTraits): self._connected=True def connect_socket(s, url): - url = util.disambiguate_url(url, self._config['location']) + # url = util.disambiguate_url(url, self._config['location']) if self._ssh: return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs) else: @@ -600,30 +581,28 @@ class Client(HasTraits): idents,msg = self.session.recv(self._query_socket,mode=0) if self.debug: pprint(msg) - msg = Message(msg) - content = msg.content - self._config['registration'] = dict(content) - if content.status == 'ok': - ident = self.session.bsession - if content.mux: - self._mux_socket = self._context.socket(zmq.DEALER) - connect_socket(self._mux_socket, content.mux) - if content.task: - self._task_scheme, task_addr = content.task - self._task_socket = self._context.socket(zmq.DEALER) - connect_socket(self._task_socket, task_addr) - if content.notification: - self._notification_socket = self._context.socket(zmq.SUB) - connect_socket(self._notification_socket, content.notification) - self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'') - if content.control: - self._control_socket = self._context.socket(zmq.DEALER) - connect_socket(self._control_socket, content.control) - if content.iopub: - self._iopub_socket = self._context.socket(zmq.SUB) - self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'') - connect_socket(self._iopub_socket, content.iopub) - self._update_engines(dict(content.engines)) + content = msg['content'] + # self._config['registration'] = dict(content) + cfg = self._config + if content['status'] == 'ok': + self._mux_socket = self._context.socket(zmq.DEALER) + connect_socket(self._mux_socket, cfg['mux']) + + self._task_socket = self._context.socket(zmq.DEALER) + connect_socket(self._task_socket, cfg['task']) + + self._notification_socket = self._context.socket(zmq.SUB) + self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'') + connect_socket(self._notification_socket, cfg['notification']) + + self._control_socket = self._context.socket(zmq.DEALER) + connect_socket(self._control_socket, cfg['control']) + + self._iopub_socket = self._context.socket(zmq.SUB) + self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'') + connect_socket(self._iopub_socket, cfg['iopub']) + + self._update_engines(dict(content['engines'])) else: self._connected = False raise Exception("Failed to connect!") diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index 2967c7b..b5bacf7 100644 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -239,30 +239,61 @@ class HubFactory(RegistrationFactory): ctx = self.context loop = self.loop + try: + scheme = self.config.TaskScheduler.scheme_name + except AttributeError: + from .scheduler import TaskScheduler + scheme = TaskScheduler.scheme_name.get_default_value() + + # build connection dicts + engine = self.engine_info = { + 'registration' : engine_iface % self.regport, + 'control' : engine_iface % self.control[1], + 'mux' : engine_iface % self.mux[1], + 'hb_ping' : engine_iface % self.hb[0], + 'hb_pong' : engine_iface % self.hb[1], + 'task' : engine_iface % self.task[1], + 'iopub' : engine_iface % self.iopub[1], + } + + client = self.client_info = { + 'registration' : client_iface % self.regport, + 'control' : client_iface % self.control[0], + 'mux' : client_iface % self.mux[0], + 'task' : client_iface % self.task[0], + 'task_scheme' : scheme, + 'iopub' : client_iface % self.iopub[0], + 'notification' : client_iface % self.notifier_port, + } + + self.log.debug("Hub engine addrs: %s", self.engine_info) + self.log.debug("Hub client addrs: %s", self.client_info) + # Registrar socket q = ZMQStream(ctx.socket(zmq.ROUTER), loop) - q.bind(client_iface % self.regport) + q.bind(client['registration']) self.log.info("Hub listening on %s for registration.", client_iface % self.regport) if self.client_ip != self.engine_ip: - q.bind(engine_iface % self.regport) + q.bind(engine['registration']) self.log.info("Hub listening on %s for registration.", engine_iface % self.regport) ### Engine connections ### # heartbeat hpub = ctx.socket(zmq.PUB) - hpub.bind(engine_iface % self.hb[0]) + hpub.bind(engine['hb_ping']) hrep = ctx.socket(zmq.ROUTER) - hrep.bind(engine_iface % self.hb[1]) + hrep.bind(engine['hb_pong']) self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop) ) ### Client connections ### + # Notifier socket n = ZMQStream(ctx.socket(zmq.PUB), loop) - n.bind(client_iface%self.notifier_port) + n.bind(client['notification']) ### build and launch the queues ### @@ -279,35 +310,10 @@ class HubFactory(RegistrationFactory): self.db = import_item(str(db_class))(session=self.session.session, config=self.config, log=self.log) time.sleep(.25) - try: - scheme = self.config.TaskScheduler.scheme_name - except AttributeError: - from .scheduler import TaskScheduler - scheme = TaskScheduler.scheme_name.get_default_value() - # build connection dicts - self.engine_info = { - 'control' : engine_iface%self.control[1], - 'mux': engine_iface%self.mux[1], - 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]), - 'task' : engine_iface%self.task[1], - 'iopub' : engine_iface%self.iopub[1], - # 'monitor' : engine_iface%self.mon_port, - } - - self.client_info = { - 'control' : client_iface%self.control[0], - 'mux': client_iface%self.mux[0], - 'task' : (scheme, client_iface%self.task[0]), - 'iopub' : client_iface%self.iopub[0], - 'notification': client_iface%self.notifier_port - } - self.log.debug("Hub engine addrs: %s", self.engine_info) - self.log.debug("Hub client addrs: %s", self.client_info) # resubmit stream r = ZMQStream(ctx.socket(zmq.DEALER), loop) - url = util.disambiguate_url(self.client_info['task'][-1]) - r.setsockopt(zmq.IDENTITY, self.session.bsession) + url = util.disambiguate_url(self.client_info['task']) r.connect(url) self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor, @@ -384,8 +390,8 @@ class Hub(SessionFactory): # validate connection dicts: for k,v in self.client_info.iteritems(): - if k == 'task': - util.validate_url_container(v[1]) + if k == 'task_scheme': + continue else: util.validate_url_container(v) # util.validate_url_container(self.client_info) @@ -865,7 +871,6 @@ class Hub(SessionFactory): """Reply with connection addresses for clients.""" self.log.info("client::client %r connected", client_id) content = dict(status='ok') - content.update(self.client_info) jsonable = {} for k,v in self.keytable.iteritems(): if v not in self.dead_engines: @@ -891,7 +896,6 @@ class Hub(SessionFactory): self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart) content = dict(id=eid,status='ok') - content.update(self.engine_info) # check if requesting available IDs: if queue in self.by_ident: try: diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py index 136af66..05c5092 100644 --- a/IPython/parallel/engine/engine.py +++ b/IPython/parallel/engine/engine.py @@ -50,7 +50,7 @@ class EngineFactory(RegistrationFactory): help="""The location (an IP address) of the controller. This is used for disambiguating URLs, to determine whether loopback should be used to connect or the public address.""") - timeout=CFloat(2,config=True, + timeout=CFloat(5, config=True, help="""The time (in seconds) to wait for the Controller to respond to registration requests before giving up.""") sshserver=Unicode(config=True, @@ -61,10 +61,11 @@ class EngineFactory(RegistrationFactory): help="""Whether to use paramiko instead of openssh for tunnels.""") # not configurable: - user_ns=Dict() - id=Integer(allow_none=True) - registrar=Instance('zmq.eventloop.zmqstream.ZMQStream') - kernel=Instance(Kernel) + connection_info = Dict() + user_ns = Dict() + id = Integer(allow_none=True) + registrar = Instance('zmq.eventloop.zmqstream.ZMQStream') + kernel = Instance(Kernel) bident = CBytes() ident = Unicode() @@ -96,7 +97,7 @@ class EngineFactory(RegistrationFactory): def connect(s, url): url = disambiguate_url(url, self.location) if self.using_ssh: - self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver)) + self.log.debug("Tunneling connection to %s via %s", url, self.sshserver) return tunnel.tunnel_connection(s, url, self.sshserver, keyfile=self.sshkey, paramiko=self.paramiko, password=password, @@ -108,12 +109,12 @@ class EngineFactory(RegistrationFactory): """like connect, but don't complete the connection (for use by heartbeat)""" url = disambiguate_url(url, self.location) if self.using_ssh: - self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver)) + self.log.debug("Tunneling connection to %s via %s", url, self.sshserver) url,tunnelobj = tunnel.open_tunnel(url, self.sshserver, keyfile=self.sshkey, paramiko=self.paramiko, password=password, ) - return url + return str(url) return connect, maybe_tunnel def register(self): @@ -131,7 +132,7 @@ class EngineFactory(RegistrationFactory): content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel)) # print (self.session.key) - self.session.send(self.registrar, "registration_request",content=content) + self.session.send(self.registrar, "registration_request", content=content) def complete_registration(self, msg, connect, maybe_tunnel): # print msg @@ -140,50 +141,39 @@ class EngineFactory(RegistrationFactory): loop = self.loop identity = self.bident idents,msg = self.session.feed_identities(msg) - msg = Message(self.session.unserialize(msg)) - - if msg.content.status == 'ok': - self.id = int(msg.content.id) + msg = self.session.unserialize(msg) + content = msg['content'] + info = self.connection_info + + if content['status'] == 'ok': + self.id = int(content['id']) # launch heartbeat - hb_addrs = msg.content.heartbeat - # possibly forward hb ports with tunnels - hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ] - heart = Heart(*map(str, hb_addrs), heart_id=identity) + hb_ping = maybe_tunnel(info['hb_ping']) + hb_pong = maybe_tunnel(info['hb_pong']) + + heart = Heart(hb_ping, hb_pong, heart_id=identity) heart.start() - # create Shell Streams (MUX, Task, etc.): - queue_addr = msg.content.mux - shell_addrs = [ str(queue_addr) ] - task_addr = msg.content.task - if task_addr: - shell_addrs.append(str(task_addr)) - - # Uncomment this to go back to two-socket model - # shell_streams = [] - # for addr in shell_addrs: - # stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop) - # stream.setsockopt(zmq.IDENTITY, identity) - # stream.connect(disambiguate_url(addr, self.location)) - # shell_streams.append(stream) - - # Now use only one shell stream for mux and tasks + # create Shell Connections (MUX, Task, etc.): + shell_addrs = map(str, [info['mux'], info['task']]) + + # Use only one shell stream for mux and tasks stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop) stream.setsockopt(zmq.IDENTITY, identity) shell_streams = [stream] for addr in shell_addrs: connect(stream, addr) - # end single stream-socket # control stream: - control_addr = str(msg.content.control) + control_addr = str(info['control']) control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop) control_stream.setsockopt(zmq.IDENTITY, identity) connect(control_stream, control_addr) # create iopub stream: - iopub_addr = msg.content.iopub + iopub_addr = info['iopub'] iopub_socket = ctx.socket(zmq.PUB) iopub_socket.setsockopt(zmq.IDENTITY, identity) connect(iopub_socket, iopub_addr) diff --git a/IPython/zmq/session.py b/IPython/zmq/session.py index 04d70d2..4cf1ab2 100644 --- a/IPython/zmq/session.py +++ b/IPython/zmq/session.py @@ -257,9 +257,11 @@ class Session(Configurable): if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker + self.unpacker = new elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker + self.unpacker = new else: self.pack = import_item(str(new)) @@ -270,9 +272,11 @@ class Session(Configurable): if new.lower() == 'json': self.pack = json_packer self.unpack = json_unpacker + self.packer = new elif new.lower() == 'pickle': self.pack = pickle_packer self.unpack = pickle_unpacker + self.packer = new else: self.unpack = import_item(str(new))