##// END OF EJS Templates
str(etype)
MinRK -
Show More
@@ -1,531 +1,531 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11 from datetime import datetime
12 12
13 13 import zmq
14 14 from zmq.utils import jsonapi
15 15 from zmq.eventloop.zmqstream import ZMQStream
16 16
17 17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 18 from IPython.utils.newserialized import serialize, unserialize
19 19
20 20 try:
21 21 import cPickle
22 22 pickle = cPickle
23 23 except:
24 24 cPickle = None
25 25 import pickle
26 26
27 27 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
28 28 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
29 29 if json_name in ('jsonlib', 'jsonlib2'):
30 30 use_json = True
31 31 elif json_name:
32 32 if cPickle is None:
33 33 use_json = True
34 34 else:
35 35 use_json = False
36 36 else:
37 37 use_json = False
38 38
39 39 def squash_unicode(obj):
40 40 if isinstance(obj,dict):
41 41 for key in obj.keys():
42 42 obj[key] = squash_unicode(obj[key])
43 43 if isinstance(key, unicode):
44 44 obj[squash_unicode(key)] = obj.pop(key)
45 45 elif isinstance(obj, list):
46 46 for i,v in enumerate(obj):
47 47 obj[i] = squash_unicode(v)
48 48 elif isinstance(obj, unicode):
49 49 obj = obj.encode('utf8')
50 50 return obj
51 51
52 52 if use_json:
53 53 default_packer = jsonapi.dumps
54 54 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
55 55 else:
56 56 default_packer = lambda o: pickle.dumps(o,-1)
57 57 default_unpacker = pickle.loads
58 58
59 59
60 60 DELIM="<IDS|MSG>"
61 61 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
62 62
63 63 def wrap_exception():
64 64 etype, evalue, tb = sys.exc_info()
65 65 tb = traceback.format_exception(etype, evalue, tb)
66 66 exc_content = {
67 67 'status' : 'error',
68 68 'traceback' : [ line.encode('utf8') for line in tb ],
69 'etype' : etype.encode('utf8'),
69 'etype' : str(etype).encode('utf8'),
70 70 'evalue' : evalue.encode('utf8')
71 71 }
72 72 return exc_content
73 73
74 74 class KernelError(Exception):
75 75 pass
76 76
77 77 def unwrap_exception(content):
78 78 err = KernelError(content['etype'], content['evalue'])
79 79 err.evalue = content['evalue']
80 80 err.etype = content['etype']
81 81 err.traceback = ''.join(content['traceback'])
82 82 return err
83 83
84 84
85 85 class Message(object):
86 86 """A simple message object that maps dict keys to attributes.
87 87
88 88 A Message can be created from a dict and a dict from a Message instance
89 89 simply by calling dict(msg_obj)."""
90 90
91 91 def __init__(self, msg_dict):
92 92 dct = self.__dict__
93 93 for k, v in dict(msg_dict).iteritems():
94 94 if isinstance(v, dict):
95 95 v = Message(v)
96 96 dct[k] = v
97 97
98 98 # Having this iterator lets dict(msg_obj) work out of the box.
99 99 def __iter__(self):
100 100 return iter(self.__dict__.iteritems())
101 101
102 102 def __repr__(self):
103 103 return repr(self.__dict__)
104 104
105 105 def __str__(self):
106 106 return pprint.pformat(self.__dict__)
107 107
108 108 def __contains__(self, k):
109 109 return k in self.__dict__
110 110
111 111 def __getitem__(self, k):
112 112 return self.__dict__[k]
113 113
114 114
115 115 def msg_header(msg_id, msg_type, username, session):
116 116 date=datetime.now().strftime(ISO8601)
117 117 return locals()
118 118
119 119 def extract_header(msg_or_header):
120 120 """Given a message or header, return the header."""
121 121 if not msg_or_header:
122 122 return {}
123 123 try:
124 124 # See if msg_or_header is the entire message.
125 125 h = msg_or_header['header']
126 126 except KeyError:
127 127 try:
128 128 # See if msg_or_header is just the header
129 129 h = msg_or_header['msg_id']
130 130 except KeyError:
131 131 raise
132 132 else:
133 133 h = msg_or_header
134 134 if not isinstance(h, dict):
135 135 h = dict(h)
136 136 return h
137 137
138 138 def rekey(dikt):
139 139 """Rekey a dict that has been forced to use str keys where there should be
140 140 ints by json. This belongs in the jsonutil added by fperez."""
141 141 for k in dikt.iterkeys():
142 142 if isinstance(k, str):
143 143 ik=fk=None
144 144 try:
145 145 ik = int(k)
146 146 except ValueError:
147 147 try:
148 148 fk = float(k)
149 149 except ValueError:
150 150 continue
151 151 if ik is not None:
152 152 nk = ik
153 153 else:
154 154 nk = fk
155 155 if nk in dikt:
156 156 raise KeyError("already have key %r"%nk)
157 157 dikt[nk] = dikt.pop(k)
158 158 return dikt
159 159
160 160 def serialize_object(obj, threshold=64e-6):
161 161 """Serialize an object into a list of sendable buffers.
162 162
163 163 Parameters
164 164 ----------
165 165
166 166 obj : object
167 167 The object to be serialized
168 168 threshold : float
169 169 The threshold for not double-pickling the content.
170 170
171 171
172 172 Returns
173 173 -------
174 174 ('pmd', [bufs]) :
175 175 where pmd is the pickled metadata wrapper,
176 176 bufs is a list of data buffers
177 177 """
178 178 databuffers = []
179 179 if isinstance(obj, (list, tuple)):
180 180 clist = canSequence(obj)
181 181 slist = map(serialize, clist)
182 182 for s in slist:
183 183 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
184 184 databuffers.append(s.getData())
185 185 s.data = None
186 186 return pickle.dumps(slist,-1), databuffers
187 187 elif isinstance(obj, dict):
188 188 sobj = {}
189 189 for k in sorted(obj.iterkeys()):
190 190 s = serialize(can(obj[k]))
191 191 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
192 192 databuffers.append(s.getData())
193 193 s.data = None
194 194 sobj[k] = s
195 195 return pickle.dumps(sobj,-1),databuffers
196 196 else:
197 197 s = serialize(can(obj))
198 198 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
199 199 databuffers.append(s.getData())
200 200 s.data = None
201 201 return pickle.dumps(s,-1),databuffers
202 202
203 203
204 204 def unserialize_object(bufs):
205 205 """reconstruct an object serialized by serialize_object from data buffers"""
206 206 bufs = list(bufs)
207 207 sobj = pickle.loads(bufs.pop(0))
208 208 if isinstance(sobj, (list, tuple)):
209 209 for s in sobj:
210 210 if s.data is None:
211 211 s.data = bufs.pop(0)
212 212 return uncanSequence(map(unserialize, sobj))
213 213 elif isinstance(sobj, dict):
214 214 newobj = {}
215 215 for k in sorted(sobj.iterkeys()):
216 216 s = sobj[k]
217 217 if s.data is None:
218 218 s.data = bufs.pop(0)
219 219 newobj[k] = uncan(unserialize(s))
220 220 return newobj
221 221 else:
222 222 if sobj.data is None:
223 223 sobj.data = bufs.pop(0)
224 224 return uncan(unserialize(sobj))
225 225
226 226 def pack_apply_message(f, args, kwargs, threshold=64e-6):
227 227 """pack up a function, args, and kwargs to be sent over the wire
228 228 as a series of buffers. Any object whose data is larger than `threshold`
229 229 will not have their data copied (currently only numpy arrays support zero-copy)"""
230 230 msg = [pickle.dumps(can(f),-1)]
231 231 databuffers = [] # for large objects
232 232 sargs, bufs = serialize_object(args,threshold)
233 233 msg.append(sargs)
234 234 databuffers.extend(bufs)
235 235 skwargs, bufs = serialize_object(kwargs,threshold)
236 236 msg.append(skwargs)
237 237 databuffers.extend(bufs)
238 238 msg.extend(databuffers)
239 239 return msg
240 240
241 241 def unpack_apply_message(bufs, g=None, copy=True):
242 242 """unpack f,args,kwargs from buffers packed by pack_apply_message()
243 243 Returns: original f,args,kwargs"""
244 244 bufs = list(bufs) # allow us to pop
245 245 assert len(bufs) >= 3, "not enough buffers!"
246 246 if not copy:
247 247 for i in range(3):
248 248 bufs[i] = bufs[i].bytes
249 249 cf = pickle.loads(bufs.pop(0))
250 250 sargs = list(pickle.loads(bufs.pop(0)))
251 251 skwargs = dict(pickle.loads(bufs.pop(0)))
252 252 # print sargs, skwargs
253 253 f = uncan(cf, g)
254 254 for sa in sargs:
255 255 if sa.data is None:
256 256 m = bufs.pop(0)
257 257 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
258 258 if copy:
259 259 sa.data = buffer(m)
260 260 else:
261 261 sa.data = m.buffer
262 262 else:
263 263 if copy:
264 264 sa.data = m
265 265 else:
266 266 sa.data = m.bytes
267 267
268 268 args = uncanSequence(map(unserialize, sargs), g)
269 269 kwargs = {}
270 270 for k in sorted(skwargs.iterkeys()):
271 271 sa = skwargs[k]
272 272 if sa.data is None:
273 273 sa.data = bufs.pop(0)
274 274 kwargs[k] = uncan(unserialize(sa), g)
275 275
276 276 return f,args,kwargs
277 277
278 278 class StreamSession(object):
279 279 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
280 280 debug=False
281 281 key=None
282 282
283 283 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
284 284 if username is None:
285 285 username = os.environ.get('USER','username')
286 286 self.username = username
287 287 if session is None:
288 288 self.session = str(uuid.uuid4())
289 289 else:
290 290 self.session = session
291 291 self.msg_id = str(uuid.uuid4())
292 292 if packer is None:
293 293 self.pack = default_packer
294 294 else:
295 295 if not callable(packer):
296 296 raise TypeError("packer must be callable, not %s"%type(packer))
297 297 self.pack = packer
298 298
299 299 if unpacker is None:
300 300 self.unpack = default_unpacker
301 301 else:
302 302 if not callable(unpacker):
303 303 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
304 304 self.unpack = unpacker
305 305
306 306 if key is not None and keyfile is not None:
307 307 raise TypeError("Must specify key OR keyfile, not both")
308 308 if keyfile is not None:
309 309 with open(keyfile) as f:
310 310 self.key = f.read().strip()
311 311 else:
312 312 self.key = key
313 313 # print key, keyfile, self.key
314 314 self.none = self.pack({})
315 315
316 316 def msg_header(self, msg_type):
317 317 h = msg_header(self.msg_id, msg_type, self.username, self.session)
318 318 self.msg_id = str(uuid.uuid4())
319 319 return h
320 320
321 321 def msg(self, msg_type, content=None, parent=None, subheader=None):
322 322 msg = {}
323 323 msg['header'] = self.msg_header(msg_type)
324 324 msg['msg_id'] = msg['header']['msg_id']
325 325 msg['parent_header'] = {} if parent is None else extract_header(parent)
326 326 msg['msg_type'] = msg_type
327 327 msg['content'] = {} if content is None else content
328 328 sub = {} if subheader is None else subheader
329 329 msg['header'].update(sub)
330 330 return msg
331 331
332 332 def check_key(self, msg_or_header):
333 333 """Check that a message's header has the right key"""
334 334 if self.key is None:
335 335 return True
336 336 header = extract_header(msg_or_header)
337 337 return header.get('key', None) == self.key
338 338
339 339
340 340 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
341 341 """Build and send a message via stream or socket.
342 342
343 343 Parameters
344 344 ----------
345 345
346 346 stream : zmq.Socket or ZMQStream
347 347 the socket-like object used to send the data
348 348 msg_type : str or Message/dict
349 349 Normally, msg_type will be
350 350
351 351
352 352
353 353 Returns
354 354 -------
355 355 (msg,sent) : tuple
356 356 msg : Message
357 357 the nice wrapped dict-like object containing the headers
358 358
359 359 """
360 360 if isinstance(msg_type, (Message, dict)):
361 361 # we got a Message, not a msg_type
362 362 # don't build a new Message
363 363 msg = msg_type
364 364 content = msg['content']
365 365 else:
366 366 msg = self.msg(msg_type, content, parent, subheader)
367 367 buffers = [] if buffers is None else buffers
368 368 to_send = []
369 369 if isinstance(ident, list):
370 370 # accept list of idents
371 371 to_send.extend(ident)
372 372 elif ident is not None:
373 373 to_send.append(ident)
374 374 to_send.append(DELIM)
375 375 if self.key is not None:
376 376 to_send.append(self.key)
377 377 to_send.append(self.pack(msg['header']))
378 378 to_send.append(self.pack(msg['parent_header']))
379 379
380 380 if content is None:
381 381 content = self.none
382 382 elif isinstance(content, dict):
383 383 content = self.pack(content)
384 384 elif isinstance(content, str):
385 385 # content is already packed, as in a relayed message
386 386 pass
387 387 else:
388 388 raise TypeError("Content incorrect type: %s"%type(content))
389 389 to_send.append(content)
390 390 flag = 0
391 391 if buffers:
392 392 flag = zmq.SNDMORE
393 393 stream.send_multipart(to_send, flag, copy=False)
394 394 for b in buffers[:-1]:
395 395 stream.send(b, flag, copy=False)
396 396 if buffers:
397 397 stream.send(buffers[-1], copy=False)
398 398 omsg = Message(msg)
399 399 if self.debug:
400 400 pprint.pprint(omsg)
401 401 pprint.pprint(to_send)
402 402 pprint.pprint(buffers)
403 403 return omsg
404 404
405 405 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
406 406 """Send a raw message via idents.
407 407
408 408 Parameters
409 409 ----------
410 410 msg : list of sendable buffers"""
411 411 to_send = []
412 412 if isinstance(ident, str):
413 413 ident = [ident]
414 414 if ident is not None:
415 415 to_send.extend(ident)
416 416 to_send.append(DELIM)
417 417 if self.key is not None:
418 418 to_send.append(self.key)
419 419 to_send.extend(msg)
420 420 stream.send_multipart(msg, flags, copy=copy)
421 421
422 422 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
423 423 """receives and unpacks a message
424 424 returns [idents], msg"""
425 425 if isinstance(socket, ZMQStream):
426 426 socket = socket.socket
427 427 try:
428 428 msg = socket.recv_multipart(mode)
429 429 except zmq.ZMQError as e:
430 430 if e.errno == zmq.EAGAIN:
431 431 # We can convert EAGAIN to None as we know in this case
432 432 # recv_json won't return None.
433 433 return None
434 434 else:
435 435 raise
436 436 # return an actual Message object
437 437 # determine the number of idents by trying to unpack them.
438 438 # this is terrible:
439 439 idents, msg = self.feed_identities(msg, copy)
440 440 try:
441 441 return idents, self.unpack_message(msg, content=content, copy=copy)
442 442 except Exception as e:
443 443 print (idents, msg)
444 444 # TODO: handle it
445 445 raise e
446 446
447 447 def feed_identities(self, msg, copy=True):
448 448 """This is a completely horrible thing, but it strips the zmq
449 449 ident prefixes off of a message. It will break if any identities
450 450 are unpackable by self.unpack."""
451 451 msg = list(msg)
452 452 idents = []
453 453 while len(msg) > 3:
454 454 if copy:
455 455 s = msg[0]
456 456 else:
457 457 s = msg[0].bytes
458 458 if s == DELIM:
459 459 msg.pop(0)
460 460 break
461 461 else:
462 462 idents.append(s)
463 463 msg.pop(0)
464 464
465 465 return idents, msg
466 466
467 467 def unpack_message(self, msg, content=True, copy=True):
468 468 """Return a message object from the format
469 469 sent by self.send.
470 470
471 471 Parameters:
472 472 -----------
473 473
474 474 content : bool (True)
475 475 whether to unpack the content dict (True),
476 476 or leave it serialized (False)
477 477
478 478 copy : bool (True)
479 479 whether to return the bytes (True),
480 480 or the non-copying Message object in each place (False)
481 481
482 482 """
483 483 ikey = int(self.key is not None)
484 484 minlen = 3 + ikey
485 485 if not len(msg) >= minlen:
486 486 raise TypeError("malformed message, must have at least %i elements"%minlen)
487 487 message = {}
488 488 if not copy:
489 489 for i in range(minlen):
490 490 msg[i] = msg[i].bytes
491 491 if ikey:
492 492 if not self.key == msg[0]:
493 493 raise KeyError("Invalid Session Key: %s"%msg[0])
494 494 message['header'] = self.unpack(msg[ikey+0])
495 495 message['msg_type'] = message['header']['msg_type']
496 496 message['parent_header'] = self.unpack(msg[ikey+1])
497 497 if content:
498 498 message['content'] = self.unpack(msg[ikey+2])
499 499 else:
500 500 message['content'] = msg[ikey+2]
501 501
502 502 # message['buffers'] = msg[3:]
503 503 # else:
504 504 # message['header'] = self.unpack(msg[0].bytes)
505 505 # message['msg_type'] = message['header']['msg_type']
506 506 # message['parent_header'] = self.unpack(msg[1].bytes)
507 507 # if content:
508 508 # message['content'] = self.unpack(msg[2].bytes)
509 509 # else:
510 510 # message['content'] = msg[2].bytes
511 511
512 512 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
513 513 return message
514 514
515 515
516 516
517 517 def test_msg2obj():
518 518 am = dict(x=1)
519 519 ao = Message(am)
520 520 assert ao.x == am['x']
521 521
522 522 am['y'] = dict(z=1)
523 523 ao = Message(am)
524 524 assert ao.y.z == am['y']['z']
525 525
526 526 k1, k2 = 'y', 'z'
527 527 assert ao[k1][k2] == am[k1][k2]
528 528
529 529 am2 = dict(ao)
530 530 assert am['x'] == am2['x']
531 531 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now