##// END OF EJS Templates
better serialization for parallel code...
MinRK -
Show More
@@ -0,0 +1,115 b''
1 """test serialization tools"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 import pickle
15
16 import nose.tools as nt
17
18 # from unittest import TestCaes
19 from IPython.zmq.serialize import serialize_object, unserialize_object
20 from IPython.testing import decorators as dec
21 from IPython.utils.pickleutil import CannedArray
22
23 def roundtrip(obj):
24 """roundtrip an object through serialization"""
25 bufs = serialize_object(obj)
26 obj2, remainder = unserialize_object(bufs)
27 nt.assert_equals(remainder, [])
28 return obj2
29
30 class C(object):
31 """dummy class for """
32
33 def __init__(self, **kwargs):
34 for key,value in kwargs.iteritems():
35 setattr(self, key, value)
36
37 @dec.parametric
38 def test_roundtrip_simple():
39 for obj in [
40 'hello',
41 dict(a='b', b=10),
42 [1,2,'hi'],
43 (b'123', 'hello'),
44 ]:
45 obj2 = roundtrip(obj)
46 yield nt.assert_equals(obj, obj2)
47
48 @dec.parametric
49 def test_roundtrip_nested():
50 for obj in [
51 dict(a=range(5), b={1:b'hello'}),
52 [range(5),[range(3),(1,[b'whoda'])]],
53 ]:
54 obj2 = roundtrip(obj)
55 yield nt.assert_equals(obj, obj2)
56
57 @dec.parametric
58 def test_roundtrip_buffered():
59 for obj in [
60 dict(a=b"x"*1025),
61 b"hello"*500,
62 [b"hello"*501, 1,2,3]
63 ]:
64 bufs = serialize_object(obj)
65 yield nt.assert_equals(len(bufs), 2)
66 obj2, remainder = unserialize_object(bufs)
67 yield nt.assert_equals(remainder, [])
68 yield nt.assert_equals(obj, obj2)
69
70 @dec.parametric
71 @dec.skip_without('numpy')
72 def test_numpy():
73 import numpy
74 from numpy.testing.utils import assert_array_equal
75 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
76 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
77 A = numpy.empty(shape, dtype=dtype)
78 bufs = serialize_object(A)
79 B, r = unserialize_object(bufs)
80 yield nt.assert_equals(r, [])
81 yield assert_array_equal(A,B)
82
83 @dec.parametric
84 @dec.skip_without('numpy')
85 def test_numpy_in_seq():
86 import numpy
87 from numpy.testing.utils import assert_array_equal
88 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
89 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
90 A = numpy.empty(shape, dtype=dtype)
91 bufs = serialize_object((A,1,2,b'hello'))
92 canned = pickle.loads(bufs[0])
93 yield nt.assert_true(canned[0], CannedArray)
94 tup, r = unserialize_object(bufs)
95 B = tup[0]
96 yield nt.assert_equals(r, [])
97 yield assert_array_equal(A,B)
98
99 @dec.parametric
100 @dec.skip_without('numpy')
101 def test_numpy_in_dict():
102 import numpy
103 from numpy.testing.utils import assert_array_equal
104 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
105 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
106 A = numpy.empty(shape, dtype=dtype)
107 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
108 canned = pickle.loads(bufs[0])
109 yield nt.assert_true(canned['a'], CannedArray)
110 d, r = unserialize_object(bufs)
111 B = d['a']
112 yield nt.assert_equals(r, [])
113 yield assert_array_equal(A,B)
114
115
@@ -239,7 +239,7 b' class TestView(ClusterTestCase, ParametricTestCase):'
239 239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 240 view = self.client[:]
241 241 a = numpy.arange(64)
242 view.scatter('a', a)
242 view.scatter('a', a, block=True)
243 243 b = view.gather('a', block=True)
244 244 assert_array_equal(b, a)
245 245
@@ -325,7 +325,7 b' class TestView(ClusterTestCase, ParametricTestCase):'
325 325 r = view.map_sync(lambda x:x, arr)
326 326 self.assertEqual(r, list(arr))
327 327
328 def test_scatterGatherNonblocking(self):
328 def test_scatter_gather_nonblocking(self):
329 329 data = range(16)
330 330 view = self.client[:]
331 331 view.scatter('a', data, block=False)
@@ -43,17 +43,11 b' from IPython.external.decorator import decorator'
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 from IPython.utils import py3compat
47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
48 from IPython.utils.newserialized import serialize, unserialize
49 46 from IPython.zmq.log import EnginePUBHandler
50 47 from IPython.zmq.serialize import (
51 48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
52 49 )
53 50
54 if py3compat.PY3:
55 buffer = memoryview
56
57 51 #-----------------------------------------------------------------------------
58 52 # Classes
59 53 #-----------------------------------------------------------------------------
@@ -19,7 +19,22 b' import copy'
19 19 import sys
20 20 from types import FunctionType
21 21
22 try:
23 import cPickle as pickle
24 except ImportError:
25 import pickle
26
27 try:
28 import numpy
29 except:
30 numpy = None
31
22 32 import codeutil
33 import py3compat
34 from importstring import import_item
35
36 if py3compat.PY3:
37 buffer = memoryview
23 38
24 39 #-------------------------------------------------------------------------------
25 40 # Classes
@@ -33,27 +48,30 b' class CannedObject(object):'
33 48 for key in keys:
34 49 setattr(self.obj, key, can(getattr(obj, key)))
35 50
51 self.buffers = []
36 52
37 def getObject(self, g=None):
53 def get_object(self, g=None):
38 54 if g is None:
39 g = globals()
55 g = {}
40 56 for key in self.keys:
41 57 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
42 58 return self.obj
43 59
60
44 61 class Reference(CannedObject):
45 62 """object for wrapping a remote reference by name."""
46 63 def __init__(self, name):
47 64 if not isinstance(name, basestring):
48 65 raise TypeError("illegal name: %r"%name)
49 66 self.name = name
67 self.buffers = []
50 68
51 69 def __repr__(self):
52 70 return "<Reference: %r>"%self.name
53 71
54 def getObject(self, g=None):
72 def get_object(self, g=None):
55 73 if g is None:
56 g = globals()
74 g = {}
57 75
58 76 return eval(self.name, g)
59 77
@@ -61,16 +79,17 b' class Reference(CannedObject):'
61 79 class CannedFunction(CannedObject):
62 80
63 81 def __init__(self, f):
64 self._checkType(f)
82 self._check_type(f)
65 83 self.code = f.func_code
66 84 self.defaults = f.func_defaults
67 85 self.module = f.__module__ or '__main__'
68 86 self.__name__ = f.__name__
87 self.buffers = []
69 88
70 def _checkType(self, obj):
89 def _check_type(self, obj):
71 90 assert isinstance(obj, FunctionType), "Not a function type"
72 91
73 def getObject(self, g=None):
92 def get_object(self, g=None):
74 93 # try to load function back into its module:
75 94 if not self.module.startswith('__'):
76 95 try:
@@ -81,30 +100,65 b' class CannedFunction(CannedObject):'
81 100 g = sys.modules[self.module].__dict__
82 101
83 102 if g is None:
84 g = globals()
103 g = {}
85 104 newFunc = FunctionType(self.code, g, self.__name__, self.defaults)
86 105 return newFunc
87 106
107
108 class CannedArray(CannedObject):
109 def __init__(self, obj):
110 self.shape = obj.shape
111 self.dtype = obj.dtype
112 if sum(obj.shape) == 0:
113 # just pickle it
114 self.buffers = [pickle.dumps(obj, -1)]
115 else:
116 # ensure contiguous
117 obj = numpy.ascontiguousarray(obj, dtype=None)
118 self.buffers = [buffer(obj)]
119
120 def get_object(self, g=None):
121 data = self.buffers[0]
122 if sum(self.shape) == 0:
123 # no shape, we just pickled it
124 return pickle.loads(data)
125 else:
126 return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape)
127
128
129 class CannedBytes(CannedObject):
130 wrap = bytes
131 def __init__(self, obj):
132 self.buffers = [obj]
133
134 def get_object(self, g=None):
135 data = self.buffers[0]
136 return self.wrap(data)
137
138 def CannedBuffer(CannedBytes):
139 wrap = buffer
140
88 141 #-------------------------------------------------------------------------------
89 142 # Functions
90 143 #-------------------------------------------------------------------------------
91 144
145
92 146 def can(obj):
93 # import here to prevent module-level circular imports
94 from IPython.parallel import dependent
95 if isinstance(obj, dependent):
96 keys = ('f','df')
97 return CannedObject(obj, keys=keys)
98 elif isinstance(obj, FunctionType):
99 return CannedFunction(obj)
100 elif isinstance(obj,dict):
101 return canDict(obj)
102 elif isinstance(obj, (list,tuple)):
103 return canSequence(obj)
104 else:
147 """prepare an object for pickling"""
148 for cls,canner in can_map.iteritems():
149 if isinstance(cls, basestring):
150 try:
151 cls = import_item(cls)
152 except Exception:
153 # not importable
154 print "not importable: %r" % cls
155 continue
156 if isinstance(obj, cls):
157 return canner(obj)
105 158 return obj
106 159
107 def canDict(obj):
160 def can_dict(obj):
161 """can the *values* of a dict"""
108 162 if isinstance(obj, dict):
109 163 newobj = {}
110 164 for k, v in obj.iteritems():
@@ -113,7 +167,8 b' def canDict(obj):'
113 167 else:
114 168 return obj
115 169
116 def canSequence(obj):
170 def can_sequence(obj):
171 """can the elements of a sequence"""
117 172 if isinstance(obj, (list, tuple)):
118 173 t = type(obj)
119 174 return t([can(i) for i in obj])
@@ -121,16 +176,20 b' def canSequence(obj):'
121 176 return obj
122 177
123 178 def uncan(obj, g=None):
124 if isinstance(obj, CannedObject):
125 return obj.getObject(g)
126 elif isinstance(obj,dict):
127 return uncanDict(obj, g)
128 elif isinstance(obj, (list,tuple)):
129 return uncanSequence(obj, g)
130 else:
179 """invert canning"""
180 for cls,uncanner in uncan_map.iteritems():
181 if isinstance(cls, basestring):
182 try:
183 cls = import_item(cls)
184 except Exception:
185 # not importable
186 print "not importable: %r" % cls
187 continue
188 if isinstance(obj, cls):
189 return uncanner(obj, g)
131 190 return obj
132 191
133 def uncanDict(obj, g=None):
192 def uncan_dict(obj, g=None):
134 193 if isinstance(obj, dict):
135 194 newobj = {}
136 195 for k, v in obj.iteritems():
@@ -139,7 +198,7 b' def uncanDict(obj, g=None):'
139 198 else:
140 199 return obj
141 200
142 def uncanSequence(obj, g=None):
201 def uncan_sequence(obj, g=None):
143 202 if isinstance(obj, (list, tuple)):
144 203 t = type(obj)
145 204 return t([uncan(i,g) for i in obj])
@@ -147,5 +206,27 b' def uncanSequence(obj, g=None):'
147 206 return obj
148 207
149 208
150 def rebindFunctionGlobals(f, glbls):
151 return FunctionType(f.func_code, glbls)
209 #-------------------------------------------------------------------------------
210 # API dictionary
211 #-------------------------------------------------------------------------------
212
213 # These dicts can be extended for custom serialization of new objects
214
215 can_map = {
216 'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')),
217 'numpy.ndarray' : CannedArray,
218 FunctionType : CannedFunction,
219 bytes : CannedBytes,
220 buffer : CannedBuffer,
221 # dict : can_dict,
222 # list : can_sequence,
223 # tuple : can_sequence,
224 }
225
226 uncan_map = {
227 CannedObject : lambda obj, g: obj.get_object(g),
228 # dict : uncan_dict,
229 # list : uncan_sequence,
230 # tuple : uncan_sequence,
231 }
232
@@ -572,8 +572,8 b' class Kernel(Configurable):'
572 572 for key in ns.iterkeys():
573 573 working.pop(key)
574 574
575 packed_result,buf = serialize_object(result)
576 result_buf = [packed_result]+buf
575 result_buf = serialize_object(result)
576
577 577 except:
578 578 # invoke IPython traceback formatting
579 579 shell.showtraceback()
@@ -32,7 +32,9 b' except:'
32 32
33 33 # IPython imports
34 34 from IPython.utils import py3compat
35 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
35 from IPython.utils.pickleutil import (
36 can, uncan, can_sequence, uncan_sequence, CannedObject
37 )
36 38 from IPython.utils.newserialized import serialize, unserialize
37 39
38 40 if py3compat.PY3:
@@ -42,7 +44,32 b' if py3compat.PY3:'
42 44 # Serialization Functions
43 45 #-----------------------------------------------------------------------------
44 46
45 def serialize_object(obj, threshold=64e-6):
47 # maximum items to iterate through in a container
48 MAX_ITEMS = 64
49
50 def _extract_buffers(obj, threshold=1024):
51 """extract buffers larger than a certain threshold"""
52 buffers = []
53 if isinstance(obj, CannedObject) and obj.buffers:
54 for i,buf in enumerate(obj.buffers):
55 if len(buf) > threshold:
56 # buffer larger than threshold, prevent pickling
57 obj.buffers[i] = None
58 buffers.append(buf)
59 elif isinstance(buf, buffer):
60 # buffer too small for separate send, coerce to bytes
61 # because pickling buffer objects just results in broken pointers
62 obj.buffers[i] = bytes(buf)
63 return buffers
64
65 def _restore_buffers(obj, buffers):
66 """restore buffers extracted by """
67 if isinstance(obj, CannedObject) and obj.buffers:
68 for i,buf in enumerate(obj.buffers):
69 if buf is None:
70 obj.buffers[i] = buffers.pop(0)
71
72 def serialize_object(obj, threshold=1024):
46 73 """Serialize an object into a list of sendable buffers.
47 74
48 75 Parameters
@@ -50,76 +77,78 b' def serialize_object(obj, threshold=64e-6):'
50 77
51 78 obj : object
52 79 The object to be serialized
53 threshold : float
54 The threshold for not double-pickling the content.
55
80 threshold : int
81 The threshold (in bytes) for pulling out data buffers
82 to avoid pickling them.
56 83
57 84 Returns
58 85 -------
59 ('pmd', [bufs]) :
60 where pmd is the pickled metadata wrapper,
61 bufs is a list of data buffers
86 [bufs] : list of buffers representing the serialized object.
62 87 """
63 databuffers = []
64 if isinstance(obj, (list, tuple)):
65 clist = canSequence(obj)
66 slist = map(serialize, clist)
67 for s in slist:
68 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
69 databuffers.append(s.getData())
70 s.data = None
71 return pickle.dumps(slist,-1), databuffers
72 elif isinstance(obj, dict):
73 sobj = {}
88 buffers = []
89 if isinstance(obj, (list, tuple)) and len(obj) < MAX_ITEMS:
90 cobj = can_sequence(obj)
91 for c in cobj:
92 buffers.extend(_extract_buffers(c, threshold))
93 elif isinstance(obj, dict) and len(obj) < MAX_ITEMS:
94 cobj = {}
74 95 for k in sorted(obj.iterkeys()):
75 s = serialize(can(obj[k]))
76 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
77 databuffers.append(s.getData())
78 s.data = None
79 sobj[k] = s
80 return pickle.dumps(sobj,-1),databuffers
96 c = can(obj[k])
97 buffers.extend(_extract_buffers(c, threshold))
98 cobj[k] = c
81 99 else:
82 s = serialize(can(obj))
83 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
84 databuffers.append(s.getData())
85 s.data = None
86 return pickle.dumps(s,-1),databuffers
87
88
89 def unserialize_object(bufs):
90 """reconstruct an object serialized by serialize_object from data buffers."""
91 bufs = list(bufs)
92 sobj = pickle.loads(bufs.pop(0))
93 if isinstance(sobj, (list, tuple)):
94 for s in sobj:
95 if s.data is None:
96 s.data = bufs.pop(0)
97 return uncanSequence(map(unserialize, sobj)), bufs
98 elif isinstance(sobj, dict):
100 cobj = can(obj)
101 buffers.extend(_extract_buffers(cobj, threshold))
102
103 buffers.insert(0, pickle.dumps(cobj,-1))
104 return buffers
105
106 def unserialize_object(buffers, g=None):
107 """reconstruct an object serialized by serialize_object from data buffers.
108
109 Parameters
110 ----------
111
112 bufs : list of buffers/bytes
113
114 g : globals to be used when uncanning
115
116 Returns
117 -------
118
119 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
120 """
121 bufs = list(buffers)
122 canned = pickle.loads(bufs.pop(0))
123 if isinstance(canned, (list, tuple)) and len(canned) < MAX_ITEMS:
124 for c in canned:
125 _restore_buffers(c, bufs)
126 newobj = uncan_sequence(canned, g)
127 elif isinstance(canned, dict) and len(canned) < MAX_ITEMS:
99 128 newobj = {}
100 for k in sorted(sobj.iterkeys()):
101 s = sobj[k]
102 if s.data is None:
103 s.data = bufs.pop(0)
104 newobj[k] = uncan(unserialize(s))
105 return newobj, bufs
129 for k in sorted(canned.iterkeys()):
130 c = canned[k]
131 _restore_buffers(c, bufs)
132 newobj[k] = uncan(c, g)
106 133 else:
107 if sobj.data is None:
108 sobj.data = bufs.pop(0)
109 return uncan(unserialize(sobj)), bufs
134 _restore_buffers(canned, bufs)
135 newobj = uncan(canned, g)
110 136
111 def pack_apply_message(f, args, kwargs, threshold=64e-6):
137 return newobj, bufs
138
139 def pack_apply_message(f, args, kwargs, threshold=1024):
112 140 """pack up a function, args, and kwargs to be sent over the wire
113 141 as a series of buffers. Any object whose data is larger than `threshold`
114 will not have their data copied (currently only numpy arrays support zero-copy)"""
142 will not have their data copied (currently only numpy arrays support zero-copy)
143 """
115 144 msg = [pickle.dumps(can(f),-1)]
116 145 databuffers = [] # for large objects
117 sargs, bufs = serialize_object(args,threshold)
118 msg.append(sargs)
119 databuffers.extend(bufs)
120 skwargs, bufs = serialize_object(kwargs,threshold)
121 msg.append(skwargs)
122 databuffers.extend(bufs)
146 sargs = serialize_object(args,threshold)
147 msg.append(sargs[0])
148 databuffers.extend(sargs[1:])
149 skwargs = serialize_object(kwargs,threshold)
150 msg.append(skwargs[0])
151 databuffers.extend(skwargs[1:])
123 152 msg.extend(databuffers)
124 153 return msg
125 154
@@ -131,49 +160,16 b' def unpack_apply_message(bufs, g=None, copy=True):'
131 160 if not copy:
132 161 for i in range(3):
133 162 bufs[i] = bufs[i].bytes
134 cf = pickle.loads(bufs.pop(0))
135 sargs = list(pickle.loads(bufs.pop(0)))
136 skwargs = dict(pickle.loads(bufs.pop(0)))
137 # print sargs, skwargs
138 f = uncan(cf, g)
139 for sa in sargs:
140 if sa.data is None:
141 m = bufs.pop(0)
142 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
143 # always use a buffer, until memoryviews get sorted out
144 sa.data = buffer(m)
145 # disable memoryview support
146 # if copy:
147 # sa.data = buffer(m)
148 # else:
149 # sa.data = m.buffer
150 else:
151 if copy:
152 sa.data = m
153 else:
154 sa.data = m.bytes
155
156 args = uncanSequence(map(unserialize, sargs), g)
157 kwargs = {}
158 for k in sorted(skwargs.iterkeys()):
159 sa = skwargs[k]
160 if sa.data is None:
161 m = bufs.pop(0)
162 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
163 # always use a buffer, until memoryviews get sorted out
164 sa.data = buffer(m)
165 # disable memoryview support
166 # if copy:
167 # sa.data = buffer(m)
168 # else:
169 # sa.data = m.buffer
170 else:
171 if copy:
172 sa.data = m
173 else:
174 sa.data = m.bytes
175
176 kwargs[k] = uncan(unserialize(sa), g)
163 f = uncan(pickle.loads(bufs.pop(0)), g)
164 # sargs = bufs.pop(0)
165 # pop kwargs out, so first n-elements are args, serialized
166 skwargs = bufs.pop(1)
167 args, bufs = unserialize_object(bufs, g)
168 # put skwargs back in as the first element
169 bufs.insert(0, skwargs)
170 kwargs, bufs = unserialize_object(bufs, g)
171
172 assert not bufs, "Shouldn't be any data left over"
177 173
178 174 return f,args,kwargs
179 175
General Comments 0
You need to be logged in to leave comments. Login now