##// END OF EJS Templates
Merge pull request #6029 from minrk/dumps-protocol...
Thomas Kluyver -
r17067:4a94456c merge
parent child Browse files
Show More
@@ -1,198 +1,185 b''
1 """serialization utilities for apply messages
1 """serialization utilities for apply messages"""
2
2
3 Authors:
3 # Copyright (c) IPython Development Team.
4
4 # Distributed under the terms of the Modified BSD License.
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
13
14 #-----------------------------------------------------------------------------
15 # Imports
16 #-----------------------------------------------------------------------------
17
5
18 try:
6 try:
19 import cPickle
7 import cPickle
20 pickle = cPickle
8 pickle = cPickle
21 except:
9 except:
22 cPickle = None
10 cPickle = None
23 import pickle
11 import pickle
24
12
25
26 # IPython imports
13 # IPython imports
27 from IPython.utils import py3compat
14 from IPython.utils import py3compat
28 from IPython.utils.data import flatten
15 from IPython.utils.data import flatten
29 from IPython.utils.pickleutil import (
16 from IPython.utils.pickleutil import (
30 can, uncan, can_sequence, uncan_sequence, CannedObject,
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
31 istype, sequence_types,
18 istype, sequence_types, PICKLE_PROTOCOL,
32 )
19 )
33
20
34 if py3compat.PY3:
21 if py3compat.PY3:
35 buffer = memoryview
22 buffer = memoryview
36
23
37 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
38 # Serialization Functions
25 # Serialization Functions
39 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
40
27
41 # default values for the thresholds:
28 # default values for the thresholds:
42 MAX_ITEMS = 64
29 MAX_ITEMS = 64
43 MAX_BYTES = 1024
30 MAX_BYTES = 1024
44
31
45 def _extract_buffers(obj, threshold=MAX_BYTES):
32 def _extract_buffers(obj, threshold=MAX_BYTES):
46 """extract buffers larger than a certain threshold"""
33 """extract buffers larger than a certain threshold"""
47 buffers = []
34 buffers = []
48 if isinstance(obj, CannedObject) and obj.buffers:
35 if isinstance(obj, CannedObject) and obj.buffers:
49 for i,buf in enumerate(obj.buffers):
36 for i,buf in enumerate(obj.buffers):
50 if len(buf) > threshold:
37 if len(buf) > threshold:
51 # buffer larger than threshold, prevent pickling
38 # buffer larger than threshold, prevent pickling
52 obj.buffers[i] = None
39 obj.buffers[i] = None
53 buffers.append(buf)
40 buffers.append(buf)
54 elif isinstance(buf, buffer):
41 elif isinstance(buf, buffer):
55 # buffer too small for separate send, coerce to bytes
42 # buffer too small for separate send, coerce to bytes
56 # because pickling buffer objects just results in broken pointers
43 # because pickling buffer objects just results in broken pointers
57 obj.buffers[i] = bytes(buf)
44 obj.buffers[i] = bytes(buf)
58 return buffers
45 return buffers
59
46
60 def _restore_buffers(obj, buffers):
47 def _restore_buffers(obj, buffers):
61 """restore buffers extracted by """
48 """restore buffers extracted by """
62 if isinstance(obj, CannedObject) and obj.buffers:
49 if isinstance(obj, CannedObject) and obj.buffers:
63 for i,buf in enumerate(obj.buffers):
50 for i,buf in enumerate(obj.buffers):
64 if buf is None:
51 if buf is None:
65 obj.buffers[i] = buffers.pop(0)
52 obj.buffers[i] = buffers.pop(0)
66
53
67 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
54 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
68 """Serialize an object into a list of sendable buffers.
55 """Serialize an object into a list of sendable buffers.
69
56
70 Parameters
57 Parameters
71 ----------
58 ----------
72
59
73 obj : object
60 obj : object
74 The object to be serialized
61 The object to be serialized
75 buffer_threshold : int
62 buffer_threshold : int
76 The threshold (in bytes) for pulling out data buffers
63 The threshold (in bytes) for pulling out data buffers
77 to avoid pickling them.
64 to avoid pickling them.
78 item_threshold : int
65 item_threshold : int
79 The maximum number of items over which canning will iterate.
66 The maximum number of items over which canning will iterate.
80 Containers (lists, dicts) larger than this will be pickled without
67 Containers (lists, dicts) larger than this will be pickled without
81 introspection.
68 introspection.
82
69
83 Returns
70 Returns
84 -------
71 -------
85 [bufs] : list of buffers representing the serialized object.
72 [bufs] : list of buffers representing the serialized object.
86 """
73 """
87 buffers = []
74 buffers = []
88 if istype(obj, sequence_types) and len(obj) < item_threshold:
75 if istype(obj, sequence_types) and len(obj) < item_threshold:
89 cobj = can_sequence(obj)
76 cobj = can_sequence(obj)
90 for c in cobj:
77 for c in cobj:
91 buffers.extend(_extract_buffers(c, buffer_threshold))
78 buffers.extend(_extract_buffers(c, buffer_threshold))
92 elif istype(obj, dict) and len(obj) < item_threshold:
79 elif istype(obj, dict) and len(obj) < item_threshold:
93 cobj = {}
80 cobj = {}
94 for k in sorted(obj):
81 for k in sorted(obj):
95 c = can(obj[k])
82 c = can(obj[k])
96 buffers.extend(_extract_buffers(c, buffer_threshold))
83 buffers.extend(_extract_buffers(c, buffer_threshold))
97 cobj[k] = c
84 cobj[k] = c
98 else:
85 else:
99 cobj = can(obj)
86 cobj = can(obj)
100 buffers.extend(_extract_buffers(cobj, buffer_threshold))
87 buffers.extend(_extract_buffers(cobj, buffer_threshold))
101
88
102 buffers.insert(0, pickle.dumps(cobj,-1))
89 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
103 return buffers
90 return buffers
104
91
105 def unserialize_object(buffers, g=None):
92 def unserialize_object(buffers, g=None):
106 """reconstruct an object serialized by serialize_object from data buffers.
93 """reconstruct an object serialized by serialize_object from data buffers.
107
94
108 Parameters
95 Parameters
109 ----------
96 ----------
110
97
111 bufs : list of buffers/bytes
98 bufs : list of buffers/bytes
112
99
113 g : globals to be used when uncanning
100 g : globals to be used when uncanning
114
101
115 Returns
102 Returns
116 -------
103 -------
117
104
118 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
105 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
119 """
106 """
120 bufs = list(buffers)
107 bufs = list(buffers)
121 pobj = bufs.pop(0)
108 pobj = bufs.pop(0)
122 if not isinstance(pobj, bytes):
109 if not isinstance(pobj, bytes):
123 # a zmq message
110 # a zmq message
124 pobj = bytes(pobj)
111 pobj = bytes(pobj)
125 canned = pickle.loads(pobj)
112 canned = pickle.loads(pobj)
126 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
113 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
127 for c in canned:
114 for c in canned:
128 _restore_buffers(c, bufs)
115 _restore_buffers(c, bufs)
129 newobj = uncan_sequence(canned, g)
116 newobj = uncan_sequence(canned, g)
130 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
117 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
131 newobj = {}
118 newobj = {}
132 for k in sorted(canned):
119 for k in sorted(canned):
133 c = canned[k]
120 c = canned[k]
134 _restore_buffers(c, bufs)
121 _restore_buffers(c, bufs)
135 newobj[k] = uncan(c, g)
122 newobj[k] = uncan(c, g)
136 else:
123 else:
137 _restore_buffers(canned, bufs)
124 _restore_buffers(canned, bufs)
138 newobj = uncan(canned, g)
125 newobj = uncan(canned, g)
139
126
140 return newobj, bufs
127 return newobj, bufs
141
128
142 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
129 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
143 """pack up a function, args, and kwargs to be sent over the wire
130 """pack up a function, args, and kwargs to be sent over the wire
144
131
145 Each element of args/kwargs will be canned for special treatment,
132 Each element of args/kwargs will be canned for special treatment,
146 but inspection will not go any deeper than that.
133 but inspection will not go any deeper than that.
147
134
148 Any object whose data is larger than `threshold` will not have their data copied
135 Any object whose data is larger than `threshold` will not have their data copied
149 (only numpy arrays and bytes/buffers support zero-copy)
136 (only numpy arrays and bytes/buffers support zero-copy)
150
137
151 Message will be a list of bytes/buffers of the format:
138 Message will be a list of bytes/buffers of the format:
152
139
153 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
140 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
154
141
155 With length at least two + len(args) + len(kwargs)
142 With length at least two + len(args) + len(kwargs)
156 """
143 """
157
144
158 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
145 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
159
146
160 kw_keys = sorted(kwargs.keys())
147 kw_keys = sorted(kwargs.keys())
161 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
148 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
162
149
163 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
150 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
164
151
165 msg = [pickle.dumps(can(f),-1)]
152 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
166 msg.append(pickle.dumps(info, -1))
153 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
167 msg.extend(arg_bufs)
154 msg.extend(arg_bufs)
168 msg.extend(kwarg_bufs)
155 msg.extend(kwarg_bufs)
169
156
170 return msg
157 return msg
171
158
172 def unpack_apply_message(bufs, g=None, copy=True):
159 def unpack_apply_message(bufs, g=None, copy=True):
173 """unpack f,args,kwargs from buffers packed by pack_apply_message()
160 """unpack f,args,kwargs from buffers packed by pack_apply_message()
174 Returns: original f,args,kwargs"""
161 Returns: original f,args,kwargs"""
175 bufs = list(bufs) # allow us to pop
162 bufs = list(bufs) # allow us to pop
176 assert len(bufs) >= 2, "not enough buffers!"
163 assert len(bufs) >= 2, "not enough buffers!"
177 if not copy:
164 if not copy:
178 for i in range(2):
165 for i in range(2):
179 bufs[i] = bufs[i].bytes
166 bufs[i] = bufs[i].bytes
180 f = uncan(pickle.loads(bufs.pop(0)), g)
167 f = uncan(pickle.loads(bufs.pop(0)), g)
181 info = pickle.loads(bufs.pop(0))
168 info = pickle.loads(bufs.pop(0))
182 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
169 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
183
170
184 args = []
171 args = []
185 for i in range(info['nargs']):
172 for i in range(info['nargs']):
186 arg, arg_bufs = unserialize_object(arg_bufs, g)
173 arg, arg_bufs = unserialize_object(arg_bufs, g)
187 args.append(arg)
174 args.append(arg)
188 args = tuple(args)
175 args = tuple(args)
189 assert not arg_bufs, "Shouldn't be any arg bufs left over"
176 assert not arg_bufs, "Shouldn't be any arg bufs left over"
190
177
191 kwargs = {}
178 kwargs = {}
192 for key in info['kw_keys']:
179 for key in info['kw_keys']:
193 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
180 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
194 kwargs[key] = kwarg
181 kwargs[key] = kwarg
195 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
182 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
196
183
197 return f,args,kwargs
184 return f,args,kwargs
198
185
@@ -1,856 +1,857 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 from datetime import datetime
21 from datetime import datetime
22
22
23 try:
23 try:
24 import cPickle
24 import cPickle
25 pickle = cPickle
25 pickle = cPickle
26 except:
26 except:
27 cPickle = None
27 cPickle = None
28 import pickle
28 import pickle
29
29
30 import zmq
30 import zmq
31 from zmq.utils import jsonapi
31 from zmq.utils import jsonapi
32 from zmq.eventloop.ioloop import IOLoop
32 from zmq.eventloop.ioloop import IOLoop
33 from zmq.eventloop.zmqstream import ZMQStream
33 from zmq.eventloop.zmqstream import ZMQStream
34
34
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 from IPython.config.configurable import Configurable, LoggingConfigurable
36 from IPython.config.configurable import Configurable, LoggingConfigurable
37 from IPython.utils import io
37 from IPython.utils import io
38 from IPython.utils.importstring import import_item
38 from IPython.utils.importstring import import_item
39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 iteritems)
41 iteritems)
42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 DottedObjectName, CUnicode, Dict, Integer,
43 DottedObjectName, CUnicode, Dict, Integer,
44 TraitError,
44 TraitError,
45 )
45 )
46 from IPython.utils.pickleutil import PICKLE_PROTOCOL
46 from IPython.kernel.adapter import adapt
47 from IPython.kernel.adapter import adapt
47 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48
49
49 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
50 # utility functions
51 # utility functions
51 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
52
53
53 def squash_unicode(obj):
54 def squash_unicode(obj):
54 """coerce unicode back to bytestrings."""
55 """coerce unicode back to bytestrings."""
55 if isinstance(obj,dict):
56 if isinstance(obj,dict):
56 for key in obj.keys():
57 for key in obj.keys():
57 obj[key] = squash_unicode(obj[key])
58 obj[key] = squash_unicode(obj[key])
58 if isinstance(key, unicode_type):
59 if isinstance(key, unicode_type):
59 obj[squash_unicode(key)] = obj.pop(key)
60 obj[squash_unicode(key)] = obj.pop(key)
60 elif isinstance(obj, list):
61 elif isinstance(obj, list):
61 for i,v in enumerate(obj):
62 for i,v in enumerate(obj):
62 obj[i] = squash_unicode(v)
63 obj[i] = squash_unicode(v)
63 elif isinstance(obj, unicode_type):
64 elif isinstance(obj, unicode_type):
64 obj = obj.encode('utf8')
65 obj = obj.encode('utf8')
65 return obj
66 return obj
66
67
67 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
68 # globals and defaults
69 # globals and defaults
69 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
70
71
71 # ISO8601-ify datetime objects
72 # ISO8601-ify datetime objects
72 # allow unicode
73 # allow unicode
73 # disallow nan, because it's not actually valid JSON
74 # disallow nan, because it's not actually valid JSON
74 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 ensure_ascii=False, allow_nan=False,
76 ensure_ascii=False, allow_nan=False,
76 )
77 )
77 json_unpacker = lambda s: jsonapi.loads(s)
78 json_unpacker = lambda s: jsonapi.loads(s)
78
79
79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
80 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
80 pickle_unpacker = pickle.loads
81 pickle_unpacker = pickle.loads
81
82
82 default_packer = json_packer
83 default_packer = json_packer
83 default_unpacker = json_unpacker
84 default_unpacker = json_unpacker
84
85
85 DELIM = b"<IDS|MSG>"
86 DELIM = b"<IDS|MSG>"
86 # singleton dummy tracker, which will always report as done
87 # singleton dummy tracker, which will always report as done
87 DONE = zmq.MessageTracker()
88 DONE = zmq.MessageTracker()
88
89
89 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
90 # Mixin tools for apps that use Sessions
91 # Mixin tools for apps that use Sessions
91 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
92
93
93 session_aliases = dict(
94 session_aliases = dict(
94 ident = 'Session.session',
95 ident = 'Session.session',
95 user = 'Session.username',
96 user = 'Session.username',
96 keyfile = 'Session.keyfile',
97 keyfile = 'Session.keyfile',
97 )
98 )
98
99
99 session_flags = {
100 session_flags = {
100 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 'keyfile' : '' }},
102 'keyfile' : '' }},
102 """Use HMAC digests for authentication of messages.
103 """Use HMAC digests for authentication of messages.
103 Setting this flag will generate a new UUID to use as the HMAC key.
104 Setting this flag will generate a new UUID to use as the HMAC key.
104 """),
105 """),
105 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 """Don't authenticate messages."""),
107 """Don't authenticate messages."""),
107 }
108 }
108
109
109 def default_secure(cfg):
110 def default_secure(cfg):
110 """Set the default behavior for a config environment to be secure.
111 """Set the default behavior for a config environment to be secure.
111
112
112 If Session.key/keyfile have not been set, set Session.key to
113 If Session.key/keyfile have not been set, set Session.key to
113 a new random UUID.
114 a new random UUID.
114 """
115 """
115
116
116 if 'Session' in cfg:
117 if 'Session' in cfg:
117 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 return
119 return
119 # key/keyfile not specified, generate new UUID:
120 # key/keyfile not specified, generate new UUID:
120 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121
122
122
123
123 #-----------------------------------------------------------------------------
124 #-----------------------------------------------------------------------------
124 # Classes
125 # Classes
125 #-----------------------------------------------------------------------------
126 #-----------------------------------------------------------------------------
126
127
127 class SessionFactory(LoggingConfigurable):
128 class SessionFactory(LoggingConfigurable):
128 """The Base class for configurables that have a Session, Context, logger,
129 """The Base class for configurables that have a Session, Context, logger,
129 and IOLoop.
130 and IOLoop.
130 """
131 """
131
132
132 logname = Unicode('')
133 logname = Unicode('')
133 def _logname_changed(self, name, old, new):
134 def _logname_changed(self, name, old, new):
134 self.log = logging.getLogger(new)
135 self.log = logging.getLogger(new)
135
136
136 # not configurable:
137 # not configurable:
137 context = Instance('zmq.Context')
138 context = Instance('zmq.Context')
138 def _context_default(self):
139 def _context_default(self):
139 return zmq.Context.instance()
140 return zmq.Context.instance()
140
141
141 session = Instance('IPython.kernel.zmq.session.Session')
142 session = Instance('IPython.kernel.zmq.session.Session')
142
143
143 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 def _loop_default(self):
145 def _loop_default(self):
145 return IOLoop.instance()
146 return IOLoop.instance()
146
147
147 def __init__(self, **kwargs):
148 def __init__(self, **kwargs):
148 super(SessionFactory, self).__init__(**kwargs)
149 super(SessionFactory, self).__init__(**kwargs)
149
150
150 if self.session is None:
151 if self.session is None:
151 # construct the session
152 # construct the session
152 self.session = Session(**kwargs)
153 self.session = Session(**kwargs)
153
154
154
155
155 class Message(object):
156 class Message(object):
156 """A simple message object that maps dict keys to attributes.
157 """A simple message object that maps dict keys to attributes.
157
158
158 A Message can be created from a dict and a dict from a Message instance
159 A Message can be created from a dict and a dict from a Message instance
159 simply by calling dict(msg_obj)."""
160 simply by calling dict(msg_obj)."""
160
161
161 def __init__(self, msg_dict):
162 def __init__(self, msg_dict):
162 dct = self.__dict__
163 dct = self.__dict__
163 for k, v in iteritems(dict(msg_dict)):
164 for k, v in iteritems(dict(msg_dict)):
164 if isinstance(v, dict):
165 if isinstance(v, dict):
165 v = Message(v)
166 v = Message(v)
166 dct[k] = v
167 dct[k] = v
167
168
168 # Having this iterator lets dict(msg_obj) work out of the box.
169 # Having this iterator lets dict(msg_obj) work out of the box.
169 def __iter__(self):
170 def __iter__(self):
170 return iter(iteritems(self.__dict__))
171 return iter(iteritems(self.__dict__))
171
172
172 def __repr__(self):
173 def __repr__(self):
173 return repr(self.__dict__)
174 return repr(self.__dict__)
174
175
175 def __str__(self):
176 def __str__(self):
176 return pprint.pformat(self.__dict__)
177 return pprint.pformat(self.__dict__)
177
178
178 def __contains__(self, k):
179 def __contains__(self, k):
179 return k in self.__dict__
180 return k in self.__dict__
180
181
181 def __getitem__(self, k):
182 def __getitem__(self, k):
182 return self.__dict__[k]
183 return self.__dict__[k]
183
184
184
185
185 def msg_header(msg_id, msg_type, username, session):
186 def msg_header(msg_id, msg_type, username, session):
186 date = datetime.now()
187 date = datetime.now()
187 version = kernel_protocol_version
188 version = kernel_protocol_version
188 return locals()
189 return locals()
189
190
190 def extract_header(msg_or_header):
191 def extract_header(msg_or_header):
191 """Given a message or header, return the header."""
192 """Given a message or header, return the header."""
192 if not msg_or_header:
193 if not msg_or_header:
193 return {}
194 return {}
194 try:
195 try:
195 # See if msg_or_header is the entire message.
196 # See if msg_or_header is the entire message.
196 h = msg_or_header['header']
197 h = msg_or_header['header']
197 except KeyError:
198 except KeyError:
198 try:
199 try:
199 # See if msg_or_header is just the header
200 # See if msg_or_header is just the header
200 h = msg_or_header['msg_id']
201 h = msg_or_header['msg_id']
201 except KeyError:
202 except KeyError:
202 raise
203 raise
203 else:
204 else:
204 h = msg_or_header
205 h = msg_or_header
205 if not isinstance(h, dict):
206 if not isinstance(h, dict):
206 h = dict(h)
207 h = dict(h)
207 return h
208 return h
208
209
209 class Session(Configurable):
210 class Session(Configurable):
210 """Object for handling serialization and sending of messages.
211 """Object for handling serialization and sending of messages.
211
212
212 The Session object handles building messages and sending them
213 The Session object handles building messages and sending them
213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 other over the network via Session objects, and only need to work with the
215 other over the network via Session objects, and only need to work with the
215 dict-based IPython message spec. The Session will handle
216 dict-based IPython message spec. The Session will handle
216 serialization/deserialization, security, and metadata.
217 serialization/deserialization, security, and metadata.
217
218
218 Sessions support configurable serialiization via packer/unpacker traits,
219 Sessions support configurable serialiization via packer/unpacker traits,
219 and signing with HMAC digests via the key/keyfile traits.
220 and signing with HMAC digests via the key/keyfile traits.
220
221
221 Parameters
222 Parameters
222 ----------
223 ----------
223
224
224 debug : bool
225 debug : bool
225 whether to trigger extra debugging statements
226 whether to trigger extra debugging statements
226 packer/unpacker : str : 'json', 'pickle' or import_string
227 packer/unpacker : str : 'json', 'pickle' or import_string
227 importstrings for methods to serialize message parts. If just
228 importstrings for methods to serialize message parts. If just
228 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 Otherwise, the entire importstring must be used.
230 Otherwise, the entire importstring must be used.
230
231
231 The functions must accept at least valid JSON input, and output *bytes*.
232 The functions must accept at least valid JSON input, and output *bytes*.
232
233
233 For example, to use msgpack:
234 For example, to use msgpack:
234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 pack/unpack : callables
236 pack/unpack : callables
236 You can also set the pack/unpack callables for serialization directly.
237 You can also set the pack/unpack callables for serialization directly.
237 session : bytes
238 session : bytes
238 the ID of this Session object. The default is to generate a new UUID.
239 the ID of this Session object. The default is to generate a new UUID.
239 username : unicode
240 username : unicode
240 username added to message headers. The default is to ask the OS.
241 username added to message headers. The default is to ask the OS.
241 key : bytes
242 key : bytes
242 The key used to initialize an HMAC signature. If unset, messages
243 The key used to initialize an HMAC signature. If unset, messages
243 will not be signed or checked.
244 will not be signed or checked.
244 keyfile : filepath
245 keyfile : filepath
245 The file containing a key. If this is set, `key` will be initialized
246 The file containing a key. If this is set, `key` will be initialized
246 to the contents of the file.
247 to the contents of the file.
247
248
248 """
249 """
249
250
250 debug=Bool(False, config=True, help="""Debug output in the Session""")
251 debug=Bool(False, config=True, help="""Debug output in the Session""")
251
252
252 packer = DottedObjectName('json',config=True,
253 packer = DottedObjectName('json',config=True,
253 help="""The name of the packer for serializing messages.
254 help="""The name of the packer for serializing messages.
254 Should be one of 'json', 'pickle', or an import name
255 Should be one of 'json', 'pickle', or an import name
255 for a custom callable serializer.""")
256 for a custom callable serializer.""")
256 def _packer_changed(self, name, old, new):
257 def _packer_changed(self, name, old, new):
257 if new.lower() == 'json':
258 if new.lower() == 'json':
258 self.pack = json_packer
259 self.pack = json_packer
259 self.unpack = json_unpacker
260 self.unpack = json_unpacker
260 self.unpacker = new
261 self.unpacker = new
261 elif new.lower() == 'pickle':
262 elif new.lower() == 'pickle':
262 self.pack = pickle_packer
263 self.pack = pickle_packer
263 self.unpack = pickle_unpacker
264 self.unpack = pickle_unpacker
264 self.unpacker = new
265 self.unpacker = new
265 else:
266 else:
266 self.pack = import_item(str(new))
267 self.pack = import_item(str(new))
267
268
268 unpacker = DottedObjectName('json', config=True,
269 unpacker = DottedObjectName('json', config=True,
269 help="""The name of the unpacker for unserializing messages.
270 help="""The name of the unpacker for unserializing messages.
270 Only used with custom functions for `packer`.""")
271 Only used with custom functions for `packer`.""")
271 def _unpacker_changed(self, name, old, new):
272 def _unpacker_changed(self, name, old, new):
272 if new.lower() == 'json':
273 if new.lower() == 'json':
273 self.pack = json_packer
274 self.pack = json_packer
274 self.unpack = json_unpacker
275 self.unpack = json_unpacker
275 self.packer = new
276 self.packer = new
276 elif new.lower() == 'pickle':
277 elif new.lower() == 'pickle':
277 self.pack = pickle_packer
278 self.pack = pickle_packer
278 self.unpack = pickle_unpacker
279 self.unpack = pickle_unpacker
279 self.packer = new
280 self.packer = new
280 else:
281 else:
281 self.unpack = import_item(str(new))
282 self.unpack = import_item(str(new))
282
283
283 session = CUnicode(u'', config=True,
284 session = CUnicode(u'', config=True,
284 help="""The UUID identifying this session.""")
285 help="""The UUID identifying this session.""")
285 def _session_default(self):
286 def _session_default(self):
286 u = unicode_type(uuid.uuid4())
287 u = unicode_type(uuid.uuid4())
287 self.bsession = u.encode('ascii')
288 self.bsession = u.encode('ascii')
288 return u
289 return u
289
290
290 def _session_changed(self, name, old, new):
291 def _session_changed(self, name, old, new):
291 self.bsession = self.session.encode('ascii')
292 self.bsession = self.session.encode('ascii')
292
293
293 # bsession is the session as bytes
294 # bsession is the session as bytes
294 bsession = CBytes(b'')
295 bsession = CBytes(b'')
295
296
296 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 help="""Username for the Session. Default is your system username.""",
298 help="""Username for the Session. Default is your system username.""",
298 config=True)
299 config=True)
299
300
300 metadata = Dict({}, config=True,
301 metadata = Dict({}, config=True,
301 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302
303
303 # if 0, no adapting to do.
304 # if 0, no adapting to do.
304 adapt_version = Integer(0)
305 adapt_version = Integer(0)
305
306
306 # message signature related traits:
307 # message signature related traits:
307
308
308 key = CBytes(b'', config=True,
309 key = CBytes(b'', config=True,
309 help="""execution key, for extra authentication.""")
310 help="""execution key, for extra authentication.""")
310 def _key_changed(self):
311 def _key_changed(self):
311 self._new_auth()
312 self._new_auth()
312
313
313 signature_scheme = Unicode('hmac-sha256', config=True,
314 signature_scheme = Unicode('hmac-sha256', config=True,
314 help="""The digest scheme used to construct the message signatures.
315 help="""The digest scheme used to construct the message signatures.
315 Must have the form 'hmac-HASH'.""")
316 Must have the form 'hmac-HASH'.""")
316 def _signature_scheme_changed(self, name, old, new):
317 def _signature_scheme_changed(self, name, old, new):
317 if not new.startswith('hmac-'):
318 if not new.startswith('hmac-'):
318 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
319 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
319 hash_name = new.split('-', 1)[1]
320 hash_name = new.split('-', 1)[1]
320 try:
321 try:
321 self.digest_mod = getattr(hashlib, hash_name)
322 self.digest_mod = getattr(hashlib, hash_name)
322 except AttributeError:
323 except AttributeError:
323 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 self._new_auth()
325 self._new_auth()
325
326
326 digest_mod = Any()
327 digest_mod = Any()
327 def _digest_mod_default(self):
328 def _digest_mod_default(self):
328 return hashlib.sha256
329 return hashlib.sha256
329
330
330 auth = Instance(hmac.HMAC)
331 auth = Instance(hmac.HMAC)
331
332
332 def _new_auth(self):
333 def _new_auth(self):
333 if self.key:
334 if self.key:
334 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 else:
336 else:
336 self.auth = None
337 self.auth = None
337
338
338 digest_history = Set()
339 digest_history = Set()
339 digest_history_size = Integer(2**16, config=True,
340 digest_history_size = Integer(2**16, config=True,
340 help="""The maximum number of digests to remember.
341 help="""The maximum number of digests to remember.
341
342
342 The digest history will be culled when it exceeds this value.
343 The digest history will be culled when it exceeds this value.
343 """
344 """
344 )
345 )
345
346
346 keyfile = Unicode('', config=True,
347 keyfile = Unicode('', config=True,
347 help="""path to file containing execution key.""")
348 help="""path to file containing execution key.""")
348 def _keyfile_changed(self, name, old, new):
349 def _keyfile_changed(self, name, old, new):
349 with open(new, 'rb') as f:
350 with open(new, 'rb') as f:
350 self.key = f.read().strip()
351 self.key = f.read().strip()
351
352
352 # for protecting against sends from forks
353 # for protecting against sends from forks
353 pid = Integer()
354 pid = Integer()
354
355
355 # serialization traits:
356 # serialization traits:
356
357
357 pack = Any(default_packer) # the actual packer function
358 pack = Any(default_packer) # the actual packer function
358 def _pack_changed(self, name, old, new):
359 def _pack_changed(self, name, old, new):
359 if not callable(new):
360 if not callable(new):
360 raise TypeError("packer must be callable, not %s"%type(new))
361 raise TypeError("packer must be callable, not %s"%type(new))
361
362
362 unpack = Any(default_unpacker) # the actual packer function
363 unpack = Any(default_unpacker) # the actual packer function
363 def _unpack_changed(self, name, old, new):
364 def _unpack_changed(self, name, old, new):
364 # unpacker is not checked - it is assumed to be
365 # unpacker is not checked - it is assumed to be
365 if not callable(new):
366 if not callable(new):
366 raise TypeError("unpacker must be callable, not %s"%type(new))
367 raise TypeError("unpacker must be callable, not %s"%type(new))
367
368
368 # thresholds:
369 # thresholds:
369 copy_threshold = Integer(2**16, config=True,
370 copy_threshold = Integer(2**16, config=True,
370 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
371 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
371 buffer_threshold = Integer(MAX_BYTES, config=True,
372 buffer_threshold = Integer(MAX_BYTES, config=True,
372 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
373 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
373 item_threshold = Integer(MAX_ITEMS, config=True,
374 item_threshold = Integer(MAX_ITEMS, config=True,
374 help="""The maximum number of items for a container to be introspected for custom serialization.
375 help="""The maximum number of items for a container to be introspected for custom serialization.
375 Containers larger than this are pickled outright.
376 Containers larger than this are pickled outright.
376 """
377 """
377 )
378 )
378
379
379
380
380 def __init__(self, **kwargs):
381 def __init__(self, **kwargs):
381 """create a Session object
382 """create a Session object
382
383
383 Parameters
384 Parameters
384 ----------
385 ----------
385
386
386 debug : bool
387 debug : bool
387 whether to trigger extra debugging statements
388 whether to trigger extra debugging statements
388 packer/unpacker : str : 'json', 'pickle' or import_string
389 packer/unpacker : str : 'json', 'pickle' or import_string
389 importstrings for methods to serialize message parts. If just
390 importstrings for methods to serialize message parts. If just
390 'json' or 'pickle', predefined JSON and pickle packers will be used.
391 'json' or 'pickle', predefined JSON and pickle packers will be used.
391 Otherwise, the entire importstring must be used.
392 Otherwise, the entire importstring must be used.
392
393
393 The functions must accept at least valid JSON input, and output
394 The functions must accept at least valid JSON input, and output
394 *bytes*.
395 *bytes*.
395
396
396 For example, to use msgpack:
397 For example, to use msgpack:
397 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
398 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
398 pack/unpack : callables
399 pack/unpack : callables
399 You can also set the pack/unpack callables for serialization
400 You can also set the pack/unpack callables for serialization
400 directly.
401 directly.
401 session : unicode (must be ascii)
402 session : unicode (must be ascii)
402 the ID of this Session object. The default is to generate a new
403 the ID of this Session object. The default is to generate a new
403 UUID.
404 UUID.
404 bsession : bytes
405 bsession : bytes
405 The session as bytes
406 The session as bytes
406 username : unicode
407 username : unicode
407 username added to message headers. The default is to ask the OS.
408 username added to message headers. The default is to ask the OS.
408 key : bytes
409 key : bytes
409 The key used to initialize an HMAC signature. If unset, messages
410 The key used to initialize an HMAC signature. If unset, messages
410 will not be signed or checked.
411 will not be signed or checked.
411 signature_scheme : str
412 signature_scheme : str
412 The message digest scheme. Currently must be of the form 'hmac-HASH',
413 The message digest scheme. Currently must be of the form 'hmac-HASH',
413 where 'HASH' is a hashing function available in Python's hashlib.
414 where 'HASH' is a hashing function available in Python's hashlib.
414 The default is 'hmac-sha256'.
415 The default is 'hmac-sha256'.
415 This is ignored if 'key' is empty.
416 This is ignored if 'key' is empty.
416 keyfile : filepath
417 keyfile : filepath
417 The file containing a key. If this is set, `key` will be
418 The file containing a key. If this is set, `key` will be
418 initialized to the contents of the file.
419 initialized to the contents of the file.
419 """
420 """
420 super(Session, self).__init__(**kwargs)
421 super(Session, self).__init__(**kwargs)
421 self._check_packers()
422 self._check_packers()
422 self.none = self.pack({})
423 self.none = self.pack({})
423 # ensure self._session_default() if necessary, so bsession is defined:
424 # ensure self._session_default() if necessary, so bsession is defined:
424 self.session
425 self.session
425 self.pid = os.getpid()
426 self.pid = os.getpid()
426
427
427 @property
428 @property
428 def msg_id(self):
429 def msg_id(self):
429 """always return new uuid"""
430 """always return new uuid"""
430 return str(uuid.uuid4())
431 return str(uuid.uuid4())
431
432
432 def _check_packers(self):
433 def _check_packers(self):
433 """check packers for datetime support."""
434 """check packers for datetime support."""
434 pack = self.pack
435 pack = self.pack
435 unpack = self.unpack
436 unpack = self.unpack
436
437
437 # check simple serialization
438 # check simple serialization
438 msg = dict(a=[1,'hi'])
439 msg = dict(a=[1,'hi'])
439 try:
440 try:
440 packed = pack(msg)
441 packed = pack(msg)
441 except Exception as e:
442 except Exception as e:
442 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
443 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
443 if self.packer == 'json':
444 if self.packer == 'json':
444 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
445 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
445 else:
446 else:
446 jsonmsg = ""
447 jsonmsg = ""
447 raise ValueError(
448 raise ValueError(
448 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
449 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
449 )
450 )
450
451
451 # ensure packed message is bytes
452 # ensure packed message is bytes
452 if not isinstance(packed, bytes):
453 if not isinstance(packed, bytes):
453 raise ValueError("message packed to %r, but bytes are required"%type(packed))
454 raise ValueError("message packed to %r, but bytes are required"%type(packed))
454
455
455 # check that unpack is pack's inverse
456 # check that unpack is pack's inverse
456 try:
457 try:
457 unpacked = unpack(packed)
458 unpacked = unpack(packed)
458 assert unpacked == msg
459 assert unpacked == msg
459 except Exception as e:
460 except Exception as e:
460 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
461 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
461 if self.packer == 'json':
462 if self.packer == 'json':
462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
463 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
463 else:
464 else:
464 jsonmsg = ""
465 jsonmsg = ""
465 raise ValueError(
466 raise ValueError(
466 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
467 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
467 )
468 )
468
469
469 # check datetime support
470 # check datetime support
470 msg = dict(t=datetime.now())
471 msg = dict(t=datetime.now())
471 try:
472 try:
472 unpacked = unpack(pack(msg))
473 unpacked = unpack(pack(msg))
473 if isinstance(unpacked['t'], datetime):
474 if isinstance(unpacked['t'], datetime):
474 raise ValueError("Shouldn't deserialize to datetime")
475 raise ValueError("Shouldn't deserialize to datetime")
475 except Exception:
476 except Exception:
476 self.pack = lambda o: pack(squash_dates(o))
477 self.pack = lambda o: pack(squash_dates(o))
477 self.unpack = lambda s: unpack(s)
478 self.unpack = lambda s: unpack(s)
478
479
479 def msg_header(self, msg_type):
480 def msg_header(self, msg_type):
480 return msg_header(self.msg_id, msg_type, self.username, self.session)
481 return msg_header(self.msg_id, msg_type, self.username, self.session)
481
482
482 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
483 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
483 """Return the nested message dict.
484 """Return the nested message dict.
484
485
485 This format is different from what is sent over the wire. The
486 This format is different from what is sent over the wire. The
486 serialize/unserialize methods converts this nested message dict to the wire
487 serialize/unserialize methods converts this nested message dict to the wire
487 format, which is a list of message parts.
488 format, which is a list of message parts.
488 """
489 """
489 msg = {}
490 msg = {}
490 header = self.msg_header(msg_type) if header is None else header
491 header = self.msg_header(msg_type) if header is None else header
491 msg['header'] = header
492 msg['header'] = header
492 msg['msg_id'] = header['msg_id']
493 msg['msg_id'] = header['msg_id']
493 msg['msg_type'] = header['msg_type']
494 msg['msg_type'] = header['msg_type']
494 msg['parent_header'] = {} if parent is None else extract_header(parent)
495 msg['parent_header'] = {} if parent is None else extract_header(parent)
495 msg['content'] = {} if content is None else content
496 msg['content'] = {} if content is None else content
496 msg['metadata'] = self.metadata.copy()
497 msg['metadata'] = self.metadata.copy()
497 if metadata is not None:
498 if metadata is not None:
498 msg['metadata'].update(metadata)
499 msg['metadata'].update(metadata)
499 return msg
500 return msg
500
501
501 def sign(self, msg_list):
502 def sign(self, msg_list):
502 """Sign a message with HMAC digest. If no auth, return b''.
503 """Sign a message with HMAC digest. If no auth, return b''.
503
504
504 Parameters
505 Parameters
505 ----------
506 ----------
506 msg_list : list
507 msg_list : list
507 The [p_header,p_parent,p_content] part of the message list.
508 The [p_header,p_parent,p_content] part of the message list.
508 """
509 """
509 if self.auth is None:
510 if self.auth is None:
510 return b''
511 return b''
511 h = self.auth.copy()
512 h = self.auth.copy()
512 for m in msg_list:
513 for m in msg_list:
513 h.update(m)
514 h.update(m)
514 return str_to_bytes(h.hexdigest())
515 return str_to_bytes(h.hexdigest())
515
516
516 def serialize(self, msg, ident=None):
517 def serialize(self, msg, ident=None):
517 """Serialize the message components to bytes.
518 """Serialize the message components to bytes.
518
519
519 This is roughly the inverse of unserialize. The serialize/unserialize
520 This is roughly the inverse of unserialize. The serialize/unserialize
520 methods work with full message lists, whereas pack/unpack work with
521 methods work with full message lists, whereas pack/unpack work with
521 the individual message parts in the message list.
522 the individual message parts in the message list.
522
523
523 Parameters
524 Parameters
524 ----------
525 ----------
525 msg : dict or Message
526 msg : dict or Message
526 The nexted message dict as returned by the self.msg method.
527 The nexted message dict as returned by the self.msg method.
527
528
528 Returns
529 Returns
529 -------
530 -------
530 msg_list : list
531 msg_list : list
531 The list of bytes objects to be sent with the format::
532 The list of bytes objects to be sent with the format::
532
533
533 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
534 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
534 p_metadata, p_content, buffer1, buffer2, ...]
535 p_metadata, p_content, buffer1, buffer2, ...]
535
536
536 In this list, the ``p_*`` entities are the packed or serialized
537 In this list, the ``p_*`` entities are the packed or serialized
537 versions, so if JSON is used, these are utf8 encoded JSON strings.
538 versions, so if JSON is used, these are utf8 encoded JSON strings.
538 """
539 """
539 content = msg.get('content', {})
540 content = msg.get('content', {})
540 if content is None:
541 if content is None:
541 content = self.none
542 content = self.none
542 elif isinstance(content, dict):
543 elif isinstance(content, dict):
543 content = self.pack(content)
544 content = self.pack(content)
544 elif isinstance(content, bytes):
545 elif isinstance(content, bytes):
545 # content is already packed, as in a relayed message
546 # content is already packed, as in a relayed message
546 pass
547 pass
547 elif isinstance(content, unicode_type):
548 elif isinstance(content, unicode_type):
548 # should be bytes, but JSON often spits out unicode
549 # should be bytes, but JSON often spits out unicode
549 content = content.encode('utf8')
550 content = content.encode('utf8')
550 else:
551 else:
551 raise TypeError("Content incorrect type: %s"%type(content))
552 raise TypeError("Content incorrect type: %s"%type(content))
552
553
553 real_message = [self.pack(msg['header']),
554 real_message = [self.pack(msg['header']),
554 self.pack(msg['parent_header']),
555 self.pack(msg['parent_header']),
555 self.pack(msg['metadata']),
556 self.pack(msg['metadata']),
556 content,
557 content,
557 ]
558 ]
558
559
559 to_send = []
560 to_send = []
560
561
561 if isinstance(ident, list):
562 if isinstance(ident, list):
562 # accept list of idents
563 # accept list of idents
563 to_send.extend(ident)
564 to_send.extend(ident)
564 elif ident is not None:
565 elif ident is not None:
565 to_send.append(ident)
566 to_send.append(ident)
566 to_send.append(DELIM)
567 to_send.append(DELIM)
567
568
568 signature = self.sign(real_message)
569 signature = self.sign(real_message)
569 to_send.append(signature)
570 to_send.append(signature)
570
571
571 to_send.extend(real_message)
572 to_send.extend(real_message)
572
573
573 return to_send
574 return to_send
574
575
575 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
576 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
576 buffers=None, track=False, header=None, metadata=None):
577 buffers=None, track=False, header=None, metadata=None):
577 """Build and send a message via stream or socket.
578 """Build and send a message via stream or socket.
578
579
579 The message format used by this function internally is as follows:
580 The message format used by this function internally is as follows:
580
581
581 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
582 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
582 buffer1,buffer2,...]
583 buffer1,buffer2,...]
583
584
584 The serialize/unserialize methods convert the nested message dict into this
585 The serialize/unserialize methods convert the nested message dict into this
585 format.
586 format.
586
587
587 Parameters
588 Parameters
588 ----------
589 ----------
589
590
590 stream : zmq.Socket or ZMQStream
591 stream : zmq.Socket or ZMQStream
591 The socket-like object used to send the data.
592 The socket-like object used to send the data.
592 msg_or_type : str or Message/dict
593 msg_or_type : str or Message/dict
593 Normally, msg_or_type will be a msg_type unless a message is being
594 Normally, msg_or_type will be a msg_type unless a message is being
594 sent more than once. If a header is supplied, this can be set to
595 sent more than once. If a header is supplied, this can be set to
595 None and the msg_type will be pulled from the header.
596 None and the msg_type will be pulled from the header.
596
597
597 content : dict or None
598 content : dict or None
598 The content of the message (ignored if msg_or_type is a message).
599 The content of the message (ignored if msg_or_type is a message).
599 header : dict or None
600 header : dict or None
600 The header dict for the message (ignored if msg_to_type is a message).
601 The header dict for the message (ignored if msg_to_type is a message).
601 parent : Message or dict or None
602 parent : Message or dict or None
602 The parent or parent header describing the parent of this message
603 The parent or parent header describing the parent of this message
603 (ignored if msg_or_type is a message).
604 (ignored if msg_or_type is a message).
604 ident : bytes or list of bytes
605 ident : bytes or list of bytes
605 The zmq.IDENTITY routing path.
606 The zmq.IDENTITY routing path.
606 metadata : dict or None
607 metadata : dict or None
607 The metadata describing the message
608 The metadata describing the message
608 buffers : list or None
609 buffers : list or None
609 The already-serialized buffers to be appended to the message.
610 The already-serialized buffers to be appended to the message.
610 track : bool
611 track : bool
611 Whether to track. Only for use with Sockets, because ZMQStream
612 Whether to track. Only for use with Sockets, because ZMQStream
612 objects cannot track messages.
613 objects cannot track messages.
613
614
614
615
615 Returns
616 Returns
616 -------
617 -------
617 msg : dict
618 msg : dict
618 The constructed message.
619 The constructed message.
619 """
620 """
620 if not isinstance(stream, zmq.Socket):
621 if not isinstance(stream, zmq.Socket):
621 # ZMQStreams and dummy sockets do not support tracking.
622 # ZMQStreams and dummy sockets do not support tracking.
622 track = False
623 track = False
623
624
624 if isinstance(msg_or_type, (Message, dict)):
625 if isinstance(msg_or_type, (Message, dict)):
625 # We got a Message or message dict, not a msg_type so don't
626 # We got a Message or message dict, not a msg_type so don't
626 # build a new Message.
627 # build a new Message.
627 msg = msg_or_type
628 msg = msg_or_type
628 else:
629 else:
629 msg = self.msg(msg_or_type, content=content, parent=parent,
630 msg = self.msg(msg_or_type, content=content, parent=parent,
630 header=header, metadata=metadata)
631 header=header, metadata=metadata)
631 if not os.getpid() == self.pid:
632 if not os.getpid() == self.pid:
632 io.rprint("WARNING: attempted to send message from fork")
633 io.rprint("WARNING: attempted to send message from fork")
633 io.rprint(msg)
634 io.rprint(msg)
634 return
635 return
635 buffers = [] if buffers is None else buffers
636 buffers = [] if buffers is None else buffers
636 if self.adapt_version:
637 if self.adapt_version:
637 msg = adapt(msg, self.adapt_version)
638 msg = adapt(msg, self.adapt_version)
638 to_send = self.serialize(msg, ident)
639 to_send = self.serialize(msg, ident)
639 to_send.extend(buffers)
640 to_send.extend(buffers)
640 longest = max([ len(s) for s in to_send ])
641 longest = max([ len(s) for s in to_send ])
641 copy = (longest < self.copy_threshold)
642 copy = (longest < self.copy_threshold)
642
643
643 if buffers and track and not copy:
644 if buffers and track and not copy:
644 # only really track when we are doing zero-copy buffers
645 # only really track when we are doing zero-copy buffers
645 tracker = stream.send_multipart(to_send, copy=False, track=True)
646 tracker = stream.send_multipart(to_send, copy=False, track=True)
646 else:
647 else:
647 # use dummy tracker, which will be done immediately
648 # use dummy tracker, which will be done immediately
648 tracker = DONE
649 tracker = DONE
649 stream.send_multipart(to_send, copy=copy)
650 stream.send_multipart(to_send, copy=copy)
650
651
651 if self.debug:
652 if self.debug:
652 pprint.pprint(msg)
653 pprint.pprint(msg)
653 pprint.pprint(to_send)
654 pprint.pprint(to_send)
654 pprint.pprint(buffers)
655 pprint.pprint(buffers)
655
656
656 msg['tracker'] = tracker
657 msg['tracker'] = tracker
657
658
658 return msg
659 return msg
659
660
660 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
661 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
661 """Send a raw message via ident path.
662 """Send a raw message via ident path.
662
663
663 This method is used to send a already serialized message.
664 This method is used to send a already serialized message.
664
665
665 Parameters
666 Parameters
666 ----------
667 ----------
667 stream : ZMQStream or Socket
668 stream : ZMQStream or Socket
668 The ZMQ stream or socket to use for sending the message.
669 The ZMQ stream or socket to use for sending the message.
669 msg_list : list
670 msg_list : list
670 The serialized list of messages to send. This only includes the
671 The serialized list of messages to send. This only includes the
671 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
672 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
672 the message.
673 the message.
673 ident : ident or list
674 ident : ident or list
674 A single ident or a list of idents to use in sending.
675 A single ident or a list of idents to use in sending.
675 """
676 """
676 to_send = []
677 to_send = []
677 if isinstance(ident, bytes):
678 if isinstance(ident, bytes):
678 ident = [ident]
679 ident = [ident]
679 if ident is not None:
680 if ident is not None:
680 to_send.extend(ident)
681 to_send.extend(ident)
681
682
682 to_send.append(DELIM)
683 to_send.append(DELIM)
683 to_send.append(self.sign(msg_list))
684 to_send.append(self.sign(msg_list))
684 to_send.extend(msg_list)
685 to_send.extend(msg_list)
685 stream.send_multipart(to_send, flags, copy=copy)
686 stream.send_multipart(to_send, flags, copy=copy)
686
687
687 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
688 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
688 """Receive and unpack a message.
689 """Receive and unpack a message.
689
690
690 Parameters
691 Parameters
691 ----------
692 ----------
692 socket : ZMQStream or Socket
693 socket : ZMQStream or Socket
693 The socket or stream to use in receiving.
694 The socket or stream to use in receiving.
694
695
695 Returns
696 Returns
696 -------
697 -------
697 [idents], msg
698 [idents], msg
698 [idents] is a list of idents and msg is a nested message dict of
699 [idents] is a list of idents and msg is a nested message dict of
699 same format as self.msg returns.
700 same format as self.msg returns.
700 """
701 """
701 if isinstance(socket, ZMQStream):
702 if isinstance(socket, ZMQStream):
702 socket = socket.socket
703 socket = socket.socket
703 try:
704 try:
704 msg_list = socket.recv_multipart(mode, copy=copy)
705 msg_list = socket.recv_multipart(mode, copy=copy)
705 except zmq.ZMQError as e:
706 except zmq.ZMQError as e:
706 if e.errno == zmq.EAGAIN:
707 if e.errno == zmq.EAGAIN:
707 # We can convert EAGAIN to None as we know in this case
708 # We can convert EAGAIN to None as we know in this case
708 # recv_multipart won't return None.
709 # recv_multipart won't return None.
709 return None,None
710 return None,None
710 else:
711 else:
711 raise
712 raise
712 # split multipart message into identity list and message dict
713 # split multipart message into identity list and message dict
713 # invalid large messages can cause very expensive string comparisons
714 # invalid large messages can cause very expensive string comparisons
714 idents, msg_list = self.feed_identities(msg_list, copy)
715 idents, msg_list = self.feed_identities(msg_list, copy)
715 try:
716 try:
716 return idents, self.unserialize(msg_list, content=content, copy=copy)
717 return idents, self.unserialize(msg_list, content=content, copy=copy)
717 except Exception as e:
718 except Exception as e:
718 # TODO: handle it
719 # TODO: handle it
719 raise e
720 raise e
720
721
721 def feed_identities(self, msg_list, copy=True):
722 def feed_identities(self, msg_list, copy=True):
722 """Split the identities from the rest of the message.
723 """Split the identities from the rest of the message.
723
724
724 Feed until DELIM is reached, then return the prefix as idents and
725 Feed until DELIM is reached, then return the prefix as idents and
725 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
726 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
726 but that would be silly.
727 but that would be silly.
727
728
728 Parameters
729 Parameters
729 ----------
730 ----------
730 msg_list : a list of Message or bytes objects
731 msg_list : a list of Message or bytes objects
731 The message to be split.
732 The message to be split.
732 copy : bool
733 copy : bool
733 flag determining whether the arguments are bytes or Messages
734 flag determining whether the arguments are bytes or Messages
734
735
735 Returns
736 Returns
736 -------
737 -------
737 (idents, msg_list) : two lists
738 (idents, msg_list) : two lists
738 idents will always be a list of bytes, each of which is a ZMQ
739 idents will always be a list of bytes, each of which is a ZMQ
739 identity. msg_list will be a list of bytes or zmq.Messages of the
740 identity. msg_list will be a list of bytes or zmq.Messages of the
740 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
741 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
741 should be unpackable/unserializable via self.unserialize at this
742 should be unpackable/unserializable via self.unserialize at this
742 point.
743 point.
743 """
744 """
744 if copy:
745 if copy:
745 idx = msg_list.index(DELIM)
746 idx = msg_list.index(DELIM)
746 return msg_list[:idx], msg_list[idx+1:]
747 return msg_list[:idx], msg_list[idx+1:]
747 else:
748 else:
748 failed = True
749 failed = True
749 for idx,m in enumerate(msg_list):
750 for idx,m in enumerate(msg_list):
750 if m.bytes == DELIM:
751 if m.bytes == DELIM:
751 failed = False
752 failed = False
752 break
753 break
753 if failed:
754 if failed:
754 raise ValueError("DELIM not in msg_list")
755 raise ValueError("DELIM not in msg_list")
755 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
756 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
756 return [m.bytes for m in idents], msg_list
757 return [m.bytes for m in idents], msg_list
757
758
758 def _add_digest(self, signature):
759 def _add_digest(self, signature):
759 """add a digest to history to protect against replay attacks"""
760 """add a digest to history to protect against replay attacks"""
760 if self.digest_history_size == 0:
761 if self.digest_history_size == 0:
761 # no history, never add digests
762 # no history, never add digests
762 return
763 return
763
764
764 self.digest_history.add(signature)
765 self.digest_history.add(signature)
765 if len(self.digest_history) > self.digest_history_size:
766 if len(self.digest_history) > self.digest_history_size:
766 # threshold reached, cull 10%
767 # threshold reached, cull 10%
767 self._cull_digest_history()
768 self._cull_digest_history()
768
769
769 def _cull_digest_history(self):
770 def _cull_digest_history(self):
770 """cull the digest history
771 """cull the digest history
771
772
772 Removes a randomly selected 10% of the digest history
773 Removes a randomly selected 10% of the digest history
773 """
774 """
774 current = len(self.digest_history)
775 current = len(self.digest_history)
775 n_to_cull = max(int(current // 10), current - self.digest_history_size)
776 n_to_cull = max(int(current // 10), current - self.digest_history_size)
776 if n_to_cull >= current:
777 if n_to_cull >= current:
777 self.digest_history = set()
778 self.digest_history = set()
778 return
779 return
779 to_cull = random.sample(self.digest_history, n_to_cull)
780 to_cull = random.sample(self.digest_history, n_to_cull)
780 self.digest_history.difference_update(to_cull)
781 self.digest_history.difference_update(to_cull)
781
782
782 def unserialize(self, msg_list, content=True, copy=True):
783 def unserialize(self, msg_list, content=True, copy=True):
783 """Unserialize a msg_list to a nested message dict.
784 """Unserialize a msg_list to a nested message dict.
784
785
785 This is roughly the inverse of serialize. The serialize/unserialize
786 This is roughly the inverse of serialize. The serialize/unserialize
786 methods work with full message lists, whereas pack/unpack work with
787 methods work with full message lists, whereas pack/unpack work with
787 the individual message parts in the message list.
788 the individual message parts in the message list.
788
789
789 Parameters
790 Parameters
790 ----------
791 ----------
791 msg_list : list of bytes or Message objects
792 msg_list : list of bytes or Message objects
792 The list of message parts of the form [HMAC,p_header,p_parent,
793 The list of message parts of the form [HMAC,p_header,p_parent,
793 p_metadata,p_content,buffer1,buffer2,...].
794 p_metadata,p_content,buffer1,buffer2,...].
794 content : bool (True)
795 content : bool (True)
795 Whether to unpack the content dict (True), or leave it packed
796 Whether to unpack the content dict (True), or leave it packed
796 (False).
797 (False).
797 copy : bool (True)
798 copy : bool (True)
798 Whether to return the bytes (True), or the non-copying Message
799 Whether to return the bytes (True), or the non-copying Message
799 object in each place (False).
800 object in each place (False).
800
801
801 Returns
802 Returns
802 -------
803 -------
803 msg : dict
804 msg : dict
804 The nested message dict with top-level keys [header, parent_header,
805 The nested message dict with top-level keys [header, parent_header,
805 content, buffers].
806 content, buffers].
806 """
807 """
807 minlen = 5
808 minlen = 5
808 message = {}
809 message = {}
809 if not copy:
810 if not copy:
810 for i in range(minlen):
811 for i in range(minlen):
811 msg_list[i] = msg_list[i].bytes
812 msg_list[i] = msg_list[i].bytes
812 if self.auth is not None:
813 if self.auth is not None:
813 signature = msg_list[0]
814 signature = msg_list[0]
814 if not signature:
815 if not signature:
815 raise ValueError("Unsigned Message")
816 raise ValueError("Unsigned Message")
816 if signature in self.digest_history:
817 if signature in self.digest_history:
817 raise ValueError("Duplicate Signature: %r" % signature)
818 raise ValueError("Duplicate Signature: %r" % signature)
818 self._add_digest(signature)
819 self._add_digest(signature)
819 check = self.sign(msg_list[1:5])
820 check = self.sign(msg_list[1:5])
820 if not signature == check:
821 if not signature == check:
821 raise ValueError("Invalid Signature: %r" % signature)
822 raise ValueError("Invalid Signature: %r" % signature)
822 if not len(msg_list) >= minlen:
823 if not len(msg_list) >= minlen:
823 raise TypeError("malformed message, must have at least %i elements"%minlen)
824 raise TypeError("malformed message, must have at least %i elements"%minlen)
824 header = self.unpack(msg_list[1])
825 header = self.unpack(msg_list[1])
825 message['header'] = extract_dates(header)
826 message['header'] = extract_dates(header)
826 message['msg_id'] = header['msg_id']
827 message['msg_id'] = header['msg_id']
827 message['msg_type'] = header['msg_type']
828 message['msg_type'] = header['msg_type']
828 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
829 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
829 message['metadata'] = self.unpack(msg_list[3])
830 message['metadata'] = self.unpack(msg_list[3])
830 if content:
831 if content:
831 message['content'] = self.unpack(msg_list[4])
832 message['content'] = self.unpack(msg_list[4])
832 else:
833 else:
833 message['content'] = msg_list[4]
834 message['content'] = msg_list[4]
834
835
835 message['buffers'] = msg_list[5:]
836 message['buffers'] = msg_list[5:]
836 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
837 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
837 # adapt to the current version
838 # adapt to the current version
838 return adapt(message)
839 return adapt(message)
839 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
840 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
840
841
841 def test_msg2obj():
842 def test_msg2obj():
842 am = dict(x=1)
843 am = dict(x=1)
843 ao = Message(am)
844 ao = Message(am)
844 assert ao.x == am['x']
845 assert ao.x == am['x']
845
846
846 am['y'] = dict(z=1)
847 am['y'] = dict(z=1)
847 ao = Message(am)
848 ao = Message(am)
848 assert ao.y.z == am['y']['z']
849 assert ao.y.z == am['y']['z']
849
850
850 k1, k2 = 'y', 'z'
851 k1, k2 = 'y', 'z'
851 assert ao[k1][k2] == am[k1][k2]
852 assert ao[k1][k2] == am[k1][k2]
852
853
853 am2 = dict(ao)
854 am2 = dict(ao)
854 assert am['x'] == am2['x']
855 assert am['x'] == am2['x']
855 assert am['y']['z'] == am2['y']['z']
856 assert am['y']['z'] == am2['y']['z']
856
857
@@ -1,420 +1,425 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Pickle related utilities. Perhaps this should be called 'can'."""
2 """Pickle related utilities. Perhaps this should be called 'can'."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import copy
7 import copy
8 import logging
8 import logging
9 import sys
9 import sys
10 from types import FunctionType
10 from types import FunctionType
11
11
12 try:
12 try:
13 import cPickle as pickle
13 import cPickle as pickle
14 except ImportError:
14 except ImportError:
15 import pickle
15 import pickle
16
16
17 from . import codeutil # This registers a hook when it's imported
17 from . import codeutil # This registers a hook when it's imported
18 from . import py3compat
18 from . import py3compat
19 from .importstring import import_item
19 from .importstring import import_item
20 from .py3compat import string_types, iteritems
20 from .py3compat import string_types, iteritems
21
21
22 from IPython.config import Application
22 from IPython.config import Application
23 from IPython.utils.log import get_logger
23 from IPython.utils.log import get_logger
24
24
25 if py3compat.PY3:
25 if py3compat.PY3:
26 buffer = memoryview
26 buffer = memoryview
27 class_type = type
27 class_type = type
28 else:
28 else:
29 from types import ClassType
29 from types import ClassType
30 class_type = (type, ClassType)
30 class_type = (type, ClassType)
31
31
32 try:
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 except AttributeError:
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36
32 def _get_cell_type(a=None):
37 def _get_cell_type(a=None):
33 """the type of a closure cell doesn't seem to be importable,
38 """the type of a closure cell doesn't seem to be importable,
34 so just create one
39 so just create one
35 """
40 """
36 def inner():
41 def inner():
37 return a
42 return a
38 return type(py3compat.get_closure(inner)[0])
43 return type(py3compat.get_closure(inner)[0])
39
44
40 cell_type = _get_cell_type()
45 cell_type = _get_cell_type()
41
46
42 #-------------------------------------------------------------------------------
47 #-------------------------------------------------------------------------------
43 # Functions
48 # Functions
44 #-------------------------------------------------------------------------------
49 #-------------------------------------------------------------------------------
45
50
46
51
47 def use_dill():
52 def use_dill():
48 """use dill to expand serialization support
53 """use dill to expand serialization support
49
54
50 adds support for object methods and closures to serialization.
55 adds support for object methods and closures to serialization.
51 """
56 """
52 # import dill causes most of the magic
57 # import dill causes most of the magic
53 import dill
58 import dill
54
59
55 # dill doesn't work with cPickle,
60 # dill doesn't work with cPickle,
56 # tell the two relevant modules to use plain pickle
61 # tell the two relevant modules to use plain pickle
57
62
58 global pickle
63 global pickle
59 pickle = dill
64 pickle = dill
60
65
61 try:
66 try:
62 from IPython.kernel.zmq import serialize
67 from IPython.kernel.zmq import serialize
63 except ImportError:
68 except ImportError:
64 pass
69 pass
65 else:
70 else:
66 serialize.pickle = dill
71 serialize.pickle = dill
67
72
68 # disable special function handling, let dill take care of it
73 # disable special function handling, let dill take care of it
69 can_map.pop(FunctionType, None)
74 can_map.pop(FunctionType, None)
70
75
71 def use_cloudpickle():
76 def use_cloudpickle():
72 """use cloudpickle to expand serialization support
77 """use cloudpickle to expand serialization support
73
78
74 adds support for object methods and closures to serialization.
79 adds support for object methods and closures to serialization.
75 """
80 """
76 from cloud.serialization import cloudpickle
81 from cloud.serialization import cloudpickle
77
82
78 global pickle
83 global pickle
79 pickle = cloudpickle
84 pickle = cloudpickle
80
85
81 try:
86 try:
82 from IPython.kernel.zmq import serialize
87 from IPython.kernel.zmq import serialize
83 except ImportError:
88 except ImportError:
84 pass
89 pass
85 else:
90 else:
86 serialize.pickle = cloudpickle
91 serialize.pickle = cloudpickle
87
92
88 # disable special function handling, let cloudpickle take care of it
93 # disable special function handling, let cloudpickle take care of it
89 can_map.pop(FunctionType, None)
94 can_map.pop(FunctionType, None)
90
95
91
96
92 #-------------------------------------------------------------------------------
97 #-------------------------------------------------------------------------------
93 # Classes
98 # Classes
94 #-------------------------------------------------------------------------------
99 #-------------------------------------------------------------------------------
95
100
96
101
97 class CannedObject(object):
102 class CannedObject(object):
98 def __init__(self, obj, keys=[], hook=None):
103 def __init__(self, obj, keys=[], hook=None):
99 """can an object for safe pickling
104 """can an object for safe pickling
100
105
101 Parameters
106 Parameters
102 ==========
107 ==========
103
108
104 obj:
109 obj:
105 The object to be canned
110 The object to be canned
106 keys: list (optional)
111 keys: list (optional)
107 list of attribute names that will be explicitly canned / uncanned
112 list of attribute names that will be explicitly canned / uncanned
108 hook: callable (optional)
113 hook: callable (optional)
109 An optional extra callable,
114 An optional extra callable,
110 which can do additional processing of the uncanned object.
115 which can do additional processing of the uncanned object.
111
116
112 large data may be offloaded into the buffers list,
117 large data may be offloaded into the buffers list,
113 used for zero-copy transfers.
118 used for zero-copy transfers.
114 """
119 """
115 self.keys = keys
120 self.keys = keys
116 self.obj = copy.copy(obj)
121 self.obj = copy.copy(obj)
117 self.hook = can(hook)
122 self.hook = can(hook)
118 for key in keys:
123 for key in keys:
119 setattr(self.obj, key, can(getattr(obj, key)))
124 setattr(self.obj, key, can(getattr(obj, key)))
120
125
121 self.buffers = []
126 self.buffers = []
122
127
123 def get_object(self, g=None):
128 def get_object(self, g=None):
124 if g is None:
129 if g is None:
125 g = {}
130 g = {}
126 obj = self.obj
131 obj = self.obj
127 for key in self.keys:
132 for key in self.keys:
128 setattr(obj, key, uncan(getattr(obj, key), g))
133 setattr(obj, key, uncan(getattr(obj, key), g))
129
134
130 if self.hook:
135 if self.hook:
131 self.hook = uncan(self.hook, g)
136 self.hook = uncan(self.hook, g)
132 self.hook(obj, g)
137 self.hook(obj, g)
133 return self.obj
138 return self.obj
134
139
135
140
136 class Reference(CannedObject):
141 class Reference(CannedObject):
137 """object for wrapping a remote reference by name."""
142 """object for wrapping a remote reference by name."""
138 def __init__(self, name):
143 def __init__(self, name):
139 if not isinstance(name, string_types):
144 if not isinstance(name, string_types):
140 raise TypeError("illegal name: %r"%name)
145 raise TypeError("illegal name: %r"%name)
141 self.name = name
146 self.name = name
142 self.buffers = []
147 self.buffers = []
143
148
144 def __repr__(self):
149 def __repr__(self):
145 return "<Reference: %r>"%self.name
150 return "<Reference: %r>"%self.name
146
151
147 def get_object(self, g=None):
152 def get_object(self, g=None):
148 if g is None:
153 if g is None:
149 g = {}
154 g = {}
150
155
151 return eval(self.name, g)
156 return eval(self.name, g)
152
157
153
158
154 class CannedCell(CannedObject):
159 class CannedCell(CannedObject):
155 """Can a closure cell"""
160 """Can a closure cell"""
156 def __init__(self, cell):
161 def __init__(self, cell):
157 self.cell_contents = can(cell.cell_contents)
162 self.cell_contents = can(cell.cell_contents)
158
163
159 def get_object(self, g=None):
164 def get_object(self, g=None):
160 cell_contents = uncan(self.cell_contents, g)
165 cell_contents = uncan(self.cell_contents, g)
161 def inner():
166 def inner():
162 return cell_contents
167 return cell_contents
163 return py3compat.get_closure(inner)[0]
168 return py3compat.get_closure(inner)[0]
164
169
165
170
166 class CannedFunction(CannedObject):
171 class CannedFunction(CannedObject):
167
172
168 def __init__(self, f):
173 def __init__(self, f):
169 self._check_type(f)
174 self._check_type(f)
170 self.code = f.__code__
175 self.code = f.__code__
171 if f.__defaults__:
176 if f.__defaults__:
172 self.defaults = [ can(fd) for fd in f.__defaults__ ]
177 self.defaults = [ can(fd) for fd in f.__defaults__ ]
173 else:
178 else:
174 self.defaults = None
179 self.defaults = None
175
180
176 closure = py3compat.get_closure(f)
181 closure = py3compat.get_closure(f)
177 if closure:
182 if closure:
178 self.closure = tuple( can(cell) for cell in closure )
183 self.closure = tuple( can(cell) for cell in closure )
179 else:
184 else:
180 self.closure = None
185 self.closure = None
181
186
182 self.module = f.__module__ or '__main__'
187 self.module = f.__module__ or '__main__'
183 self.__name__ = f.__name__
188 self.__name__ = f.__name__
184 self.buffers = []
189 self.buffers = []
185
190
186 def _check_type(self, obj):
191 def _check_type(self, obj):
187 assert isinstance(obj, FunctionType), "Not a function type"
192 assert isinstance(obj, FunctionType), "Not a function type"
188
193
189 def get_object(self, g=None):
194 def get_object(self, g=None):
190 # try to load function back into its module:
195 # try to load function back into its module:
191 if not self.module.startswith('__'):
196 if not self.module.startswith('__'):
192 __import__(self.module)
197 __import__(self.module)
193 g = sys.modules[self.module].__dict__
198 g = sys.modules[self.module].__dict__
194
199
195 if g is None:
200 if g is None:
196 g = {}
201 g = {}
197 if self.defaults:
202 if self.defaults:
198 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
203 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
199 else:
204 else:
200 defaults = None
205 defaults = None
201 if self.closure:
206 if self.closure:
202 closure = tuple(uncan(cell, g) for cell in self.closure)
207 closure = tuple(uncan(cell, g) for cell in self.closure)
203 else:
208 else:
204 closure = None
209 closure = None
205 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
210 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
206 return newFunc
211 return newFunc
207
212
208 class CannedClass(CannedObject):
213 class CannedClass(CannedObject):
209
214
210 def __init__(self, cls):
215 def __init__(self, cls):
211 self._check_type(cls)
216 self._check_type(cls)
212 self.name = cls.__name__
217 self.name = cls.__name__
213 self.old_style = not isinstance(cls, type)
218 self.old_style = not isinstance(cls, type)
214 self._canned_dict = {}
219 self._canned_dict = {}
215 for k,v in cls.__dict__.items():
220 for k,v in cls.__dict__.items():
216 if k not in ('__weakref__', '__dict__'):
221 if k not in ('__weakref__', '__dict__'):
217 self._canned_dict[k] = can(v)
222 self._canned_dict[k] = can(v)
218 if self.old_style:
223 if self.old_style:
219 mro = []
224 mro = []
220 else:
225 else:
221 mro = cls.mro()
226 mro = cls.mro()
222
227
223 self.parents = [ can(c) for c in mro[1:] ]
228 self.parents = [ can(c) for c in mro[1:] ]
224 self.buffers = []
229 self.buffers = []
225
230
226 def _check_type(self, obj):
231 def _check_type(self, obj):
227 assert isinstance(obj, class_type), "Not a class type"
232 assert isinstance(obj, class_type), "Not a class type"
228
233
229 def get_object(self, g=None):
234 def get_object(self, g=None):
230 parents = tuple(uncan(p, g) for p in self.parents)
235 parents = tuple(uncan(p, g) for p in self.parents)
231 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
236 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
232
237
233 class CannedArray(CannedObject):
238 class CannedArray(CannedObject):
234 def __init__(self, obj):
239 def __init__(self, obj):
235 from numpy import ascontiguousarray
240 from numpy import ascontiguousarray
236 self.shape = obj.shape
241 self.shape = obj.shape
237 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
242 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
238 self.pickled = False
243 self.pickled = False
239 if sum(obj.shape) == 0:
244 if sum(obj.shape) == 0:
240 self.pickled = True
245 self.pickled = True
241 elif obj.dtype == 'O':
246 elif obj.dtype == 'O':
242 # can't handle object dtype with buffer approach
247 # can't handle object dtype with buffer approach
243 self.pickled = True
248 self.pickled = True
244 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
249 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
245 self.pickled = True
250 self.pickled = True
246 if self.pickled:
251 if self.pickled:
247 # just pickle it
252 # just pickle it
248 self.buffers = [pickle.dumps(obj, -1)]
253 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
249 else:
254 else:
250 # ensure contiguous
255 # ensure contiguous
251 obj = ascontiguousarray(obj, dtype=None)
256 obj = ascontiguousarray(obj, dtype=None)
252 self.buffers = [buffer(obj)]
257 self.buffers = [buffer(obj)]
253
258
254 def get_object(self, g=None):
259 def get_object(self, g=None):
255 from numpy import frombuffer
260 from numpy import frombuffer
256 data = self.buffers[0]
261 data = self.buffers[0]
257 if self.pickled:
262 if self.pickled:
258 # no shape, we just pickled it
263 # no shape, we just pickled it
259 return pickle.loads(data)
264 return pickle.loads(data)
260 else:
265 else:
261 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
266 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
262
267
263
268
264 class CannedBytes(CannedObject):
269 class CannedBytes(CannedObject):
265 wrap = bytes
270 wrap = bytes
266 def __init__(self, obj):
271 def __init__(self, obj):
267 self.buffers = [obj]
272 self.buffers = [obj]
268
273
269 def get_object(self, g=None):
274 def get_object(self, g=None):
270 data = self.buffers[0]
275 data = self.buffers[0]
271 return self.wrap(data)
276 return self.wrap(data)
272
277
273 def CannedBuffer(CannedBytes):
278 def CannedBuffer(CannedBytes):
274 wrap = buffer
279 wrap = buffer
275
280
276 #-------------------------------------------------------------------------------
281 #-------------------------------------------------------------------------------
277 # Functions
282 # Functions
278 #-------------------------------------------------------------------------------
283 #-------------------------------------------------------------------------------
279
284
280 def _import_mapping(mapping, original=None):
285 def _import_mapping(mapping, original=None):
281 """import any string-keys in a type mapping
286 """import any string-keys in a type mapping
282
287
283 """
288 """
284 log = get_logger()
289 log = get_logger()
285 log.debug("Importing canning map")
290 log.debug("Importing canning map")
286 for key,value in list(mapping.items()):
291 for key,value in list(mapping.items()):
287 if isinstance(key, string_types):
292 if isinstance(key, string_types):
288 try:
293 try:
289 cls = import_item(key)
294 cls = import_item(key)
290 except Exception:
295 except Exception:
291 if original and key not in original:
296 if original and key not in original:
292 # only message on user-added classes
297 # only message on user-added classes
293 log.error("canning class not importable: %r", key, exc_info=True)
298 log.error("canning class not importable: %r", key, exc_info=True)
294 mapping.pop(key)
299 mapping.pop(key)
295 else:
300 else:
296 mapping[cls] = mapping.pop(key)
301 mapping[cls] = mapping.pop(key)
297
302
298 def istype(obj, check):
303 def istype(obj, check):
299 """like isinstance(obj, check), but strict
304 """like isinstance(obj, check), but strict
300
305
301 This won't catch subclasses.
306 This won't catch subclasses.
302 """
307 """
303 if isinstance(check, tuple):
308 if isinstance(check, tuple):
304 for cls in check:
309 for cls in check:
305 if type(obj) is cls:
310 if type(obj) is cls:
306 return True
311 return True
307 return False
312 return False
308 else:
313 else:
309 return type(obj) is check
314 return type(obj) is check
310
315
311 def can(obj):
316 def can(obj):
312 """prepare an object for pickling"""
317 """prepare an object for pickling"""
313
318
314 import_needed = False
319 import_needed = False
315
320
316 for cls,canner in iteritems(can_map):
321 for cls,canner in iteritems(can_map):
317 if isinstance(cls, string_types):
322 if isinstance(cls, string_types):
318 import_needed = True
323 import_needed = True
319 break
324 break
320 elif istype(obj, cls):
325 elif istype(obj, cls):
321 return canner(obj)
326 return canner(obj)
322
327
323 if import_needed:
328 if import_needed:
324 # perform can_map imports, then try again
329 # perform can_map imports, then try again
325 # this will usually only happen once
330 # this will usually only happen once
326 _import_mapping(can_map, _original_can_map)
331 _import_mapping(can_map, _original_can_map)
327 return can(obj)
332 return can(obj)
328
333
329 return obj
334 return obj
330
335
331 def can_class(obj):
336 def can_class(obj):
332 if isinstance(obj, class_type) and obj.__module__ == '__main__':
337 if isinstance(obj, class_type) and obj.__module__ == '__main__':
333 return CannedClass(obj)
338 return CannedClass(obj)
334 else:
339 else:
335 return obj
340 return obj
336
341
337 def can_dict(obj):
342 def can_dict(obj):
338 """can the *values* of a dict"""
343 """can the *values* of a dict"""
339 if istype(obj, dict):
344 if istype(obj, dict):
340 newobj = {}
345 newobj = {}
341 for k, v in iteritems(obj):
346 for k, v in iteritems(obj):
342 newobj[k] = can(v)
347 newobj[k] = can(v)
343 return newobj
348 return newobj
344 else:
349 else:
345 return obj
350 return obj
346
351
347 sequence_types = (list, tuple, set)
352 sequence_types = (list, tuple, set)
348
353
349 def can_sequence(obj):
354 def can_sequence(obj):
350 """can the elements of a sequence"""
355 """can the elements of a sequence"""
351 if istype(obj, sequence_types):
356 if istype(obj, sequence_types):
352 t = type(obj)
357 t = type(obj)
353 return t([can(i) for i in obj])
358 return t([can(i) for i in obj])
354 else:
359 else:
355 return obj
360 return obj
356
361
357 def uncan(obj, g=None):
362 def uncan(obj, g=None):
358 """invert canning"""
363 """invert canning"""
359
364
360 import_needed = False
365 import_needed = False
361 for cls,uncanner in iteritems(uncan_map):
366 for cls,uncanner in iteritems(uncan_map):
362 if isinstance(cls, string_types):
367 if isinstance(cls, string_types):
363 import_needed = True
368 import_needed = True
364 break
369 break
365 elif isinstance(obj, cls):
370 elif isinstance(obj, cls):
366 return uncanner(obj, g)
371 return uncanner(obj, g)
367
372
368 if import_needed:
373 if import_needed:
369 # perform uncan_map imports, then try again
374 # perform uncan_map imports, then try again
370 # this will usually only happen once
375 # this will usually only happen once
371 _import_mapping(uncan_map, _original_uncan_map)
376 _import_mapping(uncan_map, _original_uncan_map)
372 return uncan(obj, g)
377 return uncan(obj, g)
373
378
374 return obj
379 return obj
375
380
376 def uncan_dict(obj, g=None):
381 def uncan_dict(obj, g=None):
377 if istype(obj, dict):
382 if istype(obj, dict):
378 newobj = {}
383 newobj = {}
379 for k, v in iteritems(obj):
384 for k, v in iteritems(obj):
380 newobj[k] = uncan(v,g)
385 newobj[k] = uncan(v,g)
381 return newobj
386 return newobj
382 else:
387 else:
383 return obj
388 return obj
384
389
385 def uncan_sequence(obj, g=None):
390 def uncan_sequence(obj, g=None):
386 if istype(obj, sequence_types):
391 if istype(obj, sequence_types):
387 t = type(obj)
392 t = type(obj)
388 return t([uncan(i,g) for i in obj])
393 return t([uncan(i,g) for i in obj])
389 else:
394 else:
390 return obj
395 return obj
391
396
392 def _uncan_dependent_hook(dep, g=None):
397 def _uncan_dependent_hook(dep, g=None):
393 dep.check_dependency()
398 dep.check_dependency()
394
399
395 def can_dependent(obj):
400 def can_dependent(obj):
396 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
401 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
397
402
398 #-------------------------------------------------------------------------------
403 #-------------------------------------------------------------------------------
399 # API dictionaries
404 # API dictionaries
400 #-------------------------------------------------------------------------------
405 #-------------------------------------------------------------------------------
401
406
402 # These dicts can be extended for custom serialization of new objects
407 # These dicts can be extended for custom serialization of new objects
403
408
404 can_map = {
409 can_map = {
405 'IPython.parallel.dependent' : can_dependent,
410 'IPython.parallel.dependent' : can_dependent,
406 'numpy.ndarray' : CannedArray,
411 'numpy.ndarray' : CannedArray,
407 FunctionType : CannedFunction,
412 FunctionType : CannedFunction,
408 bytes : CannedBytes,
413 bytes : CannedBytes,
409 buffer : CannedBuffer,
414 buffer : CannedBuffer,
410 cell_type : CannedCell,
415 cell_type : CannedCell,
411 class_type : can_class,
416 class_type : can_class,
412 }
417 }
413
418
414 uncan_map = {
419 uncan_map = {
415 CannedObject : lambda obj, g: obj.get_object(g),
420 CannedObject : lambda obj, g: obj.get_object(g),
416 }
421 }
417
422
418 # for use in _import_mapping:
423 # for use in _import_mapping:
419 _original_can_map = can_map.copy()
424 _original_can_map = can_map.copy()
420 _original_uncan_map = uncan_map.copy()
425 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now