##// END OF EJS Templates
Merge pull request #2069 from minrk/betterserial...
Fernando Perez -
r8040:8796581d merge
parent child Browse files
Show More
@@ -0,0 +1,168 b''
1 """test serialization tools"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 import pickle
15
16 import nose.tools as nt
17
18 # from unittest import TestCaes
19 from IPython.zmq.serialize import serialize_object, unserialize_object
20 from IPython.testing import decorators as dec
21 from IPython.utils.pickleutil import CannedArray
22
23 #-------------------------------------------------------------------------------
24 # Globals and Utilities
25 #-------------------------------------------------------------------------------
26
27 def roundtrip(obj):
28 """roundtrip an object through serialization"""
29 bufs = serialize_object(obj)
30 obj2, remainder = unserialize_object(bufs)
31 nt.assert_equals(remainder, [])
32 return obj2
33
34 class C(object):
35 """dummy class for """
36
37 def __init__(self, **kwargs):
38 for key,value in kwargs.iteritems():
39 setattr(self, key, value)
40
41 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
42 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
43 #-------------------------------------------------------------------------------
44 # Tests
45 #-------------------------------------------------------------------------------
46
47 @dec.parametric
48 def test_roundtrip_simple():
49 for obj in [
50 'hello',
51 dict(a='b', b=10),
52 [1,2,'hi'],
53 (b'123', 'hello'),
54 ]:
55 obj2 = roundtrip(obj)
56 yield nt.assert_equals(obj, obj2)
57
58 @dec.parametric
59 def test_roundtrip_nested():
60 for obj in [
61 dict(a=range(5), b={1:b'hello'}),
62 [range(5),[range(3),(1,[b'whoda'])]],
63 ]:
64 obj2 = roundtrip(obj)
65 yield nt.assert_equals(obj, obj2)
66
67 @dec.parametric
68 def test_roundtrip_buffered():
69 for obj in [
70 dict(a=b"x"*1025),
71 b"hello"*500,
72 [b"hello"*501, 1,2,3]
73 ]:
74 bufs = serialize_object(obj)
75 yield nt.assert_equals(len(bufs), 2)
76 obj2, remainder = unserialize_object(bufs)
77 yield nt.assert_equals(remainder, [])
78 yield nt.assert_equals(obj, obj2)
79
80 def _scrub_nan(A):
81 """scrub nans out of empty arrays
82
83 since nan != nan
84 """
85 import numpy
86 if A.dtype.fields and A.shape:
87 for field in A.dtype.fields.keys():
88 try:
89 A[field][numpy.isnan(A[field])] = 0
90 except (TypeError, NotImplementedError):
91 # e.g. str dtype
92 pass
93
94 @dec.parametric
95 @dec.skip_without('numpy')
96 def test_numpy():
97 import numpy
98 from numpy.testing.utils import assert_array_equal
99 for shape in SHAPES:
100 for dtype in DTYPES:
101 A = numpy.empty(shape, dtype=dtype)
102 _scrub_nan(A)
103 bufs = serialize_object(A)
104 B, r = unserialize_object(bufs)
105 yield nt.assert_equals(r, [])
106 yield nt.assert_equals(A.shape, B.shape)
107 yield nt.assert_equals(A.dtype, B.dtype)
108 yield assert_array_equal(A,B)
109
110 @dec.parametric
111 @dec.skip_without('numpy')
112 def test_recarray():
113 import numpy
114 from numpy.testing.utils import assert_array_equal
115 for shape in SHAPES:
116 for dtype in [
117 [('f', float), ('s', '|S10')],
118 [('n', int), ('s', '|S1'), ('u', 'uint32')],
119 ]:
120 A = numpy.empty(shape, dtype=dtype)
121 _scrub_nan(A)
122
123 bufs = serialize_object(A)
124 B, r = unserialize_object(bufs)
125 yield nt.assert_equals(r, [])
126 yield nt.assert_equals(A.shape, B.shape)
127 yield nt.assert_equals(A.dtype, B.dtype)
128 yield assert_array_equal(A,B)
129
130 @dec.parametric
131 @dec.skip_without('numpy')
132 def test_numpy_in_seq():
133 import numpy
134 from numpy.testing.utils import assert_array_equal
135 for shape in SHAPES:
136 for dtype in DTYPES:
137 A = numpy.empty(shape, dtype=dtype)
138 _scrub_nan(A)
139 bufs = serialize_object((A,1,2,b'hello'))
140 canned = pickle.loads(bufs[0])
141 yield nt.assert_true(canned[0], CannedArray)
142 tup, r = unserialize_object(bufs)
143 B = tup[0]
144 yield nt.assert_equals(r, [])
145 yield nt.assert_equals(A.shape, B.shape)
146 yield nt.assert_equals(A.dtype, B.dtype)
147 yield assert_array_equal(A,B)
148
149 @dec.parametric
150 @dec.skip_without('numpy')
151 def test_numpy_in_dict():
152 import numpy
153 from numpy.testing.utils import assert_array_equal
154 for shape in SHAPES:
155 for dtype in DTYPES:
156 A = numpy.empty(shape, dtype=dtype)
157 _scrub_nan(A)
158 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
159 canned = pickle.loads(bufs[0])
160 yield nt.assert_true(canned['a'], CannedArray)
161 d, r = unserialize_object(bufs)
162 B = d['a']
163 yield nt.assert_equals(r, [])
164 yield nt.assert_equals(A.shape, B.shape)
165 yield nt.assert_equals(A.dtype, B.dtype)
166 yield assert_array_equal(A,B)
167
168
@@ -1209,7 +1209,10 b' class Client(HasTraits):'
1209 if not isinstance(metadata, dict):
1209 if not isinstance(metadata, dict):
1210 raise TypeError("metadata must be dict, not %s"%type(metadata))
1210 raise TypeError("metadata must be dict, not %s"%type(metadata))
1211
1211
1212 bufs = util.pack_apply_message(f,args,kwargs)
1212 bufs = util.pack_apply_message(f, args, kwargs,
1213 buffer_threshold=self.session.buffer_threshold,
1214 item_threshold=self.session.item_threshold,
1215 )
1213
1216
1214 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1217 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1215 metadata=metadata, track=track)
1218 metadata=metadata, track=track)
@@ -239,7 +239,7 b' class TestView(ClusterTestCase, ParametricTestCase):'
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view.scatter('a', a)
242 view.scatter('a', a, block=True)
243 b = view.gather('a', block=True)
243 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
244 assert_array_equal(b, a)
245
245
@@ -325,7 +325,7 b' class TestView(ClusterTestCase, ParametricTestCase):'
325 r = view.map_sync(lambda x:x, arr)
325 r = view.map_sync(lambda x:x, arr)
326 self.assertEqual(r, list(arr))
326 self.assertEqual(r, list(arr))
327
327
328 def test_scatterGatherNonblocking(self):
328 def test_scatter_gather_nonblocking(self):
329 data = range(16)
329 data = range(16)
330 view = self.client[:]
330 view = self.client[:]
331 view.scatter('a', data, block=False)
331 view.scatter('a', data, block=False)
@@ -43,17 +43,11 b' from IPython.external.decorator import decorator'
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.utils import py3compat
47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
48 from IPython.utils.newserialized import serialize, unserialize
49 from IPython.zmq.log import EnginePUBHandler
46 from IPython.zmq.log import EnginePUBHandler
50 from IPython.zmq.serialize import (
47 from IPython.zmq.serialize import (
51 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
52 )
49 )
53
50
54 if py3compat.PY3:
55 buffer = memoryview
56
57 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
58 # Classes
52 # Classes
59 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
@@ -16,10 +16,28 b' __docformat__ = "restructuredtext en"'
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 import copy
18 import copy
19 import logging
19 import sys
20 import sys
20 from types import FunctionType
21 from types import FunctionType
21
22
23 try:
24 import cPickle as pickle
25 except ImportError:
26 import pickle
27
28 try:
29 import numpy
30 except:
31 numpy = None
32
22 import codeutil
33 import codeutil
34 import py3compat
35 from importstring import import_item
36
37 from IPython.config import Application
38
39 if py3compat.PY3:
40 buffer = memoryview
23
41
24 #-------------------------------------------------------------------------------
42 #-------------------------------------------------------------------------------
25 # Classes
43 # Classes
@@ -32,14 +50,16 b' class CannedObject(object):'
32 self.obj = copy.copy(obj)
50 self.obj = copy.copy(obj)
33 for key in keys:
51 for key in keys:
34 setattr(self.obj, key, can(getattr(obj, key)))
52 setattr(self.obj, key, can(getattr(obj, key)))
53
54 self.buffers = []
35
55
36
56 def get_object(self, g=None):
37 def getObject(self, g=None):
38 if g is None:
57 if g is None:
39 g = globals()
58 g = {}
40 for key in self.keys:
59 for key in self.keys:
41 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
60 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
42 return self.obj
61 return self.obj
62
43
63
44 class Reference(CannedObject):
64 class Reference(CannedObject):
45 """object for wrapping a remote reference by name."""
65 """object for wrapping a remote reference by name."""
@@ -47,13 +67,14 b' class Reference(CannedObject):'
47 if not isinstance(name, basestring):
67 if not isinstance(name, basestring):
48 raise TypeError("illegal name: %r"%name)
68 raise TypeError("illegal name: %r"%name)
49 self.name = name
69 self.name = name
70 self.buffers = []
50
71
51 def __repr__(self):
72 def __repr__(self):
52 return "<Reference: %r>"%self.name
73 return "<Reference: %r>"%self.name
53
74
54 def getObject(self, g=None):
75 def get_object(self, g=None):
55 if g is None:
76 if g is None:
56 g = globals()
77 g = {}
57
78
58 return eval(self.name, g)
79 return eval(self.name, g)
59
80
@@ -61,16 +82,17 b' class Reference(CannedObject):'
61 class CannedFunction(CannedObject):
82 class CannedFunction(CannedObject):
62
83
63 def __init__(self, f):
84 def __init__(self, f):
64 self._checkType(f)
85 self._check_type(f)
65 self.code = f.func_code
86 self.code = f.func_code
66 self.defaults = f.func_defaults
87 self.defaults = f.func_defaults
67 self.module = f.__module__ or '__main__'
88 self.module = f.__module__ or '__main__'
68 self.__name__ = f.__name__
89 self.__name__ = f.__name__
90 self.buffers = []
69
91
70 def _checkType(self, obj):
92 def _check_type(self, obj):
71 assert isinstance(obj, FunctionType), "Not a function type"
93 assert isinstance(obj, FunctionType), "Not a function type"
72
94
73 def getObject(self, g=None):
95 def get_object(self, g=None):
74 # try to load function back into its module:
96 # try to load function back into its module:
75 if not self.module.startswith('__'):
97 if not self.module.startswith('__'):
76 try:
98 try:
@@ -81,30 +103,73 b' class CannedFunction(CannedObject):'
81 g = sys.modules[self.module].__dict__
103 g = sys.modules[self.module].__dict__
82
104
83 if g is None:
105 if g is None:
84 g = globals()
106 g = {}
85 newFunc = FunctionType(self.code, g, self.__name__, self.defaults)
107 newFunc = FunctionType(self.code, g, self.__name__, self.defaults)
86 return newFunc
108 return newFunc
87
109
110
111 class CannedArray(CannedObject):
112 def __init__(self, obj):
113 self.shape = obj.shape
114 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
115 if sum(obj.shape) == 0:
116 # just pickle it
117 self.buffers = [pickle.dumps(obj, -1)]
118 else:
119 # ensure contiguous
120 obj = numpy.ascontiguousarray(obj, dtype=None)
121 self.buffers = [buffer(obj)]
122
123 def get_object(self, g=None):
124 data = self.buffers[0]
125 if sum(self.shape) == 0:
126 # no shape, we just pickled it
127 return pickle.loads(data)
128 else:
129 return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape)
130
131
132 class CannedBytes(CannedObject):
133 wrap = bytes
134 def __init__(self, obj):
135 self.buffers = [obj]
136
137 def get_object(self, g=None):
138 data = self.buffers[0]
139 return self.wrap(data)
140
141 def CannedBuffer(CannedBytes):
142 wrap = buffer
143
88 #-------------------------------------------------------------------------------
144 #-------------------------------------------------------------------------------
89 # Functions
145 # Functions
90 #-------------------------------------------------------------------------------
146 #-------------------------------------------------------------------------------
91
147
92 def can(obj):
148 def _error(*args, **kwargs):
93 # import here to prevent module-level circular imports
149 if Application.initialized():
94 from IPython.parallel import dependent
150 logger = Application.instance().log
95 if isinstance(obj, dependent):
96 keys = ('f','df')
97 return CannedObject(obj, keys=keys)
98 elif isinstance(obj, FunctionType):
99 return CannedFunction(obj)
100 elif isinstance(obj,dict):
101 return canDict(obj)
102 elif isinstance(obj, (list,tuple)):
103 return canSequence(obj)
104 else:
151 else:
105 return obj
152 logger = logging.getLogger()
153 if not logger.handlers:
154 logging.basicConfig()
155 logger.error(*args, **kwargs)
106
156
107 def canDict(obj):
157 def can(obj):
158 """prepare an object for pickling"""
159 for cls,canner in can_map.iteritems():
160 if isinstance(cls, basestring):
161 try:
162 cls = import_item(cls)
163 except Exception:
164 _error("cannning class not importable: %r", cls, exc_info=True)
165 cls = None
166 continue
167 if isinstance(obj, cls):
168 return canner(obj)
169 return obj
170
171 def can_dict(obj):
172 """can the *values* of a dict"""
108 if isinstance(obj, dict):
173 if isinstance(obj, dict):
109 newobj = {}
174 newobj = {}
110 for k, v in obj.iteritems():
175 for k, v in obj.iteritems():
@@ -113,7 +178,8 b' def canDict(obj):'
113 else:
178 else:
114 return obj
179 return obj
115
180
116 def canSequence(obj):
181 def can_sequence(obj):
182 """can the elements of a sequence"""
117 if isinstance(obj, (list, tuple)):
183 if isinstance(obj, (list, tuple)):
118 t = type(obj)
184 t = type(obj)
119 return t([can(i) for i in obj])
185 return t([can(i) for i in obj])
@@ -121,16 +187,20 b' def canSequence(obj):'
121 return obj
187 return obj
122
188
123 def uncan(obj, g=None):
189 def uncan(obj, g=None):
124 if isinstance(obj, CannedObject):
190 """invert canning"""
125 return obj.getObject(g)
191 for cls,uncanner in uncan_map.iteritems():
126 elif isinstance(obj,dict):
192 if isinstance(cls, basestring):
127 return uncanDict(obj, g)
193 try:
128 elif isinstance(obj, (list,tuple)):
194 cls = import_item(cls)
129 return uncanSequence(obj, g)
195 except Exception:
130 else:
196 _error("uncanning class not importable: %r", cls, exc_info=True)
131 return obj
197 cls = None
132
198 continue
133 def uncanDict(obj, g=None):
199 if isinstance(obj, cls):
200 return uncanner(obj, g)
201 return obj
202
203 def uncan_dict(obj, g=None):
134 if isinstance(obj, dict):
204 if isinstance(obj, dict):
135 newobj = {}
205 newobj = {}
136 for k, v in obj.iteritems():
206 for k, v in obj.iteritems():
@@ -139,7 +209,7 b' def uncanDict(obj, g=None):'
139 else:
209 else:
140 return obj
210 return obj
141
211
142 def uncanSequence(obj, g=None):
212 def uncan_sequence(obj, g=None):
143 if isinstance(obj, (list, tuple)):
213 if isinstance(obj, (list, tuple)):
144 t = type(obj)
214 t = type(obj)
145 return t([uncan(i,g) for i in obj])
215 return t([uncan(i,g) for i in obj])
@@ -147,5 +217,21 b' def uncanSequence(obj, g=None):'
147 return obj
217 return obj
148
218
149
219
150 def rebindFunctionGlobals(f, glbls):
220 #-------------------------------------------------------------------------------
151 return FunctionType(f.func_code, glbls)
221 # API dictionary
222 #-------------------------------------------------------------------------------
223
224 # These dicts can be extended for custom serialization of new objects
225
226 can_map = {
227 'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')),
228 'numpy.ndarray' : CannedArray,
229 FunctionType : CannedFunction,
230 bytes : CannedBytes,
231 buffer : CannedBuffer,
232 }
233
234 uncan_map = {
235 CannedObject : lambda obj, g: obj.get_object(g),
236 }
237
@@ -570,8 +570,11 b' class Kernel(Configurable):'
570 for key in ns.iterkeys():
570 for key in ns.iterkeys():
571 working.pop(key)
571 working.pop(key)
572
572
573 packed_result,buf = serialize_object(result)
573 result_buf = serialize_object(result,
574 result_buf = [packed_result]+buf
574 buffer_threshold=self.session.buffer_threshold,
575 item_threshold=self.session.item_threshold,
576 )
577
575 except:
578 except:
576 # invoke IPython traceback formatting
579 # invoke IPython traceback formatting
577 shell.showtraceback()
580 shell.showtraceback()
@@ -32,8 +32,9 b' except:'
32
32
33 # IPython imports
33 # IPython imports
34 from IPython.utils import py3compat
34 from IPython.utils import py3compat
35 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
35 from IPython.utils.pickleutil import (
36 from IPython.utils.newserialized import serialize, unserialize
36 can, uncan, can_sequence, uncan_sequence, CannedObject
37 )
37
38
38 if py3compat.PY3:
39 if py3compat.PY3:
39 buffer = memoryview
40 buffer = memoryview
@@ -42,7 +43,33 b' if py3compat.PY3:'
42 # Serialization Functions
43 # Serialization Functions
43 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
44
45
45 def serialize_object(obj, threshold=64e-6):
46 # default values for the thresholds:
47 MAX_ITEMS = 64
48 MAX_BYTES = 1024
49
50 def _extract_buffers(obj, threshold=MAX_BYTES):
51 """extract buffers larger than a certain threshold"""
52 buffers = []
53 if isinstance(obj, CannedObject) and obj.buffers:
54 for i,buf in enumerate(obj.buffers):
55 if len(buf) > threshold:
56 # buffer larger than threshold, prevent pickling
57 obj.buffers[i] = None
58 buffers.append(buf)
59 elif isinstance(buf, buffer):
60 # buffer too small for separate send, coerce to bytes
61 # because pickling buffer objects just results in broken pointers
62 obj.buffers[i] = bytes(buf)
63 return buffers
64
65 def _restore_buffers(obj, buffers):
66 """restore buffers extracted by """
67 if isinstance(obj, CannedObject) and obj.buffers:
68 for i,buf in enumerate(obj.buffers):
69 if buf is None:
70 obj.buffers[i] = buffers.pop(0)
71
72 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
46 """Serialize an object into a list of sendable buffers.
73 """Serialize an object into a list of sendable buffers.
47
74
48 Parameters
75 Parameters
@@ -50,76 +77,83 b' def serialize_object(obj, threshold=64e-6):'
50
77
51 obj : object
78 obj : object
52 The object to be serialized
79 The object to be serialized
53 threshold : float
80 buffer_threshold : int
54 The threshold for not double-pickling the content.
81 The threshold (in bytes) for pulling out data buffers
55
82 to avoid pickling them.
83 item_threshold : int
84 The maximum number of items over which canning will iterate.
85 Containers (lists, dicts) larger than this will be pickled without
86 introspection.
56
87
57 Returns
88 Returns
58 -------
89 -------
59 ('pmd', [bufs]) :
90 [bufs] : list of buffers representing the serialized object.
60 where pmd is the pickled metadata wrapper,
61 bufs is a list of data buffers
62 """
91 """
63 databuffers = []
92 buffers = []
64 if isinstance(obj, (list, tuple)):
93 if isinstance(obj, (list, tuple)) and len(obj) < item_threshold:
65 clist = canSequence(obj)
94 cobj = can_sequence(obj)
66 slist = map(serialize, clist)
95 for c in cobj:
67 for s in slist:
96 buffers.extend(_extract_buffers(c, buffer_threshold))
68 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
97 elif isinstance(obj, dict) and len(obj) < item_threshold:
69 databuffers.append(s.getData())
98 cobj = {}
70 s.data = None
71 return pickle.dumps(slist,-1), databuffers
72 elif isinstance(obj, dict):
73 sobj = {}
74 for k in sorted(obj.iterkeys()):
99 for k in sorted(obj.iterkeys()):
75 s = serialize(can(obj[k]))
100 c = can(obj[k])
76 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
101 buffers.extend(_extract_buffers(c, buffer_threshold))
77 databuffers.append(s.getData())
102 cobj[k] = c
78 s.data = None
79 sobj[k] = s
80 return pickle.dumps(sobj,-1),databuffers
81 else:
103 else:
82 s = serialize(can(obj))
104 cobj = can(obj)
83 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
105 buffers.extend(_extract_buffers(cobj, buffer_threshold))
84 databuffers.append(s.getData())
106
85 s.data = None
107 buffers.insert(0, pickle.dumps(cobj,-1))
86 return pickle.dumps(s,-1),databuffers
108 return buffers
87
109
88
110 def unserialize_object(buffers, g=None):
89 def unserialize_object(bufs):
111 """reconstruct an object serialized by serialize_object from data buffers.
90 """reconstruct an object serialized by serialize_object from data buffers."""
112
91 bufs = list(bufs)
113 Parameters
92 sobj = pickle.loads(bufs.pop(0))
114 ----------
93 if isinstance(sobj, (list, tuple)):
115
94 for s in sobj:
116 bufs : list of buffers/bytes
95 if s.data is None:
117
96 s.data = bufs.pop(0)
118 g : globals to be used when uncanning
97 return uncanSequence(map(unserialize, sobj)), bufs
119
98 elif isinstance(sobj, dict):
120 Returns
121 -------
122
123 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
124 """
125 bufs = list(buffers)
126 canned = pickle.loads(bufs.pop(0))
127 if isinstance(canned, (list, tuple)) and len(canned) < MAX_ITEMS:
128 for c in canned:
129 _restore_buffers(c, bufs)
130 newobj = uncan_sequence(canned, g)
131 elif isinstance(canned, dict) and len(canned) < MAX_ITEMS:
99 newobj = {}
132 newobj = {}
100 for k in sorted(sobj.iterkeys()):
133 for k in sorted(canned.iterkeys()):
101 s = sobj[k]
134 c = canned[k]
102 if s.data is None:
135 _restore_buffers(c, bufs)
103 s.data = bufs.pop(0)
136 newobj[k] = uncan(c, g)
104 newobj[k] = uncan(unserialize(s))
105 return newobj, bufs
106 else:
137 else:
107 if sobj.data is None:
138 _restore_buffers(canned, bufs)
108 sobj.data = bufs.pop(0)
139 newobj = uncan(canned, g)
109 return uncan(unserialize(sobj)), bufs
140
141 return newobj, bufs
110
142
111 def pack_apply_message(f, args, kwargs, threshold=64e-6):
143 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
112 """pack up a function, args, and kwargs to be sent over the wire
144 """pack up a function, args, and kwargs to be sent over the wire
145
113 as a series of buffers. Any object whose data is larger than `threshold`
146 as a series of buffers. Any object whose data is larger than `threshold`
114 will not have their data copied (currently only numpy arrays support zero-copy)"""
147 will not have their data copied (currently only numpy arrays support zero-copy)
148 """
115 msg = [pickle.dumps(can(f),-1)]
149 msg = [pickle.dumps(can(f),-1)]
116 databuffers = [] # for large objects
150 databuffers = [] # for large objects
117 sargs, bufs = serialize_object(args,threshold)
151 sargs = serialize_object(args, buffer_threshold, item_threshold)
118 msg.append(sargs)
152 msg.append(sargs[0])
119 databuffers.extend(bufs)
153 databuffers.extend(sargs[1:])
120 skwargs, bufs = serialize_object(kwargs,threshold)
154 skwargs = serialize_object(kwargs, buffer_threshold, item_threshold)
121 msg.append(skwargs)
155 msg.append(skwargs[0])
122 databuffers.extend(bufs)
156 databuffers.extend(skwargs[1:])
123 msg.extend(databuffers)
157 msg.extend(databuffers)
124 return msg
158 return msg
125
159
@@ -131,49 +165,16 b' def unpack_apply_message(bufs, g=None, copy=True):'
131 if not copy:
165 if not copy:
132 for i in range(3):
166 for i in range(3):
133 bufs[i] = bufs[i].bytes
167 bufs[i] = bufs[i].bytes
134 cf = pickle.loads(bufs.pop(0))
168 f = uncan(pickle.loads(bufs.pop(0)), g)
135 sargs = list(pickle.loads(bufs.pop(0)))
169 # sargs = bufs.pop(0)
136 skwargs = dict(pickle.loads(bufs.pop(0)))
170 # pop kwargs out, so first n-elements are args, serialized
137 # print sargs, skwargs
171 skwargs = bufs.pop(1)
138 f = uncan(cf, g)
172 args, bufs = unserialize_object(bufs, g)
139 for sa in sargs:
173 # put skwargs back in as the first element
140 if sa.data is None:
174 bufs.insert(0, skwargs)
141 m = bufs.pop(0)
175 kwargs, bufs = unserialize_object(bufs, g)
142 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
143 # always use a buffer, until memoryviews get sorted out
144 sa.data = buffer(m)
145 # disable memoryview support
146 # if copy:
147 # sa.data = buffer(m)
148 # else:
149 # sa.data = m.buffer
150 else:
151 if copy:
152 sa.data = m
153 else:
154 sa.data = m.bytes
155
176
156 args = uncanSequence(map(unserialize, sargs), g)
177 assert not bufs, "Shouldn't be any data left over"
157 kwargs = {}
158 for k in sorted(skwargs.iterkeys()):
159 sa = skwargs[k]
160 if sa.data is None:
161 m = bufs.pop(0)
162 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
163 # always use a buffer, until memoryviews get sorted out
164 sa.data = buffer(m)
165 # disable memoryview support
166 # if copy:
167 # sa.data = buffer(m)
168 # else:
169 # sa.data = m.buffer
170 else:
171 if copy:
172 sa.data = m
173 else:
174 sa.data = m.bytes
175
176 kwargs[k] = uncan(unserialize(sa), g)
177
178
178 return f,args,kwargs
179 return f,args,kwargs
179
180
@@ -49,7 +49,8 b' from IPython.utils.importstring import import_item'
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 from IPython.utils.py3compat import str_to_bytes
50 from IPython.utils.py3compat import str_to_bytes
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 DottedObjectName, CUnicode, Dict)
52 DottedObjectName, CUnicode, Dict, Integer)
53 from IPython.zmq.serialize import MAX_ITEMS, MAX_BYTES
53
54
54 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
55 # utility functions
56 # utility functions
@@ -84,8 +85,9 b' pickle_unpacker = pickle.loads'
84 default_packer = json_packer
85 default_packer = json_packer
85 default_unpacker = json_unpacker
86 default_unpacker = json_unpacker
86
87
87 DELIM=b"<IDS|MSG>"
88 DELIM = b"<IDS|MSG>"
88
89 # singleton dummy tracker, which will always report as done
90 DONE = zmq.MessageTracker()
89
91
90 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
91 # Mixin tools for apps that use Sessions
93 # Mixin tools for apps that use Sessions
@@ -329,7 +331,18 b' class Session(Configurable):'
329 # unpacker is not checked - it is assumed to be
331 # unpacker is not checked - it is assumed to be
330 if not callable(new):
332 if not callable(new):
331 raise TypeError("unpacker must be callable, not %s"%type(new))
333 raise TypeError("unpacker must be callable, not %s"%type(new))
332
334
335 # thresholds:
336 copy_threshold = Integer(2**16, config=True,
337 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
338 buffer_threshold = Integer(MAX_BYTES, config=True,
339 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
340 item_threshold = Integer(MAX_ITEMS, config=True,
341 help="""The maximum number of items for a container to be introspected for custom serialization.
342 Containers larger than this are pickled outright.
343 """
344 )
345
333 def __init__(self, **kwargs):
346 def __init__(self, **kwargs):
334 """create a Session object
347 """create a Session object
335
348
@@ -544,11 +557,6 b' class Session(Configurable):'
544 -------
557 -------
545 msg : dict
558 msg : dict
546 The constructed message.
559 The constructed message.
547 (msg,tracker) : (dict, MessageTracker)
548 if track=True, then a 2-tuple will be returned,
549 the first element being the constructed
550 message, and the second being the MessageTracker
551
552 """
560 """
553
561
554 if not isinstance(stream, (zmq.Socket, ZMQStream)):
562 if not isinstance(stream, (zmq.Socket, ZMQStream)):
@@ -566,25 +574,18 b' class Session(Configurable):'
566
574
567 buffers = [] if buffers is None else buffers
575 buffers = [] if buffers is None else buffers
568 to_send = self.serialize(msg, ident)
576 to_send = self.serialize(msg, ident)
569 flag = 0
577 to_send.extend(buffers)
570 if buffers:
578 longest = max([ len(s) for s in to_send ])
571 flag = zmq.SNDMORE
579 copy = (longest < self.copy_threshold)
572 _track = False
580
581 if buffers and track and not copy:
582 # only really track when we are doing zero-copy buffers
583 tracker = stream.send_multipart(to_send, copy=False, track=True)
573 else:
584 else:
574 _track=track
585 # use dummy tracker, which will be done immediately
575 if track:
586 tracker = DONE
576 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
587 stream.send_multipart(to_send, copy=copy)
577 else:
578 tracker = stream.send_multipart(to_send, flag, copy=False)
579 for b in buffers[:-1]:
580 stream.send(b, flag, copy=False)
581 if buffers:
582 if track:
583 tracker = stream.send(buffers[-1], copy=False, track=track)
584 else:
585 tracker = stream.send(buffers[-1], copy=False)
586
588
587 # omsg = Message(msg)
588 if self.debug:
589 if self.debug:
589 pprint.pprint(msg)
590 pprint.pprint(msg)
590 pprint.pprint(to_send)
591 pprint.pprint(to_send)
@@ -146,9 +146,10 b' class TestSession(SessionTestCase):'
146 """test tracking messages"""
146 """test tracking messages"""
147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
148 s = self.session
148 s = self.session
149 s.copy_threshold = 1
149 stream = ZMQStream(a)
150 stream = ZMQStream(a)
150 msg = s.send(a, 'hello', track=False)
151 msg = s.send(a, 'hello', track=False)
151 self.assertTrue(msg['tracker'] is None)
152 self.assertTrue(msg['tracker'] is ss.DONE)
152 msg = s.send(a, 'hello', track=True)
153 msg = s.send(a, 'hello', track=True)
153 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
154 M = zmq.Message(b'hi there', track=True)
155 M = zmq.Message(b'hi there', track=True)
General Comments 0
You need to be logged in to leave comments. Login now