##// END OF EJS Templates
better serialization for parallel code...
MinRK -
Show More
@@ -0,0 +1,115 b''
1 """test serialization tools"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 import pickle
15
16 import nose.tools as nt
17
18 # from unittest import TestCaes
19 from IPython.zmq.serialize import serialize_object, unserialize_object
20 from IPython.testing import decorators as dec
21 from IPython.utils.pickleutil import CannedArray
22
23 def roundtrip(obj):
24 """roundtrip an object through serialization"""
25 bufs = serialize_object(obj)
26 obj2, remainder = unserialize_object(bufs)
27 nt.assert_equals(remainder, [])
28 return obj2
29
30 class C(object):
31 """dummy class for """
32
33 def __init__(self, **kwargs):
34 for key,value in kwargs.iteritems():
35 setattr(self, key, value)
36
37 @dec.parametric
38 def test_roundtrip_simple():
39 for obj in [
40 'hello',
41 dict(a='b', b=10),
42 [1,2,'hi'],
43 (b'123', 'hello'),
44 ]:
45 obj2 = roundtrip(obj)
46 yield nt.assert_equals(obj, obj2)
47
48 @dec.parametric
49 def test_roundtrip_nested():
50 for obj in [
51 dict(a=range(5), b={1:b'hello'}),
52 [range(5),[range(3),(1,[b'whoda'])]],
53 ]:
54 obj2 = roundtrip(obj)
55 yield nt.assert_equals(obj, obj2)
56
57 @dec.parametric
58 def test_roundtrip_buffered():
59 for obj in [
60 dict(a=b"x"*1025),
61 b"hello"*500,
62 [b"hello"*501, 1,2,3]
63 ]:
64 bufs = serialize_object(obj)
65 yield nt.assert_equals(len(bufs), 2)
66 obj2, remainder = unserialize_object(bufs)
67 yield nt.assert_equals(remainder, [])
68 yield nt.assert_equals(obj, obj2)
69
70 @dec.parametric
71 @dec.skip_without('numpy')
72 def test_numpy():
73 import numpy
74 from numpy.testing.utils import assert_array_equal
75 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
76 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
77 A = numpy.empty(shape, dtype=dtype)
78 bufs = serialize_object(A)
79 B, r = unserialize_object(bufs)
80 yield nt.assert_equals(r, [])
81 yield assert_array_equal(A,B)
82
83 @dec.parametric
84 @dec.skip_without('numpy')
85 def test_numpy_in_seq():
86 import numpy
87 from numpy.testing.utils import assert_array_equal
88 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
89 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
90 A = numpy.empty(shape, dtype=dtype)
91 bufs = serialize_object((A,1,2,b'hello'))
92 canned = pickle.loads(bufs[0])
93 yield nt.assert_true(canned[0], CannedArray)
94 tup, r = unserialize_object(bufs)
95 B = tup[0]
96 yield nt.assert_equals(r, [])
97 yield assert_array_equal(A,B)
98
99 @dec.parametric
100 @dec.skip_without('numpy')
101 def test_numpy_in_dict():
102 import numpy
103 from numpy.testing.utils import assert_array_equal
104 for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)):
105 for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]):
106 A = numpy.empty(shape, dtype=dtype)
107 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
108 canned = pickle.loads(bufs[0])
109 yield nt.assert_true(canned['a'], CannedArray)
110 d, r = unserialize_object(bufs)
111 B = d['a']
112 yield nt.assert_equals(r, [])
113 yield assert_array_equal(A,B)
114
115
@@ -1,597 +1,597 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from tempfile import mktemp
23 23 from StringIO import StringIO
24 24
25 25 import zmq
26 26 from nose import SkipTest
27 27
28 28 from IPython.testing import decorators as dec
29 29 from IPython.testing.ipunittest import ParametricTestCase
30 30
31 31 from IPython import parallel as pmod
32 32 from IPython.parallel import error
33 33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 34 from IPython.parallel import DirectView
35 35 from IPython.parallel.util import interactive
36 36
37 37 from IPython.parallel.tests import add_engines
38 38
39 39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 40
41 41 def setup():
42 42 add_engines(3, total=True)
43 43
44 44 class TestView(ClusterTestCase, ParametricTestCase):
45 45
46 46 def setUp(self):
47 47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 50 time.sleep(2)
51 51 super(TestView, self).setUp()
52 52
53 53 def test_z_crash_mux(self):
54 54 """test graceful handling of engine death (direct)"""
55 55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 56 # self.add_engines(1)
57 57 eid = self.client.ids[-1]
58 58 ar = self.client[eid].apply_async(crash)
59 59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 60 eid = ar.engine_id
61 61 tic = time.time()
62 62 while eid in self.client.ids and time.time()-tic < 5:
63 63 time.sleep(.01)
64 64 self.client.spin()
65 65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66 66
67 67 def test_push_pull(self):
68 68 """test pushing and pulling"""
69 69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 70 t = self.client.ids[-1]
71 71 v = self.client[t]
72 72 push = v.push
73 73 pull = v.pull
74 74 v.block=True
75 75 nengines = len(self.client)
76 76 push({'data':data})
77 77 d = pull('data')
78 78 self.assertEqual(d, data)
79 79 self.client[:].push({'data':data})
80 80 d = self.client[:].pull('data', block=True)
81 81 self.assertEqual(d, nengines*[data])
82 82 ar = push({'data':data}, block=False)
83 83 self.assertTrue(isinstance(ar, AsyncResult))
84 84 r = ar.get()
85 85 ar = self.client[:].pull('data', block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 88 self.assertEqual(r, nengines*[data])
89 89 self.client[:].push(dict(a=10,b=20))
90 90 r = self.client[:].pull(('a','b'), block=True)
91 91 self.assertEqual(r, nengines*[[10,20]])
92 92
93 93 def test_push_pull_function(self):
94 94 "test pushing and pulling functions"
95 95 def testf(x):
96 96 return 2.0*x
97 97
98 98 t = self.client.ids[-1]
99 99 v = self.client[t]
100 100 v.block=True
101 101 push = v.push
102 102 pull = v.pull
103 103 execute = v.execute
104 104 push({'testf':testf})
105 105 r = pull('testf')
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute('r = testf(10)')
108 108 r = pull('r')
109 109 self.assertEqual(r, testf(10))
110 110 ar = self.client[:].push({'testf':testf}, block=False)
111 111 ar.get()
112 112 ar = self.client[:].pull('testf', block=False)
113 113 rlist = ar.get()
114 114 for r in rlist:
115 115 self.assertEqual(r(1.0), testf(1.0))
116 116 execute("def g(x): return x*x")
117 117 r = pull(('testf','g'))
118 118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119 119
120 120 def test_push_function_globals(self):
121 121 """test that pushed functions have access to globals"""
122 122 @interactive
123 123 def geta():
124 124 return a
125 125 # self.add_engines(1)
126 126 v = self.client[-1]
127 127 v.block=True
128 128 v['f'] = geta
129 129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 130 v.execute('a=5')
131 131 v.execute('b=f()')
132 132 self.assertEqual(v['b'], 5)
133 133
134 134 def test_push_function_defaults(self):
135 135 """test that pushed functions preserve default args"""
136 136 def echo(a=10):
137 137 return a
138 138 v = self.client[-1]
139 139 v.block=True
140 140 v['f'] = echo
141 141 v.execute('b=f()')
142 142 self.assertEqual(v['b'], 10)
143 143
144 144 def test_get_result(self):
145 145 """test getting results from the Hub."""
146 146 c = pmod.Client(profile='iptest')
147 147 # self.add_engines(1)
148 148 t = c.ids[-1]
149 149 v = c[t]
150 150 v2 = self.client[t]
151 151 ar = v.apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = v2.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 156 self.assertEqual(ahr.get(), ar.get())
157 157 ar2 = v2.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.spin()
160 160 c.close()
161 161
162 162 def test_run_newline(self):
163 163 """test that run appends newline to files"""
164 164 tmpfile = mktemp()
165 165 with open(tmpfile, 'w') as f:
166 166 f.write("""def g():
167 167 return 5
168 168 """)
169 169 v = self.client[-1]
170 170 v.run(tmpfile, block=True)
171 171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172 172
173 173 def test_apply_tracked(self):
174 174 """test tracking for apply"""
175 175 # self.add_engines(1)
176 176 t = self.client.ids[-1]
177 177 v = self.client[t]
178 178 v.block=False
179 179 def echo(n=1024*1024, **kwargs):
180 180 with v.temp_flags(**kwargs):
181 181 return v.apply(lambda x: x, 'x'*n)
182 182 ar = echo(1, track=False)
183 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 184 self.assertTrue(ar.sent)
185 185 ar = echo(track=True)
186 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 187 self.assertEqual(ar.sent, ar._tracker.done)
188 188 ar._tracker.wait()
189 189 self.assertTrue(ar.sent)
190 190
191 191 def test_push_tracked(self):
192 192 t = self.client.ids[-1]
193 193 ns = dict(x='x'*1024*1024)
194 194 v = self.client[t]
195 195 ar = v.push(ns, block=False, track=False)
196 196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 197 self.assertTrue(ar.sent)
198 198
199 199 ar = v.push(ns, block=False, track=True)
200 200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 201 ar._tracker.wait()
202 202 self.assertEqual(ar.sent, ar._tracker.done)
203 203 self.assertTrue(ar.sent)
204 204 ar.get()
205 205
206 206 def test_scatter_tracked(self):
207 207 t = self.client.ids
208 208 x='x'*1024*1024
209 209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 211 self.assertTrue(ar.sent)
212 212
213 213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 215 self.assertEqual(ar.sent, ar._tracker.done)
216 216 ar._tracker.wait()
217 217 self.assertTrue(ar.sent)
218 218 ar.get()
219 219
220 220 def test_remote_reference(self):
221 221 v = self.client[-1]
222 222 v['a'] = 123
223 223 ra = pmod.Reference('a')
224 224 b = v.apply_sync(lambda x: x, ra)
225 225 self.assertEqual(b, 123)
226 226
227 227
228 228 def test_scatter_gather(self):
229 229 view = self.client[:]
230 230 seq1 = range(16)
231 231 view.scatter('a', seq1)
232 232 seq2 = view.gather('a', block=True)
233 233 self.assertEqual(seq2, seq1)
234 234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235 235
236 236 @skip_without('numpy')
237 237 def test_scatter_gather_numpy(self):
238 238 import numpy
239 239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 240 view = self.client[:]
241 241 a = numpy.arange(64)
242 view.scatter('a', a)
242 view.scatter('a', a, block=True)
243 243 b = view.gather('a', block=True)
244 244 assert_array_equal(b, a)
245 245
246 246 def test_scatter_gather_lazy(self):
247 247 """scatter/gather with targets='all'"""
248 248 view = self.client.direct_view(targets='all')
249 249 x = range(64)
250 250 view.scatter('x', x)
251 251 gathered = view.gather('x', block=True)
252 252 self.assertEqual(gathered, x)
253 253
254 254
255 255 @dec.known_failure_py3
256 256 @skip_without('numpy')
257 257 def test_push_numpy_nocopy(self):
258 258 import numpy
259 259 view = self.client[:]
260 260 a = numpy.arange(64)
261 261 view['A'] = a
262 262 @interactive
263 263 def check_writeable(x):
264 264 return x.flags.writeable
265 265
266 266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268 268
269 269 view.push(dict(B=a))
270 270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 272
273 273 @skip_without('numpy')
274 274 def test_apply_numpy(self):
275 275 """view.apply(f, ndarray)"""
276 276 import numpy
277 277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 278
279 279 A = numpy.random.random((100,100))
280 280 view = self.client[-1]
281 281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 282 B = A.astype(dt)
283 283 C = view.apply_sync(lambda x:x, B)
284 284 assert_array_equal(B,C)
285 285
286 286 @skip_without('numpy')
287 287 def test_push_pull_recarray(self):
288 288 """push/pull recarrays"""
289 289 import numpy
290 290 from numpy.testing.utils import assert_array_equal
291 291
292 292 view = self.client[-1]
293 293
294 294 R = numpy.array([
295 295 (1, 'hi', 0.),
296 296 (2**30, 'there', 2.5),
297 297 (-99999, 'world', -12345.6789),
298 298 ], [('n', int), ('s', '|S10'), ('f', float)])
299 299
300 300 view['RR'] = R
301 301 R2 = view['RR']
302 302
303 303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 304 self.assertEqual(r_dtype, R.dtype)
305 305 self.assertEqual(r_shape, R.shape)
306 306 self.assertEqual(R2.dtype, R.dtype)
307 307 self.assertEqual(R2.shape, R.shape)
308 308 assert_array_equal(R2, R)
309 309
310 310 def test_map(self):
311 311 view = self.client[:]
312 312 def f(x):
313 313 return x**2
314 314 data = range(16)
315 315 r = view.map_sync(f, data)
316 316 self.assertEqual(r, map(f, data))
317 317
318 318 def test_map_iterable(self):
319 319 """test map on iterables (direct)"""
320 320 view = self.client[:]
321 321 # 101 is prime, so it won't be evenly distributed
322 322 arr = range(101)
323 323 # ensure it will be an iterator, even in Python 3
324 324 it = iter(arr)
325 325 r = view.map_sync(lambda x:x, arr)
326 326 self.assertEqual(r, list(arr))
327 327
328 def test_scatterGatherNonblocking(self):
328 def test_scatter_gather_nonblocking(self):
329 329 data = range(16)
330 330 view = self.client[:]
331 331 view.scatter('a', data, block=False)
332 332 ar = view.gather('a', block=False)
333 333 self.assertEqual(ar.get(), data)
334 334
335 335 @skip_without('numpy')
336 336 def test_scatter_gather_numpy_nonblocking(self):
337 337 import numpy
338 338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 339 a = numpy.arange(64)
340 340 view = self.client[:]
341 341 ar = view.scatter('a', a, block=False)
342 342 self.assertTrue(isinstance(ar, AsyncResult))
343 343 amr = view.gather('a', block=False)
344 344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 345 assert_array_equal(amr.get(), a)
346 346
347 347 def test_execute(self):
348 348 view = self.client[:]
349 349 # self.client.debug=True
350 350 execute = view.execute
351 351 ar = execute('c=30', block=False)
352 352 self.assertTrue(isinstance(ar, AsyncResult))
353 353 ar = execute('d=[0,1,2]', block=False)
354 354 self.client.wait(ar, 1)
355 355 self.assertEqual(len(ar.get()), len(self.client))
356 356 for c in view['c']:
357 357 self.assertEqual(c, 30)
358 358
359 359 def test_abort(self):
360 360 view = self.client[-1]
361 361 ar = view.execute('import time; time.sleep(1)', block=False)
362 362 ar2 = view.apply_async(lambda : 2)
363 363 ar3 = view.apply_async(lambda : 3)
364 364 view.abort(ar2)
365 365 view.abort(ar3.msg_ids)
366 366 self.assertRaises(error.TaskAborted, ar2.get)
367 367 self.assertRaises(error.TaskAborted, ar3.get)
368 368
369 369 def test_abort_all(self):
370 370 """view.abort() aborts all outstanding tasks"""
371 371 view = self.client[-1]
372 372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 373 view.abort()
374 374 view.wait(timeout=5)
375 375 for ar in ars[5:]:
376 376 self.assertRaises(error.TaskAborted, ar.get)
377 377
378 378 def test_temp_flags(self):
379 379 view = self.client[-1]
380 380 view.block=True
381 381 with view.temp_flags(block=False):
382 382 self.assertFalse(view.block)
383 383 self.assertTrue(view.block)
384 384
385 385 @dec.known_failure_py3
386 386 def test_importer(self):
387 387 view = self.client[-1]
388 388 view.clear(block=True)
389 389 with view.importer:
390 390 import re
391 391
392 392 @interactive
393 393 def findall(pat, s):
394 394 # this globals() step isn't necessary in real code
395 395 # only to prevent a closure in the test
396 396 re = globals()['re']
397 397 return re.findall(pat, s)
398 398
399 399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400 400
401 401 def test_unicode_execute(self):
402 402 """test executing unicode strings"""
403 403 v = self.client[-1]
404 404 v.block=True
405 405 if sys.version_info[0] >= 3:
406 406 code="a='é'"
407 407 else:
408 408 code=u"a=u'é'"
409 409 v.execute(code)
410 410 self.assertEqual(v['a'], u'é')
411 411
412 412 def test_unicode_apply_result(self):
413 413 """test unicode apply results"""
414 414 v = self.client[-1]
415 415 r = v.apply_sync(lambda : u'é')
416 416 self.assertEqual(r, u'é')
417 417
418 418 def test_unicode_apply_arg(self):
419 419 """test passing unicode arguments to apply"""
420 420 v = self.client[-1]
421 421
422 422 @interactive
423 423 def check_unicode(a, check):
424 424 assert isinstance(a, unicode), "%r is not unicode"%a
425 425 assert isinstance(check, bytes), "%r is not bytes"%check
426 426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427 427
428 428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 429 try:
430 430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 431 except error.RemoteError as e:
432 432 if e.ename == 'AssertionError':
433 433 self.fail(e.evalue)
434 434 else:
435 435 raise e
436 436
437 437 def test_map_reference(self):
438 438 """view.map(<Reference>, *seqs) should work"""
439 439 v = self.client[:]
440 440 v.scatter('n', self.client.ids, flatten=True)
441 441 v.execute("f = lambda x,y: x*y")
442 442 rf = pmod.Reference('f')
443 443 nlist = list(range(10))
444 444 mlist = nlist[::-1]
445 445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 446 result = v.map_sync(rf, mlist, nlist)
447 447 self.assertEqual(result, expected)
448 448
449 449 def test_apply_reference(self):
450 450 """view.apply(<Reference>, *args) should work"""
451 451 v = self.client[:]
452 452 v.scatter('n', self.client.ids, flatten=True)
453 453 v.execute("f = lambda x: n*x")
454 454 rf = pmod.Reference('f')
455 455 result = v.apply_sync(rf, 5)
456 456 expected = [ 5*id for id in self.client.ids ]
457 457 self.assertEqual(result, expected)
458 458
459 459 def test_eval_reference(self):
460 460 v = self.client[self.client.ids[0]]
461 461 v['g'] = range(5)
462 462 rg = pmod.Reference('g[0]')
463 463 echo = lambda x:x
464 464 self.assertEqual(v.apply_sync(echo, rg), 0)
465 465
466 466 def test_reference_nameerror(self):
467 467 v = self.client[self.client.ids[0]]
468 468 r = pmod.Reference('elvis_has_left')
469 469 echo = lambda x:x
470 470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471 471
472 472 def test_single_engine_map(self):
473 473 e0 = self.client[self.client.ids[0]]
474 474 r = range(5)
475 475 check = [ -1*i for i in r ]
476 476 result = e0.map_sync(lambda x: -1*x, r)
477 477 self.assertEqual(result, check)
478 478
479 479 def test_len(self):
480 480 """len(view) makes sense"""
481 481 e0 = self.client[self.client.ids[0]]
482 482 yield self.assertEqual(len(e0), 1)
483 483 v = self.client[:]
484 484 yield self.assertEqual(len(v), len(self.client.ids))
485 485 v = self.client.direct_view('all')
486 486 yield self.assertEqual(len(v), len(self.client.ids))
487 487 v = self.client[:2]
488 488 yield self.assertEqual(len(v), 2)
489 489 v = self.client[:1]
490 490 yield self.assertEqual(len(v), 1)
491 491 v = self.client.load_balanced_view()
492 492 yield self.assertEqual(len(v), len(self.client.ids))
493 493 # parametric tests seem to require manual closing?
494 494 self.client.close()
495 495
496 496
497 497 # begin execute tests
498 498
499 499 def test_execute_reply(self):
500 500 e0 = self.client[self.client.ids[0]]
501 501 e0.block = True
502 502 ar = e0.execute("5", silent=False)
503 503 er = ar.get()
504 504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506 506
507 507 def test_execute_reply_stdout(self):
508 508 e0 = self.client[self.client.ids[0]]
509 509 e0.block = True
510 510 ar = e0.execute("print (5)", silent=False)
511 511 er = ar.get()
512 512 self.assertEqual(er.stdout.strip(), '5')
513 513
514 514 def test_execute_pyout(self):
515 515 """execute triggers pyout with silent=False"""
516 516 view = self.client[:]
517 517 ar = view.execute("5", silent=False, block=True)
518 518
519 519 expected = [{'text/plain' : '5'}] * len(view)
520 520 mimes = [ out['data'] for out in ar.pyout ]
521 521 self.assertEqual(mimes, expected)
522 522
523 523 def test_execute_silent(self):
524 524 """execute does not trigger pyout with silent=True"""
525 525 view = self.client[:]
526 526 ar = view.execute("5", block=True)
527 527 expected = [None] * len(view)
528 528 self.assertEqual(ar.pyout, expected)
529 529
530 530 def test_execute_magic(self):
531 531 """execute accepts IPython commands"""
532 532 view = self.client[:]
533 533 view.execute("a = 5")
534 534 ar = view.execute("%whos", block=True)
535 535 # this will raise, if that failed
536 536 ar.get(5)
537 537 for stdout in ar.stdout:
538 538 lines = stdout.splitlines()
539 539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 540 found = False
541 541 for line in lines[2:]:
542 542 split = line.split()
543 543 if split == ['a', 'int', '5']:
544 544 found = True
545 545 break
546 546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547 547
548 548 def test_execute_displaypub(self):
549 549 """execute tracks display_pub output"""
550 550 view = self.client[:]
551 551 view.execute("from IPython.core.display import *")
552 552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553 553
554 554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 555 for outputs in ar.outputs:
556 556 mimes = [ out['data'] for out in outputs ]
557 557 self.assertEqual(mimes, expected)
558 558
559 559 def test_apply_displaypub(self):
560 560 """apply tracks display_pub output"""
561 561 view = self.client[:]
562 562 view.execute("from IPython.core.display import *")
563 563
564 564 @interactive
565 565 def publish():
566 566 [ display(i) for i in range(5) ]
567 567
568 568 ar = view.apply_async(publish)
569 569 ar.get(5)
570 570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 571 for outputs in ar.outputs:
572 572 mimes = [ out['data'] for out in outputs ]
573 573 self.assertEqual(mimes, expected)
574 574
575 575 def test_execute_raises(self):
576 576 """exceptions in execute requests raise appropriately"""
577 577 view = self.client[-1]
578 578 ar = view.execute("1/0")
579 579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580 580
581 581 @dec.skipif_not_matplotlib
582 582 def test_magic_pylab(self):
583 583 """%pylab works on engines"""
584 584 view = self.client[-1]
585 585 ar = view.execute("%pylab inline")
586 586 # at least check if this raised:
587 587 reply = ar.get(5)
588 588 # include imports, in case user config
589 589 ar = view.execute("plot(rand(100))", silent=False)
590 590 reply = ar.get(5)
591 591 self.assertEqual(len(reply.outputs), 1)
592 592 output = reply.outputs[0]
593 593 self.assertTrue("data" in output)
594 594 data = output['data']
595 595 self.assertTrue("image/png" in data)
596 596
597 597
@@ -1,358 +1,352 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 from IPython.external.decorator import decorator
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 from IPython.utils import py3compat
47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
48 from IPython.utils.newserialized import serialize, unserialize
49 46 from IPython.zmq.log import EnginePUBHandler
50 47 from IPython.zmq.serialize import (
51 48 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
52 49 )
53 50
54 if py3compat.PY3:
55 buffer = memoryview
56
57 51 #-----------------------------------------------------------------------------
58 52 # Classes
59 53 #-----------------------------------------------------------------------------
60 54
61 55 class Namespace(dict):
62 56 """Subclass of dict for attribute access to keys."""
63 57
64 58 def __getattr__(self, key):
65 59 """getattr aliased to getitem"""
66 60 if key in self.iterkeys():
67 61 return self[key]
68 62 else:
69 63 raise NameError(key)
70 64
71 65 def __setattr__(self, key, value):
72 66 """setattr aliased to setitem, with strict"""
73 67 if hasattr(dict, key):
74 68 raise KeyError("Cannot override dict keys %r"%key)
75 69 self[key] = value
76 70
77 71
78 72 class ReverseDict(dict):
79 73 """simple double-keyed subset of dict methods."""
80 74
81 75 def __init__(self, *args, **kwargs):
82 76 dict.__init__(self, *args, **kwargs)
83 77 self._reverse = dict()
84 78 for key, value in self.iteritems():
85 79 self._reverse[value] = key
86 80
87 81 def __getitem__(self, key):
88 82 try:
89 83 return dict.__getitem__(self, key)
90 84 except KeyError:
91 85 return self._reverse[key]
92 86
93 87 def __setitem__(self, key, value):
94 88 if key in self._reverse:
95 89 raise KeyError("Can't have key %r on both sides!"%key)
96 90 dict.__setitem__(self, key, value)
97 91 self._reverse[value] = key
98 92
99 93 def pop(self, key):
100 94 value = dict.pop(self, key)
101 95 self._reverse.pop(value)
102 96 return value
103 97
104 98 def get(self, key, default=None):
105 99 try:
106 100 return self[key]
107 101 except KeyError:
108 102 return default
109 103
110 104 #-----------------------------------------------------------------------------
111 105 # Functions
112 106 #-----------------------------------------------------------------------------
113 107
114 108 @decorator
115 109 def log_errors(f, self, *args, **kwargs):
116 110 """decorator to log unhandled exceptions raised in a method.
117 111
118 112 For use wrapping on_recv callbacks, so that exceptions
119 113 do not cause the stream to be closed.
120 114 """
121 115 try:
122 116 return f(self, *args, **kwargs)
123 117 except Exception:
124 118 self.log.error("Uncaught exception in %r" % f, exc_info=True)
125 119
126 120
127 121 def is_url(url):
128 122 """boolean check for whether a string is a zmq url"""
129 123 if '://' not in url:
130 124 return False
131 125 proto, addr = url.split('://', 1)
132 126 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
133 127 return False
134 128 return True
135 129
136 130 def validate_url(url):
137 131 """validate a url for zeromq"""
138 132 if not isinstance(url, basestring):
139 133 raise TypeError("url must be a string, not %r"%type(url))
140 134 url = url.lower()
141 135
142 136 proto_addr = url.split('://')
143 137 assert len(proto_addr) == 2, 'Invalid url: %r'%url
144 138 proto, addr = proto_addr
145 139 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
146 140
147 141 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
148 142 # author: Remi Sabourin
149 143 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
150 144
151 145 if proto == 'tcp':
152 146 lis = addr.split(':')
153 147 assert len(lis) == 2, 'Invalid url: %r'%url
154 148 addr,s_port = lis
155 149 try:
156 150 port = int(s_port)
157 151 except ValueError:
158 152 raise AssertionError("Invalid port %r in url: %r"%(port, url))
159 153
160 154 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
161 155
162 156 else:
163 157 # only validate tcp urls currently
164 158 pass
165 159
166 160 return True
167 161
168 162
169 163 def validate_url_container(container):
170 164 """validate a potentially nested collection of urls."""
171 165 if isinstance(container, basestring):
172 166 url = container
173 167 return validate_url(url)
174 168 elif isinstance(container, dict):
175 169 container = container.itervalues()
176 170
177 171 for element in container:
178 172 validate_url_container(element)
179 173
180 174
181 175 def split_url(url):
182 176 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
183 177 proto_addr = url.split('://')
184 178 assert len(proto_addr) == 2, 'Invalid url: %r'%url
185 179 proto, addr = proto_addr
186 180 lis = addr.split(':')
187 181 assert len(lis) == 2, 'Invalid url: %r'%url
188 182 addr,s_port = lis
189 183 return proto,addr,s_port
190 184
191 185 def disambiguate_ip_address(ip, location=None):
192 186 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
193 187 ones, based on the location (default interpretation of location is localhost)."""
194 188 if ip in ('0.0.0.0', '*'):
195 189 try:
196 190 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
197 191 except (socket.gaierror, IndexError):
198 192 # couldn't identify this machine, assume localhost
199 193 external_ips = []
200 194 if location is None or location in external_ips or not external_ips:
201 195 # If location is unspecified or cannot be determined, assume local
202 196 ip='127.0.0.1'
203 197 elif location:
204 198 return location
205 199 return ip
206 200
207 201 def disambiguate_url(url, location=None):
208 202 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
209 203 ones, based on the location (default interpretation is localhost).
210 204
211 205 This is for zeromq urls, such as tcp://*:10101."""
212 206 try:
213 207 proto,ip,port = split_url(url)
214 208 except AssertionError:
215 209 # probably not tcp url; could be ipc, etc.
216 210 return url
217 211
218 212 ip = disambiguate_ip_address(ip,location)
219 213
220 214 return "%s://%s:%s"%(proto,ip,port)
221 215
222 216
223 217 #--------------------------------------------------------------------------
224 218 # helpers for implementing old MEC API via view.apply
225 219 #--------------------------------------------------------------------------
226 220
227 221 def interactive(f):
228 222 """decorator for making functions appear as interactively defined.
229 223 This results in the function being linked to the user_ns as globals()
230 224 instead of the module globals().
231 225 """
232 226 f.__module__ = '__main__'
233 227 return f
234 228
235 229 @interactive
236 230 def _push(**ns):
237 231 """helper method for implementing `client.push` via `client.apply`"""
238 232 globals().update(ns)
239 233
240 234 @interactive
241 235 def _pull(keys):
242 236 """helper method for implementing `client.pull` via `client.apply`"""
243 237 user_ns = globals()
244 238 if isinstance(keys, (list,tuple, set)):
245 239 for key in keys:
246 240 if key not in user_ns:
247 241 raise NameError("name '%s' is not defined"%key)
248 242 return map(user_ns.get, keys)
249 243 else:
250 244 if keys not in user_ns:
251 245 raise NameError("name '%s' is not defined"%keys)
252 246 return user_ns.get(keys)
253 247
254 248 @interactive
255 249 def _execute(code):
256 250 """helper method for implementing `client.execute` via `client.apply`"""
257 251 exec code in globals()
258 252
259 253 #--------------------------------------------------------------------------
260 254 # extra process management utilities
261 255 #--------------------------------------------------------------------------
262 256
263 257 _random_ports = set()
264 258
265 259 def select_random_ports(n):
266 260 """Selects and return n random ports that are available."""
267 261 ports = []
268 262 for i in xrange(n):
269 263 sock = socket.socket()
270 264 sock.bind(('', 0))
271 265 while sock.getsockname()[1] in _random_ports:
272 266 sock.close()
273 267 sock = socket.socket()
274 268 sock.bind(('', 0))
275 269 ports.append(sock)
276 270 for i, sock in enumerate(ports):
277 271 port = sock.getsockname()[1]
278 272 sock.close()
279 273 ports[i] = port
280 274 _random_ports.add(port)
281 275 return ports
282 276
283 277 def signal_children(children):
284 278 """Relay interupt/term signals to children, for more solid process cleanup."""
285 279 def terminate_children(sig, frame):
286 280 log = Application.instance().log
287 281 log.critical("Got signal %i, terminating children..."%sig)
288 282 for child in children:
289 283 child.terminate()
290 284
291 285 sys.exit(sig != SIGINT)
292 286 # sys.exit(sig)
293 287 for sig in (SIGINT, SIGABRT, SIGTERM):
294 288 signal(sig, terminate_children)
295 289
296 290 def generate_exec_key(keyfile):
297 291 import uuid
298 292 newkey = str(uuid.uuid4())
299 293 with open(keyfile, 'w') as f:
300 294 # f.write('ipython-key ')
301 295 f.write(newkey+'\n')
302 296 # set user-only RW permissions (0600)
303 297 # this will have no effect on Windows
304 298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
305 299
306 300
307 301 def integer_loglevel(loglevel):
308 302 try:
309 303 loglevel = int(loglevel)
310 304 except ValueError:
311 305 if isinstance(loglevel, str):
312 306 loglevel = getattr(logging, loglevel)
313 307 return loglevel
314 308
315 309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
316 310 logger = logging.getLogger(logname)
317 311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
318 312 # don't add a second PUBHandler
319 313 return
320 314 loglevel = integer_loglevel(loglevel)
321 315 lsock = context.socket(zmq.PUB)
322 316 lsock.connect(iface)
323 317 handler = handlers.PUBHandler(lsock)
324 318 handler.setLevel(loglevel)
325 319 handler.root_topic = root
326 320 logger.addHandler(handler)
327 321 logger.setLevel(loglevel)
328 322
329 323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
330 324 logger = logging.getLogger()
331 325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
332 326 # don't add a second PUBHandler
333 327 return
334 328 loglevel = integer_loglevel(loglevel)
335 329 lsock = context.socket(zmq.PUB)
336 330 lsock.connect(iface)
337 331 handler = EnginePUBHandler(engine, lsock)
338 332 handler.setLevel(loglevel)
339 333 logger.addHandler(handler)
340 334 logger.setLevel(loglevel)
341 335 return logger
342 336
343 337 def local_logger(logname, loglevel=logging.DEBUG):
344 338 loglevel = integer_loglevel(loglevel)
345 339 logger = logging.getLogger(logname)
346 340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
347 341 # don't add a second StreamHandler
348 342 return
349 343 handler = logging.StreamHandler()
350 344 handler.setLevel(loglevel)
351 345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
352 346 datefmt="%Y-%m-%d %H:%M:%S")
353 347 handler.setFormatter(formatter)
354 348
355 349 logger.addHandler(handler)
356 350 logger.setLevel(loglevel)
357 351 return logger
358 352
@@ -1,151 +1,232 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import sys
20 20 from types import FunctionType
21 21
22 try:
23 import cPickle as pickle
24 except ImportError:
25 import pickle
26
27 try:
28 import numpy
29 except:
30 numpy = None
31
22 32 import codeutil
33 import py3compat
34 from importstring import import_item
35
36 if py3compat.PY3:
37 buffer = memoryview
23 38
24 39 #-------------------------------------------------------------------------------
25 40 # Classes
26 41 #-------------------------------------------------------------------------------
27 42
28 43
29 44 class CannedObject(object):
30 45 def __init__(self, obj, keys=[]):
31 46 self.keys = keys
32 47 self.obj = copy.copy(obj)
33 48 for key in keys:
34 49 setattr(self.obj, key, can(getattr(obj, key)))
50
51 self.buffers = []
35 52
36
37 def getObject(self, g=None):
53 def get_object(self, g=None):
38 54 if g is None:
39 g = globals()
55 g = {}
40 56 for key in self.keys:
41 57 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
42 58 return self.obj
59
43 60
44 61 class Reference(CannedObject):
45 62 """object for wrapping a remote reference by name."""
46 63 def __init__(self, name):
47 64 if not isinstance(name, basestring):
48 65 raise TypeError("illegal name: %r"%name)
49 66 self.name = name
67 self.buffers = []
50 68
51 69 def __repr__(self):
52 70 return "<Reference: %r>"%self.name
53 71
54 def getObject(self, g=None):
72 def get_object(self, g=None):
55 73 if g is None:
56 g = globals()
74 g = {}
57 75
58 76 return eval(self.name, g)
59 77
60 78
61 79 class CannedFunction(CannedObject):
62 80
63 81 def __init__(self, f):
64 self._checkType(f)
82 self._check_type(f)
65 83 self.code = f.func_code
66 84 self.defaults = f.func_defaults
67 85 self.module = f.__module__ or '__main__'
68 86 self.__name__ = f.__name__
87 self.buffers = []
69 88
70 def _checkType(self, obj):
89 def _check_type(self, obj):
71 90 assert isinstance(obj, FunctionType), "Not a function type"
72 91
73 def getObject(self, g=None):
92 def get_object(self, g=None):
74 93 # try to load function back into its module:
75 94 if not self.module.startswith('__'):
76 95 try:
77 96 __import__(self.module)
78 97 except ImportError:
79 98 pass
80 99 else:
81 100 g = sys.modules[self.module].__dict__
82 101
83 102 if g is None:
84 g = globals()
103 g = {}
85 104 newFunc = FunctionType(self.code, g, self.__name__, self.defaults)
86 105 return newFunc
87 106
107
108 class CannedArray(CannedObject):
109 def __init__(self, obj):
110 self.shape = obj.shape
111 self.dtype = obj.dtype
112 if sum(obj.shape) == 0:
113 # just pickle it
114 self.buffers = [pickle.dumps(obj, -1)]
115 else:
116 # ensure contiguous
117 obj = numpy.ascontiguousarray(obj, dtype=None)
118 self.buffers = [buffer(obj)]
119
120 def get_object(self, g=None):
121 data = self.buffers[0]
122 if sum(self.shape) == 0:
123 # no shape, we just pickled it
124 return pickle.loads(data)
125 else:
126 return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape)
127
128
129 class CannedBytes(CannedObject):
130 wrap = bytes
131 def __init__(self, obj):
132 self.buffers = [obj]
133
134 def get_object(self, g=None):
135 data = self.buffers[0]
136 return self.wrap(data)
137
138 def CannedBuffer(CannedBytes):
139 wrap = buffer
140
88 141 #-------------------------------------------------------------------------------
89 142 # Functions
90 143 #-------------------------------------------------------------------------------
91 144
92 def can(obj):
93 # import here to prevent module-level circular imports
94 from IPython.parallel import dependent
95 if isinstance(obj, dependent):
96 keys = ('f','df')
97 return CannedObject(obj, keys=keys)
98 elif isinstance(obj, FunctionType):
99 return CannedFunction(obj)
100 elif isinstance(obj,dict):
101 return canDict(obj)
102 elif isinstance(obj, (list,tuple)):
103 return canSequence(obj)
104 else:
105 return obj
106 145
107 def canDict(obj):
146 def can(obj):
147 """prepare an object for pickling"""
148 for cls,canner in can_map.iteritems():
149 if isinstance(cls, basestring):
150 try:
151 cls = import_item(cls)
152 except Exception:
153 # not importable
154 print "not importable: %r" % cls
155 continue
156 if isinstance(obj, cls):
157 return canner(obj)
158 return obj
159
160 def can_dict(obj):
161 """can the *values* of a dict"""
108 162 if isinstance(obj, dict):
109 163 newobj = {}
110 164 for k, v in obj.iteritems():
111 165 newobj[k] = can(v)
112 166 return newobj
113 167 else:
114 168 return obj
115 169
116 def canSequence(obj):
170 def can_sequence(obj):
171 """can the elements of a sequence"""
117 172 if isinstance(obj, (list, tuple)):
118 173 t = type(obj)
119 174 return t([can(i) for i in obj])
120 175 else:
121 176 return obj
122 177
123 178 def uncan(obj, g=None):
124 if isinstance(obj, CannedObject):
125 return obj.getObject(g)
126 elif isinstance(obj,dict):
127 return uncanDict(obj, g)
128 elif isinstance(obj, (list,tuple)):
129 return uncanSequence(obj, g)
130 else:
131 return obj
132
133 def uncanDict(obj, g=None):
179 """invert canning"""
180 for cls,uncanner in uncan_map.iteritems():
181 if isinstance(cls, basestring):
182 try:
183 cls = import_item(cls)
184 except Exception:
185 # not importable
186 print "not importable: %r" % cls
187 continue
188 if isinstance(obj, cls):
189 return uncanner(obj, g)
190 return obj
191
192 def uncan_dict(obj, g=None):
134 193 if isinstance(obj, dict):
135 194 newobj = {}
136 195 for k, v in obj.iteritems():
137 196 newobj[k] = uncan(v,g)
138 197 return newobj
139 198 else:
140 199 return obj
141 200
142 def uncanSequence(obj, g=None):
201 def uncan_sequence(obj, g=None):
143 202 if isinstance(obj, (list, tuple)):
144 203 t = type(obj)
145 204 return t([uncan(i,g) for i in obj])
146 205 else:
147 206 return obj
148 207
149 208
150 def rebindFunctionGlobals(f, glbls):
151 return FunctionType(f.func_code, glbls)
209 #-------------------------------------------------------------------------------
210 # API dictionary
211 #-------------------------------------------------------------------------------
212
213 # These dicts can be extended for custom serialization of new objects
214
215 can_map = {
216 'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')),
217 'numpy.ndarray' : CannedArray,
218 FunctionType : CannedFunction,
219 bytes : CannedBytes,
220 buffer : CannedBuffer,
221 # dict : can_dict,
222 # list : can_sequence,
223 # tuple : can_sequence,
224 }
225
226 uncan_map = {
227 CannedObject : lambda obj, g: obj.get_object(g),
228 # dict : uncan_dict,
229 # list : uncan_sequence,
230 # tuple : uncan_sequence,
231 }
232
@@ -1,925 +1,925 b''
1 1 #!/usr/bin/env python
2 2 """A simple interactive kernel that talks to a frontend over 0MQ.
3 3
4 4 Things to do:
5 5
6 6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
7 7 call set_parent on all the PUB objects with the message about to be executed.
8 8 * Implement random port and security key logic.
9 9 * Implement control messages.
10 10 * Implement event loop and poll version.
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 # Standard library imports
19 19 import __builtin__
20 20 import atexit
21 21 import sys
22 22 import time
23 23 import traceback
24 24 import logging
25 25 import uuid
26 26
27 27 from datetime import datetime
28 28 from signal import (
29 29 signal, getsignal, default_int_handler, SIGINT, SIG_IGN
30 30 )
31 31
32 32 # System library imports
33 33 import zmq
34 34 from zmq.eventloop import ioloop
35 35 from zmq.eventloop.zmqstream import ZMQStream
36 36
37 37 # Local imports
38 38 from IPython.config.configurable import Configurable
39 39 from IPython.config.application import boolean_flag, catch_config_error
40 40 from IPython.core.application import ProfileDir
41 41 from IPython.core.error import StdinNotImplementedError
42 42 from IPython.core.shellapp import (
43 43 InteractiveShellApp, shell_flags, shell_aliases
44 44 )
45 45 from IPython.utils import io
46 46 from IPython.utils import py3compat
47 47 from IPython.utils.frame import extract_module_locals
48 48 from IPython.utils.jsonutil import json_clean
49 49 from IPython.utils.traitlets import (
50 50 Any, Instance, Float, Dict, CaselessStrEnum, List, Set, Integer, Unicode
51 51 )
52 52
53 53 from entry_point import base_launch_kernel
54 54 from kernelapp import KernelApp, kernel_flags, kernel_aliases
55 55 from serialize import serialize_object, unpack_apply_message
56 56 from session import Session, Message
57 57 from zmqshell import ZMQInteractiveShell
58 58
59 59
60 60 #-----------------------------------------------------------------------------
61 61 # Main kernel class
62 62 #-----------------------------------------------------------------------------
63 63
64 64 class Kernel(Configurable):
65 65
66 66 #---------------------------------------------------------------------------
67 67 # Kernel interface
68 68 #---------------------------------------------------------------------------
69 69
70 70 # attribute to override with a GUI
71 71 eventloop = Any(None)
72 72 def _eventloop_changed(self, name, old, new):
73 73 """schedule call to eventloop from IOLoop"""
74 74 loop = ioloop.IOLoop.instance()
75 75 loop.add_timeout(time.time()+0.1, self.enter_eventloop)
76 76
77 77 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
78 78 session = Instance(Session)
79 79 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
80 80 shell_streams = List()
81 81 control_stream = Instance(ZMQStream)
82 82 iopub_socket = Instance(zmq.Socket)
83 83 stdin_socket = Instance(zmq.Socket)
84 84 log = Instance(logging.Logger)
85 85
86 86 user_module = Any()
87 87 def _user_module_changed(self, name, old, new):
88 88 if self.shell is not None:
89 89 self.shell.user_module = new
90 90
91 91 user_ns = Dict(default_value=None)
92 92 def _user_ns_changed(self, name, old, new):
93 93 if self.shell is not None:
94 94 self.shell.user_ns = new
95 95 self.shell.init_user_ns()
96 96
97 97 # identities:
98 98 int_id = Integer(-1)
99 99 ident = Unicode()
100 100
101 101 def _ident_default(self):
102 102 return unicode(uuid.uuid4())
103 103
104 104
105 105 # Private interface
106 106
107 107 # Time to sleep after flushing the stdout/err buffers in each execute
108 108 # cycle. While this introduces a hard limit on the minimal latency of the
109 109 # execute cycle, it helps prevent output synchronization problems for
110 110 # clients.
111 111 # Units are in seconds. The minimum zmq latency on local host is probably
112 112 # ~150 microseconds, set this to 500us for now. We may need to increase it
113 113 # a little if it's not enough after more interactive testing.
114 114 _execute_sleep = Float(0.0005, config=True)
115 115
116 116 # Frequency of the kernel's event loop.
117 117 # Units are in seconds, kernel subclasses for GUI toolkits may need to
118 118 # adapt to milliseconds.
119 119 _poll_interval = Float(0.05, config=True)
120 120
121 121 # If the shutdown was requested over the network, we leave here the
122 122 # necessary reply message so it can be sent by our registered atexit
123 123 # handler. This ensures that the reply is only sent to clients truly at
124 124 # the end of our shutdown process (which happens after the underlying
125 125 # IPython shell's own shutdown).
126 126 _shutdown_message = None
127 127
128 128 # This is a dict of port number that the kernel is listening on. It is set
129 129 # by record_ports and used by connect_request.
130 130 _recorded_ports = Dict()
131 131
132 132 # set of aborted msg_ids
133 133 aborted = Set()
134 134
135 135
136 136 def __init__(self, **kwargs):
137 137 super(Kernel, self).__init__(**kwargs)
138 138
139 139 # Initialize the InteractiveShell subclass
140 140 self.shell = ZMQInteractiveShell.instance(config=self.config,
141 141 profile_dir = self.profile_dir,
142 142 user_module = self.user_module,
143 143 user_ns = self.user_ns,
144 144 )
145 145 self.shell.displayhook.session = self.session
146 146 self.shell.displayhook.pub_socket = self.iopub_socket
147 147 self.shell.displayhook.topic = self._topic('pyout')
148 148 self.shell.display_pub.session = self.session
149 149 self.shell.display_pub.pub_socket = self.iopub_socket
150 150
151 151 # TMP - hack while developing
152 152 self.shell._reply_content = None
153 153
154 154 # Build dict of handlers for message types
155 155 msg_types = [ 'execute_request', 'complete_request',
156 156 'object_info_request', 'history_request',
157 157 'connect_request', 'shutdown_request',
158 158 'apply_request',
159 159 ]
160 160 self.shell_handlers = {}
161 161 for msg_type in msg_types:
162 162 self.shell_handlers[msg_type] = getattr(self, msg_type)
163 163
164 164 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
165 165 self.control_handlers = {}
166 166 for msg_type in control_msg_types:
167 167 self.control_handlers[msg_type] = getattr(self, msg_type)
168 168
169 169 def dispatch_control(self, msg):
170 170 """dispatch control requests"""
171 171 idents,msg = self.session.feed_identities(msg, copy=False)
172 172 try:
173 173 msg = self.session.unserialize(msg, content=True, copy=False)
174 174 except:
175 175 self.log.error("Invalid Control Message", exc_info=True)
176 176 return
177 177
178 178 self.log.debug("Control received: %s", msg)
179 179
180 180 header = msg['header']
181 181 msg_id = header['msg_id']
182 182 msg_type = header['msg_type']
183 183
184 184 handler = self.control_handlers.get(msg_type, None)
185 185 if handler is None:
186 186 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
187 187 else:
188 188 try:
189 189 handler(self.control_stream, idents, msg)
190 190 except Exception:
191 191 self.log.error("Exception in control handler:", exc_info=True)
192 192
193 193 def dispatch_shell(self, stream, msg):
194 194 """dispatch shell requests"""
195 195 # flush control requests first
196 196 if self.control_stream:
197 197 self.control_stream.flush()
198 198
199 199 idents,msg = self.session.feed_identities(msg, copy=False)
200 200 try:
201 201 msg = self.session.unserialize(msg, content=True, copy=False)
202 202 except:
203 203 self.log.error("Invalid Message", exc_info=True)
204 204 return
205 205
206 206 header = msg['header']
207 207 msg_id = header['msg_id']
208 208 msg_type = msg['header']['msg_type']
209 209
210 210 # Print some info about this message and leave a '--->' marker, so it's
211 211 # easier to trace visually the message chain when debugging. Each
212 212 # handler prints its message at the end.
213 213 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
214 214 self.log.debug(' Content: %s\n --->\n ', msg['content'])
215 215
216 216 if msg_id in self.aborted:
217 217 self.aborted.remove(msg_id)
218 218 # is it safe to assume a msg_id will not be resubmitted?
219 219 reply_type = msg_type.split('_')[0] + '_reply'
220 220 status = {'status' : 'aborted'}
221 221 md = {'engine' : self.ident}
222 222 md.update(status)
223 223 reply_msg = self.session.send(stream, reply_type, metadata=md,
224 224 content=status, parent=msg, ident=idents)
225 225 return
226 226
227 227 handler = self.shell_handlers.get(msg_type, None)
228 228 if handler is None:
229 229 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
230 230 else:
231 231 # ensure default_int_handler during handler call
232 232 sig = signal(SIGINT, default_int_handler)
233 233 try:
234 234 handler(stream, idents, msg)
235 235 except Exception:
236 236 self.log.error("Exception in message handler:", exc_info=True)
237 237 finally:
238 238 signal(SIGINT, sig)
239 239
240 240 def enter_eventloop(self):
241 241 """enter eventloop"""
242 242 self.log.info("entering eventloop")
243 243 # restore default_int_handler
244 244 signal(SIGINT, default_int_handler)
245 245 while self.eventloop is not None:
246 246 try:
247 247 self.eventloop(self)
248 248 except KeyboardInterrupt:
249 249 # Ctrl-C shouldn't crash the kernel
250 250 self.log.error("KeyboardInterrupt caught in kernel")
251 251 continue
252 252 else:
253 253 # eventloop exited cleanly, this means we should stop (right?)
254 254 self.eventloop = None
255 255 break
256 256 self.log.info("exiting eventloop")
257 257 # if eventloop exits, IOLoop should stop
258 258 ioloop.IOLoop.instance().stop()
259 259
260 260 def start(self):
261 261 """register dispatchers for streams"""
262 262 self.shell.exit_now = False
263 263 if self.control_stream:
264 264 self.control_stream.on_recv(self.dispatch_control, copy=False)
265 265
266 266 def make_dispatcher(stream):
267 267 def dispatcher(msg):
268 268 return self.dispatch_shell(stream, msg)
269 269 return dispatcher
270 270
271 271 for s in self.shell_streams:
272 272 s.on_recv(make_dispatcher(s), copy=False)
273 273
274 274 def do_one_iteration(self):
275 275 """step eventloop just once"""
276 276 if self.control_stream:
277 277 self.control_stream.flush()
278 278 for stream in self.shell_streams:
279 279 # handle at most one request per iteration
280 280 stream.flush(zmq.POLLIN, 1)
281 281 stream.flush(zmq.POLLOUT)
282 282
283 283
284 284 def record_ports(self, ports):
285 285 """Record the ports that this kernel is using.
286 286
287 287 The creator of the Kernel instance must call this methods if they
288 288 want the :meth:`connect_request` method to return the port numbers.
289 289 """
290 290 self._recorded_ports = ports
291 291
292 292 #---------------------------------------------------------------------------
293 293 # Kernel request handlers
294 294 #---------------------------------------------------------------------------
295 295
296 296 def _make_metadata(self, other=None):
297 297 """init metadata dict, for execute/apply_reply"""
298 298 new_md = {
299 299 'dependencies_met' : True,
300 300 'engine' : self.ident,
301 301 'started': datetime.now(),
302 302 }
303 303 if other:
304 304 new_md.update(other)
305 305 return new_md
306 306
307 307 def _publish_pyin(self, code, parent, execution_count):
308 308 """Publish the code request on the pyin stream."""
309 309
310 310 self.session.send(self.iopub_socket, u'pyin',
311 311 {u'code':code, u'execution_count': execution_count},
312 312 parent=parent, ident=self._topic('pyin')
313 313 )
314 314
315 315 def _publish_status(self, status, parent=None):
316 316 """send status (busy/idle) on IOPub"""
317 317 self.session.send(self.iopub_socket,
318 318 u'status',
319 319 {u'execution_state': status},
320 320 parent=parent,
321 321 ident=self._topic('status'),
322 322 )
323 323
324 324
325 325 def execute_request(self, stream, ident, parent):
326 326 """handle an execute_request"""
327 327
328 328 self._publish_status(u'busy', parent)
329 329
330 330 try:
331 331 content = parent[u'content']
332 332 code = content[u'code']
333 333 silent = content[u'silent']
334 334 except:
335 335 self.log.error("Got bad msg: ")
336 336 self.log.error("%s", parent)
337 337 return
338 338
339 339 md = self._make_metadata(parent['metadata'])
340 340
341 341 shell = self.shell # we'll need this a lot here
342 342
343 343 # Replace raw_input. Note that is not sufficient to replace
344 344 # raw_input in the user namespace.
345 345 if content.get('allow_stdin', False):
346 346 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
347 347 else:
348 348 raw_input = lambda prompt='' : self._no_raw_input()
349 349
350 350 if py3compat.PY3:
351 351 __builtin__.input = raw_input
352 352 else:
353 353 __builtin__.raw_input = raw_input
354 354
355 355 # Set the parent message of the display hook and out streams.
356 356 shell.displayhook.set_parent(parent)
357 357 shell.display_pub.set_parent(parent)
358 358 sys.stdout.set_parent(parent)
359 359 sys.stderr.set_parent(parent)
360 360
361 361 # Re-broadcast our input for the benefit of listening clients, and
362 362 # start computing output
363 363 if not silent:
364 364 self._publish_pyin(code, parent, shell.execution_count)
365 365
366 366 reply_content = {}
367 367 try:
368 368 # FIXME: the shell calls the exception handler itself.
369 369 shell.run_cell(code, store_history=not silent, silent=silent)
370 370 except:
371 371 status = u'error'
372 372 # FIXME: this code right now isn't being used yet by default,
373 373 # because the run_cell() call above directly fires off exception
374 374 # reporting. This code, therefore, is only active in the scenario
375 375 # where runlines itself has an unhandled exception. We need to
376 376 # uniformize this, for all exception construction to come from a
377 377 # single location in the codbase.
378 378 etype, evalue, tb = sys.exc_info()
379 379 tb_list = traceback.format_exception(etype, evalue, tb)
380 380 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
381 381 else:
382 382 status = u'ok'
383 383
384 384 reply_content[u'status'] = status
385 385
386 386 # Return the execution counter so clients can display prompts
387 387 reply_content['execution_count'] = shell.execution_count - 1
388 388
389 389 # FIXME - fish exception info out of shell, possibly left there by
390 390 # runlines. We'll need to clean up this logic later.
391 391 if shell._reply_content is not None:
392 392 reply_content.update(shell._reply_content)
393 393 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='execute')
394 394 reply_content['engine_info'] = e_info
395 395 # reset after use
396 396 shell._reply_content = None
397 397
398 398 # At this point, we can tell whether the main code execution succeeded
399 399 # or not. If it did, we proceed to evaluate user_variables/expressions
400 400 if reply_content['status'] == 'ok':
401 401 reply_content[u'user_variables'] = \
402 402 shell.user_variables(content.get(u'user_variables', []))
403 403 reply_content[u'user_expressions'] = \
404 404 shell.user_expressions(content.get(u'user_expressions', {}))
405 405 else:
406 406 # If there was an error, don't even try to compute variables or
407 407 # expressions
408 408 reply_content[u'user_variables'] = {}
409 409 reply_content[u'user_expressions'] = {}
410 410
411 411 # Payloads should be retrieved regardless of outcome, so we can both
412 412 # recover partial output (that could have been generated early in a
413 413 # block, before an error) and clear the payload system always.
414 414 reply_content[u'payload'] = shell.payload_manager.read_payload()
415 415 # Be agressive about clearing the payload because we don't want
416 416 # it to sit in memory until the next execute_request comes in.
417 417 shell.payload_manager.clear_payload()
418 418
419 419 # Flush output before sending the reply.
420 420 sys.stdout.flush()
421 421 sys.stderr.flush()
422 422 # FIXME: on rare occasions, the flush doesn't seem to make it to the
423 423 # clients... This seems to mitigate the problem, but we definitely need
424 424 # to better understand what's going on.
425 425 if self._execute_sleep:
426 426 time.sleep(self._execute_sleep)
427 427
428 428 # Send the reply.
429 429 reply_content = json_clean(reply_content)
430 430
431 431 md['status'] = reply_content['status']
432 432 if reply_content['status'] == 'error' and \
433 433 reply_content['ename'] == 'UnmetDependency':
434 434 md['dependencies_met'] = False
435 435
436 436 reply_msg = self.session.send(stream, u'execute_reply',
437 437 reply_content, parent, metadata=md,
438 438 ident=ident)
439 439
440 440 self.log.debug("%s", reply_msg)
441 441
442 442 if not silent and reply_msg['content']['status'] == u'error':
443 443 self._abort_queues()
444 444
445 445 self._publish_status(u'idle', parent)
446 446
447 447 def complete_request(self, stream, ident, parent):
448 448 txt, matches = self._complete(parent)
449 449 matches = {'matches' : matches,
450 450 'matched_text' : txt,
451 451 'status' : 'ok'}
452 452 matches = json_clean(matches)
453 453 completion_msg = self.session.send(stream, 'complete_reply',
454 454 matches, parent, ident)
455 455 self.log.debug("%s", completion_msg)
456 456
457 457 def object_info_request(self, stream, ident, parent):
458 458 content = parent['content']
459 459 object_info = self.shell.object_inspect(content['oname'],
460 460 detail_level = content.get('detail_level', 0)
461 461 )
462 462 # Before we send this object over, we scrub it for JSON usage
463 463 oinfo = json_clean(object_info)
464 464 msg = self.session.send(stream, 'object_info_reply',
465 465 oinfo, parent, ident)
466 466 self.log.debug("%s", msg)
467 467
468 468 def history_request(self, stream, ident, parent):
469 469 # We need to pull these out, as passing **kwargs doesn't work with
470 470 # unicode keys before Python 2.6.5.
471 471 hist_access_type = parent['content']['hist_access_type']
472 472 raw = parent['content']['raw']
473 473 output = parent['content']['output']
474 474 if hist_access_type == 'tail':
475 475 n = parent['content']['n']
476 476 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
477 477 include_latest=True)
478 478
479 479 elif hist_access_type == 'range':
480 480 session = parent['content']['session']
481 481 start = parent['content']['start']
482 482 stop = parent['content']['stop']
483 483 hist = self.shell.history_manager.get_range(session, start, stop,
484 484 raw=raw, output=output)
485 485
486 486 elif hist_access_type == 'search':
487 487 pattern = parent['content']['pattern']
488 488 hist = self.shell.history_manager.search(pattern, raw=raw,
489 489 output=output)
490 490
491 491 else:
492 492 hist = []
493 493 hist = list(hist)
494 494 content = {'history' : hist}
495 495 content = json_clean(content)
496 496 msg = self.session.send(stream, 'history_reply',
497 497 content, parent, ident)
498 498 self.log.debug("Sending history reply with %i entries", len(hist))
499 499
500 500 def connect_request(self, stream, ident, parent):
501 501 if self._recorded_ports is not None:
502 502 content = self._recorded_ports.copy()
503 503 else:
504 504 content = {}
505 505 msg = self.session.send(stream, 'connect_reply',
506 506 content, parent, ident)
507 507 self.log.debug("%s", msg)
508 508
509 509 def shutdown_request(self, stream, ident, parent):
510 510 self.shell.exit_now = True
511 511 content = dict(status='ok')
512 512 content.update(parent['content'])
513 513 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
514 514 # same content, but different msg_id for broadcasting on IOPub
515 515 self._shutdown_message = self.session.msg(u'shutdown_reply',
516 516 content, parent
517 517 )
518 518
519 519 self._at_shutdown()
520 520 # call sys.exit after a short delay
521 521 loop = ioloop.IOLoop.instance()
522 522 loop.add_timeout(time.time()+0.1, loop.stop)
523 523
524 524 #---------------------------------------------------------------------------
525 525 # Engine methods
526 526 #---------------------------------------------------------------------------
527 527
528 528 def apply_request(self, stream, ident, parent):
529 529 try:
530 530 content = parent[u'content']
531 531 bufs = parent[u'buffers']
532 532 msg_id = parent['header']['msg_id']
533 533 except:
534 534 self.log.error("Got bad msg: %s", parent, exc_info=True)
535 535 return
536 536
537 537 self._publish_status(u'busy', parent)
538 538
539 539 # Set the parent message of the display hook and out streams.
540 540 shell = self.shell
541 541 shell.displayhook.set_parent(parent)
542 542 shell.display_pub.set_parent(parent)
543 543 sys.stdout.set_parent(parent)
544 544 sys.stderr.set_parent(parent)
545 545
546 546 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
547 547 # self.iopub_socket.send(pyin_msg)
548 548 # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
549 549 md = self._make_metadata(parent['metadata'])
550 550 try:
551 551 working = shell.user_ns
552 552
553 553 prefix = "_"+str(msg_id).replace("-","")+"_"
554 554
555 555 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
556 556
557 557 fname = getattr(f, '__name__', 'f')
558 558
559 559 fname = prefix+"f"
560 560 argname = prefix+"args"
561 561 kwargname = prefix+"kwargs"
562 562 resultname = prefix+"result"
563 563
564 564 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
565 565 # print ns
566 566 working.update(ns)
567 567 code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
568 568 try:
569 569 exec code in shell.user_global_ns, shell.user_ns
570 570 result = working.get(resultname)
571 571 finally:
572 572 for key in ns.iterkeys():
573 573 working.pop(key)
574 574
575 packed_result,buf = serialize_object(result)
576 result_buf = [packed_result]+buf
575 result_buf = serialize_object(result)
576
577 577 except:
578 578 # invoke IPython traceback formatting
579 579 shell.showtraceback()
580 580 # FIXME - fish exception info out of shell, possibly left there by
581 581 # run_code. We'll need to clean up this logic later.
582 582 reply_content = {}
583 583 if shell._reply_content is not None:
584 584 reply_content.update(shell._reply_content)
585 585 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply')
586 586 reply_content['engine_info'] = e_info
587 587 # reset after use
588 588 shell._reply_content = None
589 589
590 590 self.session.send(self.iopub_socket, u'pyerr', reply_content, parent=parent,
591 591 ident=self._topic('pyerr'))
592 592 result_buf = []
593 593
594 594 if reply_content['ename'] == 'UnmetDependency':
595 595 md['dependencies_met'] = False
596 596 else:
597 597 reply_content = {'status' : 'ok'}
598 598
599 599 # put 'ok'/'error' status in header, for scheduler introspection:
600 600 md['status'] = reply_content['status']
601 601
602 602 # flush i/o
603 603 sys.stdout.flush()
604 604 sys.stderr.flush()
605 605
606 606 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
607 607 parent=parent, ident=ident,buffers=result_buf, metadata=md)
608 608
609 609 self._publish_status(u'idle', parent)
610 610
611 611 #---------------------------------------------------------------------------
612 612 # Control messages
613 613 #---------------------------------------------------------------------------
614 614
615 615 def abort_request(self, stream, ident, parent):
616 616 """abort a specifig msg by id"""
617 617 msg_ids = parent['content'].get('msg_ids', None)
618 618 if isinstance(msg_ids, basestring):
619 619 msg_ids = [msg_ids]
620 620 if not msg_ids:
621 621 self.abort_queues()
622 622 for mid in msg_ids:
623 623 self.aborted.add(str(mid))
624 624
625 625 content = dict(status='ok')
626 626 reply_msg = self.session.send(stream, 'abort_reply', content=content,
627 627 parent=parent, ident=ident)
628 628 self.log.debug("%s", reply_msg)
629 629
630 630 def clear_request(self, stream, idents, parent):
631 631 """Clear our namespace."""
632 632 self.shell.reset(False)
633 633 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
634 634 content = dict(status='ok'))
635 635
636 636
637 637 #---------------------------------------------------------------------------
638 638 # Protected interface
639 639 #---------------------------------------------------------------------------
640 640
641 641
642 642 def _wrap_exception(self, method=None):
643 643 # import here, because _wrap_exception is only used in parallel,
644 644 # and parallel has higher min pyzmq version
645 645 from IPython.parallel.error import wrap_exception
646 646 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
647 647 content = wrap_exception(e_info)
648 648 return content
649 649
650 650 def _topic(self, topic):
651 651 """prefixed topic for IOPub messages"""
652 652 if self.int_id >= 0:
653 653 base = "engine.%i" % self.int_id
654 654 else:
655 655 base = "kernel.%s" % self.ident
656 656
657 657 return py3compat.cast_bytes("%s.%s" % (base, topic))
658 658
659 659 def _abort_queues(self):
660 660 for stream in self.shell_streams:
661 661 if stream:
662 662 self._abort_queue(stream)
663 663
664 664 def _abort_queue(self, stream):
665 665 poller = zmq.Poller()
666 666 poller.register(stream.socket, zmq.POLLIN)
667 667 while True:
668 668 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
669 669 if msg is None:
670 670 return
671 671
672 672 self.log.info("Aborting:")
673 673 self.log.info("%s", msg)
674 674 msg_type = msg['header']['msg_type']
675 675 reply_type = msg_type.split('_')[0] + '_reply'
676 676
677 677 status = {'status' : 'aborted'}
678 678 md = {'engine' : self.ident}
679 679 md.update(status)
680 680 reply_msg = self.session.send(stream, reply_type, meatadata=md,
681 681 content=status, parent=msg, ident=idents)
682 682 self.log.debug("%s", reply_msg)
683 683 # We need to wait a bit for requests to come in. This can probably
684 684 # be set shorter for true asynchronous clients.
685 685 poller.poll(50)
686 686
687 687
688 688 def _no_raw_input(self):
689 689 """Raise StdinNotImplentedError if active frontend doesn't support
690 690 stdin."""
691 691 raise StdinNotImplementedError("raw_input was called, but this "
692 692 "frontend does not support stdin.")
693 693
694 694 def _raw_input(self, prompt, ident, parent):
695 695 # Flush output before making the request.
696 696 sys.stderr.flush()
697 697 sys.stdout.flush()
698 698
699 699 # Send the input request.
700 700 content = json_clean(dict(prompt=prompt))
701 701 self.session.send(self.stdin_socket, u'input_request', content, parent,
702 702 ident=ident)
703 703
704 704 # Await a response.
705 705 while True:
706 706 try:
707 707 ident, reply = self.session.recv(self.stdin_socket, 0)
708 708 except Exception:
709 709 self.log.warn("Invalid Message:", exc_info=True)
710 710 else:
711 711 break
712 712 try:
713 713 value = reply['content']['value']
714 714 except:
715 715 self.log.error("Got bad raw_input reply: ")
716 716 self.log.error("%s", parent)
717 717 value = ''
718 718 if value == '\x04':
719 719 # EOF
720 720 raise EOFError
721 721 return value
722 722
723 723 def _complete(self, msg):
724 724 c = msg['content']
725 725 try:
726 726 cpos = int(c['cursor_pos'])
727 727 except:
728 728 # If we don't get something that we can convert to an integer, at
729 729 # least attempt the completion guessing the cursor is at the end of
730 730 # the text, if there's any, and otherwise of the line
731 731 cpos = len(c['text'])
732 732 if cpos==0:
733 733 cpos = len(c['line'])
734 734 return self.shell.complete(c['text'], c['line'], cpos)
735 735
736 736 def _object_info(self, context):
737 737 symbol, leftover = self._symbol_from_context(context)
738 738 if symbol is not None and not leftover:
739 739 doc = getattr(symbol, '__doc__', '')
740 740 else:
741 741 doc = ''
742 742 object_info = dict(docstring = doc)
743 743 return object_info
744 744
745 745 def _symbol_from_context(self, context):
746 746 if not context:
747 747 return None, context
748 748
749 749 base_symbol_string = context[0]
750 750 symbol = self.shell.user_ns.get(base_symbol_string, None)
751 751 if symbol is None:
752 752 symbol = __builtin__.__dict__.get(base_symbol_string, None)
753 753 if symbol is None:
754 754 return None, context
755 755
756 756 context = context[1:]
757 757 for i, name in enumerate(context):
758 758 new_symbol = getattr(symbol, name, None)
759 759 if new_symbol is None:
760 760 return symbol, context[i:]
761 761 else:
762 762 symbol = new_symbol
763 763
764 764 return symbol, []
765 765
766 766 def _at_shutdown(self):
767 767 """Actions taken at shutdown by the kernel, called by python's atexit.
768 768 """
769 769 # io.rprint("Kernel at_shutdown") # dbg
770 770 if self._shutdown_message is not None:
771 771 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
772 772 self.log.debug("%s", self._shutdown_message)
773 773 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
774 774
775 775 #-----------------------------------------------------------------------------
776 776 # Aliases and Flags for the IPKernelApp
777 777 #-----------------------------------------------------------------------------
778 778
779 779 flags = dict(kernel_flags)
780 780 flags.update(shell_flags)
781 781
782 782 addflag = lambda *args: flags.update(boolean_flag(*args))
783 783
784 784 flags['pylab'] = (
785 785 {'IPKernelApp' : {'pylab' : 'auto'}},
786 786 """Pre-load matplotlib and numpy for interactive use with
787 787 the default matplotlib backend."""
788 788 )
789 789
790 790 aliases = dict(kernel_aliases)
791 791 aliases.update(shell_aliases)
792 792
793 793 #-----------------------------------------------------------------------------
794 794 # The IPKernelApp class
795 795 #-----------------------------------------------------------------------------
796 796
797 797 class IPKernelApp(KernelApp, InteractiveShellApp):
798 798 name = 'ipkernel'
799 799
800 800 aliases = Dict(aliases)
801 801 flags = Dict(flags)
802 802 classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
803 803
804 804 @catch_config_error
805 805 def initialize(self, argv=None):
806 806 super(IPKernelApp, self).initialize(argv)
807 807 self.init_path()
808 808 self.init_shell()
809 809 self.init_gui_pylab()
810 810 self.init_extensions()
811 811 self.init_code()
812 812
813 813 def init_kernel(self):
814 814
815 815 shell_stream = ZMQStream(self.shell_socket)
816 816
817 817 kernel = Kernel(config=self.config, session=self.session,
818 818 shell_streams=[shell_stream],
819 819 iopub_socket=self.iopub_socket,
820 820 stdin_socket=self.stdin_socket,
821 821 log=self.log,
822 822 profile_dir=self.profile_dir,
823 823 )
824 824 self.kernel = kernel
825 825 kernel.record_ports(self.ports)
826 826 shell = kernel.shell
827 827
828 828 def init_gui_pylab(self):
829 829 """Enable GUI event loop integration, taking pylab into account."""
830 830
831 831 # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
832 832 # to ensure that any exception is printed straight to stderr.
833 833 # Normally _showtraceback associates the reply with an execution,
834 834 # which means frontends will never draw it, as this exception
835 835 # is not associated with any execute request.
836 836
837 837 shell = self.shell
838 838 _showtraceback = shell._showtraceback
839 839 try:
840 840 # replace pyerr-sending traceback with stderr
841 841 def print_tb(etype, evalue, stb):
842 842 print ("GUI event loop or pylab initialization failed",
843 843 file=io.stderr)
844 844 print (shell.InteractiveTB.stb2text(stb), file=io.stderr)
845 845 shell._showtraceback = print_tb
846 846 InteractiveShellApp.init_gui_pylab(self)
847 847 finally:
848 848 shell._showtraceback = _showtraceback
849 849
850 850 def init_shell(self):
851 851 self.shell = self.kernel.shell
852 852 self.shell.configurables.append(self)
853 853
854 854
855 855 #-----------------------------------------------------------------------------
856 856 # Kernel main and launch functions
857 857 #-----------------------------------------------------------------------------
858 858
859 859 def launch_kernel(*args, **kwargs):
860 860 """Launches a localhost IPython kernel, binding to the specified ports.
861 861
862 862 This function simply calls entry_point.base_launch_kernel with the right
863 863 first command to start an ipkernel. See base_launch_kernel for arguments.
864 864
865 865 Returns
866 866 -------
867 867 A tuple of form:
868 868 (kernel_process, shell_port, iopub_port, stdin_port, hb_port)
869 869 where kernel_process is a Popen object and the ports are integers.
870 870 """
871 871 return base_launch_kernel('from IPython.zmq.ipkernel import main; main()',
872 872 *args, **kwargs)
873 873
874 874
875 875 def embed_kernel(module=None, local_ns=None, **kwargs):
876 876 """Embed and start an IPython kernel in a given scope.
877 877
878 878 Parameters
879 879 ----------
880 880 module : ModuleType, optional
881 881 The module to load into IPython globals (default: caller)
882 882 local_ns : dict, optional
883 883 The namespace to load into IPython user namespace (default: caller)
884 884
885 885 kwargs : various, optional
886 886 Further keyword args are relayed to the KernelApp constructor,
887 887 allowing configuration of the Kernel. Will only have an effect
888 888 on the first embed_kernel call for a given process.
889 889
890 890 """
891 891 # get the app if it exists, or set it up if it doesn't
892 892 if IPKernelApp.initialized():
893 893 app = IPKernelApp.instance()
894 894 else:
895 895 app = IPKernelApp.instance(**kwargs)
896 896 app.initialize([])
897 897 # Undo unnecessary sys module mangling from init_sys_modules.
898 898 # This would not be necessary if we could prevent it
899 899 # in the first place by using a different InteractiveShell
900 900 # subclass, as in the regular embed case.
901 901 main = app.kernel.shell._orig_sys_modules_main_mod
902 902 if main is not None:
903 903 sys.modules[app.kernel.shell._orig_sys_modules_main_name] = main
904 904
905 905 # load the calling scope if not given
906 906 (caller_module, caller_locals) = extract_module_locals(1)
907 907 if module is None:
908 908 module = caller_module
909 909 if local_ns is None:
910 910 local_ns = caller_locals
911 911
912 912 app.kernel.user_module = module
913 913 app.kernel.user_ns = local_ns
914 914 app.shell.set_completer_frame()
915 915 app.start()
916 916
917 917 def main():
918 918 """Run an IPKernel as an application"""
919 919 app = IPKernelApp.instance()
920 920 app.initialize()
921 921 app.start()
922 922
923 923
924 924 if __name__ == '__main__':
925 925 main()
@@ -1,179 +1,175 b''
1 1 """serialization utilities for apply messages
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports
19 19 import logging
20 20 import os
21 21 import re
22 22 import socket
23 23 import sys
24 24
25 25 try:
26 26 import cPickle
27 27 pickle = cPickle
28 28 except:
29 29 cPickle = None
30 30 import pickle
31 31
32 32
33 33 # IPython imports
34 34 from IPython.utils import py3compat
35 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
35 from IPython.utils.pickleutil import (
36 can, uncan, can_sequence, uncan_sequence, CannedObject
37 )
36 38 from IPython.utils.newserialized import serialize, unserialize
37 39
38 40 if py3compat.PY3:
39 41 buffer = memoryview
40 42
41 43 #-----------------------------------------------------------------------------
42 44 # Serialization Functions
43 45 #-----------------------------------------------------------------------------
44 46
45 def serialize_object(obj, threshold=64e-6):
47 # maximum items to iterate through in a container
48 MAX_ITEMS = 64
49
50 def _extract_buffers(obj, threshold=1024):
51 """extract buffers larger than a certain threshold"""
52 buffers = []
53 if isinstance(obj, CannedObject) and obj.buffers:
54 for i,buf in enumerate(obj.buffers):
55 if len(buf) > threshold:
56 # buffer larger than threshold, prevent pickling
57 obj.buffers[i] = None
58 buffers.append(buf)
59 elif isinstance(buf, buffer):
60 # buffer too small for separate send, coerce to bytes
61 # because pickling buffer objects just results in broken pointers
62 obj.buffers[i] = bytes(buf)
63 return buffers
64
65 def _restore_buffers(obj, buffers):
66 """restore buffers extracted by """
67 if isinstance(obj, CannedObject) and obj.buffers:
68 for i,buf in enumerate(obj.buffers):
69 if buf is None:
70 obj.buffers[i] = buffers.pop(0)
71
72 def serialize_object(obj, threshold=1024):
46 73 """Serialize an object into a list of sendable buffers.
47 74
48 75 Parameters
49 76 ----------
50 77
51 78 obj : object
52 79 The object to be serialized
53 threshold : float
54 The threshold for not double-pickling the content.
55
80 threshold : int
81 The threshold (in bytes) for pulling out data buffers
82 to avoid pickling them.
56 83
57 84 Returns
58 85 -------
59 ('pmd', [bufs]) :
60 where pmd is the pickled metadata wrapper,
61 bufs is a list of data buffers
86 [bufs] : list of buffers representing the serialized object.
62 87 """
63 databuffers = []
64 if isinstance(obj, (list, tuple)):
65 clist = canSequence(obj)
66 slist = map(serialize, clist)
67 for s in slist:
68 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
69 databuffers.append(s.getData())
70 s.data = None
71 return pickle.dumps(slist,-1), databuffers
72 elif isinstance(obj, dict):
73 sobj = {}
88 buffers = []
89 if isinstance(obj, (list, tuple)) and len(obj) < MAX_ITEMS:
90 cobj = can_sequence(obj)
91 for c in cobj:
92 buffers.extend(_extract_buffers(c, threshold))
93 elif isinstance(obj, dict) and len(obj) < MAX_ITEMS:
94 cobj = {}
74 95 for k in sorted(obj.iterkeys()):
75 s = serialize(can(obj[k]))
76 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
77 databuffers.append(s.getData())
78 s.data = None
79 sobj[k] = s
80 return pickle.dumps(sobj,-1),databuffers
96 c = can(obj[k])
97 buffers.extend(_extract_buffers(c, threshold))
98 cobj[k] = c
81 99 else:
82 s = serialize(can(obj))
83 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
84 databuffers.append(s.getData())
85 s.data = None
86 return pickle.dumps(s,-1),databuffers
87
88
89 def unserialize_object(bufs):
90 """reconstruct an object serialized by serialize_object from data buffers."""
91 bufs = list(bufs)
92 sobj = pickle.loads(bufs.pop(0))
93 if isinstance(sobj, (list, tuple)):
94 for s in sobj:
95 if s.data is None:
96 s.data = bufs.pop(0)
97 return uncanSequence(map(unserialize, sobj)), bufs
98 elif isinstance(sobj, dict):
100 cobj = can(obj)
101 buffers.extend(_extract_buffers(cobj, threshold))
102
103 buffers.insert(0, pickle.dumps(cobj,-1))
104 return buffers
105
106 def unserialize_object(buffers, g=None):
107 """reconstruct an object serialized by serialize_object from data buffers.
108
109 Parameters
110 ----------
111
112 bufs : list of buffers/bytes
113
114 g : globals to be used when uncanning
115
116 Returns
117 -------
118
119 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
120 """
121 bufs = list(buffers)
122 canned = pickle.loads(bufs.pop(0))
123 if isinstance(canned, (list, tuple)) and len(canned) < MAX_ITEMS:
124 for c in canned:
125 _restore_buffers(c, bufs)
126 newobj = uncan_sequence(canned, g)
127 elif isinstance(canned, dict) and len(canned) < MAX_ITEMS:
99 128 newobj = {}
100 for k in sorted(sobj.iterkeys()):
101 s = sobj[k]
102 if s.data is None:
103 s.data = bufs.pop(0)
104 newobj[k] = uncan(unserialize(s))
105 return newobj, bufs
129 for k in sorted(canned.iterkeys()):
130 c = canned[k]
131 _restore_buffers(c, bufs)
132 newobj[k] = uncan(c, g)
106 133 else:
107 if sobj.data is None:
108 sobj.data = bufs.pop(0)
109 return uncan(unserialize(sobj)), bufs
134 _restore_buffers(canned, bufs)
135 newobj = uncan(canned, g)
136
137 return newobj, bufs
110 138
111 def pack_apply_message(f, args, kwargs, threshold=64e-6):
139 def pack_apply_message(f, args, kwargs, threshold=1024):
112 140 """pack up a function, args, and kwargs to be sent over the wire
113 141 as a series of buffers. Any object whose data is larger than `threshold`
114 will not have their data copied (currently only numpy arrays support zero-copy)"""
142 will not have their data copied (currently only numpy arrays support zero-copy)
143 """
115 144 msg = [pickle.dumps(can(f),-1)]
116 145 databuffers = [] # for large objects
117 sargs, bufs = serialize_object(args,threshold)
118 msg.append(sargs)
119 databuffers.extend(bufs)
120 skwargs, bufs = serialize_object(kwargs,threshold)
121 msg.append(skwargs)
122 databuffers.extend(bufs)
146 sargs = serialize_object(args,threshold)
147 msg.append(sargs[0])
148 databuffers.extend(sargs[1:])
149 skwargs = serialize_object(kwargs,threshold)
150 msg.append(skwargs[0])
151 databuffers.extend(skwargs[1:])
123 152 msg.extend(databuffers)
124 153 return msg
125 154
126 155 def unpack_apply_message(bufs, g=None, copy=True):
127 156 """unpack f,args,kwargs from buffers packed by pack_apply_message()
128 157 Returns: original f,args,kwargs"""
129 158 bufs = list(bufs) # allow us to pop
130 159 assert len(bufs) >= 3, "not enough buffers!"
131 160 if not copy:
132 161 for i in range(3):
133 162 bufs[i] = bufs[i].bytes
134 cf = pickle.loads(bufs.pop(0))
135 sargs = list(pickle.loads(bufs.pop(0)))
136 skwargs = dict(pickle.loads(bufs.pop(0)))
137 # print sargs, skwargs
138 f = uncan(cf, g)
139 for sa in sargs:
140 if sa.data is None:
141 m = bufs.pop(0)
142 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
143 # always use a buffer, until memoryviews get sorted out
144 sa.data = buffer(m)
145 # disable memoryview support
146 # if copy:
147 # sa.data = buffer(m)
148 # else:
149 # sa.data = m.buffer
150 else:
151 if copy:
152 sa.data = m
153 else:
154 sa.data = m.bytes
163 f = uncan(pickle.loads(bufs.pop(0)), g)
164 # sargs = bufs.pop(0)
165 # pop kwargs out, so first n-elements are args, serialized
166 skwargs = bufs.pop(1)
167 args, bufs = unserialize_object(bufs, g)
168 # put skwargs back in as the first element
169 bufs.insert(0, skwargs)
170 kwargs, bufs = unserialize_object(bufs, g)
155 171
156 args = uncanSequence(map(unserialize, sargs), g)
157 kwargs = {}
158 for k in sorted(skwargs.iterkeys()):
159 sa = skwargs[k]
160 if sa.data is None:
161 m = bufs.pop(0)
162 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
163 # always use a buffer, until memoryviews get sorted out
164 sa.data = buffer(m)
165 # disable memoryview support
166 # if copy:
167 # sa.data = buffer(m)
168 # else:
169 # sa.data = m.buffer
170 else:
171 if copy:
172 sa.data = m
173 else:
174 sa.data = m.bytes
175
176 kwargs[k] = uncan(unserialize(sa), g)
172 assert not bufs, "Shouldn't be any data left over"
177 173
178 174 return f,args,kwargs
179 175
General Comments 0
You need to be logged in to leave comments. Login now