##// END OF EJS Templates
fixed buffer serialization for buffers below threshold
MinRK -
Show More
@@ -1,167 +1,167 b''
1 1 # encoding: utf-8
2 2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
3 3
4 4 """Refactored serialization classes and interfaces."""
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Imports
20 20 #-------------------------------------------------------------------------------
21 21
22 22 import cPickle as pickle
23 23
24 24 # from twisted.python import components
25 25 # from zope.interface import Interface, implements
26 26
27 27 try:
28 28 import numpy
29 29 except ImportError:
30 30 pass
31 31
32 32 from IPython.kernel.error import SerializationError
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # Classes and functions
36 36 #-----------------------------------------------------------------------------
37 37
38 38 class ISerialized:
39 39
40 40 def getData():
41 41 """"""
42 42
43 43 def getDataSize(units=10.0**6):
44 44 """"""
45 45
46 46 def getTypeDescriptor():
47 47 """"""
48 48
49 49 def getMetadata():
50 50 """"""
51 51
52 52
53 53 class IUnSerialized:
54 54
55 55 def getObject():
56 56 """"""
57 57
58 58 class Serialized(object):
59 59
60 60 # implements(ISerialized)
61 61
62 62 def __init__(self, data, typeDescriptor, metadata={}):
63 63 self.data = data
64 64 self.typeDescriptor = typeDescriptor
65 65 self.metadata = metadata
66 66
67 67 def getData(self):
68 68 return self.data
69 69
70 70 def getDataSize(self, units=10.0**6):
71 71 return len(self.data)/units
72 72
73 73 def getTypeDescriptor(self):
74 74 return self.typeDescriptor
75 75
76 76 def getMetadata(self):
77 77 return self.metadata
78 78
79 79
80 80 class UnSerialized(object):
81 81
82 82 # implements(IUnSerialized)
83 83
84 84 def __init__(self, obj):
85 85 self.obj = obj
86 86
87 87 def getObject(self):
88 88 return self.obj
89 89
90 90
91 91 class SerializeIt(object):
92 92
93 93 # implements(ISerialized)
94 94
95 95 def __init__(self, unSerialized):
96 96 self.data = None
97 97 self.obj = unSerialized.getObject()
98 98 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
99 99 if len(self.obj) == 0: # length 0 arrays can't be reconstructed
100 100 raise SerializationError("You cannot send a length 0 array")
101 101 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
102 102 self.typeDescriptor = 'ndarray'
103 103 self.metadata = {'shape':self.obj.shape,
104 104 'dtype':self.obj.dtype.str}
105 105 elif isinstance(self.obj, str):
106 106 self.typeDescriptor = 'bytes'
107 107 self.metadata = {}
108 108 elif isinstance(self.obj, buffer):
109 109 self.typeDescriptor = 'buffer'
110 110 self.metadata = {}
111 111 else:
112 112 self.typeDescriptor = 'pickle'
113 113 self.metadata = {}
114 114 self._generateData()
115 115
116 116 def _generateData(self):
117 117 if self.typeDescriptor == 'ndarray':
118 118 self.data = numpy.getbuffer(self.obj)
119 119 elif self.typeDescriptor in ('bytes', 'buffer'):
120 120 self.data = self.obj
121 121 elif self.typeDescriptor == 'pickle':
122 self.data = pickle.dumps(self.obj, 2)
122 self.data = pickle.dumps(self.obj, -1)
123 123 else:
124 124 raise SerializationError("Really wierd serialization error.")
125 125 del self.obj
126 126
127 127 def getData(self):
128 128 return self.data
129 129
130 130 def getDataSize(self, units=10.0**6):
131 131 return 1.0*len(self.data)/units
132 132
133 133 def getTypeDescriptor(self):
134 134 return self.typeDescriptor
135 135
136 136 def getMetadata(self):
137 137 return self.metadata
138 138
139 139
140 140 class UnSerializeIt(UnSerialized):
141 141
142 142 # implements(IUnSerialized)
143 143
144 144 def __init__(self, serialized):
145 145 self.serialized = serialized
146 146
147 147 def getObject(self):
148 148 typeDescriptor = self.serialized.getTypeDescriptor()
149 149 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
150 150 result = numpy.frombuffer(self.serialized.getData(), dtype = self.serialized.metadata['dtype'])
151 151 result.shape = self.serialized.metadata['shape']
152 152 # This is a hack to make the array writable. We are working with
153 153 # the numpy folks to address this issue.
154 154 result = result.copy()
155 155 elif typeDescriptor == 'pickle':
156 156 result = pickle.loads(self.serialized.getData())
157 157 elif typeDescriptor in ('bytes', 'buffer'):
158 158 result = self.serialized.getData()
159 159 else:
160 160 raise SerializationError("Really wierd serialization error.")
161 161 return result
162 162
163 163 def serialize(obj):
164 164 return SerializeIt(UnSerialized(obj))
165 165
166 166 def unserialize(serialized):
167 167 return UnSerializeIt(serialized).getObject()
@@ -1,447 +1,447 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11
12 12 import zmq
13 13 from zmq.utils import jsonapi
14 14 from zmq.eventloop.zmqstream import ZMQStream
15 15
16 16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
17 17 from IPython.zmq.newserialized import serialize, unserialize
18 18
19 19 try:
20 20 import cPickle
21 21 pickle = cPickle
22 22 except:
23 23 cPickle = None
24 24 import pickle
25 25
26 26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
27 27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
28 28 if json_name in ('jsonlib', 'jsonlib2'):
29 29 use_json = True
30 30 elif json_name:
31 31 if cPickle is None:
32 32 use_json = True
33 33 else:
34 34 use_json = False
35 35 else:
36 36 use_json = False
37 37
38 38 if use_json:
39 39 default_packer = jsonapi.dumps
40 40 default_unpacker = jsonapi.loads
41 41 else:
42 42 default_packer = lambda o: pickle.dumps(o,-1)
43 43 default_unpacker = pickle.loads
44 44
45 45
46 46 DELIM="<IDS|MSG>"
47 47
48 48 def wrap_exception():
49 49 etype, evalue, tb = sys.exc_info()
50 50 tb = traceback.format_exception(etype, evalue, tb)
51 51 exc_content = {
52 52 u'status' : u'error',
53 53 u'traceback' : tb,
54 54 u'etype' : unicode(etype),
55 55 u'evalue' : unicode(evalue)
56 56 }
57 57 return exc_content
58 58
59 59 class KernelError(Exception):
60 60 pass
61 61
62 62 def unwrap_exception(content):
63 63 err = KernelError(content['etype'], content['evalue'])
64 64 err.evalue = content['evalue']
65 65 err.etype = content['etype']
66 66 err.traceback = ''.join(content['traceback'])
67 67 return err
68 68
69 69
70 70 class Message(object):
71 71 """A simple message object that maps dict keys to attributes.
72 72
73 73 A Message can be created from a dict and a dict from a Message instance
74 74 simply by calling dict(msg_obj)."""
75 75
76 76 def __init__(self, msg_dict):
77 77 dct = self.__dict__
78 78 for k, v in dict(msg_dict).iteritems():
79 79 if isinstance(v, dict):
80 80 v = Message(v)
81 81 dct[k] = v
82 82
83 83 # Having this iterator lets dict(msg_obj) work out of the box.
84 84 def __iter__(self):
85 85 return iter(self.__dict__.iteritems())
86 86
87 87 def __repr__(self):
88 88 return repr(self.__dict__)
89 89
90 90 def __str__(self):
91 91 return pprint.pformat(self.__dict__)
92 92
93 93 def __contains__(self, k):
94 94 return k in self.__dict__
95 95
96 96 def __getitem__(self, k):
97 97 return self.__dict__[k]
98 98
99 99
100 100 def msg_header(msg_id, msg_type, username, session):
101 101 return locals()
102 102 # return {
103 103 # 'msg_id' : msg_id,
104 104 # 'msg_type': msg_type,
105 105 # 'username' : username,
106 106 # 'session' : session
107 107 # }
108 108
109 109
110 110 def extract_header(msg_or_header):
111 111 """Given a message or header, return the header."""
112 112 if not msg_or_header:
113 113 return {}
114 114 try:
115 115 # See if msg_or_header is the entire message.
116 116 h = msg_or_header['header']
117 117 except KeyError:
118 118 try:
119 119 # See if msg_or_header is just the header
120 120 h = msg_or_header['msg_id']
121 121 except KeyError:
122 122 raise
123 123 else:
124 124 h = msg_or_header
125 125 if not isinstance(h, dict):
126 126 h = dict(h)
127 127 return h
128 128
129 129 def rekey(dikt):
130 130 """rekey a dict that has been forced to use str keys where there should be
131 131 ints by json. This belongs in the jsonutil added by fperez."""
132 132 for k in dikt.iterkeys():
133 133 if isinstance(k, str):
134 134 ik=fk=None
135 135 try:
136 136 ik = int(k)
137 137 except ValueError:
138 138 try:
139 139 fk = float(k)
140 140 except ValueError:
141 141 continue
142 142 if ik is not None:
143 143 nk = ik
144 144 else:
145 145 nk = fk
146 146 if nk in dikt:
147 147 raise KeyError("already have key %r"%nk)
148 148 dikt[nk] = dikt.pop(k)
149 149 return dikt
150 150
151 151 def serialize_object(obj, threshold=64e-6):
152 152 """serialize an object into a list of sendable buffers.
153 153
154 154 Returns: (pmd, bufs)
155 155 where pmd is the pickled metadata wrapper, and bufs
156 156 is a list of data buffers"""
157 157 # threshold is 100 B
158 158 databuffers = []
159 159 if isinstance(obj, (list, tuple)):
160 160 clist = canSequence(obj)
161 161 slist = map(serialize, clist)
162 162 for s in slist:
163 if s.getDataSize() > threshold:
163 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
164 164 databuffers.append(s.getData())
165 165 s.data = None
166 166 return pickle.dumps(slist,-1), databuffers
167 167 elif isinstance(obj, dict):
168 168 sobj = {}
169 169 for k in sorted(obj.iterkeys()):
170 170 s = serialize(can(obj[k]))
171 171 if s.getDataSize() > threshold:
172 172 databuffers.append(s.getData())
173 173 s.data = None
174 174 sobj[k] = s
175 175 return pickle.dumps(sobj,-1),databuffers
176 176 else:
177 177 s = serialize(can(obj))
178 178 if s.getDataSize() > threshold:
179 179 databuffers.append(s.getData())
180 180 s.data = None
181 181 return pickle.dumps(s,-1),databuffers
182 182
183 183
184 184 def unserialize_object(bufs):
185 185 """reconstruct an object serialized by serialize_object from data buffers"""
186 186 bufs = list(bufs)
187 187 sobj = pickle.loads(bufs.pop(0))
188 188 if isinstance(sobj, (list, tuple)):
189 189 for s in sobj:
190 190 if s.data is None:
191 191 s.data = bufs.pop(0)
192 192 return uncanSequence(map(unserialize, sobj))
193 193 elif isinstance(sobj, dict):
194 194 newobj = {}
195 195 for k in sorted(sobj.iterkeys()):
196 196 s = sobj[k]
197 197 if s.data is None:
198 198 s.data = bufs.pop(0)
199 199 newobj[k] = uncan(unserialize(s))
200 200 return newobj
201 201 else:
202 202 if sobj.data is None:
203 203 sobj.data = bufs.pop(0)
204 204 return uncan(unserialize(sobj))
205 205
206 206 def pack_apply_message(f, args, kwargs, threshold=64e-6):
207 207 """pack up a function, args, and kwargs to be sent over the wire
208 208 as a series of buffers. Any object whose data is larger than `threshold`
209 209 will not have their data copied (currently only numpy arrays support zero-copy)"""
210 210 msg = [pickle.dumps(can(f),-1)]
211 211 databuffers = [] # for large objects
212 212 sargs, bufs = serialize_object(args,threshold)
213 213 msg.append(sargs)
214 214 databuffers.extend(bufs)
215 215 skwargs, bufs = serialize_object(kwargs,threshold)
216 216 msg.append(skwargs)
217 217 databuffers.extend(bufs)
218 218 msg.extend(databuffers)
219 219 return msg
220 220
221 221 def unpack_apply_message(bufs, g=None, copy=True):
222 222 """unpack f,args,kwargs from buffers packed by pack_apply_message()
223 223 Returns: original f,args,kwargs"""
224 224 bufs = list(bufs) # allow us to pop
225 225 assert len(bufs) >= 3, "not enough buffers!"
226 226 if not copy:
227 227 for i in range(3):
228 228 bufs[i] = bufs[i].bytes
229 229 cf = pickle.loads(bufs.pop(0))
230 230 sargs = list(pickle.loads(bufs.pop(0)))
231 231 skwargs = dict(pickle.loads(bufs.pop(0)))
232 232 # print sargs, skwargs
233 233 f = cf.getFunction(g)
234 234 for sa in sargs:
235 235 if sa.data is None:
236 236 m = bufs.pop(0)
237 237 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
238 238 if copy:
239 239 sa.data = buffer(m)
240 240 else:
241 241 sa.data = m.buffer
242 242 else:
243 243 if copy:
244 244 sa.data = m
245 245 else:
246 246 sa.data = m.bytes
247 247
248 248 args = uncanSequence(map(unserialize, sargs), g)
249 249 kwargs = {}
250 250 for k in sorted(skwargs.iterkeys()):
251 251 sa = skwargs[k]
252 252 if sa.data is None:
253 253 sa.data = bufs.pop(0)
254 254 kwargs[k] = uncan(unserialize(sa), g)
255 255
256 256 return f,args,kwargs
257 257
258 258 class StreamSession(object):
259 259 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
260 260 debug=False
261 261 def __init__(self, username=None, session=None, packer=None, unpacker=None):
262 262 if username is None:
263 263 username = os.environ.get('USER','username')
264 264 self.username = username
265 265 if session is None:
266 266 self.session = str(uuid.uuid4())
267 267 else:
268 268 self.session = session
269 269 self.msg_id = str(uuid.uuid4())
270 270 if packer is None:
271 271 self.pack = default_packer
272 272 else:
273 273 if not callable(packer):
274 274 raise TypeError("packer must be callable, not %s"%type(packer))
275 275 self.pack = packer
276 276
277 277 if unpacker is None:
278 278 self.unpack = default_unpacker
279 279 else:
280 280 if not callable(unpacker):
281 281 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
282 282 self.unpack = unpacker
283 283
284 284 self.none = self.pack({})
285 285
286 286 def msg_header(self, msg_type):
287 287 h = msg_header(self.msg_id, msg_type, self.username, self.session)
288 288 self.msg_id = str(uuid.uuid4())
289 289 return h
290 290
291 291 def msg(self, msg_type, content=None, parent=None, subheader=None):
292 292 msg = {}
293 293 msg['header'] = self.msg_header(msg_type)
294 294 msg['msg_id'] = msg['header']['msg_id']
295 295 msg['parent_header'] = {} if parent is None else extract_header(parent)
296 296 msg['msg_type'] = msg_type
297 297 msg['content'] = {} if content is None else content
298 298 sub = {} if subheader is None else subheader
299 299 msg['header'].update(sub)
300 300 return msg
301 301
302 302 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
303 303 """send a message via stream"""
304 304 msg = self.msg(msg_type, content, parent, subheader)
305 305 buffers = [] if buffers is None else buffers
306 306 to_send = []
307 307 if isinstance(ident, list):
308 308 # accept list of idents
309 309 to_send.extend(ident)
310 310 elif ident is not None:
311 311 to_send.append(ident)
312 312 to_send.append(DELIM)
313 313 to_send.append(self.pack(msg['header']))
314 314 to_send.append(self.pack(msg['parent_header']))
315 315 # if parent is None:
316 316 # to_send.append(self.none)
317 317 # else:
318 318 # to_send.append(self.pack(dict(parent)))
319 319 if content is None:
320 320 content = self.none
321 321 elif isinstance(content, dict):
322 322 content = self.pack(content)
323 323 elif isinstance(content, str):
324 324 # content is already packed, as in a relayed message
325 325 pass
326 326 else:
327 327 raise TypeError("Content incorrect type: %s"%type(content))
328 328 to_send.append(content)
329 329 flag = 0
330 330 if buffers:
331 331 flag = zmq.SNDMORE
332 332 stream.send_multipart(to_send, flag, copy=False)
333 333 for b in buffers[:-1]:
334 334 stream.send(b, flag, copy=False)
335 335 if buffers:
336 336 stream.send(buffers[-1], copy=False)
337 337 omsg = Message(msg)
338 338 if self.debug:
339 339 pprint.pprint(omsg)
340 340 pprint.pprint(to_send)
341 341 pprint.pprint(buffers)
342 342 return omsg
343 343
344 344 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
345 345 """receives and unpacks a message
346 346 returns [idents], msg"""
347 347 if isinstance(socket, ZMQStream):
348 348 socket = socket.socket
349 349 try:
350 350 msg = socket.recv_multipart(mode)
351 351 except zmq.ZMQError, e:
352 352 if e.errno == zmq.EAGAIN:
353 353 # We can convert EAGAIN to None as we know in this case
354 354 # recv_json won't return None.
355 355 return None
356 356 else:
357 357 raise
358 358 # return an actual Message object
359 359 # determine the number of idents by trying to unpack them.
360 360 # this is terrible:
361 361 idents, msg = self.feed_identities(msg, copy)
362 362 try:
363 363 return idents, self.unpack_message(msg, content=content, copy=copy)
364 364 except Exception, e:
365 365 print idents, msg
366 366 # TODO: handle it
367 367 raise e
368 368
369 369 def feed_identities(self, msg, copy=True):
370 370 """This is a completely horrible thing, but it strips the zmq
371 371 ident prefixes off of a message. It will break if any identities
372 372 are unpackable by self.unpack."""
373 373 msg = list(msg)
374 374 idents = []
375 375 while len(msg) > 3:
376 376 if copy:
377 377 s = msg[0]
378 378 else:
379 379 s = msg[0].bytes
380 380 if s == DELIM:
381 381 msg.pop(0)
382 382 break
383 383 else:
384 384 idents.append(s)
385 385 msg.pop(0)
386 386
387 387 return idents, msg
388 388
389 389 def unpack_message(self, msg, content=True, copy=True):
390 390 """return a message object from the format
391 391 sent by self.send.
392 392
393 393 parameters:
394 394
395 395 content : bool (True)
396 396 whether to unpack the content dict (True),
397 397 or leave it serialized (False)
398 398
399 399 copy : bool (True)
400 400 whether to return the bytes (True),
401 401 or the non-copying Message object in each place (False)
402 402
403 403 """
404 404 if not len(msg) >= 3:
405 405 raise TypeError("malformed message, must have at least 3 elements")
406 406 message = {}
407 407 if not copy:
408 408 for i in range(3):
409 409 msg[i] = msg[i].bytes
410 410 message['header'] = self.unpack(msg[0])
411 411 message['msg_type'] = message['header']['msg_type']
412 412 message['parent_header'] = self.unpack(msg[1])
413 413 if content:
414 414 message['content'] = self.unpack(msg[2])
415 415 else:
416 416 message['content'] = msg[2]
417 417
418 418 # message['buffers'] = msg[3:]
419 419 # else:
420 420 # message['header'] = self.unpack(msg[0].bytes)
421 421 # message['msg_type'] = message['header']['msg_type']
422 422 # message['parent_header'] = self.unpack(msg[1].bytes)
423 423 # if content:
424 424 # message['content'] = self.unpack(msg[2].bytes)
425 425 # else:
426 426 # message['content'] = msg[2].bytes
427 427
428 428 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
429 429 return message
430 430
431 431
432 432
433 433 def test_msg2obj():
434 434 am = dict(x=1)
435 435 ao = Message(am)
436 436 assert ao.x == am['x']
437 437
438 438 am['y'] = dict(z=1)
439 439 ao = Message(am)
440 440 assert ao.y.z == am['y']['z']
441 441
442 442 k1, k2 = 'y', 'z'
443 443 assert ao[k1][k2] == am[k1][k2]
444 444
445 445 am2 = dict(ao)
446 446 assert am['x'] == am2['x']
447 447 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now