##// END OF EJS Templates
Fix NaN warnings and failures in test_serialize...
MinRK -
Show More
@@ -1,229 +1,208 b''
1 1 """test serialization tools"""
2 2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
13 5
14 6 import pickle
15 7 from collections import namedtuple
16 8
17 9 import nose.tools as nt
18 10
19 11 # from unittest import TestCaes
20 12 from IPython.kernel.zmq.serialize import serialize_object, unserialize_object
21 13 from IPython.testing import decorators as dec
22 14 from IPython.utils.pickleutil import CannedArray, CannedClass
23 15 from IPython.utils.py3compat import iteritems
24 16 from IPython.parallel import interactive
25 17
26 18 #-------------------------------------------------------------------------------
27 19 # Globals and Utilities
28 20 #-------------------------------------------------------------------------------
29 21
30 22 def roundtrip(obj):
31 23 """roundtrip an object through serialization"""
32 24 bufs = serialize_object(obj)
33 25 obj2, remainder = unserialize_object(bufs)
34 26 nt.assert_equals(remainder, [])
35 27 return obj2
36 28
37 29 class C(object):
38 30 """dummy class for """
39 31
40 32 def __init__(self, **kwargs):
41 33 for key,value in iteritems(kwargs):
42 34 setattr(self, key, value)
43 35
44 36 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
45 37 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
38
46 39 #-------------------------------------------------------------------------------
47 40 # Tests
48 41 #-------------------------------------------------------------------------------
49 42
43 def new_array(shape, dtype):
44 import numpy
45 return numpy.random.random(shape).astype(dtype)
46
50 47 def test_roundtrip_simple():
51 48 for obj in [
52 49 'hello',
53 50 dict(a='b', b=10),
54 51 [1,2,'hi'],
55 52 (b'123', 'hello'),
56 53 ]:
57 54 obj2 = roundtrip(obj)
58 55 nt.assert_equal(obj, obj2)
59 56
60 57 def test_roundtrip_nested():
61 58 for obj in [
62 59 dict(a=range(5), b={1:b'hello'}),
63 60 [range(5),[range(3),(1,[b'whoda'])]],
64 61 ]:
65 62 obj2 = roundtrip(obj)
66 63 nt.assert_equal(obj, obj2)
67 64
68 65 def test_roundtrip_buffered():
69 66 for obj in [
70 67 dict(a=b"x"*1025),
71 68 b"hello"*500,
72 69 [b"hello"*501, 1,2,3]
73 70 ]:
74 71 bufs = serialize_object(obj)
75 72 nt.assert_equal(len(bufs), 2)
76 73 obj2, remainder = unserialize_object(bufs)
77 74 nt.assert_equal(remainder, [])
78 75 nt.assert_equal(obj, obj2)
79 76
80 def _scrub_nan(A):
81 """scrub nans out of empty arrays
82
83 since nan != nan
84 """
85 import numpy
86 if A.dtype.fields and A.shape:
87 for field in A.dtype.fields.keys():
88 try:
89 A[field][numpy.isnan(A[field])] = 0
90 except (TypeError, NotImplementedError):
91 # e.g. str dtype
92 pass
93
94 77 @dec.skip_without('numpy')
95 78 def test_numpy():
96 79 import numpy
97 80 from numpy.testing.utils import assert_array_equal
98 81 for shape in SHAPES:
99 82 for dtype in DTYPES:
100 A = numpy.empty(shape, dtype=dtype)
101 _scrub_nan(A)
83 A = new_array(shape, dtype=dtype)
102 84 bufs = serialize_object(A)
103 85 B, r = unserialize_object(bufs)
104 86 nt.assert_equal(r, [])
105 87 nt.assert_equal(A.shape, B.shape)
106 88 nt.assert_equal(A.dtype, B.dtype)
107 89 assert_array_equal(A,B)
108 90
109 91 @dec.skip_without('numpy')
110 92 def test_recarray():
111 93 import numpy
112 94 from numpy.testing.utils import assert_array_equal
113 95 for shape in SHAPES:
114 96 for dtype in [
115 97 [('f', float), ('s', '|S10')],
116 98 [('n', int), ('s', '|S1'), ('u', 'uint32')],
117 99 ]:
118 A = numpy.empty(shape, dtype=dtype)
119 _scrub_nan(A)
100 A = new_array(shape, dtype=dtype)
120 101
121 102 bufs = serialize_object(A)
122 103 B, r = unserialize_object(bufs)
123 104 nt.assert_equal(r, [])
124 105 nt.assert_equal(A.shape, B.shape)
125 106 nt.assert_equal(A.dtype, B.dtype)
126 107 assert_array_equal(A,B)
127 108
128 109 @dec.skip_without('numpy')
129 110 def test_numpy_in_seq():
130 111 import numpy
131 112 from numpy.testing.utils import assert_array_equal
132 113 for shape in SHAPES:
133 114 for dtype in DTYPES:
134 A = numpy.empty(shape, dtype=dtype)
135 _scrub_nan(A)
115 A = new_array(shape, dtype=dtype)
136 116 bufs = serialize_object((A,1,2,b'hello'))
137 117 canned = pickle.loads(bufs[0])
138 118 nt.assert_is_instance(canned[0], CannedArray)
139 119 tup, r = unserialize_object(bufs)
140 120 B = tup[0]
141 121 nt.assert_equal(r, [])
142 122 nt.assert_equal(A.shape, B.shape)
143 123 nt.assert_equal(A.dtype, B.dtype)
144 124 assert_array_equal(A,B)
145 125
146 126 @dec.skip_without('numpy')
147 127 def test_numpy_in_dict():
148 128 import numpy
149 129 from numpy.testing.utils import assert_array_equal
150 130 for shape in SHAPES:
151 131 for dtype in DTYPES:
152 A = numpy.empty(shape, dtype=dtype)
153 _scrub_nan(A)
132 A = new_array(shape, dtype=dtype)
154 133 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
155 134 canned = pickle.loads(bufs[0])
156 135 nt.assert_is_instance(canned['a'], CannedArray)
157 136 d, r = unserialize_object(bufs)
158 137 B = d['a']
159 138 nt.assert_equal(r, [])
160 139 nt.assert_equal(A.shape, B.shape)
161 140 nt.assert_equal(A.dtype, B.dtype)
162 141 assert_array_equal(A,B)
163 142
164 143 def test_class():
165 144 @interactive
166 145 class C(object):
167 146 a=5
168 147 bufs = serialize_object(dict(C=C))
169 148 canned = pickle.loads(bufs[0])
170 149 nt.assert_is_instance(canned['C'], CannedClass)
171 150 d, r = unserialize_object(bufs)
172 151 C2 = d['C']
173 152 nt.assert_equal(C2.a, C.a)
174 153
175 154 def test_class_oldstyle():
176 155 @interactive
177 156 class C:
178 157 a=5
179 158
180 159 bufs = serialize_object(dict(C=C))
181 160 canned = pickle.loads(bufs[0])
182 161 nt.assert_is_instance(canned['C'], CannedClass)
183 162 d, r = unserialize_object(bufs)
184 163 C2 = d['C']
185 164 nt.assert_equal(C2.a, C.a)
186 165
187 166 def test_tuple():
188 167 tup = (lambda x:x, 1)
189 168 bufs = serialize_object(tup)
190 169 canned = pickle.loads(bufs[0])
191 170 nt.assert_is_instance(canned, tuple)
192 171 t2, r = unserialize_object(bufs)
193 172 nt.assert_equal(t2[0](t2[1]), tup[0](tup[1]))
194 173
195 174 point = namedtuple('point', 'x y')
196 175
197 176 def test_namedtuple():
198 177 p = point(1,2)
199 178 bufs = serialize_object(p)
200 179 canned = pickle.loads(bufs[0])
201 180 nt.assert_is_instance(canned, point)
202 181 p2, r = unserialize_object(bufs, globals())
203 182 nt.assert_equal(p2.x, p.x)
204 183 nt.assert_equal(p2.y, p.y)
205 184
206 185 def test_list():
207 186 lis = [lambda x:x, 1]
208 187 bufs = serialize_object(lis)
209 188 canned = pickle.loads(bufs[0])
210 189 nt.assert_is_instance(canned, list)
211 190 l2, r = unserialize_object(bufs)
212 191 nt.assert_equal(l2[0](l2[1]), lis[0](lis[1]))
213 192
214 193 def test_class_inheritance():
215 194 @interactive
216 195 class C(object):
217 196 a=5
218 197
219 198 @interactive
220 199 class D(C):
221 200 b=10
222 201
223 202 bufs = serialize_object(dict(D=D))
224 203 canned = pickle.loads(bufs[0])
225 204 nt.assert_is_instance(canned['D'], CannedClass)
226 205 d, r = unserialize_object(bufs)
227 206 D2 = d['D']
228 207 nt.assert_equal(D2.a, D.a)
229 208 nt.assert_equal(D2.b, D.b)
General Comments 0
You need to be logged in to leave comments. Login now