##// END OF EJS Templates
add pickleutil.PICKLE_PROTOCOL...
MinRK -
Show More
@@ -1,198 +1,185 b''
1 """serialization utilities for apply messages
1 """serialization utilities for apply messages"""
2
2
3 Authors:
3 # Copyright (c) IPython Development Team.
4
4 # Distributed under the terms of the Modified BSD License.
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
13
14 #-----------------------------------------------------------------------------
15 # Imports
16 #-----------------------------------------------------------------------------
17
5
18 try:
6 try:
19 import cPickle
7 import cPickle
20 pickle = cPickle
8 pickle = cPickle
21 except:
9 except:
22 cPickle = None
10 cPickle = None
23 import pickle
11 import pickle
24
12
25
26 # IPython imports
13 # IPython imports
27 from IPython.utils import py3compat
14 from IPython.utils import py3compat
28 from IPython.utils.data import flatten
15 from IPython.utils.data import flatten
29 from IPython.utils.pickleutil import (
16 from IPython.utils.pickleutil import (
30 can, uncan, can_sequence, uncan_sequence, CannedObject,
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
31 istype, sequence_types,
18 istype, sequence_types, PICKLE_PROTOCOL,
32 )
19 )
33
20
34 if py3compat.PY3:
21 if py3compat.PY3:
35 buffer = memoryview
22 buffer = memoryview
36
23
37 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
38 # Serialization Functions
25 # Serialization Functions
39 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
40
27
41 # default values for the thresholds:
28 # default values for the thresholds:
42 MAX_ITEMS = 64
29 MAX_ITEMS = 64
43 MAX_BYTES = 1024
30 MAX_BYTES = 1024
44
31
45 def _extract_buffers(obj, threshold=MAX_BYTES):
32 def _extract_buffers(obj, threshold=MAX_BYTES):
46 """extract buffers larger than a certain threshold"""
33 """extract buffers larger than a certain threshold"""
47 buffers = []
34 buffers = []
48 if isinstance(obj, CannedObject) and obj.buffers:
35 if isinstance(obj, CannedObject) and obj.buffers:
49 for i,buf in enumerate(obj.buffers):
36 for i,buf in enumerate(obj.buffers):
50 if len(buf) > threshold:
37 if len(buf) > threshold:
51 # buffer larger than threshold, prevent pickling
38 # buffer larger than threshold, prevent pickling
52 obj.buffers[i] = None
39 obj.buffers[i] = None
53 buffers.append(buf)
40 buffers.append(buf)
54 elif isinstance(buf, buffer):
41 elif isinstance(buf, buffer):
55 # buffer too small for separate send, coerce to bytes
42 # buffer too small for separate send, coerce to bytes
56 # because pickling buffer objects just results in broken pointers
43 # because pickling buffer objects just results in broken pointers
57 obj.buffers[i] = bytes(buf)
44 obj.buffers[i] = bytes(buf)
58 return buffers
45 return buffers
59
46
60 def _restore_buffers(obj, buffers):
47 def _restore_buffers(obj, buffers):
61 """restore buffers extracted by """
48 """restore buffers extracted by """
62 if isinstance(obj, CannedObject) and obj.buffers:
49 if isinstance(obj, CannedObject) and obj.buffers:
63 for i,buf in enumerate(obj.buffers):
50 for i,buf in enumerate(obj.buffers):
64 if buf is None:
51 if buf is None:
65 obj.buffers[i] = buffers.pop(0)
52 obj.buffers[i] = buffers.pop(0)
66
53
67 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
54 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
68 """Serialize an object into a list of sendable buffers.
55 """Serialize an object into a list of sendable buffers.
69
56
70 Parameters
57 Parameters
71 ----------
58 ----------
72
59
73 obj : object
60 obj : object
74 The object to be serialized
61 The object to be serialized
75 buffer_threshold : int
62 buffer_threshold : int
76 The threshold (in bytes) for pulling out data buffers
63 The threshold (in bytes) for pulling out data buffers
77 to avoid pickling them.
64 to avoid pickling them.
78 item_threshold : int
65 item_threshold : int
79 The maximum number of items over which canning will iterate.
66 The maximum number of items over which canning will iterate.
80 Containers (lists, dicts) larger than this will be pickled without
67 Containers (lists, dicts) larger than this will be pickled without
81 introspection.
68 introspection.
82
69
83 Returns
70 Returns
84 -------
71 -------
85 [bufs] : list of buffers representing the serialized object.
72 [bufs] : list of buffers representing the serialized object.
86 """
73 """
87 buffers = []
74 buffers = []
88 if istype(obj, sequence_types) and len(obj) < item_threshold:
75 if istype(obj, sequence_types) and len(obj) < item_threshold:
89 cobj = can_sequence(obj)
76 cobj = can_sequence(obj)
90 for c in cobj:
77 for c in cobj:
91 buffers.extend(_extract_buffers(c, buffer_threshold))
78 buffers.extend(_extract_buffers(c, buffer_threshold))
92 elif istype(obj, dict) and len(obj) < item_threshold:
79 elif istype(obj, dict) and len(obj) < item_threshold:
93 cobj = {}
80 cobj = {}
94 for k in sorted(obj):
81 for k in sorted(obj):
95 c = can(obj[k])
82 c = can(obj[k])
96 buffers.extend(_extract_buffers(c, buffer_threshold))
83 buffers.extend(_extract_buffers(c, buffer_threshold))
97 cobj[k] = c
84 cobj[k] = c
98 else:
85 else:
99 cobj = can(obj)
86 cobj = can(obj)
100 buffers.extend(_extract_buffers(cobj, buffer_threshold))
87 buffers.extend(_extract_buffers(cobj, buffer_threshold))
101
88
102 buffers.insert(0, pickle.dumps(cobj,-1))
89 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
103 return buffers
90 return buffers
104
91
105 def unserialize_object(buffers, g=None):
92 def unserialize_object(buffers, g=None):
106 """reconstruct an object serialized by serialize_object from data buffers.
93 """reconstruct an object serialized by serialize_object from data buffers.
107
94
108 Parameters
95 Parameters
109 ----------
96 ----------
110
97
111 bufs : list of buffers/bytes
98 bufs : list of buffers/bytes
112
99
113 g : globals to be used when uncanning
100 g : globals to be used when uncanning
114
101
115 Returns
102 Returns
116 -------
103 -------
117
104
118 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
105 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
119 """
106 """
120 bufs = list(buffers)
107 bufs = list(buffers)
121 pobj = bufs.pop(0)
108 pobj = bufs.pop(0)
122 if not isinstance(pobj, bytes):
109 if not isinstance(pobj, bytes):
123 # a zmq message
110 # a zmq message
124 pobj = bytes(pobj)
111 pobj = bytes(pobj)
125 canned = pickle.loads(pobj)
112 canned = pickle.loads(pobj)
126 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
113 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
127 for c in canned:
114 for c in canned:
128 _restore_buffers(c, bufs)
115 _restore_buffers(c, bufs)
129 newobj = uncan_sequence(canned, g)
116 newobj = uncan_sequence(canned, g)
130 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
117 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
131 newobj = {}
118 newobj = {}
132 for k in sorted(canned):
119 for k in sorted(canned):
133 c = canned[k]
120 c = canned[k]
134 _restore_buffers(c, bufs)
121 _restore_buffers(c, bufs)
135 newobj[k] = uncan(c, g)
122 newobj[k] = uncan(c, g)
136 else:
123 else:
137 _restore_buffers(canned, bufs)
124 _restore_buffers(canned, bufs)
138 newobj = uncan(canned, g)
125 newobj = uncan(canned, g)
139
126
140 return newobj, bufs
127 return newobj, bufs
141
128
142 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
129 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
143 """pack up a function, args, and kwargs to be sent over the wire
130 """pack up a function, args, and kwargs to be sent over the wire
144
131
145 Each element of args/kwargs will be canned for special treatment,
132 Each element of args/kwargs will be canned for special treatment,
146 but inspection will not go any deeper than that.
133 but inspection will not go any deeper than that.
147
134
148 Any object whose data is larger than `threshold` will not have their data copied
135 Any object whose data is larger than `threshold` will not have their data copied
149 (only numpy arrays and bytes/buffers support zero-copy)
136 (only numpy arrays and bytes/buffers support zero-copy)
150
137
151 Message will be a list of bytes/buffers of the format:
138 Message will be a list of bytes/buffers of the format:
152
139
153 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
140 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
154
141
155 With length at least two + len(args) + len(kwargs)
142 With length at least two + len(args) + len(kwargs)
156 """
143 """
157
144
158 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
145 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
159
146
160 kw_keys = sorted(kwargs.keys())
147 kw_keys = sorted(kwargs.keys())
161 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
148 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
162
149
163 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
150 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
164
151
165 msg = [pickle.dumps(can(f),-1)]
152 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
166 msg.append(pickle.dumps(info, -1))
153 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
167 msg.extend(arg_bufs)
154 msg.extend(arg_bufs)
168 msg.extend(kwarg_bufs)
155 msg.extend(kwarg_bufs)
169
156
170 return msg
157 return msg
171
158
172 def unpack_apply_message(bufs, g=None, copy=True):
159 def unpack_apply_message(bufs, g=None, copy=True):
173 """unpack f,args,kwargs from buffers packed by pack_apply_message()
160 """unpack f,args,kwargs from buffers packed by pack_apply_message()
174 Returns: original f,args,kwargs"""
161 Returns: original f,args,kwargs"""
175 bufs = list(bufs) # allow us to pop
162 bufs = list(bufs) # allow us to pop
176 assert len(bufs) >= 2, "not enough buffers!"
163 assert len(bufs) >= 2, "not enough buffers!"
177 if not copy:
164 if not copy:
178 for i in range(2):
165 for i in range(2):
179 bufs[i] = bufs[i].bytes
166 bufs[i] = bufs[i].bytes
180 f = uncan(pickle.loads(bufs.pop(0)), g)
167 f = uncan(pickle.loads(bufs.pop(0)), g)
181 info = pickle.loads(bufs.pop(0))
168 info = pickle.loads(bufs.pop(0))
182 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
169 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
183
170
184 args = []
171 args = []
185 for i in range(info['nargs']):
172 for i in range(info['nargs']):
186 arg, arg_bufs = unserialize_object(arg_bufs, g)
173 arg, arg_bufs = unserialize_object(arg_bufs, g)
187 args.append(arg)
174 args.append(arg)
188 args = tuple(args)
175 args = tuple(args)
189 assert not arg_bufs, "Shouldn't be any arg bufs left over"
176 assert not arg_bufs, "Shouldn't be any arg bufs left over"
190
177
191 kwargs = {}
178 kwargs = {}
192 for key in info['kw_keys']:
179 for key in info['kw_keys']:
193 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
180 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
194 kwargs[key] = kwarg
181 kwargs[key] = kwarg
195 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
182 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
196
183
197 return f,args,kwargs
184 return f,args,kwargs
198
185
@@ -1,856 +1,857 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 from datetime import datetime
21 from datetime import datetime
22
22
23 try:
23 try:
24 import cPickle
24 import cPickle
25 pickle = cPickle
25 pickle = cPickle
26 except:
26 except:
27 cPickle = None
27 cPickle = None
28 import pickle
28 import pickle
29
29
30 import zmq
30 import zmq
31 from zmq.utils import jsonapi
31 from zmq.utils import jsonapi
32 from zmq.eventloop.ioloop import IOLoop
32 from zmq.eventloop.ioloop import IOLoop
33 from zmq.eventloop.zmqstream import ZMQStream
33 from zmq.eventloop.zmqstream import ZMQStream
34
34
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 from IPython.config.configurable import Configurable, LoggingConfigurable
36 from IPython.config.configurable import Configurable, LoggingConfigurable
37 from IPython.utils import io
37 from IPython.utils import io
38 from IPython.utils.importstring import import_item
38 from IPython.utils.importstring import import_item
39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 iteritems)
41 iteritems)
42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 DottedObjectName, CUnicode, Dict, Integer,
43 DottedObjectName, CUnicode, Dict, Integer,
44 TraitError,
44 TraitError,
45 )
45 )
46 from IPython.utils.pickleutil import PICKLE_PROTOCOL
46 from IPython.kernel.adapter import adapt
47 from IPython.kernel.adapter import adapt
47 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48
49
49 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
50 # utility functions
51 # utility functions
51 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
52
53
53 def squash_unicode(obj):
54 def squash_unicode(obj):
54 """coerce unicode back to bytestrings."""
55 """coerce unicode back to bytestrings."""
55 if isinstance(obj,dict):
56 if isinstance(obj,dict):
56 for key in obj.keys():
57 for key in obj.keys():
57 obj[key] = squash_unicode(obj[key])
58 obj[key] = squash_unicode(obj[key])
58 if isinstance(key, unicode_type):
59 if isinstance(key, unicode_type):
59 obj[squash_unicode(key)] = obj.pop(key)
60 obj[squash_unicode(key)] = obj.pop(key)
60 elif isinstance(obj, list):
61 elif isinstance(obj, list):
61 for i,v in enumerate(obj):
62 for i,v in enumerate(obj):
62 obj[i] = squash_unicode(v)
63 obj[i] = squash_unicode(v)
63 elif isinstance(obj, unicode_type):
64 elif isinstance(obj, unicode_type):
64 obj = obj.encode('utf8')
65 obj = obj.encode('utf8')
65 return obj
66 return obj
66
67
67 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
68 # globals and defaults
69 # globals and defaults
69 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
70
71
71 # ISO8601-ify datetime objects
72 # ISO8601-ify datetime objects
72 # allow unicode
73 # allow unicode
73 # disallow nan, because it's not actually valid JSON
74 # disallow nan, because it's not actually valid JSON
74 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 ensure_ascii=False, allow_nan=False,
76 ensure_ascii=False, allow_nan=False,
76 )
77 )
77 json_unpacker = lambda s: jsonapi.loads(s)
78 json_unpacker = lambda s: jsonapi.loads(s)
78
79
79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
80 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
80 pickle_unpacker = pickle.loads
81 pickle_unpacker = pickle.loads
81
82
82 default_packer = json_packer
83 default_packer = json_packer
83 default_unpacker = json_unpacker
84 default_unpacker = json_unpacker
84
85
85 DELIM = b"<IDS|MSG>"
86 DELIM = b"<IDS|MSG>"
86 # singleton dummy tracker, which will always report as done
87 # singleton dummy tracker, which will always report as done
87 DONE = zmq.MessageTracker()
88 DONE = zmq.MessageTracker()
88
89
89 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
90 # Mixin tools for apps that use Sessions
91 # Mixin tools for apps that use Sessions
91 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
92
93
93 session_aliases = dict(
94 session_aliases = dict(
94 ident = 'Session.session',
95 ident = 'Session.session',
95 user = 'Session.username',
96 user = 'Session.username',
96 keyfile = 'Session.keyfile',
97 keyfile = 'Session.keyfile',
97 )
98 )
98
99
99 session_flags = {
100 session_flags = {
100 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 'keyfile' : '' }},
102 'keyfile' : '' }},
102 """Use HMAC digests for authentication of messages.
103 """Use HMAC digests for authentication of messages.
103 Setting this flag will generate a new UUID to use as the HMAC key.
104 Setting this flag will generate a new UUID to use as the HMAC key.
104 """),
105 """),
105 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 """Don't authenticate messages."""),
107 """Don't authenticate messages."""),
107 }
108 }
108
109
109 def default_secure(cfg):
110 def default_secure(cfg):
110 """Set the default behavior for a config environment to be secure.
111 """Set the default behavior for a config environment to be secure.
111
112
112 If Session.key/keyfile have not been set, set Session.key to
113 If Session.key/keyfile have not been set, set Session.key to
113 a new random UUID.
114 a new random UUID.
114 """
115 """
115
116
116 if 'Session' in cfg:
117 if 'Session' in cfg:
117 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 return
119 return
119 # key/keyfile not specified, generate new UUID:
120 # key/keyfile not specified, generate new UUID:
120 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121
122
122
123
123 #-----------------------------------------------------------------------------
124 #-----------------------------------------------------------------------------
124 # Classes
125 # Classes
125 #-----------------------------------------------------------------------------
126 #-----------------------------------------------------------------------------
126
127
127 class SessionFactory(LoggingConfigurable):
128 class SessionFactory(LoggingConfigurable):
128 """The Base class for configurables that have a Session, Context, logger,
129 """The Base class for configurables that have a Session, Context, logger,
129 and IOLoop.
130 and IOLoop.
130 """
131 """
131
132
132 logname = Unicode('')
133 logname = Unicode('')
133 def _logname_changed(self, name, old, new):
134 def _logname_changed(self, name, old, new):
134 self.log = logging.getLogger(new)
135 self.log = logging.getLogger(new)
135
136
136 # not configurable:
137 # not configurable:
137 context = Instance('zmq.Context')
138 context = Instance('zmq.Context')
138 def _context_default(self):
139 def _context_default(self):
139 return zmq.Context.instance()
140 return zmq.Context.instance()
140
141
141 session = Instance('IPython.kernel.zmq.session.Session')
142 session = Instance('IPython.kernel.zmq.session.Session')
142
143
143 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 def _loop_default(self):
145 def _loop_default(self):
145 return IOLoop.instance()
146 return IOLoop.instance()
146
147
147 def __init__(self, **kwargs):
148 def __init__(self, **kwargs):
148 super(SessionFactory, self).__init__(**kwargs)
149 super(SessionFactory, self).__init__(**kwargs)
149
150
150 if self.session is None:
151 if self.session is None:
151 # construct the session
152 # construct the session
152 self.session = Session(**kwargs)
153 self.session = Session(**kwargs)
153
154
154
155
155 class Message(object):
156 class Message(object):
156 """A simple message object that maps dict keys to attributes.
157 """A simple message object that maps dict keys to attributes.
157
158
158 A Message can be created from a dict and a dict from a Message instance
159 A Message can be created from a dict and a dict from a Message instance
159 simply by calling dict(msg_obj)."""
160 simply by calling dict(msg_obj)."""
160
161
161 def __init__(self, msg_dict):
162 def __init__(self, msg_dict):
162 dct = self.__dict__
163 dct = self.__dict__
163 for k, v in iteritems(dict(msg_dict)):
164 for k, v in iteritems(dict(msg_dict)):
164 if isinstance(v, dict):
165 if isinstance(v, dict):
165 v = Message(v)
166 v = Message(v)
166 dct[k] = v
167 dct[k] = v
167
168
168 # Having this iterator lets dict(msg_obj) work out of the box.
169 # Having this iterator lets dict(msg_obj) work out of the box.
169 def __iter__(self):
170 def __iter__(self):
170 return iter(iteritems(self.__dict__))
171 return iter(iteritems(self.__dict__))
171
172
172 def __repr__(self):
173 def __repr__(self):
173 return repr(self.__dict__)
174 return repr(self.__dict__)
174
175
175 def __str__(self):
176 def __str__(self):
176 return pprint.pformat(self.__dict__)
177 return pprint.pformat(self.__dict__)
177
178
178 def __contains__(self, k):
179 def __contains__(self, k):
179 return k in self.__dict__
180 return k in self.__dict__
180
181
181 def __getitem__(self, k):
182 def __getitem__(self, k):
182 return self.__dict__[k]
183 return self.__dict__[k]
183
184
184
185
185 def msg_header(msg_id, msg_type, username, session):
186 def msg_header(msg_id, msg_type, username, session):
186 date = datetime.now()
187 date = datetime.now()
187 version = kernel_protocol_version
188 version = kernel_protocol_version
188 return locals()
189 return locals()
189
190
190 def extract_header(msg_or_header):
191 def extract_header(msg_or_header):
191 """Given a message or header, return the header."""
192 """Given a message or header, return the header."""
192 if not msg_or_header:
193 if not msg_or_header:
193 return {}
194 return {}
194 try:
195 try:
195 # See if msg_or_header is the entire message.
196 # See if msg_or_header is the entire message.
196 h = msg_or_header['header']
197 h = msg_or_header['header']
197 except KeyError:
198 except KeyError:
198 try:
199 try:
199 # See if msg_or_header is just the header
200 # See if msg_or_header is just the header
200 h = msg_or_header['msg_id']
201 h = msg_or_header['msg_id']
201 except KeyError:
202 except KeyError:
202 raise
203 raise
203 else:
204 else:
204 h = msg_or_header
205 h = msg_or_header
205 if not isinstance(h, dict):
206 if not isinstance(h, dict):
206 h = dict(h)
207 h = dict(h)
207 return h
208 return h
208
209
209 class Session(Configurable):
210 class Session(Configurable):
210 """Object for handling serialization and sending of messages.
211 """Object for handling serialization and sending of messages.
211
212
212 The Session object handles building messages and sending them
213 The Session object handles building messages and sending them
213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 other over the network via Session objects, and only need to work with the
215 other over the network via Session objects, and only need to work with the
215 dict-based IPython message spec. The Session will handle
216 dict-based IPython message spec. The Session will handle
216 serialization/deserialization, security, and metadata.
217 serialization/deserialization, security, and metadata.
217
218
218 Sessions support configurable serialiization via packer/unpacker traits,
219 Sessions support configurable serialiization via packer/unpacker traits,
219 and signing with HMAC digests via the key/keyfile traits.
220 and signing with HMAC digests via the key/keyfile traits.
220
221
221 Parameters
222 Parameters
222 ----------
223 ----------
223
224
224 debug : bool
225 debug : bool
225 whether to trigger extra debugging statements
226 whether to trigger extra debugging statements
226 packer/unpacker : str : 'json', 'pickle' or import_string
227 packer/unpacker : str : 'json', 'pickle' or import_string
227 importstrings for methods to serialize message parts. If just
228 importstrings for methods to serialize message parts. If just
228 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 Otherwise, the entire importstring must be used.
230 Otherwise, the entire importstring must be used.
230
231
231 The functions must accept at least valid JSON input, and output *bytes*.
232 The functions must accept at least valid JSON input, and output *bytes*.
232
233
233 For example, to use msgpack:
234 For example, to use msgpack:
234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 pack/unpack : callables
236 pack/unpack : callables
236 You can also set the pack/unpack callables for serialization directly.
237 You can also set the pack/unpack callables for serialization directly.
237 session : bytes
238 session : bytes
238 the ID of this Session object. The default is to generate a new UUID.
239 the ID of this Session object. The default is to generate a new UUID.
239 username : unicode
240 username : unicode
240 username added to message headers. The default is to ask the OS.
241 username added to message headers. The default is to ask the OS.
241 key : bytes
242 key : bytes
242 The key used to initialize an HMAC signature. If unset, messages
243 The key used to initialize an HMAC signature. If unset, messages
243 will not be signed or checked.
244 will not be signed or checked.
244 keyfile : filepath
245 keyfile : filepath
245 The file containing a key. If this is set, `key` will be initialized
246 The file containing a key. If this is set, `key` will be initialized
246 to the contents of the file.
247 to the contents of the file.
247
248
248 """
249 """
249
250
250 debug=Bool(False, config=True, help="""Debug output in the Session""")
251 debug=Bool(False, config=True, help="""Debug output in the Session""")
251
252
252 packer = DottedObjectName('json',config=True,
253 packer = DottedObjectName('json',config=True,
253 help="""The name of the packer for serializing messages.
254 help="""The name of the packer for serializing messages.
254 Should be one of 'json', 'pickle', or an import name
255 Should be one of 'json', 'pickle', or an import name
255 for a custom callable serializer.""")
256 for a custom callable serializer.""")
256 def _packer_changed(self, name, old, new):
257 def _packer_changed(self, name, old, new):
257 if new.lower() == 'json':
258 if new.lower() == 'json':
258 self.pack = json_packer
259 self.pack = json_packer
259 self.unpack = json_unpacker
260 self.unpack = json_unpacker
260 self.unpacker = new
261 self.unpacker = new
261 elif new.lower() == 'pickle':
262 elif new.lower() == 'pickle':
262 self.pack = pickle_packer
263 self.pack = pickle_packer
263 self.unpack = pickle_unpacker
264 self.unpack = pickle_unpacker
264 self.unpacker = new
265 self.unpacker = new
265 else:
266 else:
266 self.pack = import_item(str(new))
267 self.pack = import_item(str(new))
267
268
268 unpacker = DottedObjectName('json', config=True,
269 unpacker = DottedObjectName('json', config=True,
269 help="""The name of the unpacker for unserializing messages.
270 help="""The name of the unpacker for unserializing messages.
270 Only used with custom functions for `packer`.""")
271 Only used with custom functions for `packer`.""")
271 def _unpacker_changed(self, name, old, new):
272 def _unpacker_changed(self, name, old, new):
272 if new.lower() == 'json':
273 if new.lower() == 'json':
273 self.pack = json_packer
274 self.pack = json_packer
274 self.unpack = json_unpacker
275 self.unpack = json_unpacker
275 self.packer = new
276 self.packer = new
276 elif new.lower() == 'pickle':
277 elif new.lower() == 'pickle':
277 self.pack = pickle_packer
278 self.pack = pickle_packer
278 self.unpack = pickle_unpacker
279 self.unpack = pickle_unpacker
279 self.packer = new
280 self.packer = new
280 else:
281 else:
281 self.unpack = import_item(str(new))
282 self.unpack = import_item(str(new))
282
283
283 session = CUnicode(u'', config=True,
284 session = CUnicode(u'', config=True,
284 help="""The UUID identifying this session.""")
285 help="""The UUID identifying this session.""")
285 def _session_default(self):
286 def _session_default(self):
286 u = unicode_type(uuid.uuid4())
287 u = unicode_type(uuid.uuid4())
287 self.bsession = u.encode('ascii')
288 self.bsession = u.encode('ascii')
288 return u
289 return u
289
290
290 def _session_changed(self, name, old, new):
291 def _session_changed(self, name, old, new):
291 self.bsession = self.session.encode('ascii')
292 self.bsession = self.session.encode('ascii')
292
293
293 # bsession is the session as bytes
294 # bsession is the session as bytes
294 bsession = CBytes(b'')
295 bsession = CBytes(b'')
295
296
296 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 help="""Username for the Session. Default is your system username.""",
298 help="""Username for the Session. Default is your system username.""",
298 config=True)
299 config=True)
299
300
300 metadata = Dict({}, config=True,
301 metadata = Dict({}, config=True,
301 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302
303
303 # if 0, no adapting to do.
304 # if 0, no adapting to do.
304 adapt_version = Integer(0)
305 adapt_version = Integer(0)
305
306
306 # message signature related traits:
307 # message signature related traits:
307
308
308 key = CBytes(b'', config=True,
309 key = CBytes(b'', config=True,
309 help="""execution key, for extra authentication.""")
310 help="""execution key, for extra authentication.""")
310 def _key_changed(self):
311 def _key_changed(self):
311 self._new_auth()
312 self._new_auth()
312
313
313 signature_scheme = Unicode('hmac-sha256', config=True,
314 signature_scheme = Unicode('hmac-sha256', config=True,
314 help="""The digest scheme used to construct the message signatures.
315 help="""The digest scheme used to construct the message signatures.
315 Must have the form 'hmac-HASH'.""")
316 Must have the form 'hmac-HASH'.""")
316 def _signature_scheme_changed(self, name, old, new):
317 def _signature_scheme_changed(self, name, old, new):
317 if not new.startswith('hmac-'):
318 if not new.startswith('hmac-'):
318 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
319 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
319 hash_name = new.split('-', 1)[1]
320 hash_name = new.split('-', 1)[1]
320 try:
321 try:
321 self.digest_mod = getattr(hashlib, hash_name)
322 self.digest_mod = getattr(hashlib, hash_name)
322 except AttributeError:
323 except AttributeError:
323 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 self._new_auth()
325 self._new_auth()
325
326
326 digest_mod = Any()
327 digest_mod = Any()
327 def _digest_mod_default(self):
328 def _digest_mod_default(self):
328 return hashlib.sha256
329 return hashlib.sha256
329
330
330 auth = Instance(hmac.HMAC)
331 auth = Instance(hmac.HMAC)
331
332
332 def _new_auth(self):
333 def _new_auth(self):
333 if self.key:
334 if self.key:
334 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 else:
336 else:
336 self.auth = None
337 self.auth = None
337
338
338 digest_history = Set()
339 digest_history = Set()
339 digest_history_size = Integer(2**16, config=True,
340 digest_history_size = Integer(2**16, config=True,
340 help="""The maximum number of digests to remember.
341 help="""The maximum number of digests to remember.
341
342
342 The digest history will be culled when it exceeds this value.
343 The digest history will be culled when it exceeds this value.
343 """
344 """
344 )
345 )
345
346
346 keyfile = Unicode('', config=True,
347 keyfile = Unicode('', config=True,
347 help="""path to file containing execution key.""")
348 help="""path to file containing execution key.""")
348 def _keyfile_changed(self, name, old, new):
349 def _keyfile_changed(self, name, old, new):
349 with open(new, 'rb') as f:
350 with open(new, 'rb') as f:
350 self.key = f.read().strip()
351 self.key = f.read().strip()
351
352
352 # for protecting against sends from forks
353 # for protecting against sends from forks
353 pid = Integer()
354 pid = Integer()
354
355
355 # serialization traits:
356 # serialization traits:
356
357
357 pack = Any(default_packer) # the actual packer function
358 pack = Any(default_packer) # the actual packer function
358 def _pack_changed(self, name, old, new):
359 def _pack_changed(self, name, old, new):
359 if not callable(new):
360 if not callable(new):
360 raise TypeError("packer must be callable, not %s"%type(new))
361 raise TypeError("packer must be callable, not %s"%type(new))
361
362
362 unpack = Any(default_unpacker) # the actual packer function
363 unpack = Any(default_unpacker) # the actual packer function
363 def _unpack_changed(self, name, old, new):
364 def _unpack_changed(self, name, old, new):
364 # unpacker is not checked - it is assumed to be
365 # unpacker is not checked - it is assumed to be
365 if not callable(new):
366 if not callable(new):
366 raise TypeError("unpacker must be callable, not %s"%type(new))
367 raise TypeError("unpacker must be callable, not %s"%type(new))
367
368
368 # thresholds:
369 # thresholds:
369 copy_threshold = Integer(2**16, config=True,
370 copy_threshold = Integer(2**16, config=True,
370 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
371 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
371 buffer_threshold = Integer(MAX_BYTES, config=True,
372 buffer_threshold = Integer(MAX_BYTES, config=True,
372 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
373 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
373 item_threshold = Integer(MAX_ITEMS, config=True,
374 item_threshold = Integer(MAX_ITEMS, config=True,
374 help="""The maximum number of items for a container to be introspected for custom serialization.
375 help="""The maximum number of items for a container to be introspected for custom serialization.
375 Containers larger than this are pickled outright.
376 Containers larger than this are pickled outright.
376 """
377 """
377 )
378 )
378
379
379
380
380 def __init__(self, **kwargs):
381 def __init__(self, **kwargs):
381 """create a Session object
382 """create a Session object
382
383
383 Parameters
384 Parameters
384 ----------
385 ----------
385
386
386 debug : bool
387 debug : bool
387 whether to trigger extra debugging statements
388 whether to trigger extra debugging statements
388 packer/unpacker : str : 'json', 'pickle' or import_string
389 packer/unpacker : str : 'json', 'pickle' or import_string
389 importstrings for methods to serialize message parts. If just
390 importstrings for methods to serialize message parts. If just
390 'json' or 'pickle', predefined JSON and pickle packers will be used.
391 'json' or 'pickle', predefined JSON and pickle packers will be used.
391 Otherwise, the entire importstring must be used.
392 Otherwise, the entire importstring must be used.
392
393
393 The functions must accept at least valid JSON input, and output
394 The functions must accept at least valid JSON input, and output
394 *bytes*.
395 *bytes*.
395
396
396 For example, to use msgpack:
397 For example, to use msgpack:
397 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
398 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
398 pack/unpack : callables
399 pack/unpack : callables
399 You can also set the pack/unpack callables for serialization
400 You can also set the pack/unpack callables for serialization
400 directly.
401 directly.
401 session : unicode (must be ascii)
402 session : unicode (must be ascii)
402 the ID of this Session object. The default is to generate a new
403 the ID of this Session object. The default is to generate a new
403 UUID.
404 UUID.
404 bsession : bytes
405 bsession : bytes
405 The session as bytes
406 The session as bytes
406 username : unicode
407 username : unicode
407 username added to message headers. The default is to ask the OS.
408 username added to message headers. The default is to ask the OS.
408 key : bytes
409 key : bytes
409 The key used to initialize an HMAC signature. If unset, messages
410 The key used to initialize an HMAC signature. If unset, messages
410 will not be signed or checked.
411 will not be signed or checked.
411 signature_scheme : str
412 signature_scheme : str
412 The message digest scheme. Currently must be of the form 'hmac-HASH',
413 The message digest scheme. Currently must be of the form 'hmac-HASH',
413 where 'HASH' is a hashing function available in Python's hashlib.
414 where 'HASH' is a hashing function available in Python's hashlib.
414 The default is 'hmac-sha256'.
415 The default is 'hmac-sha256'.
415 This is ignored if 'key' is empty.
416 This is ignored if 'key' is empty.
416 keyfile : filepath
417 keyfile : filepath
417 The file containing a key. If this is set, `key` will be
418 The file containing a key. If this is set, `key` will be
418 initialized to the contents of the file.
419 initialized to the contents of the file.
419 """
420 """
420 super(Session, self).__init__(**kwargs)
421 super(Session, self).__init__(**kwargs)
421 self._check_packers()
422 self._check_packers()
422 self.none = self.pack({})
423 self.none = self.pack({})
423 # ensure self._session_default() if necessary, so bsession is defined:
424 # ensure self._session_default() if necessary, so bsession is defined:
424 self.session
425 self.session
425 self.pid = os.getpid()
426 self.pid = os.getpid()
426
427
427 @property
428 @property
428 def msg_id(self):
429 def msg_id(self):
429 """always return new uuid"""
430 """always return new uuid"""
430 return str(uuid.uuid4())
431 return str(uuid.uuid4())
431
432
432 def _check_packers(self):
433 def _check_packers(self):
433 """check packers for datetime support."""
434 """check packers for datetime support."""
434 pack = self.pack
435 pack = self.pack
435 unpack = self.unpack
436 unpack = self.unpack
436
437
437 # check simple serialization
438 # check simple serialization
438 msg = dict(a=[1,'hi'])
439 msg = dict(a=[1,'hi'])
439 try:
440 try:
440 packed = pack(msg)
441 packed = pack(msg)
441 except Exception as e:
442 except Exception as e:
442 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
443 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
443 if self.packer == 'json':
444 if self.packer == 'json':
444 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
445 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
445 else:
446 else:
446 jsonmsg = ""
447 jsonmsg = ""
447 raise ValueError(
448 raise ValueError(
448 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
449 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
449 )
450 )
450
451
451 # ensure packed message is bytes
452 # ensure packed message is bytes
452 if not isinstance(packed, bytes):
453 if not isinstance(packed, bytes):
453 raise ValueError("message packed to %r, but bytes are required"%type(packed))
454 raise ValueError("message packed to %r, but bytes are required"%type(packed))
454
455
455 # check that unpack is pack's inverse
456 # check that unpack is pack's inverse
456 try:
457 try:
457 unpacked = unpack(packed)
458 unpacked = unpack(packed)
458 assert unpacked == msg
459 assert unpacked == msg
459 except Exception as e:
460 except Exception as e:
460 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
461 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
461 if self.packer == 'json':
462 if self.packer == 'json':
462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
463 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
463 else:
464 else:
464 jsonmsg = ""
465 jsonmsg = ""
465 raise ValueError(
466 raise ValueError(
466 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
467 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
467 )
468 )
468
469
469 # check datetime support
470 # check datetime support
470 msg = dict(t=datetime.now())
471 msg = dict(t=datetime.now())
471 try:
472 try:
472 unpacked = unpack(pack(msg))
473 unpacked = unpack(pack(msg))
473 if isinstance(unpacked['t'], datetime):
474 if isinstance(unpacked['t'], datetime):
474 raise ValueError("Shouldn't deserialize to datetime")
475 raise ValueError("Shouldn't deserialize to datetime")
475 except Exception:
476 except Exception:
476 self.pack = lambda o: pack(squash_dates(o))
477 self.pack = lambda o: pack(squash_dates(o))
477 self.unpack = lambda s: unpack(s)
478 self.unpack = lambda s: unpack(s)
478
479
479 def msg_header(self, msg_type):
480 def msg_header(self, msg_type):
480 return msg_header(self.msg_id, msg_type, self.username, self.session)
481 return msg_header(self.msg_id, msg_type, self.username, self.session)
481
482
482 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
483 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
483 """Return the nested message dict.
484 """Return the nested message dict.
484
485
485 This format is different from what is sent over the wire. The
486 This format is different from what is sent over the wire. The
486 serialize/unserialize methods converts this nested message dict to the wire
487 serialize/unserialize methods converts this nested message dict to the wire
487 format, which is a list of message parts.
488 format, which is a list of message parts.
488 """
489 """
489 msg = {}
490 msg = {}
490 header = self.msg_header(msg_type) if header is None else header
491 header = self.msg_header(msg_type) if header is None else header
491 msg['header'] = header
492 msg['header'] = header
492 msg['msg_id'] = header['msg_id']
493 msg['msg_id'] = header['msg_id']
493 msg['msg_type'] = header['msg_type']
494 msg['msg_type'] = header['msg_type']
494 msg['parent_header'] = {} if parent is None else extract_header(parent)
495 msg['parent_header'] = {} if parent is None else extract_header(parent)
495 msg['content'] = {} if content is None else content
496 msg['content'] = {} if content is None else content
496 msg['metadata'] = self.metadata.copy()
497 msg['metadata'] = self.metadata.copy()
497 if metadata is not None:
498 if metadata is not None:
498 msg['metadata'].update(metadata)
499 msg['metadata'].update(metadata)
499 return msg
500 return msg
500
501
501 def sign(self, msg_list):
502 def sign(self, msg_list):
502 """Sign a message with HMAC digest. If no auth, return b''.
503 """Sign a message with HMAC digest. If no auth, return b''.
503
504
504 Parameters
505 Parameters
505 ----------
506 ----------
506 msg_list : list
507 msg_list : list
507 The [p_header,p_parent,p_content] part of the message list.
508 The [p_header,p_parent,p_content] part of the message list.
508 """
509 """
509 if self.auth is None:
510 if self.auth is None:
510 return b''
511 return b''
511 h = self.auth.copy()
512 h = self.auth.copy()
512 for m in msg_list:
513 for m in msg_list:
513 h.update(m)
514 h.update(m)
514 return str_to_bytes(h.hexdigest())
515 return str_to_bytes(h.hexdigest())
515
516
516 def serialize(self, msg, ident=None):
517 def serialize(self, msg, ident=None):
517 """Serialize the message components to bytes.
518 """Serialize the message components to bytes.
518
519
519 This is roughly the inverse of unserialize. The serialize/unserialize
520 This is roughly the inverse of unserialize. The serialize/unserialize
520 methods work with full message lists, whereas pack/unpack work with
521 methods work with full message lists, whereas pack/unpack work with
521 the individual message parts in the message list.
522 the individual message parts in the message list.
522
523
523 Parameters
524 Parameters
524 ----------
525 ----------
525 msg : dict or Message
526 msg : dict or Message
526 The nexted message dict as returned by the self.msg method.
527 The nexted message dict as returned by the self.msg method.
527
528
528 Returns
529 Returns
529 -------
530 -------
530 msg_list : list
531 msg_list : list
531 The list of bytes objects to be sent with the format::
532 The list of bytes objects to be sent with the format::
532
533
533 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
534 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
534 p_metadata, p_content, buffer1, buffer2, ...]
535 p_metadata, p_content, buffer1, buffer2, ...]
535
536
536 In this list, the ``p_*`` entities are the packed or serialized
537 In this list, the ``p_*`` entities are the packed or serialized
537 versions, so if JSON is used, these are utf8 encoded JSON strings.
538 versions, so if JSON is used, these are utf8 encoded JSON strings.
538 """
539 """
539 content = msg.get('content', {})
540 content = msg.get('content', {})
540 if content is None:
541 if content is None:
541 content = self.none
542 content = self.none
542 elif isinstance(content, dict):
543 elif isinstance(content, dict):
543 content = self.pack(content)
544 content = self.pack(content)
544 elif isinstance(content, bytes):
545 elif isinstance(content, bytes):
545 # content is already packed, as in a relayed message
546 # content is already packed, as in a relayed message
546 pass
547 pass
547 elif isinstance(content, unicode_type):
548 elif isinstance(content, unicode_type):
548 # should be bytes, but JSON often spits out unicode
549 # should be bytes, but JSON often spits out unicode
549 content = content.encode('utf8')
550 content = content.encode('utf8')
550 else:
551 else:
551 raise TypeError("Content incorrect type: %s"%type(content))
552 raise TypeError("Content incorrect type: %s"%type(content))
552
553
553 real_message = [self.pack(msg['header']),
554 real_message = [self.pack(msg['header']),
554 self.pack(msg['parent_header']),
555 self.pack(msg['parent_header']),
555 self.pack(msg['metadata']),
556 self.pack(msg['metadata']),
556 content,
557 content,
557 ]
558 ]
558
559
559 to_send = []
560 to_send = []
560
561
561 if isinstance(ident, list):
562 if isinstance(ident, list):
562 # accept list of idents
563 # accept list of idents
563 to_send.extend(ident)
564 to_send.extend(ident)
564 elif ident is not None:
565 elif ident is not None:
565 to_send.append(ident)
566 to_send.append(ident)
566 to_send.append(DELIM)
567 to_send.append(DELIM)
567
568
568 signature = self.sign(real_message)
569 signature = self.sign(real_message)
569 to_send.append(signature)
570 to_send.append(signature)
570
571
571 to_send.extend(real_message)
572 to_send.extend(real_message)
572
573
573 return to_send
574 return to_send
574
575
575 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
576 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
576 buffers=None, track=False, header=None, metadata=None):
577 buffers=None, track=False, header=None, metadata=None):
577 """Build and send a message via stream or socket.
578 """Build and send a message via stream or socket.
578
579
579 The message format used by this function internally is as follows:
580 The message format used by this function internally is as follows:
580
581
581 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
582 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
582 buffer1,buffer2,...]
583 buffer1,buffer2,...]
583
584
584 The serialize/unserialize methods convert the nested message dict into this
585 The serialize/unserialize methods convert the nested message dict into this
585 format.
586 format.
586
587
587 Parameters
588 Parameters
588 ----------
589 ----------
589
590
590 stream : zmq.Socket or ZMQStream
591 stream : zmq.Socket or ZMQStream
591 The socket-like object used to send the data.
592 The socket-like object used to send the data.
592 msg_or_type : str or Message/dict
593 msg_or_type : str or Message/dict
593 Normally, msg_or_type will be a msg_type unless a message is being
594 Normally, msg_or_type will be a msg_type unless a message is being
594 sent more than once. If a header is supplied, this can be set to
595 sent more than once. If a header is supplied, this can be set to
595 None and the msg_type will be pulled from the header.
596 None and the msg_type will be pulled from the header.
596
597
597 content : dict or None
598 content : dict or None
598 The content of the message (ignored if msg_or_type is a message).
599 The content of the message (ignored if msg_or_type is a message).
599 header : dict or None
600 header : dict or None
600 The header dict for the message (ignored if msg_to_type is a message).
601 The header dict for the message (ignored if msg_to_type is a message).
601 parent : Message or dict or None
602 parent : Message or dict or None
602 The parent or parent header describing the parent of this message
603 The parent or parent header describing the parent of this message
603 (ignored if msg_or_type is a message).
604 (ignored if msg_or_type is a message).
604 ident : bytes or list of bytes
605 ident : bytes or list of bytes
605 The zmq.IDENTITY routing path.
606 The zmq.IDENTITY routing path.
606 metadata : dict or None
607 metadata : dict or None
607 The metadata describing the message
608 The metadata describing the message
608 buffers : list or None
609 buffers : list or None
609 The already-serialized buffers to be appended to the message.
610 The already-serialized buffers to be appended to the message.
610 track : bool
611 track : bool
611 Whether to track. Only for use with Sockets, because ZMQStream
612 Whether to track. Only for use with Sockets, because ZMQStream
612 objects cannot track messages.
613 objects cannot track messages.
613
614
614
615
615 Returns
616 Returns
616 -------
617 -------
617 msg : dict
618 msg : dict
618 The constructed message.
619 The constructed message.
619 """
620 """
620 if not isinstance(stream, zmq.Socket):
621 if not isinstance(stream, zmq.Socket):
621 # ZMQStreams and dummy sockets do not support tracking.
622 # ZMQStreams and dummy sockets do not support tracking.
622 track = False
623 track = False
623
624
624 if isinstance(msg_or_type, (Message, dict)):
625 if isinstance(msg_or_type, (Message, dict)):
625 # We got a Message or message dict, not a msg_type so don't
626 # We got a Message or message dict, not a msg_type so don't
626 # build a new Message.
627 # build a new Message.
627 msg = msg_or_type
628 msg = msg_or_type
628 else:
629 else:
629 msg = self.msg(msg_or_type, content=content, parent=parent,
630 msg = self.msg(msg_or_type, content=content, parent=parent,
630 header=header, metadata=metadata)
631 header=header, metadata=metadata)
631 if not os.getpid() == self.pid:
632 if not os.getpid() == self.pid:
632 io.rprint("WARNING: attempted to send message from fork")
633 io.rprint("WARNING: attempted to send message from fork")
633 io.rprint(msg)
634 io.rprint(msg)
634 return
635 return
635 buffers = [] if buffers is None else buffers
636 buffers = [] if buffers is None else buffers
636 if self.adapt_version:
637 if self.adapt_version:
637 msg = adapt(msg, self.adapt_version)
638 msg = adapt(msg, self.adapt_version)
638 to_send = self.serialize(msg, ident)
639 to_send = self.serialize(msg, ident)
639 to_send.extend(buffers)
640 to_send.extend(buffers)
640 longest = max([ len(s) for s in to_send ])
641 longest = max([ len(s) for s in to_send ])
641 copy = (longest < self.copy_threshold)
642 copy = (longest < self.copy_threshold)
642
643
643 if buffers and track and not copy:
644 if buffers and track and not copy:
644 # only really track when we are doing zero-copy buffers
645 # only really track when we are doing zero-copy buffers
645 tracker = stream.send_multipart(to_send, copy=False, track=True)
646 tracker = stream.send_multipart(to_send, copy=False, track=True)
646 else:
647 else:
647 # use dummy tracker, which will be done immediately
648 # use dummy tracker, which will be done immediately
648 tracker = DONE
649 tracker = DONE
649 stream.send_multipart(to_send, copy=copy)
650 stream.send_multipart(to_send, copy=copy)
650
651
651 if self.debug:
652 if self.debug:
652 pprint.pprint(msg)
653 pprint.pprint(msg)
653 pprint.pprint(to_send)
654 pprint.pprint(to_send)
654 pprint.pprint(buffers)
655 pprint.pprint(buffers)
655
656
656 msg['tracker'] = tracker
657 msg['tracker'] = tracker
657
658
658 return msg
659 return msg
659
660
660 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
661 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
661 """Send a raw message via ident path.
662 """Send a raw message via ident path.
662
663
663 This method is used to send a already serialized message.
664 This method is used to send a already serialized message.
664
665
665 Parameters
666 Parameters
666 ----------
667 ----------
667 stream : ZMQStream or Socket
668 stream : ZMQStream or Socket
668 The ZMQ stream or socket to use for sending the message.
669 The ZMQ stream or socket to use for sending the message.
669 msg_list : list
670 msg_list : list
670 The serialized list of messages to send. This only includes the
671 The serialized list of messages to send. This only includes the
671 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
672 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
672 the message.
673 the message.
673 ident : ident or list
674 ident : ident or list
674 A single ident or a list of idents to use in sending.
675 A single ident or a list of idents to use in sending.
675 """
676 """
676 to_send = []
677 to_send = []
677 if isinstance(ident, bytes):
678 if isinstance(ident, bytes):
678 ident = [ident]
679 ident = [ident]
679 if ident is not None:
680 if ident is not None:
680 to_send.extend(ident)
681 to_send.extend(ident)
681
682
682 to_send.append(DELIM)
683 to_send.append(DELIM)
683 to_send.append(self.sign(msg_list))
684 to_send.append(self.sign(msg_list))
684 to_send.extend(msg_list)
685 to_send.extend(msg_list)
685 stream.send_multipart(to_send, flags, copy=copy)
686 stream.send_multipart(to_send, flags, copy=copy)
686
687
687 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
688 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
688 """Receive and unpack a message.
689 """Receive and unpack a message.
689
690
690 Parameters
691 Parameters
691 ----------
692 ----------
692 socket : ZMQStream or Socket
693 socket : ZMQStream or Socket
693 The socket or stream to use in receiving.
694 The socket or stream to use in receiving.
694
695
695 Returns
696 Returns
696 -------
697 -------
697 [idents], msg
698 [idents], msg
698 [idents] is a list of idents and msg is a nested message dict of
699 [idents] is a list of idents and msg is a nested message dict of
699 same format as self.msg returns.
700 same format as self.msg returns.
700 """
701 """
701 if isinstance(socket, ZMQStream):
702 if isinstance(socket, ZMQStream):
702 socket = socket.socket
703 socket = socket.socket
703 try:
704 try:
704 msg_list = socket.recv_multipart(mode, copy=copy)
705 msg_list = socket.recv_multipart(mode, copy=copy)
705 except zmq.ZMQError as e:
706 except zmq.ZMQError as e:
706 if e.errno == zmq.EAGAIN:
707 if e.errno == zmq.EAGAIN:
707 # We can convert EAGAIN to None as we know in this case
708 # We can convert EAGAIN to None as we know in this case
708 # recv_multipart won't return None.
709 # recv_multipart won't return None.
709 return None,None
710 return None,None
710 else:
711 else:
711 raise
712 raise
712 # split multipart message into identity list and message dict
713 # split multipart message into identity list and message dict
713 # invalid large messages can cause very expensive string comparisons
714 # invalid large messages can cause very expensive string comparisons
714 idents, msg_list = self.feed_identities(msg_list, copy)
715 idents, msg_list = self.feed_identities(msg_list, copy)
715 try:
716 try:
716 return idents, self.unserialize(msg_list, content=content, copy=copy)
717 return idents, self.unserialize(msg_list, content=content, copy=copy)
717 except Exception as e:
718 except Exception as e:
718 # TODO: handle it
719 # TODO: handle it
719 raise e
720 raise e
720
721
721 def feed_identities(self, msg_list, copy=True):
722 def feed_identities(self, msg_list, copy=True):
722 """Split the identities from the rest of the message.
723 """Split the identities from the rest of the message.
723
724
724 Feed until DELIM is reached, then return the prefix as idents and
725 Feed until DELIM is reached, then return the prefix as idents and
725 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
726 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
726 but that would be silly.
727 but that would be silly.
727
728
728 Parameters
729 Parameters
729 ----------
730 ----------
730 msg_list : a list of Message or bytes objects
731 msg_list : a list of Message or bytes objects
731 The message to be split.
732 The message to be split.
732 copy : bool
733 copy : bool
733 flag determining whether the arguments are bytes or Messages
734 flag determining whether the arguments are bytes or Messages
734
735
735 Returns
736 Returns
736 -------
737 -------
737 (idents, msg_list) : two lists
738 (idents, msg_list) : two lists
738 idents will always be a list of bytes, each of which is a ZMQ
739 idents will always be a list of bytes, each of which is a ZMQ
739 identity. msg_list will be a list of bytes or zmq.Messages of the
740 identity. msg_list will be a list of bytes or zmq.Messages of the
740 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
741 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
741 should be unpackable/unserializable via self.unserialize at this
742 should be unpackable/unserializable via self.unserialize at this
742 point.
743 point.
743 """
744 """
744 if copy:
745 if copy:
745 idx = msg_list.index(DELIM)
746 idx = msg_list.index(DELIM)
746 return msg_list[:idx], msg_list[idx+1:]
747 return msg_list[:idx], msg_list[idx+1:]
747 else:
748 else:
748 failed = True
749 failed = True
749 for idx,m in enumerate(msg_list):
750 for idx,m in enumerate(msg_list):
750 if m.bytes == DELIM:
751 if m.bytes == DELIM:
751 failed = False
752 failed = False
752 break
753 break
753 if failed:
754 if failed:
754 raise ValueError("DELIM not in msg_list")
755 raise ValueError("DELIM not in msg_list")
755 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
756 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
756 return [m.bytes for m in idents], msg_list
757 return [m.bytes for m in idents], msg_list
757
758
758 def _add_digest(self, signature):
759 def _add_digest(self, signature):
759 """add a digest to history to protect against replay attacks"""
760 """add a digest to history to protect against replay attacks"""
760 if self.digest_history_size == 0:
761 if self.digest_history_size == 0:
761 # no history, never add digests
762 # no history, never add digests
762 return
763 return
763
764
764 self.digest_history.add(signature)
765 self.digest_history.add(signature)
765 if len(self.digest_history) > self.digest_history_size:
766 if len(self.digest_history) > self.digest_history_size:
766 # threshold reached, cull 10%
767 # threshold reached, cull 10%
767 self._cull_digest_history()
768 self._cull_digest_history()
768
769
769 def _cull_digest_history(self):
770 def _cull_digest_history(self):
770 """cull the digest history
771 """cull the digest history
771
772
772 Removes a randomly selected 10% of the digest history
773 Removes a randomly selected 10% of the digest history
773 """
774 """
774 current = len(self.digest_history)
775 current = len(self.digest_history)
775 n_to_cull = max(int(current // 10), current - self.digest_history_size)
776 n_to_cull = max(int(current // 10), current - self.digest_history_size)
776 if n_to_cull >= current:
777 if n_to_cull >= current:
777 self.digest_history = set()
778 self.digest_history = set()
778 return
779 return
779 to_cull = random.sample(self.digest_history, n_to_cull)
780 to_cull = random.sample(self.digest_history, n_to_cull)
780 self.digest_history.difference_update(to_cull)
781 self.digest_history.difference_update(to_cull)
781
782
782 def unserialize(self, msg_list, content=True, copy=True):
783 def unserialize(self, msg_list, content=True, copy=True):
783 """Unserialize a msg_list to a nested message dict.
784 """Unserialize a msg_list to a nested message dict.
784
785
785 This is roughly the inverse of serialize. The serialize/unserialize
786 This is roughly the inverse of serialize. The serialize/unserialize
786 methods work with full message lists, whereas pack/unpack work with
787 methods work with full message lists, whereas pack/unpack work with
787 the individual message parts in the message list.
788 the individual message parts in the message list.
788
789
789 Parameters
790 Parameters
790 ----------
791 ----------
791 msg_list : list of bytes or Message objects
792 msg_list : list of bytes or Message objects
792 The list of message parts of the form [HMAC,p_header,p_parent,
793 The list of message parts of the form [HMAC,p_header,p_parent,
793 p_metadata,p_content,buffer1,buffer2,...].
794 p_metadata,p_content,buffer1,buffer2,...].
794 content : bool (True)
795 content : bool (True)
795 Whether to unpack the content dict (True), or leave it packed
796 Whether to unpack the content dict (True), or leave it packed
796 (False).
797 (False).
797 copy : bool (True)
798 copy : bool (True)
798 Whether to return the bytes (True), or the non-copying Message
799 Whether to return the bytes (True), or the non-copying Message
799 object in each place (False).
800 object in each place (False).
800
801
801 Returns
802 Returns
802 -------
803 -------
803 msg : dict
804 msg : dict
804 The nested message dict with top-level keys [header, parent_header,
805 The nested message dict with top-level keys [header, parent_header,
805 content, buffers].
806 content, buffers].
806 """
807 """
807 minlen = 5
808 minlen = 5
808 message = {}
809 message = {}
809 if not copy:
810 if not copy:
810 for i in range(minlen):
811 for i in range(minlen):
811 msg_list[i] = msg_list[i].bytes
812 msg_list[i] = msg_list[i].bytes
812 if self.auth is not None:
813 if self.auth is not None:
813 signature = msg_list[0]
814 signature = msg_list[0]
814 if not signature:
815 if not signature:
815 raise ValueError("Unsigned Message")
816 raise ValueError("Unsigned Message")
816 if signature in self.digest_history:
817 if signature in self.digest_history:
817 raise ValueError("Duplicate Signature: %r" % signature)
818 raise ValueError("Duplicate Signature: %r" % signature)
818 self._add_digest(signature)
819 self._add_digest(signature)
819 check = self.sign(msg_list[1:5])
820 check = self.sign(msg_list[1:5])
820 if not signature == check:
821 if not signature == check:
821 raise ValueError("Invalid Signature: %r" % signature)
822 raise ValueError("Invalid Signature: %r" % signature)
822 if not len(msg_list) >= minlen:
823 if not len(msg_list) >= minlen:
823 raise TypeError("malformed message, must have at least %i elements"%minlen)
824 raise TypeError("malformed message, must have at least %i elements"%minlen)
824 header = self.unpack(msg_list[1])
825 header = self.unpack(msg_list[1])
825 message['header'] = extract_dates(header)
826 message['header'] = extract_dates(header)
826 message['msg_id'] = header['msg_id']
827 message['msg_id'] = header['msg_id']
827 message['msg_type'] = header['msg_type']
828 message['msg_type'] = header['msg_type']
828 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
829 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
829 message['metadata'] = self.unpack(msg_list[3])
830 message['metadata'] = self.unpack(msg_list[3])
830 if content:
831 if content:
831 message['content'] = self.unpack(msg_list[4])
832 message['content'] = self.unpack(msg_list[4])
832 else:
833 else:
833 message['content'] = msg_list[4]
834 message['content'] = msg_list[4]
834
835
835 message['buffers'] = msg_list[5:]
836 message['buffers'] = msg_list[5:]
836 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
837 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
837 # adapt to the current version
838 # adapt to the current version
838 return adapt(message)
839 return adapt(message)
839 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
840 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
840
841
841 def test_msg2obj():
842 def test_msg2obj():
842 am = dict(x=1)
843 am = dict(x=1)
843 ao = Message(am)
844 ao = Message(am)
844 assert ao.x == am['x']
845 assert ao.x == am['x']
845
846
846 am['y'] = dict(z=1)
847 am['y'] = dict(z=1)
847 ao = Message(am)
848 ao = Message(am)
848 assert ao.y.z == am['y']['z']
849 assert ao.y.z == am['y']['z']
849
850
850 k1, k2 = 'y', 'z'
851 k1, k2 = 'y', 'z'
851 assert ao[k1][k2] == am[k1][k2]
852 assert ao[k1][k2] == am[k1][k2]
852
853
853 am2 = dict(ao)
854 am2 = dict(ao)
854 assert am['x'] == am2['x']
855 assert am['x'] == am2['x']
855 assert am['y']['z'] == am2['y']['z']
856 assert am['y']['z'] == am2['y']['z']
856
857
@@ -1,433 +1,438 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Pickle related utilities. Perhaps this should be called 'can'."""
2 """Pickle related utilities. Perhaps this should be called 'can'."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import copy
7 import copy
8 import logging
8 import logging
9 import sys
9 import sys
10 from types import FunctionType
10 from types import FunctionType
11
11
12 try:
12 try:
13 import cPickle as pickle
13 import cPickle as pickle
14 except ImportError:
14 except ImportError:
15 import pickle
15 import pickle
16
16
17 from . import codeutil # This registers a hook when it's imported
17 from . import codeutil # This registers a hook when it's imported
18 from . import py3compat
18 from . import py3compat
19 from .importstring import import_item
19 from .importstring import import_item
20 from .py3compat import string_types, iteritems
20 from .py3compat import string_types, iteritems
21
21
22 from IPython.config import Application
22 from IPython.config import Application
23
23
24 if py3compat.PY3:
24 if py3compat.PY3:
25 buffer = memoryview
25 buffer = memoryview
26 class_type = type
26 class_type = type
27 else:
27 else:
28 from types import ClassType
28 from types import ClassType
29 class_type = (type, ClassType)
29 class_type = (type, ClassType)
30
30
31 try:
32 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
33 except AttributeError:
34 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
35
31 def _get_cell_type(a=None):
36 def _get_cell_type(a=None):
32 """the type of a closure cell doesn't seem to be importable,
37 """the type of a closure cell doesn't seem to be importable,
33 so just create one
38 so just create one
34 """
39 """
35 def inner():
40 def inner():
36 return a
41 return a
37 return type(py3compat.get_closure(inner)[0])
42 return type(py3compat.get_closure(inner)[0])
38
43
39 cell_type = _get_cell_type()
44 cell_type = _get_cell_type()
40
45
41 #-------------------------------------------------------------------------------
46 #-------------------------------------------------------------------------------
42 # Functions
47 # Functions
43 #-------------------------------------------------------------------------------
48 #-------------------------------------------------------------------------------
44
49
45
50
46 def use_dill():
51 def use_dill():
47 """use dill to expand serialization support
52 """use dill to expand serialization support
48
53
49 adds support for object methods and closures to serialization.
54 adds support for object methods and closures to serialization.
50 """
55 """
51 # import dill causes most of the magic
56 # import dill causes most of the magic
52 import dill
57 import dill
53
58
54 # dill doesn't work with cPickle,
59 # dill doesn't work with cPickle,
55 # tell the two relevant modules to use plain pickle
60 # tell the two relevant modules to use plain pickle
56
61
57 global pickle
62 global pickle
58 pickle = dill
63 pickle = dill
59
64
60 try:
65 try:
61 from IPython.kernel.zmq import serialize
66 from IPython.kernel.zmq import serialize
62 except ImportError:
67 except ImportError:
63 pass
68 pass
64 else:
69 else:
65 serialize.pickle = dill
70 serialize.pickle = dill
66
71
67 # disable special function handling, let dill take care of it
72 # disable special function handling, let dill take care of it
68 can_map.pop(FunctionType, None)
73 can_map.pop(FunctionType, None)
69
74
70 def use_cloudpickle():
75 def use_cloudpickle():
71 """use cloudpickle to expand serialization support
76 """use cloudpickle to expand serialization support
72
77
73 adds support for object methods and closures to serialization.
78 adds support for object methods and closures to serialization.
74 """
79 """
75 from cloud.serialization import cloudpickle
80 from cloud.serialization import cloudpickle
76
81
77 global pickle
82 global pickle
78 pickle = cloudpickle
83 pickle = cloudpickle
79
84
80 try:
85 try:
81 from IPython.kernel.zmq import serialize
86 from IPython.kernel.zmq import serialize
82 except ImportError:
87 except ImportError:
83 pass
88 pass
84 else:
89 else:
85 serialize.pickle = cloudpickle
90 serialize.pickle = cloudpickle
86
91
87 # disable special function handling, let cloudpickle take care of it
92 # disable special function handling, let cloudpickle take care of it
88 can_map.pop(FunctionType, None)
93 can_map.pop(FunctionType, None)
89
94
90
95
91 #-------------------------------------------------------------------------------
96 #-------------------------------------------------------------------------------
92 # Classes
97 # Classes
93 #-------------------------------------------------------------------------------
98 #-------------------------------------------------------------------------------
94
99
95
100
96 class CannedObject(object):
101 class CannedObject(object):
97 def __init__(self, obj, keys=[], hook=None):
102 def __init__(self, obj, keys=[], hook=None):
98 """can an object for safe pickling
103 """can an object for safe pickling
99
104
100 Parameters
105 Parameters
101 ==========
106 ==========
102
107
103 obj:
108 obj:
104 The object to be canned
109 The object to be canned
105 keys: list (optional)
110 keys: list (optional)
106 list of attribute names that will be explicitly canned / uncanned
111 list of attribute names that will be explicitly canned / uncanned
107 hook: callable (optional)
112 hook: callable (optional)
108 An optional extra callable,
113 An optional extra callable,
109 which can do additional processing of the uncanned object.
114 which can do additional processing of the uncanned object.
110
115
111 large data may be offloaded into the buffers list,
116 large data may be offloaded into the buffers list,
112 used for zero-copy transfers.
117 used for zero-copy transfers.
113 """
118 """
114 self.keys = keys
119 self.keys = keys
115 self.obj = copy.copy(obj)
120 self.obj = copy.copy(obj)
116 self.hook = can(hook)
121 self.hook = can(hook)
117 for key in keys:
122 for key in keys:
118 setattr(self.obj, key, can(getattr(obj, key)))
123 setattr(self.obj, key, can(getattr(obj, key)))
119
124
120 self.buffers = []
125 self.buffers = []
121
126
122 def get_object(self, g=None):
127 def get_object(self, g=None):
123 if g is None:
128 if g is None:
124 g = {}
129 g = {}
125 obj = self.obj
130 obj = self.obj
126 for key in self.keys:
131 for key in self.keys:
127 setattr(obj, key, uncan(getattr(obj, key), g))
132 setattr(obj, key, uncan(getattr(obj, key), g))
128
133
129 if self.hook:
134 if self.hook:
130 self.hook = uncan(self.hook, g)
135 self.hook = uncan(self.hook, g)
131 self.hook(obj, g)
136 self.hook(obj, g)
132 return self.obj
137 return self.obj
133
138
134
139
135 class Reference(CannedObject):
140 class Reference(CannedObject):
136 """object for wrapping a remote reference by name."""
141 """object for wrapping a remote reference by name."""
137 def __init__(self, name):
142 def __init__(self, name):
138 if not isinstance(name, string_types):
143 if not isinstance(name, string_types):
139 raise TypeError("illegal name: %r"%name)
144 raise TypeError("illegal name: %r"%name)
140 self.name = name
145 self.name = name
141 self.buffers = []
146 self.buffers = []
142
147
143 def __repr__(self):
148 def __repr__(self):
144 return "<Reference: %r>"%self.name
149 return "<Reference: %r>"%self.name
145
150
146 def get_object(self, g=None):
151 def get_object(self, g=None):
147 if g is None:
152 if g is None:
148 g = {}
153 g = {}
149
154
150 return eval(self.name, g)
155 return eval(self.name, g)
151
156
152
157
153 class CannedCell(CannedObject):
158 class CannedCell(CannedObject):
154 """Can a closure cell"""
159 """Can a closure cell"""
155 def __init__(self, cell):
160 def __init__(self, cell):
156 self.cell_contents = can(cell.cell_contents)
161 self.cell_contents = can(cell.cell_contents)
157
162
158 def get_object(self, g=None):
163 def get_object(self, g=None):
159 cell_contents = uncan(self.cell_contents, g)
164 cell_contents = uncan(self.cell_contents, g)
160 def inner():
165 def inner():
161 return cell_contents
166 return cell_contents
162 return py3compat.get_closure(inner)[0]
167 return py3compat.get_closure(inner)[0]
163
168
164
169
165 class CannedFunction(CannedObject):
170 class CannedFunction(CannedObject):
166
171
167 def __init__(self, f):
172 def __init__(self, f):
168 self._check_type(f)
173 self._check_type(f)
169 self.code = f.__code__
174 self.code = f.__code__
170 if f.__defaults__:
175 if f.__defaults__:
171 self.defaults = [ can(fd) for fd in f.__defaults__ ]
176 self.defaults = [ can(fd) for fd in f.__defaults__ ]
172 else:
177 else:
173 self.defaults = None
178 self.defaults = None
174
179
175 closure = py3compat.get_closure(f)
180 closure = py3compat.get_closure(f)
176 if closure:
181 if closure:
177 self.closure = tuple( can(cell) for cell in closure )
182 self.closure = tuple( can(cell) for cell in closure )
178 else:
183 else:
179 self.closure = None
184 self.closure = None
180
185
181 self.module = f.__module__ or '__main__'
186 self.module = f.__module__ or '__main__'
182 self.__name__ = f.__name__
187 self.__name__ = f.__name__
183 self.buffers = []
188 self.buffers = []
184
189
185 def _check_type(self, obj):
190 def _check_type(self, obj):
186 assert isinstance(obj, FunctionType), "Not a function type"
191 assert isinstance(obj, FunctionType), "Not a function type"
187
192
188 def get_object(self, g=None):
193 def get_object(self, g=None):
189 # try to load function back into its module:
194 # try to load function back into its module:
190 if not self.module.startswith('__'):
195 if not self.module.startswith('__'):
191 __import__(self.module)
196 __import__(self.module)
192 g = sys.modules[self.module].__dict__
197 g = sys.modules[self.module].__dict__
193
198
194 if g is None:
199 if g is None:
195 g = {}
200 g = {}
196 if self.defaults:
201 if self.defaults:
197 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
202 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
198 else:
203 else:
199 defaults = None
204 defaults = None
200 if self.closure:
205 if self.closure:
201 closure = tuple(uncan(cell, g) for cell in self.closure)
206 closure = tuple(uncan(cell, g) for cell in self.closure)
202 else:
207 else:
203 closure = None
208 closure = None
204 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
209 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
205 return newFunc
210 return newFunc
206
211
207 class CannedClass(CannedObject):
212 class CannedClass(CannedObject):
208
213
209 def __init__(self, cls):
214 def __init__(self, cls):
210 self._check_type(cls)
215 self._check_type(cls)
211 self.name = cls.__name__
216 self.name = cls.__name__
212 self.old_style = not isinstance(cls, type)
217 self.old_style = not isinstance(cls, type)
213 self._canned_dict = {}
218 self._canned_dict = {}
214 for k,v in cls.__dict__.items():
219 for k,v in cls.__dict__.items():
215 if k not in ('__weakref__', '__dict__'):
220 if k not in ('__weakref__', '__dict__'):
216 self._canned_dict[k] = can(v)
221 self._canned_dict[k] = can(v)
217 if self.old_style:
222 if self.old_style:
218 mro = []
223 mro = []
219 else:
224 else:
220 mro = cls.mro()
225 mro = cls.mro()
221
226
222 self.parents = [ can(c) for c in mro[1:] ]
227 self.parents = [ can(c) for c in mro[1:] ]
223 self.buffers = []
228 self.buffers = []
224
229
225 def _check_type(self, obj):
230 def _check_type(self, obj):
226 assert isinstance(obj, class_type), "Not a class type"
231 assert isinstance(obj, class_type), "Not a class type"
227
232
228 def get_object(self, g=None):
233 def get_object(self, g=None):
229 parents = tuple(uncan(p, g) for p in self.parents)
234 parents = tuple(uncan(p, g) for p in self.parents)
230 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
235 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
231
236
232 class CannedArray(CannedObject):
237 class CannedArray(CannedObject):
233 def __init__(self, obj):
238 def __init__(self, obj):
234 from numpy import ascontiguousarray
239 from numpy import ascontiguousarray
235 self.shape = obj.shape
240 self.shape = obj.shape
236 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
241 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
237 self.pickled = False
242 self.pickled = False
238 if sum(obj.shape) == 0:
243 if sum(obj.shape) == 0:
239 self.pickled = True
244 self.pickled = True
240 elif obj.dtype == 'O':
245 elif obj.dtype == 'O':
241 # can't handle object dtype with buffer approach
246 # can't handle object dtype with buffer approach
242 self.pickled = True
247 self.pickled = True
243 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
248 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
244 self.pickled = True
249 self.pickled = True
245 if self.pickled:
250 if self.pickled:
246 # just pickle it
251 # just pickle it
247 self.buffers = [pickle.dumps(obj, -1)]
252 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
248 else:
253 else:
249 # ensure contiguous
254 # ensure contiguous
250 obj = ascontiguousarray(obj, dtype=None)
255 obj = ascontiguousarray(obj, dtype=None)
251 self.buffers = [buffer(obj)]
256 self.buffers = [buffer(obj)]
252
257
253 def get_object(self, g=None):
258 def get_object(self, g=None):
254 from numpy import frombuffer
259 from numpy import frombuffer
255 data = self.buffers[0]
260 data = self.buffers[0]
256 if self.pickled:
261 if self.pickled:
257 # no shape, we just pickled it
262 # no shape, we just pickled it
258 return pickle.loads(data)
263 return pickle.loads(data)
259 else:
264 else:
260 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
265 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
261
266
262
267
263 class CannedBytes(CannedObject):
268 class CannedBytes(CannedObject):
264 wrap = bytes
269 wrap = bytes
265 def __init__(self, obj):
270 def __init__(self, obj):
266 self.buffers = [obj]
271 self.buffers = [obj]
267
272
268 def get_object(self, g=None):
273 def get_object(self, g=None):
269 data = self.buffers[0]
274 data = self.buffers[0]
270 return self.wrap(data)
275 return self.wrap(data)
271
276
272 def CannedBuffer(CannedBytes):
277 def CannedBuffer(CannedBytes):
273 wrap = buffer
278 wrap = buffer
274
279
275 #-------------------------------------------------------------------------------
280 #-------------------------------------------------------------------------------
276 # Functions
281 # Functions
277 #-------------------------------------------------------------------------------
282 #-------------------------------------------------------------------------------
278
283
279 def _logger():
284 def _logger():
280 """get the logger for the current Application
285 """get the logger for the current Application
281
286
282 the root logger will be used if no Application is running
287 the root logger will be used if no Application is running
283 """
288 """
284 if Application.initialized():
289 if Application.initialized():
285 logger = Application.instance().log
290 logger = Application.instance().log
286 else:
291 else:
287 logger = logging.getLogger()
292 logger = logging.getLogger()
288 if not logger.handlers:
293 if not logger.handlers:
289 logging.basicConfig()
294 logging.basicConfig()
290
295
291 return logger
296 return logger
292
297
293 def _import_mapping(mapping, original=None):
298 def _import_mapping(mapping, original=None):
294 """import any string-keys in a type mapping
299 """import any string-keys in a type mapping
295
300
296 """
301 """
297 log = _logger()
302 log = _logger()
298 log.debug("Importing canning map")
303 log.debug("Importing canning map")
299 for key,value in list(mapping.items()):
304 for key,value in list(mapping.items()):
300 if isinstance(key, string_types):
305 if isinstance(key, string_types):
301 try:
306 try:
302 cls = import_item(key)
307 cls = import_item(key)
303 except Exception:
308 except Exception:
304 if original and key not in original:
309 if original and key not in original:
305 # only message on user-added classes
310 # only message on user-added classes
306 log.error("canning class not importable: %r", key, exc_info=True)
311 log.error("canning class not importable: %r", key, exc_info=True)
307 mapping.pop(key)
312 mapping.pop(key)
308 else:
313 else:
309 mapping[cls] = mapping.pop(key)
314 mapping[cls] = mapping.pop(key)
310
315
311 def istype(obj, check):
316 def istype(obj, check):
312 """like isinstance(obj, check), but strict
317 """like isinstance(obj, check), but strict
313
318
314 This won't catch subclasses.
319 This won't catch subclasses.
315 """
320 """
316 if isinstance(check, tuple):
321 if isinstance(check, tuple):
317 for cls in check:
322 for cls in check:
318 if type(obj) is cls:
323 if type(obj) is cls:
319 return True
324 return True
320 return False
325 return False
321 else:
326 else:
322 return type(obj) is check
327 return type(obj) is check
323
328
324 def can(obj):
329 def can(obj):
325 """prepare an object for pickling"""
330 """prepare an object for pickling"""
326
331
327 import_needed = False
332 import_needed = False
328
333
329 for cls,canner in iteritems(can_map):
334 for cls,canner in iteritems(can_map):
330 if isinstance(cls, string_types):
335 if isinstance(cls, string_types):
331 import_needed = True
336 import_needed = True
332 break
337 break
333 elif istype(obj, cls):
338 elif istype(obj, cls):
334 return canner(obj)
339 return canner(obj)
335
340
336 if import_needed:
341 if import_needed:
337 # perform can_map imports, then try again
342 # perform can_map imports, then try again
338 # this will usually only happen once
343 # this will usually only happen once
339 _import_mapping(can_map, _original_can_map)
344 _import_mapping(can_map, _original_can_map)
340 return can(obj)
345 return can(obj)
341
346
342 return obj
347 return obj
343
348
344 def can_class(obj):
349 def can_class(obj):
345 if isinstance(obj, class_type) and obj.__module__ == '__main__':
350 if isinstance(obj, class_type) and obj.__module__ == '__main__':
346 return CannedClass(obj)
351 return CannedClass(obj)
347 else:
352 else:
348 return obj
353 return obj
349
354
350 def can_dict(obj):
355 def can_dict(obj):
351 """can the *values* of a dict"""
356 """can the *values* of a dict"""
352 if istype(obj, dict):
357 if istype(obj, dict):
353 newobj = {}
358 newobj = {}
354 for k, v in iteritems(obj):
359 for k, v in iteritems(obj):
355 newobj[k] = can(v)
360 newobj[k] = can(v)
356 return newobj
361 return newobj
357 else:
362 else:
358 return obj
363 return obj
359
364
360 sequence_types = (list, tuple, set)
365 sequence_types = (list, tuple, set)
361
366
362 def can_sequence(obj):
367 def can_sequence(obj):
363 """can the elements of a sequence"""
368 """can the elements of a sequence"""
364 if istype(obj, sequence_types):
369 if istype(obj, sequence_types):
365 t = type(obj)
370 t = type(obj)
366 return t([can(i) for i in obj])
371 return t([can(i) for i in obj])
367 else:
372 else:
368 return obj
373 return obj
369
374
370 def uncan(obj, g=None):
375 def uncan(obj, g=None):
371 """invert canning"""
376 """invert canning"""
372
377
373 import_needed = False
378 import_needed = False
374 for cls,uncanner in iteritems(uncan_map):
379 for cls,uncanner in iteritems(uncan_map):
375 if isinstance(cls, string_types):
380 if isinstance(cls, string_types):
376 import_needed = True
381 import_needed = True
377 break
382 break
378 elif isinstance(obj, cls):
383 elif isinstance(obj, cls):
379 return uncanner(obj, g)
384 return uncanner(obj, g)
380
385
381 if import_needed:
386 if import_needed:
382 # perform uncan_map imports, then try again
387 # perform uncan_map imports, then try again
383 # this will usually only happen once
388 # this will usually only happen once
384 _import_mapping(uncan_map, _original_uncan_map)
389 _import_mapping(uncan_map, _original_uncan_map)
385 return uncan(obj, g)
390 return uncan(obj, g)
386
391
387 return obj
392 return obj
388
393
389 def uncan_dict(obj, g=None):
394 def uncan_dict(obj, g=None):
390 if istype(obj, dict):
395 if istype(obj, dict):
391 newobj = {}
396 newobj = {}
392 for k, v in iteritems(obj):
397 for k, v in iteritems(obj):
393 newobj[k] = uncan(v,g)
398 newobj[k] = uncan(v,g)
394 return newobj
399 return newobj
395 else:
400 else:
396 return obj
401 return obj
397
402
398 def uncan_sequence(obj, g=None):
403 def uncan_sequence(obj, g=None):
399 if istype(obj, sequence_types):
404 if istype(obj, sequence_types):
400 t = type(obj)
405 t = type(obj)
401 return t([uncan(i,g) for i in obj])
406 return t([uncan(i,g) for i in obj])
402 else:
407 else:
403 return obj
408 return obj
404
409
405 def _uncan_dependent_hook(dep, g=None):
410 def _uncan_dependent_hook(dep, g=None):
406 dep.check_dependency()
411 dep.check_dependency()
407
412
408 def can_dependent(obj):
413 def can_dependent(obj):
409 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
414 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
410
415
411 #-------------------------------------------------------------------------------
416 #-------------------------------------------------------------------------------
412 # API dictionaries
417 # API dictionaries
413 #-------------------------------------------------------------------------------
418 #-------------------------------------------------------------------------------
414
419
415 # These dicts can be extended for custom serialization of new objects
420 # These dicts can be extended for custom serialization of new objects
416
421
417 can_map = {
422 can_map = {
418 'IPython.parallel.dependent' : can_dependent,
423 'IPython.parallel.dependent' : can_dependent,
419 'numpy.ndarray' : CannedArray,
424 'numpy.ndarray' : CannedArray,
420 FunctionType : CannedFunction,
425 FunctionType : CannedFunction,
421 bytes : CannedBytes,
426 bytes : CannedBytes,
422 buffer : CannedBuffer,
427 buffer : CannedBuffer,
423 cell_type : CannedCell,
428 cell_type : CannedCell,
424 class_type : can_class,
429 class_type : can_class,
425 }
430 }
426
431
427 uncan_map = {
432 uncan_map = {
428 CannedObject : lambda obj, g: obj.get_object(g),
433 CannedObject : lambda obj, g: obj.get_object(g),
429 }
434 }
430
435
431 # for use in _import_mapping:
436 # for use in _import_mapping:
432 _original_can_map = can_map.copy()
437 _original_can_map = can_map.copy()
433 _original_uncan_map = uncan_map.copy()
438 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now