##// END OF EJS Templates
merge IPython.parallel.streamsession into IPython.zmq.session...
MinRK -
Show More
@@ -41,7 +41,7 b' from IPython.utils.importstring import import_item'
41 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
41 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
42
42
43 # from IPython.parallel.controller.controller import ControllerFactory
43 # from IPython.parallel.controller.controller import ControllerFactory
44 from IPython.parallel.streamsession import StreamSession
44 from IPython.zmq.session import Session
45 from IPython.parallel.controller.heartmonitor import HeartMonitor
45 from IPython.parallel.controller.heartmonitor import HeartMonitor
46 from IPython.parallel.controller.hub import HubFactory
46 from IPython.parallel.controller.hub import HubFactory
47 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
47 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
@@ -109,7 +109,7 b' class IPControllerApp(BaseParallelApplication):'
109 name = u'ipcontroller'
109 name = u'ipcontroller'
110 description = _description
110 description = _description
111 config_file_name = Unicode(default_config_file_name)
111 config_file_name = Unicode(default_config_file_name)
112 classes = [ProfileDir, StreamSession, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
112 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
113
113
114 # change default to True
114 # change default to True
115 auto_create = Bool(True, config=True,
115 auto_create = Bool(True, config=True,
@@ -155,9 +155,9 b' class IPControllerApp(BaseParallelApplication):'
155 import_statements = 'IPControllerApp.import_statements',
155 import_statements = 'IPControllerApp.import_statements',
156 location = 'IPControllerApp.location',
156 location = 'IPControllerApp.location',
157
157
158 ident = 'StreamSession.session',
158 ident = 'Session.session',
159 user = 'StreamSession.username',
159 user = 'Session.username',
160 exec_key = 'StreamSession.keyfile',
160 exec_key = 'Session.keyfile',
161
161
162 url = 'HubFactory.url',
162 url = 'HubFactory.url',
163 ip = 'HubFactory.ip',
163 ip = 'HubFactory.ip',
@@ -201,7 +201,7 b' class IPControllerApp(BaseParallelApplication):'
201 # load from engine config
201 # load from engine config
202 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
202 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
203 cfg = json.loads(f.read())
203 cfg = json.loads(f.read())
204 key = c.StreamSession.key = cfg['exec_key']
204 key = c.Session.key = cfg['exec_key']
205 xport,addr = cfg['url'].split('://')
205 xport,addr = cfg['url'].split('://')
206 c.HubFactory.engine_transport = xport
206 c.HubFactory.engine_transport = xport
207 ip,ports = addr.split(':')
207 ip,ports = addr.split(':')
@@ -239,9 +239,9 b' class IPControllerApp(BaseParallelApplication):'
239 # with open(keyfile, 'w') as f:
239 # with open(keyfile, 'w') as f:
240 # f.write(key)
240 # f.write(key)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 c.StreamSession.key = key
242 c.Session.key = key
243 else:
243 else:
244 key = c.StreamSession.key = ''
244 key = c.Session.key = ''
245
245
246 try:
246 try:
247 self.factory = HubFactory(config=c, log=self.log)
247 self.factory = HubFactory(config=c, log=self.log)
@@ -27,7 +27,7 b' from IPython.parallel.apps.baseapp import BaseParallelApplication'
27 from IPython.zmq.log import EnginePUBHandler
27 from IPython.zmq.log import EnginePUBHandler
28
28
29 from IPython.config.configurable import Configurable
29 from IPython.config.configurable import Configurable
30 from IPython.parallel.streamsession import StreamSession
30 from IPython.zmq.session import Session
31 from IPython.parallel.engine.engine import EngineFactory
31 from IPython.parallel.engine.engine import EngineFactory
32 from IPython.parallel.engine.streamkernel import Kernel
32 from IPython.parallel.engine.streamkernel import Kernel
33 from IPython.parallel.util import disambiguate_url
33 from IPython.parallel.util import disambiguate_url
@@ -100,7 +100,7 b' class IPEngineApp(BaseParallelApplication):'
100 app_name = Unicode(u'ipengine')
100 app_name = Unicode(u'ipengine')
101 description = Unicode(_description)
101 description = Unicode(_description)
102 config_file_name = Unicode(default_config_file_name)
102 config_file_name = Unicode(default_config_file_name)
103 classes = List([ProfileDir, StreamSession, EngineFactory, Kernel, MPI])
103 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
104
104
105 startup_script = Unicode(u'', config=True,
105 startup_script = Unicode(u'', config=True,
106 help='specify a script to be run at startup')
106 help='specify a script to be run at startup')
@@ -124,9 +124,9 b' class IPEngineApp(BaseParallelApplication):'
124 c = 'IPEngineApp.startup_command',
124 c = 'IPEngineApp.startup_command',
125 s = 'IPEngineApp.startup_script',
125 s = 'IPEngineApp.startup_script',
126
126
127 ident = 'StreamSession.session',
127 ident = 'Session.session',
128 user = 'StreamSession.username',
128 user = 'Session.username',
129 exec_key = 'StreamSession.keyfile',
129 exec_key = 'Session.keyfile',
130
130
131 url = 'EngineFactory.url',
131 url = 'EngineFactory.url',
132 ip = 'EngineFactory.ip',
132 ip = 'EngineFactory.ip',
@@ -190,7 +190,7 b' class IPEngineApp(BaseParallelApplication):'
190 if isinstance(v, unicode):
190 if isinstance(v, unicode):
191 d[k] = v.encode()
191 d[k] = v.encode()
192 if d['exec_key']:
192 if d['exec_key']:
193 config.StreamSession.key = d['exec_key']
193 config.Session.key = d['exec_key']
194 d['url'] = disambiguate_url(d['url'], d['location'])
194 d['url'] = disambiguate_url(d['url'], d['location'])
195 config.EngineFactory.url = d['url']
195 config.EngineFactory.url = d['url']
196 config.EngineFactory.location = d['location']
196 config.EngineFactory.location = d['location']
@@ -23,6 +23,7 b' pjoin = os.path.join'
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.jsonutil import extract_dates
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 Dict, List, Bool, Set)
29 Dict, List, Bool, Set)
@@ -30,9 +31,10 b' from IPython.external.decorator import decorator'
30 from IPython.external.ssh import tunnel
31 from IPython.external.ssh import tunnel
31
32
32 from IPython.parallel import error
33 from IPython.parallel import error
33 from IPython.parallel import streamsession as ss
34 from IPython.parallel import util
34 from IPython.parallel import util
35
35
36 from IPython.zmq.session import Session, Message
37
36 from .asyncresult import AsyncResult, AsyncHubResult
38 from .asyncresult import AsyncResult, AsyncHubResult
37 from IPython.core.newapplication import ProfileDir, ProfileDirError
39 from IPython.core.newapplication import ProfileDir, ProfileDirError
38 from .view import DirectView, LoadBalancedView
40 from .view import DirectView, LoadBalancedView
@@ -294,9 +296,9 b' class Client(HasTraits):'
294 arg = 'key'
296 arg = 'key'
295 key_arg = {arg:exec_key}
297 key_arg = {arg:exec_key}
296 if username is None:
298 if username is None:
297 self.session = ss.StreamSession(**key_arg)
299 self.session = Session(**key_arg)
298 else:
300 else:
299 self.session = ss.StreamSession(username=username, **key_arg)
301 self.session = Session(username=username, **key_arg)
300 self._query_socket = self._context.socket(zmq.XREQ)
302 self._query_socket = self._context.socket(zmq.XREQ)
301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
303 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 if self._ssh:
304 if self._ssh:
@@ -416,7 +418,7 b' class Client(HasTraits):'
416 idents,msg = self.session.recv(self._query_socket,mode=0)
418 idents,msg = self.session.recv(self._query_socket,mode=0)
417 if self.debug:
419 if self.debug:
418 pprint(msg)
420 pprint(msg)
419 msg = ss.Message(msg)
421 msg = Message(msg)
420 content = msg.content
422 content = msg.content
421 self._config['registration'] = dict(content)
423 self._config['registration'] = dict(content)
422 if content.status == 'ok':
424 if content.status == 'ok':
@@ -478,11 +480,11 b' class Client(HasTraits):'
478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
480 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479
481
480 if 'date' in parent:
482 if 'date' in parent:
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
483 md['submitted'] = parent['date']
482 if 'started' in header:
484 if 'started' in header:
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
485 md['started'] = header['started']
484 if 'date' in header:
486 if 'date' in header:
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
487 md['completed'] = header['date']
486 return md
488 return md
487
489
488 def _register_engine(self, msg):
490 def _register_engine(self, msg):
@@ -528,7 +530,7 b' class Client(HasTraits):'
528 header = {}
530 header = {}
529 parent['msg_id'] = msg_id
531 parent['msg_id'] = msg_id
530 header['engine'] = uuid
532 header['engine'] = uuid
531 header['date'] = datetime.now().strftime(util.ISO8601)
533 header['date'] = datetime.now()
532 msg = dict(parent_header=parent, header=header, content=content)
534 msg = dict(parent_header=parent, header=header, content=content)
533 self._handle_apply_reply(msg)
535 self._handle_apply_reply(msg)
534
536
@@ -551,7 +553,7 b' class Client(HasTraits):'
551
553
552 def _handle_apply_reply(self, msg):
554 def _handle_apply_reply(self, msg):
553 """Save the reply to an apply_request into our results."""
555 """Save the reply to an apply_request into our results."""
554 parent = msg['parent_header']
556 parent = extract_dates(msg['parent_header'])
555 msg_id = parent['msg_id']
557 msg_id = parent['msg_id']
556 if msg_id not in self.outstanding:
558 if msg_id not in self.outstanding:
557 if msg_id in self.history:
559 if msg_id in self.history:
@@ -563,7 +565,7 b' class Client(HasTraits):'
563 else:
565 else:
564 self.outstanding.remove(msg_id)
566 self.outstanding.remove(msg_id)
565 content = msg['content']
567 content = msg['content']
566 header = msg['header']
568 header = extract_dates(msg['header'])
567
569
568 # construct metadata:
570 # construct metadata:
569 md = self.metadata[msg_id]
571 md = self.metadata[msg_id]
@@ -589,33 +591,31 b' class Client(HasTraits):'
589 def _flush_notifications(self):
591 def _flush_notifications(self):
590 """Flush notifications of engine registrations waiting
592 """Flush notifications of engine registrations waiting
591 in ZMQ queue."""
593 in ZMQ queue."""
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
594 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 while msg is not None:
595 while msg is not None:
594 if self.debug:
596 if self.debug:
595 pprint(msg)
597 pprint(msg)
596 msg = msg[-1]
597 msg_type = msg['msg_type']
598 msg_type = msg['msg_type']
598 handler = self._notification_handlers.get(msg_type, None)
599 handler = self._notification_handlers.get(msg_type, None)
599 if handler is None:
600 if handler is None:
600 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 else:
602 else:
602 handler(msg)
603 handler(msg)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604
605
605 def _flush_results(self, sock):
606 def _flush_results(self, sock):
606 """Flush task or queue results waiting in ZMQ queue."""
607 """Flush task or queue results waiting in ZMQ queue."""
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 while msg is not None:
609 while msg is not None:
609 if self.debug:
610 if self.debug:
610 pprint(msg)
611 pprint(msg)
611 msg = msg[-1]
612 msg_type = msg['msg_type']
612 msg_type = msg['msg_type']
613 handler = self._queue_handlers.get(msg_type, None)
613 handler = self._queue_handlers.get(msg_type, None)
614 if handler is None:
614 if handler is None:
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 else:
616 else:
617 handler(msg)
617 handler(msg)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
618 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619
619
620 def _flush_control(self, sock):
620 def _flush_control(self, sock):
621 """Flush replies from the control channel waiting
621 """Flush replies from the control channel waiting
@@ -624,12 +624,12 b' class Client(HasTraits):'
624 Currently: ignore them."""
624 Currently: ignore them."""
625 if self._ignored_control_replies <= 0:
625 if self._ignored_control_replies <= 0:
626 return
626 return
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 while msg is not None:
628 while msg is not None:
629 self._ignored_control_replies -= 1
629 self._ignored_control_replies -= 1
630 if self.debug:
630 if self.debug:
631 pprint(msg)
631 pprint(msg)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633
633
634 def _flush_ignored_control(self):
634 def _flush_ignored_control(self):
635 """flush ignored control replies"""
635 """flush ignored control replies"""
@@ -638,19 +638,18 b' class Client(HasTraits):'
638 self._ignored_control_replies -= 1
638 self._ignored_control_replies -= 1
639
639
640 def _flush_ignored_hub_replies(self):
640 def _flush_ignored_hub_replies(self):
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
641 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 while msg is not None:
642 while msg is not None:
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
643 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644
644
645 def _flush_iopub(self, sock):
645 def _flush_iopub(self, sock):
646 """Flush replies from the iopub channel waiting
646 """Flush replies from the iopub channel waiting
647 in the ZMQ queue.
647 in the ZMQ queue.
648 """
648 """
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 while msg is not None:
650 while msg is not None:
651 if self.debug:
651 if self.debug:
652 pprint(msg)
652 pprint(msg)
653 msg = msg[-1]
654 parent = msg['parent_header']
653 parent = msg['parent_header']
655 msg_id = parent['msg_id']
654 msg_id = parent['msg_id']
656 content = msg['content']
655 content = msg['content']
@@ -674,7 +673,7 b' class Client(HasTraits):'
674 # reduntant?
673 # reduntant?
675 self.metadata[msg_id] = md
674 self.metadata[msg_id] = md
676
675
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
676 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678
677
679 #--------------------------------------------------------------------------
678 #--------------------------------------------------------------------------
680 # len, getitem
679 # len, getitem
@@ -1172,6 +1171,7 b' class Client(HasTraits):'
1172 failures = []
1171 failures = []
1173 # load cached results into result:
1172 # load cached results into result:
1174 content.update(local_results)
1173 content.update(local_results)
1174 content = extract_dates(content)
1175 # update cache with results:
1175 # update cache with results:
1176 for msg_id in sorted(theids):
1176 for msg_id in sorted(theids):
1177 if msg_id in content['completed']:
1177 if msg_id in content['completed']:
@@ -1338,6 +1338,8 b' class Client(HasTraits):'
1338 has_bufs = buffer_lens is not None
1338 has_bufs = buffer_lens is not None
1339 has_rbufs = result_buffer_lens is not None
1339 has_rbufs = result_buffer_lens is not None
1340 for i,rec in enumerate(records):
1340 for i,rec in enumerate(records):
1341 # unpack timestamps
1342 rec = extract_dates(rec)
1341 # relink buffers
1343 # relink buffers
1342 if has_bufs:
1344 if has_bufs:
1343 blen = buffer_lens[i]
1345 blen = buffer_lens[i]
@@ -1345,11 +1347,6 b' class Client(HasTraits):'
1345 if has_rbufs:
1347 if has_rbufs:
1346 blen = result_buffer_lens[i]
1348 blen = result_buffer_lens[i]
1347 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1349 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1348 # turn timestamps back into times
1349 for key in 'submitted started completed resubmitted'.split():
1350 maybedate = rec.get(key, None)
1351 if maybedate and util.ISO8601_RE.match(maybedate):
1352 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1353
1350
1354 return records
1351 return records
1355
1352
@@ -28,6 +28,7 b' from IPython.utils.importstring import import_item'
28 from IPython.utils.traitlets import (
28 from IPython.utils.traitlets import (
29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
30 )
30 )
31 from IPython.utils.jsonutil import ISO8601, extract_dates
31
32
32 from IPython.parallel import error, util
33 from IPython.parallel import error, util
33 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
34 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
@@ -71,13 +72,13 b' def empty_record():'
71
72
72 def init_record(msg):
73 def init_record(msg):
73 """Initialize a TaskRecord based on a request."""
74 """Initialize a TaskRecord based on a request."""
74 header = msg['header']
75 header = extract_dates(msg['header'])
75 return {
76 return {
76 'msg_id' : header['msg_id'],
77 'msg_id' : header['msg_id'],
77 'header' : header,
78 'header' : header,
78 'content': msg['content'],
79 'content': msg['content'],
79 'buffers': msg['buffers'],
80 'buffers': msg['buffers'],
80 'submitted': datetime.strptime(header['date'], util.ISO8601),
81 'submitted': header['date'],
81 'client_uuid' : None,
82 'client_uuid' : None,
82 'engine_uuid' : None,
83 'engine_uuid' : None,
83 'started': None,
84 'started': None,
@@ -295,7 +296,7 b' class Hub(LoggingFactory):'
295 Parameters
296 Parameters
296 ==========
297 ==========
297 loop: zmq IOLoop instance
298 loop: zmq IOLoop instance
298 session: StreamSession object
299 session: Session object
299 <removed> context: zmq context for creating new connections (?)
300 <removed> context: zmq context for creating new connections (?)
300 queue: ZMQStream for monitoring the command queue (SUB)
301 queue: ZMQStream for monitoring the command queue (SUB)
301 query: ZMQStream for engine registration and client queries requests (XREP)
302 query: ZMQStream for engine registration and client queries requests (XREP)
@@ -610,11 +611,9 b' class Hub(LoggingFactory):'
610 self.log.warn("queue:: unknown msg finished %r"%msg_id)
611 self.log.warn("queue:: unknown msg finished %r"%msg_id)
611 return
612 return
612 # update record anyway, because the unregistration could have been premature
613 # update record anyway, because the unregistration could have been premature
613 rheader = msg['header']
614 rheader = extract_dates(msg['header'])
614 completed = datetime.strptime(rheader['date'], util.ISO8601)
615 completed = rheader['date']
615 started = rheader.get('started', None)
616 started = rheader.get('started', None)
616 if started is not None:
617 started = datetime.strptime(started, util.ISO8601)
618 result = {
617 result = {
619 'result_header' : rheader,
618 'result_header' : rheader,
620 'result_content': msg['content'],
619 'result_content': msg['content'],
@@ -695,7 +694,7 b' class Hub(LoggingFactory):'
695 if msg_id in self.unassigned:
694 if msg_id in self.unassigned:
696 self.unassigned.remove(msg_id)
695 self.unassigned.remove(msg_id)
697
696
698 header = msg['header']
697 header = extract_dates(msg['header'])
699 engine_uuid = header.get('engine', None)
698 engine_uuid = header.get('engine', None)
700 eid = self.by_ident.get(engine_uuid, None)
699 eid = self.by_ident.get(engine_uuid, None)
701
700
@@ -706,10 +705,8 b' class Hub(LoggingFactory):'
706 self.completed[eid].append(msg_id)
705 self.completed[eid].append(msg_id)
707 if msg_id in self.tasks[eid]:
706 if msg_id in self.tasks[eid]:
708 self.tasks[eid].remove(msg_id)
707 self.tasks[eid].remove(msg_id)
709 completed = datetime.strptime(header['date'], util.ISO8601)
708 completed = header['date']
710 started = header.get('started', None)
709 started = header.get('started', None)
711 if started is not None:
712 started = datetime.strptime(started, util.ISO8601)
713 result = {
710 result = {
714 'result_header' : header,
711 'result_header' : header,
715 'result_content': msg['content'],
712 'result_content': msg['content'],
@@ -1141,7 +1138,7 b' class Hub(LoggingFactory):'
1141 reply = error.wrap_exception()
1138 reply = error.wrap_exception()
1142 else:
1139 else:
1143 # send the messages
1140 # send the messages
1144 now_s = now.strftime(util.ISO8601)
1141 now_s = now.strftime(ISO8601)
1145 for rec in records:
1142 for rec in records:
1146 header = rec['header']
1143 header = rec['header']
1147 # include resubmitted in header to prevent digest collision
1144 # include resubmitted in header to prevent digest collision
@@ -17,7 +17,7 b' from zmq.eventloop import ioloop'
17
17
18 from IPython.utils.traitlets import Unicode, Instance, List
18 from IPython.utils.traitlets import Unicode, Instance, List
19 from .dictdb import BaseDB
19 from .dictdb import BaseDB
20 from IPython.parallel.util import ISO8601
20 from IPython.utils.jsonutil import date_default, extract_dates
21
21
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 # SQLite operators, adapters, and converters
23 # SQLite operators, adapters, and converters
@@ -52,13 +52,13 b' def _convert_datetime(ds):'
52 return datetime.strptime(ds, ISO8601)
52 return datetime.strptime(ds, ISO8601)
53
53
54 def _adapt_dict(d):
54 def _adapt_dict(d):
55 return json.dumps(d)
55 return json.dumps(d, default=date_default)
56
56
57 def _convert_dict(ds):
57 def _convert_dict(ds):
58 if ds is None:
58 if ds is None:
59 return ds
59 return ds
60 else:
60 else:
61 return json.loads(ds)
61 return extract_dates(json.loads(ds))
62
62
63 def _adapt_bufs(bufs):
63 def _adapt_bufs(bufs):
64 # this is *horrible*
64 # this is *horrible*
@@ -24,9 +24,10 b' from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode'
24
24
25 from IPython.parallel.controller.heartmonitor import Heart
25 from IPython.parallel.controller.heartmonitor import Heart
26 from IPython.parallel.factory import RegistrationFactory
26 from IPython.parallel.factory import RegistrationFactory
27 from IPython.parallel.streamsession import Message
28 from IPython.parallel.util import disambiguate_url
27 from IPython.parallel.util import disambiguate_url
29
28
29 from IPython.zmq.session import Message
30
30 from .streamkernel import Kernel
31 from .streamkernel import Kernel
31
32
32 class EngineFactory(RegistrationFactory):
33 class EngineFactory(RegistrationFactory):
@@ -8,7 +8,7 b''
8
8
9 from zmq.eventloop import ioloop
9 from zmq.eventloop import ioloop
10
10
11 from IPython.parallel.streamsession import StreamSession
11 from IPython.zmq.session import Session
12
12
13 class KernelStarter(object):
13 class KernelStarter(object):
14 """Object for resetting/killing the Kernel."""
14 """Object for resetting/killing the Kernel."""
@@ -213,7 +213,7 b' def make_starter(up_addr, down_addr, *args, **kwargs):'
213 """entry point function for launching a kernelstarter in a subprocess"""
213 """entry point function for launching a kernelstarter in a subprocess"""
214 loop = ioloop.IOLoop.instance()
214 loop = ioloop.IOLoop.instance()
215 ctx = zmq.Context()
215 ctx = zmq.Context()
216 session = StreamSession()
216 session = Session()
217 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
217 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
218 upstream.connect(up_addr)
218 upstream.connect(up_addr)
219 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
219 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
@@ -28,12 +28,13 b' import zmq'
28 from zmq.eventloop import ioloop, zmqstream
28 from zmq.eventloop import ioloop, zmqstream
29
29
30 # Local imports.
30 # Local imports.
31 from IPython.utils.jsonutil import ISO8601
31 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 from IPython.zmq.completer import KernelCompleter
33 from IPython.zmq.completer import KernelCompleter
33
34
34 from IPython.parallel.error import wrap_exception
35 from IPython.parallel.error import wrap_exception
35 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.util import serialize_object, unpack_apply_message, ISO8601
37 from IPython.parallel.util import serialize_object, unpack_apply_message
37
38
38 def printer(*args):
39 def printer(*args):
39 pprint(args, stream=sys.__stdout__)
40 pprint(args, stream=sys.__stdout__)
@@ -42,7 +43,7 b' def printer(*args):'
42 class _Passer(zmqstream.ZMQStream):
43 class _Passer(zmqstream.ZMQStream):
43 """Empty class that implements `send()` that does nothing.
44 """Empty class that implements `send()` that does nothing.
44
45
45 Subclass ZMQStream for StreamSession typechecking
46 Subclass ZMQStream for Session typechecking
46
47
47 """
48 """
48 def __init__(self, *args, **kwargs):
49 def __init__(self, *args, **kwargs):
@@ -21,8 +21,8 b' from zmq.eventloop.ioloop import IOLoop'
21 from IPython.config.configurable import Configurable
21 from IPython.config.configurable import Configurable
22 from IPython.utils.traitlets import Int, Instance, Unicode
22 from IPython.utils.traitlets import Int, Instance, Unicode
23
23
24 import IPython.parallel.streamsession as ss
25 from IPython.parallel.util import select_random_ports
24 from IPython.parallel.util import select_random_ports
25 from IPython.zmq.session import Session
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Classes
28 # Classes
@@ -43,7 +43,7 b' class SessionFactory(LoggingFactory):'
43 def _context_default(self):
43 def _context_default(self):
44 return zmq.Context.instance()
44 return zmq.Context.instance()
45
45
46 session = Instance('IPython.parallel.streamsession.StreamSession')
46 session = Instance('IPython.zmq.session.Session')
47 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
47 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
48 def _loop_default(self):
48 def _loop_default(self):
49 return IOLoop.instance()
49 return IOLoop.instance()
@@ -53,7 +53,7 b' class SessionFactory(LoggingFactory):'
53 super(SessionFactory, self).__init__(**kwargs)
53 super(SessionFactory, self).__init__(**kwargs)
54
54
55 # construct the session
55 # construct the session
56 self.session = ss.StreamSession(**kwargs)
56 self.session = Session(**kwargs)
57
57
58
58
59 class RegistrationFactory(SessionFactory):
59 class RegistrationFactory(SessionFactory):
@@ -20,18 +20,21 b' from unittest import TestCase'
20
20
21 from nose import SkipTest
21 from nose import SkipTest
22
22
23 from IPython.parallel import error, streamsession as ss
23 from IPython.parallel import error
24 from IPython.parallel.controller.dictdb import DictDB
24 from IPython.parallel.controller.dictdb import DictDB
25 from IPython.parallel.controller.sqlitedb import SQLiteDB
25 from IPython.parallel.controller.sqlitedb import SQLiteDB
26 from IPython.parallel.controller.hub import init_record, empty_record
26 from IPython.parallel.controller.hub import init_record, empty_record
27
27
28 from IPython.zmq.session import Session
29
30
28 #-------------------------------------------------------------------------------
31 #-------------------------------------------------------------------------------
29 # TestCases
32 # TestCases
30 #-------------------------------------------------------------------------------
33 #-------------------------------------------------------------------------------
31
34
32 class TestDictBackend(TestCase):
35 class TestDictBackend(TestCase):
33 def setUp(self):
36 def setUp(self):
34 self.session = ss.StreamSession()
37 self.session = Session()
35 self.db = self.create_db()
38 self.db = self.create_db()
36 self.load_records(16)
39 self.load_records(16)
37
40
@@ -17,7 +17,6 b' import re'
17 import stat
17 import stat
18 import socket
18 import socket
19 import sys
19 import sys
20 from datetime import datetime
21 from signal import signal, SIGINT, SIGABRT, SIGTERM
20 from signal import signal, SIGINT, SIGABRT, SIGTERM
22 try:
21 try:
23 from signal import SIGKILL
22 from signal import SIGKILL
@@ -40,10 +39,6 b' from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence'
40 from IPython.utils.newserialized import serialize, unserialize
39 from IPython.utils.newserialized import serialize, unserialize
41 from IPython.zmq.log import EnginePUBHandler
40 from IPython.zmq.log import EnginePUBHandler
42
41
43 # globals
44 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
45 ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
46
47 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
48 # Classes
43 # Classes
49 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
@@ -101,18 +96,6 b' class ReverseDict(dict):'
101 # Functions
96 # Functions
102 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
103
98
104 def extract_dates(obj):
105 """extract ISO8601 dates from unpacked JSON"""
106 if isinstance(obj, dict):
107 for k,v in obj.iteritems():
108 obj[k] = extract_dates(v)
109 elif isinstance(obj, list):
110 obj = [ extract_dates(o) for o in obj ]
111 elif isinstance(obj, basestring):
112 if ISO8601_RE.match(obj):
113 obj = datetime.strptime(obj, ISO8601)
114 return obj
115
116 def validate_url(url):
99 def validate_url(url):
117 """validate a url for zeromq"""
100 """validate a url for zeromq"""
118 if not isinstance(url, basestring):
101 if not isinstance(url, basestring):
@@ -11,12 +11,43 b''
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import re
14 import types
15 import types
16 from datetime import datetime
17
18 #-----------------------------------------------------------------------------
19 # Globals and constants
20 #-----------------------------------------------------------------------------
21
22 # timestamp formats
23 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
24 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
15
25
16 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
17 # Classes and functions
27 # Classes and functions
18 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
19
29
30 def extract_dates(obj):
31 """extract ISO8601 dates from unpacked JSON"""
32 if isinstance(obj, dict):
33 for k,v in obj.iteritems():
34 obj[k] = extract_dates(v)
35 elif isinstance(obj, list):
36 obj = [ extract_dates(o) for o in obj ]
37 elif isinstance(obj, basestring):
38 if ISO8601_PAT.match(obj):
39 obj = datetime.strptime(obj, ISO8601)
40 return obj
41
42 def date_default(obj):
43 """default function for packing datetime objects"""
44 if isinstance(obj, datetime):
45 return obj.strftime(ISO8601)
46 else:
47 raise TypeError("%r is not JSON serializable"%obj)
48
49
50
20 def json_clean(obj):
51 def json_clean(obj):
21 """Clean an object to ensure it's safe to encode in JSON.
52 """Clean an object to ensure it's safe to encode in JSON.
22
53
@@ -1,10 +1,77 b''
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
3 """
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team
6 #
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
10
11 #-----------------------------------------------------------------------------
12 # Imports
13 #-----------------------------------------------------------------------------
14
15 import hmac
1 import os
16 import os
2 import uuid
3 import pprint
17 import pprint
18 import uuid
19 from datetime import datetime
20
21 try:
22 import cPickle
23 pickle = cPickle
24 except:
25 cPickle = None
26 import pickle
4
27
5 import zmq
28 import zmq
29 from zmq.utils import jsonapi
30 from zmq.eventloop.zmqstream import ZMQStream
31
32 from IPython.config.configurable import Configurable
33 from IPython.utils.importstring import import_item
34 from IPython.utils.jsonutil import date_default
35 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
36
37 #-----------------------------------------------------------------------------
38 # utility functions
39 #-----------------------------------------------------------------------------
40
41 def squash_unicode(obj):
42 """coerce unicode back to bytestrings."""
43 if isinstance(obj,dict):
44 for key in obj.keys():
45 obj[key] = squash_unicode(obj[key])
46 if isinstance(key, unicode):
47 obj[squash_unicode(key)] = obj.pop(key)
48 elif isinstance(obj, list):
49 for i,v in enumerate(obj):
50 obj[i] = squash_unicode(v)
51 elif isinstance(obj, unicode):
52 obj = obj.encode('utf8')
53 return obj
54
55 #-----------------------------------------------------------------------------
56 # globals and defaults
57 #-----------------------------------------------------------------------------
58
59 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
60 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:date_default})
61 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
6
62
7 from zmq.utils import jsonapi as json
63 pickle_packer = lambda o: pickle.dumps(o,-1)
64 pickle_unpacker = pickle.loads
65
66 default_packer = json_packer
67 default_unpacker = json_unpacker
68
69
70 DELIM="<IDS|MSG>"
71
72 #-----------------------------------------------------------------------------
73 # Classes
74 #-----------------------------------------------------------------------------
8
75
9 class Message(object):
76 class Message(object):
10 """A simple message object that maps dict keys to attributes.
77 """A simple message object that maps dict keys to attributes.
@@ -14,7 +81,7 b' class Message(object):'
14
81
15 def __init__(self, msg_dict):
82 def __init__(self, msg_dict):
16 dct = self.__dict__
83 dct = self.__dict__
17 for k, v in msg_dict.iteritems():
84 for k, v in dict(msg_dict).iteritems():
18 if isinstance(v, dict):
85 if isinstance(v, dict):
19 v = Message(v)
86 v = Message(v)
20 dct[k] = v
87 dct[k] = v
@@ -36,13 +103,9 b' class Message(object):'
36 return self.__dict__[k]
103 return self.__dict__[k]
37
104
38
105
39 def msg_header(msg_id, username, session):
106 def msg_header(msg_id, msg_type, username, session):
40 return {
107 date=datetime.now()
41 'msg_id' : msg_id,
108 return locals()
42 'username' : username,
43 'session' : session
44 }
45
46
109
47 def extract_header(msg_or_header):
110 def extract_header(msg_or_header):
48 """Given a message or header, return the header."""
111 """Given a message or header, return the header."""
@@ -63,109 +126,341 b' def extract_header(msg_or_header):'
63 h = dict(h)
126 h = dict(h)
64 return h
127 return h
65
128
66
129 class Session(Configurable):
67 class Session(object):
130 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
68
131 debug=Bool(False, config=True, help="""Debug output in the Session""")
69 def __init__(self, username=os.environ.get('USER','username'), session=None):
132 packer = Unicode('json',config=True,
70 self.username = username
133 help="""The name of the packer for serializing messages.
71 if session is None:
134 Should be one of 'json', 'pickle', or an import name
72 self.session = str(uuid.uuid4())
135 for a custom serializer.""")
136 def _packer_changed(self, name, old, new):
137 if new.lower() == 'json':
138 self.pack = json_packer
139 self.unpack = json_unpacker
140 elif new.lower() == 'pickle':
141 self.pack = pickle_packer
142 self.unpack = pickle_unpacker
73 else:
143 else:
74 self.session = session
144 self.pack = import_item(new)
75 self.msg_id = 0
76
145
77 def msg_header(self):
146 unpacker = Unicode('json',config=True,
78 h = msg_header(self.msg_id, self.username, self.session)
147 help="""The name of the unpacker for unserializing messages.
79 self.msg_id += 1
148 Only used with custom functions for `packer`.""")
80 return h
149 def _unpacker_changed(self, name, old, new):
150 if new.lower() == 'json':
151 self.pack = json_packer
152 self.unpack = json_unpacker
153 elif new.lower() == 'pickle':
154 self.pack = pickle_packer
155 self.unpack = pickle_unpacker
156 else:
157 self.unpack = import_item(new)
158
159 session = CStr('',config=True,
160 help="""The UUID identifying this session.""")
161 def _session_default(self):
162 return bytes(uuid.uuid4())
163 username = Unicode(os.environ.get('USER','username'), config=True,
164 help="""Username for the Session. Default is your system username.""")
165
166 # message signature related traits:
167 key = CStr('', config=True,
168 help="""execution key, for extra authentication.""")
169 def _key_changed(self, name, old, new):
170 if new:
171 self.auth = hmac.HMAC(new)
172 else:
173 self.auth = None
174 auth = Instance(hmac.HMAC)
175 counters = Instance('collections.defaultdict', (int,))
176 digest_history = Set()
177
178 keyfile = Unicode('', config=True,
179 help="""path to file containing execution key.""")
180 def _keyfile_changed(self, name, old, new):
181 with open(new, 'rb') as f:
182 self.key = f.read().strip()
81
183
82 def msg(self, msg_type, content=None, parent=None):
184 pack = Any(default_packer) # the actual packer function
83 """Construct a standard-form message, with a given type, content, and parent.
185 def _pack_changed(self, name, old, new):
186 if not callable(new):
187 raise TypeError("packer must be callable, not %s"%type(new))
84
188
85 NOT to be called directly.
189 unpack = Any(default_unpacker) # the actual packer function
86 """
190 def _unpack_changed(self, name, old, new):
191 if not callable(new):
192 raise TypeError("packer must be callable, not %s"%type(new))
193
194 def __init__(self, **kwargs):
195 super(Session, self).__init__(**kwargs)
196 self.none = self.pack({})
197
198 @property
199 def msg_id(self):
200 """always return new uuid"""
201 return str(uuid.uuid4())
202
203 def msg_header(self, msg_type):
204 return msg_header(self.msg_id, msg_type, self.username, self.session)
205
206 def msg(self, msg_type, content=None, parent=None, subheader=None):
87 msg = {}
207 msg = {}
88 msg['header'] = self.msg_header()
208 msg['header'] = self.msg_header(msg_type)
209 msg['msg_id'] = msg['header']['msg_id']
89 msg['parent_header'] = {} if parent is None else extract_header(parent)
210 msg['parent_header'] = {} if parent is None else extract_header(parent)
90 msg['msg_type'] = msg_type
211 msg['msg_type'] = msg_type
91 msg['content'] = {} if content is None else content
212 msg['content'] = {} if content is None else content
213 sub = {} if subheader is None else subheader
214 msg['header'].update(sub)
92 return msg
215 return msg
93
216
94 def send(self, socket, msg_or_type, content=None, parent=None, ident=None):
217 def check_key(self, msg_or_header):
95 """send a message via a socket, using a uniform message pattern.
218 """Check that a message's header has the right key"""
219 if not self.key:
220 return True
221 header = extract_header(msg_or_header)
222 return header.get('key', '') == self.key
223
224 def sign(self, msg):
225 """Sign a message with HMAC digest. If no auth, return b''."""
226 if self.auth is None:
227 return b''
228 h = self.auth.copy()
229 for m in msg:
230 h.update(m)
231 return h.hexdigest()
232
233 def serialize(self, msg, ident=None):
234 content = msg.get('content', {})
235 if content is None:
236 content = self.none
237 elif isinstance(content, dict):
238 content = self.pack(content)
239 elif isinstance(content, bytes):
240 # content is already packed, as in a relayed message
241 pass
242 elif isinstance(content, unicode):
243 # should be bytes, but JSON often spits out unicode
244 content = content.encode('utf8')
245 else:
246 raise TypeError("Content incorrect type: %s"%type(content))
247
248 real_message = [self.pack(msg['header']),
249 self.pack(msg['parent_header']),
250 content
251 ]
252
253 to_send = []
254
255 if isinstance(ident, list):
256 # accept list of idents
257 to_send.extend(ident)
258 elif ident is not None:
259 to_send.append(ident)
260 to_send.append(DELIM)
261
262 signature = self.sign(real_message)
263 to_send.append(signature)
264
265 to_send.extend(real_message)
266
267 return to_send
268
269 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
270 buffers=None, subheader=None, track=False):
271 """Build and send a message via stream or socket.
96
272
97 Parameters
273 Parameters
98 ----------
274 ----------
99 socket : zmq.Socket
275
100 The socket on which to send.
276 stream : zmq.Socket or ZMQStream
101 msg_or_type : Message/dict or str
277 the socket-like object used to send the data
102 if str : then a new message will be constructed from content,parent
278 msg_or_type : str or Message/dict
103 if Message/dict : then content and parent are ignored, and the message
279 Normally, msg_or_type will be a msg_type unless a message is being sent more
104 is sent. This is only for use when sending a Message for a second time.
280 than once.
105 content : dict, optional
281
106 The contents of the message
282 content : dict or None
107 parent : dict, optional
283 the content of the message (ignored if msg_or_type is a message)
108 The parent header, or parent message, of this message
284 parent : Message or dict or None
109 ident : bytes, optional
285 the parent or parent header describing the parent of this message
110 The zmq.IDENTITY prefix of the destination.
286 ident : bytes or list of bytes
111 Only for use on certain socket types.
287 the zmq.IDENTITY routing path
288 subheader : dict or None
289 extra header keys for this message's header
290 buffers : list or None
291 the already-serialized buffers to be appended to the message
292 track : bool
293 whether to track. Only for use with Sockets,
294 because ZMQStream objects cannot track messages.
112
295
113 Returns
296 Returns
114 -------
297 -------
115 msg : dict
298 msg : message dict
116 The message, as constructed by self.msg(msg_type,content,parent)
299 the constructed message
300 (msg,tracker) : (message dict, MessageTracker)
301 if track=True, then a 2-tuple will be returned,
302 the first element being the constructed
303 message, and the second being the MessageTracker
304
117 """
305 """
306
307 if not isinstance(stream, (zmq.Socket, ZMQStream)):
308 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
309 elif track and isinstance(stream, ZMQStream):
310 raise TypeError("ZMQStream cannot track messages")
311
118 if isinstance(msg_or_type, (Message, dict)):
312 if isinstance(msg_or_type, (Message, dict)):
119 msg = dict(msg_or_type)
313 # we got a Message, not a msg_type
314 # don't build a new Message
315 msg = msg_or_type
120 else:
316 else:
121 msg = self.msg(msg_or_type, content, parent)
317 msg = self.msg(msg_or_type, content, parent, subheader)
122 if ident is not None:
123 socket.send(ident, zmq.SNDMORE)
124 socket.send_json(msg)
125 return msg
126
127 def recv(self, socket, mode=zmq.NOBLOCK):
128 """recv a message on a socket.
129
318
130 Receive an optionally identity-prefixed message, as sent via session.send().
319 buffers = [] if buffers is None else buffers
320 to_send = self.serialize(msg, ident)
321 flag = 0
322 if buffers:
323 flag = zmq.SNDMORE
324 _track = False
325 else:
326 _track=track
327 if track:
328 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
329 else:
330 tracker = stream.send_multipart(to_send, flag, copy=False)
331 for b in buffers[:-1]:
332 stream.send(b, flag, copy=False)
333 if buffers:
334 if track:
335 tracker = stream.send(buffers[-1], copy=False, track=track)
336 else:
337 tracker = stream.send(buffers[-1], copy=False)
338
339 # omsg = Message(msg)
340 if self.debug:
341 pprint.pprint(msg)
342 pprint.pprint(to_send)
343 pprint.pprint(buffers)
131
344
132 Parameters
345 msg['tracker'] = tracker
133 ----------
134
346
135 socket : zmq.Socket
347 return msg
136 The socket on which to recv a message.
348
137 mode : int, optional
349 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
138 the mode flag passed to socket.recv
350 """Send a raw message via ident path.
139 default: zmq.NOBLOCK
140
351
141 Returns
352 Parameters
142 -------
353 ----------
143 (ident,msg) : tuple
354 msg : list of sendable buffers"""
144 always length 2. If no message received, then return is (None,None)
355 to_send = []
145 ident : bytes or None
356 if isinstance(ident, bytes):
146 the identity prefix is there was one, None otherwise.
357 ident = [ident]
147 msg : dict or None
358 if ident is not None:
148 The actual message. If mode==zmq.NOBLOCK and no message was waiting,
359 to_send.extend(ident)
149 it will be None.
360
150 """
361 to_send.append(DELIM)
362 to_send.append(self.sign(msg))
363 to_send.extend(msg)
364 stream.send_multipart(msg, flags, copy=copy)
365
366 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
367 """receives and unpacks a message
368 returns [idents], msg"""
369 if isinstance(socket, ZMQStream):
370 socket = socket.socket
151 try:
371 try:
152 msg = socket.recv_multipart(mode)
372 msg = socket.recv_multipart(mode)
153 except zmq.ZMQError, e:
373 except zmq.ZMQError as e:
154 if e.errno == zmq.EAGAIN:
374 if e.errno == zmq.EAGAIN:
155 # We can convert EAGAIN to None as we know in this case
375 # We can convert EAGAIN to None as we know in this case
156 # recv_json won't return None.
376 # recv_multipart won't return None.
157 return None,None
377 return None,None
158 else:
378 else:
159 raise
379 raise
160 if len(msg) == 1:
380 # return an actual Message object
161 ident=None
381 # determine the number of idents by trying to unpack them.
162 msg = msg[0]
382 # this is terrible:
163 elif len(msg) == 2:
383 idents, msg = self.feed_identities(msg, copy)
164 ident, msg = msg
384 try:
385 return idents, self.unpack_message(msg, content=content, copy=copy)
386 except Exception as e:
387 print (idents, msg)
388 # TODO: handle it
389 raise e
390
391 def feed_identities(self, msg, copy=True):
392 """feed until DELIM is reached, then return the prefix as idents and remainder as
393 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
394
395 Parameters
396 ----------
397 msg : a list of Message or bytes objects
398 the message to be split
399 copy : bool
400 flag determining whether the arguments are bytes or Messages
401
402 Returns
403 -------
404 (idents,msg) : two lists
405 idents will always be a list of bytes - the indentity prefix
406 msg will be a list of bytes or Messages, unchanged from input
407 msg should be unpackable via self.unpack_message at this point.
408 """
409 if copy:
410 idx = msg.index(DELIM)
411 return msg[:idx], msg[idx+1:]
412 else:
413 failed = True
414 for idx,m in enumerate(msg):
415 if m.bytes == DELIM:
416 failed = False
417 break
418 if failed:
419 raise ValueError("DELIM not in msg")
420 idents, msg = msg[:idx], msg[idx+1:]
421 return [m.bytes for m in idents], msg
422
423 def unpack_message(self, msg, content=True, copy=True):
424 """Return a message object from the format
425 sent by self.send.
426
427 Parameters:
428 -----------
429
430 content : bool (True)
431 whether to unpack the content dict (True),
432 or leave it serialized (False)
433
434 copy : bool (True)
435 whether to return the bytes (True),
436 or the non-copying Message object in each place (False)
437
438 """
439 minlen = 4
440 message = {}
441 if not copy:
442 for i in range(minlen):
443 msg[i] = msg[i].bytes
444 if self.auth is not None:
445 signature = msg[0]
446 if signature in self.digest_history:
447 raise ValueError("Duplicate Signature: %r"%signature)
448 self.digest_history.add(signature)
449 check = self.sign(msg[1:4])
450 if not signature == check:
451 raise ValueError("Invalid Signature: %r"%signature)
452 if not len(msg) >= minlen:
453 raise TypeError("malformed message, must have at least %i elements"%minlen)
454 message['header'] = self.unpack(msg[1])
455 message['msg_type'] = message['header']['msg_type']
456 message['parent_header'] = self.unpack(msg[2])
457 if content:
458 message['content'] = self.unpack(msg[3])
165 else:
459 else:
166 raise ValueError("Got message with length > 2, which is invalid")
460 message['content'] = msg[3]
167
461
168 return ident, json.loads(msg)
462 message['buffers'] = msg[4:]
463 return message
169
464
170 def test_msg2obj():
465 def test_msg2obj():
171 am = dict(x=1)
466 am = dict(x=1)
@@ -17,14 +17,14 b' import zmq'
17
17
18 from zmq.tests import BaseZMQTestCase
18 from zmq.tests import BaseZMQTestCase
19 from zmq.eventloop.zmqstream import ZMQStream
19 from zmq.eventloop.zmqstream import ZMQStream
20 # from IPython.zmq.tests import SessionTestCase
20
21 from IPython.parallel import streamsession as ss
21 from IPython.zmq import session as ss
22
22
23 class SessionTestCase(BaseZMQTestCase):
23 class SessionTestCase(BaseZMQTestCase):
24
24
25 def setUp(self):
25 def setUp(self):
26 BaseZMQTestCase.setUp(self)
26 BaseZMQTestCase.setUp(self)
27 self.session = ss.StreamSession()
27 self.session = ss.Session()
28
28
29 class TestSession(SessionTestCase):
29 class TestSession(SessionTestCase):
30
30
@@ -42,19 +42,19 b' class TestSession(SessionTestCase):'
42
42
43
43
44 def test_args(self):
44 def test_args(self):
45 """initialization arguments for StreamSession"""
45 """initialization arguments for Session"""
46 s = self.session
46 s = self.session
47 self.assertTrue(s.pack is ss.default_packer)
47 self.assertTrue(s.pack is ss.default_packer)
48 self.assertTrue(s.unpack is ss.default_unpacker)
48 self.assertTrue(s.unpack is ss.default_unpacker)
49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
50
50
51 s = ss.StreamSession()
51 s = ss.Session()
52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
53
53
54 self.assertRaises(TypeError, ss.StreamSession, pack='hi')
54 self.assertRaises(TypeError, ss.Session, pack='hi')
55 self.assertRaises(TypeError, ss.StreamSession, unpack='hi')
55 self.assertRaises(TypeError, ss.Session, unpack='hi')
56 u = str(uuid.uuid4())
56 u = str(uuid.uuid4())
57 s = ss.StreamSession(username='carrot', session=u)
57 s = ss.Session(username='carrot', session=u)
58 self.assertEquals(s.session, u)
58 self.assertEquals(s.session, u)
59 self.assertEquals(s.username, 'carrot')
59 self.assertEquals(s.username, 'carrot')
60
60
1 NO CONTENT: file was removed
NO CONTENT: file was removed
1 NO CONTENT: file was removed
NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now