##// END OF EJS Templates
split serialize step of Session.send into separate method...
MinRK -
Show More
@@ -1,410 +1,416 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2010-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11
12 12 import os
13 13 import pprint
14 14 import uuid
15 15 from datetime import datetime
16 16
17 17 try:
18 18 import cPickle
19 19 pickle = cPickle
20 20 except:
21 21 cPickle = None
22 22 import pickle
23 23
24 24 import zmq
25 25 from zmq.utils import jsonapi
26 26 from zmq.eventloop.zmqstream import ZMQStream
27 27
28 28 from .util import ISO8601
29 29
30 30 def squash_unicode(obj):
31 31 """coerce unicode back to bytestrings."""
32 32 if isinstance(obj,dict):
33 33 for key in obj.keys():
34 34 obj[key] = squash_unicode(obj[key])
35 35 if isinstance(key, unicode):
36 36 obj[squash_unicode(key)] = obj.pop(key)
37 37 elif isinstance(obj, list):
38 38 for i,v in enumerate(obj):
39 39 obj[i] = squash_unicode(v)
40 40 elif isinstance(obj, unicode):
41 41 obj = obj.encode('utf8')
42 42 return obj
43 43
44 44 def _date_default(obj):
45 45 if isinstance(obj, datetime):
46 46 return obj.strftime(ISO8601)
47 47 else:
48 48 raise TypeError("%r is not JSON serializable"%obj)
49 49
50 50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
52 52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
53 53
54 54 pickle_packer = lambda o: pickle.dumps(o,-1)
55 55 pickle_unpacker = pickle.loads
56 56
57 57 default_packer = json_packer
58 58 default_unpacker = json_unpacker
59 59
60 60
61 61 DELIM="<IDS|MSG>"
62 62
63 63 class Message(object):
64 64 """A simple message object that maps dict keys to attributes.
65 65
66 66 A Message can be created from a dict and a dict from a Message instance
67 67 simply by calling dict(msg_obj)."""
68 68
69 69 def __init__(self, msg_dict):
70 70 dct = self.__dict__
71 71 for k, v in dict(msg_dict).iteritems():
72 72 if isinstance(v, dict):
73 73 v = Message(v)
74 74 dct[k] = v
75 75
76 76 # Having this iterator lets dict(msg_obj) work out of the box.
77 77 def __iter__(self):
78 78 return iter(self.__dict__.iteritems())
79 79
80 80 def __repr__(self):
81 81 return repr(self.__dict__)
82 82
83 83 def __str__(self):
84 84 return pprint.pformat(self.__dict__)
85 85
86 86 def __contains__(self, k):
87 87 return k in self.__dict__
88 88
89 89 def __getitem__(self, k):
90 90 return self.__dict__[k]
91 91
92 92
93 93 def msg_header(msg_id, msg_type, username, session):
94 94 date=datetime.now().strftime(ISO8601)
95 95 return locals()
96 96
97 97 def extract_header(msg_or_header):
98 98 """Given a message or header, return the header."""
99 99 if not msg_or_header:
100 100 return {}
101 101 try:
102 102 # See if msg_or_header is the entire message.
103 103 h = msg_or_header['header']
104 104 except KeyError:
105 105 try:
106 106 # See if msg_or_header is just the header
107 107 h = msg_or_header['msg_id']
108 108 except KeyError:
109 109 raise
110 110 else:
111 111 h = msg_or_header
112 112 if not isinstance(h, dict):
113 113 h = dict(h)
114 114 return h
115 115
116 116 class StreamSession(object):
117 117 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
118 118 debug=False
119 119 key=None
120 120
121 121 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
122 122 if username is None:
123 123 username = os.environ.get('USER','username')
124 124 self.username = username
125 125 if session is None:
126 126 self.session = str(uuid.uuid4())
127 127 else:
128 128 self.session = session
129 129 self.msg_id = str(uuid.uuid4())
130 130 if packer is None:
131 131 self.pack = default_packer
132 132 else:
133 133 if not callable(packer):
134 134 raise TypeError("packer must be callable, not %s"%type(packer))
135 135 self.pack = packer
136 136
137 137 if unpacker is None:
138 138 self.unpack = default_unpacker
139 139 else:
140 140 if not callable(unpacker):
141 141 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
142 142 self.unpack = unpacker
143 143
144 144 if key is not None and keyfile is not None:
145 145 raise TypeError("Must specify key OR keyfile, not both")
146 146 if keyfile is not None:
147 147 with open(keyfile) as f:
148 148 self.key = f.read().strip()
149 149 else:
150 150 self.key = key
151 151 if isinstance(self.key, unicode):
152 152 self.key = self.key.encode('utf8')
153 153 # print key, keyfile, self.key
154 154 self.none = self.pack({})
155 155
156 156 def msg_header(self, msg_type):
157 157 h = msg_header(self.msg_id, msg_type, self.username, self.session)
158 158 self.msg_id = str(uuid.uuid4())
159 159 return h
160 160
161 161 def msg(self, msg_type, content=None, parent=None, subheader=None):
162 162 msg = {}
163 163 msg['header'] = self.msg_header(msg_type)
164 164 msg['msg_id'] = msg['header']['msg_id']
165 165 msg['parent_header'] = {} if parent is None else extract_header(parent)
166 166 msg['msg_type'] = msg_type
167 167 msg['content'] = {} if content is None else content
168 168 sub = {} if subheader is None else subheader
169 169 msg['header'].update(sub)
170 170 return msg
171 171
172 172 def check_key(self, msg_or_header):
173 173 """Check that a message's header has the right key"""
174 174 if self.key is None:
175 175 return True
176 176 header = extract_header(msg_or_header)
177 177 return header.get('key', None) == self.key
178 178
179
180 def serialize(self, msg, ident=None):
181 content = msg.get('content', {})
182 if content is None:
183 content = self.none
184 elif isinstance(content, dict):
185 content = self.pack(content)
186 elif isinstance(content, bytes):
187 # content is already packed, as in a relayed message
188 pass
189 else:
190 raise TypeError("Content incorrect type: %s"%type(content))
191
192 to_send = []
193
194 if isinstance(ident, list):
195 # accept list of idents
196 to_send.extend(ident)
197 elif ident is not None:
198 to_send.append(ident)
199 to_send.append(DELIM)
200 if self.key is not None:
201 to_send.append(self.key)
202 to_send.append(self.pack(msg['header']))
203 to_send.append(self.pack(msg['parent_header']))
204 to_send.append(content)
205
206 return to_send
179 207
180 208 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
181 209 """Build and send a message via stream or socket.
182 210
183 211 Parameters
184 212 ----------
185 213
186 214 stream : zmq.Socket or ZMQStream
187 215 the socket-like object used to send the data
188 216 msg_or_type : str or Message/dict
189 217 Normally, msg_or_type will be a msg_type unless a message is being sent more
190 218 than once.
191 219
192 220 content : dict or None
193 221 the content of the message (ignored if msg_or_type is a message)
194 222 buffers : list or None
195 223 the already-serialized buffers to be appended to the message
196 224 parent : Message or dict or None
197 225 the parent or parent header describing the parent of this message
198 226 subheader : dict or None
199 227 extra header keys for this message's header
200 228 ident : bytes or list of bytes
201 229 the zmq.IDENTITY routing path
202 230 track : bool
203 231 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
204 232
205 233 Returns
206 234 -------
207 235 msg : message dict
208 236 the constructed message
209 237 (msg,tracker) : (message dict, MessageTracker)
210 238 if track=True, then a 2-tuple will be returned, the first element being the constructed
211 239 message, and the second being the MessageTracker
212 240
213 241 """
214 242
215 243 if not isinstance(stream, (zmq.Socket, ZMQStream)):
216 244 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
217 245 elif track and isinstance(stream, ZMQStream):
218 246 raise TypeError("ZMQStream cannot track messages")
219 247
220 248 if isinstance(msg_or_type, (Message, dict)):
221 249 # we got a Message, not a msg_type
222 250 # don't build a new Message
223 251 msg = msg_or_type
224 content = msg['content']
225 252 else:
226 253 msg = self.msg(msg_or_type, content, parent, subheader)
227 254
228 255 buffers = [] if buffers is None else buffers
229 to_send = []
230 if isinstance(ident, list):
231 # accept list of idents
232 to_send.extend(ident)
233 elif ident is not None:
234 to_send.append(ident)
235 to_send.append(DELIM)
236 if self.key is not None:
237 to_send.append(self.key)
238 to_send.append(self.pack(msg['header']))
239 to_send.append(self.pack(msg['parent_header']))
240
241 if content is None:
242 content = self.none
243 elif isinstance(content, dict):
244 content = self.pack(content)
245 elif isinstance(content, bytes):
246 # content is already packed, as in a relayed message
247 pass
248 else:
249 raise TypeError("Content incorrect type: %s"%type(content))
250 to_send.append(content)
256 to_send = self.serialize(msg, ident)
251 257 flag = 0
252 258 if buffers:
253 259 flag = zmq.SNDMORE
254 260 _track = False
255 261 else:
256 262 _track=track
257 263 if track:
258 264 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
259 265 else:
260 266 tracker = stream.send_multipart(to_send, flag, copy=False)
261 267 for b in buffers[:-1]:
262 268 stream.send(b, flag, copy=False)
263 269 if buffers:
264 270 if track:
265 271 tracker = stream.send(buffers[-1], copy=False, track=track)
266 272 else:
267 273 tracker = stream.send(buffers[-1], copy=False)
268 274
269 275 # omsg = Message(msg)
270 276 if self.debug:
271 277 pprint.pprint(msg)
272 278 pprint.pprint(to_send)
273 279 pprint.pprint(buffers)
274 280
275 281 msg['tracker'] = tracker
276 282
277 283 return msg
278 284
279 285 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
280 286 """Send a raw message via ident path.
281 287
282 288 Parameters
283 289 ----------
284 290 msg : list of sendable buffers"""
285 291 to_send = []
286 292 if isinstance(ident, bytes):
287 293 ident = [ident]
288 294 if ident is not None:
289 295 to_send.extend(ident)
290 296 to_send.append(DELIM)
291 297 if self.key is not None:
292 298 to_send.append(self.key)
293 299 to_send.extend(msg)
294 300 stream.send_multipart(msg, flags, copy=copy)
295 301
296 302 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
297 303 """receives and unpacks a message
298 304 returns [idents], msg"""
299 305 if isinstance(socket, ZMQStream):
300 306 socket = socket.socket
301 307 try:
302 308 msg = socket.recv_multipart(mode)
303 309 except zmq.ZMQError as e:
304 310 if e.errno == zmq.EAGAIN:
305 311 # We can convert EAGAIN to None as we know in this case
306 312 # recv_multipart won't return None.
307 313 return None
308 314 else:
309 315 raise
310 316 # return an actual Message object
311 317 # determine the number of idents by trying to unpack them.
312 318 # this is terrible:
313 319 idents, msg = self.feed_identities(msg, copy)
314 320 try:
315 321 return idents, self.unpack_message(msg, content=content, copy=copy)
316 322 except Exception as e:
317 323 print (idents, msg)
318 324 # TODO: handle it
319 325 raise e
320 326
321 327 def feed_identities(self, msg, copy=True):
322 328 """feed until DELIM is reached, then return the prefix as idents and remainder as
323 329 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
324 330
325 331 Parameters
326 332 ----------
327 333 msg : a list of Message or bytes objects
328 334 the message to be split
329 335 copy : bool
330 336 flag determining whether the arguments are bytes or Messages
331 337
332 338 Returns
333 339 -------
334 340 (idents,msg) : two lists
335 341 idents will always be a list of bytes - the indentity prefix
336 342 msg will be a list of bytes or Messages, unchanged from input
337 343 msg should be unpackable via self.unpack_message at this point.
338 344 """
339 345 ikey = int(self.key is not None)
340 346 minlen = 3 + ikey
341 347 msg = list(msg)
342 348 idents = []
343 349 while len(msg) > minlen:
344 350 if copy:
345 351 s = msg[0]
346 352 else:
347 353 s = msg[0].bytes
348 354 if s == DELIM:
349 355 msg.pop(0)
350 356 break
351 357 else:
352 358 idents.append(s)
353 359 msg.pop(0)
354 360
355 361 return idents, msg
356 362
357 363 def unpack_message(self, msg, content=True, copy=True):
358 364 """Return a message object from the format
359 365 sent by self.send.
360 366
361 367 Parameters:
362 368 -----------
363 369
364 370 content : bool (True)
365 371 whether to unpack the content dict (True),
366 372 or leave it serialized (False)
367 373
368 374 copy : bool (True)
369 375 whether to return the bytes (True),
370 376 or the non-copying Message object in each place (False)
371 377
372 378 """
373 379 ikey = int(self.key is not None)
374 380 minlen = 3 + ikey
375 381 message = {}
376 382 if not copy:
377 383 for i in range(minlen):
378 384 msg[i] = msg[i].bytes
379 385 if ikey:
380 386 if not self.key == msg[0]:
381 387 raise KeyError("Invalid Session Key: %s"%msg[0])
382 388 if not len(msg) >= minlen:
383 389 raise TypeError("malformed message, must have at least %i elements"%minlen)
384 390 message['header'] = self.unpack(msg[ikey+0])
385 391 message['msg_type'] = message['header']['msg_type']
386 392 message['parent_header'] = self.unpack(msg[ikey+1])
387 393 if content:
388 394 message['content'] = self.unpack(msg[ikey+2])
389 395 else:
390 396 message['content'] = msg[ikey+2]
391 397
392 398 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
393 399 return message
394 400
395 401
396 402 def test_msg2obj():
397 403 am = dict(x=1)
398 404 ao = Message(am)
399 405 assert ao.x == am['x']
400 406
401 407 am['y'] = dict(z=1)
402 408 ao = Message(am)
403 409 assert ao.y.z == am['y']['z']
404 410
405 411 k1, k2 = 'y', 'z'
406 412 assert ao[k1][k2] == am[k1][k2]
407 413
408 414 am2 = dict(ao)
409 415 assert am['x'] == am2['x']
410 416 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now