##// END OF EJS Templates
Merge pull request #2034 from minrk/recarray...
Fernando Perez -
r7702:feac4025 merge
parent child Browse files
Show More
@@ -1,117 +1,117 b''
1 """test serialization with newserialized
1 """test serialization with newserialized
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20
20
21 from unittest import TestCase
21 from unittest import TestCase
22
22
23 from IPython.testing.decorators import parametric
23 from IPython.testing.decorators import parametric
24 from IPython.utils import newserialized as ns
24 from IPython.utils import newserialized as ns
25 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
25 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
26 from IPython.parallel.tests.clienttest import skip_without
26 from IPython.parallel.tests.clienttest import skip_without
27
27
28 if sys.version_info[0] >= 3:
28 if sys.version_info[0] >= 3:
29 buffer = memoryview
29 buffer = memoryview
30
30
31 class CanningTestCase(TestCase):
31 class CanningTestCase(TestCase):
32 def test_canning(self):
32 def test_canning(self):
33 d = dict(a=5,b=6)
33 d = dict(a=5,b=6)
34 cd = can(d)
34 cd = can(d)
35 self.assertTrue(isinstance(cd, dict))
35 self.assertTrue(isinstance(cd, dict))
36
36
37 def test_canned_function(self):
37 def test_canned_function(self):
38 f = lambda : 7
38 f = lambda : 7
39 cf = can(f)
39 cf = can(f)
40 self.assertTrue(isinstance(cf, CannedFunction))
40 self.assertTrue(isinstance(cf, CannedFunction))
41
41
42 @parametric
42 @parametric
43 def test_can_roundtrip(cls):
43 def test_can_roundtrip(cls):
44 objs = [
44 objs = [
45 dict(),
45 dict(),
46 set(),
46 set(),
47 list(),
47 list(),
48 ['a',1,['a',1],u'e'],
48 ['a',1,['a',1],u'e'],
49 ]
49 ]
50 return map(cls.run_roundtrip, objs)
50 return map(cls.run_roundtrip, objs)
51
51
52 @classmethod
52 @classmethod
53 def run_roundtrip(self, obj):
53 def run_roundtrip(self, obj):
54 o = uncan(can(obj))
54 o = uncan(can(obj))
55 assert o == obj, "failed assertion: %r == %r"%(o,obj)
55 assert o == obj, "failed assertion: %r == %r"%(o,obj)
56
56
57 def test_serialized_interfaces(self):
57 def test_serialized_interfaces(self):
58
58
59 us = {'a':10, 'b':range(10)}
59 us = {'a':10, 'b':range(10)}
60 s = ns.serialize(us)
60 s = ns.serialize(us)
61 uus = ns.unserialize(s)
61 uus = ns.unserialize(s)
62 self.assertTrue(isinstance(s, ns.SerializeIt))
62 self.assertTrue(isinstance(s, ns.SerializeIt))
63 self.assertEquals(uus, us)
63 self.assertEquals(uus, us)
64
64
65 def test_pickle_serialized(self):
65 def test_pickle_serialized(self):
66 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
66 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
67 original = ns.UnSerialized(obj)
67 original = ns.UnSerialized(obj)
68 originalSer = ns.SerializeIt(original)
68 originalSer = ns.SerializeIt(original)
69 firstData = originalSer.getData()
69 firstData = originalSer.getData()
70 firstTD = originalSer.getTypeDescriptor()
70 firstTD = originalSer.getTypeDescriptor()
71 firstMD = originalSer.getMetadata()
71 firstMD = originalSer.getMetadata()
72 self.assertEquals(firstTD, 'pickle')
72 self.assertEquals(firstTD, 'pickle')
73 self.assertEquals(firstMD, {})
73 self.assertEquals(firstMD, {})
74 unSerialized = ns.UnSerializeIt(originalSer)
74 unSerialized = ns.UnSerializeIt(originalSer)
75 secondObj = unSerialized.getObject()
75 secondObj = unSerialized.getObject()
76 for k, v in secondObj.iteritems():
76 for k, v in secondObj.iteritems():
77 self.assertEquals(obj[k], v)
77 self.assertEquals(obj[k], v)
78 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
78 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
79 self.assertEquals(firstData, secondSer.getData())
79 self.assertEquals(firstData, secondSer.getData())
80 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
80 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
81 self.assertEquals(firstMD, secondSer.getMetadata())
81 self.assertEquals(firstMD, secondSer.getMetadata())
82
82
83 @skip_without('numpy')
83 @skip_without('numpy')
84 def test_ndarray_serialized(self):
84 def test_ndarray_serialized(self):
85 import numpy
85 import numpy
86 a = numpy.linspace(0.0, 1.0, 1000)
86 a = numpy.linspace(0.0, 1.0, 1000)
87 unSer1 = ns.UnSerialized(a)
87 unSer1 = ns.UnSerialized(a)
88 ser1 = ns.SerializeIt(unSer1)
88 ser1 = ns.SerializeIt(unSer1)
89 td = ser1.getTypeDescriptor()
89 td = ser1.getTypeDescriptor()
90 self.assertEquals(td, 'ndarray')
90 self.assertEquals(td, 'ndarray')
91 md = ser1.getMetadata()
91 md = ser1.getMetadata()
92 self.assertEquals(md['shape'], a.shape)
92 self.assertEquals(md['shape'], a.shape)
93 self.assertEquals(md['dtype'], a.dtype.str)
93 self.assertEquals(md['dtype'], a.dtype)
94 buff = ser1.getData()
94 buff = ser1.getData()
95 self.assertEquals(buff, buffer(a))
95 self.assertEquals(buff, buffer(a))
96 s = ns.Serialized(buff, td, md)
96 s = ns.Serialized(buff, td, md)
97 final = ns.unserialize(s)
97 final = ns.unserialize(s)
98 self.assertEquals(buffer(a), buffer(final))
98 self.assertEquals(buffer(a), buffer(final))
99 self.assertTrue((a==final).all())
99 self.assertTrue((a==final).all())
100 self.assertEquals(a.dtype.str, final.dtype.str)
100 self.assertEquals(a.dtype, final.dtype)
101 self.assertEquals(a.shape, final.shape)
101 self.assertEquals(a.shape, final.shape)
102 # test non-copying:
102 # test non-copying:
103 a[2] = 1e9
103 a[2] = 1e9
104 self.assertTrue((a==final).all())
104 self.assertTrue((a==final).all())
105
105
106 def test_uncan_function_globals(self):
106 def test_uncan_function_globals(self):
107 """test that uncanning a module function restores it into its module"""
107 """test that uncanning a module function restores it into its module"""
108 from re import search
108 from re import search
109 cf = can(search)
109 cf = can(search)
110 csearch = uncan(cf)
110 csearch = uncan(cf)
111 self.assertEqual(csearch.__module__, search.__module__)
111 self.assertEqual(csearch.__module__, search.__module__)
112 self.assertNotEqual(csearch('asd', 'asdf'), None)
112 self.assertNotEqual(csearch('asd', 'asdf'), None)
113 csearch = uncan(cf, dict(a=5))
113 csearch = uncan(cf, dict(a=5))
114 self.assertEqual(csearch.__module__, search.__module__)
114 self.assertEqual(csearch.__module__, search.__module__)
115 self.assertNotEqual(csearch('asd', 'asdf'), None)
115 self.assertNotEqual(csearch('asd', 'asdf'), None)
116
116
117 No newline at end of file
117
@@ -1,573 +1,597 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30
30
31 from IPython import parallel as pmod
31 from IPython import parallel as pmod
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import DirectView
34 from IPython.parallel import DirectView
35 from IPython.parallel.util import interactive
35 from IPython.parallel.util import interactive
36
36
37 from IPython.parallel.tests import add_engines
37 from IPython.parallel.tests import add_engines
38
38
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40
40
41 def setup():
41 def setup():
42 add_engines(3, total=True)
42 add_engines(3, total=True)
43
43
44 class TestView(ClusterTestCase, ParametricTestCase):
44 class TestView(ClusterTestCase, ParametricTestCase):
45
45
46 def setUp(self):
46 def setUp(self):
47 # On Win XP, wait for resource cleanup, else parallel test group fails
47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 time.sleep(2)
50 time.sleep(2)
51 super(TestView, self).setUp()
51 super(TestView, self).setUp()
52
52
53 def test_z_crash_mux(self):
53 def test_z_crash_mux(self):
54 """test graceful handling of engine death (direct)"""
54 """test graceful handling of engine death (direct)"""
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 # self.add_engines(1)
56 # self.add_engines(1)
57 eid = self.client.ids[-1]
57 eid = self.client.ids[-1]
58 ar = self.client[eid].apply_async(crash)
58 ar = self.client[eid].apply_async(crash)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 eid = ar.engine_id
60 eid = ar.engine_id
61 tic = time.time()
61 tic = time.time()
62 while eid in self.client.ids and time.time()-tic < 5:
62 while eid in self.client.ids and time.time()-tic < 5:
63 time.sleep(.01)
63 time.sleep(.01)
64 self.client.spin()
64 self.client.spin()
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66
66
67 def test_push_pull(self):
67 def test_push_pull(self):
68 """test pushing and pulling"""
68 """test pushing and pulling"""
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 t = self.client.ids[-1]
70 t = self.client.ids[-1]
71 v = self.client[t]
71 v = self.client[t]
72 push = v.push
72 push = v.push
73 pull = v.pull
73 pull = v.pull
74 v.block=True
74 v.block=True
75 nengines = len(self.client)
75 nengines = len(self.client)
76 push({'data':data})
76 push({'data':data})
77 d = pull('data')
77 d = pull('data')
78 self.assertEquals(d, data)
78 self.assertEquals(d, data)
79 self.client[:].push({'data':data})
79 self.client[:].push({'data':data})
80 d = self.client[:].pull('data', block=True)
80 d = self.client[:].pull('data', block=True)
81 self.assertEquals(d, nengines*[data])
81 self.assertEquals(d, nengines*[data])
82 ar = push({'data':data}, block=False)
82 ar = push({'data':data}, block=False)
83 self.assertTrue(isinstance(ar, AsyncResult))
83 self.assertTrue(isinstance(ar, AsyncResult))
84 r = ar.get()
84 r = ar.get()
85 ar = self.client[:].pull('data', block=False)
85 ar = self.client[:].pull('data', block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 self.assertEquals(r, nengines*[data])
88 self.assertEquals(r, nengines*[data])
89 self.client[:].push(dict(a=10,b=20))
89 self.client[:].push(dict(a=10,b=20))
90 r = self.client[:].pull(('a','b'), block=True)
90 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEquals(r, nengines*[[10,20]])
91 self.assertEquals(r, nengines*[[10,20]])
92
92
93 def test_push_pull_function(self):
93 def test_push_pull_function(self):
94 "test pushing and pulling functions"
94 "test pushing and pulling functions"
95 def testf(x):
95 def testf(x):
96 return 2.0*x
96 return 2.0*x
97
97
98 t = self.client.ids[-1]
98 t = self.client.ids[-1]
99 v = self.client[t]
99 v = self.client[t]
100 v.block=True
100 v.block=True
101 push = v.push
101 push = v.push
102 pull = v.pull
102 pull = v.pull
103 execute = v.execute
103 execute = v.execute
104 push({'testf':testf})
104 push({'testf':testf})
105 r = pull('testf')
105 r = pull('testf')
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute('r = testf(10)')
107 execute('r = testf(10)')
108 r = pull('r')
108 r = pull('r')
109 self.assertEquals(r, testf(10))
109 self.assertEquals(r, testf(10))
110 ar = self.client[:].push({'testf':testf}, block=False)
110 ar = self.client[:].push({'testf':testf}, block=False)
111 ar.get()
111 ar.get()
112 ar = self.client[:].pull('testf', block=False)
112 ar = self.client[:].pull('testf', block=False)
113 rlist = ar.get()
113 rlist = ar.get()
114 for r in rlist:
114 for r in rlist:
115 self.assertEqual(r(1.0), testf(1.0))
115 self.assertEqual(r(1.0), testf(1.0))
116 execute("def g(x): return x*x")
116 execute("def g(x): return x*x")
117 r = pull(('testf','g'))
117 r = pull(('testf','g'))
118 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
118 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
119
119
120 def test_push_function_globals(self):
120 def test_push_function_globals(self):
121 """test that pushed functions have access to globals"""
121 """test that pushed functions have access to globals"""
122 @interactive
122 @interactive
123 def geta():
123 def geta():
124 return a
124 return a
125 # self.add_engines(1)
125 # self.add_engines(1)
126 v = self.client[-1]
126 v = self.client[-1]
127 v.block=True
127 v.block=True
128 v['f'] = geta
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
130 v.execute('a=5')
131 v.execute('b=f()')
131 v.execute('b=f()')
132 self.assertEquals(v['b'], 5)
132 self.assertEquals(v['b'], 5)
133
133
134 def test_push_function_defaults(self):
134 def test_push_function_defaults(self):
135 """test that pushed functions preserve default args"""
135 """test that pushed functions preserve default args"""
136 def echo(a=10):
136 def echo(a=10):
137 return a
137 return a
138 v = self.client[-1]
138 v = self.client[-1]
139 v.block=True
139 v.block=True
140 v['f'] = echo
140 v['f'] = echo
141 v.execute('b=f()')
141 v.execute('b=f()')
142 self.assertEquals(v['b'], 10)
142 self.assertEquals(v['b'], 10)
143
143
144 def test_get_result(self):
144 def test_get_result(self):
145 """test getting results from the Hub."""
145 """test getting results from the Hub."""
146 c = pmod.Client(profile='iptest')
146 c = pmod.Client(profile='iptest')
147 # self.add_engines(1)
147 # self.add_engines(1)
148 t = c.ids[-1]
148 t = c.ids[-1]
149 v = c[t]
149 v = c[t]
150 v2 = self.client[t]
150 v2 = self.client[t]
151 ar = v.apply_async(wait, 1)
151 ar = v.apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = v2.get_result(ar.msg_ids)
154 ahr = v2.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEquals(ahr.get(), ar.get())
156 self.assertEquals(ahr.get(), ar.get())
157 ar2 = v2.get_result(ar.msg_ids)
157 ar2 = v2.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.spin()
159 c.spin()
160 c.close()
160 c.close()
161
161
162 def test_run_newline(self):
162 def test_run_newline(self):
163 """test that run appends newline to files"""
163 """test that run appends newline to files"""
164 tmpfile = mktemp()
164 tmpfile = mktemp()
165 with open(tmpfile, 'w') as f:
165 with open(tmpfile, 'w') as f:
166 f.write("""def g():
166 f.write("""def g():
167 return 5
167 return 5
168 """)
168 """)
169 v = self.client[-1]
169 v = self.client[-1]
170 v.run(tmpfile, block=True)
170 v.run(tmpfile, block=True)
171 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
171 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172
172
173 def test_apply_tracked(self):
173 def test_apply_tracked(self):
174 """test tracking for apply"""
174 """test tracking for apply"""
175 # self.add_engines(1)
175 # self.add_engines(1)
176 t = self.client.ids[-1]
176 t = self.client.ids[-1]
177 v = self.client[t]
177 v = self.client[t]
178 v.block=False
178 v.block=False
179 def echo(n=1024*1024, **kwargs):
179 def echo(n=1024*1024, **kwargs):
180 with v.temp_flags(**kwargs):
180 with v.temp_flags(**kwargs):
181 return v.apply(lambda x: x, 'x'*n)
181 return v.apply(lambda x: x, 'x'*n)
182 ar = echo(1, track=False)
182 ar = echo(1, track=False)
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(ar.sent)
184 self.assertTrue(ar.sent)
185 ar = echo(track=True)
185 ar = echo(track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEquals(ar.sent, ar._tracker.done)
187 self.assertEquals(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
188 ar._tracker.wait()
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190
190
191 def test_push_tracked(self):
191 def test_push_tracked(self):
192 t = self.client.ids[-1]
192 t = self.client.ids[-1]
193 ns = dict(x='x'*1024*1024)
193 ns = dict(x='x'*1024*1024)
194 v = self.client[t]
194 v = self.client[t]
195 ar = v.push(ns, block=False, track=False)
195 ar = v.push(ns, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
197 self.assertTrue(ar.sent)
198
198
199 ar = v.push(ns, block=False, track=True)
199 ar = v.push(ns, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 ar._tracker.wait()
201 ar._tracker.wait()
202 self.assertEquals(ar.sent, ar._tracker.done)
202 self.assertEquals(ar.sent, ar._tracker.done)
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204 ar.get()
204 ar.get()
205
205
206 def test_scatter_tracked(self):
206 def test_scatter_tracked(self):
207 t = self.client.ids
207 t = self.client.ids
208 x='x'*1024*1024
208 x='x'*1024*1024
209 ar = self.client[t].scatter('x', x, block=False, track=False)
209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(ar.sent)
211 self.assertTrue(ar.sent)
212
212
213 ar = self.client[t].scatter('x', x, block=False, track=True)
213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEquals(ar.sent, ar._tracker.done)
215 self.assertEquals(ar.sent, ar._tracker.done)
216 ar._tracker.wait()
216 ar._tracker.wait()
217 self.assertTrue(ar.sent)
217 self.assertTrue(ar.sent)
218 ar.get()
218 ar.get()
219
219
220 def test_remote_reference(self):
220 def test_remote_reference(self):
221 v = self.client[-1]
221 v = self.client[-1]
222 v['a'] = 123
222 v['a'] = 123
223 ra = pmod.Reference('a')
223 ra = pmod.Reference('a')
224 b = v.apply_sync(lambda x: x, ra)
224 b = v.apply_sync(lambda x: x, ra)
225 self.assertEquals(b, 123)
225 self.assertEquals(b, 123)
226
226
227
227
228 def test_scatter_gather(self):
228 def test_scatter_gather(self):
229 view = self.client[:]
229 view = self.client[:]
230 seq1 = range(16)
230 seq1 = range(16)
231 view.scatter('a', seq1)
231 view.scatter('a', seq1)
232 seq2 = view.gather('a', block=True)
232 seq2 = view.gather('a', block=True)
233 self.assertEquals(seq2, seq1)
233 self.assertEquals(seq2, seq1)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235
235
236 @skip_without('numpy')
236 @skip_without('numpy')
237 def test_scatter_gather_numpy(self):
237 def test_scatter_gather_numpy(self):
238 import numpy
238 import numpy
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view.scatter('a', a)
242 view.scatter('a', a)
243 b = view.gather('a', block=True)
243 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
244 assert_array_equal(b, a)
245
245
246 def test_scatter_gather_lazy(self):
246 def test_scatter_gather_lazy(self):
247 """scatter/gather with targets='all'"""
247 """scatter/gather with targets='all'"""
248 view = self.client.direct_view(targets='all')
248 view = self.client.direct_view(targets='all')
249 x = range(64)
249 x = range(64)
250 view.scatter('x', x)
250 view.scatter('x', x)
251 gathered = view.gather('x', block=True)
251 gathered = view.gather('x', block=True)
252 self.assertEquals(gathered, x)
252 self.assertEquals(gathered, x)
253
253
254
254
255 @dec.known_failure_py3
255 @dec.known_failure_py3
256 @skip_without('numpy')
256 @skip_without('numpy')
257 def test_push_numpy_nocopy(self):
257 def test_push_numpy_nocopy(self):
258 import numpy
258 import numpy
259 view = self.client[:]
259 view = self.client[:]
260 a = numpy.arange(64)
260 a = numpy.arange(64)
261 view['A'] = a
261 view['A'] = a
262 @interactive
262 @interactive
263 def check_writeable(x):
263 def check_writeable(x):
264 return x.flags.writeable
264 return x.flags.writeable
265
265
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268
268
269 view.push(dict(B=a))
269 view.push(dict(B=a))
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272
272
273 @skip_without('numpy')
273 @skip_without('numpy')
274 def test_apply_numpy(self):
274 def test_apply_numpy(self):
275 """view.apply(f, ndarray)"""
275 """view.apply(f, ndarray)"""
276 import numpy
276 import numpy
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278
278
279 A = numpy.random.random((100,100))
279 A = numpy.random.random((100,100))
280 view = self.client[-1]
280 view = self.client[-1]
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 B = A.astype(dt)
282 B = A.astype(dt)
283 C = view.apply_sync(lambda x:x, B)
283 C = view.apply_sync(lambda x:x, B)
284 assert_array_equal(B,C)
284 assert_array_equal(B,C)
285
285
286 @skip_without('numpy')
287 def test_push_pull_recarray(self):
288 """push/pull recarrays"""
289 import numpy
290 from numpy.testing.utils import assert_array_equal
291
292 view = self.client[-1]
293
294 R = numpy.array([
295 (1, 'hi', 0.),
296 (2**30, 'there', 2.5),
297 (-99999, 'world', -12345.6789),
298 ], [('n', int), ('s', '|S10'), ('f', float)])
299
300 view['RR'] = R
301 R2 = view['RR']
302
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEquals(r_dtype, R.dtype)
305 self.assertEquals(r_shape, R.shape)
306 self.assertEquals(R2.dtype, R.dtype)
307 self.assertEquals(R2.shape, R.shape)
308 assert_array_equal(R2, R)
309
286 def test_map(self):
310 def test_map(self):
287 view = self.client[:]
311 view = self.client[:]
288 def f(x):
312 def f(x):
289 return x**2
313 return x**2
290 data = range(16)
314 data = range(16)
291 r = view.map_sync(f, data)
315 r = view.map_sync(f, data)
292 self.assertEquals(r, map(f, data))
316 self.assertEquals(r, map(f, data))
293
317
294 def test_map_iterable(self):
318 def test_map_iterable(self):
295 """test map on iterables (direct)"""
319 """test map on iterables (direct)"""
296 view = self.client[:]
320 view = self.client[:]
297 # 101 is prime, so it won't be evenly distributed
321 # 101 is prime, so it won't be evenly distributed
298 arr = range(101)
322 arr = range(101)
299 # ensure it will be an iterator, even in Python 3
323 # ensure it will be an iterator, even in Python 3
300 it = iter(arr)
324 it = iter(arr)
301 r = view.map_sync(lambda x:x, arr)
325 r = view.map_sync(lambda x:x, arr)
302 self.assertEquals(r, list(arr))
326 self.assertEquals(r, list(arr))
303
327
304 def test_scatterGatherNonblocking(self):
328 def test_scatterGatherNonblocking(self):
305 data = range(16)
329 data = range(16)
306 view = self.client[:]
330 view = self.client[:]
307 view.scatter('a', data, block=False)
331 view.scatter('a', data, block=False)
308 ar = view.gather('a', block=False)
332 ar = view.gather('a', block=False)
309 self.assertEquals(ar.get(), data)
333 self.assertEquals(ar.get(), data)
310
334
311 @skip_without('numpy')
335 @skip_without('numpy')
312 def test_scatter_gather_numpy_nonblocking(self):
336 def test_scatter_gather_numpy_nonblocking(self):
313 import numpy
337 import numpy
314 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
315 a = numpy.arange(64)
339 a = numpy.arange(64)
316 view = self.client[:]
340 view = self.client[:]
317 ar = view.scatter('a', a, block=False)
341 ar = view.scatter('a', a, block=False)
318 self.assertTrue(isinstance(ar, AsyncResult))
342 self.assertTrue(isinstance(ar, AsyncResult))
319 amr = view.gather('a', block=False)
343 amr = view.gather('a', block=False)
320 self.assertTrue(isinstance(amr, AsyncMapResult))
344 self.assertTrue(isinstance(amr, AsyncMapResult))
321 assert_array_equal(amr.get(), a)
345 assert_array_equal(amr.get(), a)
322
346
323 def test_execute(self):
347 def test_execute(self):
324 view = self.client[:]
348 view = self.client[:]
325 # self.client.debug=True
349 # self.client.debug=True
326 execute = view.execute
350 execute = view.execute
327 ar = execute('c=30', block=False)
351 ar = execute('c=30', block=False)
328 self.assertTrue(isinstance(ar, AsyncResult))
352 self.assertTrue(isinstance(ar, AsyncResult))
329 ar = execute('d=[0,1,2]', block=False)
353 ar = execute('d=[0,1,2]', block=False)
330 self.client.wait(ar, 1)
354 self.client.wait(ar, 1)
331 self.assertEquals(len(ar.get()), len(self.client))
355 self.assertEquals(len(ar.get()), len(self.client))
332 for c in view['c']:
356 for c in view['c']:
333 self.assertEquals(c, 30)
357 self.assertEquals(c, 30)
334
358
335 def test_abort(self):
359 def test_abort(self):
336 view = self.client[-1]
360 view = self.client[-1]
337 ar = view.execute('import time; time.sleep(1)', block=False)
361 ar = view.execute('import time; time.sleep(1)', block=False)
338 ar2 = view.apply_async(lambda : 2)
362 ar2 = view.apply_async(lambda : 2)
339 ar3 = view.apply_async(lambda : 3)
363 ar3 = view.apply_async(lambda : 3)
340 view.abort(ar2)
364 view.abort(ar2)
341 view.abort(ar3.msg_ids)
365 view.abort(ar3.msg_ids)
342 self.assertRaises(error.TaskAborted, ar2.get)
366 self.assertRaises(error.TaskAborted, ar2.get)
343 self.assertRaises(error.TaskAborted, ar3.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
344
368
345 def test_abort_all(self):
369 def test_abort_all(self):
346 """view.abort() aborts all outstanding tasks"""
370 """view.abort() aborts all outstanding tasks"""
347 view = self.client[-1]
371 view = self.client[-1]
348 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
349 view.abort()
373 view.abort()
350 view.wait(timeout=5)
374 view.wait(timeout=5)
351 for ar in ars[5:]:
375 for ar in ars[5:]:
352 self.assertRaises(error.TaskAborted, ar.get)
376 self.assertRaises(error.TaskAborted, ar.get)
353
377
354 def test_temp_flags(self):
378 def test_temp_flags(self):
355 view = self.client[-1]
379 view = self.client[-1]
356 view.block=True
380 view.block=True
357 with view.temp_flags(block=False):
381 with view.temp_flags(block=False):
358 self.assertFalse(view.block)
382 self.assertFalse(view.block)
359 self.assertTrue(view.block)
383 self.assertTrue(view.block)
360
384
361 @dec.known_failure_py3
385 @dec.known_failure_py3
362 def test_importer(self):
386 def test_importer(self):
363 view = self.client[-1]
387 view = self.client[-1]
364 view.clear(block=True)
388 view.clear(block=True)
365 with view.importer:
389 with view.importer:
366 import re
390 import re
367
391
368 @interactive
392 @interactive
369 def findall(pat, s):
393 def findall(pat, s):
370 # this globals() step isn't necessary in real code
394 # this globals() step isn't necessary in real code
371 # only to prevent a closure in the test
395 # only to prevent a closure in the test
372 re = globals()['re']
396 re = globals()['re']
373 return re.findall(pat, s)
397 return re.findall(pat, s)
374
398
375 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
399 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
376
400
377 def test_unicode_execute(self):
401 def test_unicode_execute(self):
378 """test executing unicode strings"""
402 """test executing unicode strings"""
379 v = self.client[-1]
403 v = self.client[-1]
380 v.block=True
404 v.block=True
381 if sys.version_info[0] >= 3:
405 if sys.version_info[0] >= 3:
382 code="a='é'"
406 code="a='é'"
383 else:
407 else:
384 code=u"a=u'é'"
408 code=u"a=u'é'"
385 v.execute(code)
409 v.execute(code)
386 self.assertEquals(v['a'], u'é')
410 self.assertEquals(v['a'], u'é')
387
411
388 def test_unicode_apply_result(self):
412 def test_unicode_apply_result(self):
389 """test unicode apply results"""
413 """test unicode apply results"""
390 v = self.client[-1]
414 v = self.client[-1]
391 r = v.apply_sync(lambda : u'é')
415 r = v.apply_sync(lambda : u'é')
392 self.assertEquals(r, u'é')
416 self.assertEquals(r, u'é')
393
417
394 def test_unicode_apply_arg(self):
418 def test_unicode_apply_arg(self):
395 """test passing unicode arguments to apply"""
419 """test passing unicode arguments to apply"""
396 v = self.client[-1]
420 v = self.client[-1]
397
421
398 @interactive
422 @interactive
399 def check_unicode(a, check):
423 def check_unicode(a, check):
400 assert isinstance(a, unicode), "%r is not unicode"%a
424 assert isinstance(a, unicode), "%r is not unicode"%a
401 assert isinstance(check, bytes), "%r is not bytes"%check
425 assert isinstance(check, bytes), "%r is not bytes"%check
402 assert a.encode('utf8') == check, "%s != %s"%(a,check)
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
403
427
404 for s in [ u'é', u'ßø®∫',u'asdf' ]:
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
405 try:
429 try:
406 v.apply_sync(check_unicode, s, s.encode('utf8'))
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
407 except error.RemoteError as e:
431 except error.RemoteError as e:
408 if e.ename == 'AssertionError':
432 if e.ename == 'AssertionError':
409 self.fail(e.evalue)
433 self.fail(e.evalue)
410 else:
434 else:
411 raise e
435 raise e
412
436
413 def test_map_reference(self):
437 def test_map_reference(self):
414 """view.map(<Reference>, *seqs) should work"""
438 """view.map(<Reference>, *seqs) should work"""
415 v = self.client[:]
439 v = self.client[:]
416 v.scatter('n', self.client.ids, flatten=True)
440 v.scatter('n', self.client.ids, flatten=True)
417 v.execute("f = lambda x,y: x*y")
441 v.execute("f = lambda x,y: x*y")
418 rf = pmod.Reference('f')
442 rf = pmod.Reference('f')
419 nlist = list(range(10))
443 nlist = list(range(10))
420 mlist = nlist[::-1]
444 mlist = nlist[::-1]
421 expected = [ m*n for m,n in zip(mlist, nlist) ]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
422 result = v.map_sync(rf, mlist, nlist)
446 result = v.map_sync(rf, mlist, nlist)
423 self.assertEquals(result, expected)
447 self.assertEquals(result, expected)
424
448
425 def test_apply_reference(self):
449 def test_apply_reference(self):
426 """view.apply(<Reference>, *args) should work"""
450 """view.apply(<Reference>, *args) should work"""
427 v = self.client[:]
451 v = self.client[:]
428 v.scatter('n', self.client.ids, flatten=True)
452 v.scatter('n', self.client.ids, flatten=True)
429 v.execute("f = lambda x: n*x")
453 v.execute("f = lambda x: n*x")
430 rf = pmod.Reference('f')
454 rf = pmod.Reference('f')
431 result = v.apply_sync(rf, 5)
455 result = v.apply_sync(rf, 5)
432 expected = [ 5*id for id in self.client.ids ]
456 expected = [ 5*id for id in self.client.ids ]
433 self.assertEquals(result, expected)
457 self.assertEquals(result, expected)
434
458
435 def test_eval_reference(self):
459 def test_eval_reference(self):
436 v = self.client[self.client.ids[0]]
460 v = self.client[self.client.ids[0]]
437 v['g'] = range(5)
461 v['g'] = range(5)
438 rg = pmod.Reference('g[0]')
462 rg = pmod.Reference('g[0]')
439 echo = lambda x:x
463 echo = lambda x:x
440 self.assertEquals(v.apply_sync(echo, rg), 0)
464 self.assertEquals(v.apply_sync(echo, rg), 0)
441
465
442 def test_reference_nameerror(self):
466 def test_reference_nameerror(self):
443 v = self.client[self.client.ids[0]]
467 v = self.client[self.client.ids[0]]
444 r = pmod.Reference('elvis_has_left')
468 r = pmod.Reference('elvis_has_left')
445 echo = lambda x:x
469 echo = lambda x:x
446 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
447
471
448 def test_single_engine_map(self):
472 def test_single_engine_map(self):
449 e0 = self.client[self.client.ids[0]]
473 e0 = self.client[self.client.ids[0]]
450 r = range(5)
474 r = range(5)
451 check = [ -1*i for i in r ]
475 check = [ -1*i for i in r ]
452 result = e0.map_sync(lambda x: -1*x, r)
476 result = e0.map_sync(lambda x: -1*x, r)
453 self.assertEquals(result, check)
477 self.assertEquals(result, check)
454
478
455 def test_len(self):
479 def test_len(self):
456 """len(view) makes sense"""
480 """len(view) makes sense"""
457 e0 = self.client[self.client.ids[0]]
481 e0 = self.client[self.client.ids[0]]
458 yield self.assertEquals(len(e0), 1)
482 yield self.assertEquals(len(e0), 1)
459 v = self.client[:]
483 v = self.client[:]
460 yield self.assertEquals(len(v), len(self.client.ids))
484 yield self.assertEquals(len(v), len(self.client.ids))
461 v = self.client.direct_view('all')
485 v = self.client.direct_view('all')
462 yield self.assertEquals(len(v), len(self.client.ids))
486 yield self.assertEquals(len(v), len(self.client.ids))
463 v = self.client[:2]
487 v = self.client[:2]
464 yield self.assertEquals(len(v), 2)
488 yield self.assertEquals(len(v), 2)
465 v = self.client[:1]
489 v = self.client[:1]
466 yield self.assertEquals(len(v), 1)
490 yield self.assertEquals(len(v), 1)
467 v = self.client.load_balanced_view()
491 v = self.client.load_balanced_view()
468 yield self.assertEquals(len(v), len(self.client.ids))
492 yield self.assertEquals(len(v), len(self.client.ids))
469 # parametric tests seem to require manual closing?
493 # parametric tests seem to require manual closing?
470 self.client.close()
494 self.client.close()
471
495
472
496
473 # begin execute tests
497 # begin execute tests
474
498
475 def test_execute_reply(self):
499 def test_execute_reply(self):
476 e0 = self.client[self.client.ids[0]]
500 e0 = self.client[self.client.ids[0]]
477 e0.block = True
501 e0.block = True
478 ar = e0.execute("5", silent=False)
502 ar = e0.execute("5", silent=False)
479 er = ar.get()
503 er = ar.get()
480 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
504 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
481 self.assertEquals(er.pyout['data']['text/plain'], '5')
505 self.assertEquals(er.pyout['data']['text/plain'], '5')
482
506
483 def test_execute_reply_stdout(self):
507 def test_execute_reply_stdout(self):
484 e0 = self.client[self.client.ids[0]]
508 e0 = self.client[self.client.ids[0]]
485 e0.block = True
509 e0.block = True
486 ar = e0.execute("print (5)", silent=False)
510 ar = e0.execute("print (5)", silent=False)
487 er = ar.get()
511 er = ar.get()
488 self.assertEquals(er.stdout.strip(), '5')
512 self.assertEquals(er.stdout.strip(), '5')
489
513
490 def test_execute_pyout(self):
514 def test_execute_pyout(self):
491 """execute triggers pyout with silent=False"""
515 """execute triggers pyout with silent=False"""
492 view = self.client[:]
516 view = self.client[:]
493 ar = view.execute("5", silent=False, block=True)
517 ar = view.execute("5", silent=False, block=True)
494
518
495 expected = [{'text/plain' : '5'}] * len(view)
519 expected = [{'text/plain' : '5'}] * len(view)
496 mimes = [ out['data'] for out in ar.pyout ]
520 mimes = [ out['data'] for out in ar.pyout ]
497 self.assertEquals(mimes, expected)
521 self.assertEquals(mimes, expected)
498
522
499 def test_execute_silent(self):
523 def test_execute_silent(self):
500 """execute does not trigger pyout with silent=True"""
524 """execute does not trigger pyout with silent=True"""
501 view = self.client[:]
525 view = self.client[:]
502 ar = view.execute("5", block=True)
526 ar = view.execute("5", block=True)
503 expected = [None] * len(view)
527 expected = [None] * len(view)
504 self.assertEquals(ar.pyout, expected)
528 self.assertEquals(ar.pyout, expected)
505
529
506 def test_execute_magic(self):
530 def test_execute_magic(self):
507 """execute accepts IPython commands"""
531 """execute accepts IPython commands"""
508 view = self.client[:]
532 view = self.client[:]
509 view.execute("a = 5")
533 view.execute("a = 5")
510 ar = view.execute("%whos", block=True)
534 ar = view.execute("%whos", block=True)
511 # this will raise, if that failed
535 # this will raise, if that failed
512 ar.get(5)
536 ar.get(5)
513 for stdout in ar.stdout:
537 for stdout in ar.stdout:
514 lines = stdout.splitlines()
538 lines = stdout.splitlines()
515 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
539 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
516 found = False
540 found = False
517 for line in lines[2:]:
541 for line in lines[2:]:
518 split = line.split()
542 split = line.split()
519 if split == ['a', 'int', '5']:
543 if split == ['a', 'int', '5']:
520 found = True
544 found = True
521 break
545 break
522 self.assertTrue(found, "whos output wrong: %s" % stdout)
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
523
547
524 def test_execute_displaypub(self):
548 def test_execute_displaypub(self):
525 """execute tracks display_pub output"""
549 """execute tracks display_pub output"""
526 view = self.client[:]
550 view = self.client[:]
527 view.execute("from IPython.core.display import *")
551 view.execute("from IPython.core.display import *")
528 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
529
553
530 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
531 for outputs in ar.outputs:
555 for outputs in ar.outputs:
532 mimes = [ out['data'] for out in outputs ]
556 mimes = [ out['data'] for out in outputs ]
533 self.assertEquals(mimes, expected)
557 self.assertEquals(mimes, expected)
534
558
535 def test_apply_displaypub(self):
559 def test_apply_displaypub(self):
536 """apply tracks display_pub output"""
560 """apply tracks display_pub output"""
537 view = self.client[:]
561 view = self.client[:]
538 view.execute("from IPython.core.display import *")
562 view.execute("from IPython.core.display import *")
539
563
540 @interactive
564 @interactive
541 def publish():
565 def publish():
542 [ display(i) for i in range(5) ]
566 [ display(i) for i in range(5) ]
543
567
544 ar = view.apply_async(publish)
568 ar = view.apply_async(publish)
545 ar.get(5)
569 ar.get(5)
546 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
547 for outputs in ar.outputs:
571 for outputs in ar.outputs:
548 mimes = [ out['data'] for out in outputs ]
572 mimes = [ out['data'] for out in outputs ]
549 self.assertEquals(mimes, expected)
573 self.assertEquals(mimes, expected)
550
574
551 def test_execute_raises(self):
575 def test_execute_raises(self):
552 """exceptions in execute requests raise appropriately"""
576 """exceptions in execute requests raise appropriately"""
553 view = self.client[-1]
577 view = self.client[-1]
554 ar = view.execute("1/0")
578 ar = view.execute("1/0")
555 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
556
580
557 @dec.skipif_not_matplotlib
581 @dec.skipif_not_matplotlib
558 def test_magic_pylab(self):
582 def test_magic_pylab(self):
559 """%pylab works on engines"""
583 """%pylab works on engines"""
560 view = self.client[-1]
584 view = self.client[-1]
561 ar = view.execute("%pylab inline")
585 ar = view.execute("%pylab inline")
562 # at least check if this raised:
586 # at least check if this raised:
563 reply = ar.get(5)
587 reply = ar.get(5)
564 # include imports, in case user config
588 # include imports, in case user config
565 ar = view.execute("plot(rand(100))", silent=False)
589 ar = view.execute("plot(rand(100))", silent=False)
566 reply = ar.get(5)
590 reply = ar.get(5)
567 self.assertEquals(len(reply.outputs), 1)
591 self.assertEquals(len(reply.outputs), 1)
568 output = reply.outputs[0]
592 output = reply.outputs[0]
569 self.assertTrue("data" in output)
593 self.assertTrue("data" in output)
570 data = output['data']
594 data = output['data']
571 self.assertTrue("image/png" in data)
595 self.assertTrue("image/png" in data)
572
596
573
597
@@ -1,177 +1,177 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
3
3
4 """Refactored serialization classes and interfaces."""
4 """Refactored serialization classes and interfaces."""
5
5
6 __docformat__ = "restructuredtext en"
6 __docformat__ = "restructuredtext en"
7
7
8 # Tell nose to skip this module
8 # Tell nose to skip this module
9 __test__ = {}
9 __test__ = {}
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21
21
22 import sys
22 import sys
23 import cPickle as pickle
23 import cPickle as pickle
24
24
25 try:
25 try:
26 import numpy
26 import numpy
27 except ImportError:
27 except ImportError:
28 numpy = None
28 numpy = None
29
29
30 class SerializationError(Exception):
30 class SerializationError(Exception):
31 pass
31 pass
32
32
33 if sys.version_info[0] >= 3:
33 if sys.version_info[0] >= 3:
34 buffer = memoryview
34 buffer = memoryview
35 py3k = True
35 py3k = True
36 else:
36 else:
37 py3k = False
37 py3k = False
38 if sys.version_info[:2] <= (2,6):
38 if sys.version_info[:2] <= (2,6):
39 memoryview = buffer
39 memoryview = buffer
40
40
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42 # Classes and functions
42 # Classes and functions
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44
44
45 class ISerialized:
45 class ISerialized:
46
46
47 def getData():
47 def getData():
48 """"""
48 """"""
49
49
50 def getDataSize(units=10.0**6):
50 def getDataSize(units=10.0**6):
51 """"""
51 """"""
52
52
53 def getTypeDescriptor():
53 def getTypeDescriptor():
54 """"""
54 """"""
55
55
56 def getMetadata():
56 def getMetadata():
57 """"""
57 """"""
58
58
59
59
60 class IUnSerialized:
60 class IUnSerialized:
61
61
62 def getObject():
62 def getObject():
63 """"""
63 """"""
64
64
65 class Serialized(object):
65 class Serialized(object):
66
66
67 # implements(ISerialized)
67 # implements(ISerialized)
68
68
69 def __init__(self, data, typeDescriptor, metadata={}):
69 def __init__(self, data, typeDescriptor, metadata={}):
70 self.data = data
70 self.data = data
71 self.typeDescriptor = typeDescriptor
71 self.typeDescriptor = typeDescriptor
72 self.metadata = metadata
72 self.metadata = metadata
73
73
74 def getData(self):
74 def getData(self):
75 return self.data
75 return self.data
76
76
77 def getDataSize(self, units=10.0**6):
77 def getDataSize(self, units=10.0**6):
78 return len(self.data)/units
78 return len(self.data)/units
79
79
80 def getTypeDescriptor(self):
80 def getTypeDescriptor(self):
81 return self.typeDescriptor
81 return self.typeDescriptor
82
82
83 def getMetadata(self):
83 def getMetadata(self):
84 return self.metadata
84 return self.metadata
85
85
86
86
87 class UnSerialized(object):
87 class UnSerialized(object):
88
88
89 # implements(IUnSerialized)
89 # implements(IUnSerialized)
90
90
91 def __init__(self, obj):
91 def __init__(self, obj):
92 self.obj = obj
92 self.obj = obj
93
93
94 def getObject(self):
94 def getObject(self):
95 return self.obj
95 return self.obj
96
96
97
97
98 class SerializeIt(object):
98 class SerializeIt(object):
99
99
100 # implements(ISerialized)
100 # implements(ISerialized)
101
101
102 def __init__(self, unSerialized):
102 def __init__(self, unSerialized):
103 self.data = None
103 self.data = None
104 self.obj = unSerialized.getObject()
104 self.obj = unSerialized.getObject()
105 if numpy is not None and isinstance(self.obj, numpy.ndarray):
105 if numpy is not None and isinstance(self.obj, numpy.ndarray):
106 if len(self.obj.shape) == 0: # length 0 arrays are just pickled
106 if len(self.obj.shape) == 0: # length 0 arrays are just pickled
107 self.typeDescriptor = 'pickle'
107 self.typeDescriptor = 'pickle'
108 self.metadata = {}
108 self.metadata = {}
109 else:
109 else:
110 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
110 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
111 self.typeDescriptor = 'ndarray'
111 self.typeDescriptor = 'ndarray'
112 self.metadata = {'shape':self.obj.shape,
112 self.metadata = {'shape':self.obj.shape,
113 'dtype':self.obj.dtype.str}
113 'dtype':self.obj.dtype}
114 elif isinstance(self.obj, bytes):
114 elif isinstance(self.obj, bytes):
115 self.typeDescriptor = 'bytes'
115 self.typeDescriptor = 'bytes'
116 self.metadata = {}
116 self.metadata = {}
117 elif isinstance(self.obj, buffer):
117 elif isinstance(self.obj, buffer):
118 self.typeDescriptor = 'buffer'
118 self.typeDescriptor = 'buffer'
119 self.metadata = {}
119 self.metadata = {}
120 else:
120 else:
121 self.typeDescriptor = 'pickle'
121 self.typeDescriptor = 'pickle'
122 self.metadata = {}
122 self.metadata = {}
123 self._generateData()
123 self._generateData()
124
124
125 def _generateData(self):
125 def _generateData(self):
126 if self.typeDescriptor == 'ndarray':
126 if self.typeDescriptor == 'ndarray':
127 self.data = buffer(self.obj)
127 self.data = buffer(self.obj)
128 elif self.typeDescriptor in ('bytes', 'buffer'):
128 elif self.typeDescriptor in ('bytes', 'buffer'):
129 self.data = self.obj
129 self.data = self.obj
130 elif self.typeDescriptor == 'pickle':
130 elif self.typeDescriptor == 'pickle':
131 self.data = pickle.dumps(self.obj, -1)
131 self.data = pickle.dumps(self.obj, -1)
132 else:
132 else:
133 raise SerializationError("Really wierd serialization error.")
133 raise SerializationError("Really wierd serialization error.")
134 del self.obj
134 del self.obj
135
135
136 def getData(self):
136 def getData(self):
137 return self.data
137 return self.data
138
138
139 def getDataSize(self, units=10.0**6):
139 def getDataSize(self, units=10.0**6):
140 return 1.0*len(self.data)/units
140 return 1.0*len(self.data)/units
141
141
142 def getTypeDescriptor(self):
142 def getTypeDescriptor(self):
143 return self.typeDescriptor
143 return self.typeDescriptor
144
144
145 def getMetadata(self):
145 def getMetadata(self):
146 return self.metadata
146 return self.metadata
147
147
148
148
149 class UnSerializeIt(UnSerialized):
149 class UnSerializeIt(UnSerialized):
150
150
151 # implements(IUnSerialized)
151 # implements(IUnSerialized)
152
152
153 def __init__(self, serialized):
153 def __init__(self, serialized):
154 self.serialized = serialized
154 self.serialized = serialized
155
155
156 def getObject(self):
156 def getObject(self):
157 typeDescriptor = self.serialized.getTypeDescriptor()
157 typeDescriptor = self.serialized.getTypeDescriptor()
158 if numpy is not None and typeDescriptor == 'ndarray':
158 if numpy is not None and typeDescriptor == 'ndarray':
159 buf = self.serialized.getData()
159 buf = self.serialized.getData()
160 if isinstance(buf, (bytes, buffer, memoryview)):
160 if isinstance(buf, (bytes, buffer, memoryview)):
161 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
161 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
162 else:
162 else:
163 raise TypeError("Expected bytes or buffer/memoryview, but got %r"%type(buf))
163 raise TypeError("Expected bytes or buffer/memoryview, but got %r"%type(buf))
164 result.shape = self.serialized.metadata['shape']
164 result.shape = self.serialized.metadata['shape']
165 elif typeDescriptor == 'pickle':
165 elif typeDescriptor == 'pickle':
166 result = pickle.loads(self.serialized.getData())
166 result = pickle.loads(self.serialized.getData())
167 elif typeDescriptor in ('bytes', 'buffer'):
167 elif typeDescriptor in ('bytes', 'buffer'):
168 result = self.serialized.getData()
168 result = self.serialized.getData()
169 else:
169 else:
170 raise SerializationError("Really wierd serialization error.")
170 raise SerializationError("Really wierd serialization error.")
171 return result
171 return result
172
172
173 def serialize(obj):
173 def serialize(obj):
174 return SerializeIt(UnSerialized(obj))
174 return SerializeIt(UnSerialized(obj))
175
175
176 def unserialize(serialized):
176 def unserialize(serialized):
177 return UnSerializeIt(serialized).getObject()
177 return UnSerializeIt(serialized).getObject()
General Comments 0
You need to be logged in to leave comments. Login now