##// END OF EJS Templates
pickle length-0 arrays.
MinRK -
Show More
@@ -1,167 +1,169
1 # encoding: utf-8
1 # encoding: utf-8
2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
3
3
4 """Refactored serialization classes and interfaces."""
4 """Refactored serialization classes and interfaces."""
5
5
6 __docformat__ = "restructuredtext en"
6 __docformat__ = "restructuredtext en"
7
7
8 # Tell nose to skip this module
8 # Tell nose to skip this module
9 __test__ = {}
9 __test__ = {}
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
12 # Copyright (C) 2008 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21
21
22 import cPickle as pickle
22 import cPickle as pickle
23
23
24 try:
24 try:
25 import numpy
25 import numpy
26 except ImportError:
26 except ImportError:
27 pass
27 pass
28
28
29 class SerializationError(Exception):
29 class SerializationError(Exception):
30 pass
30 pass
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Classes and functions
33 # Classes and functions
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 class ISerialized:
36 class ISerialized:
37
37
38 def getData():
38 def getData():
39 """"""
39 """"""
40
40
41 def getDataSize(units=10.0**6):
41 def getDataSize(units=10.0**6):
42 """"""
42 """"""
43
43
44 def getTypeDescriptor():
44 def getTypeDescriptor():
45 """"""
45 """"""
46
46
47 def getMetadata():
47 def getMetadata():
48 """"""
48 """"""
49
49
50
50
51 class IUnSerialized:
51 class IUnSerialized:
52
52
53 def getObject():
53 def getObject():
54 """"""
54 """"""
55
55
56 class Serialized(object):
56 class Serialized(object):
57
57
58 # implements(ISerialized)
58 # implements(ISerialized)
59
59
60 def __init__(self, data, typeDescriptor, metadata={}):
60 def __init__(self, data, typeDescriptor, metadata={}):
61 self.data = data
61 self.data = data
62 self.typeDescriptor = typeDescriptor
62 self.typeDescriptor = typeDescriptor
63 self.metadata = metadata
63 self.metadata = metadata
64
64
65 def getData(self):
65 def getData(self):
66 return self.data
66 return self.data
67
67
68 def getDataSize(self, units=10.0**6):
68 def getDataSize(self, units=10.0**6):
69 return len(self.data)/units
69 return len(self.data)/units
70
70
71 def getTypeDescriptor(self):
71 def getTypeDescriptor(self):
72 return self.typeDescriptor
72 return self.typeDescriptor
73
73
74 def getMetadata(self):
74 def getMetadata(self):
75 return self.metadata
75 return self.metadata
76
76
77
77
78 class UnSerialized(object):
78 class UnSerialized(object):
79
79
80 # implements(IUnSerialized)
80 # implements(IUnSerialized)
81
81
82 def __init__(self, obj):
82 def __init__(self, obj):
83 self.obj = obj
83 self.obj = obj
84
84
85 def getObject(self):
85 def getObject(self):
86 return self.obj
86 return self.obj
87
87
88
88
89 class SerializeIt(object):
89 class SerializeIt(object):
90
90
91 # implements(ISerialized)
91 # implements(ISerialized)
92
92
93 def __init__(self, unSerialized):
93 def __init__(self, unSerialized):
94 self.data = None
94 self.data = None
95 self.obj = unSerialized.getObject()
95 self.obj = unSerialized.getObject()
96 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
96 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
97 if len(self.obj) == 0: # length 0 arrays can't be reconstructed
97 if len(self.obj.shape) == 0: # length 0 arrays are just pickled
98 raise SerializationError("You cannot send a length 0 array")
98 self.typeDescriptor = 'pickle'
99 self.metadata = {}
100 else:
99 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
101 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
100 self.typeDescriptor = 'ndarray'
102 self.typeDescriptor = 'ndarray'
101 self.metadata = {'shape':self.obj.shape,
103 self.metadata = {'shape':self.obj.shape,
102 'dtype':self.obj.dtype.str}
104 'dtype':self.obj.dtype.str}
103 elif isinstance(self.obj, bytes):
105 elif isinstance(self.obj, bytes):
104 self.typeDescriptor = 'bytes'
106 self.typeDescriptor = 'bytes'
105 self.metadata = {}
107 self.metadata = {}
106 elif isinstance(self.obj, buffer):
108 elif isinstance(self.obj, buffer):
107 self.typeDescriptor = 'buffer'
109 self.typeDescriptor = 'buffer'
108 self.metadata = {}
110 self.metadata = {}
109 else:
111 else:
110 self.typeDescriptor = 'pickle'
112 self.typeDescriptor = 'pickle'
111 self.metadata = {}
113 self.metadata = {}
112 self._generateData()
114 self._generateData()
113
115
114 def _generateData(self):
116 def _generateData(self):
115 if self.typeDescriptor == 'ndarray':
117 if self.typeDescriptor == 'ndarray':
116 self.data = numpy.getbuffer(self.obj)
118 self.data = numpy.getbuffer(self.obj)
117 elif self.typeDescriptor in ('bytes', 'buffer'):
119 elif self.typeDescriptor in ('bytes', 'buffer'):
118 self.data = self.obj
120 self.data = self.obj
119 elif self.typeDescriptor == 'pickle':
121 elif self.typeDescriptor == 'pickle':
120 self.data = pickle.dumps(self.obj, -1)
122 self.data = pickle.dumps(self.obj, -1)
121 else:
123 else:
122 raise SerializationError("Really wierd serialization error.")
124 raise SerializationError("Really wierd serialization error.")
123 del self.obj
125 del self.obj
124
126
125 def getData(self):
127 def getData(self):
126 return self.data
128 return self.data
127
129
128 def getDataSize(self, units=10.0**6):
130 def getDataSize(self, units=10.0**6):
129 return 1.0*len(self.data)/units
131 return 1.0*len(self.data)/units
130
132
131 def getTypeDescriptor(self):
133 def getTypeDescriptor(self):
132 return self.typeDescriptor
134 return self.typeDescriptor
133
135
134 def getMetadata(self):
136 def getMetadata(self):
135 return self.metadata
137 return self.metadata
136
138
137
139
138 class UnSerializeIt(UnSerialized):
140 class UnSerializeIt(UnSerialized):
139
141
140 # implements(IUnSerialized)
142 # implements(IUnSerialized)
141
143
142 def __init__(self, serialized):
144 def __init__(self, serialized):
143 self.serialized = serialized
145 self.serialized = serialized
144
146
145 def getObject(self):
147 def getObject(self):
146 typeDescriptor = self.serialized.getTypeDescriptor()
148 typeDescriptor = self.serialized.getTypeDescriptor()
147 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
149 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
148 buf = self.serialized.getData()
150 buf = self.serialized.getData()
149 if isinstance(buf, (buffer,bytes)):
151 if isinstance(buf, (buffer,bytes)):
150 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
152 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
151 else:
153 else:
152 # memoryview
154 # memoryview
153 result = numpy.array(buf, dtype = self.serialized.metadata['dtype'])
155 result = numpy.array(buf, dtype = self.serialized.metadata['dtype'])
154 result.shape = self.serialized.metadata['shape']
156 result.shape = self.serialized.metadata['shape']
155 elif typeDescriptor == 'pickle':
157 elif typeDescriptor == 'pickle':
156 result = pickle.loads(self.serialized.getData())
158 result = pickle.loads(self.serialized.getData())
157 elif typeDescriptor in ('bytes', 'buffer'):
159 elif typeDescriptor in ('bytes', 'buffer'):
158 result = self.serialized.getData()
160 result = self.serialized.getData()
159 else:
161 else:
160 raise SerializationError("Really wierd serialization error.")
162 raise SerializationError("Really wierd serialization error.")
161 return result
163 return result
162
164
163 def serialize(obj):
165 def serialize(obj):
164 return SerializeIt(UnSerialized(obj))
166 return SerializeIt(UnSerialized(obj))
165
167
166 def unserialize(serialized):
168 def unserialize(serialized):
167 return UnSerializeIt(serialized).getObject()
169 return UnSerializeIt(serialized).getObject()
@@ -1,139 +1,138
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple engine that talks to a controller over 0MQ.
2 """A simple engine that talks to a controller over 0MQ.
3 it handles registration, etc. and launches a kernel
3 it handles registration, etc. and launches a kernel
4 connected to the Controller's Schedulers.
4 connected to the Controller's Schedulers.
5 """
5 """
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import sys
8 import sys
9 import time
9 import time
10
10
11 import zmq
11 import zmq
12 from zmq.eventloop import ioloop, zmqstream
12 from zmq.eventloop import ioloop, zmqstream
13
13
14 # internal
14 # internal
15 from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
15 from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
16 # from IPython.utils.localinterfaces import LOCALHOST
16 # from IPython.utils.localinterfaces import LOCALHOST
17
17
18 from . import heartmonitor
18 from . import heartmonitor
19 from .factory import RegistrationFactory
19 from .factory import RegistrationFactory
20 from .streamkernel import Kernel
20 from .streamkernel import Kernel
21 from .streamsession import Message
21 from .streamsession import Message
22 from .util import disambiguate_url
22 from .util import disambiguate_url
23
23
24 class EngineFactory(RegistrationFactory):
24 class EngineFactory(RegistrationFactory):
25 """IPython engine"""
25 """IPython engine"""
26
26
27 # configurables:
27 # configurables:
28 user_ns=Dict(config=True)
28 user_ns=Dict(config=True)
29 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
29 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
30 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
30 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
31 location=Str(config=True)
31 location=Str(config=True)
32 timeout=CFloat(2,config=True)
32 timeout=CFloat(2,config=True)
33
33
34 # not configurable:
34 # not configurable:
35 id=Int(allow_none=True)
35 id=Int(allow_none=True)
36 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
36 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
37 kernel=Instance(Kernel)
37 kernel=Instance(Kernel)
38
38
39
39
40 def __init__(self, **kwargs):
40 def __init__(self, **kwargs):
41 super(EngineFactory, self).__init__(**kwargs)
41 super(EngineFactory, self).__init__(**kwargs)
42 ctx = self.context
42 ctx = self.context
43
43
44 reg = ctx.socket(zmq.PAIR)
44 reg = ctx.socket(zmq.PAIR)
45 reg.setsockopt(zmq.IDENTITY, self.ident)
45 reg.setsockopt(zmq.IDENTITY, self.ident)
46 reg.connect(self.url)
46 reg.connect(self.url)
47 self.registrar = zmqstream.ZMQStream(reg, self.loop)
47 self.registrar = zmqstream.ZMQStream(reg, self.loop)
48
48
49 def register(self):
49 def register(self):
50 """send the registration_request"""
50 """send the registration_request"""
51
51
52 self.log.info("registering")
52 self.log.info("registering")
53 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
53 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
54 self.registrar.on_recv(self.complete_registration)
54 self.registrar.on_recv(self.complete_registration)
55 # print (self.session.key)
55 # print (self.session.key)
56 self.session.send(self.registrar, "registration_request",content=content)
56 self.session.send(self.registrar, "registration_request",content=content)
57
57
58 def complete_registration(self, msg):
58 def complete_registration(self, msg):
59 # print msg
59 # print msg
60 self._abort_dc.stop()
60 self._abort_dc.stop()
61 ctx = self.context
61 ctx = self.context
62 loop = self.loop
62 loop = self.loop
63 identity = self.ident
63 identity = self.ident
64 print (identity)
65
64
66 idents,msg = self.session.feed_identities(msg)
65 idents,msg = self.session.feed_identities(msg)
67 msg = Message(self.session.unpack_message(msg))
66 msg = Message(self.session.unpack_message(msg))
68
67
69 if msg.content.status == 'ok':
68 if msg.content.status == 'ok':
70 self.id = int(msg.content.id)
69 self.id = int(msg.content.id)
71
70
72 # create Shell Streams (MUX, Task, etc.):
71 # create Shell Streams (MUX, Task, etc.):
73 queue_addr = msg.content.mux
72 queue_addr = msg.content.mux
74 shell_addrs = [ str(queue_addr) ]
73 shell_addrs = [ str(queue_addr) ]
75 task_addr = msg.content.task
74 task_addr = msg.content.task
76 if task_addr:
75 if task_addr:
77 shell_addrs.append(str(task_addr))
76 shell_addrs.append(str(task_addr))
78 shell_streams = []
77 shell_streams = []
79 for addr in shell_addrs:
78 for addr in shell_addrs:
80 stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
79 stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
81 stream.setsockopt(zmq.IDENTITY, identity)
80 stream.setsockopt(zmq.IDENTITY, identity)
82 stream.connect(disambiguate_url(addr, self.location))
81 stream.connect(disambiguate_url(addr, self.location))
83 shell_streams.append(stream)
82 shell_streams.append(stream)
84
83
85 # control stream:
84 # control stream:
86 control_addr = str(msg.content.control)
85 control_addr = str(msg.content.control)
87 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
86 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
88 control_stream.setsockopt(zmq.IDENTITY, identity)
87 control_stream.setsockopt(zmq.IDENTITY, identity)
89 control_stream.connect(disambiguate_url(control_addr, self.location))
88 control_stream.connect(disambiguate_url(control_addr, self.location))
90
89
91 # create iopub stream:
90 # create iopub stream:
92 iopub_addr = msg.content.iopub
91 iopub_addr = msg.content.iopub
93 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
92 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
94 iopub_stream.setsockopt(zmq.IDENTITY, identity)
93 iopub_stream.setsockopt(zmq.IDENTITY, identity)
95 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
94 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
96
95
97 # launch heartbeat
96 # launch heartbeat
98 hb_addrs = msg.content.heartbeat
97 hb_addrs = msg.content.heartbeat
99 # print (hb_addrs)
98 # print (hb_addrs)
100
99
101 # # Redirect input streams and set a display hook.
100 # # Redirect input streams and set a display hook.
102 if self.out_stream_factory:
101 if self.out_stream_factory:
103 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
102 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
104 sys.stdout.topic = 'engine.%i.stdout'%self.id
103 sys.stdout.topic = 'engine.%i.stdout'%self.id
105 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
104 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
106 sys.stderr.topic = 'engine.%i.stderr'%self.id
105 sys.stderr.topic = 'engine.%i.stderr'%self.id
107 if self.display_hook_factory:
106 if self.display_hook_factory:
108 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
107 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
109 sys.displayhook.topic = 'engine.%i.pyout'%self.id
108 sys.displayhook.topic = 'engine.%i.pyout'%self.id
110
109
111 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
110 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
112 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
111 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
113 loop=loop, user_ns = self.user_ns, logname=self.log.name)
112 loop=loop, user_ns = self.user_ns, logname=self.log.name)
114 self.kernel.start()
113 self.kernel.start()
115 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
114 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
116 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
115 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
117 # ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
116 # ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
118 heart.start()
117 heart.start()
119
118
120
119
121 else:
120 else:
122 self.log.fatal("Registration Failed: %s"%msg)
121 self.log.fatal("Registration Failed: %s"%msg)
123 raise Exception("Registration Failed: %s"%msg)
122 raise Exception("Registration Failed: %s"%msg)
124
123
125 self.log.info("Completed registration with id %i"%self.id)
124 self.log.info("Completed registration with id %i"%self.id)
126
125
127
126
128 def abort(self):
127 def abort(self):
129 self.log.fatal("Registration timed out")
128 self.log.fatal("Registration timed out")
130 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
129 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
131 time.sleep(1)
130 time.sleep(1)
132 sys.exit(255)
131 sys.exit(255)
133
132
134 def start(self):
133 def start(self):
135 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
134 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
136 dc.start()
135 dc.start()
137 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
136 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
138 self._abort_dc.start()
137 self._abort_dc.start()
139
138
General Comments 0
You need to be logged in to leave comments. Login now