##// END OF EJS Templates
fixed buffer serialization for buffers below threshold
MinRK -
Show More
@@ -1,167 +1,167 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
3
3
4 """Refactored serialization classes and interfaces."""
4 """Refactored serialization classes and interfaces."""
5
5
6 __docformat__ = "restructuredtext en"
6 __docformat__ = "restructuredtext en"
7
7
8 # Tell nose to skip this module
8 # Tell nose to skip this module
9 __test__ = {}
9 __test__ = {}
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
12 # Copyright (C) 2008 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21
21
22 import cPickle as pickle
22 import cPickle as pickle
23
23
24 # from twisted.python import components
24 # from twisted.python import components
25 # from zope.interface import Interface, implements
25 # from zope.interface import Interface, implements
26
26
27 try:
27 try:
28 import numpy
28 import numpy
29 except ImportError:
29 except ImportError:
30 pass
30 pass
31
31
32 from IPython.kernel.error import SerializationError
32 from IPython.kernel.error import SerializationError
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # Classes and functions
35 # Classes and functions
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 class ISerialized:
38 class ISerialized:
39
39
40 def getData():
40 def getData():
41 """"""
41 """"""
42
42
43 def getDataSize(units=10.0**6):
43 def getDataSize(units=10.0**6):
44 """"""
44 """"""
45
45
46 def getTypeDescriptor():
46 def getTypeDescriptor():
47 """"""
47 """"""
48
48
49 def getMetadata():
49 def getMetadata():
50 """"""
50 """"""
51
51
52
52
53 class IUnSerialized:
53 class IUnSerialized:
54
54
55 def getObject():
55 def getObject():
56 """"""
56 """"""
57
57
58 class Serialized(object):
58 class Serialized(object):
59
59
60 # implements(ISerialized)
60 # implements(ISerialized)
61
61
62 def __init__(self, data, typeDescriptor, metadata={}):
62 def __init__(self, data, typeDescriptor, metadata={}):
63 self.data = data
63 self.data = data
64 self.typeDescriptor = typeDescriptor
64 self.typeDescriptor = typeDescriptor
65 self.metadata = metadata
65 self.metadata = metadata
66
66
67 def getData(self):
67 def getData(self):
68 return self.data
68 return self.data
69
69
70 def getDataSize(self, units=10.0**6):
70 def getDataSize(self, units=10.0**6):
71 return len(self.data)/units
71 return len(self.data)/units
72
72
73 def getTypeDescriptor(self):
73 def getTypeDescriptor(self):
74 return self.typeDescriptor
74 return self.typeDescriptor
75
75
76 def getMetadata(self):
76 def getMetadata(self):
77 return self.metadata
77 return self.metadata
78
78
79
79
80 class UnSerialized(object):
80 class UnSerialized(object):
81
81
82 # implements(IUnSerialized)
82 # implements(IUnSerialized)
83
83
84 def __init__(self, obj):
84 def __init__(self, obj):
85 self.obj = obj
85 self.obj = obj
86
86
87 def getObject(self):
87 def getObject(self):
88 return self.obj
88 return self.obj
89
89
90
90
91 class SerializeIt(object):
91 class SerializeIt(object):
92
92
93 # implements(ISerialized)
93 # implements(ISerialized)
94
94
95 def __init__(self, unSerialized):
95 def __init__(self, unSerialized):
96 self.data = None
96 self.data = None
97 self.obj = unSerialized.getObject()
97 self.obj = unSerialized.getObject()
98 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
98 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
99 if len(self.obj) == 0: # length 0 arrays can't be reconstructed
99 if len(self.obj) == 0: # length 0 arrays can't be reconstructed
100 raise SerializationError("You cannot send a length 0 array")
100 raise SerializationError("You cannot send a length 0 array")
101 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
101 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
102 self.typeDescriptor = 'ndarray'
102 self.typeDescriptor = 'ndarray'
103 self.metadata = {'shape':self.obj.shape,
103 self.metadata = {'shape':self.obj.shape,
104 'dtype':self.obj.dtype.str}
104 'dtype':self.obj.dtype.str}
105 elif isinstance(self.obj, str):
105 elif isinstance(self.obj, str):
106 self.typeDescriptor = 'bytes'
106 self.typeDescriptor = 'bytes'
107 self.metadata = {}
107 self.metadata = {}
108 elif isinstance(self.obj, buffer):
108 elif isinstance(self.obj, buffer):
109 self.typeDescriptor = 'buffer'
109 self.typeDescriptor = 'buffer'
110 self.metadata = {}
110 self.metadata = {}
111 else:
111 else:
112 self.typeDescriptor = 'pickle'
112 self.typeDescriptor = 'pickle'
113 self.metadata = {}
113 self.metadata = {}
114 self._generateData()
114 self._generateData()
115
115
116 def _generateData(self):
116 def _generateData(self):
117 if self.typeDescriptor == 'ndarray':
117 if self.typeDescriptor == 'ndarray':
118 self.data = numpy.getbuffer(self.obj)
118 self.data = numpy.getbuffer(self.obj)
119 elif self.typeDescriptor in ('bytes', 'buffer'):
119 elif self.typeDescriptor in ('bytes', 'buffer'):
120 self.data = self.obj
120 self.data = self.obj
121 elif self.typeDescriptor == 'pickle':
121 elif self.typeDescriptor == 'pickle':
122 self.data = pickle.dumps(self.obj, 2)
122 self.data = pickle.dumps(self.obj, -1)
123 else:
123 else:
124 raise SerializationError("Really wierd serialization error.")
124 raise SerializationError("Really wierd serialization error.")
125 del self.obj
125 del self.obj
126
126
127 def getData(self):
127 def getData(self):
128 return self.data
128 return self.data
129
129
130 def getDataSize(self, units=10.0**6):
130 def getDataSize(self, units=10.0**6):
131 return 1.0*len(self.data)/units
131 return 1.0*len(self.data)/units
132
132
133 def getTypeDescriptor(self):
133 def getTypeDescriptor(self):
134 return self.typeDescriptor
134 return self.typeDescriptor
135
135
136 def getMetadata(self):
136 def getMetadata(self):
137 return self.metadata
137 return self.metadata
138
138
139
139
140 class UnSerializeIt(UnSerialized):
140 class UnSerializeIt(UnSerialized):
141
141
142 # implements(IUnSerialized)
142 # implements(IUnSerialized)
143
143
144 def __init__(self, serialized):
144 def __init__(self, serialized):
145 self.serialized = serialized
145 self.serialized = serialized
146
146
147 def getObject(self):
147 def getObject(self):
148 typeDescriptor = self.serialized.getTypeDescriptor()
148 typeDescriptor = self.serialized.getTypeDescriptor()
149 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
149 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
150 result = numpy.frombuffer(self.serialized.getData(), dtype = self.serialized.metadata['dtype'])
150 result = numpy.frombuffer(self.serialized.getData(), dtype = self.serialized.metadata['dtype'])
151 result.shape = self.serialized.metadata['shape']
151 result.shape = self.serialized.metadata['shape']
152 # This is a hack to make the array writable. We are working with
152 # This is a hack to make the array writable. We are working with
153 # the numpy folks to address this issue.
153 # the numpy folks to address this issue.
154 result = result.copy()
154 result = result.copy()
155 elif typeDescriptor == 'pickle':
155 elif typeDescriptor == 'pickle':
156 result = pickle.loads(self.serialized.getData())
156 result = pickle.loads(self.serialized.getData())
157 elif typeDescriptor in ('bytes', 'buffer'):
157 elif typeDescriptor in ('bytes', 'buffer'):
158 result = self.serialized.getData()
158 result = self.serialized.getData()
159 else:
159 else:
160 raise SerializationError("Really wierd serialization error.")
160 raise SerializationError("Really wierd serialization error.")
161 return result
161 return result
162
162
163 def serialize(obj):
163 def serialize(obj):
164 return SerializeIt(UnSerialized(obj))
164 return SerializeIt(UnSerialized(obj))
165
165
166 def unserialize(serialized):
166 def unserialize(serialized):
167 return UnSerializeIt(serialized).getObject()
167 return UnSerializeIt(serialized).getObject()
@@ -1,447 +1,447 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4
4
5
5
6 import os
6 import os
7 import sys
7 import sys
8 import traceback
8 import traceback
9 import pprint
9 import pprint
10 import uuid
10 import uuid
11
11
12 import zmq
12 import zmq
13 from zmq.utils import jsonapi
13 from zmq.utils import jsonapi
14 from zmq.eventloop.zmqstream import ZMQStream
14 from zmq.eventloop.zmqstream import ZMQStream
15
15
16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
17 from IPython.zmq.newserialized import serialize, unserialize
17 from IPython.zmq.newserialized import serialize, unserialize
18
18
19 try:
19 try:
20 import cPickle
20 import cPickle
21 pickle = cPickle
21 pickle = cPickle
22 except:
22 except:
23 cPickle = None
23 cPickle = None
24 import pickle
24 import pickle
25
25
26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
28 if json_name in ('jsonlib', 'jsonlib2'):
28 if json_name in ('jsonlib', 'jsonlib2'):
29 use_json = True
29 use_json = True
30 elif json_name:
30 elif json_name:
31 if cPickle is None:
31 if cPickle is None:
32 use_json = True
32 use_json = True
33 else:
33 else:
34 use_json = False
34 use_json = False
35 else:
35 else:
36 use_json = False
36 use_json = False
37
37
38 if use_json:
38 if use_json:
39 default_packer = jsonapi.dumps
39 default_packer = jsonapi.dumps
40 default_unpacker = jsonapi.loads
40 default_unpacker = jsonapi.loads
41 else:
41 else:
42 default_packer = lambda o: pickle.dumps(o,-1)
42 default_packer = lambda o: pickle.dumps(o,-1)
43 default_unpacker = pickle.loads
43 default_unpacker = pickle.loads
44
44
45
45
46 DELIM="<IDS|MSG>"
46 DELIM="<IDS|MSG>"
47
47
48 def wrap_exception():
48 def wrap_exception():
49 etype, evalue, tb = sys.exc_info()
49 etype, evalue, tb = sys.exc_info()
50 tb = traceback.format_exception(etype, evalue, tb)
50 tb = traceback.format_exception(etype, evalue, tb)
51 exc_content = {
51 exc_content = {
52 u'status' : u'error',
52 u'status' : u'error',
53 u'traceback' : tb,
53 u'traceback' : tb,
54 u'etype' : unicode(etype),
54 u'etype' : unicode(etype),
55 u'evalue' : unicode(evalue)
55 u'evalue' : unicode(evalue)
56 }
56 }
57 return exc_content
57 return exc_content
58
58
59 class KernelError(Exception):
59 class KernelError(Exception):
60 pass
60 pass
61
61
62 def unwrap_exception(content):
62 def unwrap_exception(content):
63 err = KernelError(content['etype'], content['evalue'])
63 err = KernelError(content['etype'], content['evalue'])
64 err.evalue = content['evalue']
64 err.evalue = content['evalue']
65 err.etype = content['etype']
65 err.etype = content['etype']
66 err.traceback = ''.join(content['traceback'])
66 err.traceback = ''.join(content['traceback'])
67 return err
67 return err
68
68
69
69
70 class Message(object):
70 class Message(object):
71 """A simple message object that maps dict keys to attributes.
71 """A simple message object that maps dict keys to attributes.
72
72
73 A Message can be created from a dict and a dict from a Message instance
73 A Message can be created from a dict and a dict from a Message instance
74 simply by calling dict(msg_obj)."""
74 simply by calling dict(msg_obj)."""
75
75
76 def __init__(self, msg_dict):
76 def __init__(self, msg_dict):
77 dct = self.__dict__
77 dct = self.__dict__
78 for k, v in dict(msg_dict).iteritems():
78 for k, v in dict(msg_dict).iteritems():
79 if isinstance(v, dict):
79 if isinstance(v, dict):
80 v = Message(v)
80 v = Message(v)
81 dct[k] = v
81 dct[k] = v
82
82
83 # Having this iterator lets dict(msg_obj) work out of the box.
83 # Having this iterator lets dict(msg_obj) work out of the box.
84 def __iter__(self):
84 def __iter__(self):
85 return iter(self.__dict__.iteritems())
85 return iter(self.__dict__.iteritems())
86
86
87 def __repr__(self):
87 def __repr__(self):
88 return repr(self.__dict__)
88 return repr(self.__dict__)
89
89
90 def __str__(self):
90 def __str__(self):
91 return pprint.pformat(self.__dict__)
91 return pprint.pformat(self.__dict__)
92
92
93 def __contains__(self, k):
93 def __contains__(self, k):
94 return k in self.__dict__
94 return k in self.__dict__
95
95
96 def __getitem__(self, k):
96 def __getitem__(self, k):
97 return self.__dict__[k]
97 return self.__dict__[k]
98
98
99
99
100 def msg_header(msg_id, msg_type, username, session):
100 def msg_header(msg_id, msg_type, username, session):
101 return locals()
101 return locals()
102 # return {
102 # return {
103 # 'msg_id' : msg_id,
103 # 'msg_id' : msg_id,
104 # 'msg_type': msg_type,
104 # 'msg_type': msg_type,
105 # 'username' : username,
105 # 'username' : username,
106 # 'session' : session
106 # 'session' : session
107 # }
107 # }
108
108
109
109
110 def extract_header(msg_or_header):
110 def extract_header(msg_or_header):
111 """Given a message or header, return the header."""
111 """Given a message or header, return the header."""
112 if not msg_or_header:
112 if not msg_or_header:
113 return {}
113 return {}
114 try:
114 try:
115 # See if msg_or_header is the entire message.
115 # See if msg_or_header is the entire message.
116 h = msg_or_header['header']
116 h = msg_or_header['header']
117 except KeyError:
117 except KeyError:
118 try:
118 try:
119 # See if msg_or_header is just the header
119 # See if msg_or_header is just the header
120 h = msg_or_header['msg_id']
120 h = msg_or_header['msg_id']
121 except KeyError:
121 except KeyError:
122 raise
122 raise
123 else:
123 else:
124 h = msg_or_header
124 h = msg_or_header
125 if not isinstance(h, dict):
125 if not isinstance(h, dict):
126 h = dict(h)
126 h = dict(h)
127 return h
127 return h
128
128
129 def rekey(dikt):
129 def rekey(dikt):
130 """rekey a dict that has been forced to use str keys where there should be
130 """rekey a dict that has been forced to use str keys where there should be
131 ints by json. This belongs in the jsonutil added by fperez."""
131 ints by json. This belongs in the jsonutil added by fperez."""
132 for k in dikt.iterkeys():
132 for k in dikt.iterkeys():
133 if isinstance(k, str):
133 if isinstance(k, str):
134 ik=fk=None
134 ik=fk=None
135 try:
135 try:
136 ik = int(k)
136 ik = int(k)
137 except ValueError:
137 except ValueError:
138 try:
138 try:
139 fk = float(k)
139 fk = float(k)
140 except ValueError:
140 except ValueError:
141 continue
141 continue
142 if ik is not None:
142 if ik is not None:
143 nk = ik
143 nk = ik
144 else:
144 else:
145 nk = fk
145 nk = fk
146 if nk in dikt:
146 if nk in dikt:
147 raise KeyError("already have key %r"%nk)
147 raise KeyError("already have key %r"%nk)
148 dikt[nk] = dikt.pop(k)
148 dikt[nk] = dikt.pop(k)
149 return dikt
149 return dikt
150
150
151 def serialize_object(obj, threshold=64e-6):
151 def serialize_object(obj, threshold=64e-6):
152 """serialize an object into a list of sendable buffers.
152 """serialize an object into a list of sendable buffers.
153
153
154 Returns: (pmd, bufs)
154 Returns: (pmd, bufs)
155 where pmd is the pickled metadata wrapper, and bufs
155 where pmd is the pickled metadata wrapper, and bufs
156 is a list of data buffers"""
156 is a list of data buffers"""
157 # threshold is 100 B
157 # threshold is 100 B
158 databuffers = []
158 databuffers = []
159 if isinstance(obj, (list, tuple)):
159 if isinstance(obj, (list, tuple)):
160 clist = canSequence(obj)
160 clist = canSequence(obj)
161 slist = map(serialize, clist)
161 slist = map(serialize, clist)
162 for s in slist:
162 for s in slist:
163 if s.getDataSize() > threshold:
163 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
164 databuffers.append(s.getData())
164 databuffers.append(s.getData())
165 s.data = None
165 s.data = None
166 return pickle.dumps(slist,-1), databuffers
166 return pickle.dumps(slist,-1), databuffers
167 elif isinstance(obj, dict):
167 elif isinstance(obj, dict):
168 sobj = {}
168 sobj = {}
169 for k in sorted(obj.iterkeys()):
169 for k in sorted(obj.iterkeys()):
170 s = serialize(can(obj[k]))
170 s = serialize(can(obj[k]))
171 if s.getDataSize() > threshold:
171 if s.getDataSize() > threshold:
172 databuffers.append(s.getData())
172 databuffers.append(s.getData())
173 s.data = None
173 s.data = None
174 sobj[k] = s
174 sobj[k] = s
175 return pickle.dumps(sobj,-1),databuffers
175 return pickle.dumps(sobj,-1),databuffers
176 else:
176 else:
177 s = serialize(can(obj))
177 s = serialize(can(obj))
178 if s.getDataSize() > threshold:
178 if s.getDataSize() > threshold:
179 databuffers.append(s.getData())
179 databuffers.append(s.getData())
180 s.data = None
180 s.data = None
181 return pickle.dumps(s,-1),databuffers
181 return pickle.dumps(s,-1),databuffers
182
182
183
183
184 def unserialize_object(bufs):
184 def unserialize_object(bufs):
185 """reconstruct an object serialized by serialize_object from data buffers"""
185 """reconstruct an object serialized by serialize_object from data buffers"""
186 bufs = list(bufs)
186 bufs = list(bufs)
187 sobj = pickle.loads(bufs.pop(0))
187 sobj = pickle.loads(bufs.pop(0))
188 if isinstance(sobj, (list, tuple)):
188 if isinstance(sobj, (list, tuple)):
189 for s in sobj:
189 for s in sobj:
190 if s.data is None:
190 if s.data is None:
191 s.data = bufs.pop(0)
191 s.data = bufs.pop(0)
192 return uncanSequence(map(unserialize, sobj))
192 return uncanSequence(map(unserialize, sobj))
193 elif isinstance(sobj, dict):
193 elif isinstance(sobj, dict):
194 newobj = {}
194 newobj = {}
195 for k in sorted(sobj.iterkeys()):
195 for k in sorted(sobj.iterkeys()):
196 s = sobj[k]
196 s = sobj[k]
197 if s.data is None:
197 if s.data is None:
198 s.data = bufs.pop(0)
198 s.data = bufs.pop(0)
199 newobj[k] = uncan(unserialize(s))
199 newobj[k] = uncan(unserialize(s))
200 return newobj
200 return newobj
201 else:
201 else:
202 if sobj.data is None:
202 if sobj.data is None:
203 sobj.data = bufs.pop(0)
203 sobj.data = bufs.pop(0)
204 return uncan(unserialize(sobj))
204 return uncan(unserialize(sobj))
205
205
206 def pack_apply_message(f, args, kwargs, threshold=64e-6):
206 def pack_apply_message(f, args, kwargs, threshold=64e-6):
207 """pack up a function, args, and kwargs to be sent over the wire
207 """pack up a function, args, and kwargs to be sent over the wire
208 as a series of buffers. Any object whose data is larger than `threshold`
208 as a series of buffers. Any object whose data is larger than `threshold`
209 will not have their data copied (currently only numpy arrays support zero-copy)"""
209 will not have their data copied (currently only numpy arrays support zero-copy)"""
210 msg = [pickle.dumps(can(f),-1)]
210 msg = [pickle.dumps(can(f),-1)]
211 databuffers = [] # for large objects
211 databuffers = [] # for large objects
212 sargs, bufs = serialize_object(args,threshold)
212 sargs, bufs = serialize_object(args,threshold)
213 msg.append(sargs)
213 msg.append(sargs)
214 databuffers.extend(bufs)
214 databuffers.extend(bufs)
215 skwargs, bufs = serialize_object(kwargs,threshold)
215 skwargs, bufs = serialize_object(kwargs,threshold)
216 msg.append(skwargs)
216 msg.append(skwargs)
217 databuffers.extend(bufs)
217 databuffers.extend(bufs)
218 msg.extend(databuffers)
218 msg.extend(databuffers)
219 return msg
219 return msg
220
220
221 def unpack_apply_message(bufs, g=None, copy=True):
221 def unpack_apply_message(bufs, g=None, copy=True):
222 """unpack f,args,kwargs from buffers packed by pack_apply_message()
222 """unpack f,args,kwargs from buffers packed by pack_apply_message()
223 Returns: original f,args,kwargs"""
223 Returns: original f,args,kwargs"""
224 bufs = list(bufs) # allow us to pop
224 bufs = list(bufs) # allow us to pop
225 assert len(bufs) >= 3, "not enough buffers!"
225 assert len(bufs) >= 3, "not enough buffers!"
226 if not copy:
226 if not copy:
227 for i in range(3):
227 for i in range(3):
228 bufs[i] = bufs[i].bytes
228 bufs[i] = bufs[i].bytes
229 cf = pickle.loads(bufs.pop(0))
229 cf = pickle.loads(bufs.pop(0))
230 sargs = list(pickle.loads(bufs.pop(0)))
230 sargs = list(pickle.loads(bufs.pop(0)))
231 skwargs = dict(pickle.loads(bufs.pop(0)))
231 skwargs = dict(pickle.loads(bufs.pop(0)))
232 # print sargs, skwargs
232 # print sargs, skwargs
233 f = cf.getFunction(g)
233 f = cf.getFunction(g)
234 for sa in sargs:
234 for sa in sargs:
235 if sa.data is None:
235 if sa.data is None:
236 m = bufs.pop(0)
236 m = bufs.pop(0)
237 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
237 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
238 if copy:
238 if copy:
239 sa.data = buffer(m)
239 sa.data = buffer(m)
240 else:
240 else:
241 sa.data = m.buffer
241 sa.data = m.buffer
242 else:
242 else:
243 if copy:
243 if copy:
244 sa.data = m
244 sa.data = m
245 else:
245 else:
246 sa.data = m.bytes
246 sa.data = m.bytes
247
247
248 args = uncanSequence(map(unserialize, sargs), g)
248 args = uncanSequence(map(unserialize, sargs), g)
249 kwargs = {}
249 kwargs = {}
250 for k in sorted(skwargs.iterkeys()):
250 for k in sorted(skwargs.iterkeys()):
251 sa = skwargs[k]
251 sa = skwargs[k]
252 if sa.data is None:
252 if sa.data is None:
253 sa.data = bufs.pop(0)
253 sa.data = bufs.pop(0)
254 kwargs[k] = uncan(unserialize(sa), g)
254 kwargs[k] = uncan(unserialize(sa), g)
255
255
256 return f,args,kwargs
256 return f,args,kwargs
257
257
258 class StreamSession(object):
258 class StreamSession(object):
259 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
259 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
260 debug=False
260 debug=False
261 def __init__(self, username=None, session=None, packer=None, unpacker=None):
261 def __init__(self, username=None, session=None, packer=None, unpacker=None):
262 if username is None:
262 if username is None:
263 username = os.environ.get('USER','username')
263 username = os.environ.get('USER','username')
264 self.username = username
264 self.username = username
265 if session is None:
265 if session is None:
266 self.session = str(uuid.uuid4())
266 self.session = str(uuid.uuid4())
267 else:
267 else:
268 self.session = session
268 self.session = session
269 self.msg_id = str(uuid.uuid4())
269 self.msg_id = str(uuid.uuid4())
270 if packer is None:
270 if packer is None:
271 self.pack = default_packer
271 self.pack = default_packer
272 else:
272 else:
273 if not callable(packer):
273 if not callable(packer):
274 raise TypeError("packer must be callable, not %s"%type(packer))
274 raise TypeError("packer must be callable, not %s"%type(packer))
275 self.pack = packer
275 self.pack = packer
276
276
277 if unpacker is None:
277 if unpacker is None:
278 self.unpack = default_unpacker
278 self.unpack = default_unpacker
279 else:
279 else:
280 if not callable(unpacker):
280 if not callable(unpacker):
281 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
281 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
282 self.unpack = unpacker
282 self.unpack = unpacker
283
283
284 self.none = self.pack({})
284 self.none = self.pack({})
285
285
286 def msg_header(self, msg_type):
286 def msg_header(self, msg_type):
287 h = msg_header(self.msg_id, msg_type, self.username, self.session)
287 h = msg_header(self.msg_id, msg_type, self.username, self.session)
288 self.msg_id = str(uuid.uuid4())
288 self.msg_id = str(uuid.uuid4())
289 return h
289 return h
290
290
291 def msg(self, msg_type, content=None, parent=None, subheader=None):
291 def msg(self, msg_type, content=None, parent=None, subheader=None):
292 msg = {}
292 msg = {}
293 msg['header'] = self.msg_header(msg_type)
293 msg['header'] = self.msg_header(msg_type)
294 msg['msg_id'] = msg['header']['msg_id']
294 msg['msg_id'] = msg['header']['msg_id']
295 msg['parent_header'] = {} if parent is None else extract_header(parent)
295 msg['parent_header'] = {} if parent is None else extract_header(parent)
296 msg['msg_type'] = msg_type
296 msg['msg_type'] = msg_type
297 msg['content'] = {} if content is None else content
297 msg['content'] = {} if content is None else content
298 sub = {} if subheader is None else subheader
298 sub = {} if subheader is None else subheader
299 msg['header'].update(sub)
299 msg['header'].update(sub)
300 return msg
300 return msg
301
301
302 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
302 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
303 """send a message via stream"""
303 """send a message via stream"""
304 msg = self.msg(msg_type, content, parent, subheader)
304 msg = self.msg(msg_type, content, parent, subheader)
305 buffers = [] if buffers is None else buffers
305 buffers = [] if buffers is None else buffers
306 to_send = []
306 to_send = []
307 if isinstance(ident, list):
307 if isinstance(ident, list):
308 # accept list of idents
308 # accept list of idents
309 to_send.extend(ident)
309 to_send.extend(ident)
310 elif ident is not None:
310 elif ident is not None:
311 to_send.append(ident)
311 to_send.append(ident)
312 to_send.append(DELIM)
312 to_send.append(DELIM)
313 to_send.append(self.pack(msg['header']))
313 to_send.append(self.pack(msg['header']))
314 to_send.append(self.pack(msg['parent_header']))
314 to_send.append(self.pack(msg['parent_header']))
315 # if parent is None:
315 # if parent is None:
316 # to_send.append(self.none)
316 # to_send.append(self.none)
317 # else:
317 # else:
318 # to_send.append(self.pack(dict(parent)))
318 # to_send.append(self.pack(dict(parent)))
319 if content is None:
319 if content is None:
320 content = self.none
320 content = self.none
321 elif isinstance(content, dict):
321 elif isinstance(content, dict):
322 content = self.pack(content)
322 content = self.pack(content)
323 elif isinstance(content, str):
323 elif isinstance(content, str):
324 # content is already packed, as in a relayed message
324 # content is already packed, as in a relayed message
325 pass
325 pass
326 else:
326 else:
327 raise TypeError("Content incorrect type: %s"%type(content))
327 raise TypeError("Content incorrect type: %s"%type(content))
328 to_send.append(content)
328 to_send.append(content)
329 flag = 0
329 flag = 0
330 if buffers:
330 if buffers:
331 flag = zmq.SNDMORE
331 flag = zmq.SNDMORE
332 stream.send_multipart(to_send, flag, copy=False)
332 stream.send_multipart(to_send, flag, copy=False)
333 for b in buffers[:-1]:
333 for b in buffers[:-1]:
334 stream.send(b, flag, copy=False)
334 stream.send(b, flag, copy=False)
335 if buffers:
335 if buffers:
336 stream.send(buffers[-1], copy=False)
336 stream.send(buffers[-1], copy=False)
337 omsg = Message(msg)
337 omsg = Message(msg)
338 if self.debug:
338 if self.debug:
339 pprint.pprint(omsg)
339 pprint.pprint(omsg)
340 pprint.pprint(to_send)
340 pprint.pprint(to_send)
341 pprint.pprint(buffers)
341 pprint.pprint(buffers)
342 return omsg
342 return omsg
343
343
344 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
344 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
345 """receives and unpacks a message
345 """receives and unpacks a message
346 returns [idents], msg"""
346 returns [idents], msg"""
347 if isinstance(socket, ZMQStream):
347 if isinstance(socket, ZMQStream):
348 socket = socket.socket
348 socket = socket.socket
349 try:
349 try:
350 msg = socket.recv_multipart(mode)
350 msg = socket.recv_multipart(mode)
351 except zmq.ZMQError, e:
351 except zmq.ZMQError, e:
352 if e.errno == zmq.EAGAIN:
352 if e.errno == zmq.EAGAIN:
353 # We can convert EAGAIN to None as we know in this case
353 # We can convert EAGAIN to None as we know in this case
354 # recv_json won't return None.
354 # recv_json won't return None.
355 return None
355 return None
356 else:
356 else:
357 raise
357 raise
358 # return an actual Message object
358 # return an actual Message object
359 # determine the number of idents by trying to unpack them.
359 # determine the number of idents by trying to unpack them.
360 # this is terrible:
360 # this is terrible:
361 idents, msg = self.feed_identities(msg, copy)
361 idents, msg = self.feed_identities(msg, copy)
362 try:
362 try:
363 return idents, self.unpack_message(msg, content=content, copy=copy)
363 return idents, self.unpack_message(msg, content=content, copy=copy)
364 except Exception, e:
364 except Exception, e:
365 print idents, msg
365 print idents, msg
366 # TODO: handle it
366 # TODO: handle it
367 raise e
367 raise e
368
368
369 def feed_identities(self, msg, copy=True):
369 def feed_identities(self, msg, copy=True):
370 """This is a completely horrible thing, but it strips the zmq
370 """This is a completely horrible thing, but it strips the zmq
371 ident prefixes off of a message. It will break if any identities
371 ident prefixes off of a message. It will break if any identities
372 are unpackable by self.unpack."""
372 are unpackable by self.unpack."""
373 msg = list(msg)
373 msg = list(msg)
374 idents = []
374 idents = []
375 while len(msg) > 3:
375 while len(msg) > 3:
376 if copy:
376 if copy:
377 s = msg[0]
377 s = msg[0]
378 else:
378 else:
379 s = msg[0].bytes
379 s = msg[0].bytes
380 if s == DELIM:
380 if s == DELIM:
381 msg.pop(0)
381 msg.pop(0)
382 break
382 break
383 else:
383 else:
384 idents.append(s)
384 idents.append(s)
385 msg.pop(0)
385 msg.pop(0)
386
386
387 return idents, msg
387 return idents, msg
388
388
389 def unpack_message(self, msg, content=True, copy=True):
389 def unpack_message(self, msg, content=True, copy=True):
390 """return a message object from the format
390 """return a message object from the format
391 sent by self.send.
391 sent by self.send.
392
392
393 parameters:
393 parameters:
394
394
395 content : bool (True)
395 content : bool (True)
396 whether to unpack the content dict (True),
396 whether to unpack the content dict (True),
397 or leave it serialized (False)
397 or leave it serialized (False)
398
398
399 copy : bool (True)
399 copy : bool (True)
400 whether to return the bytes (True),
400 whether to return the bytes (True),
401 or the non-copying Message object in each place (False)
401 or the non-copying Message object in each place (False)
402
402
403 """
403 """
404 if not len(msg) >= 3:
404 if not len(msg) >= 3:
405 raise TypeError("malformed message, must have at least 3 elements")
405 raise TypeError("malformed message, must have at least 3 elements")
406 message = {}
406 message = {}
407 if not copy:
407 if not copy:
408 for i in range(3):
408 for i in range(3):
409 msg[i] = msg[i].bytes
409 msg[i] = msg[i].bytes
410 message['header'] = self.unpack(msg[0])
410 message['header'] = self.unpack(msg[0])
411 message['msg_type'] = message['header']['msg_type']
411 message['msg_type'] = message['header']['msg_type']
412 message['parent_header'] = self.unpack(msg[1])
412 message['parent_header'] = self.unpack(msg[1])
413 if content:
413 if content:
414 message['content'] = self.unpack(msg[2])
414 message['content'] = self.unpack(msg[2])
415 else:
415 else:
416 message['content'] = msg[2]
416 message['content'] = msg[2]
417
417
418 # message['buffers'] = msg[3:]
418 # message['buffers'] = msg[3:]
419 # else:
419 # else:
420 # message['header'] = self.unpack(msg[0].bytes)
420 # message['header'] = self.unpack(msg[0].bytes)
421 # message['msg_type'] = message['header']['msg_type']
421 # message['msg_type'] = message['header']['msg_type']
422 # message['parent_header'] = self.unpack(msg[1].bytes)
422 # message['parent_header'] = self.unpack(msg[1].bytes)
423 # if content:
423 # if content:
424 # message['content'] = self.unpack(msg[2].bytes)
424 # message['content'] = self.unpack(msg[2].bytes)
425 # else:
425 # else:
426 # message['content'] = msg[2].bytes
426 # message['content'] = msg[2].bytes
427
427
428 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
428 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
429 return message
429 return message
430
430
431
431
432
432
433 def test_msg2obj():
433 def test_msg2obj():
434 am = dict(x=1)
434 am = dict(x=1)
435 ao = Message(am)
435 ao = Message(am)
436 assert ao.x == am['x']
436 assert ao.x == am['x']
437
437
438 am['y'] = dict(z=1)
438 am['y'] = dict(z=1)
439 ao = Message(am)
439 ao = Message(am)
440 assert ao.y.z == am['y']['z']
440 assert ao.y.z == am['y']['z']
441
441
442 k1, k2 = 'y', 'z'
442 k1, k2 = 'y', 'z'
443 assert ao[k1][k2] == am[k1][k2]
443 assert ao[k1][k2] == am[k1][k2]
444
444
445 am2 = dict(ao)
445 am2 = dict(ao)
446 assert am['x'] == am2['x']
446 assert am['x'] == am2['x']
447 assert am['y']['z'] == am2['y']['z']
447 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now