##// END OF EJS Templates
Do not flatten unicode on unpacking
Thomas Kluyver -
Show More
@@ -1,627 +1,627 b''
1 1 #!/usr/bin/env python
2 2 """Session object for building, serializing, sending, and receiving messages in
3 3 IPython. The Session object supports serialization, HMAC signatures, and
4 4 metadata on messages.
5 5
6 6 Also defined here are utilities for working with Sessions:
7 7 * A SessionFactory to be used as a base class for configurables that work with
8 8 Sessions.
9 9 * A Message object for convenience that allows attribute-access to the msg dict.
10 10
11 11 Authors:
12 12
13 13 * Min RK
14 14 * Brian Granger
15 15 * Fernando Perez
16 16 """
17 17 #-----------------------------------------------------------------------------
18 18 # Copyright (C) 2010-2011 The IPython Development Team
19 19 #
20 20 # Distributed under the terms of the BSD License. The full license is in
21 21 # the file COPYING, distributed as part of this software.
22 22 #-----------------------------------------------------------------------------
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Imports
26 26 #-----------------------------------------------------------------------------
27 27
28 28 import hmac
29 29 import logging
30 30 import os
31 31 import pprint
32 32 import uuid
33 33 from datetime import datetime
34 34
35 35 try:
36 36 import cPickle
37 37 pickle = cPickle
38 38 except:
39 39 cPickle = None
40 40 import pickle
41 41
42 42 import zmq
43 43 from zmq.utils import jsonapi
44 44 from zmq.eventloop.ioloop import IOLoop
45 45 from zmq.eventloop.zmqstream import ZMQStream
46 46
47 47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 48 from IPython.utils.importstring import import_item
49 49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 50 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
51 51
52 52 #-----------------------------------------------------------------------------
53 53 # utility functions
54 54 #-----------------------------------------------------------------------------
55 55
56 56 def squash_unicode(obj):
57 57 """coerce unicode back to bytestrings."""
58 58 if isinstance(obj,dict):
59 59 for key in obj.keys():
60 60 obj[key] = squash_unicode(obj[key])
61 61 if isinstance(key, unicode):
62 62 obj[squash_unicode(key)] = obj.pop(key)
63 63 elif isinstance(obj, list):
64 64 for i,v in enumerate(obj):
65 65 obj[i] = squash_unicode(v)
66 66 elif isinstance(obj, unicode):
67 67 obj = obj.encode('utf8')
68 68 return obj
69 69
70 70 #-----------------------------------------------------------------------------
71 71 # globals and defaults
72 72 #-----------------------------------------------------------------------------
73 73 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
74 74 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
75 json_unpacker = lambda s: squash_unicode(extract_dates(jsonapi.loads(s)))
75 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
76 76
77 77 pickle_packer = lambda o: pickle.dumps(o,-1)
78 78 pickle_unpacker = pickle.loads
79 79
80 80 default_packer = json_packer
81 81 default_unpacker = json_unpacker
82 82
83 83
84 84 DELIM="<IDS|MSG>"
85 85
86 86 #-----------------------------------------------------------------------------
87 87 # Classes
88 88 #-----------------------------------------------------------------------------
89 89
90 90 class SessionFactory(LoggingConfigurable):
91 91 """The Base class for configurables that have a Session, Context, logger,
92 92 and IOLoop.
93 93 """
94 94
95 95 logname = Unicode('')
96 96 def _logname_changed(self, name, old, new):
97 97 self.log = logging.getLogger(new)
98 98
99 99 # not configurable:
100 100 context = Instance('zmq.Context')
101 101 def _context_default(self):
102 102 return zmq.Context.instance()
103 103
104 104 session = Instance('IPython.zmq.session.Session')
105 105
106 106 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
107 107 def _loop_default(self):
108 108 return IOLoop.instance()
109 109
110 110 def __init__(self, **kwargs):
111 111 super(SessionFactory, self).__init__(**kwargs)
112 112
113 113 if self.session is None:
114 114 # construct the session
115 115 self.session = Session(**kwargs)
116 116
117 117
118 118 class Message(object):
119 119 """A simple message object that maps dict keys to attributes.
120 120
121 121 A Message can be created from a dict and a dict from a Message instance
122 122 simply by calling dict(msg_obj)."""
123 123
124 124 def __init__(self, msg_dict):
125 125 dct = self.__dict__
126 126 for k, v in dict(msg_dict).iteritems():
127 127 if isinstance(v, dict):
128 128 v = Message(v)
129 129 dct[k] = v
130 130
131 131 # Having this iterator lets dict(msg_obj) work out of the box.
132 132 def __iter__(self):
133 133 return iter(self.__dict__.iteritems())
134 134
135 135 def __repr__(self):
136 136 return repr(self.__dict__)
137 137
138 138 def __str__(self):
139 139 return pprint.pformat(self.__dict__)
140 140
141 141 def __contains__(self, k):
142 142 return k in self.__dict__
143 143
144 144 def __getitem__(self, k):
145 145 return self.__dict__[k]
146 146
147 147
148 148 def msg_header(msg_id, msg_type, username, session):
149 149 date = datetime.now()
150 150 return locals()
151 151
152 152 def extract_header(msg_or_header):
153 153 """Given a message or header, return the header."""
154 154 if not msg_or_header:
155 155 return {}
156 156 try:
157 157 # See if msg_or_header is the entire message.
158 158 h = msg_or_header['header']
159 159 except KeyError:
160 160 try:
161 161 # See if msg_or_header is just the header
162 162 h = msg_or_header['msg_id']
163 163 except KeyError:
164 164 raise
165 165 else:
166 166 h = msg_or_header
167 167 if not isinstance(h, dict):
168 168 h = dict(h)
169 169 return h
170 170
171 171 class Session(Configurable):
172 172 """Object for handling serialization and sending of messages.
173 173
174 174 The Session object handles building messages and sending them
175 175 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
176 176 other over the network via Session objects, and only need to work with the
177 177 dict-based IPython message spec. The Session will handle
178 178 serialization/deserialization, security, and metadata.
179 179
180 180 Sessions support configurable serialiization via packer/unpacker traits,
181 181 and signing with HMAC digests via the key/keyfile traits.
182 182
183 183 Parameters
184 184 ----------
185 185
186 186 debug : bool
187 187 whether to trigger extra debugging statements
188 188 packer/unpacker : str : 'json', 'pickle' or import_string
189 189 importstrings for methods to serialize message parts. If just
190 190 'json' or 'pickle', predefined JSON and pickle packers will be used.
191 191 Otherwise, the entire importstring must be used.
192 192
193 193 The functions must accept at least valid JSON input, and output *bytes*.
194 194
195 195 For example, to use msgpack:
196 196 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
197 197 pack/unpack : callables
198 198 You can also set the pack/unpack callables for serialization directly.
199 199 session : bytes
200 200 the ID of this Session object. The default is to generate a new UUID.
201 201 username : unicode
202 202 username added to message headers. The default is to ask the OS.
203 203 key : bytes
204 204 The key used to initialize an HMAC signature. If unset, messages
205 205 will not be signed or checked.
206 206 keyfile : filepath
207 207 The file containing a key. If this is set, `key` will be initialized
208 208 to the contents of the file.
209 209
210 210 """
211 211
212 212 debug=Bool(False, config=True, help="""Debug output in the Session""")
213 213
214 214 packer = Unicode('json',config=True,
215 215 help="""The name of the packer for serializing messages.
216 216 Should be one of 'json', 'pickle', or an import name
217 217 for a custom callable serializer.""")
218 218 def _packer_changed(self, name, old, new):
219 219 if new.lower() == 'json':
220 220 self.pack = json_packer
221 221 self.unpack = json_unpacker
222 222 elif new.lower() == 'pickle':
223 223 self.pack = pickle_packer
224 224 self.unpack = pickle_unpacker
225 225 else:
226 226 self.pack = import_item(str(new))
227 227
228 228 unpacker = Unicode('json', config=True,
229 229 help="""The name of the unpacker for unserializing messages.
230 230 Only used with custom functions for `packer`.""")
231 231 def _unpacker_changed(self, name, old, new):
232 232 if new.lower() == 'json':
233 233 self.pack = json_packer
234 234 self.unpack = json_unpacker
235 235 elif new.lower() == 'pickle':
236 236 self.pack = pickle_packer
237 237 self.unpack = pickle_unpacker
238 238 else:
239 239 self.unpack = import_item(str(new))
240 240
241 241 session = CStr('', config=True,
242 242 help="""The UUID identifying this session.""")
243 243 def _session_default(self):
244 244 return bytes(uuid.uuid4())
245 245
246 246 username = Unicode(os.environ.get('USER','username'), config=True,
247 247 help="""Username for the Session. Default is your system username.""")
248 248
249 249 # message signature related traits:
250 250 key = CStr('', config=True,
251 251 help="""execution key, for extra authentication.""")
252 252 def _key_changed(self, name, old, new):
253 253 if new:
254 254 self.auth = hmac.HMAC(new)
255 255 else:
256 256 self.auth = None
257 257 auth = Instance(hmac.HMAC)
258 258 digest_history = Set()
259 259
260 260 keyfile = Unicode('', config=True,
261 261 help="""path to file containing execution key.""")
262 262 def _keyfile_changed(self, name, old, new):
263 263 with open(new, 'rb') as f:
264 264 self.key = f.read().strip()
265 265
266 266 pack = Any(default_packer) # the actual packer function
267 267 def _pack_changed(self, name, old, new):
268 268 if not callable(new):
269 269 raise TypeError("packer must be callable, not %s"%type(new))
270 270
271 271 unpack = Any(default_unpacker) # the actual packer function
272 272 def _unpack_changed(self, name, old, new):
273 273 # unpacker is not checked - it is assumed to be
274 274 if not callable(new):
275 275 raise TypeError("unpacker must be callable, not %s"%type(new))
276 276
277 277 def __init__(self, **kwargs):
278 278 """create a Session object
279 279
280 280 Parameters
281 281 ----------
282 282
283 283 debug : bool
284 284 whether to trigger extra debugging statements
285 285 packer/unpacker : str : 'json', 'pickle' or import_string
286 286 importstrings for methods to serialize message parts. If just
287 287 'json' or 'pickle', predefined JSON and pickle packers will be used.
288 288 Otherwise, the entire importstring must be used.
289 289
290 290 The functions must accept at least valid JSON input, and output
291 291 *bytes*.
292 292
293 293 For example, to use msgpack:
294 294 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
295 295 pack/unpack : callables
296 296 You can also set the pack/unpack callables for serialization
297 297 directly.
298 298 session : bytes
299 299 the ID of this Session object. The default is to generate a new
300 300 UUID.
301 301 username : unicode
302 302 username added to message headers. The default is to ask the OS.
303 303 key : bytes
304 304 The key used to initialize an HMAC signature. If unset, messages
305 305 will not be signed or checked.
306 306 keyfile : filepath
307 307 The file containing a key. If this is set, `key` will be
308 308 initialized to the contents of the file.
309 309 """
310 310 super(Session, self).__init__(**kwargs)
311 311 self._check_packers()
312 312 self.none = self.pack({})
313 313
314 314 @property
315 315 def msg_id(self):
316 316 """always return new uuid"""
317 317 return str(uuid.uuid4())
318 318
319 319 def _check_packers(self):
320 320 """check packers for binary data and datetime support."""
321 321 pack = self.pack
322 322 unpack = self.unpack
323 323
324 324 # check simple serialization
325 325 msg = dict(a=[1,'hi'])
326 326 try:
327 327 packed = pack(msg)
328 328 except Exception:
329 329 raise ValueError("packer could not serialize a simple message")
330 330
331 331 # ensure packed message is bytes
332 332 if not isinstance(packed, bytes):
333 333 raise ValueError("message packed to %r, but bytes are required"%type(packed))
334 334
335 335 # check that unpack is pack's inverse
336 336 try:
337 337 unpacked = unpack(packed)
338 338 except Exception:
339 339 raise ValueError("unpacker could not handle the packer's output")
340 340
341 341 # check datetime support
342 342 msg = dict(t=datetime.now())
343 343 try:
344 344 unpacked = unpack(pack(msg))
345 345 except Exception:
346 346 self.pack = lambda o: pack(squash_dates(o))
347 347 self.unpack = lambda s: extract_dates(unpack(s))
348 348
349 349 def msg_header(self, msg_type):
350 350 return msg_header(self.msg_id, msg_type, self.username, self.session)
351 351
352 352 def msg(self, msg_type, content=None, parent=None, subheader=None):
353 353 msg = {}
354 354 msg['header'] = self.msg_header(msg_type)
355 355 msg['msg_id'] = msg['header']['msg_id']
356 356 msg['parent_header'] = {} if parent is None else extract_header(parent)
357 357 msg['msg_type'] = msg_type
358 358 msg['content'] = {} if content is None else content
359 359 sub = {} if subheader is None else subheader
360 360 msg['header'].update(sub)
361 361 return msg
362 362
363 363 def sign(self, msg):
364 364 """Sign a message with HMAC digest. If no auth, return b''."""
365 365 if self.auth is None:
366 366 return b''
367 367 h = self.auth.copy()
368 368 for m in msg:
369 369 h.update(m)
370 370 return h.hexdigest()
371 371
372 372 def serialize(self, msg, ident=None):
373 373 """Serialize the message components to bytes.
374 374
375 375 Returns
376 376 -------
377 377
378 378 list of bytes objects
379 379
380 380 """
381 381 content = msg.get('content', {})
382 382 if content is None:
383 383 content = self.none
384 384 elif isinstance(content, dict):
385 385 content = self.pack(content)
386 386 elif isinstance(content, bytes):
387 387 # content is already packed, as in a relayed message
388 388 pass
389 389 elif isinstance(content, unicode):
390 390 # should be bytes, but JSON often spits out unicode
391 391 content = content.encode('utf8')
392 392 else:
393 393 raise TypeError("Content incorrect type: %s"%type(content))
394 394
395 395 real_message = [self.pack(msg['header']),
396 396 self.pack(msg['parent_header']),
397 397 content
398 398 ]
399 399
400 400 to_send = []
401 401
402 402 if isinstance(ident, list):
403 403 # accept list of idents
404 404 to_send.extend(ident)
405 405 elif ident is not None:
406 406 to_send.append(ident)
407 407 to_send.append(DELIM)
408 408
409 409 signature = self.sign(real_message)
410 410 to_send.append(signature)
411 411
412 412 to_send.extend(real_message)
413 413
414 414 return to_send
415 415
416 416 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
417 417 buffers=None, subheader=None, track=False):
418 418 """Build and send a message via stream or socket.
419 419
420 420 Parameters
421 421 ----------
422 422
423 423 stream : zmq.Socket or ZMQStream
424 424 the socket-like object used to send the data
425 425 msg_or_type : str or Message/dict
426 426 Normally, msg_or_type will be a msg_type unless a message is being
427 427 sent more than once.
428 428
429 429 content : dict or None
430 430 the content of the message (ignored if msg_or_type is a message)
431 431 parent : Message or dict or None
432 432 the parent or parent header describing the parent of this message
433 433 ident : bytes or list of bytes
434 434 the zmq.IDENTITY routing path
435 435 subheader : dict or None
436 436 extra header keys for this message's header
437 437 buffers : list or None
438 438 the already-serialized buffers to be appended to the message
439 439 track : bool
440 440 whether to track. Only for use with Sockets,
441 441 because ZMQStream objects cannot track messages.
442 442
443 443 Returns
444 444 -------
445 445 msg : message dict
446 446 the constructed message
447 447 (msg,tracker) : (message dict, MessageTracker)
448 448 if track=True, then a 2-tuple will be returned,
449 449 the first element being the constructed
450 450 message, and the second being the MessageTracker
451 451
452 452 """
453 453
454 454 if not isinstance(stream, (zmq.Socket, ZMQStream)):
455 455 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
456 456 elif track and isinstance(stream, ZMQStream):
457 457 raise TypeError("ZMQStream cannot track messages")
458 458
459 459 if isinstance(msg_or_type, (Message, dict)):
460 460 # we got a Message, not a msg_type
461 461 # don't build a new Message
462 462 msg = msg_or_type
463 463 else:
464 464 msg = self.msg(msg_or_type, content, parent, subheader)
465 465
466 466 buffers = [] if buffers is None else buffers
467 467 to_send = self.serialize(msg, ident)
468 468 flag = 0
469 469 if buffers:
470 470 flag = zmq.SNDMORE
471 471 _track = False
472 472 else:
473 473 _track=track
474 474 if track:
475 475 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
476 476 else:
477 477 tracker = stream.send_multipart(to_send, flag, copy=False)
478 478 for b in buffers[:-1]:
479 479 stream.send(b, flag, copy=False)
480 480 if buffers:
481 481 if track:
482 482 tracker = stream.send(buffers[-1], copy=False, track=track)
483 483 else:
484 484 tracker = stream.send(buffers[-1], copy=False)
485 485
486 486 # omsg = Message(msg)
487 487 if self.debug:
488 488 pprint.pprint(msg)
489 489 pprint.pprint(to_send)
490 490 pprint.pprint(buffers)
491 491
492 492 msg['tracker'] = tracker
493 493
494 494 return msg
495 495
496 496 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
497 497 """Send a raw message via ident path.
498 498
499 499 Parameters
500 500 ----------
501 501 msg : list of sendable buffers"""
502 502 to_send = []
503 503 if isinstance(ident, bytes):
504 504 ident = [ident]
505 505 if ident is not None:
506 506 to_send.extend(ident)
507 507
508 508 to_send.append(DELIM)
509 509 to_send.append(self.sign(msg))
510 510 to_send.extend(msg)
511 511 stream.send_multipart(msg, flags, copy=copy)
512 512
513 513 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
514 514 """receives and unpacks a message
515 515 returns [idents], msg"""
516 516 if isinstance(socket, ZMQStream):
517 517 socket = socket.socket
518 518 try:
519 519 msg = socket.recv_multipart(mode)
520 520 except zmq.ZMQError as e:
521 521 if e.errno == zmq.EAGAIN:
522 522 # We can convert EAGAIN to None as we know in this case
523 523 # recv_multipart won't return None.
524 524 return None,None
525 525 else:
526 526 raise
527 527 # split multipart message into identity list and message dict
528 528 # invalid large messages can cause very expensive string comparisons
529 529 idents, msg = self.feed_identities(msg, copy)
530 530 try:
531 531 return idents, self.unpack_message(msg, content=content, copy=copy)
532 532 except Exception as e:
533 533 print (idents, msg)
534 534 # TODO: handle it
535 535 raise e
536 536
537 537 def feed_identities(self, msg, copy=True):
538 538 """feed until DELIM is reached, then return the prefix as idents and
539 539 remainder as msg. This is easily broken by setting an IDENT to DELIM,
540 540 but that would be silly.
541 541
542 542 Parameters
543 543 ----------
544 544 msg : a list of Message or bytes objects
545 545 the message to be split
546 546 copy : bool
547 547 flag determining whether the arguments are bytes or Messages
548 548
549 549 Returns
550 550 -------
551 551 (idents,msg) : two lists
552 552 idents will always be a list of bytes - the indentity prefix
553 553 msg will be a list of bytes or Messages, unchanged from input
554 554 msg should be unpackable via self.unpack_message at this point.
555 555 """
556 556 if copy:
557 557 idx = msg.index(DELIM)
558 558 return msg[:idx], msg[idx+1:]
559 559 else:
560 560 failed = True
561 561 for idx,m in enumerate(msg):
562 562 if m.bytes == DELIM:
563 563 failed = False
564 564 break
565 565 if failed:
566 566 raise ValueError("DELIM not in msg")
567 567 idents, msg = msg[:idx], msg[idx+1:]
568 568 return [m.bytes for m in idents], msg
569 569
570 570 def unpack_message(self, msg, content=True, copy=True):
571 571 """Return a message object from the format
572 572 sent by self.send.
573 573
574 574 Parameters:
575 575 -----------
576 576
577 577 content : bool (True)
578 578 whether to unpack the content dict (True),
579 579 or leave it serialized (False)
580 580
581 581 copy : bool (True)
582 582 whether to return the bytes (True),
583 583 or the non-copying Message object in each place (False)
584 584
585 585 """
586 586 minlen = 4
587 587 message = {}
588 588 if not copy:
589 589 for i in range(minlen):
590 590 msg[i] = msg[i].bytes
591 591 if self.auth is not None:
592 592 signature = msg[0]
593 593 if signature in self.digest_history:
594 594 raise ValueError("Duplicate Signature: %r"%signature)
595 595 self.digest_history.add(signature)
596 596 check = self.sign(msg[1:4])
597 597 if not signature == check:
598 598 raise ValueError("Invalid Signature: %r"%signature)
599 599 if not len(msg) >= minlen:
600 600 raise TypeError("malformed message, must have at least %i elements"%minlen)
601 601 message['header'] = self.unpack(msg[1])
602 602 message['msg_type'] = message['header']['msg_type']
603 603 message['parent_header'] = self.unpack(msg[2])
604 604 if content:
605 605 message['content'] = self.unpack(msg[3])
606 606 else:
607 607 message['content'] = msg[3]
608 608
609 609 message['buffers'] = msg[4:]
610 610 return message
611 611
612 612 def test_msg2obj():
613 613 am = dict(x=1)
614 614 ao = Message(am)
615 615 assert ao.x == am['x']
616 616
617 617 am['y'] = dict(z=1)
618 618 ao = Message(am)
619 619 assert ao.y.z == am['y']['z']
620 620
621 621 k1, k2 = 'y', 'z'
622 622 assert ao[k1][k2] == am[k1][k2]
623 623
624 624 am2 = dict(ao)
625 625 assert am['x'] == am2['x']
626 626 assert am['y']['z'] == am2['y']['z']
627 627
General Comments 0
You need to be logged in to leave comments. Login now