Show More
@@ -0,0 +1,168 b'' | |||||
|
1 | """test serialization tools""" | |||
|
2 | ||||
|
3 | #------------------------------------------------------------------------------- | |||
|
4 | # Copyright (C) 2011 The IPython Development Team | |||
|
5 | # | |||
|
6 | # Distributed under the terms of the BSD License. The full license is in | |||
|
7 | # the file COPYING, distributed as part of this software. | |||
|
8 | #------------------------------------------------------------------------------- | |||
|
9 | ||||
|
10 | #------------------------------------------------------------------------------- | |||
|
11 | # Imports | |||
|
12 | #------------------------------------------------------------------------------- | |||
|
13 | ||||
|
14 | import pickle | |||
|
15 | ||||
|
16 | import nose.tools as nt | |||
|
17 | ||||
|
18 | # from unittest import TestCaes | |||
|
19 | from IPython.zmq.serialize import serialize_object, unserialize_object | |||
|
20 | from IPython.testing import decorators as dec | |||
|
21 | from IPython.utils.pickleutil import CannedArray | |||
|
22 | ||||
|
23 | #------------------------------------------------------------------------------- | |||
|
24 | # Globals and Utilities | |||
|
25 | #------------------------------------------------------------------------------- | |||
|
26 | ||||
|
27 | def roundtrip(obj): | |||
|
28 | """roundtrip an object through serialization""" | |||
|
29 | bufs = serialize_object(obj) | |||
|
30 | obj2, remainder = unserialize_object(bufs) | |||
|
31 | nt.assert_equals(remainder, []) | |||
|
32 | return obj2 | |||
|
33 | ||||
|
34 | class C(object): | |||
|
35 | """dummy class for """ | |||
|
36 | ||||
|
37 | def __init__(self, **kwargs): | |||
|
38 | for key,value in kwargs.iteritems(): | |||
|
39 | setattr(self, key, value) | |||
|
40 | ||||
|
41 | SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,)) | |||
|
42 | DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10') | |||
|
43 | #------------------------------------------------------------------------------- | |||
|
44 | # Tests | |||
|
45 | #------------------------------------------------------------------------------- | |||
|
46 | ||||
|
47 | @dec.parametric | |||
|
48 | def test_roundtrip_simple(): | |||
|
49 | for obj in [ | |||
|
50 | 'hello', | |||
|
51 | dict(a='b', b=10), | |||
|
52 | [1,2,'hi'], | |||
|
53 | (b'123', 'hello'), | |||
|
54 | ]: | |||
|
55 | obj2 = roundtrip(obj) | |||
|
56 | yield nt.assert_equals(obj, obj2) | |||
|
57 | ||||
|
58 | @dec.parametric | |||
|
59 | def test_roundtrip_nested(): | |||
|
60 | for obj in [ | |||
|
61 | dict(a=range(5), b={1:b'hello'}), | |||
|
62 | [range(5),[range(3),(1,[b'whoda'])]], | |||
|
63 | ]: | |||
|
64 | obj2 = roundtrip(obj) | |||
|
65 | yield nt.assert_equals(obj, obj2) | |||
|
66 | ||||
|
67 | @dec.parametric | |||
|
68 | def test_roundtrip_buffered(): | |||
|
69 | for obj in [ | |||
|
70 | dict(a=b"x"*1025), | |||
|
71 | b"hello"*500, | |||
|
72 | [b"hello"*501, 1,2,3] | |||
|
73 | ]: | |||
|
74 | bufs = serialize_object(obj) | |||
|
75 | yield nt.assert_equals(len(bufs), 2) | |||
|
76 | obj2, remainder = unserialize_object(bufs) | |||
|
77 | yield nt.assert_equals(remainder, []) | |||
|
78 | yield nt.assert_equals(obj, obj2) | |||
|
79 | ||||
|
80 | def _scrub_nan(A): | |||
|
81 | """scrub nans out of empty arrays | |||
|
82 | ||||
|
83 | since nan != nan | |||
|
84 | """ | |||
|
85 | import numpy | |||
|
86 | if A.dtype.fields and A.shape: | |||
|
87 | for field in A.dtype.fields.keys(): | |||
|
88 | try: | |||
|
89 | A[field][numpy.isnan(A[field])] = 0 | |||
|
90 | except (TypeError, NotImplementedError): | |||
|
91 | # e.g. str dtype | |||
|
92 | pass | |||
|
93 | ||||
|
94 | @dec.parametric | |||
|
95 | @dec.skip_without('numpy') | |||
|
96 | def test_numpy(): | |||
|
97 | import numpy | |||
|
98 | from numpy.testing.utils import assert_array_equal | |||
|
99 | for shape in SHAPES: | |||
|
100 | for dtype in DTYPES: | |||
|
101 | A = numpy.empty(shape, dtype=dtype) | |||
|
102 | _scrub_nan(A) | |||
|
103 | bufs = serialize_object(A) | |||
|
104 | B, r = unserialize_object(bufs) | |||
|
105 | yield nt.assert_equals(r, []) | |||
|
106 | yield nt.assert_equals(A.shape, B.shape) | |||
|
107 | yield nt.assert_equals(A.dtype, B.dtype) | |||
|
108 | yield assert_array_equal(A,B) | |||
|
109 | ||||
|
110 | @dec.parametric | |||
|
111 | @dec.skip_without('numpy') | |||
|
112 | def test_recarray(): | |||
|
113 | import numpy | |||
|
114 | from numpy.testing.utils import assert_array_equal | |||
|
115 | for shape in SHAPES: | |||
|
116 | for dtype in [ | |||
|
117 | [('f', float), ('s', '|S10')], | |||
|
118 | [('n', int), ('s', '|S1'), ('u', 'uint32')], | |||
|
119 | ]: | |||
|
120 | A = numpy.empty(shape, dtype=dtype) | |||
|
121 | _scrub_nan(A) | |||
|
122 | ||||
|
123 | bufs = serialize_object(A) | |||
|
124 | B, r = unserialize_object(bufs) | |||
|
125 | yield nt.assert_equals(r, []) | |||
|
126 | yield nt.assert_equals(A.shape, B.shape) | |||
|
127 | yield nt.assert_equals(A.dtype, B.dtype) | |||
|
128 | yield assert_array_equal(A,B) | |||
|
129 | ||||
|
130 | @dec.parametric | |||
|
131 | @dec.skip_without('numpy') | |||
|
132 | def test_numpy_in_seq(): | |||
|
133 | import numpy | |||
|
134 | from numpy.testing.utils import assert_array_equal | |||
|
135 | for shape in SHAPES: | |||
|
136 | for dtype in DTYPES: | |||
|
137 | A = numpy.empty(shape, dtype=dtype) | |||
|
138 | _scrub_nan(A) | |||
|
139 | bufs = serialize_object((A,1,2,b'hello')) | |||
|
140 | canned = pickle.loads(bufs[0]) | |||
|
141 | yield nt.assert_true(canned[0], CannedArray) | |||
|
142 | tup, r = unserialize_object(bufs) | |||
|
143 | B = tup[0] | |||
|
144 | yield nt.assert_equals(r, []) | |||
|
145 | yield nt.assert_equals(A.shape, B.shape) | |||
|
146 | yield nt.assert_equals(A.dtype, B.dtype) | |||
|
147 | yield assert_array_equal(A,B) | |||
|
148 | ||||
|
149 | @dec.parametric | |||
|
150 | @dec.skip_without('numpy') | |||
|
151 | def test_numpy_in_dict(): | |||
|
152 | import numpy | |||
|
153 | from numpy.testing.utils import assert_array_equal | |||
|
154 | for shape in SHAPES: | |||
|
155 | for dtype in DTYPES: | |||
|
156 | A = numpy.empty(shape, dtype=dtype) | |||
|
157 | _scrub_nan(A) | |||
|
158 | bufs = serialize_object(dict(a=A,b=1,c=range(20))) | |||
|
159 | canned = pickle.loads(bufs[0]) | |||
|
160 | yield nt.assert_true(canned['a'], CannedArray) | |||
|
161 | d, r = unserialize_object(bufs) | |||
|
162 | B = d['a'] | |||
|
163 | yield nt.assert_equals(r, []) | |||
|
164 | yield nt.assert_equals(A.shape, B.shape) | |||
|
165 | yield nt.assert_equals(A.dtype, B.dtype) | |||
|
166 | yield assert_array_equal(A,B) | |||
|
167 | ||||
|
168 |
@@ -1209,7 +1209,10 b' class Client(HasTraits):' | |||||
1209 | if not isinstance(metadata, dict): |
|
1209 | if not isinstance(metadata, dict): | |
1210 | raise TypeError("metadata must be dict, not %s"%type(metadata)) |
|
1210 | raise TypeError("metadata must be dict, not %s"%type(metadata)) | |
1211 |
|
1211 | |||
1212 |
bufs = util.pack_apply_message(f,args,kwargs |
|
1212 | bufs = util.pack_apply_message(f, args, kwargs, | |
|
1213 | buffer_threshold=self.session.buffer_threshold, | |||
|
1214 | item_threshold=self.session.item_threshold, | |||
|
1215 | ) | |||
1213 |
|
1216 | |||
1214 | msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident, |
|
1217 | msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident, | |
1215 | metadata=metadata, track=track) |
|
1218 | metadata=metadata, track=track) |
@@ -239,7 +239,7 b' class TestView(ClusterTestCase, ParametricTestCase):' | |||||
239 | from numpy.testing.utils import assert_array_equal, assert_array_almost_equal |
|
239 | from numpy.testing.utils import assert_array_equal, assert_array_almost_equal | |
240 | view = self.client[:] |
|
240 | view = self.client[:] | |
241 | a = numpy.arange(64) |
|
241 | a = numpy.arange(64) | |
242 | view.scatter('a', a) |
|
242 | view.scatter('a', a, block=True) | |
243 | b = view.gather('a', block=True) |
|
243 | b = view.gather('a', block=True) | |
244 | assert_array_equal(b, a) |
|
244 | assert_array_equal(b, a) | |
245 |
|
245 | |||
@@ -325,7 +325,7 b' class TestView(ClusterTestCase, ParametricTestCase):' | |||||
325 | r = view.map_sync(lambda x:x, arr) |
|
325 | r = view.map_sync(lambda x:x, arr) | |
326 | self.assertEqual(r, list(arr)) |
|
326 | self.assertEqual(r, list(arr)) | |
327 |
|
327 | |||
328 |
def test_scatter |
|
328 | def test_scatter_gather_nonblocking(self): | |
329 | data = range(16) |
|
329 | data = range(16) | |
330 | view = self.client[:] |
|
330 | view = self.client[:] | |
331 | view.scatter('a', data, block=False) |
|
331 | view.scatter('a', data, block=False) |
@@ -43,17 +43,11 b' from IPython.external.decorator import decorator' | |||||
43 |
|
43 | |||
44 | # IPython imports |
|
44 | # IPython imports | |
45 | from IPython.config.application import Application |
|
45 | from IPython.config.application import Application | |
46 | from IPython.utils import py3compat |
|
|||
47 | from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence |
|
|||
48 | from IPython.utils.newserialized import serialize, unserialize |
|
|||
49 | from IPython.zmq.log import EnginePUBHandler |
|
46 | from IPython.zmq.log import EnginePUBHandler | |
50 | from IPython.zmq.serialize import ( |
|
47 | from IPython.zmq.serialize import ( | |
51 | unserialize_object, serialize_object, pack_apply_message, unpack_apply_message |
|
48 | unserialize_object, serialize_object, pack_apply_message, unpack_apply_message | |
52 | ) |
|
49 | ) | |
53 |
|
50 | |||
54 | if py3compat.PY3: |
|
|||
55 | buffer = memoryview |
|
|||
56 |
|
||||
57 | #----------------------------------------------------------------------------- |
|
51 | #----------------------------------------------------------------------------- | |
58 | # Classes |
|
52 | # Classes | |
59 | #----------------------------------------------------------------------------- |
|
53 | #----------------------------------------------------------------------------- |
@@ -16,10 +16,28 b' __docformat__ = "restructuredtext en"' | |||||
16 | #------------------------------------------------------------------------------- |
|
16 | #------------------------------------------------------------------------------- | |
17 |
|
17 | |||
18 | import copy |
|
18 | import copy | |
|
19 | import logging | |||
19 | import sys |
|
20 | import sys | |
20 | from types import FunctionType |
|
21 | from types import FunctionType | |
21 |
|
22 | |||
|
23 | try: | |||
|
24 | import cPickle as pickle | |||
|
25 | except ImportError: | |||
|
26 | import pickle | |||
|
27 | ||||
|
28 | try: | |||
|
29 | import numpy | |||
|
30 | except: | |||
|
31 | numpy = None | |||
|
32 | ||||
22 | import codeutil |
|
33 | import codeutil | |
|
34 | import py3compat | |||
|
35 | from importstring import import_item | |||
|
36 | ||||
|
37 | from IPython.config import Application | |||
|
38 | ||||
|
39 | if py3compat.PY3: | |||
|
40 | buffer = memoryview | |||
23 |
|
41 | |||
24 | #------------------------------------------------------------------------------- |
|
42 | #------------------------------------------------------------------------------- | |
25 | # Classes |
|
43 | # Classes | |
@@ -32,14 +50,16 b' class CannedObject(object):' | |||||
32 | self.obj = copy.copy(obj) |
|
50 | self.obj = copy.copy(obj) | |
33 | for key in keys: |
|
51 | for key in keys: | |
34 | setattr(self.obj, key, can(getattr(obj, key))) |
|
52 | setattr(self.obj, key, can(getattr(obj, key))) | |
|
53 | ||||
|
54 | self.buffers = [] | |||
35 |
|
55 | |||
36 |
|
56 | def get_object(self, g=None): | ||
37 | def getObject(self, g=None): |
|
|||
38 | if g is None: |
|
57 | if g is None: | |
39 |
g = |
|
58 | g = {} | |
40 | for key in self.keys: |
|
59 | for key in self.keys: | |
41 | setattr(self.obj, key, uncan(getattr(self.obj, key), g)) |
|
60 | setattr(self.obj, key, uncan(getattr(self.obj, key), g)) | |
42 | return self.obj |
|
61 | return self.obj | |
|
62 | ||||
43 |
|
63 | |||
44 | class Reference(CannedObject): |
|
64 | class Reference(CannedObject): | |
45 | """object for wrapping a remote reference by name.""" |
|
65 | """object for wrapping a remote reference by name.""" | |
@@ -47,13 +67,14 b' class Reference(CannedObject):' | |||||
47 | if not isinstance(name, basestring): |
|
67 | if not isinstance(name, basestring): | |
48 | raise TypeError("illegal name: %r"%name) |
|
68 | raise TypeError("illegal name: %r"%name) | |
49 | self.name = name |
|
69 | self.name = name | |
|
70 | self.buffers = [] | |||
50 |
|
71 | |||
51 | def __repr__(self): |
|
72 | def __repr__(self): | |
52 | return "<Reference: %r>"%self.name |
|
73 | return "<Reference: %r>"%self.name | |
53 |
|
74 | |||
54 |
def get |
|
75 | def get_object(self, g=None): | |
55 | if g is None: |
|
76 | if g is None: | |
56 |
g = |
|
77 | g = {} | |
57 |
|
78 | |||
58 | return eval(self.name, g) |
|
79 | return eval(self.name, g) | |
59 |
|
80 | |||
@@ -61,16 +82,17 b' class Reference(CannedObject):' | |||||
61 | class CannedFunction(CannedObject): |
|
82 | class CannedFunction(CannedObject): | |
62 |
|
83 | |||
63 | def __init__(self, f): |
|
84 | def __init__(self, f): | |
64 |
self._check |
|
85 | self._check_type(f) | |
65 | self.code = f.func_code |
|
86 | self.code = f.func_code | |
66 | self.defaults = f.func_defaults |
|
87 | self.defaults = f.func_defaults | |
67 | self.module = f.__module__ or '__main__' |
|
88 | self.module = f.__module__ or '__main__' | |
68 | self.__name__ = f.__name__ |
|
89 | self.__name__ = f.__name__ | |
|
90 | self.buffers = [] | |||
69 |
|
91 | |||
70 |
def _check |
|
92 | def _check_type(self, obj): | |
71 | assert isinstance(obj, FunctionType), "Not a function type" |
|
93 | assert isinstance(obj, FunctionType), "Not a function type" | |
72 |
|
94 | |||
73 |
def get |
|
95 | def get_object(self, g=None): | |
74 | # try to load function back into its module: |
|
96 | # try to load function back into its module: | |
75 | if not self.module.startswith('__'): |
|
97 | if not self.module.startswith('__'): | |
76 | try: |
|
98 | try: | |
@@ -81,30 +103,73 b' class CannedFunction(CannedObject):' | |||||
81 | g = sys.modules[self.module].__dict__ |
|
103 | g = sys.modules[self.module].__dict__ | |
82 |
|
104 | |||
83 | if g is None: |
|
105 | if g is None: | |
84 |
g = |
|
106 | g = {} | |
85 | newFunc = FunctionType(self.code, g, self.__name__, self.defaults) |
|
107 | newFunc = FunctionType(self.code, g, self.__name__, self.defaults) | |
86 | return newFunc |
|
108 | return newFunc | |
87 |
|
109 | |||
|
110 | ||||
|
111 | class CannedArray(CannedObject): | |||
|
112 | def __init__(self, obj): | |||
|
113 | self.shape = obj.shape | |||
|
114 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str | |||
|
115 | if sum(obj.shape) == 0: | |||
|
116 | # just pickle it | |||
|
117 | self.buffers = [pickle.dumps(obj, -1)] | |||
|
118 | else: | |||
|
119 | # ensure contiguous | |||
|
120 | obj = numpy.ascontiguousarray(obj, dtype=None) | |||
|
121 | self.buffers = [buffer(obj)] | |||
|
122 | ||||
|
123 | def get_object(self, g=None): | |||
|
124 | data = self.buffers[0] | |||
|
125 | if sum(self.shape) == 0: | |||
|
126 | # no shape, we just pickled it | |||
|
127 | return pickle.loads(data) | |||
|
128 | else: | |||
|
129 | return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape) | |||
|
130 | ||||
|
131 | ||||
|
132 | class CannedBytes(CannedObject): | |||
|
133 | wrap = bytes | |||
|
134 | def __init__(self, obj): | |||
|
135 | self.buffers = [obj] | |||
|
136 | ||||
|
137 | def get_object(self, g=None): | |||
|
138 | data = self.buffers[0] | |||
|
139 | return self.wrap(data) | |||
|
140 | ||||
|
141 | def CannedBuffer(CannedBytes): | |||
|
142 | wrap = buffer | |||
|
143 | ||||
88 | #------------------------------------------------------------------------------- |
|
144 | #------------------------------------------------------------------------------- | |
89 | # Functions |
|
145 | # Functions | |
90 | #------------------------------------------------------------------------------- |
|
146 | #------------------------------------------------------------------------------- | |
91 |
|
147 | |||
92 | def can(obj): |
|
148 | def _error(*args, **kwargs): | |
93 | # import here to prevent module-level circular imports |
|
149 | if Application.initialized(): | |
94 | from IPython.parallel import dependent |
|
150 | logger = Application.instance().log | |
95 | if isinstance(obj, dependent): |
|
|||
96 | keys = ('f','df') |
|
|||
97 | return CannedObject(obj, keys=keys) |
|
|||
98 | elif isinstance(obj, FunctionType): |
|
|||
99 | return CannedFunction(obj) |
|
|||
100 | elif isinstance(obj,dict): |
|
|||
101 | return canDict(obj) |
|
|||
102 | elif isinstance(obj, (list,tuple)): |
|
|||
103 | return canSequence(obj) |
|
|||
104 | else: |
|
151 | else: | |
105 | return obj |
|
152 | logger = logging.getLogger() | |
|
153 | if not logger.handlers: | |||
|
154 | logging.basicConfig() | |||
|
155 | logger.error(*args, **kwargs) | |||
106 |
|
156 | |||
107 |
def can |
|
157 | def can(obj): | |
|
158 | """prepare an object for pickling""" | |||
|
159 | for cls,canner in can_map.iteritems(): | |||
|
160 | if isinstance(cls, basestring): | |||
|
161 | try: | |||
|
162 | cls = import_item(cls) | |||
|
163 | except Exception: | |||
|
164 | _error("cannning class not importable: %r", cls, exc_info=True) | |||
|
165 | cls = None | |||
|
166 | continue | |||
|
167 | if isinstance(obj, cls): | |||
|
168 | return canner(obj) | |||
|
169 | return obj | |||
|
170 | ||||
|
171 | def can_dict(obj): | |||
|
172 | """can the *values* of a dict""" | |||
108 | if isinstance(obj, dict): |
|
173 | if isinstance(obj, dict): | |
109 | newobj = {} |
|
174 | newobj = {} | |
110 | for k, v in obj.iteritems(): |
|
175 | for k, v in obj.iteritems(): | |
@@ -113,7 +178,8 b' def canDict(obj):' | |||||
113 | else: |
|
178 | else: | |
114 | return obj |
|
179 | return obj | |
115 |
|
180 | |||
116 |
def can |
|
181 | def can_sequence(obj): | |
|
182 | """can the elements of a sequence""" | |||
117 | if isinstance(obj, (list, tuple)): |
|
183 | if isinstance(obj, (list, tuple)): | |
118 | t = type(obj) |
|
184 | t = type(obj) | |
119 | return t([can(i) for i in obj]) |
|
185 | return t([can(i) for i in obj]) | |
@@ -121,16 +187,20 b' def canSequence(obj):' | |||||
121 | return obj |
|
187 | return obj | |
122 |
|
188 | |||
123 | def uncan(obj, g=None): |
|
189 | def uncan(obj, g=None): | |
124 | if isinstance(obj, CannedObject): |
|
190 | """invert canning""" | |
125 | return obj.getObject(g) |
|
191 | for cls,uncanner in uncan_map.iteritems(): | |
126 |
|
|
192 | if isinstance(cls, basestring): | |
127 | return uncanDict(obj, g) |
|
193 | try: | |
128 | elif isinstance(obj, (list,tuple)): |
|
194 | cls = import_item(cls) | |
129 | return uncanSequence(obj, g) |
|
195 | except Exception: | |
130 | else: |
|
196 | _error("uncanning class not importable: %r", cls, exc_info=True) | |
131 | return obj |
|
197 | cls = None | |
132 |
|
198 | continue | ||
133 | def uncanDict(obj, g=None): |
|
199 | if isinstance(obj, cls): | |
|
200 | return uncanner(obj, g) | |||
|
201 | return obj | |||
|
202 | ||||
|
203 | def uncan_dict(obj, g=None): | |||
134 | if isinstance(obj, dict): |
|
204 | if isinstance(obj, dict): | |
135 | newobj = {} |
|
205 | newobj = {} | |
136 | for k, v in obj.iteritems(): |
|
206 | for k, v in obj.iteritems(): | |
@@ -139,7 +209,7 b' def uncanDict(obj, g=None):' | |||||
139 | else: |
|
209 | else: | |
140 | return obj |
|
210 | return obj | |
141 |
|
211 | |||
142 |
def uncan |
|
212 | def uncan_sequence(obj, g=None): | |
143 | if isinstance(obj, (list, tuple)): |
|
213 | if isinstance(obj, (list, tuple)): | |
144 | t = type(obj) |
|
214 | t = type(obj) | |
145 | return t([uncan(i,g) for i in obj]) |
|
215 | return t([uncan(i,g) for i in obj]) | |
@@ -147,5 +217,21 b' def uncanSequence(obj, g=None):' | |||||
147 | return obj |
|
217 | return obj | |
148 |
|
218 | |||
149 |
|
219 | |||
150 | def rebindFunctionGlobals(f, glbls): |
|
220 | #------------------------------------------------------------------------------- | |
151 | return FunctionType(f.func_code, glbls) |
|
221 | # API dictionary | |
|
222 | #------------------------------------------------------------------------------- | |||
|
223 | ||||
|
224 | # These dicts can be extended for custom serialization of new objects | |||
|
225 | ||||
|
226 | can_map = { | |||
|
227 | 'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')), | |||
|
228 | 'numpy.ndarray' : CannedArray, | |||
|
229 | FunctionType : CannedFunction, | |||
|
230 | bytes : CannedBytes, | |||
|
231 | buffer : CannedBuffer, | |||
|
232 | } | |||
|
233 | ||||
|
234 | uncan_map = { | |||
|
235 | CannedObject : lambda obj, g: obj.get_object(g), | |||
|
236 | } | |||
|
237 |
@@ -570,8 +570,11 b' class Kernel(Configurable):' | |||||
570 | for key in ns.iterkeys(): |
|
570 | for key in ns.iterkeys(): | |
571 | working.pop(key) |
|
571 | working.pop(key) | |
572 |
|
572 | |||
573 |
|
|
573 | result_buf = serialize_object(result, | |
574 | result_buf = [packed_result]+buf |
|
574 | buffer_threshold=self.session.buffer_threshold, | |
|
575 | item_threshold=self.session.item_threshold, | |||
|
576 | ) | |||
|
577 | ||||
575 | except: |
|
578 | except: | |
576 | # invoke IPython traceback formatting |
|
579 | # invoke IPython traceback formatting | |
577 | shell.showtraceback() |
|
580 | shell.showtraceback() |
@@ -32,8 +32,9 b' except:' | |||||
32 |
|
32 | |||
33 | # IPython imports |
|
33 | # IPython imports | |
34 | from IPython.utils import py3compat |
|
34 | from IPython.utils import py3compat | |
35 |
from IPython.utils.pickleutil import |
|
35 | from IPython.utils.pickleutil import ( | |
36 | from IPython.utils.newserialized import serialize, unserialize |
|
36 | can, uncan, can_sequence, uncan_sequence, CannedObject | |
|
37 | ) | |||
37 |
|
38 | |||
38 | if py3compat.PY3: |
|
39 | if py3compat.PY3: | |
39 | buffer = memoryview |
|
40 | buffer = memoryview | |
@@ -42,7 +43,33 b' if py3compat.PY3:' | |||||
42 | # Serialization Functions |
|
43 | # Serialization Functions | |
43 | #----------------------------------------------------------------------------- |
|
44 | #----------------------------------------------------------------------------- | |
44 |
|
45 | |||
45 | def serialize_object(obj, threshold=64e-6): |
|
46 | # default values for the thresholds: | |
|
47 | MAX_ITEMS = 64 | |||
|
48 | MAX_BYTES = 1024 | |||
|
49 | ||||
|
50 | def _extract_buffers(obj, threshold=MAX_BYTES): | |||
|
51 | """extract buffers larger than a certain threshold""" | |||
|
52 | buffers = [] | |||
|
53 | if isinstance(obj, CannedObject) and obj.buffers: | |||
|
54 | for i,buf in enumerate(obj.buffers): | |||
|
55 | if len(buf) > threshold: | |||
|
56 | # buffer larger than threshold, prevent pickling | |||
|
57 | obj.buffers[i] = None | |||
|
58 | buffers.append(buf) | |||
|
59 | elif isinstance(buf, buffer): | |||
|
60 | # buffer too small for separate send, coerce to bytes | |||
|
61 | # because pickling buffer objects just results in broken pointers | |||
|
62 | obj.buffers[i] = bytes(buf) | |||
|
63 | return buffers | |||
|
64 | ||||
|
65 | def _restore_buffers(obj, buffers): | |||
|
66 | """restore buffers extracted by """ | |||
|
67 | if isinstance(obj, CannedObject) and obj.buffers: | |||
|
68 | for i,buf in enumerate(obj.buffers): | |||
|
69 | if buf is None: | |||
|
70 | obj.buffers[i] = buffers.pop(0) | |||
|
71 | ||||
|
72 | def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS): | |||
46 | """Serialize an object into a list of sendable buffers. |
|
73 | """Serialize an object into a list of sendable buffers. | |
47 |
|
74 | |||
48 | Parameters |
|
75 | Parameters | |
@@ -50,76 +77,83 b' def serialize_object(obj, threshold=64e-6):' | |||||
50 |
|
77 | |||
51 | obj : object |
|
78 | obj : object | |
52 | The object to be serialized |
|
79 | The object to be serialized | |
53 |
threshold : |
|
80 | buffer_threshold : int | |
54 | The threshold for not double-pickling the content. |
|
81 | The threshold (in bytes) for pulling out data buffers | |
55 |
|
82 | to avoid pickling them. | ||
|
83 | item_threshold : int | |||
|
84 | The maximum number of items over which canning will iterate. | |||
|
85 | Containers (lists, dicts) larger than this will be pickled without | |||
|
86 | introspection. | |||
56 |
|
87 | |||
57 | Returns |
|
88 | Returns | |
58 | ------- |
|
89 | ------- | |
59 | ('pmd', [bufs]) : |
|
90 | [bufs] : list of buffers representing the serialized object. | |
60 | where pmd is the pickled metadata wrapper, |
|
|||
61 | bufs is a list of data buffers |
|
|||
62 | """ |
|
91 | """ | |
63 |
|
|
92 | buffers = [] | |
64 | if isinstance(obj, (list, tuple)): |
|
93 | if isinstance(obj, (list, tuple)) and len(obj) < item_threshold: | |
65 |
c |
|
94 | cobj = can_sequence(obj) | |
66 | slist = map(serialize, clist) |
|
95 | for c in cobj: | |
67 | for s in slist: |
|
96 | buffers.extend(_extract_buffers(c, buffer_threshold)) | |
68 | if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold: |
|
97 | elif isinstance(obj, dict) and len(obj) < item_threshold: | |
69 | databuffers.append(s.getData()) |
|
98 | cobj = {} | |
70 | s.data = None |
|
|||
71 | return pickle.dumps(slist,-1), databuffers |
|
|||
72 | elif isinstance(obj, dict): |
|
|||
73 | sobj = {} |
|
|||
74 | for k in sorted(obj.iterkeys()): |
|
99 | for k in sorted(obj.iterkeys()): | |
75 |
|
|
100 | c = can(obj[k]) | |
76 | if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold: |
|
101 | buffers.extend(_extract_buffers(c, buffer_threshold)) | |
77 | databuffers.append(s.getData()) |
|
102 | cobj[k] = c | |
78 | s.data = None |
|
|||
79 | sobj[k] = s |
|
|||
80 | return pickle.dumps(sobj,-1),databuffers |
|
|||
81 | else: |
|
103 | else: | |
82 |
|
|
104 | cobj = can(obj) | |
83 | if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold: |
|
105 | buffers.extend(_extract_buffers(cobj, buffer_threshold)) | |
84 | databuffers.append(s.getData()) |
|
106 | ||
85 | s.data = None |
|
107 | buffers.insert(0, pickle.dumps(cobj,-1)) | |
86 | return pickle.dumps(s,-1),databuffers |
|
108 | return buffers | |
87 |
|
109 | |||
88 |
|
110 | def unserialize_object(buffers, g=None): | ||
89 | def unserialize_object(bufs): |
|
111 | """reconstruct an object serialized by serialize_object from data buffers. | |
90 | """reconstruct an object serialized by serialize_object from data buffers.""" |
|
112 | ||
91 | bufs = list(bufs) |
|
113 | Parameters | |
92 | sobj = pickle.loads(bufs.pop(0)) |
|
114 | ---------- | |
93 | if isinstance(sobj, (list, tuple)): |
|
115 | ||
94 | for s in sobj: |
|
116 | bufs : list of buffers/bytes | |
95 | if s.data is None: |
|
117 | ||
96 | s.data = bufs.pop(0) |
|
118 | g : globals to be used when uncanning | |
97 | return uncanSequence(map(unserialize, sobj)), bufs |
|
119 | ||
98 | elif isinstance(sobj, dict): |
|
120 | Returns | |
|
121 | ------- | |||
|
122 | ||||
|
123 | (newobj, bufs) : unpacked object, and the list of remaining unused buffers. | |||
|
124 | """ | |||
|
125 | bufs = list(buffers) | |||
|
126 | canned = pickle.loads(bufs.pop(0)) | |||
|
127 | if isinstance(canned, (list, tuple)) and len(canned) < MAX_ITEMS: | |||
|
128 | for c in canned: | |||
|
129 | _restore_buffers(c, bufs) | |||
|
130 | newobj = uncan_sequence(canned, g) | |||
|
131 | elif isinstance(canned, dict) and len(canned) < MAX_ITEMS: | |||
99 | newobj = {} |
|
132 | newobj = {} | |
100 |
for k in sorted( |
|
133 | for k in sorted(canned.iterkeys()): | |
101 |
|
|
134 | c = canned[k] | |
102 | if s.data is None: |
|
135 | _restore_buffers(c, bufs) | |
103 | s.data = bufs.pop(0) |
|
136 | newobj[k] = uncan(c, g) | |
104 | newobj[k] = uncan(unserialize(s)) |
|
|||
105 | return newobj, bufs |
|
|||
106 | else: |
|
137 | else: | |
107 | if sobj.data is None: |
|
138 | _restore_buffers(canned, bufs) | |
108 | sobj.data = bufs.pop(0) |
|
139 | newobj = uncan(canned, g) | |
109 | return uncan(unserialize(sobj)), bufs |
|
140 | ||
|
141 | return newobj, bufs | |||
110 |
|
142 | |||
111 |
def pack_apply_message(f, args, kwargs, threshold= |
|
143 | def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS): | |
112 | """pack up a function, args, and kwargs to be sent over the wire |
|
144 | """pack up a function, args, and kwargs to be sent over the wire | |
|
145 | ||||
113 | as a series of buffers. Any object whose data is larger than `threshold` |
|
146 | as a series of buffers. Any object whose data is larger than `threshold` | |
114 |
will not have their data copied (currently only numpy arrays support zero-copy) |
|
147 | will not have their data copied (currently only numpy arrays support zero-copy) | |
|
148 | """ | |||
115 | msg = [pickle.dumps(can(f),-1)] |
|
149 | msg = [pickle.dumps(can(f),-1)] | |
116 | databuffers = [] # for large objects |
|
150 | databuffers = [] # for large objects | |
117 |
sargs |
|
151 | sargs = serialize_object(args, buffer_threshold, item_threshold) | |
118 | msg.append(sargs) |
|
152 | msg.append(sargs[0]) | |
119 |
databuffers.extend( |
|
153 | databuffers.extend(sargs[1:]) | |
120 |
skwargs |
|
154 | skwargs = serialize_object(kwargs, buffer_threshold, item_threshold) | |
121 | msg.append(skwargs) |
|
155 | msg.append(skwargs[0]) | |
122 |
databuffers.extend( |
|
156 | databuffers.extend(skwargs[1:]) | |
123 | msg.extend(databuffers) |
|
157 | msg.extend(databuffers) | |
124 | return msg |
|
158 | return msg | |
125 |
|
159 | |||
@@ -131,49 +165,16 b' def unpack_apply_message(bufs, g=None, copy=True):' | |||||
131 | if not copy: |
|
165 | if not copy: | |
132 | for i in range(3): |
|
166 | for i in range(3): | |
133 | bufs[i] = bufs[i].bytes |
|
167 | bufs[i] = bufs[i].bytes | |
134 |
|
|
168 | f = uncan(pickle.loads(bufs.pop(0)), g) | |
135 |
sargs = |
|
169 | # sargs = bufs.pop(0) | |
136 | skwargs = dict(pickle.loads(bufs.pop(0))) |
|
170 | # pop kwargs out, so first n-elements are args, serialized | |
137 | # print sargs, skwargs |
|
171 | skwargs = bufs.pop(1) | |
138 | f = uncan(cf, g) |
|
172 | args, bufs = unserialize_object(bufs, g) | |
139 | for sa in sargs: |
|
173 | # put skwargs back in as the first element | |
140 | if sa.data is None: |
|
174 | bufs.insert(0, skwargs) | |
141 | m = bufs.pop(0) |
|
175 | kwargs, bufs = unserialize_object(bufs, g) | |
142 | if sa.getTypeDescriptor() in ('buffer', 'ndarray'): |
|
|||
143 | # always use a buffer, until memoryviews get sorted out |
|
|||
144 | sa.data = buffer(m) |
|
|||
145 | # disable memoryview support |
|
|||
146 | # if copy: |
|
|||
147 | # sa.data = buffer(m) |
|
|||
148 | # else: |
|
|||
149 | # sa.data = m.buffer |
|
|||
150 | else: |
|
|||
151 | if copy: |
|
|||
152 | sa.data = m |
|
|||
153 | else: |
|
|||
154 | sa.data = m.bytes |
|
|||
155 |
|
176 | |||
156 | args = uncanSequence(map(unserialize, sargs), g) |
|
177 | assert not bufs, "Shouldn't be any data left over" | |
157 | kwargs = {} |
|
|||
158 | for k in sorted(skwargs.iterkeys()): |
|
|||
159 | sa = skwargs[k] |
|
|||
160 | if sa.data is None: |
|
|||
161 | m = bufs.pop(0) |
|
|||
162 | if sa.getTypeDescriptor() in ('buffer', 'ndarray'): |
|
|||
163 | # always use a buffer, until memoryviews get sorted out |
|
|||
164 | sa.data = buffer(m) |
|
|||
165 | # disable memoryview support |
|
|||
166 | # if copy: |
|
|||
167 | # sa.data = buffer(m) |
|
|||
168 | # else: |
|
|||
169 | # sa.data = m.buffer |
|
|||
170 | else: |
|
|||
171 | if copy: |
|
|||
172 | sa.data = m |
|
|||
173 | else: |
|
|||
174 | sa.data = m.bytes |
|
|||
175 |
|
||||
176 | kwargs[k] = uncan(unserialize(sa), g) |
|
|||
177 |
|
178 | |||
178 | return f,args,kwargs |
|
179 | return f,args,kwargs | |
179 |
|
180 |
@@ -49,7 +49,8 b' from IPython.utils.importstring import import_item' | |||||
49 | from IPython.utils.jsonutil import extract_dates, squash_dates, date_default |
|
49 | from IPython.utils.jsonutil import extract_dates, squash_dates, date_default | |
50 | from IPython.utils.py3compat import str_to_bytes |
|
50 | from IPython.utils.py3compat import str_to_bytes | |
51 | from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set, |
|
51 | from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set, | |
52 | DottedObjectName, CUnicode, Dict) |
|
52 | DottedObjectName, CUnicode, Dict, Integer) | |
|
53 | from IPython.zmq.serialize import MAX_ITEMS, MAX_BYTES | |||
53 |
|
54 | |||
54 | #----------------------------------------------------------------------------- |
|
55 | #----------------------------------------------------------------------------- | |
55 | # utility functions |
|
56 | # utility functions | |
@@ -84,8 +85,9 b' pickle_unpacker = pickle.loads' | |||||
84 | default_packer = json_packer |
|
85 | default_packer = json_packer | |
85 | default_unpacker = json_unpacker |
|
86 | default_unpacker = json_unpacker | |
86 |
|
87 | |||
87 | DELIM=b"<IDS|MSG>" |
|
88 | DELIM = b"<IDS|MSG>" | |
88 |
|
89 | # singleton dummy tracker, which will always report as done | ||
|
90 | DONE = zmq.MessageTracker() | |||
89 |
|
91 | |||
90 | #----------------------------------------------------------------------------- |
|
92 | #----------------------------------------------------------------------------- | |
91 | # Mixin tools for apps that use Sessions |
|
93 | # Mixin tools for apps that use Sessions | |
@@ -329,7 +331,18 b' class Session(Configurable):' | |||||
329 | # unpacker is not checked - it is assumed to be |
|
331 | # unpacker is not checked - it is assumed to be | |
330 | if not callable(new): |
|
332 | if not callable(new): | |
331 | raise TypeError("unpacker must be callable, not %s"%type(new)) |
|
333 | raise TypeError("unpacker must be callable, not %s"%type(new)) | |
332 |
|
334 | |||
|
335 | # thresholds: | |||
|
336 | copy_threshold = Integer(2**16, config=True, | |||
|
337 | help="Threshold (in bytes) beyond which a buffer should be sent without copying.") | |||
|
338 | buffer_threshold = Integer(MAX_BYTES, config=True, | |||
|
339 | help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.") | |||
|
340 | item_threshold = Integer(MAX_ITEMS, config=True, | |||
|
341 | help="""The maximum number of items for a container to be introspected for custom serialization. | |||
|
342 | Containers larger than this are pickled outright. | |||
|
343 | """ | |||
|
344 | ) | |||
|
345 | ||||
333 | def __init__(self, **kwargs): |
|
346 | def __init__(self, **kwargs): | |
334 | """create a Session object |
|
347 | """create a Session object | |
335 |
|
348 | |||
@@ -544,11 +557,6 b' class Session(Configurable):' | |||||
544 | ------- |
|
557 | ------- | |
545 | msg : dict |
|
558 | msg : dict | |
546 | The constructed message. |
|
559 | The constructed message. | |
547 | (msg,tracker) : (dict, MessageTracker) |
|
|||
548 | if track=True, then a 2-tuple will be returned, |
|
|||
549 | the first element being the constructed |
|
|||
550 | message, and the second being the MessageTracker |
|
|||
551 |
|
||||
552 | """ |
|
560 | """ | |
553 |
|
561 | |||
554 | if not isinstance(stream, (zmq.Socket, ZMQStream)): |
|
562 | if not isinstance(stream, (zmq.Socket, ZMQStream)): | |
@@ -566,25 +574,18 b' class Session(Configurable):' | |||||
566 |
|
574 | |||
567 | buffers = [] if buffers is None else buffers |
|
575 | buffers = [] if buffers is None else buffers | |
568 | to_send = self.serialize(msg, ident) |
|
576 | to_send = self.serialize(msg, ident) | |
569 | flag = 0 |
|
577 | to_send.extend(buffers) | |
570 | if buffers: |
|
578 | longest = max([ len(s) for s in to_send ]) | |
571 | flag = zmq.SNDMORE |
|
579 | copy = (longest < self.copy_threshold) | |
572 | _track = False |
|
580 | ||
|
581 | if buffers and track and not copy: | |||
|
582 | # only really track when we are doing zero-copy buffers | |||
|
583 | tracker = stream.send_multipart(to_send, copy=False, track=True) | |||
573 | else: |
|
584 | else: | |
574 | _track=track |
|
585 | # use dummy tracker, which will be done immediately | |
575 |
|
|
586 | tracker = DONE | |
576 |
|
|
587 | stream.send_multipart(to_send, copy=copy) | |
577 | else: |
|
|||
578 | tracker = stream.send_multipart(to_send, flag, copy=False) |
|
|||
579 | for b in buffers[:-1]: |
|
|||
580 | stream.send(b, flag, copy=False) |
|
|||
581 | if buffers: |
|
|||
582 | if track: |
|
|||
583 | tracker = stream.send(buffers[-1], copy=False, track=track) |
|
|||
584 | else: |
|
|||
585 | tracker = stream.send(buffers[-1], copy=False) |
|
|||
586 |
|
588 | |||
587 | # omsg = Message(msg) |
|
|||
588 | if self.debug: |
|
589 | if self.debug: | |
589 | pprint.pprint(msg) |
|
590 | pprint.pprint(msg) | |
590 | pprint.pprint(to_send) |
|
591 | pprint.pprint(to_send) |
@@ -146,9 +146,10 b' class TestSession(SessionTestCase):' | |||||
146 | """test tracking messages""" |
|
146 | """test tracking messages""" | |
147 | a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) |
|
147 | a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) | |
148 | s = self.session |
|
148 | s = self.session | |
|
149 | s.copy_threshold = 1 | |||
149 | stream = ZMQStream(a) |
|
150 | stream = ZMQStream(a) | |
150 | msg = s.send(a, 'hello', track=False) |
|
151 | msg = s.send(a, 'hello', track=False) | |
151 |
self.assertTrue(msg['tracker'] is |
|
152 | self.assertTrue(msg['tracker'] is ss.DONE) | |
152 | msg = s.send(a, 'hello', track=True) |
|
153 | msg = s.send(a, 'hello', track=True) | |
153 | self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) |
|
154 | self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) | |
154 | M = zmq.Message(b'hi there', track=True) |
|
155 | M = zmq.Message(b'hi there', track=True) |
General Comments 0
You need to be logged in to leave comments.
Login now