streamsession.py
483 lines
| 16.2 KiB
| text/x-python
|
PythonLexer
MinRK
|
r3539 | #!/usr/bin/env python | ||
"""edited session.py to work with streams, and move msg_type to the header | ||||
""" | ||||
MinRK
|
r3660 | #----------------------------------------------------------------------------- | ||
# Copyright (C) 2010-2011 The IPython Development Team | ||||
# | ||||
# Distributed under the terms of the BSD License. The full license is in | ||||
# the file COPYING, distributed as part of this software. | ||||
#----------------------------------------------------------------------------- | ||||
MinRK
|
r3539 | |||
MinRK
|
r4000 | #----------------------------------------------------------------------------- | ||
# Imports | ||||
#----------------------------------------------------------------------------- | ||||
MinRK
|
r3539 | |||
MinRK
|
r4000 | import hmac | ||
MinRK
|
r3539 | import os | ||
MinRK
|
r3631 | import pprint | ||
MinRK
|
r3539 | import uuid | ||
MinRK
|
r3556 | from datetime import datetime | ||
MinRK
|
r3539 | |||
MinRK
|
r3631 | try: | ||
import cPickle | ||||
pickle = cPickle | ||||
except: | ||||
cPickle = None | ||||
import pickle | ||||
MinRK
|
r3539 | import zmq | ||
from zmq.utils import jsonapi | ||||
from zmq.eventloop.zmqstream import ZMQStream | ||||
MinRK
|
r3985 | from IPython.config.configurable import Configurable | ||
from IPython.utils.importstring import import_item | ||||
MinRK
|
r4000 | from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set | ||
MinRK
|
r3985 | |||
MinRK
|
r3644 | from .util import ISO8601 | ||
MinRK
|
r3583 | |||
MinRK
|
r4000 | #----------------------------------------------------------------------------- | ||
# utility functions | ||||
#----------------------------------------------------------------------------- | ||||
MinRK
|
r3985 | |||
MinRK
|
r3549 | def squash_unicode(obj): | ||
MinRK
|
r3780 | """coerce unicode back to bytestrings.""" | ||
MinRK
|
r3549 | if isinstance(obj,dict): | ||
for key in obj.keys(): | ||||
obj[key] = squash_unicode(obj[key]) | ||||
if isinstance(key, unicode): | ||||
obj[squash_unicode(key)] = obj.pop(key) | ||||
elif isinstance(obj, list): | ||||
for i,v in enumerate(obj): | ||||
obj[i] = squash_unicode(v) | ||||
elif isinstance(obj, unicode): | ||||
obj = obj.encode('utf8') | ||||
return obj | ||||
MinRK
|
r3780 | def _date_default(obj): | ||
if isinstance(obj, datetime): | ||||
return obj.strftime(ISO8601) | ||||
else: | ||||
raise TypeError("%r is not JSON serializable"%obj) | ||||
MinRK
|
r4000 | #----------------------------------------------------------------------------- | ||
# globals and defaults | ||||
#----------------------------------------------------------------------------- | ||||
MinRK
|
r3780 | _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default' | ||
json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default}) | ||||
MinRK
|
r3604 | json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) | ||
pickle_packer = lambda o: pickle.dumps(o,-1) | ||||
pickle_unpacker = pickle.loads | ||||
MinRK
|
r3769 | default_packer = json_packer | ||
default_unpacker = json_unpacker | ||||
MinRK
|
r3539 | |||
DELIM="<IDS|MSG>" | ||||
MinRK
|
r4000 | #----------------------------------------------------------------------------- | ||
# Classes | ||||
#----------------------------------------------------------------------------- | ||||
MinRK
|
r3539 | class Message(object): | ||
"""A simple message object that maps dict keys to attributes. | ||||
A Message can be created from a dict and a dict from a Message instance | ||||
simply by calling dict(msg_obj).""" | ||||
def __init__(self, msg_dict): | ||||
dct = self.__dict__ | ||||
for k, v in dict(msg_dict).iteritems(): | ||||
if isinstance(v, dict): | ||||
v = Message(v) | ||||
dct[k] = v | ||||
# Having this iterator lets dict(msg_obj) work out of the box. | ||||
def __iter__(self): | ||||
return iter(self.__dict__.iteritems()) | ||||
def __repr__(self): | ||||
return repr(self.__dict__) | ||||
def __str__(self): | ||||
return pprint.pformat(self.__dict__) | ||||
def __contains__(self, k): | ||||
return k in self.__dict__ | ||||
def __getitem__(self, k): | ||||
return self.__dict__[k] | ||||
def msg_header(msg_id, msg_type, username, session): | ||||
MinRK
|
r3578 | date=datetime.now().strftime(ISO8601) | ||
MinRK
|
r3539 | return locals() | ||
def extract_header(msg_or_header): | ||||
"""Given a message or header, return the header.""" | ||||
if not msg_or_header: | ||||
return {} | ||||
try: | ||||
# See if msg_or_header is the entire message. | ||||
h = msg_or_header['header'] | ||||
except KeyError: | ||||
try: | ||||
# See if msg_or_header is just the header | ||||
h = msg_or_header['msg_id'] | ||||
except KeyError: | ||||
raise | ||||
else: | ||||
h = msg_or_header | ||||
if not isinstance(h, dict): | ||||
h = dict(h) | ||||
return h | ||||
MinRK
|
r3985 | class StreamSession(Configurable): | ||
MinRK
|
r3539 | """tweaked version of IPython.zmq.session.Session, for development in Parallel""" | ||
MinRK
|
r3985 | debug=Bool(False, config=True, help="""Debug output in the StreamSession""") | ||
MinRK
|
r3988 | packer = Unicode('json',config=True, | ||
MinRK
|
r3985 | help="""The name of the packer for serializing messages. | ||
Should be one of 'json', 'pickle', or an import name | ||||
for a custom serializer.""") | ||||
def _packer_changed(self, name, old, new): | ||||
if new.lower() == 'json': | ||||
self.pack = json_packer | ||||
self.unpack = json_unpacker | ||||
elif new.lower() == 'pickle': | ||||
self.pack = pickle_packer | ||||
self.unpack = pickle_unpacker | ||||
MinRK
|
r3539 | else: | ||
MinRK
|
r3985 | self.pack = import_item(new) | ||
MinRK
|
r3988 | unpacker = Unicode('json',config=True, | ||
MinRK
|
r3985 | help="""The name of the unpacker for unserializing messages. | ||
Only used with custom functions for `packer`.""") | ||||
def _unpacker_changed(self, name, old, new): | ||||
if new.lower() == 'json': | ||||
self.pack = json_packer | ||||
self.unpack = json_unpacker | ||||
elif new.lower() == 'pickle': | ||||
self.pack = pickle_packer | ||||
self.unpack = pickle_unpacker | ||||
MinRK
|
r3539 | else: | ||
MinRK
|
r3985 | self.unpack = import_item(new) | ||
MinRK
|
r3539 | |||
MinRK
|
r3985 | session = CStr('',config=True, | ||
help="""The UUID identifying this session.""") | ||||
def _session_default(self): | ||||
MinRK
|
r3988 | return bytes(uuid.uuid4()) | ||
MinRK
|
r4000 | username = Unicode(os.environ.get('USER','username'), config=True, | ||
MinRK
|
r3985 | help="""Username for the Session. Default is your system username.""") | ||
MinRK
|
r4000 | |||
# message signature related traits: | ||||
MinRK
|
r3985 | key = CStr('', config=True, | ||
help="""execution key, for extra authentication.""") | ||||
MinRK
|
r4000 | def _key_changed(self, name, old, new): | ||
if new: | ||||
self.auth = hmac.HMAC(new) | ||||
else: | ||||
self.auth = None | ||||
auth = Instance(hmac.HMAC) | ||||
counters = Instance('collections.defaultdict', (int,)) | ||||
digest_history = Set() | ||||
MinRK
|
r3988 | keyfile = Unicode('', config=True, | ||
MinRK
|
r3985 | help="""path to file containing execution key.""") | ||
def _keyfile_changed(self, name, old, new): | ||||
with open(new, 'rb') as f: | ||||
self.key = f.read().strip() | ||||
pack = Any(default_packer) # the actual packer function | ||||
def _pack_changed(self, name, old, new): | ||||
if not callable(new): | ||||
raise TypeError("packer must be callable, not %s"%type(new)) | ||||
MinRK
|
r3539 | |||
MinRK
|
r3985 | unpack = Any(default_unpacker) # the actual packer function | ||
def _unpack_changed(self, name, old, new): | ||||
if not callable(new): | ||||
raise TypeError("packer must be callable, not %s"%type(new)) | ||||
def __init__(self, **kwargs): | ||||
super(StreamSession, self).__init__(**kwargs) | ||||
MinRK
|
r3539 | self.none = self.pack({}) | ||
MinRK
|
r3985 | |||
@property | ||||
def msg_id(self): | ||||
"""always return new uuid""" | ||||
return str(uuid.uuid4()) | ||||
MinRK
|
r3539 | def msg_header(self, msg_type): | ||
MinRK
|
r3985 | return msg_header(self.msg_id, msg_type, self.username, self.session) | ||
MinRK
|
r3539 | |||
def msg(self, msg_type, content=None, parent=None, subheader=None): | ||||
msg = {} | ||||
msg['header'] = self.msg_header(msg_type) | ||||
msg['msg_id'] = msg['header']['msg_id'] | ||||
msg['parent_header'] = {} if parent is None else extract_header(parent) | ||||
msg['msg_type'] = msg_type | ||||
msg['content'] = {} if content is None else content | ||||
sub = {} if subheader is None else subheader | ||||
msg['header'].update(sub) | ||||
return msg | ||||
MinRK
|
r3575 | def check_key(self, msg_or_header): | ||
"""Check that a message's header has the right key""" | ||||
MinRK
|
r3985 | if not self.key: | ||
MinRK
|
r3575 | return True | ||
header = extract_header(msg_or_header) | ||||
MinRK
|
r3985 | return header.get('key', '') == self.key | ||
MinRK
|
r4000 | |||
def sign(self, msg): | ||||
"""Sign a message with HMAC digest. If no auth, return b''.""" | ||||
if self.auth is None: | ||||
return b'' | ||||
h = self.auth.copy() | ||||
for m in msg: | ||||
h.update(m) | ||||
return h.hexdigest() | ||||
MinRK
|
r3872 | |||
def serialize(self, msg, ident=None): | ||||
content = msg.get('content', {}) | ||||
if content is None: | ||||
content = self.none | ||||
elif isinstance(content, dict): | ||||
content = self.pack(content) | ||||
elif isinstance(content, bytes): | ||||
# content is already packed, as in a relayed message | ||||
pass | ||||
MinRK
|
r3874 | elif isinstance(content, unicode): | ||
# should be bytes, but JSON often spits out unicode | ||||
content = content.encode('utf8') | ||||
MinRK
|
r3872 | else: | ||
raise TypeError("Content incorrect type: %s"%type(content)) | ||||
MinRK
|
r4000 | |||
real_message = [self.pack(msg['header']), | ||||
self.pack(msg['parent_header']), | ||||
content | ||||
] | ||||
MinRK
|
r3872 | to_send = [] | ||
if isinstance(ident, list): | ||||
# accept list of idents | ||||
to_send.extend(ident) | ||||
elif ident is not None: | ||||
to_send.append(ident) | ||||
to_send.append(DELIM) | ||||
MinRK
|
r4000 | |||
signature = self.sign(real_message) | ||||
to_send.append(signature) | ||||
to_send.extend(real_message) | ||||
MinRK
|
r3872 | |||
return to_send | ||||
MinRK
|
r3575 | |||
MinRK
|
r3654 | def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False): | ||
MinRK
|
r3549 | """Build and send a message via stream or socket. | ||
Parameters | ||||
---------- | ||||
MinRK
|
r3556 | stream : zmq.Socket or ZMQStream | ||
the socket-like object used to send the data | ||||
MinRK
|
r3602 | msg_or_type : str or Message/dict | ||
Normally, msg_or_type will be a msg_type unless a message is being sent more | ||||
than once. | ||||
MinRK
|
r3549 | |||
MinRK
|
r3654 | content : dict or None | ||
the content of the message (ignored if msg_or_type is a message) | ||||
buffers : list or None | ||||
the already-serialized buffers to be appended to the message | ||||
parent : Message or dict or None | ||||
the parent or parent header describing the parent of this message | ||||
subheader : dict or None | ||||
extra header keys for this message's header | ||||
ident : bytes or list of bytes | ||||
the zmq.IDENTITY routing path | ||||
track : bool | ||||
whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. | ||||
MinRK
|
r3549 | Returns | ||
------- | ||||
MinRK
|
r3654 | msg : message dict | ||
the constructed message | ||||
(msg,tracker) : (message dict, MessageTracker) | ||||
if track=True, then a 2-tuple will be returned, the first element being the constructed | ||||
message, and the second being the MessageTracker | ||||
MinRK
|
r3549 | |||
""" | ||||
MinRK
|
r3654 | |||
if not isinstance(stream, (zmq.Socket, ZMQStream)): | ||||
raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) | ||||
elif track and isinstance(stream, ZMQStream): | ||||
raise TypeError("ZMQStream cannot track messages") | ||||
MinRK
|
r3602 | if isinstance(msg_or_type, (Message, dict)): | ||
MinRK
|
r3549 | # we got a Message, not a msg_type | ||
# don't build a new Message | ||||
MinRK
|
r3602 | msg = msg_or_type | ||
MinRK
|
r3549 | else: | ||
MinRK
|
r3602 | msg = self.msg(msg_or_type, content, parent, subheader) | ||
MinRK
|
r3654 | |||
MinRK
|
r3539 | buffers = [] if buffers is None else buffers | ||
MinRK
|
r3872 | to_send = self.serialize(msg, ident) | ||
MinRK
|
r3539 | flag = 0 | ||
if buffers: | ||||
flag = zmq.SNDMORE | ||||
MinRK
|
r3654 | _track = False | ||
else: | ||||
_track=track | ||||
if track: | ||||
tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) | ||||
else: | ||||
tracker = stream.send_multipart(to_send, flag, copy=False) | ||||
MinRK
|
r3539 | for b in buffers[:-1]: | ||
stream.send(b, flag, copy=False) | ||||
if buffers: | ||||
MinRK
|
r3654 | if track: | ||
tracker = stream.send(buffers[-1], copy=False, track=track) | ||||
else: | ||||
tracker = stream.send(buffers[-1], copy=False) | ||||
MinRK
|
r3607 | # omsg = Message(msg) | ||
MinRK
|
r3540 | if self.debug: | ||
MinRK
|
r3607 | pprint.pprint(msg) | ||
MinRK
|
r3540 | pprint.pprint(to_send) | ||
pprint.pprint(buffers) | ||||
MinRK
|
r3654 | |||
msg['tracker'] = tracker | ||||
MinRK
|
r3607 | return msg | ||
MinRK
|
r3549 | |||
MinRK
|
r3583 | def send_raw(self, stream, msg, flags=0, copy=True, ident=None): | ||
MinRK
|
r3584 | """Send a raw message via ident path. | ||
MinRK
|
r3549 | |||
Parameters | ||||
---------- | ||||
msg : list of sendable buffers""" | ||||
to_send = [] | ||||
MinRK
|
r3654 | if isinstance(ident, bytes): | ||
MinRK
|
r3549 | ident = [ident] | ||
if ident is not None: | ||||
to_send.extend(ident) | ||||
MinRK
|
r4000 | |||
MinRK
|
r3549 | to_send.append(DELIM) | ||
MinRK
|
r4000 | to_send.append(self.sign(msg)) | ||
MinRK
|
r3549 | to_send.extend(msg) | ||
stream.send_multipart(msg, flags, copy=copy) | ||||
MinRK
|
r3539 | def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): | ||
"""receives and unpacks a message | ||||
returns [idents], msg""" | ||||
if isinstance(socket, ZMQStream): | ||||
socket = socket.socket | ||||
try: | ||||
MinRK
|
r3913 | msg = socket.recv_multipart(mode, copy=copy) | ||
MinRK
|
r3556 | except zmq.ZMQError as e: | ||
MinRK
|
r3539 | if e.errno == zmq.EAGAIN: | ||
# We can convert EAGAIN to None as we know in this case | ||||
MinRK
|
r3616 | # recv_multipart won't return None. | ||
MinRK
|
r3539 | return None | ||
else: | ||||
raise | ||||
# return an actual Message object | ||||
# determine the number of idents by trying to unpack them. | ||||
# this is terrible: | ||||
idents, msg = self.feed_identities(msg, copy) | ||||
try: | ||||
return idents, self.unpack_message(msg, content=content, copy=copy) | ||||
MinRK
|
r3556 | except Exception as e: | ||
MinRK
|
r3553 | print (idents, msg) | ||
MinRK
|
r3539 | # TODO: handle it | ||
raise e | ||||
def feed_identities(self, msg, copy=True): | ||||
MinRK
|
r3584 | """feed until DELIM is reached, then return the prefix as idents and remainder as | ||
msg. This is easily broken by setting an IDENT to DELIM, but that would be silly. | ||||
Parameters | ||||
---------- | ||||
msg : a list of Message or bytes objects | ||||
the message to be split | ||||
copy : bool | ||||
flag determining whether the arguments are bytes or Messages | ||||
Returns | ||||
------- | ||||
(idents,msg) : two lists | ||||
idents will always be a list of bytes - the indentity prefix | ||||
msg will be a list of bytes or Messages, unchanged from input | ||||
msg should be unpackable via self.unpack_message at this point. | ||||
""" | ||||
MinRK
|
r4000 | if copy: | ||
idx = msg.index(DELIM) | ||||
return msg[:idx], msg[idx+1:] | ||||
else: | ||||
failed = True | ||||
for idx,m in enumerate(msg): | ||||
if m.bytes == DELIM: | ||||
failed = False | ||||
break | ||||
if failed: | ||||
raise ValueError("DELIM not in msg") | ||||
idents, msg = msg[:idx], msg[idx+1:] | ||||
return [m.bytes for m in idents], msg | ||||
MinRK
|
r3539 | |||
def unpack_message(self, msg, content=True, copy=True): | ||||
MinRK
|
r3551 | """Return a message object from the format | ||
MinRK
|
r3539 | sent by self.send. | ||
MinRK
|
r3551 | Parameters: | ||
----------- | ||||
MinRK
|
r3539 | |||
content : bool (True) | ||||
whether to unpack the content dict (True), | ||||
or leave it serialized (False) | ||||
copy : bool (True) | ||||
whether to return the bytes (True), | ||||
or the non-copying Message object in each place (False) | ||||
""" | ||||
MinRK
|
r4000 | minlen = 4 | ||
MinRK
|
r3539 | message = {} | ||
if not copy: | ||||
MinRK
|
r3575 | for i in range(minlen): | ||
MinRK
|
r3539 | msg[i] = msg[i].bytes | ||
MinRK
|
r4000 | if self.auth is not None: | ||
signature = msg[0] | ||||
if signature in self.digest_history: | ||||
raise ValueError("Duplicate Signature: %r"%signature) | ||||
self.digest_history.add(signature) | ||||
check = self.sign(msg[1:4]) | ||||
if not signature == check: | ||||
raise ValueError("Invalid Signature: %r"%signature) | ||||
MinRK
|
r3616 | if not len(msg) >= minlen: | ||
raise TypeError("malformed message, must have at least %i elements"%minlen) | ||||
MinRK
|
r4000 | message['header'] = self.unpack(msg[1]) | ||
MinRK
|
r3539 | message['msg_type'] = message['header']['msg_type'] | ||
MinRK
|
r4000 | message['parent_header'] = self.unpack(msg[2]) | ||
MinRK
|
r3539 | if content: | ||
MinRK
|
r4000 | message['content'] = self.unpack(msg[3]) | ||
MinRK
|
r3539 | else: | ||
MinRK
|
r4000 | message['content'] = msg[3] | ||
MinRK
|
r3539 | |||
MinRK
|
r4000 | message['buffers'] = msg[4:] | ||
MinRK
|
r3539 | return message | ||
def test_msg2obj(): | ||||
am = dict(x=1) | ||||
ao = Message(am) | ||||
assert ao.x == am['x'] | ||||
am['y'] = dict(z=1) | ||||
ao = Message(am) | ||||
assert ao.y.z == am['y']['z'] | ||||
k1, k2 = 'y', 'z' | ||||
assert ao[k1][k2] == am[k1][k2] | ||||
am2 = dict(ao) | ||||
assert am['x'] == am2['x'] | ||||
assert am['y']['z'] == am2['y']['z'] | ||||