diff --git a/IPython/zmq/tests/test_serialize.py b/IPython/zmq/tests/test_serialize.py index c1e0429..2a1a33a 100644 --- a/IPython/zmq/tests/test_serialize.py +++ b/IPython/zmq/tests/test_serialize.py @@ -20,6 +20,10 @@ from IPython.zmq.serialize import serialize_object, unserialize_object from IPython.testing import decorators as dec from IPython.utils.pickleutil import CannedArray +#------------------------------------------------------------------------------- +# Globals and Utilities +#------------------------------------------------------------------------------- + def roundtrip(obj): """roundtrip an object through serialization""" bufs = serialize_object(obj) @@ -34,6 +38,12 @@ class C(object): for key,value in kwargs.iteritems(): setattr(self, key, value) +SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,)) +DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10') +#------------------------------------------------------------------------------- +# Tests +#------------------------------------------------------------------------------- + @dec.parametric def test_roundtrip_simple(): for obj in [ @@ -67,17 +77,54 @@ def test_roundtrip_buffered(): yield nt.assert_equals(remainder, []) yield nt.assert_equals(obj, obj2) +def _scrub_nan(A): + """scrub nans out of empty arrays + + since nan != nan + """ + import numpy + if A.dtype.fields and A.shape: + for field in A.dtype.fields.keys(): + try: + A[field][numpy.isnan(A[field])] = 0 + except TypeError: + # e.g. str dtype + pass + @dec.parametric @dec.skip_without('numpy') def test_numpy(): import numpy from numpy.testing.utils import assert_array_equal - for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)): - for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]): + for shape in SHAPES: + for dtype in DTYPES: + A = numpy.empty(shape, dtype=dtype) + _scrub_nan(A) + bufs = serialize_object(A) + B, r = unserialize_object(bufs) + yield nt.assert_equals(r, []) + yield nt.assert_equals(A.shape, B.shape) + yield nt.assert_equals(A.dtype, B.dtype) + yield assert_array_equal(A,B) + +@dec.parametric +@dec.skip_without('numpy') +def test_recarray(): + import numpy + from numpy.testing.utils import assert_array_equal + for shape in SHAPES: + for dtype in [ + [('f', float), ('s', '|S10')], + [('n', int), ('s', '|S1'), ('u', 'uint32')], + ]: A = numpy.empty(shape, dtype=dtype) + _scrub_nan(A) + bufs = serialize_object(A) B, r = unserialize_object(bufs) yield nt.assert_equals(r, []) + yield nt.assert_equals(A.shape, B.shape) + yield nt.assert_equals(A.dtype, B.dtype) yield assert_array_equal(A,B) @dec.parametric @@ -85,15 +132,18 @@ def test_numpy(): def test_numpy_in_seq(): import numpy from numpy.testing.utils import assert_array_equal - for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)): - for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]): + for shape in SHAPES: + for dtype in DTYPES: A = numpy.empty(shape, dtype=dtype) + _scrub_nan(A) bufs = serialize_object((A,1,2,b'hello')) canned = pickle.loads(bufs[0]) yield nt.assert_true(canned[0], CannedArray) tup, r = unserialize_object(bufs) B = tup[0] yield nt.assert_equals(r, []) + yield nt.assert_equals(A.shape, B.shape) + yield nt.assert_equals(A.dtype, B.dtype) yield assert_array_equal(A,B) @dec.parametric @@ -101,15 +151,18 @@ def test_numpy_in_seq(): def test_numpy_in_dict(): import numpy from numpy.testing.utils import assert_array_equal - for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)): - for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]): + for shape in SHAPES: + for dtype in DTYPES: A = numpy.empty(shape, dtype=dtype) + _scrub_nan(A) bufs = serialize_object(dict(a=A,b=1,c=range(20))) canned = pickle.loads(bufs[0]) yield nt.assert_true(canned['a'], CannedArray) d, r = unserialize_object(bufs) B = d['a'] yield nt.assert_equals(r, []) + yield nt.assert_equals(A.shape, B.shape) + yield nt.assert_equals(A.dtype, B.dtype) yield assert_array_equal(A,B)