##// END OF EJS Templates
don't special case for py3k+numpy...
MinRK -
Show More
@@ -16,6 +16,8 b' Authors:'
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
20
19 from unittest import TestCase
21 from unittest import TestCase
20
22
21 from IPython.testing.decorators import parametric
23 from IPython.testing.decorators import parametric
@@ -23,6 +25,8 b' from IPython.utils import newserialized as ns'
23 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
25 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
24 from IPython.parallel.tests.clienttest import skip_without
26 from IPython.parallel.tests.clienttest import skip_without
25
27
28 if sys.version_info[0] >= 3:
29 buffer = memoryview
26
30
27 class CanningTestCase(TestCase):
31 class CanningTestCase(TestCase):
28 def test_canning(self):
32 def test_canning(self):
@@ -88,10 +92,10 b' class CanningTestCase(TestCase):'
88 self.assertEquals(md['shape'], a.shape)
92 self.assertEquals(md['shape'], a.shape)
89 self.assertEquals(md['dtype'], a.dtype.str)
93 self.assertEquals(md['dtype'], a.dtype.str)
90 buff = ser1.getData()
94 buff = ser1.getData()
91 self.assertEquals(buff, numpy.getbuffer(a))
95 self.assertEquals(buff, buffer(a))
92 s = ns.Serialized(buff, td, md)
96 s = ns.Serialized(buff, td, md)
93 final = ns.unserialize(s)
97 final = ns.unserialize(s)
94 self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
98 self.assertEquals(buffer(a), buffer(final))
95 self.assertTrue((a==final).all())
99 self.assertTrue((a==final).all())
96 self.assertEquals(a.dtype.str, final.dtype.str)
100 self.assertEquals(a.dtype.str, final.dtype.str)
97 self.assertEquals(a.shape, final.shape)
101 self.assertEquals(a.shape, final.shape)
@@ -35,6 +35,8 b' if sys.version_info[0] >= 3:'
35 py3k = True
35 py3k = True
36 else:
36 else:
37 py3k = False
37 py3k = False
38 if sys.version_info[:2] <= (2,6):
39 memoryview = buffer
38
40
39 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
40 # Classes and functions
42 # Classes and functions
@@ -101,10 +103,7 b' class SerializeIt(object):'
101 self.data = None
103 self.data = None
102 self.obj = unSerialized.getObject()
104 self.obj = unSerialized.getObject()
103 if numpy is not None and isinstance(self.obj, numpy.ndarray):
105 if numpy is not None and isinstance(self.obj, numpy.ndarray):
104 if py3k or len(self.obj.shape) == 0: # length 0 arrays are just pickled
106 if len(self.obj.shape) == 0: # length 0 arrays are just pickled
105 # FIXME:
106 # also use pickle for numpy arrays on py3k, since
107 # pyzmq doesn't rebuild from memoryviews properly
108 self.typeDescriptor = 'pickle'
107 self.typeDescriptor = 'pickle'
109 self.metadata = {}
108 self.metadata = {}
110 else:
109 else:
@@ -125,7 +124,7 b' class SerializeIt(object):'
125
124
126 def _generateData(self):
125 def _generateData(self):
127 if self.typeDescriptor == 'ndarray':
126 if self.typeDescriptor == 'ndarray':
128 self.data = numpy.getbuffer(self.obj)
127 self.data = buffer(self.obj)
129 elif self.typeDescriptor in ('bytes', 'buffer'):
128 elif self.typeDescriptor in ('bytes', 'buffer'):
130 self.data = self.obj
129 self.data = self.obj
131 elif self.typeDescriptor == 'pickle':
130 elif self.typeDescriptor == 'pickle':
@@ -158,11 +157,10 b' class UnSerializeIt(UnSerialized):'
158 typeDescriptor = self.serialized.getTypeDescriptor()
157 typeDescriptor = self.serialized.getTypeDescriptor()
159 if numpy is not None and typeDescriptor == 'ndarray':
158 if numpy is not None and typeDescriptor == 'ndarray':
160 buf = self.serialized.getData()
159 buf = self.serialized.getData()
161 if isinstance(buf, (bytes, buffer)):
160 if isinstance(buf, (bytes, buffer, memoryview)):
162 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
161 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
163 else:
162 else:
164 # memoryview
163 raise TypeError("Expected bytes or buffer/memoryview, but got %r"%type(buf))
165 result = numpy.array(buf, dtype = self.serialized.metadata['dtype'])
166 result.shape = self.serialized.metadata['shape']
164 result.shape = self.serialized.metadata['shape']
167 elif typeDescriptor == 'pickle':
165 elif typeDescriptor == 'pickle':
168 result = pickle.loads(self.serialized.getData())
166 result = pickle.loads(self.serialized.getData())
General Comments 0
You need to be logged in to leave comments. Login now