##// END OF EJS Templates
better serialization for parallel code...
MinRK -
Show More
@@ -0,0 +1,115 b''
1 """test serialization tools"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 import pickle
15
16 import nose.tools as nt
17
18 # from unittest import TestCaes
19 from IPython.zmq.serialize import serialize_object, unserialize_object
20 from IPython.testing import decorators as dec
21 from IPython.utils.pickleutil import CannedArray
22
23 def roundtrip(obj):
24 """roundtrip an object through serialization"""
25 bufs = serialize_object(obj)
26 obj2, remainder = unserialize_object(bufs)
27 nt.assert_equals(remainder, [])
28 return obj2
29
30 class C(object):
31 """dummy class for """
32
33 def __init__(self, **kwargs):
34 for key,value in kwargs.iteritems():
35 setattr(self, key, value)
36
37 @dec.parametric
38 def test_roundtrip_simple():
39 for obj in [
40 'hello',
41 dict(a='b', b=10),
42 [1,2,'hi'],
43 (b'123', 'hello'),
44 ]:
45 obj2 = roundtrip(obj)
46 yield nt.assert_equals(obj, obj2)
47
48 @dec.parametric
49 def test_roundtrip_nested():
50 for obj in [
51 dict(a=range(5), b={1:b'hello'}),
52 [range(5),[range(3),(1,[b'whoda'])]],
53 ]:
54 obj2 = roundtrip(obj)
55 yield nt.assert_equals(obj, obj2)
56
57 @dec.parametric
58 def test_roundtrip_buffered():
59 for obj in [
60 dict(a=b"x"*1025),
61 b"hello"*500,
62 [b"hello"*501, 1,2,3]
63 ]:
64 bufs = serialize_object(obj)
65 yield nt.assert_equals(len(bufs), 2)
66 obj2, remainder = unserialize_object(bufs)
67 yield nt.assert_equals(remainder, [])
68 yield nt.assert_equals(obj, obj2)
69
70 @dec.parametric
71 @dec.skip_without('numpy')
72 def test_numpy():
73 import numpy
74 from numpy.testing.utils import assert_array_equal
75 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
76 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
77 A = numpy.empty(shape, dtype=dtype)
78 bufs = serialize_object(A)
79 B, r = unserialize_object(bufs)
80 yield nt.assert_equals(r, [])
81 yield assert_array_equal(A,B)
82
83 @dec.parametric
84 @dec.skip_without('numpy')
85 def test_numpy_in_seq():
86 import numpy
87 from numpy.testing.utils import assert_array_equal
88 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
89 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
90 A = numpy.empty(shape, dtype=dtype)
91 bufs = serialize_object((A,1,2,b'hello'))
92 canned = pickle.loads(bufs[0])
93 yield nt.assert_true(canned[0], CannedArray)
94 tup, r = unserialize_object(bufs)
95 B = tup[0]
96 yield nt.assert_equals(r, [])
97 yield assert_array_equal(A,B)
98
99 @dec.parametric
100 @dec.skip_without('numpy')
101 def test_numpy_in_dict():
102 import numpy
103 from numpy.testing.utils import assert_array_equal
104 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
105 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
106 A = numpy.empty(shape, dtype=dtype)
107 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
108 canned = pickle.loads(bufs[0])
109 yield nt.assert_true(canned['a'], CannedArray)
110 d, r = unserialize_object(bufs)
111 B = d['a']
112 yield nt.assert_equals(r, [])
113 yield assert_array_equal(A,B)
114
115
@@ -1,597 +1,597 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30
30
31 from IPython import parallel as pmod
31 from IPython import parallel as pmod
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import DirectView
34 from IPython.parallel import DirectView
35 from IPython.parallel.util import interactive
35 from IPython.parallel.util import interactive
36
36
37 from IPython.parallel.tests import add_engines
37 from IPython.parallel.tests import add_engines
38
38
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40
40
41 def setup():
41 def setup():
42 add_engines(3, total=True)
42 add_engines(3, total=True)
43
43
44 class TestView(ClusterTestCase, ParametricTestCase):
44 class TestView(ClusterTestCase, ParametricTestCase):
45
45
46 def setUp(self):
46 def setUp(self):
47 # On Win XP, wait for resource cleanup, else parallel test group fails
47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 time.sleep(2)
50 time.sleep(2)
51 super(TestView, self).setUp()
51 super(TestView, self).setUp()
52
52
53 def test_z_crash_mux(self):
53 def test_z_crash_mux(self):
54 """test graceful handling of engine death (direct)"""
54 """test graceful handling of engine death (direct)"""
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 # self.add_engines(1)
56 # self.add_engines(1)
57 eid = self.client.ids[-1]
57 eid = self.client.ids[-1]
58 ar = self.client[eid].apply_async(crash)
58 ar = self.client[eid].apply_async(crash)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 eid = ar.engine_id
60 eid = ar.engine_id
61 tic = time.time()
61 tic = time.time()
62 while eid in self.client.ids and time.time()-tic < 5:
62 while eid in self.client.ids and time.time()-tic < 5:
63 time.sleep(.01)
63 time.sleep(.01)
64 self.client.spin()
64 self.client.spin()
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66
66
67 def test_push_pull(self):
67 def test_push_pull(self):
68 """test pushing and pulling"""
68 """test pushing and pulling"""
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 t = self.client.ids[-1]
70 t = self.client.ids[-1]
71 v = self.client[t]
71 v = self.client[t]
72 push = v.push
72 push = v.push
73 pull = v.pull
73 pull = v.pull
74 v.block=True
74 v.block=True
75 nengines = len(self.client)
75 nengines = len(self.client)
76 push({'data':data})
76 push({'data':data})
77 d = pull('data')
77 d = pull('data')
78 self.assertEqual(d, data)
78 self.assertEqual(d, data)
79 self.client[:].push({'data':data})
79 self.client[:].push({'data':data})
80 d = self.client[:].pull('data', block=True)
80 d = self.client[:].pull('data', block=True)
81 self.assertEqual(d, nengines*[data])
81 self.assertEqual(d, nengines*[data])
82 ar = push({'data':data}, block=False)
82 ar = push({'data':data}, block=False)
83 self.assertTrue(isinstance(ar, AsyncResult))
83 self.assertTrue(isinstance(ar, AsyncResult))
84 r = ar.get()
84 r = ar.get()
85 ar = self.client[:].pull('data', block=False)
85 ar = self.client[:].pull('data', block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 self.assertEqual(r, nengines*[data])
88 self.assertEqual(r, nengines*[data])
89 self.client[:].push(dict(a=10,b=20))
89 self.client[:].push(dict(a=10,b=20))
90 r = self.client[:].pull(('a','b'), block=True)
90 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEqual(r, nengines*[[10,20]])
91 self.assertEqual(r, nengines*[[10,20]])
92
92
93 def test_push_pull_function(self):
93 def test_push_pull_function(self):
94 "test pushing and pulling functions"
94 "test pushing and pulling functions"
95 def testf(x):
95 def testf(x):
96 return 2.0*x
96 return 2.0*x
97
97
98 t = self.client.ids[-1]
98 t = self.client.ids[-1]
99 v = self.client[t]
99 v = self.client[t]
100 v.block=True
100 v.block=True
101 push = v.push
101 push = v.push
102 pull = v.pull
102 pull = v.pull
103 execute = v.execute
103 execute = v.execute
104 push({'testf':testf})
104 push({'testf':testf})
105 r = pull('testf')
105 r = pull('testf')
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute('r = testf(10)')
107 execute('r = testf(10)')
108 r = pull('r')
108 r = pull('r')
109 self.assertEqual(r, testf(10))
109 self.assertEqual(r, testf(10))
110 ar = self.client[:].push({'testf':testf}, block=False)
110 ar = self.client[:].push({'testf':testf}, block=False)
111 ar.get()
111 ar.get()
112 ar = self.client[:].pull('testf', block=False)
112 ar = self.client[:].pull('testf', block=False)
113 rlist = ar.get()
113 rlist = ar.get()
114 for r in rlist:
114 for r in rlist:
115 self.assertEqual(r(1.0), testf(1.0))
115 self.assertEqual(r(1.0), testf(1.0))
116 execute("def g(x): return x*x")
116 execute("def g(x): return x*x")
117 r = pull(('testf','g'))
117 r = pull(('testf','g'))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119
119
120 def test_push_function_globals(self):
120 def test_push_function_globals(self):
121 """test that pushed functions have access to globals"""
121 """test that pushed functions have access to globals"""
122 @interactive
122 @interactive
123 def geta():
123 def geta():
124 return a
124 return a
125 # self.add_engines(1)
125 # self.add_engines(1)
126 v = self.client[-1]
126 v = self.client[-1]
127 v.block=True
127 v.block=True
128 v['f'] = geta
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
130 v.execute('a=5')
131 v.execute('b=f()')
131 v.execute('b=f()')
132 self.assertEqual(v['b'], 5)
132 self.assertEqual(v['b'], 5)
133
133
134 def test_push_function_defaults(self):
134 def test_push_function_defaults(self):
135 """test that pushed functions preserve default args"""
135 """test that pushed functions preserve default args"""
136 def echo(a=10):
136 def echo(a=10):
137 return a
137 return a
138 v = self.client[-1]
138 v = self.client[-1]
139 v.block=True
139 v.block=True
140 v['f'] = echo
140 v['f'] = echo
141 v.execute('b=f()')
141 v.execute('b=f()')
142 self.assertEqual(v['b'], 10)
142 self.assertEqual(v['b'], 10)
143
143
144 def test_get_result(self):
144 def test_get_result(self):
145 """test getting results from the Hub."""
145 """test getting results from the Hub."""
146 c = pmod.Client(profile='iptest')
146 c = pmod.Client(profile='iptest')
147 # self.add_engines(1)
147 # self.add_engines(1)
148 t = c.ids[-1]
148 t = c.ids[-1]
149 v = c[t]
149 v = c[t]
150 v2 = self.client[t]
150 v2 = self.client[t]
151 ar = v.apply_async(wait, 1)
151 ar = v.apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = v2.get_result(ar.msg_ids)
154 ahr = v2.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEqual(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
157 ar2 = v2.get_result(ar.msg_ids)
157 ar2 = v2.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.spin()
159 c.spin()
160 c.close()
160 c.close()
161
161
162 def test_run_newline(self):
162 def test_run_newline(self):
163 """test that run appends newline to files"""
163 """test that run appends newline to files"""
164 tmpfile = mktemp()
164 tmpfile = mktemp()
165 with open(tmpfile, 'w') as f:
165 with open(tmpfile, 'w') as f:
166 f.write("""def g():
166 f.write("""def g():
167 return 5
167 return 5
168 """)
168 """)
169 v = self.client[-1]
169 v = self.client[-1]
170 v.run(tmpfile, block=True)
170 v.run(tmpfile, block=True)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172
172
173 def test_apply_tracked(self):
173 def test_apply_tracked(self):
174 """test tracking for apply"""
174 """test tracking for apply"""
175 # self.add_engines(1)
175 # self.add_engines(1)
176 t = self.client.ids[-1]
176 t = self.client.ids[-1]
177 v = self.client[t]
177 v = self.client[t]
178 v.block=False
178 v.block=False
179 def echo(n=1024*1024, **kwargs):
179 def echo(n=1024*1024, **kwargs):
180 with v.temp_flags(**kwargs):
180 with v.temp_flags(**kwargs):
181 return v.apply(lambda x: x, 'x'*n)
181 return v.apply(lambda x: x, 'x'*n)
182 ar = echo(1, track=False)
182 ar = echo(1, track=False)
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(ar.sent)
184 self.assertTrue(ar.sent)
185 ar = echo(track=True)
185 ar = echo(track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEqual(ar.sent, ar._tracker.done)
187 self.assertEqual(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
188 ar._tracker.wait()
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190
190
191 def test_push_tracked(self):
191 def test_push_tracked(self):
192 t = self.client.ids[-1]
192 t = self.client.ids[-1]
193 ns = dict(x='x'*1024*1024)
193 ns = dict(x='x'*1024*1024)
194 v = self.client[t]
194 v = self.client[t]
195 ar = v.push(ns, block=False, track=False)
195 ar = v.push(ns, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
197 self.assertTrue(ar.sent)
198
198
199 ar = v.push(ns, block=False, track=True)
199 ar = v.push(ns, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 ar._tracker.wait()
201 ar._tracker.wait()
202 self.assertEqual(ar.sent, ar._tracker.done)
202 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204 ar.get()
204 ar.get()
205
205
206 def test_scatter_tracked(self):
206 def test_scatter_tracked(self):
207 t = self.client.ids
207 t = self.client.ids
208 x='x'*1024*1024
208 x='x'*1024*1024
209 ar = self.client[t].scatter('x', x, block=False, track=False)
209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(ar.sent)
211 self.assertTrue(ar.sent)
212
212
213 ar = self.client[t].scatter('x', x, block=False, track=True)
213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEqual(ar.sent, ar._tracker.done)
215 self.assertEqual(ar.sent, ar._tracker.done)
216 ar._tracker.wait()
216 ar._tracker.wait()
217 self.assertTrue(ar.sent)
217 self.assertTrue(ar.sent)
218 ar.get()
218 ar.get()
219
219
220 def test_remote_reference(self):
220 def test_remote_reference(self):
221 v = self.client[-1]
221 v = self.client[-1]
222 v['a'] = 123
222 v['a'] = 123
223 ra = pmod.Reference('a')
223 ra = pmod.Reference('a')
224 b = v.apply_sync(lambda x: x, ra)
224 b = v.apply_sync(lambda x: x, ra)
225 self.assertEqual(b, 123)
225 self.assertEqual(b, 123)
226
226
227
227
228 def test_scatter_gather(self):
228 def test_scatter_gather(self):
229 view = self.client[:]
229 view = self.client[:]
230 seq1 = range(16)
230 seq1 = range(16)
231 view.scatter('a', seq1)
231 view.scatter('a', seq1)
232 seq2 = view.gather('a', block=True)
232 seq2 = view.gather('a', block=True)
233 self.assertEqual(seq2, seq1)
233 self.assertEqual(seq2, seq1)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235
235
236 @skip_without('numpy')
236 @skip_without('numpy')
237 def test_scatter_gather_numpy(self):
237 def test_scatter_gather_numpy(self):
238 import numpy
238 import numpy
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view.scatter('a', a)
242 view.scatter('a', a, block=True)
243 b = view.gather('a', block=True)
243 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
244 assert_array_equal(b, a)
245
245
246 def test_scatter_gather_lazy(self):
246 def test_scatter_gather_lazy(self):
247 """scatter/gather with targets='all'"""
247 """scatter/gather with targets='all'"""
248 view = self.client.direct_view(targets='all')
248 view = self.client.direct_view(targets='all')
249 x = range(64)
249 x = range(64)
250 view.scatter('x', x)
250 view.scatter('x', x)
251 gathered = view.gather('x', block=True)
251 gathered = view.gather('x', block=True)
252 self.assertEqual(gathered, x)
252 self.assertEqual(gathered, x)
253
253
254
254
255 @dec.known_failure_py3
255 @dec.known_failure_py3
256 @skip_without('numpy')
256 @skip_without('numpy')
257 def test_push_numpy_nocopy(self):
257 def test_push_numpy_nocopy(self):
258 import numpy
258 import numpy
259 view = self.client[:]
259 view = self.client[:]
260 a = numpy.arange(64)
260 a = numpy.arange(64)
261 view['A'] = a
261 view['A'] = a
262 @interactive
262 @interactive
263 def check_writeable(x):
263 def check_writeable(x):
264 return x.flags.writeable
264 return x.flags.writeable
265
265
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268
268
269 view.push(dict(B=a))
269 view.push(dict(B=a))
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272
272
273 @skip_without('numpy')
273 @skip_without('numpy')
274 def test_apply_numpy(self):
274 def test_apply_numpy(self):
275 """view.apply(f, ndarray)"""
275 """view.apply(f, ndarray)"""
276 import numpy
276 import numpy
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278
278
279 A = numpy.random.random((100,100))
279 A = numpy.random.random((100,100))
280 view = self.client[-1]
280 view = self.client[-1]
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 B = A.astype(dt)
282 B = A.astype(dt)
283 C = view.apply_sync(lambda x:x, B)
283 C = view.apply_sync(lambda x:x, B)
284 assert_array_equal(B,C)
284 assert_array_equal(B,C)
285
285
286 @skip_without('numpy')
286 @skip_without('numpy')
287 def test_push_pull_recarray(self):
287 def test_push_pull_recarray(self):
288 """push/pull recarrays"""
288 """push/pull recarrays"""
289 import numpy
289 import numpy
290 from numpy.testing.utils import assert_array_equal
290 from numpy.testing.utils import assert_array_equal
291
291
292 view = self.client[-1]
292 view = self.client[-1]
293
293
294 R = numpy.array([
294 R = numpy.array([
295 (1, 'hi', 0.),
295 (1, 'hi', 0.),
296 (2**30, 'there', 2.5),
296 (2**30, 'there', 2.5),
297 (-99999, 'world', -12345.6789),
297 (-99999, 'world', -12345.6789),
298 ], [('n', int), ('s', '|S10'), ('f', float)])
298 ], [('n', int), ('s', '|S10'), ('f', float)])
299
299
300 view['RR'] = R
300 view['RR'] = R
301 R2 = view['RR']
301 R2 = view['RR']
302
302
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEqual(r_dtype, R.dtype)
304 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_shape, R.shape)
305 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(R2.dtype, R.dtype)
306 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.shape, R.shape)
307 self.assertEqual(R2.shape, R.shape)
308 assert_array_equal(R2, R)
308 assert_array_equal(R2, R)
309
309
310 def test_map(self):
310 def test_map(self):
311 view = self.client[:]
311 view = self.client[:]
312 def f(x):
312 def f(x):
313 return x**2
313 return x**2
314 data = range(16)
314 data = range(16)
315 r = view.map_sync(f, data)
315 r = view.map_sync(f, data)
316 self.assertEqual(r, map(f, data))
316 self.assertEqual(r, map(f, data))
317
317
318 def test_map_iterable(self):
318 def test_map_iterable(self):
319 """test map on iterables (direct)"""
319 """test map on iterables (direct)"""
320 view = self.client[:]
320 view = self.client[:]
321 # 101 is prime, so it won't be evenly distributed
321 # 101 is prime, so it won't be evenly distributed
322 arr = range(101)
322 arr = range(101)
323 # ensure it will be an iterator, even in Python 3
323 # ensure it will be an iterator, even in Python 3
324 it = iter(arr)
324 it = iter(arr)
325 r = view.map_sync(lambda x:x, arr)
325 r = view.map_sync(lambda x:x, arr)
326 self.assertEqual(r, list(arr))
326 self.assertEqual(r, list(arr))
327
327
328 def test_scatterGatherNonblocking(self):
328 def test_scatter_gather_nonblocking(self):
329 data = range(16)
329 data = range(16)
330 view = self.client[:]
330 view = self.client[:]
331 view.scatter('a', data, block=False)
331 view.scatter('a', data, block=False)
332 ar = view.gather('a', block=False)
332 ar = view.gather('a', block=False)
333 self.assertEqual(ar.get(), data)
333 self.assertEqual(ar.get(), data)
334
334
335 @skip_without('numpy')
335 @skip_without('numpy')
336 def test_scatter_gather_numpy_nonblocking(self):
336 def test_scatter_gather_numpy_nonblocking(self):
337 import numpy
337 import numpy
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 a = numpy.arange(64)
339 a = numpy.arange(64)
340 view = self.client[:]
340 view = self.client[:]
341 ar = view.scatter('a', a, block=False)
341 ar = view.scatter('a', a, block=False)
342 self.assertTrue(isinstance(ar, AsyncResult))
342 self.assertTrue(isinstance(ar, AsyncResult))
343 amr = view.gather('a', block=False)
343 amr = view.gather('a', block=False)
344 self.assertTrue(isinstance(amr, AsyncMapResult))
344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 assert_array_equal(amr.get(), a)
345 assert_array_equal(amr.get(), a)
346
346
347 def test_execute(self):
347 def test_execute(self):
348 view = self.client[:]
348 view = self.client[:]
349 # self.client.debug=True
349 # self.client.debug=True
350 execute = view.execute
350 execute = view.execute
351 ar = execute('c=30', block=False)
351 ar = execute('c=30', block=False)
352 self.assertTrue(isinstance(ar, AsyncResult))
352 self.assertTrue(isinstance(ar, AsyncResult))
353 ar = execute('d=[0,1,2]', block=False)
353 ar = execute('d=[0,1,2]', block=False)
354 self.client.wait(ar, 1)
354 self.client.wait(ar, 1)
355 self.assertEqual(len(ar.get()), len(self.client))
355 self.assertEqual(len(ar.get()), len(self.client))
356 for c in view['c']:
356 for c in view['c']:
357 self.assertEqual(c, 30)
357 self.assertEqual(c, 30)
358
358
359 def test_abort(self):
359 def test_abort(self):
360 view = self.client[-1]
360 view = self.client[-1]
361 ar = view.execute('import time; time.sleep(1)', block=False)
361 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar2 = view.apply_async(lambda : 2)
362 ar2 = view.apply_async(lambda : 2)
363 ar3 = view.apply_async(lambda : 3)
363 ar3 = view.apply_async(lambda : 3)
364 view.abort(ar2)
364 view.abort(ar2)
365 view.abort(ar3.msg_ids)
365 view.abort(ar3.msg_ids)
366 self.assertRaises(error.TaskAborted, ar2.get)
366 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
368
368
369 def test_abort_all(self):
369 def test_abort_all(self):
370 """view.abort() aborts all outstanding tasks"""
370 """view.abort() aborts all outstanding tasks"""
371 view = self.client[-1]
371 view = self.client[-1]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 view.abort()
373 view.abort()
374 view.wait(timeout=5)
374 view.wait(timeout=5)
375 for ar in ars[5:]:
375 for ar in ars[5:]:
376 self.assertRaises(error.TaskAborted, ar.get)
376 self.assertRaises(error.TaskAborted, ar.get)
377
377
378 def test_temp_flags(self):
378 def test_temp_flags(self):
379 view = self.client[-1]
379 view = self.client[-1]
380 view.block=True
380 view.block=True
381 with view.temp_flags(block=False):
381 with view.temp_flags(block=False):
382 self.assertFalse(view.block)
382 self.assertFalse(view.block)
383 self.assertTrue(view.block)
383 self.assertTrue(view.block)
384
384
385 @dec.known_failure_py3
385 @dec.known_failure_py3
386 def test_importer(self):
386 def test_importer(self):
387 view = self.client[-1]
387 view = self.client[-1]
388 view.clear(block=True)
388 view.clear(block=True)
389 with view.importer:
389 with view.importer:
390 import re
390 import re
391
391
392 @interactive
392 @interactive
393 def findall(pat, s):
393 def findall(pat, s):
394 # this globals() step isn't necessary in real code
394 # this globals() step isn't necessary in real code
395 # only to prevent a closure in the test
395 # only to prevent a closure in the test
396 re = globals()['re']
396 re = globals()['re']
397 return re.findall(pat, s)
397 return re.findall(pat, s)
398
398
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400
400
401 def test_unicode_execute(self):
401 def test_unicode_execute(self):
402 """test executing unicode strings"""
402 """test executing unicode strings"""
403 v = self.client[-1]
403 v = self.client[-1]
404 v.block=True
404 v.block=True
405 if sys.version_info[0] >= 3:
405 if sys.version_info[0] >= 3:
406 code="a='é'"
406 code="a='é'"
407 else:
407 else:
408 code=u"a=u'é'"
408 code=u"a=u'é'"
409 v.execute(code)
409 v.execute(code)
410 self.assertEqual(v['a'], u'é')
410 self.assertEqual(v['a'], u'é')
411
411
412 def test_unicode_apply_result(self):
412 def test_unicode_apply_result(self):
413 """test unicode apply results"""
413 """test unicode apply results"""
414 v = self.client[-1]
414 v = self.client[-1]
415 r = v.apply_sync(lambda : u'é')
415 r = v.apply_sync(lambda : u'é')
416 self.assertEqual(r, u'é')
416 self.assertEqual(r, u'é')
417
417
418 def test_unicode_apply_arg(self):
418 def test_unicode_apply_arg(self):
419 """test passing unicode arguments to apply"""
419 """test passing unicode arguments to apply"""
420 v = self.client[-1]
420 v = self.client[-1]
421
421
422 @interactive
422 @interactive
423 def check_unicode(a, check):
423 def check_unicode(a, check):
424 assert isinstance(a, unicode), "%r is not unicode"%a
424 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(check, bytes), "%r is not bytes"%check
425 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427
427
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 try:
429 try:
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 except error.RemoteError as e:
431 except error.RemoteError as e:
432 if e.ename == 'AssertionError':
432 if e.ename == 'AssertionError':
433 self.fail(e.evalue)
433 self.fail(e.evalue)
434 else:
434 else:
435 raise e
435 raise e
436
436
437 def test_map_reference(self):
437 def test_map_reference(self):
438 """view.map(<Reference>, *seqs) should work"""
438 """view.map(<Reference>, *seqs) should work"""
439 v = self.client[:]
439 v = self.client[:]
440 v.scatter('n', self.client.ids, flatten=True)
440 v.scatter('n', self.client.ids, flatten=True)
441 v.execute("f = lambda x,y: x*y")
441 v.execute("f = lambda x,y: x*y")
442 rf = pmod.Reference('f')
442 rf = pmod.Reference('f')
443 nlist = list(range(10))
443 nlist = list(range(10))
444 mlist = nlist[::-1]
444 mlist = nlist[::-1]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 result = v.map_sync(rf, mlist, nlist)
446 result = v.map_sync(rf, mlist, nlist)
447 self.assertEqual(result, expected)
447 self.assertEqual(result, expected)
448
448
449 def test_apply_reference(self):
449 def test_apply_reference(self):
450 """view.apply(<Reference>, *args) should work"""
450 """view.apply(<Reference>, *args) should work"""
451 v = self.client[:]
451 v = self.client[:]
452 v.scatter('n', self.client.ids, flatten=True)
452 v.scatter('n', self.client.ids, flatten=True)
453 v.execute("f = lambda x: n*x")
453 v.execute("f = lambda x: n*x")
454 rf = pmod.Reference('f')
454 rf = pmod.Reference('f')
455 result = v.apply_sync(rf, 5)
455 result = v.apply_sync(rf, 5)
456 expected = [ 5*id for id in self.client.ids ]
456 expected = [ 5*id for id in self.client.ids ]
457 self.assertEqual(result, expected)
457 self.assertEqual(result, expected)
458
458
459 def test_eval_reference(self):
459 def test_eval_reference(self):
460 v = self.client[self.client.ids[0]]
460 v = self.client[self.client.ids[0]]
461 v['g'] = range(5)
461 v['g'] = range(5)
462 rg = pmod.Reference('g[0]')
462 rg = pmod.Reference('g[0]')
463 echo = lambda x:x
463 echo = lambda x:x
464 self.assertEqual(v.apply_sync(echo, rg), 0)
464 self.assertEqual(v.apply_sync(echo, rg), 0)
465
465
466 def test_reference_nameerror(self):
466 def test_reference_nameerror(self):
467 v = self.client[self.client.ids[0]]
467 v = self.client[self.client.ids[0]]
468 r = pmod.Reference('elvis_has_left')
468 r = pmod.Reference('elvis_has_left')
469 echo = lambda x:x
469 echo = lambda x:x
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471
471
472 def test_single_engine_map(self):
472 def test_single_engine_map(self):
473 e0 = self.client[self.client.ids[0]]
473 e0 = self.client[self.client.ids[0]]
474 r = range(5)
474 r = range(5)
475 check = [ -1*i for i in r ]
475 check = [ -1*i for i in r ]
476 result = e0.map_sync(lambda x: -1*x, r)
476 result = e0.map_sync(lambda x: -1*x, r)
477 self.assertEqual(result, check)
477 self.assertEqual(result, check)
478
478
479 def test_len(self):
479 def test_len(self):
480 """len(view) makes sense"""
480 """len(view) makes sense"""
481 e0 = self.client[self.client.ids[0]]
481 e0 = self.client[self.client.ids[0]]
482 yield self.assertEqual(len(e0), 1)
482 yield self.assertEqual(len(e0), 1)
483 v = self.client[:]
483 v = self.client[:]
484 yield self.assertEqual(len(v), len(self.client.ids))
484 yield self.assertEqual(len(v), len(self.client.ids))
485 v = self.client.direct_view('all')
485 v = self.client.direct_view('all')
486 yield self.assertEqual(len(v), len(self.client.ids))
486 yield self.assertEqual(len(v), len(self.client.ids))
487 v = self.client[:2]
487 v = self.client[:2]
488 yield self.assertEqual(len(v), 2)
488 yield self.assertEqual(len(v), 2)
489 v = self.client[:1]
489 v = self.client[:1]
490 yield self.assertEqual(len(v), 1)
490 yield self.assertEqual(len(v), 1)
491 v = self.client.load_balanced_view()
491 v = self.client.load_balanced_view()
492 yield self.assertEqual(len(v), len(self.client.ids))
492 yield self.assertEqual(len(v), len(self.client.ids))
493 # parametric tests seem to require manual closing?
493 # parametric tests seem to require manual closing?
494 self.client.close()
494 self.client.close()
495
495
496
496
497 # begin execute tests
497 # begin execute tests
498
498
499 def test_execute_reply(self):
499 def test_execute_reply(self):
500 e0 = self.client[self.client.ids[0]]
500 e0 = self.client[self.client.ids[0]]
501 e0.block = True
501 e0.block = True
502 ar = e0.execute("5", silent=False)
502 ar = e0.execute("5", silent=False)
503 er = ar.get()
503 er = ar.get()
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506
506
507 def test_execute_reply_stdout(self):
507 def test_execute_reply_stdout(self):
508 e0 = self.client[self.client.ids[0]]
508 e0 = self.client[self.client.ids[0]]
509 e0.block = True
509 e0.block = True
510 ar = e0.execute("print (5)", silent=False)
510 ar = e0.execute("print (5)", silent=False)
511 er = ar.get()
511 er = ar.get()
512 self.assertEqual(er.stdout.strip(), '5')
512 self.assertEqual(er.stdout.strip(), '5')
513
513
514 def test_execute_pyout(self):
514 def test_execute_pyout(self):
515 """execute triggers pyout with silent=False"""
515 """execute triggers pyout with silent=False"""
516 view = self.client[:]
516 view = self.client[:]
517 ar = view.execute("5", silent=False, block=True)
517 ar = view.execute("5", silent=False, block=True)
518
518
519 expected = [{'text/plain' : '5'}] * len(view)
519 expected = [{'text/plain' : '5'}] * len(view)
520 mimes = [ out['data'] for out in ar.pyout ]
520 mimes = [ out['data'] for out in ar.pyout ]
521 self.assertEqual(mimes, expected)
521 self.assertEqual(mimes, expected)
522
522
523 def test_execute_silent(self):
523 def test_execute_silent(self):
524 """execute does not trigger pyout with silent=True"""
524 """execute does not trigger pyout with silent=True"""
525 view = self.client[:]
525 view = self.client[:]
526 ar = view.execute("5", block=True)
526 ar = view.execute("5", block=True)
527 expected = [None] * len(view)
527 expected = [None] * len(view)
528 self.assertEqual(ar.pyout, expected)
528 self.assertEqual(ar.pyout, expected)
529
529
530 def test_execute_magic(self):
530 def test_execute_magic(self):
531 """execute accepts IPython commands"""
531 """execute accepts IPython commands"""
532 view = self.client[:]
532 view = self.client[:]
533 view.execute("a = 5")
533 view.execute("a = 5")
534 ar = view.execute("%whos", block=True)
534 ar = view.execute("%whos", block=True)
535 # this will raise, if that failed
535 # this will raise, if that failed
536 ar.get(5)
536 ar.get(5)
537 for stdout in ar.stdout:
537 for stdout in ar.stdout:
538 lines = stdout.splitlines()
538 lines = stdout.splitlines()
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 found = False
540 found = False
541 for line in lines[2:]:
541 for line in lines[2:]:
542 split = line.split()
542 split = line.split()
543 if split == ['a', 'int', '5']:
543 if split == ['a', 'int', '5']:
544 found = True
544 found = True
545 break
545 break
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547
547
548 def test_execute_displaypub(self):
548 def test_execute_displaypub(self):
549 """execute tracks display_pub output"""
549 """execute tracks display_pub output"""
550 view = self.client[:]
550 view = self.client[:]
551 view.execute("from IPython.core.display import *")
551 view.execute("from IPython.core.display import *")
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553
553
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 for outputs in ar.outputs:
555 for outputs in ar.outputs:
556 mimes = [ out['data'] for out in outputs ]
556 mimes = [ out['data'] for out in outputs ]
557 self.assertEqual(mimes, expected)
557 self.assertEqual(mimes, expected)
558
558
559 def test_apply_displaypub(self):
559 def test_apply_displaypub(self):
560 """apply tracks display_pub output"""
560 """apply tracks display_pub output"""
561 view = self.client[:]
561 view = self.client[:]
562 view.execute("from IPython.core.display import *")
562 view.execute("from IPython.core.display import *")
563
563
564 @interactive
564 @interactive
565 def publish():
565 def publish():
566 [ display(i) for i in range(5) ]
566 [ display(i) for i in range(5) ]
567
567
568 ar = view.apply_async(publish)
568 ar = view.apply_async(publish)
569 ar.get(5)
569 ar.get(5)
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 for outputs in ar.outputs:
571 for outputs in ar.outputs:
572 mimes = [ out['data'] for out in outputs ]
572 mimes = [ out['data'] for out in outputs ]
573 self.assertEqual(mimes, expected)
573 self.assertEqual(mimes, expected)
574
574
575 def test_execute_raises(self):
575 def test_execute_raises(self):
576 """exceptions in execute requests raise appropriately"""
576 """exceptions in execute requests raise appropriately"""
577 view = self.client[-1]
577 view = self.client[-1]
578 ar = view.execute("1/0")
578 ar = view.execute("1/0")
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580
580
581 @dec.skipif_not_matplotlib
581 @dec.skipif_not_matplotlib
582 def test_magic_pylab(self):
582 def test_magic_pylab(self):
583 """%pylab works on engines"""
583 """%pylab works on engines"""
584 view = self.client[-1]
584 view = self.client[-1]
585 ar = view.execute("%pylab inline")
585 ar = view.execute("%pylab inline")
586 # at least check if this raised:
586 # at least check if this raised:
587 reply = ar.get(5)
587 reply = ar.get(5)
588 # include imports, in case user config
588 # include imports, in case user config
589 ar = view.execute("plot(rand(100))", silent=False)
589 ar = view.execute("plot(rand(100))", silent=False)
590 reply = ar.get(5)
590 reply = ar.get(5)
591 self.assertEqual(len(reply.outputs), 1)
591 self.assertEqual(len(reply.outputs), 1)
592 output = reply.outputs[0]
592 output = reply.outputs[0]
593 self.assertTrue("data" in output)
593 self.assertTrue("data" in output)
594 data = output['data']
594 data = output['data']
595 self.assertTrue("image/png" in data)
595 self.assertTrue("image/png" in data)
596
596
597
597
@@ -1,358 +1,352 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.utils import py3compat
47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
48 from IPython.utils.newserialized import serialize, unserialize
49 from IPython.zmq.log import EnginePUBHandler
46 from IPython.zmq.log import EnginePUBHandler
50 from IPython.zmq.serialize import (
47 from IPython.zmq.serialize import (
51 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
52 )
49 )
53
50
54 if py3compat.PY3:
55 buffer = memoryview
56
57 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
58 # Classes
52 # Classes
59 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
60
54
61 class Namespace(dict):
55 class Namespace(dict):
62 """Subclass of dict for attribute access to keys."""
56 """Subclass of dict for attribute access to keys."""
63
57
64 def __getattr__(self, key):
58 def __getattr__(self, key):
65 """getattr aliased to getitem"""
59 """getattr aliased to getitem"""
66 if key in self.iterkeys():
60 if key in self.iterkeys():
67 return self[key]
61 return self[key]
68 else:
62 else:
69 raise NameError(key)
63 raise NameError(key)
70
64
71 def __setattr__(self, key, value):
65 def __setattr__(self, key, value):
72 """setattr aliased to setitem, with strict"""
66 """setattr aliased to setitem, with strict"""
73 if hasattr(dict, key):
67 if hasattr(dict, key):
74 raise KeyError("Cannot override dict keys %r"%key)
68 raise KeyError("Cannot override dict keys %r"%key)
75 self[key] = value
69 self[key] = value
76
70
77
71
78 class ReverseDict(dict):
72 class ReverseDict(dict):
79 """simple double-keyed subset of dict methods."""
73 """simple double-keyed subset of dict methods."""
80
74
81 def __init__(self, *args, **kwargs):
75 def __init__(self, *args, **kwargs):
82 dict.__init__(self, *args, **kwargs)
76 dict.__init__(self, *args, **kwargs)
83 self._reverse = dict()
77 self._reverse = dict()
84 for key, value in self.iteritems():
78 for key, value in self.iteritems():
85 self._reverse[value] = key
79 self._reverse[value] = key
86
80
87 def __getitem__(self, key):
81 def __getitem__(self, key):
88 try:
82 try:
89 return dict.__getitem__(self, key)
83 return dict.__getitem__(self, key)
90 except KeyError:
84 except KeyError:
91 return self._reverse[key]
85 return self._reverse[key]
92
86
93 def __setitem__(self, key, value):
87 def __setitem__(self, key, value):
94 if key in self._reverse:
88 if key in self._reverse:
95 raise KeyError("Can't have key %r on both sides!"%key)
89 raise KeyError("Can't have key %r on both sides!"%key)
96 dict.__setitem__(self, key, value)
90 dict.__setitem__(self, key, value)
97 self._reverse[value] = key
91 self._reverse[value] = key
98
92
99 def pop(self, key):
93 def pop(self, key):
100 value = dict.pop(self, key)
94 value = dict.pop(self, key)
101 self._reverse.pop(value)
95 self._reverse.pop(value)
102 return value
96 return value
103
97
104 def get(self, key, default=None):
98 def get(self, key, default=None):
105 try:
99 try:
106 return self[key]
100 return self[key]
107 except KeyError:
101 except KeyError:
108 return default
102 return default
109
103
110 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
111 # Functions
105 # Functions
112 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
113
107
114 @decorator
108 @decorator
115 def log_errors(f, self, *args, **kwargs):
109 def log_errors(f, self, *args, **kwargs):
116 """decorator to log unhandled exceptions raised in a method.
110 """decorator to log unhandled exceptions raised in a method.
117
111
118 For use wrapping on_recv callbacks, so that exceptions
112 For use wrapping on_recv callbacks, so that exceptions
119 do not cause the stream to be closed.
113 do not cause the stream to be closed.
120 """
114 """
121 try:
115 try:
122 return f(self, *args, **kwargs)
116 return f(self, *args, **kwargs)
123 except Exception:
117 except Exception:
124 self.log.error("Uncaught exception in %r" % f, exc_info=True)
118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
125
119
126
120
127 def is_url(url):
121 def is_url(url):
128 """boolean check for whether a string is a zmq url"""
122 """boolean check for whether a string is a zmq url"""
129 if '://' not in url:
123 if '://' not in url:
130 return False
124 return False
131 proto, addr = url.split('://', 1)
125 proto, addr = url.split('://', 1)
132 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
133 return False
127 return False
134 return True
128 return True
135
129
136 def validate_url(url):
130 def validate_url(url):
137 """validate a url for zeromq"""
131 """validate a url for zeromq"""
138 if not isinstance(url, basestring):
132 if not isinstance(url, basestring):
139 raise TypeError("url must be a string, not %r"%type(url))
133 raise TypeError("url must be a string, not %r"%type(url))
140 url = url.lower()
134 url = url.lower()
141
135
142 proto_addr = url.split('://')
136 proto_addr = url.split('://')
143 assert len(proto_addr) == 2, 'Invalid url: %r'%url
137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
144 proto, addr = proto_addr
138 proto, addr = proto_addr
145 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
146
140
147 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
148 # author: Remi Sabourin
142 # author: Remi Sabourin
149 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
150
144
151 if proto == 'tcp':
145 if proto == 'tcp':
152 lis = addr.split(':')
146 lis = addr.split(':')
153 assert len(lis) == 2, 'Invalid url: %r'%url
147 assert len(lis) == 2, 'Invalid url: %r'%url
154 addr,s_port = lis
148 addr,s_port = lis
155 try:
149 try:
156 port = int(s_port)
150 port = int(s_port)
157 except ValueError:
151 except ValueError:
158 raise AssertionError("Invalid port %r in url: %r"%(port, url))
152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
159
153
160 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
161
155
162 else:
156 else:
163 # only validate tcp urls currently
157 # only validate tcp urls currently
164 pass
158 pass
165
159
166 return True
160 return True
167
161
168
162
169 def validate_url_container(container):
163 def validate_url_container(container):
170 """validate a potentially nested collection of urls."""
164 """validate a potentially nested collection of urls."""
171 if isinstance(container, basestring):
165 if isinstance(container, basestring):
172 url = container
166 url = container
173 return validate_url(url)
167 return validate_url(url)
174 elif isinstance(container, dict):
168 elif isinstance(container, dict):
175 container = container.itervalues()
169 container = container.itervalues()
176
170
177 for element in container:
171 for element in container:
178 validate_url_container(element)
172 validate_url_container(element)
179
173
180
174
181 def split_url(url):
175 def split_url(url):
182 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
183 proto_addr = url.split('://')
177 proto_addr = url.split('://')
184 assert len(proto_addr) == 2, 'Invalid url: %r'%url
178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
185 proto, addr = proto_addr
179 proto, addr = proto_addr
186 lis = addr.split(':')
180 lis = addr.split(':')
187 assert len(lis) == 2, 'Invalid url: %r'%url
181 assert len(lis) == 2, 'Invalid url: %r'%url
188 addr,s_port = lis
182 addr,s_port = lis
189 return proto,addr,s_port
183 return proto,addr,s_port
190
184
191 def disambiguate_ip_address(ip, location=None):
185 def disambiguate_ip_address(ip, location=None):
192 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
193 ones, based on the location (default interpretation of location is localhost)."""
187 ones, based on the location (default interpretation of location is localhost)."""
194 if ip in ('0.0.0.0', '*'):
188 if ip in ('0.0.0.0', '*'):
195 try:
189 try:
196 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
197 except (socket.gaierror, IndexError):
191 except (socket.gaierror, IndexError):
198 # couldn't identify this machine, assume localhost
192 # couldn't identify this machine, assume localhost
199 external_ips = []
193 external_ips = []
200 if location is None or location in external_ips or not external_ips:
194 if location is None or location in external_ips or not external_ips:
201 # If location is unspecified or cannot be determined, assume local
195 # If location is unspecified or cannot be determined, assume local
202 ip='127.0.0.1'
196 ip='127.0.0.1'
203 elif location:
197 elif location:
204 return location
198 return location
205 return ip
199 return ip
206
200
207 def disambiguate_url(url, location=None):
201 def disambiguate_url(url, location=None):
208 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
209 ones, based on the location (default interpretation is localhost).
203 ones, based on the location (default interpretation is localhost).
210
204
211 This is for zeromq urls, such as tcp://*:10101."""
205 This is for zeromq urls, such as tcp://*:10101."""
212 try:
206 try:
213 proto,ip,port = split_url(url)
207 proto,ip,port = split_url(url)
214 except AssertionError:
208 except AssertionError:
215 # probably not tcp url; could be ipc, etc.
209 # probably not tcp url; could be ipc, etc.
216 return url
210 return url
217
211
218 ip = disambiguate_ip_address(ip,location)
212 ip = disambiguate_ip_address(ip,location)
219
213
220 return "%s://%s:%s"%(proto,ip,port)
214 return "%s://%s:%s"%(proto,ip,port)
221
215
222
216
223 #--------------------------------------------------------------------------
217 #--------------------------------------------------------------------------
224 # helpers for implementing old MEC API via view.apply
218 # helpers for implementing old MEC API via view.apply
225 #--------------------------------------------------------------------------
219 #--------------------------------------------------------------------------
226
220
227 def interactive(f):
221 def interactive(f):
228 """decorator for making functions appear as interactively defined.
222 """decorator for making functions appear as interactively defined.
229 This results in the function being linked to the user_ns as globals()
223 This results in the function being linked to the user_ns as globals()
230 instead of the module globals().
224 instead of the module globals().
231 """
225 """
232 f.__module__ = '__main__'
226 f.__module__ = '__main__'
233 return f
227 return f
234
228
235 @interactive
229 @interactive
236 def _push(**ns):
230 def _push(**ns):
237 """helper method for implementing `client.push` via `client.apply`"""
231 """helper method for implementing `client.push` via `client.apply`"""
238 globals().update(ns)
232 globals().update(ns)
239
233
240 @interactive
234 @interactive
241 def _pull(keys):
235 def _pull(keys):
242 """helper method for implementing `client.pull` via `client.apply`"""
236 """helper method for implementing `client.pull` via `client.apply`"""
243 user_ns = globals()
237 user_ns = globals()
244 if isinstance(keys, (list,tuple, set)):
238 if isinstance(keys, (list,tuple, set)):
245 for key in keys:
239 for key in keys:
246 if key not in user_ns:
240 if key not in user_ns:
247 raise NameError("name '%s' is not defined"%key)
241 raise NameError("name '%s' is not defined"%key)
248 return map(user_ns.get, keys)
242 return map(user_ns.get, keys)
249 else:
243 else:
250 if keys not in user_ns:
244 if keys not in user_ns:
251 raise NameError("name '%s' is not defined"%keys)
245 raise NameError("name '%s' is not defined"%keys)
252 return user_ns.get(keys)
246 return user_ns.get(keys)
253
247
254 @interactive
248 @interactive
255 def _execute(code):
249 def _execute(code):
256 """helper method for implementing `client.execute` via `client.apply`"""
250 """helper method for implementing `client.execute` via `client.apply`"""
257 exec code in globals()
251 exec code in globals()
258
252
259 #--------------------------------------------------------------------------
253 #--------------------------------------------------------------------------
260 # extra process management utilities
254 # extra process management utilities
261 #--------------------------------------------------------------------------
255 #--------------------------------------------------------------------------
262
256
263 _random_ports = set()
257 _random_ports = set()
264
258
265 def select_random_ports(n):
259 def select_random_ports(n):
266 """Selects and return n random ports that are available."""
260 """Selects and return n random ports that are available."""
267 ports = []
261 ports = []
268 for i in xrange(n):
262 for i in xrange(n):
269 sock = socket.socket()
263 sock = socket.socket()
270 sock.bind(('', 0))
264 sock.bind(('', 0))
271 while sock.getsockname()[1] in _random_ports:
265 while sock.getsockname()[1] in _random_ports:
272 sock.close()
266 sock.close()
273 sock = socket.socket()
267 sock = socket.socket()
274 sock.bind(('', 0))
268 sock.bind(('', 0))
275 ports.append(sock)
269 ports.append(sock)
276 for i, sock in enumerate(ports):
270 for i, sock in enumerate(ports):
277 port = sock.getsockname()[1]
271 port = sock.getsockname()[1]
278 sock.close()
272 sock.close()
279 ports[i] = port
273 ports[i] = port
280 _random_ports.add(port)
274 _random_ports.add(port)
281 return ports
275 return ports
282
276
283 def signal_children(children):
277 def signal_children(children):
284 """Relay interupt/term signals to children, for more solid process cleanup."""
278 """Relay interupt/term signals to children, for more solid process cleanup."""
285 def terminate_children(sig, frame):
279 def terminate_children(sig, frame):
286 log = Application.instance().log
280 log = Application.instance().log
287 log.critical("Got signal %i, terminating children..."%sig)
281 log.critical("Got signal %i, terminating children..."%sig)
288 for child in children:
282 for child in children:
289 child.terminate()
283 child.terminate()
290
284
291 sys.exit(sig != SIGINT)
285 sys.exit(sig != SIGINT)
292 # sys.exit(sig)
286 # sys.exit(sig)
293 for sig in (SIGINT, SIGABRT, SIGTERM):
287 for sig in (SIGINT, SIGABRT, SIGTERM):
294 signal(sig, terminate_children)
288 signal(sig, terminate_children)
295
289
296 def generate_exec_key(keyfile):
290 def generate_exec_key(keyfile):
297 import uuid
291 import uuid
298 newkey = str(uuid.uuid4())
292 newkey = str(uuid.uuid4())
299 with open(keyfile, 'w') as f:
293 with open(keyfile, 'w') as f:
300 # f.write('ipython-key ')
294 # f.write('ipython-key ')
301 f.write(newkey+'\n')
295 f.write(newkey+'\n')
302 # set user-only RW permissions (0600)
296 # set user-only RW permissions (0600)
303 # this will have no effect on Windows
297 # this will have no effect on Windows
304 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
305
299
306
300
307 def integer_loglevel(loglevel):
301 def integer_loglevel(loglevel):
308 try:
302 try:
309 loglevel = int(loglevel)
303 loglevel = int(loglevel)
310 except ValueError:
304 except ValueError:
311 if isinstance(loglevel, str):
305 if isinstance(loglevel, str):
312 loglevel = getattr(logging, loglevel)
306 loglevel = getattr(logging, loglevel)
313 return loglevel
307 return loglevel
314
308
315 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
316 logger = logging.getLogger(logname)
310 logger = logging.getLogger(logname)
317 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
318 # don't add a second PUBHandler
312 # don't add a second PUBHandler
319 return
313 return
320 loglevel = integer_loglevel(loglevel)
314 loglevel = integer_loglevel(loglevel)
321 lsock = context.socket(zmq.PUB)
315 lsock = context.socket(zmq.PUB)
322 lsock.connect(iface)
316 lsock.connect(iface)
323 handler = handlers.PUBHandler(lsock)
317 handler = handlers.PUBHandler(lsock)
324 handler.setLevel(loglevel)
318 handler.setLevel(loglevel)
325 handler.root_topic = root
319 handler.root_topic = root
326 logger.addHandler(handler)
320 logger.addHandler(handler)
327 logger.setLevel(loglevel)
321 logger.setLevel(loglevel)
328
322
329 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
330 logger = logging.getLogger()
324 logger = logging.getLogger()
331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
332 # don't add a second PUBHandler
326 # don't add a second PUBHandler
333 return
327 return
334 loglevel = integer_loglevel(loglevel)
328 loglevel = integer_loglevel(loglevel)
335 lsock = context.socket(zmq.PUB)
329 lsock = context.socket(zmq.PUB)
336 lsock.connect(iface)
330 lsock.connect(iface)
337 handler = EnginePUBHandler(engine, lsock)
331 handler = EnginePUBHandler(engine, lsock)
338 handler.setLevel(loglevel)
332 handler.setLevel(loglevel)
339 logger.addHandler(handler)
333 logger.addHandler(handler)
340 logger.setLevel(loglevel)
334 logger.setLevel(loglevel)
341 return logger
335 return logger
342
336
343 def local_logger(logname, loglevel=logging.DEBUG):
337 def local_logger(logname, loglevel=logging.DEBUG):
344 loglevel = integer_loglevel(loglevel)
338 loglevel = integer_loglevel(loglevel)
345 logger = logging.getLogger(logname)
339 logger = logging.getLogger(logname)
346 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
347 # don't add a second StreamHandler
341 # don't add a second StreamHandler
348 return
342 return
349 handler = logging.StreamHandler()
343 handler = logging.StreamHandler()
350 handler.setLevel(loglevel)
344 handler.setLevel(loglevel)
351 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
352 datefmt="%Y-%m-%d %H:%M:%S")
346 datefmt="%Y-%m-%d %H:%M:%S")
353 handler.setFormatter(formatter)
347 handler.setFormatter(formatter)
354
348
355 logger.addHandler(handler)
349 logger.addHandler(handler)
356 logger.setLevel(loglevel)
350 logger.setLevel(loglevel)
357 return logger
351 return logger
358
352
@@ -1,151 +1,232 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Pickle related utilities. Perhaps this should be called 'can'."""
3 """Pickle related utilities. Perhaps this should be called 'can'."""
4
4
5 __docformat__ = "restructuredtext en"
5 __docformat__ = "restructuredtext en"
6
6
7 #-------------------------------------------------------------------------------
7 #-------------------------------------------------------------------------------
8 # Copyright (C) 2008-2011 The IPython Development Team
8 # Copyright (C) 2008-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 #-------------------------------------------------------------------------------
14 #-------------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 import copy
18 import copy
19 import sys
19 import sys
20 from types import FunctionType
20 from types import FunctionType
21
21
22 try:
23 import cPickle as pickle
24 except ImportError:
25 import pickle
26
27 try:
28 import numpy
29 except:
30 numpy = None
31
22 import codeutil
32 import codeutil
33 import py3compat
34 from importstring import import_item
35
36 if py3compat.PY3:
37 buffer = memoryview
23
38
24 #-------------------------------------------------------------------------------
39 #-------------------------------------------------------------------------------
25 # Classes
40 # Classes
26 #-------------------------------------------------------------------------------
41 #-------------------------------------------------------------------------------
27
42
28
43
29 class CannedObject(object):
44 class CannedObject(object):
30 def __init__(self, obj, keys=[]):
45 def __init__(self, obj, keys=[]):
31 self.keys = keys
46 self.keys = keys
32 self.obj = copy.copy(obj)
47 self.obj = copy.copy(obj)
33 for key in keys:
48 for key in keys:
34 setattr(self.obj, key, can(getattr(obj, key)))
49 setattr(self.obj, key, can(getattr(obj, key)))
50
51 self.buffers = []
35
52
36
53 def get_object(self, g=None):
37 def getObject(self, g=None):
38 if g is None:
54 if g is None:
39 g = globals()
55 g = {}
40 for key in self.keys:
56 for key in self.keys:
41 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
57 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
42 return self.obj
58 return self.obj
59
43
60
44 class Reference(CannedObject):
61 class Reference(CannedObject):
45 """object for wrapping a remote reference by name."""
62 """object for wrapping a remote reference by name."""
46 def __init__(self, name):
63 def __init__(self, name):
47 if not isinstance(name, basestring):
64 if not isinstance(name, basestring):
48 raise TypeError("illegal name: %r"%name)
65 raise TypeError("illegal name: %r"%name)
49 self.name = name
66 self.name = name
67 self.buffers = []
50
68
51 def __repr__(self):
69 def __repr__(self):
52 return "<Reference: %r>"%self.name
70 return "<Reference: %r>"%self.name
53
71
54 def getObject(self, g=None):
72 def get_object(self, g=None):
55 if g is None:
73 if g is None:
56 g = globals()
74 g = {}
57
75
58 return eval(self.name, g)
76 return eval(self.name, g)
59
77
60
78
61 class CannedFunction(CannedObject):
79 class CannedFunction(CannedObject):
62
80
63 def __init__(self, f):
81 def __init__(self, f):
64 self._checkType(f)
82 self._check_type(f)
65 self.code = f.func_code
83 self.code = f.func_code
66 self.defaults = f.func_defaults
84 self.defaults = f.func_defaults
67 self.module = f.__module__ or '__main__'
85 self.module = f.__module__ or '__main__'
68 self.__name__ = f.__name__
86 self.__name__ = f.__name__
87 self.buffers = []
69
88
70 def _checkType(self, obj):
89 def _check_type(self, obj):
71 assert isinstance(obj, FunctionType), "Not a function type"
90 assert isinstance(obj, FunctionType), "Not a function type"
72
91
73 def getObject(self, g=None):
92 def get_object(self, g=None):
74 # try to load function back into its module:
93 # try to load function back into its module:
75 if not self.module.startswith('__'):
94 if not self.module.startswith('__'):
76 try:
95 try:
77 __import__(self.module)
96 __import__(self.module)
78 except ImportError:
97 except ImportError:
79 pass
98 pass
80 else:
99 else:
81 g = sys.modules[self.module].__dict__
100 g = sys.modules[self.module].__dict__
82
101
83 if g is None:
102 if g is None:
84 g = globals()
103 g = {}
85 newFunc = FunctionType(self.code, g, self.__name__, self.defaults)
104 newFunc = FunctionType(self.code, g, self.__name__, self.defaults)
86 return newFunc
105 return newFunc
87
106
107
108 class CannedArray(CannedObject):
109 def __init__(self, obj):
110 self.shape = obj.shape
111 self.dtype = obj.dtype
112 if sum(obj.shape) == 0:
113 # just pickle it
114 self.buffers = [pickle.dumps(obj, -1)]
115 else:
116 # ensure contiguous
117 obj = numpy.ascontiguousarray(obj, dtype=None)
118 self.buffers = [buffer(obj)]
119
120 def get_object(self, g=None):
121 data = self.buffers[0]
122 if sum(self.shape) == 0:
123 # no shape, we just pickled it
124 return pickle.loads(data)
125 else:
126 return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape)
127
128
129 class CannedBytes(CannedObject):
130 wrap = bytes
131 def __init__(self, obj):
132 self.buffers = [obj]
133
134 def get_object(self, g=None):
135 data = self.buffers[0]
136 return self.wrap(data)
137
138 def CannedBuffer(CannedBytes):
139 wrap = buffer
140
88 #-------------------------------------------------------------------------------
141 #-------------------------------------------------------------------------------
89 # Functions
142 # Functions
90 #-------------------------------------------------------------------------------
143 #-------------------------------------------------------------------------------
91
144
92 def can(obj):
93 # import here to prevent module-level circular imports
94 from IPython.parallel import dependent
95 if isinstance(obj, dependent):
96 keys = ('f','df')
97 return CannedObject(obj, keys=keys)
98 elif isinstance(obj, FunctionType):
99 return CannedFunction(obj)
100 elif isinstance(obj,dict):
101 return canDict(obj)
102 elif isinstance(obj, (list,tuple)):
103 return canSequence(obj)
104 else:
105 return obj
106
145
107 def canDict(obj):
146 def can(obj):
147 """prepare an object for pickling"""
148 for cls,canner in can_map.iteritems():
149 if isinstance(cls, basestring):
150 try:
151 cls = import_item(cls)
152 except Exception:
153 # not importable
154 print "not importable: %r" % cls
155 continue
156 if isinstance(obj, cls):
157 return canner(obj)
158 return obj
159
160 def can_dict(obj):
161 """can the *values* of a dict"""
108 if isinstance(obj, dict):
162 if isinstance(obj, dict):
109 newobj = {}
163 newobj = {}
110 for k, v in obj.iteritems():
164 for k, v in obj.iteritems():
111 newobj[k] = can(v)
165 newobj[k] = can(v)
112 return newobj
166 return newobj
113 else:
167 else:
114 return obj
168 return obj
115
169
116 def canSequence(obj):
170 def can_sequence(obj):
171 """can the elements of a sequence"""
117 if isinstance(obj, (list, tuple)):
172 if isinstance(obj, (list, tuple)):
118 t = type(obj)
173 t = type(obj)
119 return t([can(i) for i in obj])
174 return t([can(i) for i in obj])
120 else:
175 else:
121 return obj
176 return obj
122
177
123 def uncan(obj, g=None):
178 def uncan(obj, g=None):
124 if isinstance(obj, CannedObject):
179 """invert canning"""
125 return obj.getObject(g)
180 for cls,uncanner in uncan_map.iteritems():
126 elif isinstance(obj,dict):
181 if isinstance(cls, basestring):
127 return uncanDict(obj, g)
182 try:
128 elif isinstance(obj, (list,tuple)):
183 cls = import_item(cls)
129 return uncanSequence(obj, g)
184 except Exception:
130 else:
185 # not importable
131 return obj
186 print "not importable: %r" % cls
132
187 continue
133 def uncanDict(obj, g=None):
188 if isinstance(obj, cls):
189 return uncanner(obj, g)
190 return obj
191
192 def uncan_dict(obj, g=None):
134 if isinstance(obj, dict):
193 if isinstance(obj, dict):
135 newobj = {}
194 newobj = {}
136 for k, v in obj.iteritems():
195 for k, v in obj.iteritems():
137 newobj[k] = uncan(v,g)
196 newobj[k] = uncan(v,g)
138 return newobj
197 return newobj
139 else:
198 else:
140 return obj
199 return obj
141
200
142 def uncanSequence(obj, g=None):
201 def uncan_sequence(obj, g=None):
143 if isinstance(obj, (list, tuple)):
202 if isinstance(obj, (list, tuple)):
144 t = type(obj)
203 t = type(obj)
145 return t([uncan(i,g) for i in obj])
204 return t([uncan(i,g) for i in obj])
146 else:
205 else:
147 return obj
206 return obj
148
207
149
208
150 def rebindFunctionGlobals(f, glbls):
209 #-------------------------------------------------------------------------------
151 return FunctionType(f.func_code, glbls)
210 # API dictionary
211 #-------------------------------------------------------------------------------
212
213 # These dicts can be extended for custom serialization of new objects
214
215 can_map = {
216 'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')),
217 'numpy.ndarray' : CannedArray,
218 FunctionType : CannedFunction,
219 bytes : CannedBytes,
220 buffer : CannedBuffer,
221 # dict : can_dict,
222 # list : can_sequence,
223 # tuple : can_sequence,
224 }
225
226 uncan_map = {
227 CannedObject : lambda obj, g: obj.get_object(g),
228 # dict : uncan_dict,
229 # list : uncan_sequence,
230 # tuple : uncan_sequence,
231 }
232
@@ -1,925 +1,925 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple interactive kernel that talks to a frontend over 0MQ.
2 """A simple interactive kernel that talks to a frontend over 0MQ.
3
3
4 Things to do:
4 Things to do:
5
5
6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
7 call set_parent on all the PUB objects with the message about to be executed.
7 call set_parent on all the PUB objects with the message about to be executed.
8 * Implement random port and security key logic.
8 * Implement random port and security key logic.
9 * Implement control messages.
9 * Implement control messages.
10 * Implement event loop and poll version.
10 * Implement event loop and poll version.
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 # Standard library imports
18 # Standard library imports
19 import __builtin__
19 import __builtin__
20 import atexit
20 import atexit
21 import sys
21 import sys
22 import time
22 import time
23 import traceback
23 import traceback
24 import logging
24 import logging
25 import uuid
25 import uuid
26
26
27 from datetime import datetime
27 from datetime import datetime
28 from signal import (
28 from signal import (
29 signal, getsignal, default_int_handler, SIGINT, SIG_IGN
29 signal, getsignal, default_int_handler, SIGINT, SIG_IGN
30 )
30 )
31
31
32 # System library imports
32 # System library imports
33 import zmq
33 import zmq
34 from zmq.eventloop import ioloop
34 from zmq.eventloop import ioloop
35 from zmq.eventloop.zmqstream import ZMQStream
35 from zmq.eventloop.zmqstream import ZMQStream
36
36
37 # Local imports
37 # Local imports
38 from IPython.config.configurable import Configurable
38 from IPython.config.configurable import Configurable
39 from IPython.config.application import boolean_flag, catch_config_error
39 from IPython.config.application import boolean_flag, catch_config_error
40 from IPython.core.application import ProfileDir
40 from IPython.core.application import ProfileDir
41 from IPython.core.error import StdinNotImplementedError
41 from IPython.core.error import StdinNotImplementedError
42 from IPython.core.shellapp import (
42 from IPython.core.shellapp import (
43 InteractiveShellApp, shell_flags, shell_aliases
43 InteractiveShellApp, shell_flags, shell_aliases
44 )
44 )
45 from IPython.utils import io
45 from IPython.utils import io
46 from IPython.utils import py3compat
46 from IPython.utils import py3compat
47 from IPython.utils.frame import extract_module_locals
47 from IPython.utils.frame import extract_module_locals
48 from IPython.utils.jsonutil import json_clean
48 from IPython.utils.jsonutil import json_clean
49 from IPython.utils.traitlets import (
49 from IPython.utils.traitlets import (
50 Any, Instance, Float, Dict, CaselessStrEnum, List, Set, Integer, Unicode
50 Any, Instance, Float, Dict, CaselessStrEnum, List, Set, Integer, Unicode
51 )
51 )
52
52
53 from entry_point import base_launch_kernel
53 from entry_point import base_launch_kernel
54 from kernelapp import KernelApp, kernel_flags, kernel_aliases
54 from kernelapp import KernelApp, kernel_flags, kernel_aliases
55 from serialize import serialize_object, unpack_apply_message
55 from serialize import serialize_object, unpack_apply_message
56 from session import Session, Message
56 from session import Session, Message
57 from zmqshell import ZMQInteractiveShell
57 from zmqshell import ZMQInteractiveShell
58
58
59
59
60 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
61 # Main kernel class
61 # Main kernel class
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63
63
64 class Kernel(Configurable):
64 class Kernel(Configurable):
65
65
66 #---------------------------------------------------------------------------
66 #---------------------------------------------------------------------------
67 # Kernel interface
67 # Kernel interface
68 #---------------------------------------------------------------------------
68 #---------------------------------------------------------------------------
69
69
70 # attribute to override with a GUI
70 # attribute to override with a GUI
71 eventloop = Any(None)
71 eventloop = Any(None)
72 def _eventloop_changed(self, name, old, new):
72 def _eventloop_changed(self, name, old, new):
73 """schedule call to eventloop from IOLoop"""
73 """schedule call to eventloop from IOLoop"""
74 loop = ioloop.IOLoop.instance()
74 loop = ioloop.IOLoop.instance()
75 loop.add_timeout(time.time()+0.1, self.enter_eventloop)
75 loop.add_timeout(time.time()+0.1, self.enter_eventloop)
76
76
77 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
77 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
78 session = Instance(Session)
78 session = Instance(Session)
79 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
79 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
80 shell_streams = List()
80 shell_streams = List()
81 control_stream = Instance(ZMQStream)
81 control_stream = Instance(ZMQStream)
82 iopub_socket = Instance(zmq.Socket)
82 iopub_socket = Instance(zmq.Socket)
83 stdin_socket = Instance(zmq.Socket)
83 stdin_socket = Instance(zmq.Socket)
84 log = Instance(logging.Logger)
84 log = Instance(logging.Logger)
85
85
86 user_module = Any()
86 user_module = Any()
87 def _user_module_changed(self, name, old, new):
87 def _user_module_changed(self, name, old, new):
88 if self.shell is not None:
88 if self.shell is not None:
89 self.shell.user_module = new
89 self.shell.user_module = new
90
90
91 user_ns = Dict(default_value=None)
91 user_ns = Dict(default_value=None)
92 def _user_ns_changed(self, name, old, new):
92 def _user_ns_changed(self, name, old, new):
93 if self.shell is not None:
93 if self.shell is not None:
94 self.shell.user_ns = new
94 self.shell.user_ns = new
95 self.shell.init_user_ns()
95 self.shell.init_user_ns()
96
96
97 # identities:
97 # identities:
98 int_id = Integer(-1)
98 int_id = Integer(-1)
99 ident = Unicode()
99 ident = Unicode()
100
100
101 def _ident_default(self):
101 def _ident_default(self):
102 return unicode(uuid.uuid4())
102 return unicode(uuid.uuid4())
103
103
104
104
105 # Private interface
105 # Private interface
106
106
107 # Time to sleep after flushing the stdout/err buffers in each execute
107 # Time to sleep after flushing the stdout/err buffers in each execute
108 # cycle. While this introduces a hard limit on the minimal latency of the
108 # cycle. While this introduces a hard limit on the minimal latency of the
109 # execute cycle, it helps prevent output synchronization problems for
109 # execute cycle, it helps prevent output synchronization problems for
110 # clients.
110 # clients.
111 # Units are in seconds. The minimum zmq latency on local host is probably
111 # Units are in seconds. The minimum zmq latency on local host is probably
112 # ~150 microseconds, set this to 500us for now. We may need to increase it
112 # ~150 microseconds, set this to 500us for now. We may need to increase it
113 # a little if it's not enough after more interactive testing.
113 # a little if it's not enough after more interactive testing.
114 _execute_sleep = Float(0.0005, config=True)
114 _execute_sleep = Float(0.0005, config=True)
115
115
116 # Frequency of the kernel's event loop.
116 # Frequency of the kernel's event loop.
117 # Units are in seconds, kernel subclasses for GUI toolkits may need to
117 # Units are in seconds, kernel subclasses for GUI toolkits may need to
118 # adapt to milliseconds.
118 # adapt to milliseconds.
119 _poll_interval = Float(0.05, config=True)
119 _poll_interval = Float(0.05, config=True)
120
120
121 # If the shutdown was requested over the network, we leave here the
121 # If the shutdown was requested over the network, we leave here the
122 # necessary reply message so it can be sent by our registered atexit
122 # necessary reply message so it can be sent by our registered atexit
123 # handler. This ensures that the reply is only sent to clients truly at
123 # handler. This ensures that the reply is only sent to clients truly at
124 # the end of our shutdown process (which happens after the underlying
124 # the end of our shutdown process (which happens after the underlying
125 # IPython shell's own shutdown).
125 # IPython shell's own shutdown).
126 _shutdown_message = None
126 _shutdown_message = None
127
127
128 # This is a dict of port number that the kernel is listening on. It is set
128 # This is a dict of port number that the kernel is listening on. It is set
129 # by record_ports and used by connect_request.
129 # by record_ports and used by connect_request.
130 _recorded_ports = Dict()
130 _recorded_ports = Dict()
131
131
132 # set of aborted msg_ids
132 # set of aborted msg_ids
133 aborted = Set()
133 aborted = Set()
134
134
135
135
136 def __init__(self, **kwargs):
136 def __init__(self, **kwargs):
137 super(Kernel, self).__init__(**kwargs)
137 super(Kernel, self).__init__(**kwargs)
138
138
139 # Initialize the InteractiveShell subclass
139 # Initialize the InteractiveShell subclass
140 self.shell = ZMQInteractiveShell.instance(config=self.config,
140 self.shell = ZMQInteractiveShell.instance(config=self.config,
141 profile_dir = self.profile_dir,
141 profile_dir = self.profile_dir,
142 user_module = self.user_module,
142 user_module = self.user_module,
143 user_ns = self.user_ns,
143 user_ns = self.user_ns,
144 )
144 )
145 self.shell.displayhook.session = self.session
145 self.shell.displayhook.session = self.session
146 self.shell.displayhook.pub_socket = self.iopub_socket
146 self.shell.displayhook.pub_socket = self.iopub_socket
147 self.shell.displayhook.topic = self._topic('pyout')
147 self.shell.displayhook.topic = self._topic('pyout')
148 self.shell.display_pub.session = self.session
148 self.shell.display_pub.session = self.session
149 self.shell.display_pub.pub_socket = self.iopub_socket
149 self.shell.display_pub.pub_socket = self.iopub_socket
150
150
151 # TMP - hack while developing
151 # TMP - hack while developing
152 self.shell._reply_content = None
152 self.shell._reply_content = None
153
153
154 # Build dict of handlers for message types
154 # Build dict of handlers for message types
155 msg_types = [ 'execute_request', 'complete_request',
155 msg_types = [ 'execute_request', 'complete_request',
156 'object_info_request', 'history_request',
156 'object_info_request', 'history_request',
157 'connect_request', 'shutdown_request',
157 'connect_request', 'shutdown_request',
158 'apply_request',
158 'apply_request',
159 ]
159 ]
160 self.shell_handlers = {}
160 self.shell_handlers = {}
161 for msg_type in msg_types:
161 for msg_type in msg_types:
162 self.shell_handlers[msg_type] = getattr(self, msg_type)
162 self.shell_handlers[msg_type] = getattr(self, msg_type)
163
163
164 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
164 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
165 self.control_handlers = {}
165 self.control_handlers = {}
166 for msg_type in control_msg_types:
166 for msg_type in control_msg_types:
167 self.control_handlers[msg_type] = getattr(self, msg_type)
167 self.control_handlers[msg_type] = getattr(self, msg_type)
168
168
169 def dispatch_control(self, msg):
169 def dispatch_control(self, msg):
170 """dispatch control requests"""
170 """dispatch control requests"""
171 idents,msg = self.session.feed_identities(msg, copy=False)
171 idents,msg = self.session.feed_identities(msg, copy=False)
172 try:
172 try:
173 msg = self.session.unserialize(msg, content=True, copy=False)
173 msg = self.session.unserialize(msg, content=True, copy=False)
174 except:
174 except:
175 self.log.error("Invalid Control Message", exc_info=True)
175 self.log.error("Invalid Control Message", exc_info=True)
176 return
176 return
177
177
178 self.log.debug("Control received: %s", msg)
178 self.log.debug("Control received: %s", msg)
179
179
180 header = msg['header']
180 header = msg['header']
181 msg_id = header['msg_id']
181 msg_id = header['msg_id']
182 msg_type = header['msg_type']
182 msg_type = header['msg_type']
183
183
184 handler = self.control_handlers.get(msg_type, None)
184 handler = self.control_handlers.get(msg_type, None)
185 if handler is None:
185 if handler is None:
186 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
186 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
187 else:
187 else:
188 try:
188 try:
189 handler(self.control_stream, idents, msg)
189 handler(self.control_stream, idents, msg)
190 except Exception:
190 except Exception:
191 self.log.error("Exception in control handler:", exc_info=True)
191 self.log.error("Exception in control handler:", exc_info=True)
192
192
193 def dispatch_shell(self, stream, msg):
193 def dispatch_shell(self, stream, msg):
194 """dispatch shell requests"""
194 """dispatch shell requests"""
195 # flush control requests first
195 # flush control requests first
196 if self.control_stream:
196 if self.control_stream:
197 self.control_stream.flush()
197 self.control_stream.flush()
198
198
199 idents,msg = self.session.feed_identities(msg, copy=False)
199 idents,msg = self.session.feed_identities(msg, copy=False)
200 try:
200 try:
201 msg = self.session.unserialize(msg, content=True, copy=False)
201 msg = self.session.unserialize(msg, content=True, copy=False)
202 except:
202 except:
203 self.log.error("Invalid Message", exc_info=True)
203 self.log.error("Invalid Message", exc_info=True)
204 return
204 return
205
205
206 header = msg['header']
206 header = msg['header']
207 msg_id = header['msg_id']
207 msg_id = header['msg_id']
208 msg_type = msg['header']['msg_type']
208 msg_type = msg['header']['msg_type']
209
209
210 # Print some info about this message and leave a '--->' marker, so it's
210 # Print some info about this message and leave a '--->' marker, so it's
211 # easier to trace visually the message chain when debugging. Each
211 # easier to trace visually the message chain when debugging. Each
212 # handler prints its message at the end.
212 # handler prints its message at the end.
213 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
213 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
214 self.log.debug(' Content: %s\n --->\n ', msg['content'])
214 self.log.debug(' Content: %s\n --->\n ', msg['content'])
215
215
216 if msg_id in self.aborted:
216 if msg_id in self.aborted:
217 self.aborted.remove(msg_id)
217 self.aborted.remove(msg_id)
218 # is it safe to assume a msg_id will not be resubmitted?
218 # is it safe to assume a msg_id will not be resubmitted?
219 reply_type = msg_type.split('_')[0] + '_reply'
219 reply_type = msg_type.split('_')[0] + '_reply'
220 status = {'status' : 'aborted'}
220 status = {'status' : 'aborted'}
221 md = {'engine' : self.ident}
221 md = {'engine' : self.ident}
222 md.update(status)
222 md.update(status)
223 reply_msg = self.session.send(stream, reply_type, metadata=md,
223 reply_msg = self.session.send(stream, reply_type, metadata=md,
224 content=status, parent=msg, ident=idents)
224 content=status, parent=msg, ident=idents)
225 return
225 return
226
226
227 handler = self.shell_handlers.get(msg_type, None)
227 handler = self.shell_handlers.get(msg_type, None)
228 if handler is None:
228 if handler is None:
229 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
229 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
230 else:
230 else:
231 # ensure default_int_handler during handler call
231 # ensure default_int_handler during handler call
232 sig = signal(SIGINT, default_int_handler)
232 sig = signal(SIGINT, default_int_handler)
233 try:
233 try:
234 handler(stream, idents, msg)
234 handler(stream, idents, msg)
235 except Exception:
235 except Exception:
236 self.log.error("Exception in message handler:", exc_info=True)
236 self.log.error("Exception in message handler:", exc_info=True)
237 finally:
237 finally:
238 signal(SIGINT, sig)
238 signal(SIGINT, sig)
239
239
240 def enter_eventloop(self):
240 def enter_eventloop(self):
241 """enter eventloop"""
241 """enter eventloop"""
242 self.log.info("entering eventloop")
242 self.log.info("entering eventloop")
243 # restore default_int_handler
243 # restore default_int_handler
244 signal(SIGINT, default_int_handler)
244 signal(SIGINT, default_int_handler)
245 while self.eventloop is not None:
245 while self.eventloop is not None:
246 try:
246 try:
247 self.eventloop(self)
247 self.eventloop(self)
248 except KeyboardInterrupt:
248 except KeyboardInterrupt:
249 # Ctrl-C shouldn't crash the kernel
249 # Ctrl-C shouldn't crash the kernel
250 self.log.error("KeyboardInterrupt caught in kernel")
250 self.log.error("KeyboardInterrupt caught in kernel")
251 continue
251 continue
252 else:
252 else:
253 # eventloop exited cleanly, this means we should stop (right?)
253 # eventloop exited cleanly, this means we should stop (right?)
254 self.eventloop = None
254 self.eventloop = None
255 break
255 break
256 self.log.info("exiting eventloop")
256 self.log.info("exiting eventloop")
257 # if eventloop exits, IOLoop should stop
257 # if eventloop exits, IOLoop should stop
258 ioloop.IOLoop.instance().stop()
258 ioloop.IOLoop.instance().stop()
259
259
260 def start(self):
260 def start(self):
261 """register dispatchers for streams"""
261 """register dispatchers for streams"""
262 self.shell.exit_now = False
262 self.shell.exit_now = False
263 if self.control_stream:
263 if self.control_stream:
264 self.control_stream.on_recv(self.dispatch_control, copy=False)
264 self.control_stream.on_recv(self.dispatch_control, copy=False)
265
265
266 def make_dispatcher(stream):
266 def make_dispatcher(stream):
267 def dispatcher(msg):
267 def dispatcher(msg):
268 return self.dispatch_shell(stream, msg)
268 return self.dispatch_shell(stream, msg)
269 return dispatcher
269 return dispatcher
270
270
271 for s in self.shell_streams:
271 for s in self.shell_streams:
272 s.on_recv(make_dispatcher(s), copy=False)
272 s.on_recv(make_dispatcher(s), copy=False)
273
273
274 def do_one_iteration(self):
274 def do_one_iteration(self):
275 """step eventloop just once"""
275 """step eventloop just once"""
276 if self.control_stream:
276 if self.control_stream:
277 self.control_stream.flush()
277 self.control_stream.flush()
278 for stream in self.shell_streams:
278 for stream in self.shell_streams:
279 # handle at most one request per iteration
279 # handle at most one request per iteration
280 stream.flush(zmq.POLLIN, 1)
280 stream.flush(zmq.POLLIN, 1)
281 stream.flush(zmq.POLLOUT)
281 stream.flush(zmq.POLLOUT)
282
282
283
283
284 def record_ports(self, ports):
284 def record_ports(self, ports):
285 """Record the ports that this kernel is using.
285 """Record the ports that this kernel is using.
286
286
287 The creator of the Kernel instance must call this methods if they
287 The creator of the Kernel instance must call this methods if they
288 want the :meth:`connect_request` method to return the port numbers.
288 want the :meth:`connect_request` method to return the port numbers.
289 """
289 """
290 self._recorded_ports = ports
290 self._recorded_ports = ports
291
291
292 #---------------------------------------------------------------------------
292 #---------------------------------------------------------------------------
293 # Kernel request handlers
293 # Kernel request handlers
294 #---------------------------------------------------------------------------
294 #---------------------------------------------------------------------------
295
295
296 def _make_metadata(self, other=None):
296 def _make_metadata(self, other=None):
297 """init metadata dict, for execute/apply_reply"""
297 """init metadata dict, for execute/apply_reply"""
298 new_md = {
298 new_md = {
299 'dependencies_met' : True,
299 'dependencies_met' : True,
300 'engine' : self.ident,
300 'engine' : self.ident,
301 'started': datetime.now(),
301 'started': datetime.now(),
302 }
302 }
303 if other:
303 if other:
304 new_md.update(other)
304 new_md.update(other)
305 return new_md
305 return new_md
306
306
307 def _publish_pyin(self, code, parent, execution_count):
307 def _publish_pyin(self, code, parent, execution_count):
308 """Publish the code request on the pyin stream."""
308 """Publish the code request on the pyin stream."""
309
309
310 self.session.send(self.iopub_socket, u'pyin',
310 self.session.send(self.iopub_socket, u'pyin',
311 {u'code':code, u'execution_count': execution_count},
311 {u'code':code, u'execution_count': execution_count},
312 parent=parent, ident=self._topic('pyin')
312 parent=parent, ident=self._topic('pyin')
313 )
313 )
314
314
315 def _publish_status(self, status, parent=None):
315 def _publish_status(self, status, parent=None):
316 """send status (busy/idle) on IOPub"""
316 """send status (busy/idle) on IOPub"""
317 self.session.send(self.iopub_socket,
317 self.session.send(self.iopub_socket,
318 u'status',
318 u'status',
319 {u'execution_state': status},
319 {u'execution_state': status},
320 parent=parent,
320 parent=parent,
321 ident=self._topic('status'),
321 ident=self._topic('status'),
322 )
322 )
323
323
324
324
325 def execute_request(self, stream, ident, parent):
325 def execute_request(self, stream, ident, parent):
326 """handle an execute_request"""
326 """handle an execute_request"""
327
327
328 self._publish_status(u'busy', parent)
328 self._publish_status(u'busy', parent)
329
329
330 try:
330 try:
331 content = parent[u'content']
331 content = parent[u'content']
332 code = content[u'code']
332 code = content[u'code']
333 silent = content[u'silent']
333 silent = content[u'silent']
334 except:
334 except:
335 self.log.error("Got bad msg: ")
335 self.log.error("Got bad msg: ")
336 self.log.error("%s", parent)
336 self.log.error("%s", parent)
337 return
337 return
338
338
339 md = self._make_metadata(parent['metadata'])
339 md = self._make_metadata(parent['metadata'])
340
340
341 shell = self.shell # we'll need this a lot here
341 shell = self.shell # we'll need this a lot here
342
342
343 # Replace raw_input. Note that is not sufficient to replace
343 # Replace raw_input. Note that is not sufficient to replace
344 # raw_input in the user namespace.
344 # raw_input in the user namespace.
345 if content.get('allow_stdin', False):
345 if content.get('allow_stdin', False):
346 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
346 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
347 else:
347 else:
348 raw_input = lambda prompt='' : self._no_raw_input()
348 raw_input = lambda prompt='' : self._no_raw_input()
349
349
350 if py3compat.PY3:
350 if py3compat.PY3:
351 __builtin__.input = raw_input
351 __builtin__.input = raw_input
352 else:
352 else:
353 __builtin__.raw_input = raw_input
353 __builtin__.raw_input = raw_input
354
354
355 # Set the parent message of the display hook and out streams.
355 # Set the parent message of the display hook and out streams.
356 shell.displayhook.set_parent(parent)
356 shell.displayhook.set_parent(parent)
357 shell.display_pub.set_parent(parent)
357 shell.display_pub.set_parent(parent)
358 sys.stdout.set_parent(parent)
358 sys.stdout.set_parent(parent)
359 sys.stderr.set_parent(parent)
359 sys.stderr.set_parent(parent)
360
360
361 # Re-broadcast our input for the benefit of listening clients, and
361 # Re-broadcast our input for the benefit of listening clients, and
362 # start computing output
362 # start computing output
363 if not silent:
363 if not silent:
364 self._publish_pyin(code, parent, shell.execution_count)
364 self._publish_pyin(code, parent, shell.execution_count)
365
365
366 reply_content = {}
366 reply_content = {}
367 try:
367 try:
368 # FIXME: the shell calls the exception handler itself.
368 # FIXME: the shell calls the exception handler itself.
369 shell.run_cell(code, store_history=not silent, silent=silent)
369 shell.run_cell(code, store_history=not silent, silent=silent)
370 except:
370 except:
371 status = u'error'
371 status = u'error'
372 # FIXME: this code right now isn't being used yet by default,
372 # FIXME: this code right now isn't being used yet by default,
373 # because the run_cell() call above directly fires off exception
373 # because the run_cell() call above directly fires off exception
374 # reporting. This code, therefore, is only active in the scenario
374 # reporting. This code, therefore, is only active in the scenario
375 # where runlines itself has an unhandled exception. We need to
375 # where runlines itself has an unhandled exception. We need to
376 # uniformize this, for all exception construction to come from a
376 # uniformize this, for all exception construction to come from a
377 # single location in the codbase.
377 # single location in the codbase.
378 etype, evalue, tb = sys.exc_info()
378 etype, evalue, tb = sys.exc_info()
379 tb_list = traceback.format_exception(etype, evalue, tb)
379 tb_list = traceback.format_exception(etype, evalue, tb)
380 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
380 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
381 else:
381 else:
382 status = u'ok'
382 status = u'ok'
383
383
384 reply_content[u'status'] = status
384 reply_content[u'status'] = status
385
385
386 # Return the execution counter so clients can display prompts
386 # Return the execution counter so clients can display prompts
387 reply_content['execution_count'] = shell.execution_count - 1
387 reply_content['execution_count'] = shell.execution_count - 1
388
388
389 # FIXME - fish exception info out of shell, possibly left there by
389 # FIXME - fish exception info out of shell, possibly left there by
390 # runlines. We'll need to clean up this logic later.
390 # runlines. We'll need to clean up this logic later.
391 if shell._reply_content is not None:
391 if shell._reply_content is not None:
392 reply_content.update(shell._reply_content)
392 reply_content.update(shell._reply_content)
393 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='execute')
393 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='execute')
394 reply_content['engine_info'] = e_info
394 reply_content['engine_info'] = e_info
395 # reset after use
395 # reset after use
396 shell._reply_content = None
396 shell._reply_content = None
397
397
398 # At this point, we can tell whether the main code execution succeeded
398 # At this point, we can tell whether the main code execution succeeded
399 # or not. If it did, we proceed to evaluate user_variables/expressions
399 # or not. If it did, we proceed to evaluate user_variables/expressions
400 if reply_content['status'] == 'ok':
400 if reply_content['status'] == 'ok':
401 reply_content[u'user_variables'] = \
401 reply_content[u'user_variables'] = \
402 shell.user_variables(content.get(u'user_variables', []))
402 shell.user_variables(content.get(u'user_variables', []))
403 reply_content[u'user_expressions'] = \
403 reply_content[u'user_expressions'] = \
404 shell.user_expressions(content.get(u'user_expressions', {}))
404 shell.user_expressions(content.get(u'user_expressions', {}))
405 else:
405 else:
406 # If there was an error, don't even try to compute variables or
406 # If there was an error, don't even try to compute variables or
407 # expressions
407 # expressions
408 reply_content[u'user_variables'] = {}
408 reply_content[u'user_variables'] = {}
409 reply_content[u'user_expressions'] = {}
409 reply_content[u'user_expressions'] = {}
410
410
411 # Payloads should be retrieved regardless of outcome, so we can both
411 # Payloads should be retrieved regardless of outcome, so we can both
412 # recover partial output (that could have been generated early in a
412 # recover partial output (that could have been generated early in a
413 # block, before an error) and clear the payload system always.
413 # block, before an error) and clear the payload system always.
414 reply_content[u'payload'] = shell.payload_manager.read_payload()
414 reply_content[u'payload'] = shell.payload_manager.read_payload()
415 # Be agressive about clearing the payload because we don't want
415 # Be agressive about clearing the payload because we don't want
416 # it to sit in memory until the next execute_request comes in.
416 # it to sit in memory until the next execute_request comes in.
417 shell.payload_manager.clear_payload()
417 shell.payload_manager.clear_payload()
418
418
419 # Flush output before sending the reply.
419 # Flush output before sending the reply.
420 sys.stdout.flush()
420 sys.stdout.flush()
421 sys.stderr.flush()
421 sys.stderr.flush()
422 # FIXME: on rare occasions, the flush doesn't seem to make it to the
422 # FIXME: on rare occasions, the flush doesn't seem to make it to the
423 # clients... This seems to mitigate the problem, but we definitely need
423 # clients... This seems to mitigate the problem, but we definitely need
424 # to better understand what's going on.
424 # to better understand what's going on.
425 if self._execute_sleep:
425 if self._execute_sleep:
426 time.sleep(self._execute_sleep)
426 time.sleep(self._execute_sleep)
427
427
428 # Send the reply.
428 # Send the reply.
429 reply_content = json_clean(reply_content)
429 reply_content = json_clean(reply_content)
430
430
431 md['status'] = reply_content['status']
431 md['status'] = reply_content['status']
432 if reply_content['status'] == 'error' and \
432 if reply_content['status'] == 'error' and \
433 reply_content['ename'] == 'UnmetDependency':
433 reply_content['ename'] == 'UnmetDependency':
434 md['dependencies_met'] = False
434 md['dependencies_met'] = False
435
435
436 reply_msg = self.session.send(stream, u'execute_reply',
436 reply_msg = self.session.send(stream, u'execute_reply',
437 reply_content, parent, metadata=md,
437 reply_content, parent, metadata=md,
438 ident=ident)
438 ident=ident)
439
439
440 self.log.debug("%s", reply_msg)
440 self.log.debug("%s", reply_msg)
441
441
442 if not silent and reply_msg['content']['status'] == u'error':
442 if not silent and reply_msg['content']['status'] == u'error':
443 self._abort_queues()
443 self._abort_queues()
444
444
445 self._publish_status(u'idle', parent)
445 self._publish_status(u'idle', parent)
446
446
447 def complete_request(self, stream, ident, parent):
447 def complete_request(self, stream, ident, parent):
448 txt, matches = self._complete(parent)
448 txt, matches = self._complete(parent)
449 matches = {'matches' : matches,
449 matches = {'matches' : matches,
450 'matched_text' : txt,
450 'matched_text' : txt,
451 'status' : 'ok'}
451 'status' : 'ok'}
452 matches = json_clean(matches)
452 matches = json_clean(matches)
453 completion_msg = self.session.send(stream, 'complete_reply',
453 completion_msg = self.session.send(stream, 'complete_reply',
454 matches, parent, ident)
454 matches, parent, ident)
455 self.log.debug("%s", completion_msg)
455 self.log.debug("%s", completion_msg)
456
456
457 def object_info_request(self, stream, ident, parent):
457 def object_info_request(self, stream, ident, parent):
458 content = parent['content']
458 content = parent['content']
459 object_info = self.shell.object_inspect(content['oname'],
459 object_info = self.shell.object_inspect(content['oname'],
460 detail_level = content.get('detail_level', 0)
460 detail_level = content.get('detail_level', 0)
461 )
461 )
462 # Before we send this object over, we scrub it for JSON usage
462 # Before we send this object over, we scrub it for JSON usage
463 oinfo = json_clean(object_info)
463 oinfo = json_clean(object_info)
464 msg = self.session.send(stream, 'object_info_reply',
464 msg = self.session.send(stream, 'object_info_reply',
465 oinfo, parent, ident)
465 oinfo, parent, ident)
466 self.log.debug("%s", msg)
466 self.log.debug("%s", msg)
467
467
468 def history_request(self, stream, ident, parent):
468 def history_request(self, stream, ident, parent):
469 # We need to pull these out, as passing **kwargs doesn't work with
469 # We need to pull these out, as passing **kwargs doesn't work with
470 # unicode keys before Python 2.6.5.
470 # unicode keys before Python 2.6.5.
471 hist_access_type = parent['content']['hist_access_type']
471 hist_access_type = parent['content']['hist_access_type']
472 raw = parent['content']['raw']
472 raw = parent['content']['raw']
473 output = parent['content']['output']
473 output = parent['content']['output']
474 if hist_access_type == 'tail':
474 if hist_access_type == 'tail':
475 n = parent['content']['n']
475 n = parent['content']['n']
476 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
476 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
477 include_latest=True)
477 include_latest=True)
478
478
479 elif hist_access_type == 'range':
479 elif hist_access_type == 'range':
480 session = parent['content']['session']
480 session = parent['content']['session']
481 start = parent['content']['start']
481 start = parent['content']['start']
482 stop = parent['content']['stop']
482 stop = parent['content']['stop']
483 hist = self.shell.history_manager.get_range(session, start, stop,
483 hist = self.shell.history_manager.get_range(session, start, stop,
484 raw=raw, output=output)
484 raw=raw, output=output)
485
485
486 elif hist_access_type == 'search':
486 elif hist_access_type == 'search':
487 pattern = parent['content']['pattern']
487 pattern = parent['content']['pattern']
488 hist = self.shell.history_manager.search(pattern, raw=raw,
488 hist = self.shell.history_manager.search(pattern, raw=raw,
489 output=output)
489 output=output)
490
490
491 else:
491 else:
492 hist = []
492 hist = []
493 hist = list(hist)
493 hist = list(hist)
494 content = {'history' : hist}
494 content = {'history' : hist}
495 content = json_clean(content)
495 content = json_clean(content)
496 msg = self.session.send(stream, 'history_reply',
496 msg = self.session.send(stream, 'history_reply',
497 content, parent, ident)
497 content, parent, ident)
498 self.log.debug("Sending history reply with %i entries", len(hist))
498 self.log.debug("Sending history reply with %i entries", len(hist))
499
499
500 def connect_request(self, stream, ident, parent):
500 def connect_request(self, stream, ident, parent):
501 if self._recorded_ports is not None:
501 if self._recorded_ports is not None:
502 content = self._recorded_ports.copy()
502 content = self._recorded_ports.copy()
503 else:
503 else:
504 content = {}
504 content = {}
505 msg = self.session.send(stream, 'connect_reply',
505 msg = self.session.send(stream, 'connect_reply',
506 content, parent, ident)
506 content, parent, ident)
507 self.log.debug("%s", msg)
507 self.log.debug("%s", msg)
508
508
509 def shutdown_request(self, stream, ident, parent):
509 def shutdown_request(self, stream, ident, parent):
510 self.shell.exit_now = True
510 self.shell.exit_now = True
511 content = dict(status='ok')
511 content = dict(status='ok')
512 content.update(parent['content'])
512 content.update(parent['content'])
513 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
513 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
514 # same content, but different msg_id for broadcasting on IOPub
514 # same content, but different msg_id for broadcasting on IOPub
515 self._shutdown_message = self.session.msg(u'shutdown_reply',
515 self._shutdown_message = self.session.msg(u'shutdown_reply',
516 content, parent
516 content, parent
517 )
517 )
518
518
519 self._at_shutdown()
519 self._at_shutdown()
520 # call sys.exit after a short delay
520 # call sys.exit after a short delay
521 loop = ioloop.IOLoop.instance()
521 loop = ioloop.IOLoop.instance()
522 loop.add_timeout(time.time()+0.1, loop.stop)
522 loop.add_timeout(time.time()+0.1, loop.stop)
523
523
524 #---------------------------------------------------------------------------
524 #---------------------------------------------------------------------------
525 # Engine methods
525 # Engine methods
526 #---------------------------------------------------------------------------
526 #---------------------------------------------------------------------------
527
527
528 def apply_request(self, stream, ident, parent):
528 def apply_request(self, stream, ident, parent):
529 try:
529 try:
530 content = parent[u'content']
530 content = parent[u'content']
531 bufs = parent[u'buffers']
531 bufs = parent[u'buffers']
532 msg_id = parent['header']['msg_id']
532 msg_id = parent['header']['msg_id']
533 except:
533 except:
534 self.log.error("Got bad msg: %s", parent, exc_info=True)
534 self.log.error("Got bad msg: %s", parent, exc_info=True)
535 return
535 return
536
536
537 self._publish_status(u'busy', parent)
537 self._publish_status(u'busy', parent)
538
538
539 # Set the parent message of the display hook and out streams.
539 # Set the parent message of the display hook and out streams.
540 shell = self.shell
540 shell = self.shell
541 shell.displayhook.set_parent(parent)
541 shell.displayhook.set_parent(parent)
542 shell.display_pub.set_parent(parent)
542 shell.display_pub.set_parent(parent)
543 sys.stdout.set_parent(parent)
543 sys.stdout.set_parent(parent)
544 sys.stderr.set_parent(parent)
544 sys.stderr.set_parent(parent)
545
545
546 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
546 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
547 # self.iopub_socket.send(pyin_msg)
547 # self.iopub_socket.send(pyin_msg)
548 # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
548 # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
549 md = self._make_metadata(parent['metadata'])
549 md = self._make_metadata(parent['metadata'])
550 try:
550 try:
551 working = shell.user_ns
551 working = shell.user_ns
552
552
553 prefix = "_"+str(msg_id).replace("-","")+"_"
553 prefix = "_"+str(msg_id).replace("-","")+"_"
554
554
555 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
555 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
556
556
557 fname = getattr(f, '__name__', 'f')
557 fname = getattr(f, '__name__', 'f')
558
558
559 fname = prefix+"f"
559 fname = prefix+"f"
560 argname = prefix+"args"
560 argname = prefix+"args"
561 kwargname = prefix+"kwargs"
561 kwargname = prefix+"kwargs"
562 resultname = prefix+"result"
562 resultname = prefix+"result"
563
563
564 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
564 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
565 # print ns
565 # print ns
566 working.update(ns)
566 working.update(ns)
567 code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
567 code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
568 try:
568 try:
569 exec code in shell.user_global_ns, shell.user_ns
569 exec code in shell.user_global_ns, shell.user_ns
570 result = working.get(resultname)
570 result = working.get(resultname)
571 finally:
571 finally:
572 for key in ns.iterkeys():
572 for key in ns.iterkeys():
573 working.pop(key)
573 working.pop(key)
574
574
575 packed_result,buf = serialize_object(result)
575 result_buf = serialize_object(result)
576 result_buf = [packed_result]+buf
576
577 except:
577 except:
578 # invoke IPython traceback formatting
578 # invoke IPython traceback formatting
579 shell.showtraceback()
579 shell.showtraceback()
580 # FIXME - fish exception info out of shell, possibly left there by
580 # FIXME - fish exception info out of shell, possibly left there by
581 # run_code. We'll need to clean up this logic later.
581 # run_code. We'll need to clean up this logic later.
582 reply_content = {}
582 reply_content = {}
583 if shell._reply_content is not None:
583 if shell._reply_content is not None:
584 reply_content.update(shell._reply_content)
584 reply_content.update(shell._reply_content)
585 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply')
585 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply')
586 reply_content['engine_info'] = e_info
586 reply_content['engine_info'] = e_info
587 # reset after use
587 # reset after use
588 shell._reply_content = None
588 shell._reply_content = None
589
589
590 self.session.send(self.iopub_socket, u'pyerr', reply_content, parent=parent,
590 self.session.send(self.iopub_socket, u'pyerr', reply_content, parent=parent,
591 ident=self._topic('pyerr'))
591 ident=self._topic('pyerr'))
592 result_buf = []
592 result_buf = []
593
593
594 if reply_content['ename'] == 'UnmetDependency':
594 if reply_content['ename'] == 'UnmetDependency':
595 md['dependencies_met'] = False
595 md['dependencies_met'] = False
596 else:
596 else:
597 reply_content = {'status' : 'ok'}
597 reply_content = {'status' : 'ok'}
598
598
599 # put 'ok'/'error' status in header, for scheduler introspection:
599 # put 'ok'/'error' status in header, for scheduler introspection:
600 md['status'] = reply_content['status']
600 md['status'] = reply_content['status']
601
601
602 # flush i/o
602 # flush i/o
603 sys.stdout.flush()
603 sys.stdout.flush()
604 sys.stderr.flush()
604 sys.stderr.flush()
605
605
606 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
606 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
607 parent=parent, ident=ident,buffers=result_buf, metadata=md)
607 parent=parent, ident=ident,buffers=result_buf, metadata=md)
608
608
609 self._publish_status(u'idle', parent)
609 self._publish_status(u'idle', parent)
610
610
611 #---------------------------------------------------------------------------
611 #---------------------------------------------------------------------------
612 # Control messages
612 # Control messages
613 #---------------------------------------------------------------------------
613 #---------------------------------------------------------------------------
614
614
615 def abort_request(self, stream, ident, parent):
615 def abort_request(self, stream, ident, parent):
616 """abort a specifig msg by id"""
616 """abort a specifig msg by id"""
617 msg_ids = parent['content'].get('msg_ids', None)
617 msg_ids = parent['content'].get('msg_ids', None)
618 if isinstance(msg_ids, basestring):
618 if isinstance(msg_ids, basestring):
619 msg_ids = [msg_ids]
619 msg_ids = [msg_ids]
620 if not msg_ids:
620 if not msg_ids:
621 self.abort_queues()
621 self.abort_queues()
622 for mid in msg_ids:
622 for mid in msg_ids:
623 self.aborted.add(str(mid))
623 self.aborted.add(str(mid))
624
624
625 content = dict(status='ok')
625 content = dict(status='ok')
626 reply_msg = self.session.send(stream, 'abort_reply', content=content,
626 reply_msg = self.session.send(stream, 'abort_reply', content=content,
627 parent=parent, ident=ident)
627 parent=parent, ident=ident)
628 self.log.debug("%s", reply_msg)
628 self.log.debug("%s", reply_msg)
629
629
630 def clear_request(self, stream, idents, parent):
630 def clear_request(self, stream, idents, parent):
631 """Clear our namespace."""
631 """Clear our namespace."""
632 self.shell.reset(False)
632 self.shell.reset(False)
633 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
633 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
634 content = dict(status='ok'))
634 content = dict(status='ok'))
635
635
636
636
637 #---------------------------------------------------------------------------
637 #---------------------------------------------------------------------------
638 # Protected interface
638 # Protected interface
639 #---------------------------------------------------------------------------
639 #---------------------------------------------------------------------------
640
640
641
641
642 def _wrap_exception(self, method=None):
642 def _wrap_exception(self, method=None):
643 # import here, because _wrap_exception is only used in parallel,
643 # import here, because _wrap_exception is only used in parallel,
644 # and parallel has higher min pyzmq version
644 # and parallel has higher min pyzmq version
645 from IPython.parallel.error import wrap_exception
645 from IPython.parallel.error import wrap_exception
646 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
646 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
647 content = wrap_exception(e_info)
647 content = wrap_exception(e_info)
648 return content
648 return content
649
649
650 def _topic(self, topic):
650 def _topic(self, topic):
651 """prefixed topic for IOPub messages"""
651 """prefixed topic for IOPub messages"""
652 if self.int_id >= 0:
652 if self.int_id >= 0:
653 base = "engine.%i" % self.int_id
653 base = "engine.%i" % self.int_id
654 else:
654 else:
655 base = "kernel.%s" % self.ident
655 base = "kernel.%s" % self.ident
656
656
657 return py3compat.cast_bytes("%s.%s" % (base, topic))
657 return py3compat.cast_bytes("%s.%s" % (base, topic))
658
658
659 def _abort_queues(self):
659 def _abort_queues(self):
660 for stream in self.shell_streams:
660 for stream in self.shell_streams:
661 if stream:
661 if stream:
662 self._abort_queue(stream)
662 self._abort_queue(stream)
663
663
664 def _abort_queue(self, stream):
664 def _abort_queue(self, stream):
665 poller = zmq.Poller()
665 poller = zmq.Poller()
666 poller.register(stream.socket, zmq.POLLIN)
666 poller.register(stream.socket, zmq.POLLIN)
667 while True:
667 while True:
668 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
668 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
669 if msg is None:
669 if msg is None:
670 return
670 return
671
671
672 self.log.info("Aborting:")
672 self.log.info("Aborting:")
673 self.log.info("%s", msg)
673 self.log.info("%s", msg)
674 msg_type = msg['header']['msg_type']
674 msg_type = msg['header']['msg_type']
675 reply_type = msg_type.split('_')[0] + '_reply'
675 reply_type = msg_type.split('_')[0] + '_reply'
676
676
677 status = {'status' : 'aborted'}
677 status = {'status' : 'aborted'}
678 md = {'engine' : self.ident}
678 md = {'engine' : self.ident}
679 md.update(status)
679 md.update(status)
680 reply_msg = self.session.send(stream, reply_type, meatadata=md,
680 reply_msg = self.session.send(stream, reply_type, meatadata=md,
681 content=status, parent=msg, ident=idents)
681 content=status, parent=msg, ident=idents)
682 self.log.debug("%s", reply_msg)
682 self.log.debug("%s", reply_msg)
683 # We need to wait a bit for requests to come in. This can probably
683 # We need to wait a bit for requests to come in. This can probably
684 # be set shorter for true asynchronous clients.
684 # be set shorter for true asynchronous clients.
685 poller.poll(50)
685 poller.poll(50)
686
686
687
687
688 def _no_raw_input(self):
688 def _no_raw_input(self):
689 """Raise StdinNotImplentedError if active frontend doesn't support
689 """Raise StdinNotImplentedError if active frontend doesn't support
690 stdin."""
690 stdin."""
691 raise StdinNotImplementedError("raw_input was called, but this "
691 raise StdinNotImplementedError("raw_input was called, but this "
692 "frontend does not support stdin.")
692 "frontend does not support stdin.")
693
693
694 def _raw_input(self, prompt, ident, parent):
694 def _raw_input(self, prompt, ident, parent):
695 # Flush output before making the request.
695 # Flush output before making the request.
696 sys.stderr.flush()
696 sys.stderr.flush()
697 sys.stdout.flush()
697 sys.stdout.flush()
698
698
699 # Send the input request.
699 # Send the input request.
700 content = json_clean(dict(prompt=prompt))
700 content = json_clean(dict(prompt=prompt))
701 self.session.send(self.stdin_socket, u'input_request', content, parent,
701 self.session.send(self.stdin_socket, u'input_request', content, parent,
702 ident=ident)
702 ident=ident)
703
703
704 # Await a response.
704 # Await a response.
705 while True:
705 while True:
706 try:
706 try:
707 ident, reply = self.session.recv(self.stdin_socket, 0)
707 ident, reply = self.session.recv(self.stdin_socket, 0)
708 except Exception:
708 except Exception:
709 self.log.warn("Invalid Message:", exc_info=True)
709 self.log.warn("Invalid Message:", exc_info=True)
710 else:
710 else:
711 break
711 break
712 try:
712 try:
713 value = reply['content']['value']
713 value = reply['content']['value']
714 except:
714 except:
715 self.log.error("Got bad raw_input reply: ")
715 self.log.error("Got bad raw_input reply: ")
716 self.log.error("%s", parent)
716 self.log.error("%s", parent)
717 value = ''
717 value = ''
718 if value == '\x04':
718 if value == '\x04':
719 # EOF
719 # EOF
720 raise EOFError
720 raise EOFError
721 return value
721 return value
722
722
723 def _complete(self, msg):
723 def _complete(self, msg):
724 c = msg['content']
724 c = msg['content']
725 try:
725 try:
726 cpos = int(c['cursor_pos'])
726 cpos = int(c['cursor_pos'])
727 except:
727 except:
728 # If we don't get something that we can convert to an integer, at
728 # If we don't get something that we can convert to an integer, at
729 # least attempt the completion guessing the cursor is at the end of
729 # least attempt the completion guessing the cursor is at the end of
730 # the text, if there's any, and otherwise of the line
730 # the text, if there's any, and otherwise of the line
731 cpos = len(c['text'])
731 cpos = len(c['text'])
732 if cpos==0:
732 if cpos==0:
733 cpos = len(c['line'])
733 cpos = len(c['line'])
734 return self.shell.complete(c['text'], c['line'], cpos)
734 return self.shell.complete(c['text'], c['line'], cpos)
735
735
736 def _object_info(self, context):
736 def _object_info(self, context):
737 symbol, leftover = self._symbol_from_context(context)
737 symbol, leftover = self._symbol_from_context(context)
738 if symbol is not None and not leftover:
738 if symbol is not None and not leftover:
739 doc = getattr(symbol, '__doc__', '')
739 doc = getattr(symbol, '__doc__', '')
740 else:
740 else:
741 doc = ''
741 doc = ''
742 object_info = dict(docstring = doc)
742 object_info = dict(docstring = doc)
743 return object_info
743 return object_info
744
744
745 def _symbol_from_context(self, context):
745 def _symbol_from_context(self, context):
746 if not context:
746 if not context:
747 return None, context
747 return None, context
748
748
749 base_symbol_string = context[0]
749 base_symbol_string = context[0]
750 symbol = self.shell.user_ns.get(base_symbol_string, None)
750 symbol = self.shell.user_ns.get(base_symbol_string, None)
751 if symbol is None:
751 if symbol is None:
752 symbol = __builtin__.__dict__.get(base_symbol_string, None)
752 symbol = __builtin__.__dict__.get(base_symbol_string, None)
753 if symbol is None:
753 if symbol is None:
754 return None, context
754 return None, context
755
755
756 context = context[1:]
756 context = context[1:]
757 for i, name in enumerate(context):
757 for i, name in enumerate(context):
758 new_symbol = getattr(symbol, name, None)
758 new_symbol = getattr(symbol, name, None)
759 if new_symbol is None:
759 if new_symbol is None:
760 return symbol, context[i:]
760 return symbol, context[i:]
761 else:
761 else:
762 symbol = new_symbol
762 symbol = new_symbol
763
763
764 return symbol, []
764 return symbol, []
765
765
766 def _at_shutdown(self):
766 def _at_shutdown(self):
767 """Actions taken at shutdown by the kernel, called by python's atexit.
767 """Actions taken at shutdown by the kernel, called by python's atexit.
768 """
768 """
769 # io.rprint("Kernel at_shutdown") # dbg
769 # io.rprint("Kernel at_shutdown") # dbg
770 if self._shutdown_message is not None:
770 if self._shutdown_message is not None:
771 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
771 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
772 self.log.debug("%s", self._shutdown_message)
772 self.log.debug("%s", self._shutdown_message)
773 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
773 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
774
774
775 #-----------------------------------------------------------------------------
775 #-----------------------------------------------------------------------------
776 # Aliases and Flags for the IPKernelApp
776 # Aliases and Flags for the IPKernelApp
777 #-----------------------------------------------------------------------------
777 #-----------------------------------------------------------------------------
778
778
779 flags = dict(kernel_flags)
779 flags = dict(kernel_flags)
780 flags.update(shell_flags)
780 flags.update(shell_flags)
781
781
782 addflag = lambda *args: flags.update(boolean_flag(*args))
782 addflag = lambda *args: flags.update(boolean_flag(*args))
783
783
784 flags['pylab'] = (
784 flags['pylab'] = (
785 {'IPKernelApp' : {'pylab' : 'auto'}},
785 {'IPKernelApp' : {'pylab' : 'auto'}},
786 """Pre-load matplotlib and numpy for interactive use with
786 """Pre-load matplotlib and numpy for interactive use with
787 the default matplotlib backend."""
787 the default matplotlib backend."""
788 )
788 )
789
789
790 aliases = dict(kernel_aliases)
790 aliases = dict(kernel_aliases)
791 aliases.update(shell_aliases)
791 aliases.update(shell_aliases)
792
792
793 #-----------------------------------------------------------------------------
793 #-----------------------------------------------------------------------------
794 # The IPKernelApp class
794 # The IPKernelApp class
795 #-----------------------------------------------------------------------------
795 #-----------------------------------------------------------------------------
796
796
797 class IPKernelApp(KernelApp, InteractiveShellApp):
797 class IPKernelApp(KernelApp, InteractiveShellApp):
798 name = 'ipkernel'
798 name = 'ipkernel'
799
799
800 aliases = Dict(aliases)
800 aliases = Dict(aliases)
801 flags = Dict(flags)
801 flags = Dict(flags)
802 classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
802 classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
803
803
804 @catch_config_error
804 @catch_config_error
805 def initialize(self, argv=None):
805 def initialize(self, argv=None):
806 super(IPKernelApp, self).initialize(argv)
806 super(IPKernelApp, self).initialize(argv)
807 self.init_path()
807 self.init_path()
808 self.init_shell()
808 self.init_shell()
809 self.init_gui_pylab()
809 self.init_gui_pylab()
810 self.init_extensions()
810 self.init_extensions()
811 self.init_code()
811 self.init_code()
812
812
813 def init_kernel(self):
813 def init_kernel(self):
814
814
815 shell_stream = ZMQStream(self.shell_socket)
815 shell_stream = ZMQStream(self.shell_socket)
816
816
817 kernel = Kernel(config=self.config, session=self.session,
817 kernel = Kernel(config=self.config, session=self.session,
818 shell_streams=[shell_stream],
818 shell_streams=[shell_stream],
819 iopub_socket=self.iopub_socket,
819 iopub_socket=self.iopub_socket,
820 stdin_socket=self.stdin_socket,
820 stdin_socket=self.stdin_socket,
821 log=self.log,
821 log=self.log,
822 profile_dir=self.profile_dir,
822 profile_dir=self.profile_dir,
823 )
823 )
824 self.kernel = kernel
824 self.kernel = kernel
825 kernel.record_ports(self.ports)
825 kernel.record_ports(self.ports)
826 shell = kernel.shell
826 shell = kernel.shell
827
827
828 def init_gui_pylab(self):
828 def init_gui_pylab(self):
829 """Enable GUI event loop integration, taking pylab into account."""
829 """Enable GUI event loop integration, taking pylab into account."""
830
830
831 # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
831 # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
832 # to ensure that any exception is printed straight to stderr.
832 # to ensure that any exception is printed straight to stderr.
833 # Normally _showtraceback associates the reply with an execution,
833 # Normally _showtraceback associates the reply with an execution,
834 # which means frontends will never draw it, as this exception
834 # which means frontends will never draw it, as this exception
835 # is not associated with any execute request.
835 # is not associated with any execute request.
836
836
837 shell = self.shell
837 shell = self.shell
838 _showtraceback = shell._showtraceback
838 _showtraceback = shell._showtraceback
839 try:
839 try:
840 # replace pyerr-sending traceback with stderr
840 # replace pyerr-sending traceback with stderr
841 def print_tb(etype, evalue, stb):
841 def print_tb(etype, evalue, stb):
842 print ("GUI event loop or pylab initialization failed",
842 print ("GUI event loop or pylab initialization failed",
843 file=io.stderr)
843 file=io.stderr)
844 print (shell.InteractiveTB.stb2text(stb), file=io.stderr)
844 print (shell.InteractiveTB.stb2text(stb), file=io.stderr)
845 shell._showtraceback = print_tb
845 shell._showtraceback = print_tb
846 InteractiveShellApp.init_gui_pylab(self)
846 InteractiveShellApp.init_gui_pylab(self)
847 finally:
847 finally:
848 shell._showtraceback = _showtraceback
848 shell._showtraceback = _showtraceback
849
849
850 def init_shell(self):
850 def init_shell(self):
851 self.shell = self.kernel.shell
851 self.shell = self.kernel.shell
852 self.shell.configurables.append(self)
852 self.shell.configurables.append(self)
853
853
854
854
855 #-----------------------------------------------------------------------------
855 #-----------------------------------------------------------------------------
856 # Kernel main and launch functions
856 # Kernel main and launch functions
857 #-----------------------------------------------------------------------------
857 #-----------------------------------------------------------------------------
858
858
859 def launch_kernel(*args, **kwargs):
859 def launch_kernel(*args, **kwargs):
860 """Launches a localhost IPython kernel, binding to the specified ports.
860 """Launches a localhost IPython kernel, binding to the specified ports.
861
861
862 This function simply calls entry_point.base_launch_kernel with the right
862 This function simply calls entry_point.base_launch_kernel with the right
863 first command to start an ipkernel. See base_launch_kernel for arguments.
863 first command to start an ipkernel. See base_launch_kernel for arguments.
864
864
865 Returns
865 Returns
866 -------
866 -------
867 A tuple of form:
867 A tuple of form:
868 (kernel_process, shell_port, iopub_port, stdin_port, hb_port)
868 (kernel_process, shell_port, iopub_port, stdin_port, hb_port)
869 where kernel_process is a Popen object and the ports are integers.
869 where kernel_process is a Popen object and the ports are integers.
870 """
870 """
871 return base_launch_kernel('from IPython.zmq.ipkernel import main; main()',
871 return base_launch_kernel('from IPython.zmq.ipkernel import main; main()',
872 *args, **kwargs)
872 *args, **kwargs)
873
873
874
874
875 def embed_kernel(module=None, local_ns=None, **kwargs):
875 def embed_kernel(module=None, local_ns=None, **kwargs):
876 """Embed and start an IPython kernel in a given scope.
876 """Embed and start an IPython kernel in a given scope.
877
877
878 Parameters
878 Parameters
879 ----------
879 ----------
880 module : ModuleType, optional
880 module : ModuleType, optional
881 The module to load into IPython globals (default: caller)
881 The module to load into IPython globals (default: caller)
882 local_ns : dict, optional
882 local_ns : dict, optional
883 The namespace to load into IPython user namespace (default: caller)
883 The namespace to load into IPython user namespace (default: caller)
884
884
885 kwargs : various, optional
885 kwargs : various, optional
886 Further keyword args are relayed to the KernelApp constructor,
886 Further keyword args are relayed to the KernelApp constructor,
887 allowing configuration of the Kernel. Will only have an effect
887 allowing configuration of the Kernel. Will only have an effect
888 on the first embed_kernel call for a given process.
888 on the first embed_kernel call for a given process.
889
889
890 """
890 """
891 # get the app if it exists, or set it up if it doesn't
891 # get the app if it exists, or set it up if it doesn't
892 if IPKernelApp.initialized():
892 if IPKernelApp.initialized():
893 app = IPKernelApp.instance()
893 app = IPKernelApp.instance()
894 else:
894 else:
895 app = IPKernelApp.instance(**kwargs)
895 app = IPKernelApp.instance(**kwargs)
896 app.initialize([])
896 app.initialize([])
897 # Undo unnecessary sys module mangling from init_sys_modules.
897 # Undo unnecessary sys module mangling from init_sys_modules.
898 # This would not be necessary if we could prevent it
898 # This would not be necessary if we could prevent it
899 # in the first place by using a different InteractiveShell
899 # in the first place by using a different InteractiveShell
900 # subclass, as in the regular embed case.
900 # subclass, as in the regular embed case.
901 main = app.kernel.shell._orig_sys_modules_main_mod
901 main = app.kernel.shell._orig_sys_modules_main_mod
902 if main is not None:
902 if main is not None:
903 sys.modules[app.kernel.shell._orig_sys_modules_main_name] = main
903 sys.modules[app.kernel.shell._orig_sys_modules_main_name] = main
904
904
905 # load the calling scope if not given
905 # load the calling scope if not given
906 (caller_module, caller_locals) = extract_module_locals(1)
906 (caller_module, caller_locals) = extract_module_locals(1)
907 if module is None:
907 if module is None:
908 module = caller_module
908 module = caller_module
909 if local_ns is None:
909 if local_ns is None:
910 local_ns = caller_locals
910 local_ns = caller_locals
911
911
912 app.kernel.user_module = module
912 app.kernel.user_module = module
913 app.kernel.user_ns = local_ns
913 app.kernel.user_ns = local_ns
914 app.shell.set_completer_frame()
914 app.shell.set_completer_frame()
915 app.start()
915 app.start()
916
916
917 def main():
917 def main():
918 """Run an IPKernel as an application"""
918 """Run an IPKernel as an application"""
919 app = IPKernelApp.instance()
919 app = IPKernelApp.instance()
920 app.initialize()
920 app.initialize()
921 app.start()
921 app.start()
922
922
923
923
924 if __name__ == '__main__':
924 if __name__ == '__main__':
925 main()
925 main()
@@ -1,179 +1,175 b''
1 """serialization utilities for apply messages
1 """serialization utilities for apply messages
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports
18 # Standard library imports
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import socket
22 import socket
23 import sys
23 import sys
24
24
25 try:
25 try:
26 import cPickle
26 import cPickle
27 pickle = cPickle
27 pickle = cPickle
28 except:
28 except:
29 cPickle = None
29 cPickle = None
30 import pickle
30 import pickle
31
31
32
32
33 # IPython imports
33 # IPython imports
34 from IPython.utils import py3compat
34 from IPython.utils import py3compat
35 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
35 from IPython.utils.pickleutil import (
36 can, uncan, can_sequence, uncan_sequence, CannedObject
37 )
36 from IPython.utils.newserialized import serialize, unserialize
38 from IPython.utils.newserialized import serialize, unserialize
37
39
38 if py3compat.PY3:
40 if py3compat.PY3:
39 buffer = memoryview
41 buffer = memoryview
40
42
41 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
42 # Serialization Functions
44 # Serialization Functions
43 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
44
46
45 def serialize_object(obj, threshold=64e-6):
47 # maximum items to iterate through in a container
48 MAX_ITEMS = 64
49
50 def _extract_buffers(obj, threshold=1024):
51 """extract buffers larger than a certain threshold"""
52 buffers = []
53 if isinstance(obj, CannedObject) and obj.buffers:
54 for i,buf in enumerate(obj.buffers):
55 if len(buf) > threshold:
56 # buffer larger than threshold, prevent pickling
57 obj.buffers[i] = None
58 buffers.append(buf)
59 elif isinstance(buf, buffer):
60 # buffer too small for separate send, coerce to bytes
61 # because pickling buffer objects just results in broken pointers
62 obj.buffers[i] = bytes(buf)
63 return buffers
64
65 def _restore_buffers(obj, buffers):
66 """restore buffers extracted by """
67 if isinstance(obj, CannedObject) and obj.buffers:
68 for i,buf in enumerate(obj.buffers):
69 if buf is None:
70 obj.buffers[i] = buffers.pop(0)
71
72 def serialize_object(obj, threshold=1024):
46 """Serialize an object into a list of sendable buffers.
73 """Serialize an object into a list of sendable buffers.
47
74
48 Parameters
75 Parameters
49 ----------
76 ----------
50
77
51 obj : object
78 obj : object
52 The object to be serialized
79 The object to be serialized
53 threshold : float
80 threshold : int
54 The threshold for not double-pickling the content.
81 The threshold (in bytes) for pulling out data buffers
55
82 to avoid pickling them.
56
83
57 Returns
84 Returns
58 -------
85 -------
59 ('pmd', [bufs]) :
86 [bufs] : list of buffers representing the serialized object.
60 where pmd is the pickled metadata wrapper,
61 bufs is a list of data buffers
62 """
87 """
63 databuffers = []
88 buffers = []
64 if isinstance(obj, (list, tuple)):
89 if isinstance(obj, (list, tuple)) and len(obj) < MAX_ITEMS:
65 clist = canSequence(obj)
90 cobj = can_sequence(obj)
66 slist = map(serialize, clist)
91 for c in cobj:
67 for s in slist:
92 buffers.extend(_extract_buffers(c, threshold))
68 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
93 elif isinstance(obj, dict) and len(obj) < MAX_ITEMS:
69 databuffers.append(s.getData())
94 cobj = {}
70 s.data = None
71 return pickle.dumps(slist,-1), databuffers
72 elif isinstance(obj, dict):
73 sobj = {}
74 for k in sorted(obj.iterkeys()):
95 for k in sorted(obj.iterkeys()):
75 s = serialize(can(obj[k]))
96 c = can(obj[k])
76 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
97 buffers.extend(_extract_buffers(c, threshold))
77 databuffers.append(s.getData())
98 cobj[k] = c
78 s.data = None
79 sobj[k] = s
80 return pickle.dumps(sobj,-1),databuffers
81 else:
99 else:
82 s = serialize(can(obj))
100 cobj = can(obj)
83 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
101 buffers.extend(_extract_buffers(cobj, threshold))
84 databuffers.append(s.getData())
102
85 s.data = None
103 buffers.insert(0, pickle.dumps(cobj,-1))
86 return pickle.dumps(s,-1),databuffers
104 return buffers
87
105
88
106 def unserialize_object(buffers, g=None):
89 def unserialize_object(bufs):
107 """reconstruct an object serialized by serialize_object from data buffers.
90 """reconstruct an object serialized by serialize_object from data buffers."""
108
91 bufs = list(bufs)
109 Parameters
92 sobj = pickle.loads(bufs.pop(0))
110 ----------
93 if isinstance(sobj, (list, tuple)):
111
94 for s in sobj:
112 bufs : list of buffers/bytes
95 if s.data is None:
113
96 s.data = bufs.pop(0)
114 g : globals to be used when uncanning
97 return uncanSequence(map(unserialize, sobj)), bufs
115
98 elif isinstance(sobj, dict):
116 Returns
117 -------
118
119 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
120 """
121 bufs = list(buffers)
122 canned = pickle.loads(bufs.pop(0))
123 if isinstance(canned, (list, tuple)) and len(canned) < MAX_ITEMS:
124 for c in canned:
125 _restore_buffers(c, bufs)
126 newobj = uncan_sequence(canned, g)
127 elif isinstance(canned, dict) and len(canned) < MAX_ITEMS:
99 newobj = {}
128 newobj = {}
100 for k in sorted(sobj.iterkeys()):
129 for k in sorted(canned.iterkeys()):
101 s = sobj[k]
130 c = canned[k]
102 if s.data is None:
131 _restore_buffers(c, bufs)
103 s.data = bufs.pop(0)
132 newobj[k] = uncan(c, g)
104 newobj[k] = uncan(unserialize(s))
105 return newobj, bufs
106 else:
133 else:
107 if sobj.data is None:
134 _restore_buffers(canned, bufs)
108 sobj.data = bufs.pop(0)
135 newobj = uncan(canned, g)
109 return uncan(unserialize(sobj)), bufs
136
137 return newobj, bufs
110
138
111 def pack_apply_message(f, args, kwargs, threshold=64e-6):
139 def pack_apply_message(f, args, kwargs, threshold=1024):
112 """pack up a function, args, and kwargs to be sent over the wire
140 """pack up a function, args, and kwargs to be sent over the wire
113 as a series of buffers. Any object whose data is larger than `threshold`
141 as a series of buffers. Any object whose data is larger than `threshold`
114 will not have their data copied (currently only numpy arrays support zero-copy)"""
142 will not have their data copied (currently only numpy arrays support zero-copy)
143 """
115 msg = [pickle.dumps(can(f),-1)]
144 msg = [pickle.dumps(can(f),-1)]
116 databuffers = [] # for large objects
145 databuffers = [] # for large objects
117 sargs, bufs = serialize_object(args,threshold)
146 sargs = serialize_object(args,threshold)
118 msg.append(sargs)
147 msg.append(sargs[0])
119 databuffers.extend(bufs)
148 databuffers.extend(sargs[1:])
120 skwargs, bufs = serialize_object(kwargs,threshold)
149 skwargs = serialize_object(kwargs,threshold)
121 msg.append(skwargs)
150 msg.append(skwargs[0])
122 databuffers.extend(bufs)
151 databuffers.extend(skwargs[1:])
123 msg.extend(databuffers)
152 msg.extend(databuffers)
124 return msg
153 return msg
125
154
126 def unpack_apply_message(bufs, g=None, copy=True):
155 def unpack_apply_message(bufs, g=None, copy=True):
127 """unpack f,args,kwargs from buffers packed by pack_apply_message()
156 """unpack f,args,kwargs from buffers packed by pack_apply_message()
128 Returns: original f,args,kwargs"""
157 Returns: original f,args,kwargs"""
129 bufs = list(bufs) # allow us to pop
158 bufs = list(bufs) # allow us to pop
130 assert len(bufs) >= 3, "not enough buffers!"
159 assert len(bufs) >= 3, "not enough buffers!"
131 if not copy:
160 if not copy:
132 for i in range(3):
161 for i in range(3):
133 bufs[i] = bufs[i].bytes
162 bufs[i] = bufs[i].bytes
134 cf = pickle.loads(bufs.pop(0))
163 f = uncan(pickle.loads(bufs.pop(0)), g)
135 sargs = list(pickle.loads(bufs.pop(0)))
164 # sargs = bufs.pop(0)
136 skwargs = dict(pickle.loads(bufs.pop(0)))
165 # pop kwargs out, so first n-elements are args, serialized
137 # print sargs, skwargs
166 skwargs = bufs.pop(1)
138 f = uncan(cf, g)
167 args, bufs = unserialize_object(bufs, g)
139 for sa in sargs:
168 # put skwargs back in as the first element
140 if sa.data is None:
169 bufs.insert(0, skwargs)
141 m = bufs.pop(0)
170 kwargs, bufs = unserialize_object(bufs, g)
142 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
143 # always use a buffer, until memoryviews get sorted out
144 sa.data = buffer(m)
145 # disable memoryview support
146 # if copy:
147 # sa.data = buffer(m)
148 # else:
149 # sa.data = m.buffer
150 else:
151 if copy:
152 sa.data = m
153 else:
154 sa.data = m.bytes
155
171
156 args = uncanSequence(map(unserialize, sargs), g)
172 assert not bufs, "Shouldn't be any data left over"
157 kwargs = {}
158 for k in sorted(skwargs.iterkeys()):
159 sa = skwargs[k]
160 if sa.data is None:
161 m = bufs.pop(0)
162 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
163 # always use a buffer, until memoryviews get sorted out
164 sa.data = buffer(m)
165 # disable memoryview support
166 # if copy:
167 # sa.data = buffer(m)
168 # else:
169 # sa.data = m.buffer
170 else:
171 if copy:
172 sa.data = m
173 else:
174 sa.data = m.bytes
175
176 kwargs[k] = uncan(unserialize(sa), g)
177
173
178 return f,args,kwargs
174 return f,args,kwargs
179
175
General Comments 0
You need to be logged in to leave comments. Login now