##// END OF EJS Templates
Merge pull request #6071 from minrk/serialize-nan...
Thomas Kluyver -
r17149:639fc3ac merge
parent child Browse files
Show More
@@ -1,15 +1,7 b''
1 """test serialization tools"""
1 """test serialization tools"""
2
2
3 #-------------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Copyright (C) 2011 The IPython Development Team
4 # Distributed under the terms of the Modified BSD License.
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
5
14 import pickle
6 import pickle
15 from collections import namedtuple
7 from collections import namedtuple
@@ -43,10 +35,15 b' class C(object):'
43
35
44 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
36 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
45 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
37 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
38
46 #-------------------------------------------------------------------------------
39 #-------------------------------------------------------------------------------
47 # Tests
40 # Tests
48 #-------------------------------------------------------------------------------
41 #-------------------------------------------------------------------------------
49
42
43 def new_array(shape, dtype):
44 import numpy
45 return numpy.random.random(shape).astype(dtype)
46
50 def test_roundtrip_simple():
47 def test_roundtrip_simple():
51 for obj in [
48 for obj in [
52 'hello',
49 'hello',
@@ -77,28 +74,13 b' def test_roundtrip_buffered():'
77 nt.assert_equal(remainder, [])
74 nt.assert_equal(remainder, [])
78 nt.assert_equal(obj, obj2)
75 nt.assert_equal(obj, obj2)
79
76
80 def _scrub_nan(A):
81 """scrub nans out of empty arrays
82
83 since nan != nan
84 """
85 import numpy
86 if A.dtype.fields and A.shape:
87 for field in A.dtype.fields.keys():
88 try:
89 A[field][numpy.isnan(A[field])] = 0
90 except (TypeError, NotImplementedError):
91 # e.g. str dtype
92 pass
93
94 @dec.skip_without('numpy')
77 @dec.skip_without('numpy')
95 def test_numpy():
78 def test_numpy():
96 import numpy
79 import numpy
97 from numpy.testing.utils import assert_array_equal
80 from numpy.testing.utils import assert_array_equal
98 for shape in SHAPES:
81 for shape in SHAPES:
99 for dtype in DTYPES:
82 for dtype in DTYPES:
100 A = numpy.empty(shape, dtype=dtype)
83 A = new_array(shape, dtype=dtype)
101 _scrub_nan(A)
102 bufs = serialize_object(A)
84 bufs = serialize_object(A)
103 B, r = unserialize_object(bufs)
85 B, r = unserialize_object(bufs)
104 nt.assert_equal(r, [])
86 nt.assert_equal(r, [])
@@ -115,8 +97,7 b' def test_recarray():'
115 [('f', float), ('s', '|S10')],
97 [('f', float), ('s', '|S10')],
116 [('n', int), ('s', '|S1'), ('u', 'uint32')],
98 [('n', int), ('s', '|S1'), ('u', 'uint32')],
117 ]:
99 ]:
118 A = numpy.empty(shape, dtype=dtype)
100 A = new_array(shape, dtype=dtype)
119 _scrub_nan(A)
120
101
121 bufs = serialize_object(A)
102 bufs = serialize_object(A)
122 B, r = unserialize_object(bufs)
103 B, r = unserialize_object(bufs)
@@ -131,8 +112,7 b' def test_numpy_in_seq():'
131 from numpy.testing.utils import assert_array_equal
112 from numpy.testing.utils import assert_array_equal
132 for shape in SHAPES:
113 for shape in SHAPES:
133 for dtype in DTYPES:
114 for dtype in DTYPES:
134 A = numpy.empty(shape, dtype=dtype)
115 A = new_array(shape, dtype=dtype)
135 _scrub_nan(A)
136 bufs = serialize_object((A,1,2,b'hello'))
116 bufs = serialize_object((A,1,2,b'hello'))
137 canned = pickle.loads(bufs[0])
117 canned = pickle.loads(bufs[0])
138 nt.assert_is_instance(canned[0], CannedArray)
118 nt.assert_is_instance(canned[0], CannedArray)
@@ -149,8 +129,7 b' def test_numpy_in_dict():'
149 from numpy.testing.utils import assert_array_equal
129 from numpy.testing.utils import assert_array_equal
150 for shape in SHAPES:
130 for shape in SHAPES:
151 for dtype in DTYPES:
131 for dtype in DTYPES:
152 A = numpy.empty(shape, dtype=dtype)
132 A = new_array(shape, dtype=dtype)
153 _scrub_nan(A)
154 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
133 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
155 canned = pickle.loads(bufs[0])
134 canned = pickle.loads(bufs[0])
156 nt.assert_is_instance(canned['a'], CannedArray)
135 nt.assert_is_instance(canned['a'], CannedArray)
General Comments 0
You need to be logged in to leave comments. Login now