##// END OF EJS Templates
pickle length-0 arrays.
MinRK -
Show More
@@ -1,167 +1,169
1 1 # encoding: utf-8
2 2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
3 3
4 4 """Refactored serialization classes and interfaces."""
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Imports
20 20 #-------------------------------------------------------------------------------
21 21
22 22 import cPickle as pickle
23 23
24 24 try:
25 25 import numpy
26 26 except ImportError:
27 27 pass
28 28
29 29 class SerializationError(Exception):
30 30 pass
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Classes and functions
34 34 #-----------------------------------------------------------------------------
35 35
36 36 class ISerialized:
37 37
38 38 def getData():
39 39 """"""
40 40
41 41 def getDataSize(units=10.0**6):
42 42 """"""
43 43
44 44 def getTypeDescriptor():
45 45 """"""
46 46
47 47 def getMetadata():
48 48 """"""
49 49
50 50
51 51 class IUnSerialized:
52 52
53 53 def getObject():
54 54 """"""
55 55
56 56 class Serialized(object):
57 57
58 58 # implements(ISerialized)
59 59
60 60 def __init__(self, data, typeDescriptor, metadata={}):
61 61 self.data = data
62 62 self.typeDescriptor = typeDescriptor
63 63 self.metadata = metadata
64 64
65 65 def getData(self):
66 66 return self.data
67 67
68 68 def getDataSize(self, units=10.0**6):
69 69 return len(self.data)/units
70 70
71 71 def getTypeDescriptor(self):
72 72 return self.typeDescriptor
73 73
74 74 def getMetadata(self):
75 75 return self.metadata
76 76
77 77
78 78 class UnSerialized(object):
79 79
80 80 # implements(IUnSerialized)
81 81
82 82 def __init__(self, obj):
83 83 self.obj = obj
84 84
85 85 def getObject(self):
86 86 return self.obj
87 87
88 88
89 89 class SerializeIt(object):
90 90
91 91 # implements(ISerialized)
92 92
93 93 def __init__(self, unSerialized):
94 94 self.data = None
95 95 self.obj = unSerialized.getObject()
96 96 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
97 if len(self.obj) == 0: # length 0 arrays can't be reconstructed
98 raise SerializationError("You cannot send a length 0 array")
97 if len(self.obj.shape) == 0: # length 0 arrays are just pickled
98 self.typeDescriptor = 'pickle'
99 self.metadata = {}
100 else:
99 101 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
100 102 self.typeDescriptor = 'ndarray'
101 103 self.metadata = {'shape':self.obj.shape,
102 104 'dtype':self.obj.dtype.str}
103 105 elif isinstance(self.obj, bytes):
104 106 self.typeDescriptor = 'bytes'
105 107 self.metadata = {}
106 108 elif isinstance(self.obj, buffer):
107 109 self.typeDescriptor = 'buffer'
108 110 self.metadata = {}
109 111 else:
110 112 self.typeDescriptor = 'pickle'
111 113 self.metadata = {}
112 114 self._generateData()
113 115
114 116 def _generateData(self):
115 117 if self.typeDescriptor == 'ndarray':
116 118 self.data = numpy.getbuffer(self.obj)
117 119 elif self.typeDescriptor in ('bytes', 'buffer'):
118 120 self.data = self.obj
119 121 elif self.typeDescriptor == 'pickle':
120 122 self.data = pickle.dumps(self.obj, -1)
121 123 else:
122 124 raise SerializationError("Really wierd serialization error.")
123 125 del self.obj
124 126
125 127 def getData(self):
126 128 return self.data
127 129
128 130 def getDataSize(self, units=10.0**6):
129 131 return 1.0*len(self.data)/units
130 132
131 133 def getTypeDescriptor(self):
132 134 return self.typeDescriptor
133 135
134 136 def getMetadata(self):
135 137 return self.metadata
136 138
137 139
138 140 class UnSerializeIt(UnSerialized):
139 141
140 142 # implements(IUnSerialized)
141 143
142 144 def __init__(self, serialized):
143 145 self.serialized = serialized
144 146
145 147 def getObject(self):
146 148 typeDescriptor = self.serialized.getTypeDescriptor()
147 149 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
148 150 buf = self.serialized.getData()
149 151 if isinstance(buf, (buffer,bytes)):
150 152 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
151 153 else:
152 154 # memoryview
153 155 result = numpy.array(buf, dtype = self.serialized.metadata['dtype'])
154 156 result.shape = self.serialized.metadata['shape']
155 157 elif typeDescriptor == 'pickle':
156 158 result = pickle.loads(self.serialized.getData())
157 159 elif typeDescriptor in ('bytes', 'buffer'):
158 160 result = self.serialized.getData()
159 161 else:
160 162 raise SerializationError("Really wierd serialization error.")
161 163 return result
162 164
163 165 def serialize(obj):
164 166 return SerializeIt(UnSerialized(obj))
165 167
166 168 def unserialize(serialized):
167 169 return UnSerializeIt(serialized).getObject()
@@ -1,139 +1,138
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5 5 """
6 6 from __future__ import print_function
7 7
8 8 import sys
9 9 import time
10 10
11 11 import zmq
12 12 from zmq.eventloop import ioloop, zmqstream
13 13
14 14 # internal
15 15 from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
16 16 # from IPython.utils.localinterfaces import LOCALHOST
17 17
18 18 from . import heartmonitor
19 19 from .factory import RegistrationFactory
20 20 from .streamkernel import Kernel
21 21 from .streamsession import Message
22 22 from .util import disambiguate_url
23 23
24 24 class EngineFactory(RegistrationFactory):
25 25 """IPython engine"""
26 26
27 27 # configurables:
28 28 user_ns=Dict(config=True)
29 29 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
30 30 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
31 31 location=Str(config=True)
32 32 timeout=CFloat(2,config=True)
33 33
34 34 # not configurable:
35 35 id=Int(allow_none=True)
36 36 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
37 37 kernel=Instance(Kernel)
38 38
39 39
40 40 def __init__(self, **kwargs):
41 41 super(EngineFactory, self).__init__(**kwargs)
42 42 ctx = self.context
43 43
44 44 reg = ctx.socket(zmq.PAIR)
45 45 reg.setsockopt(zmq.IDENTITY, self.ident)
46 46 reg.connect(self.url)
47 47 self.registrar = zmqstream.ZMQStream(reg, self.loop)
48 48
49 49 def register(self):
50 50 """send the registration_request"""
51 51
52 52 self.log.info("registering")
53 53 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
54 54 self.registrar.on_recv(self.complete_registration)
55 55 # print (self.session.key)
56 56 self.session.send(self.registrar, "registration_request",content=content)
57 57
58 58 def complete_registration(self, msg):
59 59 # print msg
60 60 self._abort_dc.stop()
61 61 ctx = self.context
62 62 loop = self.loop
63 63 identity = self.ident
64 print (identity)
65 64
66 65 idents,msg = self.session.feed_identities(msg)
67 66 msg = Message(self.session.unpack_message(msg))
68 67
69 68 if msg.content.status == 'ok':
70 69 self.id = int(msg.content.id)
71 70
72 71 # create Shell Streams (MUX, Task, etc.):
73 72 queue_addr = msg.content.mux
74 73 shell_addrs = [ str(queue_addr) ]
75 74 task_addr = msg.content.task
76 75 if task_addr:
77 76 shell_addrs.append(str(task_addr))
78 77 shell_streams = []
79 78 for addr in shell_addrs:
80 79 stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
81 80 stream.setsockopt(zmq.IDENTITY, identity)
82 81 stream.connect(disambiguate_url(addr, self.location))
83 82 shell_streams.append(stream)
84 83
85 84 # control stream:
86 85 control_addr = str(msg.content.control)
87 86 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
88 87 control_stream.setsockopt(zmq.IDENTITY, identity)
89 88 control_stream.connect(disambiguate_url(control_addr, self.location))
90 89
91 90 # create iopub stream:
92 91 iopub_addr = msg.content.iopub
93 92 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
94 93 iopub_stream.setsockopt(zmq.IDENTITY, identity)
95 94 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
96 95
97 96 # launch heartbeat
98 97 hb_addrs = msg.content.heartbeat
99 98 # print (hb_addrs)
100 99
101 100 # # Redirect input streams and set a display hook.
102 101 if self.out_stream_factory:
103 102 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
104 103 sys.stdout.topic = 'engine.%i.stdout'%self.id
105 104 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
106 105 sys.stderr.topic = 'engine.%i.stderr'%self.id
107 106 if self.display_hook_factory:
108 107 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
109 108 sys.displayhook.topic = 'engine.%i.pyout'%self.id
110 109
111 110 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
112 111 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
113 112 loop=loop, user_ns = self.user_ns, logname=self.log.name)
114 113 self.kernel.start()
115 114 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
116 115 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
117 116 # ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
118 117 heart.start()
119 118
120 119
121 120 else:
122 121 self.log.fatal("Registration Failed: %s"%msg)
123 122 raise Exception("Registration Failed: %s"%msg)
124 123
125 124 self.log.info("Completed registration with id %i"%self.id)
126 125
127 126
128 127 def abort(self):
129 128 self.log.fatal("Registration timed out")
130 129 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
131 130 time.sleep(1)
132 131 sys.exit(255)
133 132
134 133 def start(self):
135 134 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
136 135 dc.start()
137 136 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
138 137 self._abort_dc.start()
139 138
General Comments 0
You need to be logged in to leave comments. Login now