From 79b3544ce9b43e47f89f4c473ada631b096a4c28 2011-07-03 21:01:00 From: MinRK Date: 2011-07-03 21:01:00 Subject: [PATCH] enforce ascii identities in parallel code This should be temporary also use uuid.bytes to get a binary uuid in the HeartMonitor --- diff --git a/IPython/parallel/controller/heartmonitor.py b/IPython/parallel/controller/heartmonitor.py index 47658c8..23a1a92 100644 --- a/IPython/parallel/controller/heartmonitor.py +++ b/IPython/parallel/controller/heartmonitor.py @@ -46,7 +46,7 @@ class Heart(object): if in_type == zmq.SUB: self.device.setsockopt_in(zmq.SUBSCRIBE, b"") if heart_id is None: - heart_id = ensure_bytes(uuid.uuid4()) + heart_id = uuid.uuid4().bytes self.device.setsockopt_out(zmq.IDENTITY, heart_id) self.id = heart_id diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index f601ce5..8549e89 100755 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -563,8 +563,8 @@ class Hub(SessionFactory): record = init_record(msg) msg_id = record['msg_id'] # Unicode in records - record['engine_uuid'] = queue_id.decode('utf8', 'replace') - record['client_uuid'] = client_id.decode('utf8', 'replace') + record['engine_uuid'] = queue_id.decode('ascii') + record['client_uuid'] = client_id.decode('ascii') record['queue'] = 'mux' try: @@ -834,7 +834,7 @@ class Hub(SessionFactory): jsonable = {} for k,v in self.keytable.iteritems(): if v not in self.dead_engines: - jsonable[str(k)] = v.decode() + jsonable[str(k)] = v.decode('ascii') content['engines'] = jsonable self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id) diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index 55ee059..ba86250 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -503,7 +503,7 @@ class TaskScheduler(SessionFactory): self.add_job(idx) self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout) # notify Hub - content = dict(msg_id=msg_id, engine_id=target.decode()) + content = dict(msg_id=msg_id, engine_id=target.decode('ascii')) self.session.send(self.mon_stream, 'task_destination', content=content, ident=[b'tracktask',self.ident]) diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py index 43e911b..d4b30d4 100755 --- a/IPython/parallel/engine/engine.py +++ b/IPython/parallel/engine/engine.py @@ -23,7 +23,7 @@ import zmq from zmq.eventloop import ioloop, zmqstream # internal -from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode +from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes # from IPython.utils.localinterfaces import LOCALHOST from IPython.parallel.controller.heartmonitor import Heart @@ -58,6 +58,11 @@ class EngineFactory(RegistrationFactory): registrar=Instance('zmq.eventloop.zmqstream.ZMQStream') kernel=Instance(Kernel) + bident = CBytes() + ident = Unicode() + def _ident_changed(self, name, old, new): + self.bident = ensure_bytes(new) + def __init__(self, **kwargs): super(EngineFactory, self).__init__(**kwargs) @@ -65,7 +70,7 @@ class EngineFactory(RegistrationFactory): ctx = self.context reg = ctx.socket(zmq.XREQ) - reg.setsockopt(zmq.IDENTITY, ensure_bytes(self.ident)) + reg.setsockopt(zmq.IDENTITY, self.bident) reg.connect(self.url) self.registrar = zmqstream.ZMQStream(reg, self.loop) @@ -83,8 +88,7 @@ class EngineFactory(RegistrationFactory): self._abort_dc.stop() ctx = self.context loop = self.loop - identity = ensure_bytes(self.ident) - + identity = self.bident idents,msg = self.session.feed_identities(msg) msg = Message(self.session.unpack_message(msg)) @@ -139,7 +143,7 @@ class EngineFactory(RegistrationFactory): if self.display_hook_factory: sys.displayhook = self.display_hook_factory(self.session, iopub_stream) sys.displayhook.topic = 'engine.%i.pyout'%self.id - + self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session, control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream, loop=loop, user_ns = self.user_ns, log=self.log) diff --git a/IPython/parallel/util.py b/IPython/parallel/util.py index fe09943..486b894 100644 --- a/IPython/parallel/util.py +++ b/IPython/parallel/util.py @@ -102,9 +102,9 @@ class ReverseDict(dict): #----------------------------------------------------------------------------- def ensure_bytes(s): - """ensure that an object is bytes""" + """ensure that an object is ascii bytes""" if isinstance(s, unicode): - s = s.encode(sys.getdefaultencoding(), 'replace') + s = s.encode('ascii') return s def validate_url(url):