##// END OF EJS Templates
Backport PR #6029: add pickleutil.PICKLE_PROTOCOL...
MinRK -
Show More
@@ -1,198 +1,185
1 """serialization utilities for apply messages
1 """serialization utilities for apply messages"""
2
2
3 Authors:
3 # Copyright (c) IPython Development Team.
4
4 # Distributed under the terms of the Modified BSD License.
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
13
14 #-----------------------------------------------------------------------------
15 # Imports
16 #-----------------------------------------------------------------------------
17
5
18 try:
6 try:
19 import cPickle
7 import cPickle
20 pickle = cPickle
8 pickle = cPickle
21 except:
9 except:
22 cPickle = None
10 cPickle = None
23 import pickle
11 import pickle
24
12
25
26 # IPython imports
13 # IPython imports
27 from IPython.utils import py3compat
14 from IPython.utils import py3compat
28 from IPython.utils.data import flatten
15 from IPython.utils.data import flatten
29 from IPython.utils.pickleutil import (
16 from IPython.utils.pickleutil import (
30 can, uncan, can_sequence, uncan_sequence, CannedObject,
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
31 istype, sequence_types,
18 istype, sequence_types, PICKLE_PROTOCOL,
32 )
19 )
33
20
34 if py3compat.PY3:
21 if py3compat.PY3:
35 buffer = memoryview
22 buffer = memoryview
36
23
37 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
38 # Serialization Functions
25 # Serialization Functions
39 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
40
27
41 # default values for the thresholds:
28 # default values for the thresholds:
42 MAX_ITEMS = 64
29 MAX_ITEMS = 64
43 MAX_BYTES = 1024
30 MAX_BYTES = 1024
44
31
45 def _extract_buffers(obj, threshold=MAX_BYTES):
32 def _extract_buffers(obj, threshold=MAX_BYTES):
46 """extract buffers larger than a certain threshold"""
33 """extract buffers larger than a certain threshold"""
47 buffers = []
34 buffers = []
48 if isinstance(obj, CannedObject) and obj.buffers:
35 if isinstance(obj, CannedObject) and obj.buffers:
49 for i,buf in enumerate(obj.buffers):
36 for i,buf in enumerate(obj.buffers):
50 if len(buf) > threshold:
37 if len(buf) > threshold:
51 # buffer larger than threshold, prevent pickling
38 # buffer larger than threshold, prevent pickling
52 obj.buffers[i] = None
39 obj.buffers[i] = None
53 buffers.append(buf)
40 buffers.append(buf)
54 elif isinstance(buf, buffer):
41 elif isinstance(buf, buffer):
55 # buffer too small for separate send, coerce to bytes
42 # buffer too small for separate send, coerce to bytes
56 # because pickling buffer objects just results in broken pointers
43 # because pickling buffer objects just results in broken pointers
57 obj.buffers[i] = bytes(buf)
44 obj.buffers[i] = bytes(buf)
58 return buffers
45 return buffers
59
46
60 def _restore_buffers(obj, buffers):
47 def _restore_buffers(obj, buffers):
61 """restore buffers extracted by """
48 """restore buffers extracted by """
62 if isinstance(obj, CannedObject) and obj.buffers:
49 if isinstance(obj, CannedObject) and obj.buffers:
63 for i,buf in enumerate(obj.buffers):
50 for i,buf in enumerate(obj.buffers):
64 if buf is None:
51 if buf is None:
65 obj.buffers[i] = buffers.pop(0)
52 obj.buffers[i] = buffers.pop(0)
66
53
67 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
54 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
68 """Serialize an object into a list of sendable buffers.
55 """Serialize an object into a list of sendable buffers.
69
56
70 Parameters
57 Parameters
71 ----------
58 ----------
72
59
73 obj : object
60 obj : object
74 The object to be serialized
61 The object to be serialized
75 buffer_threshold : int
62 buffer_threshold : int
76 The threshold (in bytes) for pulling out data buffers
63 The threshold (in bytes) for pulling out data buffers
77 to avoid pickling them.
64 to avoid pickling them.
78 item_threshold : int
65 item_threshold : int
79 The maximum number of items over which canning will iterate.
66 The maximum number of items over which canning will iterate.
80 Containers (lists, dicts) larger than this will be pickled without
67 Containers (lists, dicts) larger than this will be pickled without
81 introspection.
68 introspection.
82
69
83 Returns
70 Returns
84 -------
71 -------
85 [bufs] : list of buffers representing the serialized object.
72 [bufs] : list of buffers representing the serialized object.
86 """
73 """
87 buffers = []
74 buffers = []
88 if istype(obj, sequence_types) and len(obj) < item_threshold:
75 if istype(obj, sequence_types) and len(obj) < item_threshold:
89 cobj = can_sequence(obj)
76 cobj = can_sequence(obj)
90 for c in cobj:
77 for c in cobj:
91 buffers.extend(_extract_buffers(c, buffer_threshold))
78 buffers.extend(_extract_buffers(c, buffer_threshold))
92 elif istype(obj, dict) and len(obj) < item_threshold:
79 elif istype(obj, dict) and len(obj) < item_threshold:
93 cobj = {}
80 cobj = {}
94 for k in sorted(obj):
81 for k in sorted(obj):
95 c = can(obj[k])
82 c = can(obj[k])
96 buffers.extend(_extract_buffers(c, buffer_threshold))
83 buffers.extend(_extract_buffers(c, buffer_threshold))
97 cobj[k] = c
84 cobj[k] = c
98 else:
85 else:
99 cobj = can(obj)
86 cobj = can(obj)
100 buffers.extend(_extract_buffers(cobj, buffer_threshold))
87 buffers.extend(_extract_buffers(cobj, buffer_threshold))
101
88
102 buffers.insert(0, pickle.dumps(cobj,-1))
89 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
103 return buffers
90 return buffers
104
91
105 def unserialize_object(buffers, g=None):
92 def unserialize_object(buffers, g=None):
106 """reconstruct an object serialized by serialize_object from data buffers.
93 """reconstruct an object serialized by serialize_object from data buffers.
107
94
108 Parameters
95 Parameters
109 ----------
96 ----------
110
97
111 bufs : list of buffers/bytes
98 bufs : list of buffers/bytes
112
99
113 g : globals to be used when uncanning
100 g : globals to be used when uncanning
114
101
115 Returns
102 Returns
116 -------
103 -------
117
104
118 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
105 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
119 """
106 """
120 bufs = list(buffers)
107 bufs = list(buffers)
121 pobj = bufs.pop(0)
108 pobj = bufs.pop(0)
122 if not isinstance(pobj, bytes):
109 if not isinstance(pobj, bytes):
123 # a zmq message
110 # a zmq message
124 pobj = bytes(pobj)
111 pobj = bytes(pobj)
125 canned = pickle.loads(pobj)
112 canned = pickle.loads(pobj)
126 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
113 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
127 for c in canned:
114 for c in canned:
128 _restore_buffers(c, bufs)
115 _restore_buffers(c, bufs)
129 newobj = uncan_sequence(canned, g)
116 newobj = uncan_sequence(canned, g)
130 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
117 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
131 newobj = {}
118 newobj = {}
132 for k in sorted(canned):
119 for k in sorted(canned):
133 c = canned[k]
120 c = canned[k]
134 _restore_buffers(c, bufs)
121 _restore_buffers(c, bufs)
135 newobj[k] = uncan(c, g)
122 newobj[k] = uncan(c, g)
136 else:
123 else:
137 _restore_buffers(canned, bufs)
124 _restore_buffers(canned, bufs)
138 newobj = uncan(canned, g)
125 newobj = uncan(canned, g)
139
126
140 return newobj, bufs
127 return newobj, bufs
141
128
142 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
129 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
143 """pack up a function, args, and kwargs to be sent over the wire
130 """pack up a function, args, and kwargs to be sent over the wire
144
131
145 Each element of args/kwargs will be canned for special treatment,
132 Each element of args/kwargs will be canned for special treatment,
146 but inspection will not go any deeper than that.
133 but inspection will not go any deeper than that.
147
134
148 Any object whose data is larger than `threshold` will not have their data copied
135 Any object whose data is larger than `threshold` will not have their data copied
149 (only numpy arrays and bytes/buffers support zero-copy)
136 (only numpy arrays and bytes/buffers support zero-copy)
150
137
151 Message will be a list of bytes/buffers of the format:
138 Message will be a list of bytes/buffers of the format:
152
139
153 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
140 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
154
141
155 With length at least two + len(args) + len(kwargs)
142 With length at least two + len(args) + len(kwargs)
156 """
143 """
157
144
158 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
145 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
159
146
160 kw_keys = sorted(kwargs.keys())
147 kw_keys = sorted(kwargs.keys())
161 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
148 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
162
149
163 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
150 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
164
151
165 msg = [pickle.dumps(can(f),-1)]
152 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
166 msg.append(pickle.dumps(info, -1))
153 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
167 msg.extend(arg_bufs)
154 msg.extend(arg_bufs)
168 msg.extend(kwarg_bufs)
155 msg.extend(kwarg_bufs)
169
156
170 return msg
157 return msg
171
158
172 def unpack_apply_message(bufs, g=None, copy=True):
159 def unpack_apply_message(bufs, g=None, copy=True):
173 """unpack f,args,kwargs from buffers packed by pack_apply_message()
160 """unpack f,args,kwargs from buffers packed by pack_apply_message()
174 Returns: original f,args,kwargs"""
161 Returns: original f,args,kwargs"""
175 bufs = list(bufs) # allow us to pop
162 bufs = list(bufs) # allow us to pop
176 assert len(bufs) >= 2, "not enough buffers!"
163 assert len(bufs) >= 2, "not enough buffers!"
177 if not copy:
164 if not copy:
178 for i in range(2):
165 for i in range(2):
179 bufs[i] = bufs[i].bytes
166 bufs[i] = bufs[i].bytes
180 f = uncan(pickle.loads(bufs.pop(0)), g)
167 f = uncan(pickle.loads(bufs.pop(0)), g)
181 info = pickle.loads(bufs.pop(0))
168 info = pickle.loads(bufs.pop(0))
182 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
169 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
183
170
184 args = []
171 args = []
185 for i in range(info['nargs']):
172 for i in range(info['nargs']):
186 arg, arg_bufs = unserialize_object(arg_bufs, g)
173 arg, arg_bufs = unserialize_object(arg_bufs, g)
187 args.append(arg)
174 args.append(arg)
188 args = tuple(args)
175 args = tuple(args)
189 assert not arg_bufs, "Shouldn't be any arg bufs left over"
176 assert not arg_bufs, "Shouldn't be any arg bufs left over"
190
177
191 kwargs = {}
178 kwargs = {}
192 for key in info['kw_keys']:
179 for key in info['kw_keys']:
193 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
180 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
194 kwargs[key] = kwarg
181 kwargs[key] = kwarg
195 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
182 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
196
183
197 return f,args,kwargs
184 return f,args,kwargs
198
185
@@ -1,850 +1,851
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hashlib
27 import hashlib
28 import hmac
28 import hmac
29 import logging
29 import logging
30 import os
30 import os
31 import pprint
31 import pprint
32 import random
32 import random
33 import uuid
33 import uuid
34 from datetime import datetime
34 from datetime import datetime
35
35
36 try:
36 try:
37 import cPickle
37 import cPickle
38 pickle = cPickle
38 pickle = cPickle
39 except:
39 except:
40 cPickle = None
40 cPickle = None
41 import pickle
41 import pickle
42
42
43 import zmq
43 import zmq
44 from zmq.utils import jsonapi
44 from zmq.utils import jsonapi
45 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.ioloop import IOLoop
46 from zmq.eventloop.zmqstream import ZMQStream
46 from zmq.eventloop.zmqstream import ZMQStream
47
47
48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.config.configurable import Configurable, LoggingConfigurable
49 from IPython.utils import io
49 from IPython.utils import io
50 from IPython.utils.importstring import import_item
50 from IPython.utils.importstring import import_item
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
53 iteritems)
53 iteritems)
54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
55 DottedObjectName, CUnicode, Dict, Integer,
55 DottedObjectName, CUnicode, Dict, Integer,
56 TraitError,
56 TraitError,
57 )
57 )
58 from IPython.utils.pickleutil import PICKLE_PROTOCOL
58 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59
60
60 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
61 # utility functions
62 # utility functions
62 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
63
64
64 def squash_unicode(obj):
65 def squash_unicode(obj):
65 """coerce unicode back to bytestrings."""
66 """coerce unicode back to bytestrings."""
66 if isinstance(obj,dict):
67 if isinstance(obj,dict):
67 for key in obj.keys():
68 for key in obj.keys():
68 obj[key] = squash_unicode(obj[key])
69 obj[key] = squash_unicode(obj[key])
69 if isinstance(key, unicode_type):
70 if isinstance(key, unicode_type):
70 obj[squash_unicode(key)] = obj.pop(key)
71 obj[squash_unicode(key)] = obj.pop(key)
71 elif isinstance(obj, list):
72 elif isinstance(obj, list):
72 for i,v in enumerate(obj):
73 for i,v in enumerate(obj):
73 obj[i] = squash_unicode(v)
74 obj[i] = squash_unicode(v)
74 elif isinstance(obj, unicode_type):
75 elif isinstance(obj, unicode_type):
75 obj = obj.encode('utf8')
76 obj = obj.encode('utf8')
76 return obj
77 return obj
77
78
78 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
79 # globals and defaults
80 # globals and defaults
80 #-----------------------------------------------------------------------------
81 #-----------------------------------------------------------------------------
81
82
82 # ISO8601-ify datetime objects
83 # ISO8601-ify datetime objects
83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
84 json_unpacker = lambda s: jsonapi.loads(s)
85 json_unpacker = lambda s: jsonapi.loads(s)
85
86
86 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
87 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
87 pickle_unpacker = pickle.loads
88 pickle_unpacker = pickle.loads
88
89
89 default_packer = json_packer
90 default_packer = json_packer
90 default_unpacker = json_unpacker
91 default_unpacker = json_unpacker
91
92
92 DELIM = b"<IDS|MSG>"
93 DELIM = b"<IDS|MSG>"
93 # singleton dummy tracker, which will always report as done
94 # singleton dummy tracker, which will always report as done
94 DONE = zmq.MessageTracker()
95 DONE = zmq.MessageTracker()
95
96
96 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
97 # Mixin tools for apps that use Sessions
98 # Mixin tools for apps that use Sessions
98 #-----------------------------------------------------------------------------
99 #-----------------------------------------------------------------------------
99
100
100 session_aliases = dict(
101 session_aliases = dict(
101 ident = 'Session.session',
102 ident = 'Session.session',
102 user = 'Session.username',
103 user = 'Session.username',
103 keyfile = 'Session.keyfile',
104 keyfile = 'Session.keyfile',
104 )
105 )
105
106
106 session_flags = {
107 session_flags = {
107 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 'keyfile' : '' }},
109 'keyfile' : '' }},
109 """Use HMAC digests for authentication of messages.
110 """Use HMAC digests for authentication of messages.
110 Setting this flag will generate a new UUID to use as the HMAC key.
111 Setting this flag will generate a new UUID to use as the HMAC key.
111 """),
112 """),
112 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 """Don't authenticate messages."""),
114 """Don't authenticate messages."""),
114 }
115 }
115
116
116 def default_secure(cfg):
117 def default_secure(cfg):
117 """Set the default behavior for a config environment to be secure.
118 """Set the default behavior for a config environment to be secure.
118
119
119 If Session.key/keyfile have not been set, set Session.key to
120 If Session.key/keyfile have not been set, set Session.key to
120 a new random UUID.
121 a new random UUID.
121 """
122 """
122
123
123 if 'Session' in cfg:
124 if 'Session' in cfg:
124 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 return
126 return
126 # key/keyfile not specified, generate new UUID:
127 # key/keyfile not specified, generate new UUID:
127 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128
129
129
130
130 #-----------------------------------------------------------------------------
131 #-----------------------------------------------------------------------------
131 # Classes
132 # Classes
132 #-----------------------------------------------------------------------------
133 #-----------------------------------------------------------------------------
133
134
134 class SessionFactory(LoggingConfigurable):
135 class SessionFactory(LoggingConfigurable):
135 """The Base class for configurables that have a Session, Context, logger,
136 """The Base class for configurables that have a Session, Context, logger,
136 and IOLoop.
137 and IOLoop.
137 """
138 """
138
139
139 logname = Unicode('')
140 logname = Unicode('')
140 def _logname_changed(self, name, old, new):
141 def _logname_changed(self, name, old, new):
141 self.log = logging.getLogger(new)
142 self.log = logging.getLogger(new)
142
143
143 # not configurable:
144 # not configurable:
144 context = Instance('zmq.Context')
145 context = Instance('zmq.Context')
145 def _context_default(self):
146 def _context_default(self):
146 return zmq.Context.instance()
147 return zmq.Context.instance()
147
148
148 session = Instance('IPython.kernel.zmq.session.Session')
149 session = Instance('IPython.kernel.zmq.session.Session')
149
150
150 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 def _loop_default(self):
152 def _loop_default(self):
152 return IOLoop.instance()
153 return IOLoop.instance()
153
154
154 def __init__(self, **kwargs):
155 def __init__(self, **kwargs):
155 super(SessionFactory, self).__init__(**kwargs)
156 super(SessionFactory, self).__init__(**kwargs)
156
157
157 if self.session is None:
158 if self.session is None:
158 # construct the session
159 # construct the session
159 self.session = Session(**kwargs)
160 self.session = Session(**kwargs)
160
161
161
162
162 class Message(object):
163 class Message(object):
163 """A simple message object that maps dict keys to attributes.
164 """A simple message object that maps dict keys to attributes.
164
165
165 A Message can be created from a dict and a dict from a Message instance
166 A Message can be created from a dict and a dict from a Message instance
166 simply by calling dict(msg_obj)."""
167 simply by calling dict(msg_obj)."""
167
168
168 def __init__(self, msg_dict):
169 def __init__(self, msg_dict):
169 dct = self.__dict__
170 dct = self.__dict__
170 for k, v in iteritems(dict(msg_dict)):
171 for k, v in iteritems(dict(msg_dict)):
171 if isinstance(v, dict):
172 if isinstance(v, dict):
172 v = Message(v)
173 v = Message(v)
173 dct[k] = v
174 dct[k] = v
174
175
175 # Having this iterator lets dict(msg_obj) work out of the box.
176 # Having this iterator lets dict(msg_obj) work out of the box.
176 def __iter__(self):
177 def __iter__(self):
177 return iter(iteritems(self.__dict__))
178 return iter(iteritems(self.__dict__))
178
179
179 def __repr__(self):
180 def __repr__(self):
180 return repr(self.__dict__)
181 return repr(self.__dict__)
181
182
182 def __str__(self):
183 def __str__(self):
183 return pprint.pformat(self.__dict__)
184 return pprint.pformat(self.__dict__)
184
185
185 def __contains__(self, k):
186 def __contains__(self, k):
186 return k in self.__dict__
187 return k in self.__dict__
187
188
188 def __getitem__(self, k):
189 def __getitem__(self, k):
189 return self.__dict__[k]
190 return self.__dict__[k]
190
191
191
192
192 def msg_header(msg_id, msg_type, username, session):
193 def msg_header(msg_id, msg_type, username, session):
193 date = datetime.now()
194 date = datetime.now()
194 return locals()
195 return locals()
195
196
196 def extract_header(msg_or_header):
197 def extract_header(msg_or_header):
197 """Given a message or header, return the header."""
198 """Given a message or header, return the header."""
198 if not msg_or_header:
199 if not msg_or_header:
199 return {}
200 return {}
200 try:
201 try:
201 # See if msg_or_header is the entire message.
202 # See if msg_or_header is the entire message.
202 h = msg_or_header['header']
203 h = msg_or_header['header']
203 except KeyError:
204 except KeyError:
204 try:
205 try:
205 # See if msg_or_header is just the header
206 # See if msg_or_header is just the header
206 h = msg_or_header['msg_id']
207 h = msg_or_header['msg_id']
207 except KeyError:
208 except KeyError:
208 raise
209 raise
209 else:
210 else:
210 h = msg_or_header
211 h = msg_or_header
211 if not isinstance(h, dict):
212 if not isinstance(h, dict):
212 h = dict(h)
213 h = dict(h)
213 return h
214 return h
214
215
215 class Session(Configurable):
216 class Session(Configurable):
216 """Object for handling serialization and sending of messages.
217 """Object for handling serialization and sending of messages.
217
218
218 The Session object handles building messages and sending them
219 The Session object handles building messages and sending them
219 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 other over the network via Session objects, and only need to work with the
221 other over the network via Session objects, and only need to work with the
221 dict-based IPython message spec. The Session will handle
222 dict-based IPython message spec. The Session will handle
222 serialization/deserialization, security, and metadata.
223 serialization/deserialization, security, and metadata.
223
224
224 Sessions support configurable serialiization via packer/unpacker traits,
225 Sessions support configurable serialiization via packer/unpacker traits,
225 and signing with HMAC digests via the key/keyfile traits.
226 and signing with HMAC digests via the key/keyfile traits.
226
227
227 Parameters
228 Parameters
228 ----------
229 ----------
229
230
230 debug : bool
231 debug : bool
231 whether to trigger extra debugging statements
232 whether to trigger extra debugging statements
232 packer/unpacker : str : 'json', 'pickle' or import_string
233 packer/unpacker : str : 'json', 'pickle' or import_string
233 importstrings for methods to serialize message parts. If just
234 importstrings for methods to serialize message parts. If just
234 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 Otherwise, the entire importstring must be used.
236 Otherwise, the entire importstring must be used.
236
237
237 The functions must accept at least valid JSON input, and output *bytes*.
238 The functions must accept at least valid JSON input, and output *bytes*.
238
239
239 For example, to use msgpack:
240 For example, to use msgpack:
240 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 pack/unpack : callables
242 pack/unpack : callables
242 You can also set the pack/unpack callables for serialization directly.
243 You can also set the pack/unpack callables for serialization directly.
243 session : bytes
244 session : bytes
244 the ID of this Session object. The default is to generate a new UUID.
245 the ID of this Session object. The default is to generate a new UUID.
245 username : unicode
246 username : unicode
246 username added to message headers. The default is to ask the OS.
247 username added to message headers. The default is to ask the OS.
247 key : bytes
248 key : bytes
248 The key used to initialize an HMAC signature. If unset, messages
249 The key used to initialize an HMAC signature. If unset, messages
249 will not be signed or checked.
250 will not be signed or checked.
250 keyfile : filepath
251 keyfile : filepath
251 The file containing a key. If this is set, `key` will be initialized
252 The file containing a key. If this is set, `key` will be initialized
252 to the contents of the file.
253 to the contents of the file.
253
254
254 """
255 """
255
256
256 debug=Bool(False, config=True, help="""Debug output in the Session""")
257 debug=Bool(False, config=True, help="""Debug output in the Session""")
257
258
258 packer = DottedObjectName('json',config=True,
259 packer = DottedObjectName('json',config=True,
259 help="""The name of the packer for serializing messages.
260 help="""The name of the packer for serializing messages.
260 Should be one of 'json', 'pickle', or an import name
261 Should be one of 'json', 'pickle', or an import name
261 for a custom callable serializer.""")
262 for a custom callable serializer.""")
262 def _packer_changed(self, name, old, new):
263 def _packer_changed(self, name, old, new):
263 if new.lower() == 'json':
264 if new.lower() == 'json':
264 self.pack = json_packer
265 self.pack = json_packer
265 self.unpack = json_unpacker
266 self.unpack = json_unpacker
266 self.unpacker = new
267 self.unpacker = new
267 elif new.lower() == 'pickle':
268 elif new.lower() == 'pickle':
268 self.pack = pickle_packer
269 self.pack = pickle_packer
269 self.unpack = pickle_unpacker
270 self.unpack = pickle_unpacker
270 self.unpacker = new
271 self.unpacker = new
271 else:
272 else:
272 self.pack = import_item(str(new))
273 self.pack = import_item(str(new))
273
274
274 unpacker = DottedObjectName('json', config=True,
275 unpacker = DottedObjectName('json', config=True,
275 help="""The name of the unpacker for unserializing messages.
276 help="""The name of the unpacker for unserializing messages.
276 Only used with custom functions for `packer`.""")
277 Only used with custom functions for `packer`.""")
277 def _unpacker_changed(self, name, old, new):
278 def _unpacker_changed(self, name, old, new):
278 if new.lower() == 'json':
279 if new.lower() == 'json':
279 self.pack = json_packer
280 self.pack = json_packer
280 self.unpack = json_unpacker
281 self.unpack = json_unpacker
281 self.packer = new
282 self.packer = new
282 elif new.lower() == 'pickle':
283 elif new.lower() == 'pickle':
283 self.pack = pickle_packer
284 self.pack = pickle_packer
284 self.unpack = pickle_unpacker
285 self.unpack = pickle_unpacker
285 self.packer = new
286 self.packer = new
286 else:
287 else:
287 self.unpack = import_item(str(new))
288 self.unpack = import_item(str(new))
288
289
289 session = CUnicode(u'', config=True,
290 session = CUnicode(u'', config=True,
290 help="""The UUID identifying this session.""")
291 help="""The UUID identifying this session.""")
291 def _session_default(self):
292 def _session_default(self):
292 u = unicode_type(uuid.uuid4())
293 u = unicode_type(uuid.uuid4())
293 self.bsession = u.encode('ascii')
294 self.bsession = u.encode('ascii')
294 return u
295 return u
295
296
296 def _session_changed(self, name, old, new):
297 def _session_changed(self, name, old, new):
297 self.bsession = self.session.encode('ascii')
298 self.bsession = self.session.encode('ascii')
298
299
299 # bsession is the session as bytes
300 # bsession is the session as bytes
300 bsession = CBytes(b'')
301 bsession = CBytes(b'')
301
302
302 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 help="""Username for the Session. Default is your system username.""",
304 help="""Username for the Session. Default is your system username.""",
304 config=True)
305 config=True)
305
306
306 metadata = Dict({}, config=True,
307 metadata = Dict({}, config=True,
307 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308
309
309 # message signature related traits:
310 # message signature related traits:
310
311
311 key = CBytes(b'', config=True,
312 key = CBytes(b'', config=True,
312 help="""execution key, for extra authentication.""")
313 help="""execution key, for extra authentication.""")
313 def _key_changed(self, name, old, new):
314 def _key_changed(self, name, old, new):
314 if new:
315 if new:
315 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 else:
317 else:
317 self.auth = None
318 self.auth = None
318
319
319 signature_scheme = Unicode('hmac-sha256', config=True,
320 signature_scheme = Unicode('hmac-sha256', config=True,
320 help="""The digest scheme used to construct the message signatures.
321 help="""The digest scheme used to construct the message signatures.
321 Must have the form 'hmac-HASH'.""")
322 Must have the form 'hmac-HASH'.""")
322 def _signature_scheme_changed(self, name, old, new):
323 def _signature_scheme_changed(self, name, old, new):
323 if not new.startswith('hmac-'):
324 if not new.startswith('hmac-'):
324 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 hash_name = new.split('-', 1)[1]
326 hash_name = new.split('-', 1)[1]
326 try:
327 try:
327 self.digest_mod = getattr(hashlib, hash_name)
328 self.digest_mod = getattr(hashlib, hash_name)
328 except AttributeError:
329 except AttributeError:
329 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330
331
331 digest_mod = Any()
332 digest_mod = Any()
332 def _digest_mod_default(self):
333 def _digest_mod_default(self):
333 return hashlib.sha256
334 return hashlib.sha256
334
335
335 auth = Instance(hmac.HMAC)
336 auth = Instance(hmac.HMAC)
336
337
337 digest_history = Set()
338 digest_history = Set()
338 digest_history_size = Integer(2**16, config=True,
339 digest_history_size = Integer(2**16, config=True,
339 help="""The maximum number of digests to remember.
340 help="""The maximum number of digests to remember.
340
341
341 The digest history will be culled when it exceeds this value.
342 The digest history will be culled when it exceeds this value.
342 """
343 """
343 )
344 )
344
345
345 keyfile = Unicode('', config=True,
346 keyfile = Unicode('', config=True,
346 help="""path to file containing execution key.""")
347 help="""path to file containing execution key.""")
347 def _keyfile_changed(self, name, old, new):
348 def _keyfile_changed(self, name, old, new):
348 with open(new, 'rb') as f:
349 with open(new, 'rb') as f:
349 self.key = f.read().strip()
350 self.key = f.read().strip()
350
351
351 # for protecting against sends from forks
352 # for protecting against sends from forks
352 pid = Integer()
353 pid = Integer()
353
354
354 # serialization traits:
355 # serialization traits:
355
356
356 pack = Any(default_packer) # the actual packer function
357 pack = Any(default_packer) # the actual packer function
357 def _pack_changed(self, name, old, new):
358 def _pack_changed(self, name, old, new):
358 if not callable(new):
359 if not callable(new):
359 raise TypeError("packer must be callable, not %s"%type(new))
360 raise TypeError("packer must be callable, not %s"%type(new))
360
361
361 unpack = Any(default_unpacker) # the actual packer function
362 unpack = Any(default_unpacker) # the actual packer function
362 def _unpack_changed(self, name, old, new):
363 def _unpack_changed(self, name, old, new):
363 # unpacker is not checked - it is assumed to be
364 # unpacker is not checked - it is assumed to be
364 if not callable(new):
365 if not callable(new):
365 raise TypeError("unpacker must be callable, not %s"%type(new))
366 raise TypeError("unpacker must be callable, not %s"%type(new))
366
367
367 # thresholds:
368 # thresholds:
368 copy_threshold = Integer(2**16, config=True,
369 copy_threshold = Integer(2**16, config=True,
369 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 buffer_threshold = Integer(MAX_BYTES, config=True,
371 buffer_threshold = Integer(MAX_BYTES, config=True,
371 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 item_threshold = Integer(MAX_ITEMS, config=True,
373 item_threshold = Integer(MAX_ITEMS, config=True,
373 help="""The maximum number of items for a container to be introspected for custom serialization.
374 help="""The maximum number of items for a container to be introspected for custom serialization.
374 Containers larger than this are pickled outright.
375 Containers larger than this are pickled outright.
375 """
376 """
376 )
377 )
377
378
378
379
379 def __init__(self, **kwargs):
380 def __init__(self, **kwargs):
380 """create a Session object
381 """create a Session object
381
382
382 Parameters
383 Parameters
383 ----------
384 ----------
384
385
385 debug : bool
386 debug : bool
386 whether to trigger extra debugging statements
387 whether to trigger extra debugging statements
387 packer/unpacker : str : 'json', 'pickle' or import_string
388 packer/unpacker : str : 'json', 'pickle' or import_string
388 importstrings for methods to serialize message parts. If just
389 importstrings for methods to serialize message parts. If just
389 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 Otherwise, the entire importstring must be used.
391 Otherwise, the entire importstring must be used.
391
392
392 The functions must accept at least valid JSON input, and output
393 The functions must accept at least valid JSON input, and output
393 *bytes*.
394 *bytes*.
394
395
395 For example, to use msgpack:
396 For example, to use msgpack:
396 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 pack/unpack : callables
398 pack/unpack : callables
398 You can also set the pack/unpack callables for serialization
399 You can also set the pack/unpack callables for serialization
399 directly.
400 directly.
400 session : unicode (must be ascii)
401 session : unicode (must be ascii)
401 the ID of this Session object. The default is to generate a new
402 the ID of this Session object. The default is to generate a new
402 UUID.
403 UUID.
403 bsession : bytes
404 bsession : bytes
404 The session as bytes
405 The session as bytes
405 username : unicode
406 username : unicode
406 username added to message headers. The default is to ask the OS.
407 username added to message headers. The default is to ask the OS.
407 key : bytes
408 key : bytes
408 The key used to initialize an HMAC signature. If unset, messages
409 The key used to initialize an HMAC signature. If unset, messages
409 will not be signed or checked.
410 will not be signed or checked.
410 signature_scheme : str
411 signature_scheme : str
411 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 where 'HASH' is a hashing function available in Python's hashlib.
413 where 'HASH' is a hashing function available in Python's hashlib.
413 The default is 'hmac-sha256'.
414 The default is 'hmac-sha256'.
414 This is ignored if 'key' is empty.
415 This is ignored if 'key' is empty.
415 keyfile : filepath
416 keyfile : filepath
416 The file containing a key. If this is set, `key` will be
417 The file containing a key. If this is set, `key` will be
417 initialized to the contents of the file.
418 initialized to the contents of the file.
418 """
419 """
419 super(Session, self).__init__(**kwargs)
420 super(Session, self).__init__(**kwargs)
420 self._check_packers()
421 self._check_packers()
421 self.none = self.pack({})
422 self.none = self.pack({})
422 # ensure self._session_default() if necessary, so bsession is defined:
423 # ensure self._session_default() if necessary, so bsession is defined:
423 self.session
424 self.session
424 self.pid = os.getpid()
425 self.pid = os.getpid()
425
426
426 @property
427 @property
427 def msg_id(self):
428 def msg_id(self):
428 """always return new uuid"""
429 """always return new uuid"""
429 return str(uuid.uuid4())
430 return str(uuid.uuid4())
430
431
431 def _check_packers(self):
432 def _check_packers(self):
432 """check packers for datetime support."""
433 """check packers for datetime support."""
433 pack = self.pack
434 pack = self.pack
434 unpack = self.unpack
435 unpack = self.unpack
435
436
436 # check simple serialization
437 # check simple serialization
437 msg = dict(a=[1,'hi'])
438 msg = dict(a=[1,'hi'])
438 try:
439 try:
439 packed = pack(msg)
440 packed = pack(msg)
440 except Exception as e:
441 except Exception as e:
441 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 if self.packer == 'json':
443 if self.packer == 'json':
443 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 else:
445 else:
445 jsonmsg = ""
446 jsonmsg = ""
446 raise ValueError(
447 raise ValueError(
447 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 )
449 )
449
450
450 # ensure packed message is bytes
451 # ensure packed message is bytes
451 if not isinstance(packed, bytes):
452 if not isinstance(packed, bytes):
452 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453
454
454 # check that unpack is pack's inverse
455 # check that unpack is pack's inverse
455 try:
456 try:
456 unpacked = unpack(packed)
457 unpacked = unpack(packed)
457 assert unpacked == msg
458 assert unpacked == msg
458 except Exception as e:
459 except Exception as e:
459 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 if self.packer == 'json':
461 if self.packer == 'json':
461 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 else:
463 else:
463 jsonmsg = ""
464 jsonmsg = ""
464 raise ValueError(
465 raise ValueError(
465 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 )
467 )
467
468
468 # check datetime support
469 # check datetime support
469 msg = dict(t=datetime.now())
470 msg = dict(t=datetime.now())
470 try:
471 try:
471 unpacked = unpack(pack(msg))
472 unpacked = unpack(pack(msg))
472 if isinstance(unpacked['t'], datetime):
473 if isinstance(unpacked['t'], datetime):
473 raise ValueError("Shouldn't deserialize to datetime")
474 raise ValueError("Shouldn't deserialize to datetime")
474 except Exception:
475 except Exception:
475 self.pack = lambda o: pack(squash_dates(o))
476 self.pack = lambda o: pack(squash_dates(o))
476 self.unpack = lambda s: unpack(s)
477 self.unpack = lambda s: unpack(s)
477
478
478 def msg_header(self, msg_type):
479 def msg_header(self, msg_type):
479 return msg_header(self.msg_id, msg_type, self.username, self.session)
480 return msg_header(self.msg_id, msg_type, self.username, self.session)
480
481
481 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 """Return the nested message dict.
483 """Return the nested message dict.
483
484
484 This format is different from what is sent over the wire. The
485 This format is different from what is sent over the wire. The
485 serialize/unserialize methods converts this nested message dict to the wire
486 serialize/unserialize methods converts this nested message dict to the wire
486 format, which is a list of message parts.
487 format, which is a list of message parts.
487 """
488 """
488 msg = {}
489 msg = {}
489 header = self.msg_header(msg_type) if header is None else header
490 header = self.msg_header(msg_type) if header is None else header
490 msg['header'] = header
491 msg['header'] = header
491 msg['msg_id'] = header['msg_id']
492 msg['msg_id'] = header['msg_id']
492 msg['msg_type'] = header['msg_type']
493 msg['msg_type'] = header['msg_type']
493 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 msg['content'] = {} if content is None else content
495 msg['content'] = {} if content is None else content
495 msg['metadata'] = self.metadata.copy()
496 msg['metadata'] = self.metadata.copy()
496 if metadata is not None:
497 if metadata is not None:
497 msg['metadata'].update(metadata)
498 msg['metadata'].update(metadata)
498 return msg
499 return msg
499
500
500 def sign(self, msg_list):
501 def sign(self, msg_list):
501 """Sign a message with HMAC digest. If no auth, return b''.
502 """Sign a message with HMAC digest. If no auth, return b''.
502
503
503 Parameters
504 Parameters
504 ----------
505 ----------
505 msg_list : list
506 msg_list : list
506 The [p_header,p_parent,p_content] part of the message list.
507 The [p_header,p_parent,p_content] part of the message list.
507 """
508 """
508 if self.auth is None:
509 if self.auth is None:
509 return b''
510 return b''
510 h = self.auth.copy()
511 h = self.auth.copy()
511 for m in msg_list:
512 for m in msg_list:
512 h.update(m)
513 h.update(m)
513 return str_to_bytes(h.hexdigest())
514 return str_to_bytes(h.hexdigest())
514
515
515 def serialize(self, msg, ident=None):
516 def serialize(self, msg, ident=None):
516 """Serialize the message components to bytes.
517 """Serialize the message components to bytes.
517
518
518 This is roughly the inverse of unserialize. The serialize/unserialize
519 This is roughly the inverse of unserialize. The serialize/unserialize
519 methods work with full message lists, whereas pack/unpack work with
520 methods work with full message lists, whereas pack/unpack work with
520 the individual message parts in the message list.
521 the individual message parts in the message list.
521
522
522 Parameters
523 Parameters
523 ----------
524 ----------
524 msg : dict or Message
525 msg : dict or Message
525 The nexted message dict as returned by the self.msg method.
526 The nexted message dict as returned by the self.msg method.
526
527
527 Returns
528 Returns
528 -------
529 -------
529 msg_list : list
530 msg_list : list
530 The list of bytes objects to be sent with the format::
531 The list of bytes objects to be sent with the format::
531
532
532 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 p_metadata, p_content, buffer1, buffer2, ...]
534 p_metadata, p_content, buffer1, buffer2, ...]
534
535
535 In this list, the ``p_*`` entities are the packed or serialized
536 In this list, the ``p_*`` entities are the packed or serialized
536 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 """
538 """
538 content = msg.get('content', {})
539 content = msg.get('content', {})
539 if content is None:
540 if content is None:
540 content = self.none
541 content = self.none
541 elif isinstance(content, dict):
542 elif isinstance(content, dict):
542 content = self.pack(content)
543 content = self.pack(content)
543 elif isinstance(content, bytes):
544 elif isinstance(content, bytes):
544 # content is already packed, as in a relayed message
545 # content is already packed, as in a relayed message
545 pass
546 pass
546 elif isinstance(content, unicode_type):
547 elif isinstance(content, unicode_type):
547 # should be bytes, but JSON often spits out unicode
548 # should be bytes, but JSON often spits out unicode
548 content = content.encode('utf8')
549 content = content.encode('utf8')
549 else:
550 else:
550 raise TypeError("Content incorrect type: %s"%type(content))
551 raise TypeError("Content incorrect type: %s"%type(content))
551
552
552 real_message = [self.pack(msg['header']),
553 real_message = [self.pack(msg['header']),
553 self.pack(msg['parent_header']),
554 self.pack(msg['parent_header']),
554 self.pack(msg['metadata']),
555 self.pack(msg['metadata']),
555 content,
556 content,
556 ]
557 ]
557
558
558 to_send = []
559 to_send = []
559
560
560 if isinstance(ident, list):
561 if isinstance(ident, list):
561 # accept list of idents
562 # accept list of idents
562 to_send.extend(ident)
563 to_send.extend(ident)
563 elif ident is not None:
564 elif ident is not None:
564 to_send.append(ident)
565 to_send.append(ident)
565 to_send.append(DELIM)
566 to_send.append(DELIM)
566
567
567 signature = self.sign(real_message)
568 signature = self.sign(real_message)
568 to_send.append(signature)
569 to_send.append(signature)
569
570
570 to_send.extend(real_message)
571 to_send.extend(real_message)
571
572
572 return to_send
573 return to_send
573
574
574 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 buffers=None, track=False, header=None, metadata=None):
576 buffers=None, track=False, header=None, metadata=None):
576 """Build and send a message via stream or socket.
577 """Build and send a message via stream or socket.
577
578
578 The message format used by this function internally is as follows:
579 The message format used by this function internally is as follows:
579
580
580 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 buffer1,buffer2,...]
582 buffer1,buffer2,...]
582
583
583 The serialize/unserialize methods convert the nested message dict into this
584 The serialize/unserialize methods convert the nested message dict into this
584 format.
585 format.
585
586
586 Parameters
587 Parameters
587 ----------
588 ----------
588
589
589 stream : zmq.Socket or ZMQStream
590 stream : zmq.Socket or ZMQStream
590 The socket-like object used to send the data.
591 The socket-like object used to send the data.
591 msg_or_type : str or Message/dict
592 msg_or_type : str or Message/dict
592 Normally, msg_or_type will be a msg_type unless a message is being
593 Normally, msg_or_type will be a msg_type unless a message is being
593 sent more than once. If a header is supplied, this can be set to
594 sent more than once. If a header is supplied, this can be set to
594 None and the msg_type will be pulled from the header.
595 None and the msg_type will be pulled from the header.
595
596
596 content : dict or None
597 content : dict or None
597 The content of the message (ignored if msg_or_type is a message).
598 The content of the message (ignored if msg_or_type is a message).
598 header : dict or None
599 header : dict or None
599 The header dict for the message (ignored if msg_to_type is a message).
600 The header dict for the message (ignored if msg_to_type is a message).
600 parent : Message or dict or None
601 parent : Message or dict or None
601 The parent or parent header describing the parent of this message
602 The parent or parent header describing the parent of this message
602 (ignored if msg_or_type is a message).
603 (ignored if msg_or_type is a message).
603 ident : bytes or list of bytes
604 ident : bytes or list of bytes
604 The zmq.IDENTITY routing path.
605 The zmq.IDENTITY routing path.
605 metadata : dict or None
606 metadata : dict or None
606 The metadata describing the message
607 The metadata describing the message
607 buffers : list or None
608 buffers : list or None
608 The already-serialized buffers to be appended to the message.
609 The already-serialized buffers to be appended to the message.
609 track : bool
610 track : bool
610 Whether to track. Only for use with Sockets, because ZMQStream
611 Whether to track. Only for use with Sockets, because ZMQStream
611 objects cannot track messages.
612 objects cannot track messages.
612
613
613
614
614 Returns
615 Returns
615 -------
616 -------
616 msg : dict
617 msg : dict
617 The constructed message.
618 The constructed message.
618 """
619 """
619 if not isinstance(stream, zmq.Socket):
620 if not isinstance(stream, zmq.Socket):
620 # ZMQStreams and dummy sockets do not support tracking.
621 # ZMQStreams and dummy sockets do not support tracking.
621 track = False
622 track = False
622
623
623 if isinstance(msg_or_type, (Message, dict)):
624 if isinstance(msg_or_type, (Message, dict)):
624 # We got a Message or message dict, not a msg_type so don't
625 # We got a Message or message dict, not a msg_type so don't
625 # build a new Message.
626 # build a new Message.
626 msg = msg_or_type
627 msg = msg_or_type
627 else:
628 else:
628 msg = self.msg(msg_or_type, content=content, parent=parent,
629 msg = self.msg(msg_or_type, content=content, parent=parent,
629 header=header, metadata=metadata)
630 header=header, metadata=metadata)
630 if not os.getpid() == self.pid:
631 if not os.getpid() == self.pid:
631 io.rprint("WARNING: attempted to send message from fork")
632 io.rprint("WARNING: attempted to send message from fork")
632 io.rprint(msg)
633 io.rprint(msg)
633 return
634 return
634 buffers = [] if buffers is None else buffers
635 buffers = [] if buffers is None else buffers
635 to_send = self.serialize(msg, ident)
636 to_send = self.serialize(msg, ident)
636 to_send.extend(buffers)
637 to_send.extend(buffers)
637 longest = max([ len(s) for s in to_send ])
638 longest = max([ len(s) for s in to_send ])
638 copy = (longest < self.copy_threshold)
639 copy = (longest < self.copy_threshold)
639
640
640 if buffers and track and not copy:
641 if buffers and track and not copy:
641 # only really track when we are doing zero-copy buffers
642 # only really track when we are doing zero-copy buffers
642 tracker = stream.send_multipart(to_send, copy=False, track=True)
643 tracker = stream.send_multipart(to_send, copy=False, track=True)
643 else:
644 else:
644 # use dummy tracker, which will be done immediately
645 # use dummy tracker, which will be done immediately
645 tracker = DONE
646 tracker = DONE
646 stream.send_multipart(to_send, copy=copy)
647 stream.send_multipart(to_send, copy=copy)
647
648
648 if self.debug:
649 if self.debug:
649 pprint.pprint(msg)
650 pprint.pprint(msg)
650 pprint.pprint(to_send)
651 pprint.pprint(to_send)
651 pprint.pprint(buffers)
652 pprint.pprint(buffers)
652
653
653 msg['tracker'] = tracker
654 msg['tracker'] = tracker
654
655
655 return msg
656 return msg
656
657
657 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
658 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
658 """Send a raw message via ident path.
659 """Send a raw message via ident path.
659
660
660 This method is used to send a already serialized message.
661 This method is used to send a already serialized message.
661
662
662 Parameters
663 Parameters
663 ----------
664 ----------
664 stream : ZMQStream or Socket
665 stream : ZMQStream or Socket
665 The ZMQ stream or socket to use for sending the message.
666 The ZMQ stream or socket to use for sending the message.
666 msg_list : list
667 msg_list : list
667 The serialized list of messages to send. This only includes the
668 The serialized list of messages to send. This only includes the
668 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
669 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
669 the message.
670 the message.
670 ident : ident or list
671 ident : ident or list
671 A single ident or a list of idents to use in sending.
672 A single ident or a list of idents to use in sending.
672 """
673 """
673 to_send = []
674 to_send = []
674 if isinstance(ident, bytes):
675 if isinstance(ident, bytes):
675 ident = [ident]
676 ident = [ident]
676 if ident is not None:
677 if ident is not None:
677 to_send.extend(ident)
678 to_send.extend(ident)
678
679
679 to_send.append(DELIM)
680 to_send.append(DELIM)
680 to_send.append(self.sign(msg_list))
681 to_send.append(self.sign(msg_list))
681 to_send.extend(msg_list)
682 to_send.extend(msg_list)
682 stream.send_multipart(to_send, flags, copy=copy)
683 stream.send_multipart(to_send, flags, copy=copy)
683
684
684 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
685 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
685 """Receive and unpack a message.
686 """Receive and unpack a message.
686
687
687 Parameters
688 Parameters
688 ----------
689 ----------
689 socket : ZMQStream or Socket
690 socket : ZMQStream or Socket
690 The socket or stream to use in receiving.
691 The socket or stream to use in receiving.
691
692
692 Returns
693 Returns
693 -------
694 -------
694 [idents], msg
695 [idents], msg
695 [idents] is a list of idents and msg is a nested message dict of
696 [idents] is a list of idents and msg is a nested message dict of
696 same format as self.msg returns.
697 same format as self.msg returns.
697 """
698 """
698 if isinstance(socket, ZMQStream):
699 if isinstance(socket, ZMQStream):
699 socket = socket.socket
700 socket = socket.socket
700 try:
701 try:
701 msg_list = socket.recv_multipart(mode, copy=copy)
702 msg_list = socket.recv_multipart(mode, copy=copy)
702 except zmq.ZMQError as e:
703 except zmq.ZMQError as e:
703 if e.errno == zmq.EAGAIN:
704 if e.errno == zmq.EAGAIN:
704 # We can convert EAGAIN to None as we know in this case
705 # We can convert EAGAIN to None as we know in this case
705 # recv_multipart won't return None.
706 # recv_multipart won't return None.
706 return None,None
707 return None,None
707 else:
708 else:
708 raise
709 raise
709 # split multipart message into identity list and message dict
710 # split multipart message into identity list and message dict
710 # invalid large messages can cause very expensive string comparisons
711 # invalid large messages can cause very expensive string comparisons
711 idents, msg_list = self.feed_identities(msg_list, copy)
712 idents, msg_list = self.feed_identities(msg_list, copy)
712 try:
713 try:
713 return idents, self.unserialize(msg_list, content=content, copy=copy)
714 return idents, self.unserialize(msg_list, content=content, copy=copy)
714 except Exception as e:
715 except Exception as e:
715 # TODO: handle it
716 # TODO: handle it
716 raise e
717 raise e
717
718
718 def feed_identities(self, msg_list, copy=True):
719 def feed_identities(self, msg_list, copy=True):
719 """Split the identities from the rest of the message.
720 """Split the identities from the rest of the message.
720
721
721 Feed until DELIM is reached, then return the prefix as idents and
722 Feed until DELIM is reached, then return the prefix as idents and
722 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
723 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
723 but that would be silly.
724 but that would be silly.
724
725
725 Parameters
726 Parameters
726 ----------
727 ----------
727 msg_list : a list of Message or bytes objects
728 msg_list : a list of Message or bytes objects
728 The message to be split.
729 The message to be split.
729 copy : bool
730 copy : bool
730 flag determining whether the arguments are bytes or Messages
731 flag determining whether the arguments are bytes or Messages
731
732
732 Returns
733 Returns
733 -------
734 -------
734 (idents, msg_list) : two lists
735 (idents, msg_list) : two lists
735 idents will always be a list of bytes, each of which is a ZMQ
736 idents will always be a list of bytes, each of which is a ZMQ
736 identity. msg_list will be a list of bytes or zmq.Messages of the
737 identity. msg_list will be a list of bytes or zmq.Messages of the
737 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
738 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
738 should be unpackable/unserializable via self.unserialize at this
739 should be unpackable/unserializable via self.unserialize at this
739 point.
740 point.
740 """
741 """
741 if copy:
742 if copy:
742 idx = msg_list.index(DELIM)
743 idx = msg_list.index(DELIM)
743 return msg_list[:idx], msg_list[idx+1:]
744 return msg_list[:idx], msg_list[idx+1:]
744 else:
745 else:
745 failed = True
746 failed = True
746 for idx,m in enumerate(msg_list):
747 for idx,m in enumerate(msg_list):
747 if m.bytes == DELIM:
748 if m.bytes == DELIM:
748 failed = False
749 failed = False
749 break
750 break
750 if failed:
751 if failed:
751 raise ValueError("DELIM not in msg_list")
752 raise ValueError("DELIM not in msg_list")
752 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
753 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
753 return [m.bytes for m in idents], msg_list
754 return [m.bytes for m in idents], msg_list
754
755
755 def _add_digest(self, signature):
756 def _add_digest(self, signature):
756 """add a digest to history to protect against replay attacks"""
757 """add a digest to history to protect against replay attacks"""
757 if self.digest_history_size == 0:
758 if self.digest_history_size == 0:
758 # no history, never add digests
759 # no history, never add digests
759 return
760 return
760
761
761 self.digest_history.add(signature)
762 self.digest_history.add(signature)
762 if len(self.digest_history) > self.digest_history_size:
763 if len(self.digest_history) > self.digest_history_size:
763 # threshold reached, cull 10%
764 # threshold reached, cull 10%
764 self._cull_digest_history()
765 self._cull_digest_history()
765
766
766 def _cull_digest_history(self):
767 def _cull_digest_history(self):
767 """cull the digest history
768 """cull the digest history
768
769
769 Removes a randomly selected 10% of the digest history
770 Removes a randomly selected 10% of the digest history
770 """
771 """
771 current = len(self.digest_history)
772 current = len(self.digest_history)
772 n_to_cull = max(int(current // 10), current - self.digest_history_size)
773 n_to_cull = max(int(current // 10), current - self.digest_history_size)
773 if n_to_cull >= current:
774 if n_to_cull >= current:
774 self.digest_history = set()
775 self.digest_history = set()
775 return
776 return
776 to_cull = random.sample(self.digest_history, n_to_cull)
777 to_cull = random.sample(self.digest_history, n_to_cull)
777 self.digest_history.difference_update(to_cull)
778 self.digest_history.difference_update(to_cull)
778
779
779 def unserialize(self, msg_list, content=True, copy=True):
780 def unserialize(self, msg_list, content=True, copy=True):
780 """Unserialize a msg_list to a nested message dict.
781 """Unserialize a msg_list to a nested message dict.
781
782
782 This is roughly the inverse of serialize. The serialize/unserialize
783 This is roughly the inverse of serialize. The serialize/unserialize
783 methods work with full message lists, whereas pack/unpack work with
784 methods work with full message lists, whereas pack/unpack work with
784 the individual message parts in the message list.
785 the individual message parts in the message list.
785
786
786 Parameters
787 Parameters
787 ----------
788 ----------
788 msg_list : list of bytes or Message objects
789 msg_list : list of bytes or Message objects
789 The list of message parts of the form [HMAC,p_header,p_parent,
790 The list of message parts of the form [HMAC,p_header,p_parent,
790 p_metadata,p_content,buffer1,buffer2,...].
791 p_metadata,p_content,buffer1,buffer2,...].
791 content : bool (True)
792 content : bool (True)
792 Whether to unpack the content dict (True), or leave it packed
793 Whether to unpack the content dict (True), or leave it packed
793 (False).
794 (False).
794 copy : bool (True)
795 copy : bool (True)
795 Whether to return the bytes (True), or the non-copying Message
796 Whether to return the bytes (True), or the non-copying Message
796 object in each place (False).
797 object in each place (False).
797
798
798 Returns
799 Returns
799 -------
800 -------
800 msg : dict
801 msg : dict
801 The nested message dict with top-level keys [header, parent_header,
802 The nested message dict with top-level keys [header, parent_header,
802 content, buffers].
803 content, buffers].
803 """
804 """
804 minlen = 5
805 minlen = 5
805 message = {}
806 message = {}
806 if not copy:
807 if not copy:
807 for i in range(minlen):
808 for i in range(minlen):
808 msg_list[i] = msg_list[i].bytes
809 msg_list[i] = msg_list[i].bytes
809 if self.auth is not None:
810 if self.auth is not None:
810 signature = msg_list[0]
811 signature = msg_list[0]
811 if not signature:
812 if not signature:
812 raise ValueError("Unsigned Message")
813 raise ValueError("Unsigned Message")
813 if signature in self.digest_history:
814 if signature in self.digest_history:
814 raise ValueError("Duplicate Signature: %r" % signature)
815 raise ValueError("Duplicate Signature: %r" % signature)
815 self._add_digest(signature)
816 self._add_digest(signature)
816 check = self.sign(msg_list[1:5])
817 check = self.sign(msg_list[1:5])
817 if not signature == check:
818 if not signature == check:
818 raise ValueError("Invalid Signature: %r" % signature)
819 raise ValueError("Invalid Signature: %r" % signature)
819 if not len(msg_list) >= minlen:
820 if not len(msg_list) >= minlen:
820 raise TypeError("malformed message, must have at least %i elements"%minlen)
821 raise TypeError("malformed message, must have at least %i elements"%minlen)
821 header = self.unpack(msg_list[1])
822 header = self.unpack(msg_list[1])
822 message['header'] = extract_dates(header)
823 message['header'] = extract_dates(header)
823 message['msg_id'] = header['msg_id']
824 message['msg_id'] = header['msg_id']
824 message['msg_type'] = header['msg_type']
825 message['msg_type'] = header['msg_type']
825 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
826 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
826 message['metadata'] = self.unpack(msg_list[3])
827 message['metadata'] = self.unpack(msg_list[3])
827 if content:
828 if content:
828 message['content'] = self.unpack(msg_list[4])
829 message['content'] = self.unpack(msg_list[4])
829 else:
830 else:
830 message['content'] = msg_list[4]
831 message['content'] = msg_list[4]
831
832
832 message['buffers'] = msg_list[5:]
833 message['buffers'] = msg_list[5:]
833 return message
834 return message
834
835
835 def test_msg2obj():
836 def test_msg2obj():
836 am = dict(x=1)
837 am = dict(x=1)
837 ao = Message(am)
838 ao = Message(am)
838 assert ao.x == am['x']
839 assert ao.x == am['x']
839
840
840 am['y'] = dict(z=1)
841 am['y'] = dict(z=1)
841 ao = Message(am)
842 ao = Message(am)
842 assert ao.y.z == am['y']['z']
843 assert ao.y.z == am['y']['z']
843
844
844 k1, k2 = 'y', 'z'
845 k1, k2 = 'y', 'z'
845 assert ao[k1][k2] == am[k1][k2]
846 assert ao[k1][k2] == am[k1][k2]
846
847
847 am2 = dict(ao)
848 am2 = dict(ao)
848 assert am['x'] == am2['x']
849 assert am['x'] == am2['x']
849 assert am['y']['z'] == am2['y']['z']
850 assert am['y']['z'] == am2['y']['z']
850
851
@@ -1,385 +1,390
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Pickle related utilities. Perhaps this should be called 'can'."""
3 """Pickle related utilities. Perhaps this should be called 'can'."""
4
4
5 # Copyright (c) IPython Development Team.
5 # Copyright (c) IPython Development Team.
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7
7
8 import copy
8 import copy
9 import logging
9 import logging
10 import sys
10 import sys
11 from types import FunctionType
11 from types import FunctionType
12
12
13 try:
13 try:
14 import cPickle as pickle
14 import cPickle as pickle
15 except ImportError:
15 except ImportError:
16 import pickle
16 import pickle
17
17
18 from . import codeutil # This registers a hook when it's imported
18 from . import codeutil # This registers a hook when it's imported
19 from . import py3compat
19 from . import py3compat
20 from .importstring import import_item
20 from .importstring import import_item
21 from .py3compat import string_types, iteritems
21 from .py3compat import string_types, iteritems
22
22
23 from IPython.config import Application
23 from IPython.config import Application
24
24
25 if py3compat.PY3:
25 if py3compat.PY3:
26 buffer = memoryview
26 buffer = memoryview
27 class_type = type
27 class_type = type
28 closure_attr = '__closure__'
28 closure_attr = '__closure__'
29 else:
29 else:
30 from types import ClassType
30 from types import ClassType
31 class_type = (type, ClassType)
31 class_type = (type, ClassType)
32 closure_attr = 'func_closure'
32 closure_attr = 'func_closure'
33
33
34 try:
35 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
36 except AttributeError:
37 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
38
34 #-------------------------------------------------------------------------------
39 #-------------------------------------------------------------------------------
35 # Functions
40 # Functions
36 #-------------------------------------------------------------------------------
41 #-------------------------------------------------------------------------------
37
42
38
43
39 def use_dill():
44 def use_dill():
40 """use dill to expand serialization support
45 """use dill to expand serialization support
41
46
42 adds support for object methods and closures to serialization.
47 adds support for object methods and closures to serialization.
43 """
48 """
44 # import dill causes most of the magic
49 # import dill causes most of the magic
45 import dill
50 import dill
46
51
47 # dill doesn't work with cPickle,
52 # dill doesn't work with cPickle,
48 # tell the two relevant modules to use plain pickle
53 # tell the two relevant modules to use plain pickle
49
54
50 global pickle
55 global pickle
51 pickle = dill
56 pickle = dill
52
57
53 try:
58 try:
54 from IPython.kernel.zmq import serialize
59 from IPython.kernel.zmq import serialize
55 except ImportError:
60 except ImportError:
56 pass
61 pass
57 else:
62 else:
58 serialize.pickle = dill
63 serialize.pickle = dill
59
64
60 # disable special function handling, let dill take care of it
65 # disable special function handling, let dill take care of it
61 can_map.pop(FunctionType, None)
66 can_map.pop(FunctionType, None)
62
67
63
68
64 #-------------------------------------------------------------------------------
69 #-------------------------------------------------------------------------------
65 # Classes
70 # Classes
66 #-------------------------------------------------------------------------------
71 #-------------------------------------------------------------------------------
67
72
68
73
69 class CannedObject(object):
74 class CannedObject(object):
70 def __init__(self, obj, keys=[], hook=None):
75 def __init__(self, obj, keys=[], hook=None):
71 """can an object for safe pickling
76 """can an object for safe pickling
72
77
73 Parameters
78 Parameters
74 ==========
79 ==========
75
80
76 obj:
81 obj:
77 The object to be canned
82 The object to be canned
78 keys: list (optional)
83 keys: list (optional)
79 list of attribute names that will be explicitly canned / uncanned
84 list of attribute names that will be explicitly canned / uncanned
80 hook: callable (optional)
85 hook: callable (optional)
81 An optional extra callable,
86 An optional extra callable,
82 which can do additional processing of the uncanned object.
87 which can do additional processing of the uncanned object.
83
88
84 large data may be offloaded into the buffers list,
89 large data may be offloaded into the buffers list,
85 used for zero-copy transfers.
90 used for zero-copy transfers.
86 """
91 """
87 self.keys = keys
92 self.keys = keys
88 self.obj = copy.copy(obj)
93 self.obj = copy.copy(obj)
89 self.hook = can(hook)
94 self.hook = can(hook)
90 for key in keys:
95 for key in keys:
91 setattr(self.obj, key, can(getattr(obj, key)))
96 setattr(self.obj, key, can(getattr(obj, key)))
92
97
93 self.buffers = []
98 self.buffers = []
94
99
95 def get_object(self, g=None):
100 def get_object(self, g=None):
96 if g is None:
101 if g is None:
97 g = {}
102 g = {}
98 obj = self.obj
103 obj = self.obj
99 for key in self.keys:
104 for key in self.keys:
100 setattr(obj, key, uncan(getattr(obj, key), g))
105 setattr(obj, key, uncan(getattr(obj, key), g))
101
106
102 if self.hook:
107 if self.hook:
103 self.hook = uncan(self.hook, g)
108 self.hook = uncan(self.hook, g)
104 self.hook(obj, g)
109 self.hook(obj, g)
105 return self.obj
110 return self.obj
106
111
107
112
108 class Reference(CannedObject):
113 class Reference(CannedObject):
109 """object for wrapping a remote reference by name."""
114 """object for wrapping a remote reference by name."""
110 def __init__(self, name):
115 def __init__(self, name):
111 if not isinstance(name, string_types):
116 if not isinstance(name, string_types):
112 raise TypeError("illegal name: %r"%name)
117 raise TypeError("illegal name: %r"%name)
113 self.name = name
118 self.name = name
114 self.buffers = []
119 self.buffers = []
115
120
116 def __repr__(self):
121 def __repr__(self):
117 return "<Reference: %r>"%self.name
122 return "<Reference: %r>"%self.name
118
123
119 def get_object(self, g=None):
124 def get_object(self, g=None):
120 if g is None:
125 if g is None:
121 g = {}
126 g = {}
122
127
123 return eval(self.name, g)
128 return eval(self.name, g)
124
129
125
130
126 class CannedFunction(CannedObject):
131 class CannedFunction(CannedObject):
127
132
128 def __init__(self, f):
133 def __init__(self, f):
129 self._check_type(f)
134 self._check_type(f)
130 self.code = f.__code__
135 self.code = f.__code__
131 if f.__defaults__:
136 if f.__defaults__:
132 self.defaults = [ can(fd) for fd in f.__defaults__ ]
137 self.defaults = [ can(fd) for fd in f.__defaults__ ]
133 else:
138 else:
134 self.defaults = None
139 self.defaults = None
135
140
136 if getattr(f, closure_attr, None):
141 if getattr(f, closure_attr, None):
137 raise ValueError("Sorry, cannot pickle functions with closures.")
142 raise ValueError("Sorry, cannot pickle functions with closures.")
138 self.module = f.__module__ or '__main__'
143 self.module = f.__module__ or '__main__'
139 self.__name__ = f.__name__
144 self.__name__ = f.__name__
140 self.buffers = []
145 self.buffers = []
141
146
142 def _check_type(self, obj):
147 def _check_type(self, obj):
143 assert isinstance(obj, FunctionType), "Not a function type"
148 assert isinstance(obj, FunctionType), "Not a function type"
144
149
145 def get_object(self, g=None):
150 def get_object(self, g=None):
146 # try to load function back into its module:
151 # try to load function back into its module:
147 if not self.module.startswith('__'):
152 if not self.module.startswith('__'):
148 __import__(self.module)
153 __import__(self.module)
149 g = sys.modules[self.module].__dict__
154 g = sys.modules[self.module].__dict__
150
155
151 if g is None:
156 if g is None:
152 g = {}
157 g = {}
153 if self.defaults:
158 if self.defaults:
154 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
159 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
155 else:
160 else:
156 defaults = None
161 defaults = None
157 newFunc = FunctionType(self.code, g, self.__name__, defaults)
162 newFunc = FunctionType(self.code, g, self.__name__, defaults)
158 return newFunc
163 return newFunc
159
164
160 class CannedClass(CannedObject):
165 class CannedClass(CannedObject):
161
166
162 def __init__(self, cls):
167 def __init__(self, cls):
163 self._check_type(cls)
168 self._check_type(cls)
164 self.name = cls.__name__
169 self.name = cls.__name__
165 self.old_style = not isinstance(cls, type)
170 self.old_style = not isinstance(cls, type)
166 self._canned_dict = {}
171 self._canned_dict = {}
167 for k,v in cls.__dict__.items():
172 for k,v in cls.__dict__.items():
168 if k not in ('__weakref__', '__dict__'):
173 if k not in ('__weakref__', '__dict__'):
169 self._canned_dict[k] = can(v)
174 self._canned_dict[k] = can(v)
170 if self.old_style:
175 if self.old_style:
171 mro = []
176 mro = []
172 else:
177 else:
173 mro = cls.mro()
178 mro = cls.mro()
174
179
175 self.parents = [ can(c) for c in mro[1:] ]
180 self.parents = [ can(c) for c in mro[1:] ]
176 self.buffers = []
181 self.buffers = []
177
182
178 def _check_type(self, obj):
183 def _check_type(self, obj):
179 assert isinstance(obj, class_type), "Not a class type"
184 assert isinstance(obj, class_type), "Not a class type"
180
185
181 def get_object(self, g=None):
186 def get_object(self, g=None):
182 parents = tuple(uncan(p, g) for p in self.parents)
187 parents = tuple(uncan(p, g) for p in self.parents)
183 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
188 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
184
189
185 class CannedArray(CannedObject):
190 class CannedArray(CannedObject):
186 def __init__(self, obj):
191 def __init__(self, obj):
187 from numpy import ascontiguousarray
192 from numpy import ascontiguousarray
188 self.shape = obj.shape
193 self.shape = obj.shape
189 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
194 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
190 self.pickled = False
195 self.pickled = False
191 if sum(obj.shape) == 0:
196 if sum(obj.shape) == 0:
192 self.pickled = True
197 self.pickled = True
193 elif obj.dtype == 'O':
198 elif obj.dtype == 'O':
194 # can't handle object dtype with buffer approach
199 # can't handle object dtype with buffer approach
195 self.pickled = True
200 self.pickled = True
196 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
201 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
197 self.pickled = True
202 self.pickled = True
198 if self.pickled:
203 if self.pickled:
199 # just pickle it
204 # just pickle it
200 self.buffers = [pickle.dumps(obj, -1)]
205 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
201 else:
206 else:
202 # ensure contiguous
207 # ensure contiguous
203 obj = ascontiguousarray(obj, dtype=None)
208 obj = ascontiguousarray(obj, dtype=None)
204 self.buffers = [buffer(obj)]
209 self.buffers = [buffer(obj)]
205
210
206 def get_object(self, g=None):
211 def get_object(self, g=None):
207 from numpy import frombuffer
212 from numpy import frombuffer
208 data = self.buffers[0]
213 data = self.buffers[0]
209 if self.pickled:
214 if self.pickled:
210 # no shape, we just pickled it
215 # no shape, we just pickled it
211 return pickle.loads(data)
216 return pickle.loads(data)
212 else:
217 else:
213 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
218 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
214
219
215
220
216 class CannedBytes(CannedObject):
221 class CannedBytes(CannedObject):
217 wrap = bytes
222 wrap = bytes
218 def __init__(self, obj):
223 def __init__(self, obj):
219 self.buffers = [obj]
224 self.buffers = [obj]
220
225
221 def get_object(self, g=None):
226 def get_object(self, g=None):
222 data = self.buffers[0]
227 data = self.buffers[0]
223 return self.wrap(data)
228 return self.wrap(data)
224
229
225 def CannedBuffer(CannedBytes):
230 def CannedBuffer(CannedBytes):
226 wrap = buffer
231 wrap = buffer
227
232
228 #-------------------------------------------------------------------------------
233 #-------------------------------------------------------------------------------
229 # Functions
234 # Functions
230 #-------------------------------------------------------------------------------
235 #-------------------------------------------------------------------------------
231
236
232 def _logger():
237 def _logger():
233 """get the logger for the current Application
238 """get the logger for the current Application
234
239
235 the root logger will be used if no Application is running
240 the root logger will be used if no Application is running
236 """
241 """
237 if Application.initialized():
242 if Application.initialized():
238 logger = Application.instance().log
243 logger = Application.instance().log
239 else:
244 else:
240 logger = logging.getLogger()
245 logger = logging.getLogger()
241 if not logger.handlers:
246 if not logger.handlers:
242 logging.basicConfig()
247 logging.basicConfig()
243
248
244 return logger
249 return logger
245
250
246 def _import_mapping(mapping, original=None):
251 def _import_mapping(mapping, original=None):
247 """import any string-keys in a type mapping
252 """import any string-keys in a type mapping
248
253
249 """
254 """
250 log = _logger()
255 log = _logger()
251 log.debug("Importing canning map")
256 log.debug("Importing canning map")
252 for key,value in list(mapping.items()):
257 for key,value in list(mapping.items()):
253 if isinstance(key, string_types):
258 if isinstance(key, string_types):
254 try:
259 try:
255 cls = import_item(key)
260 cls = import_item(key)
256 except Exception:
261 except Exception:
257 if original and key not in original:
262 if original and key not in original:
258 # only message on user-added classes
263 # only message on user-added classes
259 log.error("canning class not importable: %r", key, exc_info=True)
264 log.error("canning class not importable: %r", key, exc_info=True)
260 mapping.pop(key)
265 mapping.pop(key)
261 else:
266 else:
262 mapping[cls] = mapping.pop(key)
267 mapping[cls] = mapping.pop(key)
263
268
264 def istype(obj, check):
269 def istype(obj, check):
265 """like isinstance(obj, check), but strict
270 """like isinstance(obj, check), but strict
266
271
267 This won't catch subclasses.
272 This won't catch subclasses.
268 """
273 """
269 if isinstance(check, tuple):
274 if isinstance(check, tuple):
270 for cls in check:
275 for cls in check:
271 if type(obj) is cls:
276 if type(obj) is cls:
272 return True
277 return True
273 return False
278 return False
274 else:
279 else:
275 return type(obj) is check
280 return type(obj) is check
276
281
277 def can(obj):
282 def can(obj):
278 """prepare an object for pickling"""
283 """prepare an object for pickling"""
279
284
280 import_needed = False
285 import_needed = False
281
286
282 for cls,canner in iteritems(can_map):
287 for cls,canner in iteritems(can_map):
283 if isinstance(cls, string_types):
288 if isinstance(cls, string_types):
284 import_needed = True
289 import_needed = True
285 break
290 break
286 elif istype(obj, cls):
291 elif istype(obj, cls):
287 return canner(obj)
292 return canner(obj)
288
293
289 if import_needed:
294 if import_needed:
290 # perform can_map imports, then try again
295 # perform can_map imports, then try again
291 # this will usually only happen once
296 # this will usually only happen once
292 _import_mapping(can_map, _original_can_map)
297 _import_mapping(can_map, _original_can_map)
293 return can(obj)
298 return can(obj)
294
299
295 return obj
300 return obj
296
301
297 def can_class(obj):
302 def can_class(obj):
298 if isinstance(obj, class_type) and obj.__module__ == '__main__':
303 if isinstance(obj, class_type) and obj.__module__ == '__main__':
299 return CannedClass(obj)
304 return CannedClass(obj)
300 else:
305 else:
301 return obj
306 return obj
302
307
303 def can_dict(obj):
308 def can_dict(obj):
304 """can the *values* of a dict"""
309 """can the *values* of a dict"""
305 if istype(obj, dict):
310 if istype(obj, dict):
306 newobj = {}
311 newobj = {}
307 for k, v in iteritems(obj):
312 for k, v in iteritems(obj):
308 newobj[k] = can(v)
313 newobj[k] = can(v)
309 return newobj
314 return newobj
310 else:
315 else:
311 return obj
316 return obj
312
317
313 sequence_types = (list, tuple, set)
318 sequence_types = (list, tuple, set)
314
319
315 def can_sequence(obj):
320 def can_sequence(obj):
316 """can the elements of a sequence"""
321 """can the elements of a sequence"""
317 if istype(obj, sequence_types):
322 if istype(obj, sequence_types):
318 t = type(obj)
323 t = type(obj)
319 return t([can(i) for i in obj])
324 return t([can(i) for i in obj])
320 else:
325 else:
321 return obj
326 return obj
322
327
323 def uncan(obj, g=None):
328 def uncan(obj, g=None):
324 """invert canning"""
329 """invert canning"""
325
330
326 import_needed = False
331 import_needed = False
327 for cls,uncanner in iteritems(uncan_map):
332 for cls,uncanner in iteritems(uncan_map):
328 if isinstance(cls, string_types):
333 if isinstance(cls, string_types):
329 import_needed = True
334 import_needed = True
330 break
335 break
331 elif isinstance(obj, cls):
336 elif isinstance(obj, cls):
332 return uncanner(obj, g)
337 return uncanner(obj, g)
333
338
334 if import_needed:
339 if import_needed:
335 # perform uncan_map imports, then try again
340 # perform uncan_map imports, then try again
336 # this will usually only happen once
341 # this will usually only happen once
337 _import_mapping(uncan_map, _original_uncan_map)
342 _import_mapping(uncan_map, _original_uncan_map)
338 return uncan(obj, g)
343 return uncan(obj, g)
339
344
340 return obj
345 return obj
341
346
342 def uncan_dict(obj, g=None):
347 def uncan_dict(obj, g=None):
343 if istype(obj, dict):
348 if istype(obj, dict):
344 newobj = {}
349 newobj = {}
345 for k, v in iteritems(obj):
350 for k, v in iteritems(obj):
346 newobj[k] = uncan(v,g)
351 newobj[k] = uncan(v,g)
347 return newobj
352 return newobj
348 else:
353 else:
349 return obj
354 return obj
350
355
351 def uncan_sequence(obj, g=None):
356 def uncan_sequence(obj, g=None):
352 if istype(obj, sequence_types):
357 if istype(obj, sequence_types):
353 t = type(obj)
358 t = type(obj)
354 return t([uncan(i,g) for i in obj])
359 return t([uncan(i,g) for i in obj])
355 else:
360 else:
356 return obj
361 return obj
357
362
358 def _uncan_dependent_hook(dep, g=None):
363 def _uncan_dependent_hook(dep, g=None):
359 dep.check_dependency()
364 dep.check_dependency()
360
365
361 def can_dependent(obj):
366 def can_dependent(obj):
362 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
367 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
363
368
364 #-------------------------------------------------------------------------------
369 #-------------------------------------------------------------------------------
365 # API dictionaries
370 # API dictionaries
366 #-------------------------------------------------------------------------------
371 #-------------------------------------------------------------------------------
367
372
368 # These dicts can be extended for custom serialization of new objects
373 # These dicts can be extended for custom serialization of new objects
369
374
370 can_map = {
375 can_map = {
371 'IPython.parallel.dependent' : can_dependent,
376 'IPython.parallel.dependent' : can_dependent,
372 'numpy.ndarray' : CannedArray,
377 'numpy.ndarray' : CannedArray,
373 FunctionType : CannedFunction,
378 FunctionType : CannedFunction,
374 bytes : CannedBytes,
379 bytes : CannedBytes,
375 buffer : CannedBuffer,
380 buffer : CannedBuffer,
376 class_type : can_class,
381 class_type : can_class,
377 }
382 }
378
383
379 uncan_map = {
384 uncan_map = {
380 CannedObject : lambda obj, g: obj.get_object(g),
385 CannedObject : lambda obj, g: obj.get_object(g),
381 }
386 }
382
387
383 # for use in _import_mapping:
388 # for use in _import_mapping:
384 _original_can_map = can_map.copy()
389 _original_can_map = can_map.copy()
385 _original_uncan_map = uncan_map.copy()
390 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now