##// END OF EJS Templates
Backport PR #6029: add pickleutil.PICKLE_PROTOCOL...
MinRK -
Show More
@@ -1,198 +1,185
1 """serialization utilities for apply messages
1 """serialization utilities for apply messages"""
2 2
3 Authors:
4
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
13
14 #-----------------------------------------------------------------------------
15 # Imports
16 #-----------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
17 5
18 6 try:
19 7 import cPickle
20 8 pickle = cPickle
21 9 except:
22 10 cPickle = None
23 11 import pickle
24 12
25
26 13 # IPython imports
27 14 from IPython.utils import py3compat
28 15 from IPython.utils.data import flatten
29 16 from IPython.utils.pickleutil import (
30 17 can, uncan, can_sequence, uncan_sequence, CannedObject,
31 istype, sequence_types,
18 istype, sequence_types, PICKLE_PROTOCOL,
32 19 )
33 20
34 21 if py3compat.PY3:
35 22 buffer = memoryview
36 23
37 24 #-----------------------------------------------------------------------------
38 25 # Serialization Functions
39 26 #-----------------------------------------------------------------------------
40 27
41 28 # default values for the thresholds:
42 29 MAX_ITEMS = 64
43 30 MAX_BYTES = 1024
44 31
45 32 def _extract_buffers(obj, threshold=MAX_BYTES):
46 33 """extract buffers larger than a certain threshold"""
47 34 buffers = []
48 35 if isinstance(obj, CannedObject) and obj.buffers:
49 36 for i,buf in enumerate(obj.buffers):
50 37 if len(buf) > threshold:
51 38 # buffer larger than threshold, prevent pickling
52 39 obj.buffers[i] = None
53 40 buffers.append(buf)
54 41 elif isinstance(buf, buffer):
55 42 # buffer too small for separate send, coerce to bytes
56 43 # because pickling buffer objects just results in broken pointers
57 44 obj.buffers[i] = bytes(buf)
58 45 return buffers
59 46
60 47 def _restore_buffers(obj, buffers):
61 48 """restore buffers extracted by """
62 49 if isinstance(obj, CannedObject) and obj.buffers:
63 50 for i,buf in enumerate(obj.buffers):
64 51 if buf is None:
65 52 obj.buffers[i] = buffers.pop(0)
66 53
67 54 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
68 55 """Serialize an object into a list of sendable buffers.
69 56
70 57 Parameters
71 58 ----------
72 59
73 60 obj : object
74 61 The object to be serialized
75 62 buffer_threshold : int
76 63 The threshold (in bytes) for pulling out data buffers
77 64 to avoid pickling them.
78 65 item_threshold : int
79 66 The maximum number of items over which canning will iterate.
80 67 Containers (lists, dicts) larger than this will be pickled without
81 68 introspection.
82 69
83 70 Returns
84 71 -------
85 72 [bufs] : list of buffers representing the serialized object.
86 73 """
87 74 buffers = []
88 75 if istype(obj, sequence_types) and len(obj) < item_threshold:
89 76 cobj = can_sequence(obj)
90 77 for c in cobj:
91 78 buffers.extend(_extract_buffers(c, buffer_threshold))
92 79 elif istype(obj, dict) and len(obj) < item_threshold:
93 80 cobj = {}
94 81 for k in sorted(obj):
95 82 c = can(obj[k])
96 83 buffers.extend(_extract_buffers(c, buffer_threshold))
97 84 cobj[k] = c
98 85 else:
99 86 cobj = can(obj)
100 87 buffers.extend(_extract_buffers(cobj, buffer_threshold))
101 88
102 buffers.insert(0, pickle.dumps(cobj,-1))
89 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
103 90 return buffers
104 91
105 92 def unserialize_object(buffers, g=None):
106 93 """reconstruct an object serialized by serialize_object from data buffers.
107 94
108 95 Parameters
109 96 ----------
110 97
111 98 bufs : list of buffers/bytes
112 99
113 100 g : globals to be used when uncanning
114 101
115 102 Returns
116 103 -------
117 104
118 105 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
119 106 """
120 107 bufs = list(buffers)
121 108 pobj = bufs.pop(0)
122 109 if not isinstance(pobj, bytes):
123 110 # a zmq message
124 111 pobj = bytes(pobj)
125 112 canned = pickle.loads(pobj)
126 113 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
127 114 for c in canned:
128 115 _restore_buffers(c, bufs)
129 116 newobj = uncan_sequence(canned, g)
130 117 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
131 118 newobj = {}
132 119 for k in sorted(canned):
133 120 c = canned[k]
134 121 _restore_buffers(c, bufs)
135 122 newobj[k] = uncan(c, g)
136 123 else:
137 124 _restore_buffers(canned, bufs)
138 125 newobj = uncan(canned, g)
139 126
140 127 return newobj, bufs
141 128
142 129 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
143 130 """pack up a function, args, and kwargs to be sent over the wire
144 131
145 132 Each element of args/kwargs will be canned for special treatment,
146 133 but inspection will not go any deeper than that.
147 134
148 135 Any object whose data is larger than `threshold` will not have their data copied
149 136 (only numpy arrays and bytes/buffers support zero-copy)
150 137
151 138 Message will be a list of bytes/buffers of the format:
152 139
153 140 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
154 141
155 142 With length at least two + len(args) + len(kwargs)
156 143 """
157 144
158 145 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
159 146
160 147 kw_keys = sorted(kwargs.keys())
161 148 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
162 149
163 150 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
164 151
165 msg = [pickle.dumps(can(f),-1)]
166 msg.append(pickle.dumps(info, -1))
152 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
153 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
167 154 msg.extend(arg_bufs)
168 155 msg.extend(kwarg_bufs)
169 156
170 157 return msg
171 158
172 159 def unpack_apply_message(bufs, g=None, copy=True):
173 160 """unpack f,args,kwargs from buffers packed by pack_apply_message()
174 161 Returns: original f,args,kwargs"""
175 162 bufs = list(bufs) # allow us to pop
176 163 assert len(bufs) >= 2, "not enough buffers!"
177 164 if not copy:
178 165 for i in range(2):
179 166 bufs[i] = bufs[i].bytes
180 167 f = uncan(pickle.loads(bufs.pop(0)), g)
181 168 info = pickle.loads(bufs.pop(0))
182 169 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
183 170
184 171 args = []
185 172 for i in range(info['nargs']):
186 173 arg, arg_bufs = unserialize_object(arg_bufs, g)
187 174 args.append(arg)
188 175 args = tuple(args)
189 176 assert not arg_bufs, "Shouldn't be any arg bufs left over"
190 177
191 178 kwargs = {}
192 179 for key in info['kw_keys']:
193 180 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
194 181 kwargs[key] = kwarg
195 182 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
196 183
197 184 return f,args,kwargs
198 185
@@ -1,850 +1,851
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hashlib
28 28 import hmac
29 29 import logging
30 30 import os
31 31 import pprint
32 32 import random
33 33 import uuid
34 34 from datetime import datetime
35 35
36 36 try:
37 37 import cPickle
38 38 pickle = cPickle
39 39 except:
40 40 cPickle = None
41 41 import pickle
42 42
43 43 import zmq
44 44 from zmq.utils import jsonapi
45 45 from zmq.eventloop.ioloop import IOLoop
46 46 from zmq.eventloop.zmqstream import ZMQStream
47 47
48 48 from IPython.config.configurable import Configurable, LoggingConfigurable
49 49 from IPython.utils import io
50 50 from IPython.utils.importstring import import_item
51 51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
52 52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
53 53 iteritems)
54 54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
55 55 DottedObjectName, CUnicode, Dict, Integer,
56 56 TraitError,
57 57 )
58 from IPython.utils.pickleutil import PICKLE_PROTOCOL
58 59 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59 60
60 61 #-----------------------------------------------------------------------------
61 62 # utility functions
62 63 #-----------------------------------------------------------------------------
63 64
64 65 def squash_unicode(obj):
65 66 """coerce unicode back to bytestrings."""
66 67 if isinstance(obj,dict):
67 68 for key in obj.keys():
68 69 obj[key] = squash_unicode(obj[key])
69 70 if isinstance(key, unicode_type):
70 71 obj[squash_unicode(key)] = obj.pop(key)
71 72 elif isinstance(obj, list):
72 73 for i,v in enumerate(obj):
73 74 obj[i] = squash_unicode(v)
74 75 elif isinstance(obj, unicode_type):
75 76 obj = obj.encode('utf8')
76 77 return obj
77 78
78 79 #-----------------------------------------------------------------------------
79 80 # globals and defaults
80 81 #-----------------------------------------------------------------------------
81 82
82 83 # ISO8601-ify datetime objects
83 84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
84 85 json_unpacker = lambda s: jsonapi.loads(s)
85 86
86 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
87 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
87 88 pickle_unpacker = pickle.loads
88 89
89 90 default_packer = json_packer
90 91 default_unpacker = json_unpacker
91 92
92 93 DELIM = b"<IDS|MSG>"
93 94 # singleton dummy tracker, which will always report as done
94 95 DONE = zmq.MessageTracker()
95 96
96 97 #-----------------------------------------------------------------------------
97 98 # Mixin tools for apps that use Sessions
98 99 #-----------------------------------------------------------------------------
99 100
100 101 session_aliases = dict(
101 102 ident = 'Session.session',
102 103 user = 'Session.username',
103 104 keyfile = 'Session.keyfile',
104 105 )
105 106
106 107 session_flags = {
107 108 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 109 'keyfile' : '' }},
109 110 """Use HMAC digests for authentication of messages.
110 111 Setting this flag will generate a new UUID to use as the HMAC key.
111 112 """),
112 113 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 114 """Don't authenticate messages."""),
114 115 }
115 116
116 117 def default_secure(cfg):
117 118 """Set the default behavior for a config environment to be secure.
118 119
119 120 If Session.key/keyfile have not been set, set Session.key to
120 121 a new random UUID.
121 122 """
122 123
123 124 if 'Session' in cfg:
124 125 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 126 return
126 127 # key/keyfile not specified, generate new UUID:
127 128 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128 129
129 130
130 131 #-----------------------------------------------------------------------------
131 132 # Classes
132 133 #-----------------------------------------------------------------------------
133 134
134 135 class SessionFactory(LoggingConfigurable):
135 136 """The Base class for configurables that have a Session, Context, logger,
136 137 and IOLoop.
137 138 """
138 139
139 140 logname = Unicode('')
140 141 def _logname_changed(self, name, old, new):
141 142 self.log = logging.getLogger(new)
142 143
143 144 # not configurable:
144 145 context = Instance('zmq.Context')
145 146 def _context_default(self):
146 147 return zmq.Context.instance()
147 148
148 149 session = Instance('IPython.kernel.zmq.session.Session')
149 150
150 151 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 152 def _loop_default(self):
152 153 return IOLoop.instance()
153 154
154 155 def __init__(self, **kwargs):
155 156 super(SessionFactory, self).__init__(**kwargs)
156 157
157 158 if self.session is None:
158 159 # construct the session
159 160 self.session = Session(**kwargs)
160 161
161 162
162 163 class Message(object):
163 164 """A simple message object that maps dict keys to attributes.
164 165
165 166 A Message can be created from a dict and a dict from a Message instance
166 167 simply by calling dict(msg_obj)."""
167 168
168 169 def __init__(self, msg_dict):
169 170 dct = self.__dict__
170 171 for k, v in iteritems(dict(msg_dict)):
171 172 if isinstance(v, dict):
172 173 v = Message(v)
173 174 dct[k] = v
174 175
175 176 # Having this iterator lets dict(msg_obj) work out of the box.
176 177 def __iter__(self):
177 178 return iter(iteritems(self.__dict__))
178 179
179 180 def __repr__(self):
180 181 return repr(self.__dict__)
181 182
182 183 def __str__(self):
183 184 return pprint.pformat(self.__dict__)
184 185
185 186 def __contains__(self, k):
186 187 return k in self.__dict__
187 188
188 189 def __getitem__(self, k):
189 190 return self.__dict__[k]
190 191
191 192
192 193 def msg_header(msg_id, msg_type, username, session):
193 194 date = datetime.now()
194 195 return locals()
195 196
196 197 def extract_header(msg_or_header):
197 198 """Given a message or header, return the header."""
198 199 if not msg_or_header:
199 200 return {}
200 201 try:
201 202 # See if msg_or_header is the entire message.
202 203 h = msg_or_header['header']
203 204 except KeyError:
204 205 try:
205 206 # See if msg_or_header is just the header
206 207 h = msg_or_header['msg_id']
207 208 except KeyError:
208 209 raise
209 210 else:
210 211 h = msg_or_header
211 212 if not isinstance(h, dict):
212 213 h = dict(h)
213 214 return h
214 215
215 216 class Session(Configurable):
216 217 """Object for handling serialization and sending of messages.
217 218
218 219 The Session object handles building messages and sending them
219 220 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 221 other over the network via Session objects, and only need to work with the
221 222 dict-based IPython message spec. The Session will handle
222 223 serialization/deserialization, security, and metadata.
223 224
224 225 Sessions support configurable serialiization via packer/unpacker traits,
225 226 and signing with HMAC digests via the key/keyfile traits.
226 227
227 228 Parameters
228 229 ----------
229 230
230 231 debug : bool
231 232 whether to trigger extra debugging statements
232 233 packer/unpacker : str : 'json', 'pickle' or import_string
233 234 importstrings for methods to serialize message parts. If just
234 235 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 236 Otherwise, the entire importstring must be used.
236 237
237 238 The functions must accept at least valid JSON input, and output *bytes*.
238 239
239 240 For example, to use msgpack:
240 241 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 242 pack/unpack : callables
242 243 You can also set the pack/unpack callables for serialization directly.
243 244 session : bytes
244 245 the ID of this Session object. The default is to generate a new UUID.
245 246 username : unicode
246 247 username added to message headers. The default is to ask the OS.
247 248 key : bytes
248 249 The key used to initialize an HMAC signature. If unset, messages
249 250 will not be signed or checked.
250 251 keyfile : filepath
251 252 The file containing a key. If this is set, `key` will be initialized
252 253 to the contents of the file.
253 254
254 255 """
255 256
256 257 debug=Bool(False, config=True, help="""Debug output in the Session""")
257 258
258 259 packer = DottedObjectName('json',config=True,
259 260 help="""The name of the packer for serializing messages.
260 261 Should be one of 'json', 'pickle', or an import name
261 262 for a custom callable serializer.""")
262 263 def _packer_changed(self, name, old, new):
263 264 if new.lower() == 'json':
264 265 self.pack = json_packer
265 266 self.unpack = json_unpacker
266 267 self.unpacker = new
267 268 elif new.lower() == 'pickle':
268 269 self.pack = pickle_packer
269 270 self.unpack = pickle_unpacker
270 271 self.unpacker = new
271 272 else:
272 273 self.pack = import_item(str(new))
273 274
274 275 unpacker = DottedObjectName('json', config=True,
275 276 help="""The name of the unpacker for unserializing messages.
276 277 Only used with custom functions for `packer`.""")
277 278 def _unpacker_changed(self, name, old, new):
278 279 if new.lower() == 'json':
279 280 self.pack = json_packer
280 281 self.unpack = json_unpacker
281 282 self.packer = new
282 283 elif new.lower() == 'pickle':
283 284 self.pack = pickle_packer
284 285 self.unpack = pickle_unpacker
285 286 self.packer = new
286 287 else:
287 288 self.unpack = import_item(str(new))
288 289
289 290 session = CUnicode(u'', config=True,
290 291 help="""The UUID identifying this session.""")
291 292 def _session_default(self):
292 293 u = unicode_type(uuid.uuid4())
293 294 self.bsession = u.encode('ascii')
294 295 return u
295 296
296 297 def _session_changed(self, name, old, new):
297 298 self.bsession = self.session.encode('ascii')
298 299
299 300 # bsession is the session as bytes
300 301 bsession = CBytes(b'')
301 302
302 303 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 304 help="""Username for the Session. Default is your system username.""",
304 305 config=True)
305 306
306 307 metadata = Dict({}, config=True,
307 308 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308 309
309 310 # message signature related traits:
310 311
311 312 key = CBytes(b'', config=True,
312 313 help="""execution key, for extra authentication.""")
313 314 def _key_changed(self, name, old, new):
314 315 if new:
315 316 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 317 else:
317 318 self.auth = None
318 319
319 320 signature_scheme = Unicode('hmac-sha256', config=True,
320 321 help="""The digest scheme used to construct the message signatures.
321 322 Must have the form 'hmac-HASH'.""")
322 323 def _signature_scheme_changed(self, name, old, new):
323 324 if not new.startswith('hmac-'):
324 325 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 326 hash_name = new.split('-', 1)[1]
326 327 try:
327 328 self.digest_mod = getattr(hashlib, hash_name)
328 329 except AttributeError:
329 330 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330 331
331 332 digest_mod = Any()
332 333 def _digest_mod_default(self):
333 334 return hashlib.sha256
334 335
335 336 auth = Instance(hmac.HMAC)
336 337
337 338 digest_history = Set()
338 339 digest_history_size = Integer(2**16, config=True,
339 340 help="""The maximum number of digests to remember.
340 341
341 342 The digest history will be culled when it exceeds this value.
342 343 """
343 344 )
344 345
345 346 keyfile = Unicode('', config=True,
346 347 help="""path to file containing execution key.""")
347 348 def _keyfile_changed(self, name, old, new):
348 349 with open(new, 'rb') as f:
349 350 self.key = f.read().strip()
350 351
351 352 # for protecting against sends from forks
352 353 pid = Integer()
353 354
354 355 # serialization traits:
355 356
356 357 pack = Any(default_packer) # the actual packer function
357 358 def _pack_changed(self, name, old, new):
358 359 if not callable(new):
359 360 raise TypeError("packer must be callable, not %s"%type(new))
360 361
361 362 unpack = Any(default_unpacker) # the actual packer function
362 363 def _unpack_changed(self, name, old, new):
363 364 # unpacker is not checked - it is assumed to be
364 365 if not callable(new):
365 366 raise TypeError("unpacker must be callable, not %s"%type(new))
366 367
367 368 # thresholds:
368 369 copy_threshold = Integer(2**16, config=True,
369 370 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 371 buffer_threshold = Integer(MAX_BYTES, config=True,
371 372 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 373 item_threshold = Integer(MAX_ITEMS, config=True,
373 374 help="""The maximum number of items for a container to be introspected for custom serialization.
374 375 Containers larger than this are pickled outright.
375 376 """
376 377 )
377 378
378 379
379 380 def __init__(self, **kwargs):
380 381 """create a Session object
381 382
382 383 Parameters
383 384 ----------
384 385
385 386 debug : bool
386 387 whether to trigger extra debugging statements
387 388 packer/unpacker : str : 'json', 'pickle' or import_string
388 389 importstrings for methods to serialize message parts. If just
389 390 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 391 Otherwise, the entire importstring must be used.
391 392
392 393 The functions must accept at least valid JSON input, and output
393 394 *bytes*.
394 395
395 396 For example, to use msgpack:
396 397 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 398 pack/unpack : callables
398 399 You can also set the pack/unpack callables for serialization
399 400 directly.
400 401 session : unicode (must be ascii)
401 402 the ID of this Session object. The default is to generate a new
402 403 UUID.
403 404 bsession : bytes
404 405 The session as bytes
405 406 username : unicode
406 407 username added to message headers. The default is to ask the OS.
407 408 key : bytes
408 409 The key used to initialize an HMAC signature. If unset, messages
409 410 will not be signed or checked.
410 411 signature_scheme : str
411 412 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 413 where 'HASH' is a hashing function available in Python's hashlib.
413 414 The default is 'hmac-sha256'.
414 415 This is ignored if 'key' is empty.
415 416 keyfile : filepath
416 417 The file containing a key. If this is set, `key` will be
417 418 initialized to the contents of the file.
418 419 """
419 420 super(Session, self).__init__(**kwargs)
420 421 self._check_packers()
421 422 self.none = self.pack({})
422 423 # ensure self._session_default() if necessary, so bsession is defined:
423 424 self.session
424 425 self.pid = os.getpid()
425 426
426 427 @property
427 428 def msg_id(self):
428 429 """always return new uuid"""
429 430 return str(uuid.uuid4())
430 431
431 432 def _check_packers(self):
432 433 """check packers for datetime support."""
433 434 pack = self.pack
434 435 unpack = self.unpack
435 436
436 437 # check simple serialization
437 438 msg = dict(a=[1,'hi'])
438 439 try:
439 440 packed = pack(msg)
440 441 except Exception as e:
441 442 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 443 if self.packer == 'json':
443 444 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 445 else:
445 446 jsonmsg = ""
446 447 raise ValueError(
447 448 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 449 )
449 450
450 451 # ensure packed message is bytes
451 452 if not isinstance(packed, bytes):
452 453 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453 454
454 455 # check that unpack is pack's inverse
455 456 try:
456 457 unpacked = unpack(packed)
457 458 assert unpacked == msg
458 459 except Exception as e:
459 460 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 461 if self.packer == 'json':
461 462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 463 else:
463 464 jsonmsg = ""
464 465 raise ValueError(
465 466 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 467 )
467 468
468 469 # check datetime support
469 470 msg = dict(t=datetime.now())
470 471 try:
471 472 unpacked = unpack(pack(msg))
472 473 if isinstance(unpacked['t'], datetime):
473 474 raise ValueError("Shouldn't deserialize to datetime")
474 475 except Exception:
475 476 self.pack = lambda o: pack(squash_dates(o))
476 477 self.unpack = lambda s: unpack(s)
477 478
478 479 def msg_header(self, msg_type):
479 480 return msg_header(self.msg_id, msg_type, self.username, self.session)
480 481
481 482 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 483 """Return the nested message dict.
483 484
484 485 This format is different from what is sent over the wire. The
485 486 serialize/unserialize methods converts this nested message dict to the wire
486 487 format, which is a list of message parts.
487 488 """
488 489 msg = {}
489 490 header = self.msg_header(msg_type) if header is None else header
490 491 msg['header'] = header
491 492 msg['msg_id'] = header['msg_id']
492 493 msg['msg_type'] = header['msg_type']
493 494 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 495 msg['content'] = {} if content is None else content
495 496 msg['metadata'] = self.metadata.copy()
496 497 if metadata is not None:
497 498 msg['metadata'].update(metadata)
498 499 return msg
499 500
500 501 def sign(self, msg_list):
501 502 """Sign a message with HMAC digest. If no auth, return b''.
502 503
503 504 Parameters
504 505 ----------
505 506 msg_list : list
506 507 The [p_header,p_parent,p_content] part of the message list.
507 508 """
508 509 if self.auth is None:
509 510 return b''
510 511 h = self.auth.copy()
511 512 for m in msg_list:
512 513 h.update(m)
513 514 return str_to_bytes(h.hexdigest())
514 515
515 516 def serialize(self, msg, ident=None):
516 517 """Serialize the message components to bytes.
517 518
518 519 This is roughly the inverse of unserialize. The serialize/unserialize
519 520 methods work with full message lists, whereas pack/unpack work with
520 521 the individual message parts in the message list.
521 522
522 523 Parameters
523 524 ----------
524 525 msg : dict or Message
525 526 The nexted message dict as returned by the self.msg method.
526 527
527 528 Returns
528 529 -------
529 530 msg_list : list
530 531 The list of bytes objects to be sent with the format::
531 532
532 533 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 534 p_metadata, p_content, buffer1, buffer2, ...]
534 535
535 536 In this list, the ``p_*`` entities are the packed or serialized
536 537 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 538 """
538 539 content = msg.get('content', {})
539 540 if content is None:
540 541 content = self.none
541 542 elif isinstance(content, dict):
542 543 content = self.pack(content)
543 544 elif isinstance(content, bytes):
544 545 # content is already packed, as in a relayed message
545 546 pass
546 547 elif isinstance(content, unicode_type):
547 548 # should be bytes, but JSON often spits out unicode
548 549 content = content.encode('utf8')
549 550 else:
550 551 raise TypeError("Content incorrect type: %s"%type(content))
551 552
552 553 real_message = [self.pack(msg['header']),
553 554 self.pack(msg['parent_header']),
554 555 self.pack(msg['metadata']),
555 556 content,
556 557 ]
557 558
558 559 to_send = []
559 560
560 561 if isinstance(ident, list):
561 562 # accept list of idents
562 563 to_send.extend(ident)
563 564 elif ident is not None:
564 565 to_send.append(ident)
565 566 to_send.append(DELIM)
566 567
567 568 signature = self.sign(real_message)
568 569 to_send.append(signature)
569 570
570 571 to_send.extend(real_message)
571 572
572 573 return to_send
573 574
574 575 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 576 buffers=None, track=False, header=None, metadata=None):
576 577 """Build and send a message via stream or socket.
577 578
578 579 The message format used by this function internally is as follows:
579 580
580 581 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 582 buffer1,buffer2,...]
582 583
583 584 The serialize/unserialize methods convert the nested message dict into this
584 585 format.
585 586
586 587 Parameters
587 588 ----------
588 589
589 590 stream : zmq.Socket or ZMQStream
590 591 The socket-like object used to send the data.
591 592 msg_or_type : str or Message/dict
592 593 Normally, msg_or_type will be a msg_type unless a message is being
593 594 sent more than once. If a header is supplied, this can be set to
594 595 None and the msg_type will be pulled from the header.
595 596
596 597 content : dict or None
597 598 The content of the message (ignored if msg_or_type is a message).
598 599 header : dict or None
599 600 The header dict for the message (ignored if msg_to_type is a message).
600 601 parent : Message or dict or None
601 602 The parent or parent header describing the parent of this message
602 603 (ignored if msg_or_type is a message).
603 604 ident : bytes or list of bytes
604 605 The zmq.IDENTITY routing path.
605 606 metadata : dict or None
606 607 The metadata describing the message
607 608 buffers : list or None
608 609 The already-serialized buffers to be appended to the message.
609 610 track : bool
610 611 Whether to track. Only for use with Sockets, because ZMQStream
611 612 objects cannot track messages.
612 613
613 614
614 615 Returns
615 616 -------
616 617 msg : dict
617 618 The constructed message.
618 619 """
619 620 if not isinstance(stream, zmq.Socket):
620 621 # ZMQStreams and dummy sockets do not support tracking.
621 622 track = False
622 623
623 624 if isinstance(msg_or_type, (Message, dict)):
624 625 # We got a Message or message dict, not a msg_type so don't
625 626 # build a new Message.
626 627 msg = msg_or_type
627 628 else:
628 629 msg = self.msg(msg_or_type, content=content, parent=parent,
629 630 header=header, metadata=metadata)
630 631 if not os.getpid() == self.pid:
631 632 io.rprint("WARNING: attempted to send message from fork")
632 633 io.rprint(msg)
633 634 return
634 635 buffers = [] if buffers is None else buffers
635 636 to_send = self.serialize(msg, ident)
636 637 to_send.extend(buffers)
637 638 longest = max([ len(s) for s in to_send ])
638 639 copy = (longest < self.copy_threshold)
639 640
640 641 if buffers and track and not copy:
641 642 # only really track when we are doing zero-copy buffers
642 643 tracker = stream.send_multipart(to_send, copy=False, track=True)
643 644 else:
644 645 # use dummy tracker, which will be done immediately
645 646 tracker = DONE
646 647 stream.send_multipart(to_send, copy=copy)
647 648
648 649 if self.debug:
649 650 pprint.pprint(msg)
650 651 pprint.pprint(to_send)
651 652 pprint.pprint(buffers)
652 653
653 654 msg['tracker'] = tracker
654 655
655 656 return msg
656 657
657 658 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
658 659 """Send a raw message via ident path.
659 660
660 661 This method is used to send a already serialized message.
661 662
662 663 Parameters
663 664 ----------
664 665 stream : ZMQStream or Socket
665 666 The ZMQ stream or socket to use for sending the message.
666 667 msg_list : list
667 668 The serialized list of messages to send. This only includes the
668 669 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
669 670 the message.
670 671 ident : ident or list
671 672 A single ident or a list of idents to use in sending.
672 673 """
673 674 to_send = []
674 675 if isinstance(ident, bytes):
675 676 ident = [ident]
676 677 if ident is not None:
677 678 to_send.extend(ident)
678 679
679 680 to_send.append(DELIM)
680 681 to_send.append(self.sign(msg_list))
681 682 to_send.extend(msg_list)
682 683 stream.send_multipart(to_send, flags, copy=copy)
683 684
684 685 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
685 686 """Receive and unpack a message.
686 687
687 688 Parameters
688 689 ----------
689 690 socket : ZMQStream or Socket
690 691 The socket or stream to use in receiving.
691 692
692 693 Returns
693 694 -------
694 695 [idents], msg
695 696 [idents] is a list of idents and msg is a nested message dict of
696 697 same format as self.msg returns.
697 698 """
698 699 if isinstance(socket, ZMQStream):
699 700 socket = socket.socket
700 701 try:
701 702 msg_list = socket.recv_multipart(mode, copy=copy)
702 703 except zmq.ZMQError as e:
703 704 if e.errno == zmq.EAGAIN:
704 705 # We can convert EAGAIN to None as we know in this case
705 706 # recv_multipart won't return None.
706 707 return None,None
707 708 else:
708 709 raise
709 710 # split multipart message into identity list and message dict
710 711 # invalid large messages can cause very expensive string comparisons
711 712 idents, msg_list = self.feed_identities(msg_list, copy)
712 713 try:
713 714 return idents, self.unserialize(msg_list, content=content, copy=copy)
714 715 except Exception as e:
715 716 # TODO: handle it
716 717 raise e
717 718
718 719 def feed_identities(self, msg_list, copy=True):
719 720 """Split the identities from the rest of the message.
720 721
721 722 Feed until DELIM is reached, then return the prefix as idents and
722 723 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
723 724 but that would be silly.
724 725
725 726 Parameters
726 727 ----------
727 728 msg_list : a list of Message or bytes objects
728 729 The message to be split.
729 730 copy : bool
730 731 flag determining whether the arguments are bytes or Messages
731 732
732 733 Returns
733 734 -------
734 735 (idents, msg_list) : two lists
735 736 idents will always be a list of bytes, each of which is a ZMQ
736 737 identity. msg_list will be a list of bytes or zmq.Messages of the
737 738 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
738 739 should be unpackable/unserializable via self.unserialize at this
739 740 point.
740 741 """
741 742 if copy:
742 743 idx = msg_list.index(DELIM)
743 744 return msg_list[:idx], msg_list[idx+1:]
744 745 else:
745 746 failed = True
746 747 for idx,m in enumerate(msg_list):
747 748 if m.bytes == DELIM:
748 749 failed = False
749 750 break
750 751 if failed:
751 752 raise ValueError("DELIM not in msg_list")
752 753 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
753 754 return [m.bytes for m in idents], msg_list
754 755
755 756 def _add_digest(self, signature):
756 757 """add a digest to history to protect against replay attacks"""
757 758 if self.digest_history_size == 0:
758 759 # no history, never add digests
759 760 return
760 761
761 762 self.digest_history.add(signature)
762 763 if len(self.digest_history) > self.digest_history_size:
763 764 # threshold reached, cull 10%
764 765 self._cull_digest_history()
765 766
766 767 def _cull_digest_history(self):
767 768 """cull the digest history
768 769
769 770 Removes a randomly selected 10% of the digest history
770 771 """
771 772 current = len(self.digest_history)
772 773 n_to_cull = max(int(current // 10), current - self.digest_history_size)
773 774 if n_to_cull >= current:
774 775 self.digest_history = set()
775 776 return
776 777 to_cull = random.sample(self.digest_history, n_to_cull)
777 778 self.digest_history.difference_update(to_cull)
778 779
779 780 def unserialize(self, msg_list, content=True, copy=True):
780 781 """Unserialize a msg_list to a nested message dict.
781 782
782 783 This is roughly the inverse of serialize. The serialize/unserialize
783 784 methods work with full message lists, whereas pack/unpack work with
784 785 the individual message parts in the message list.
785 786
786 787 Parameters
787 788 ----------
788 789 msg_list : list of bytes or Message objects
789 790 The list of message parts of the form [HMAC,p_header,p_parent,
790 791 p_metadata,p_content,buffer1,buffer2,...].
791 792 content : bool (True)
792 793 Whether to unpack the content dict (True), or leave it packed
793 794 (False).
794 795 copy : bool (True)
795 796 Whether to return the bytes (True), or the non-copying Message
796 797 object in each place (False).
797 798
798 799 Returns
799 800 -------
800 801 msg : dict
801 802 The nested message dict with top-level keys [header, parent_header,
802 803 content, buffers].
803 804 """
804 805 minlen = 5
805 806 message = {}
806 807 if not copy:
807 808 for i in range(minlen):
808 809 msg_list[i] = msg_list[i].bytes
809 810 if self.auth is not None:
810 811 signature = msg_list[0]
811 812 if not signature:
812 813 raise ValueError("Unsigned Message")
813 814 if signature in self.digest_history:
814 815 raise ValueError("Duplicate Signature: %r" % signature)
815 816 self._add_digest(signature)
816 817 check = self.sign(msg_list[1:5])
817 818 if not signature == check:
818 819 raise ValueError("Invalid Signature: %r" % signature)
819 820 if not len(msg_list) >= minlen:
820 821 raise TypeError("malformed message, must have at least %i elements"%minlen)
821 822 header = self.unpack(msg_list[1])
822 823 message['header'] = extract_dates(header)
823 824 message['msg_id'] = header['msg_id']
824 825 message['msg_type'] = header['msg_type']
825 826 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
826 827 message['metadata'] = self.unpack(msg_list[3])
827 828 if content:
828 829 message['content'] = self.unpack(msg_list[4])
829 830 else:
830 831 message['content'] = msg_list[4]
831 832
832 833 message['buffers'] = msg_list[5:]
833 834 return message
834 835
835 836 def test_msg2obj():
836 837 am = dict(x=1)
837 838 ao = Message(am)
838 839 assert ao.x == am['x']
839 840
840 841 am['y'] = dict(z=1)
841 842 ao = Message(am)
842 843 assert ao.y.z == am['y']['z']
843 844
844 845 k1, k2 = 'y', 'z'
845 846 assert ao[k1][k2] == am[k1][k2]
846 847
847 848 am2 = dict(ao)
848 849 assert am['x'] == am2['x']
849 850 assert am['y']['z'] == am2['y']['z']
850 851
@@ -1,385 +1,390
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 # Copyright (c) IPython Development Team.
6 6 # Distributed under the terms of the Modified BSD License.
7 7
8 8 import copy
9 9 import logging
10 10 import sys
11 11 from types import FunctionType
12 12
13 13 try:
14 14 import cPickle as pickle
15 15 except ImportError:
16 16 import pickle
17 17
18 18 from . import codeutil # This registers a hook when it's imported
19 19 from . import py3compat
20 20 from .importstring import import_item
21 21 from .py3compat import string_types, iteritems
22 22
23 23 from IPython.config import Application
24 24
25 25 if py3compat.PY3:
26 26 buffer = memoryview
27 27 class_type = type
28 28 closure_attr = '__closure__'
29 29 else:
30 30 from types import ClassType
31 31 class_type = (type, ClassType)
32 32 closure_attr = 'func_closure'
33 33
34 try:
35 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
36 except AttributeError:
37 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
38
34 39 #-------------------------------------------------------------------------------
35 40 # Functions
36 41 #-------------------------------------------------------------------------------
37 42
38 43
39 44 def use_dill():
40 45 """use dill to expand serialization support
41 46
42 47 adds support for object methods and closures to serialization.
43 48 """
44 49 # import dill causes most of the magic
45 50 import dill
46 51
47 52 # dill doesn't work with cPickle,
48 53 # tell the two relevant modules to use plain pickle
49 54
50 55 global pickle
51 56 pickle = dill
52 57
53 58 try:
54 59 from IPython.kernel.zmq import serialize
55 60 except ImportError:
56 61 pass
57 62 else:
58 63 serialize.pickle = dill
59 64
60 65 # disable special function handling, let dill take care of it
61 66 can_map.pop(FunctionType, None)
62 67
63 68
64 69 #-------------------------------------------------------------------------------
65 70 # Classes
66 71 #-------------------------------------------------------------------------------
67 72
68 73
69 74 class CannedObject(object):
70 75 def __init__(self, obj, keys=[], hook=None):
71 76 """can an object for safe pickling
72 77
73 78 Parameters
74 79 ==========
75 80
76 81 obj:
77 82 The object to be canned
78 83 keys: list (optional)
79 84 list of attribute names that will be explicitly canned / uncanned
80 85 hook: callable (optional)
81 86 An optional extra callable,
82 87 which can do additional processing of the uncanned object.
83 88
84 89 large data may be offloaded into the buffers list,
85 90 used for zero-copy transfers.
86 91 """
87 92 self.keys = keys
88 93 self.obj = copy.copy(obj)
89 94 self.hook = can(hook)
90 95 for key in keys:
91 96 setattr(self.obj, key, can(getattr(obj, key)))
92 97
93 98 self.buffers = []
94 99
95 100 def get_object(self, g=None):
96 101 if g is None:
97 102 g = {}
98 103 obj = self.obj
99 104 for key in self.keys:
100 105 setattr(obj, key, uncan(getattr(obj, key), g))
101 106
102 107 if self.hook:
103 108 self.hook = uncan(self.hook, g)
104 109 self.hook(obj, g)
105 110 return self.obj
106 111
107 112
108 113 class Reference(CannedObject):
109 114 """object for wrapping a remote reference by name."""
110 115 def __init__(self, name):
111 116 if not isinstance(name, string_types):
112 117 raise TypeError("illegal name: %r"%name)
113 118 self.name = name
114 119 self.buffers = []
115 120
116 121 def __repr__(self):
117 122 return "<Reference: %r>"%self.name
118 123
119 124 def get_object(self, g=None):
120 125 if g is None:
121 126 g = {}
122 127
123 128 return eval(self.name, g)
124 129
125 130
126 131 class CannedFunction(CannedObject):
127 132
128 133 def __init__(self, f):
129 134 self._check_type(f)
130 135 self.code = f.__code__
131 136 if f.__defaults__:
132 137 self.defaults = [ can(fd) for fd in f.__defaults__ ]
133 138 else:
134 139 self.defaults = None
135 140
136 141 if getattr(f, closure_attr, None):
137 142 raise ValueError("Sorry, cannot pickle functions with closures.")
138 143 self.module = f.__module__ or '__main__'
139 144 self.__name__ = f.__name__
140 145 self.buffers = []
141 146
142 147 def _check_type(self, obj):
143 148 assert isinstance(obj, FunctionType), "Not a function type"
144 149
145 150 def get_object(self, g=None):
146 151 # try to load function back into its module:
147 152 if not self.module.startswith('__'):
148 153 __import__(self.module)
149 154 g = sys.modules[self.module].__dict__
150 155
151 156 if g is None:
152 157 g = {}
153 158 if self.defaults:
154 159 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
155 160 else:
156 161 defaults = None
157 162 newFunc = FunctionType(self.code, g, self.__name__, defaults)
158 163 return newFunc
159 164
160 165 class CannedClass(CannedObject):
161 166
162 167 def __init__(self, cls):
163 168 self._check_type(cls)
164 169 self.name = cls.__name__
165 170 self.old_style = not isinstance(cls, type)
166 171 self._canned_dict = {}
167 172 for k,v in cls.__dict__.items():
168 173 if k not in ('__weakref__', '__dict__'):
169 174 self._canned_dict[k] = can(v)
170 175 if self.old_style:
171 176 mro = []
172 177 else:
173 178 mro = cls.mro()
174 179
175 180 self.parents = [ can(c) for c in mro[1:] ]
176 181 self.buffers = []
177 182
178 183 def _check_type(self, obj):
179 184 assert isinstance(obj, class_type), "Not a class type"
180 185
181 186 def get_object(self, g=None):
182 187 parents = tuple(uncan(p, g) for p in self.parents)
183 188 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
184 189
185 190 class CannedArray(CannedObject):
186 191 def __init__(self, obj):
187 192 from numpy import ascontiguousarray
188 193 self.shape = obj.shape
189 194 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
190 195 self.pickled = False
191 196 if sum(obj.shape) == 0:
192 197 self.pickled = True
193 198 elif obj.dtype == 'O':
194 199 # can't handle object dtype with buffer approach
195 200 self.pickled = True
196 201 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
197 202 self.pickled = True
198 203 if self.pickled:
199 204 # just pickle it
200 self.buffers = [pickle.dumps(obj, -1)]
205 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
201 206 else:
202 207 # ensure contiguous
203 208 obj = ascontiguousarray(obj, dtype=None)
204 209 self.buffers = [buffer(obj)]
205 210
206 211 def get_object(self, g=None):
207 212 from numpy import frombuffer
208 213 data = self.buffers[0]
209 214 if self.pickled:
210 215 # no shape, we just pickled it
211 216 return pickle.loads(data)
212 217 else:
213 218 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
214 219
215 220
216 221 class CannedBytes(CannedObject):
217 222 wrap = bytes
218 223 def __init__(self, obj):
219 224 self.buffers = [obj]
220 225
221 226 def get_object(self, g=None):
222 227 data = self.buffers[0]
223 228 return self.wrap(data)
224 229
225 230 def CannedBuffer(CannedBytes):
226 231 wrap = buffer
227 232
228 233 #-------------------------------------------------------------------------------
229 234 # Functions
230 235 #-------------------------------------------------------------------------------
231 236
232 237 def _logger():
233 238 """get the logger for the current Application
234 239
235 240 the root logger will be used if no Application is running
236 241 """
237 242 if Application.initialized():
238 243 logger = Application.instance().log
239 244 else:
240 245 logger = logging.getLogger()
241 246 if not logger.handlers:
242 247 logging.basicConfig()
243 248
244 249 return logger
245 250
246 251 def _import_mapping(mapping, original=None):
247 252 """import any string-keys in a type mapping
248 253
249 254 """
250 255 log = _logger()
251 256 log.debug("Importing canning map")
252 257 for key,value in list(mapping.items()):
253 258 if isinstance(key, string_types):
254 259 try:
255 260 cls = import_item(key)
256 261 except Exception:
257 262 if original and key not in original:
258 263 # only message on user-added classes
259 264 log.error("canning class not importable: %r", key, exc_info=True)
260 265 mapping.pop(key)
261 266 else:
262 267 mapping[cls] = mapping.pop(key)
263 268
264 269 def istype(obj, check):
265 270 """like isinstance(obj, check), but strict
266 271
267 272 This won't catch subclasses.
268 273 """
269 274 if isinstance(check, tuple):
270 275 for cls in check:
271 276 if type(obj) is cls:
272 277 return True
273 278 return False
274 279 else:
275 280 return type(obj) is check
276 281
277 282 def can(obj):
278 283 """prepare an object for pickling"""
279 284
280 285 import_needed = False
281 286
282 287 for cls,canner in iteritems(can_map):
283 288 if isinstance(cls, string_types):
284 289 import_needed = True
285 290 break
286 291 elif istype(obj, cls):
287 292 return canner(obj)
288 293
289 294 if import_needed:
290 295 # perform can_map imports, then try again
291 296 # this will usually only happen once
292 297 _import_mapping(can_map, _original_can_map)
293 298 return can(obj)
294 299
295 300 return obj
296 301
297 302 def can_class(obj):
298 303 if isinstance(obj, class_type) and obj.__module__ == '__main__':
299 304 return CannedClass(obj)
300 305 else:
301 306 return obj
302 307
303 308 def can_dict(obj):
304 309 """can the *values* of a dict"""
305 310 if istype(obj, dict):
306 311 newobj = {}
307 312 for k, v in iteritems(obj):
308 313 newobj[k] = can(v)
309 314 return newobj
310 315 else:
311 316 return obj
312 317
313 318 sequence_types = (list, tuple, set)
314 319
315 320 def can_sequence(obj):
316 321 """can the elements of a sequence"""
317 322 if istype(obj, sequence_types):
318 323 t = type(obj)
319 324 return t([can(i) for i in obj])
320 325 else:
321 326 return obj
322 327
323 328 def uncan(obj, g=None):
324 329 """invert canning"""
325 330
326 331 import_needed = False
327 332 for cls,uncanner in iteritems(uncan_map):
328 333 if isinstance(cls, string_types):
329 334 import_needed = True
330 335 break
331 336 elif isinstance(obj, cls):
332 337 return uncanner(obj, g)
333 338
334 339 if import_needed:
335 340 # perform uncan_map imports, then try again
336 341 # this will usually only happen once
337 342 _import_mapping(uncan_map, _original_uncan_map)
338 343 return uncan(obj, g)
339 344
340 345 return obj
341 346
342 347 def uncan_dict(obj, g=None):
343 348 if istype(obj, dict):
344 349 newobj = {}
345 350 for k, v in iteritems(obj):
346 351 newobj[k] = uncan(v,g)
347 352 return newobj
348 353 else:
349 354 return obj
350 355
351 356 def uncan_sequence(obj, g=None):
352 357 if istype(obj, sequence_types):
353 358 t = type(obj)
354 359 return t([uncan(i,g) for i in obj])
355 360 else:
356 361 return obj
357 362
358 363 def _uncan_dependent_hook(dep, g=None):
359 364 dep.check_dependency()
360 365
361 366 def can_dependent(obj):
362 367 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
363 368
364 369 #-------------------------------------------------------------------------------
365 370 # API dictionaries
366 371 #-------------------------------------------------------------------------------
367 372
368 373 # These dicts can be extended for custom serialization of new objects
369 374
370 375 can_map = {
371 376 'IPython.parallel.dependent' : can_dependent,
372 377 'numpy.ndarray' : CannedArray,
373 378 FunctionType : CannedFunction,
374 379 bytes : CannedBytes,
375 380 buffer : CannedBuffer,
376 381 class_type : can_class,
377 382 }
378 383
379 384 uncan_map = {
380 385 CannedObject : lambda obj, g: obj.get_object(g),
381 386 }
382 387
383 388 # for use in _import_mapping:
384 389 _original_can_map = can_map.copy()
385 390 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now