##// END OF EJS Templates
add pickleutil.PICKLE_PROTOCOL...
MinRK -
Show More
@@ -1,198 +1,185 b''
1 """serialization utilities for apply messages
1 """serialization utilities for apply messages"""
2 2
3 Authors:
4
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
13
14 #-----------------------------------------------------------------------------
15 # Imports
16 #-----------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
17 5
18 6 try:
19 7 import cPickle
20 8 pickle = cPickle
21 9 except:
22 10 cPickle = None
23 11 import pickle
24 12
25
26 13 # IPython imports
27 14 from IPython.utils import py3compat
28 15 from IPython.utils.data import flatten
29 16 from IPython.utils.pickleutil import (
30 17 can, uncan, can_sequence, uncan_sequence, CannedObject,
31 istype, sequence_types,
18 istype, sequence_types, PICKLE_PROTOCOL,
32 19 )
33 20
34 21 if py3compat.PY3:
35 22 buffer = memoryview
36 23
37 24 #-----------------------------------------------------------------------------
38 25 # Serialization Functions
39 26 #-----------------------------------------------------------------------------
40 27
41 28 # default values for the thresholds:
42 29 MAX_ITEMS = 64
43 30 MAX_BYTES = 1024
44 31
45 32 def _extract_buffers(obj, threshold=MAX_BYTES):
46 33 """extract buffers larger than a certain threshold"""
47 34 buffers = []
48 35 if isinstance(obj, CannedObject) and obj.buffers:
49 36 for i,buf in enumerate(obj.buffers):
50 37 if len(buf) > threshold:
51 38 # buffer larger than threshold, prevent pickling
52 39 obj.buffers[i] = None
53 40 buffers.append(buf)
54 41 elif isinstance(buf, buffer):
55 42 # buffer too small for separate send, coerce to bytes
56 43 # because pickling buffer objects just results in broken pointers
57 44 obj.buffers[i] = bytes(buf)
58 45 return buffers
59 46
60 47 def _restore_buffers(obj, buffers):
61 48 """restore buffers extracted by """
62 49 if isinstance(obj, CannedObject) and obj.buffers:
63 50 for i,buf in enumerate(obj.buffers):
64 51 if buf is None:
65 52 obj.buffers[i] = buffers.pop(0)
66 53
67 54 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
68 55 """Serialize an object into a list of sendable buffers.
69 56
70 57 Parameters
71 58 ----------
72 59
73 60 obj : object
74 61 The object to be serialized
75 62 buffer_threshold : int
76 63 The threshold (in bytes) for pulling out data buffers
77 64 to avoid pickling them.
78 65 item_threshold : int
79 66 The maximum number of items over which canning will iterate.
80 67 Containers (lists, dicts) larger than this will be pickled without
81 68 introspection.
82 69
83 70 Returns
84 71 -------
85 72 [bufs] : list of buffers representing the serialized object.
86 73 """
87 74 buffers = []
88 75 if istype(obj, sequence_types) and len(obj) < item_threshold:
89 76 cobj = can_sequence(obj)
90 77 for c in cobj:
91 78 buffers.extend(_extract_buffers(c, buffer_threshold))
92 79 elif istype(obj, dict) and len(obj) < item_threshold:
93 80 cobj = {}
94 81 for k in sorted(obj):
95 82 c = can(obj[k])
96 83 buffers.extend(_extract_buffers(c, buffer_threshold))
97 84 cobj[k] = c
98 85 else:
99 86 cobj = can(obj)
100 87 buffers.extend(_extract_buffers(cobj, buffer_threshold))
101 88
102 buffers.insert(0, pickle.dumps(cobj,-1))
89 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
103 90 return buffers
104 91
105 92 def unserialize_object(buffers, g=None):
106 93 """reconstruct an object serialized by serialize_object from data buffers.
107 94
108 95 Parameters
109 96 ----------
110 97
111 98 bufs : list of buffers/bytes
112 99
113 100 g : globals to be used when uncanning
114 101
115 102 Returns
116 103 -------
117 104
118 105 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
119 106 """
120 107 bufs = list(buffers)
121 108 pobj = bufs.pop(0)
122 109 if not isinstance(pobj, bytes):
123 110 # a zmq message
124 111 pobj = bytes(pobj)
125 112 canned = pickle.loads(pobj)
126 113 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
127 114 for c in canned:
128 115 _restore_buffers(c, bufs)
129 116 newobj = uncan_sequence(canned, g)
130 117 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
131 118 newobj = {}
132 119 for k in sorted(canned):
133 120 c = canned[k]
134 121 _restore_buffers(c, bufs)
135 122 newobj[k] = uncan(c, g)
136 123 else:
137 124 _restore_buffers(canned, bufs)
138 125 newobj = uncan(canned, g)
139 126
140 127 return newobj, bufs
141 128
142 129 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
143 130 """pack up a function, args, and kwargs to be sent over the wire
144 131
145 132 Each element of args/kwargs will be canned for special treatment,
146 133 but inspection will not go any deeper than that.
147 134
148 135 Any object whose data is larger than `threshold` will not have their data copied
149 136 (only numpy arrays and bytes/buffers support zero-copy)
150 137
151 138 Message will be a list of bytes/buffers of the format:
152 139
153 140 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
154 141
155 142 With length at least two + len(args) + len(kwargs)
156 143 """
157 144
158 145 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
159 146
160 147 kw_keys = sorted(kwargs.keys())
161 148 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
162 149
163 150 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
164 151
165 msg = [pickle.dumps(can(f),-1)]
166 msg.append(pickle.dumps(info, -1))
152 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
153 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
167 154 msg.extend(arg_bufs)
168 155 msg.extend(kwarg_bufs)
169 156
170 157 return msg
171 158
172 159 def unpack_apply_message(bufs, g=None, copy=True):
173 160 """unpack f,args,kwargs from buffers packed by pack_apply_message()
174 161 Returns: original f,args,kwargs"""
175 162 bufs = list(bufs) # allow us to pop
176 163 assert len(bufs) >= 2, "not enough buffers!"
177 164 if not copy:
178 165 for i in range(2):
179 166 bufs[i] = bufs[i].bytes
180 167 f = uncan(pickle.loads(bufs.pop(0)), g)
181 168 info = pickle.loads(bufs.pop(0))
182 169 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
183 170
184 171 args = []
185 172 for i in range(info['nargs']):
186 173 arg, arg_bufs = unserialize_object(arg_bufs, g)
187 174 args.append(arg)
188 175 args = tuple(args)
189 176 assert not arg_bufs, "Shouldn't be any arg bufs left over"
190 177
191 178 kwargs = {}
192 179 for key in info['kw_keys']:
193 180 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
194 181 kwargs[key] = kwarg
195 182 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
196 183
197 184 return f,args,kwargs
198 185
@@ -1,856 +1,857 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 from datetime import datetime
22 22
23 23 try:
24 24 import cPickle
25 25 pickle = cPickle
26 26 except:
27 27 cPickle = None
28 28 import pickle
29 29
30 30 import zmq
31 31 from zmq.utils import jsonapi
32 32 from zmq.eventloop.ioloop import IOLoop
33 33 from zmq.eventloop.zmqstream import ZMQStream
34 34
35 35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 36 from IPython.config.configurable import Configurable, LoggingConfigurable
37 37 from IPython.utils import io
38 38 from IPython.utils.importstring import import_item
39 39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 41 iteritems)
42 42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 43 DottedObjectName, CUnicode, Dict, Integer,
44 44 TraitError,
45 45 )
46 from IPython.utils.pickleutil import PICKLE_PROTOCOL
46 47 from IPython.kernel.adapter import adapt
47 48 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48 49
49 50 #-----------------------------------------------------------------------------
50 51 # utility functions
51 52 #-----------------------------------------------------------------------------
52 53
53 54 def squash_unicode(obj):
54 55 """coerce unicode back to bytestrings."""
55 56 if isinstance(obj,dict):
56 57 for key in obj.keys():
57 58 obj[key] = squash_unicode(obj[key])
58 59 if isinstance(key, unicode_type):
59 60 obj[squash_unicode(key)] = obj.pop(key)
60 61 elif isinstance(obj, list):
61 62 for i,v in enumerate(obj):
62 63 obj[i] = squash_unicode(v)
63 64 elif isinstance(obj, unicode_type):
64 65 obj = obj.encode('utf8')
65 66 return obj
66 67
67 68 #-----------------------------------------------------------------------------
68 69 # globals and defaults
69 70 #-----------------------------------------------------------------------------
70 71
71 72 # ISO8601-ify datetime objects
72 73 # allow unicode
73 74 # disallow nan, because it's not actually valid JSON
74 75 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 76 ensure_ascii=False, allow_nan=False,
76 77 )
77 78 json_unpacker = lambda s: jsonapi.loads(s)
78 79
79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
80 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
80 81 pickle_unpacker = pickle.loads
81 82
82 83 default_packer = json_packer
83 84 default_unpacker = json_unpacker
84 85
85 86 DELIM = b"<IDS|MSG>"
86 87 # singleton dummy tracker, which will always report as done
87 88 DONE = zmq.MessageTracker()
88 89
89 90 #-----------------------------------------------------------------------------
90 91 # Mixin tools for apps that use Sessions
91 92 #-----------------------------------------------------------------------------
92 93
93 94 session_aliases = dict(
94 95 ident = 'Session.session',
95 96 user = 'Session.username',
96 97 keyfile = 'Session.keyfile',
97 98 )
98 99
99 100 session_flags = {
100 101 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 102 'keyfile' : '' }},
102 103 """Use HMAC digests for authentication of messages.
103 104 Setting this flag will generate a new UUID to use as the HMAC key.
104 105 """),
105 106 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 107 """Don't authenticate messages."""),
107 108 }
108 109
109 110 def default_secure(cfg):
110 111 """Set the default behavior for a config environment to be secure.
111 112
112 113 If Session.key/keyfile have not been set, set Session.key to
113 114 a new random UUID.
114 115 """
115 116
116 117 if 'Session' in cfg:
117 118 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 119 return
119 120 # key/keyfile not specified, generate new UUID:
120 121 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121 122
122 123
123 124 #-----------------------------------------------------------------------------
124 125 # Classes
125 126 #-----------------------------------------------------------------------------
126 127
127 128 class SessionFactory(LoggingConfigurable):
128 129 """The Base class for configurables that have a Session, Context, logger,
129 130 and IOLoop.
130 131 """
131 132
132 133 logname = Unicode('')
133 134 def _logname_changed(self, name, old, new):
134 135 self.log = logging.getLogger(new)
135 136
136 137 # not configurable:
137 138 context = Instance('zmq.Context')
138 139 def _context_default(self):
139 140 return zmq.Context.instance()
140 141
141 142 session = Instance('IPython.kernel.zmq.session.Session')
142 143
143 144 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 145 def _loop_default(self):
145 146 return IOLoop.instance()
146 147
147 148 def __init__(self, **kwargs):
148 149 super(SessionFactory, self).__init__(**kwargs)
149 150
150 151 if self.session is None:
151 152 # construct the session
152 153 self.session = Session(**kwargs)
153 154
154 155
155 156 class Message(object):
156 157 """A simple message object that maps dict keys to attributes.
157 158
158 159 A Message can be created from a dict and a dict from a Message instance
159 160 simply by calling dict(msg_obj)."""
160 161
161 162 def __init__(self, msg_dict):
162 163 dct = self.__dict__
163 164 for k, v in iteritems(dict(msg_dict)):
164 165 if isinstance(v, dict):
165 166 v = Message(v)
166 167 dct[k] = v
167 168
168 169 # Having this iterator lets dict(msg_obj) work out of the box.
169 170 def __iter__(self):
170 171 return iter(iteritems(self.__dict__))
171 172
172 173 def __repr__(self):
173 174 return repr(self.__dict__)
174 175
175 176 def __str__(self):
176 177 return pprint.pformat(self.__dict__)
177 178
178 179 def __contains__(self, k):
179 180 return k in self.__dict__
180 181
181 182 def __getitem__(self, k):
182 183 return self.__dict__[k]
183 184
184 185
185 186 def msg_header(msg_id, msg_type, username, session):
186 187 date = datetime.now()
187 188 version = kernel_protocol_version
188 189 return locals()
189 190
190 191 def extract_header(msg_or_header):
191 192 """Given a message or header, return the header."""
192 193 if not msg_or_header:
193 194 return {}
194 195 try:
195 196 # See if msg_or_header is the entire message.
196 197 h = msg_or_header['header']
197 198 except KeyError:
198 199 try:
199 200 # See if msg_or_header is just the header
200 201 h = msg_or_header['msg_id']
201 202 except KeyError:
202 203 raise
203 204 else:
204 205 h = msg_or_header
205 206 if not isinstance(h, dict):
206 207 h = dict(h)
207 208 return h
208 209
209 210 class Session(Configurable):
210 211 """Object for handling serialization and sending of messages.
211 212
212 213 The Session object handles building messages and sending them
213 214 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 215 other over the network via Session objects, and only need to work with the
215 216 dict-based IPython message spec. The Session will handle
216 217 serialization/deserialization, security, and metadata.
217 218
218 219 Sessions support configurable serialiization via packer/unpacker traits,
219 220 and signing with HMAC digests via the key/keyfile traits.
220 221
221 222 Parameters
222 223 ----------
223 224
224 225 debug : bool
225 226 whether to trigger extra debugging statements
226 227 packer/unpacker : str : 'json', 'pickle' or import_string
227 228 importstrings for methods to serialize message parts. If just
228 229 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 230 Otherwise, the entire importstring must be used.
230 231
231 232 The functions must accept at least valid JSON input, and output *bytes*.
232 233
233 234 For example, to use msgpack:
234 235 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 236 pack/unpack : callables
236 237 You can also set the pack/unpack callables for serialization directly.
237 238 session : bytes
238 239 the ID of this Session object. The default is to generate a new UUID.
239 240 username : unicode
240 241 username added to message headers. The default is to ask the OS.
241 242 key : bytes
242 243 The key used to initialize an HMAC signature. If unset, messages
243 244 will not be signed or checked.
244 245 keyfile : filepath
245 246 The file containing a key. If this is set, `key` will be initialized
246 247 to the contents of the file.
247 248
248 249 """
249 250
250 251 debug=Bool(False, config=True, help="""Debug output in the Session""")
251 252
252 253 packer = DottedObjectName('json',config=True,
253 254 help="""The name of the packer for serializing messages.
254 255 Should be one of 'json', 'pickle', or an import name
255 256 for a custom callable serializer.""")
256 257 def _packer_changed(self, name, old, new):
257 258 if new.lower() == 'json':
258 259 self.pack = json_packer
259 260 self.unpack = json_unpacker
260 261 self.unpacker = new
261 262 elif new.lower() == 'pickle':
262 263 self.pack = pickle_packer
263 264 self.unpack = pickle_unpacker
264 265 self.unpacker = new
265 266 else:
266 267 self.pack = import_item(str(new))
267 268
268 269 unpacker = DottedObjectName('json', config=True,
269 270 help="""The name of the unpacker for unserializing messages.
270 271 Only used with custom functions for `packer`.""")
271 272 def _unpacker_changed(self, name, old, new):
272 273 if new.lower() == 'json':
273 274 self.pack = json_packer
274 275 self.unpack = json_unpacker
275 276 self.packer = new
276 277 elif new.lower() == 'pickle':
277 278 self.pack = pickle_packer
278 279 self.unpack = pickle_unpacker
279 280 self.packer = new
280 281 else:
281 282 self.unpack = import_item(str(new))
282 283
283 284 session = CUnicode(u'', config=True,
284 285 help="""The UUID identifying this session.""")
285 286 def _session_default(self):
286 287 u = unicode_type(uuid.uuid4())
287 288 self.bsession = u.encode('ascii')
288 289 return u
289 290
290 291 def _session_changed(self, name, old, new):
291 292 self.bsession = self.session.encode('ascii')
292 293
293 294 # bsession is the session as bytes
294 295 bsession = CBytes(b'')
295 296
296 297 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 298 help="""Username for the Session. Default is your system username.""",
298 299 config=True)
299 300
300 301 metadata = Dict({}, config=True,
301 302 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302 303
303 304 # if 0, no adapting to do.
304 305 adapt_version = Integer(0)
305 306
306 307 # message signature related traits:
307 308
308 309 key = CBytes(b'', config=True,
309 310 help="""execution key, for extra authentication.""")
310 311 def _key_changed(self):
311 312 self._new_auth()
312 313
313 314 signature_scheme = Unicode('hmac-sha256', config=True,
314 315 help="""The digest scheme used to construct the message signatures.
315 316 Must have the form 'hmac-HASH'.""")
316 317 def _signature_scheme_changed(self, name, old, new):
317 318 if not new.startswith('hmac-'):
318 319 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
319 320 hash_name = new.split('-', 1)[1]
320 321 try:
321 322 self.digest_mod = getattr(hashlib, hash_name)
322 323 except AttributeError:
323 324 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 325 self._new_auth()
325 326
326 327 digest_mod = Any()
327 328 def _digest_mod_default(self):
328 329 return hashlib.sha256
329 330
330 331 auth = Instance(hmac.HMAC)
331 332
332 333 def _new_auth(self):
333 334 if self.key:
334 335 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 336 else:
336 337 self.auth = None
337 338
338 339 digest_history = Set()
339 340 digest_history_size = Integer(2**16, config=True,
340 341 help="""The maximum number of digests to remember.
341 342
342 343 The digest history will be culled when it exceeds this value.
343 344 """
344 345 )
345 346
346 347 keyfile = Unicode('', config=True,
347 348 help="""path to file containing execution key.""")
348 349 def _keyfile_changed(self, name, old, new):
349 350 with open(new, 'rb') as f:
350 351 self.key = f.read().strip()
351 352
352 353 # for protecting against sends from forks
353 354 pid = Integer()
354 355
355 356 # serialization traits:
356 357
357 358 pack = Any(default_packer) # the actual packer function
358 359 def _pack_changed(self, name, old, new):
359 360 if not callable(new):
360 361 raise TypeError("packer must be callable, not %s"%type(new))
361 362
362 363 unpack = Any(default_unpacker) # the actual packer function
363 364 def _unpack_changed(self, name, old, new):
364 365 # unpacker is not checked - it is assumed to be
365 366 if not callable(new):
366 367 raise TypeError("unpacker must be callable, not %s"%type(new))
367 368
368 369 # thresholds:
369 370 copy_threshold = Integer(2**16, config=True,
370 371 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
371 372 buffer_threshold = Integer(MAX_BYTES, config=True,
372 373 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
373 374 item_threshold = Integer(MAX_ITEMS, config=True,
374 375 help="""The maximum number of items for a container to be introspected for custom serialization.
375 376 Containers larger than this are pickled outright.
376 377 """
377 378 )
378 379
379 380
380 381 def __init__(self, **kwargs):
381 382 """create a Session object
382 383
383 384 Parameters
384 385 ----------
385 386
386 387 debug : bool
387 388 whether to trigger extra debugging statements
388 389 packer/unpacker : str : 'json', 'pickle' or import_string
389 390 importstrings for methods to serialize message parts. If just
390 391 'json' or 'pickle', predefined JSON and pickle packers will be used.
391 392 Otherwise, the entire importstring must be used.
392 393
393 394 The functions must accept at least valid JSON input, and output
394 395 *bytes*.
395 396
396 397 For example, to use msgpack:
397 398 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
398 399 pack/unpack : callables
399 400 You can also set the pack/unpack callables for serialization
400 401 directly.
401 402 session : unicode (must be ascii)
402 403 the ID of this Session object. The default is to generate a new
403 404 UUID.
404 405 bsession : bytes
405 406 The session as bytes
406 407 username : unicode
407 408 username added to message headers. The default is to ask the OS.
408 409 key : bytes
409 410 The key used to initialize an HMAC signature. If unset, messages
410 411 will not be signed or checked.
411 412 signature_scheme : str
412 413 The message digest scheme. Currently must be of the form 'hmac-HASH',
413 414 where 'HASH' is a hashing function available in Python's hashlib.
414 415 The default is 'hmac-sha256'.
415 416 This is ignored if 'key' is empty.
416 417 keyfile : filepath
417 418 The file containing a key. If this is set, `key` will be
418 419 initialized to the contents of the file.
419 420 """
420 421 super(Session, self).__init__(**kwargs)
421 422 self._check_packers()
422 423 self.none = self.pack({})
423 424 # ensure self._session_default() if necessary, so bsession is defined:
424 425 self.session
425 426 self.pid = os.getpid()
426 427
427 428 @property
428 429 def msg_id(self):
429 430 """always return new uuid"""
430 431 return str(uuid.uuid4())
431 432
432 433 def _check_packers(self):
433 434 """check packers for datetime support."""
434 435 pack = self.pack
435 436 unpack = self.unpack
436 437
437 438 # check simple serialization
438 439 msg = dict(a=[1,'hi'])
439 440 try:
440 441 packed = pack(msg)
441 442 except Exception as e:
442 443 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
443 444 if self.packer == 'json':
444 445 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
445 446 else:
446 447 jsonmsg = ""
447 448 raise ValueError(
448 449 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
449 450 )
450 451
451 452 # ensure packed message is bytes
452 453 if not isinstance(packed, bytes):
453 454 raise ValueError("message packed to %r, but bytes are required"%type(packed))
454 455
455 456 # check that unpack is pack's inverse
456 457 try:
457 458 unpacked = unpack(packed)
458 459 assert unpacked == msg
459 460 except Exception as e:
460 461 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
461 462 if self.packer == 'json':
462 463 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
463 464 else:
464 465 jsonmsg = ""
465 466 raise ValueError(
466 467 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
467 468 )
468 469
469 470 # check datetime support
470 471 msg = dict(t=datetime.now())
471 472 try:
472 473 unpacked = unpack(pack(msg))
473 474 if isinstance(unpacked['t'], datetime):
474 475 raise ValueError("Shouldn't deserialize to datetime")
475 476 except Exception:
476 477 self.pack = lambda o: pack(squash_dates(o))
477 478 self.unpack = lambda s: unpack(s)
478 479
479 480 def msg_header(self, msg_type):
480 481 return msg_header(self.msg_id, msg_type, self.username, self.session)
481 482
482 483 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
483 484 """Return the nested message dict.
484 485
485 486 This format is different from what is sent over the wire. The
486 487 serialize/unserialize methods converts this nested message dict to the wire
487 488 format, which is a list of message parts.
488 489 """
489 490 msg = {}
490 491 header = self.msg_header(msg_type) if header is None else header
491 492 msg['header'] = header
492 493 msg['msg_id'] = header['msg_id']
493 494 msg['msg_type'] = header['msg_type']
494 495 msg['parent_header'] = {} if parent is None else extract_header(parent)
495 496 msg['content'] = {} if content is None else content
496 497 msg['metadata'] = self.metadata.copy()
497 498 if metadata is not None:
498 499 msg['metadata'].update(metadata)
499 500 return msg
500 501
501 502 def sign(self, msg_list):
502 503 """Sign a message with HMAC digest. If no auth, return b''.
503 504
504 505 Parameters
505 506 ----------
506 507 msg_list : list
507 508 The [p_header,p_parent,p_content] part of the message list.
508 509 """
509 510 if self.auth is None:
510 511 return b''
511 512 h = self.auth.copy()
512 513 for m in msg_list:
513 514 h.update(m)
514 515 return str_to_bytes(h.hexdigest())
515 516
516 517 def serialize(self, msg, ident=None):
517 518 """Serialize the message components to bytes.
518 519
519 520 This is roughly the inverse of unserialize. The serialize/unserialize
520 521 methods work with full message lists, whereas pack/unpack work with
521 522 the individual message parts in the message list.
522 523
523 524 Parameters
524 525 ----------
525 526 msg : dict or Message
526 527 The nexted message dict as returned by the self.msg method.
527 528
528 529 Returns
529 530 -------
530 531 msg_list : list
531 532 The list of bytes objects to be sent with the format::
532 533
533 534 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
534 535 p_metadata, p_content, buffer1, buffer2, ...]
535 536
536 537 In this list, the ``p_*`` entities are the packed or serialized
537 538 versions, so if JSON is used, these are utf8 encoded JSON strings.
538 539 """
539 540 content = msg.get('content', {})
540 541 if content is None:
541 542 content = self.none
542 543 elif isinstance(content, dict):
543 544 content = self.pack(content)
544 545 elif isinstance(content, bytes):
545 546 # content is already packed, as in a relayed message
546 547 pass
547 548 elif isinstance(content, unicode_type):
548 549 # should be bytes, but JSON often spits out unicode
549 550 content = content.encode('utf8')
550 551 else:
551 552 raise TypeError("Content incorrect type: %s"%type(content))
552 553
553 554 real_message = [self.pack(msg['header']),
554 555 self.pack(msg['parent_header']),
555 556 self.pack(msg['metadata']),
556 557 content,
557 558 ]
558 559
559 560 to_send = []
560 561
561 562 if isinstance(ident, list):
562 563 # accept list of idents
563 564 to_send.extend(ident)
564 565 elif ident is not None:
565 566 to_send.append(ident)
566 567 to_send.append(DELIM)
567 568
568 569 signature = self.sign(real_message)
569 570 to_send.append(signature)
570 571
571 572 to_send.extend(real_message)
572 573
573 574 return to_send
574 575
575 576 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
576 577 buffers=None, track=False, header=None, metadata=None):
577 578 """Build and send a message via stream or socket.
578 579
579 580 The message format used by this function internally is as follows:
580 581
581 582 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
582 583 buffer1,buffer2,...]
583 584
584 585 The serialize/unserialize methods convert the nested message dict into this
585 586 format.
586 587
587 588 Parameters
588 589 ----------
589 590
590 591 stream : zmq.Socket or ZMQStream
591 592 The socket-like object used to send the data.
592 593 msg_or_type : str or Message/dict
593 594 Normally, msg_or_type will be a msg_type unless a message is being
594 595 sent more than once. If a header is supplied, this can be set to
595 596 None and the msg_type will be pulled from the header.
596 597
597 598 content : dict or None
598 599 The content of the message (ignored if msg_or_type is a message).
599 600 header : dict or None
600 601 The header dict for the message (ignored if msg_to_type is a message).
601 602 parent : Message or dict or None
602 603 The parent or parent header describing the parent of this message
603 604 (ignored if msg_or_type is a message).
604 605 ident : bytes or list of bytes
605 606 The zmq.IDENTITY routing path.
606 607 metadata : dict or None
607 608 The metadata describing the message
608 609 buffers : list or None
609 610 The already-serialized buffers to be appended to the message.
610 611 track : bool
611 612 Whether to track. Only for use with Sockets, because ZMQStream
612 613 objects cannot track messages.
613 614
614 615
615 616 Returns
616 617 -------
617 618 msg : dict
618 619 The constructed message.
619 620 """
620 621 if not isinstance(stream, zmq.Socket):
621 622 # ZMQStreams and dummy sockets do not support tracking.
622 623 track = False
623 624
624 625 if isinstance(msg_or_type, (Message, dict)):
625 626 # We got a Message or message dict, not a msg_type so don't
626 627 # build a new Message.
627 628 msg = msg_or_type
628 629 else:
629 630 msg = self.msg(msg_or_type, content=content, parent=parent,
630 631 header=header, metadata=metadata)
631 632 if not os.getpid() == self.pid:
632 633 io.rprint("WARNING: attempted to send message from fork")
633 634 io.rprint(msg)
634 635 return
635 636 buffers = [] if buffers is None else buffers
636 637 if self.adapt_version:
637 638 msg = adapt(msg, self.adapt_version)
638 639 to_send = self.serialize(msg, ident)
639 640 to_send.extend(buffers)
640 641 longest = max([ len(s) for s in to_send ])
641 642 copy = (longest < self.copy_threshold)
642 643
643 644 if buffers and track and not copy:
644 645 # only really track when we are doing zero-copy buffers
645 646 tracker = stream.send_multipart(to_send, copy=False, track=True)
646 647 else:
647 648 # use dummy tracker, which will be done immediately
648 649 tracker = DONE
649 650 stream.send_multipart(to_send, copy=copy)
650 651
651 652 if self.debug:
652 653 pprint.pprint(msg)
653 654 pprint.pprint(to_send)
654 655 pprint.pprint(buffers)
655 656
656 657 msg['tracker'] = tracker
657 658
658 659 return msg
659 660
660 661 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
661 662 """Send a raw message via ident path.
662 663
663 664 This method is used to send a already serialized message.
664 665
665 666 Parameters
666 667 ----------
667 668 stream : ZMQStream or Socket
668 669 The ZMQ stream or socket to use for sending the message.
669 670 msg_list : list
670 671 The serialized list of messages to send. This only includes the
671 672 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
672 673 the message.
673 674 ident : ident or list
674 675 A single ident or a list of idents to use in sending.
675 676 """
676 677 to_send = []
677 678 if isinstance(ident, bytes):
678 679 ident = [ident]
679 680 if ident is not None:
680 681 to_send.extend(ident)
681 682
682 683 to_send.append(DELIM)
683 684 to_send.append(self.sign(msg_list))
684 685 to_send.extend(msg_list)
685 686 stream.send_multipart(to_send, flags, copy=copy)
686 687
687 688 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
688 689 """Receive and unpack a message.
689 690
690 691 Parameters
691 692 ----------
692 693 socket : ZMQStream or Socket
693 694 The socket or stream to use in receiving.
694 695
695 696 Returns
696 697 -------
697 698 [idents], msg
698 699 [idents] is a list of idents and msg is a nested message dict of
699 700 same format as self.msg returns.
700 701 """
701 702 if isinstance(socket, ZMQStream):
702 703 socket = socket.socket
703 704 try:
704 705 msg_list = socket.recv_multipart(mode, copy=copy)
705 706 except zmq.ZMQError as e:
706 707 if e.errno == zmq.EAGAIN:
707 708 # We can convert EAGAIN to None as we know in this case
708 709 # recv_multipart won't return None.
709 710 return None,None
710 711 else:
711 712 raise
712 713 # split multipart message into identity list and message dict
713 714 # invalid large messages can cause very expensive string comparisons
714 715 idents, msg_list = self.feed_identities(msg_list, copy)
715 716 try:
716 717 return idents, self.unserialize(msg_list, content=content, copy=copy)
717 718 except Exception as e:
718 719 # TODO: handle it
719 720 raise e
720 721
721 722 def feed_identities(self, msg_list, copy=True):
722 723 """Split the identities from the rest of the message.
723 724
724 725 Feed until DELIM is reached, then return the prefix as idents and
725 726 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
726 727 but that would be silly.
727 728
728 729 Parameters
729 730 ----------
730 731 msg_list : a list of Message or bytes objects
731 732 The message to be split.
732 733 copy : bool
733 734 flag determining whether the arguments are bytes or Messages
734 735
735 736 Returns
736 737 -------
737 738 (idents, msg_list) : two lists
738 739 idents will always be a list of bytes, each of which is a ZMQ
739 740 identity. msg_list will be a list of bytes or zmq.Messages of the
740 741 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
741 742 should be unpackable/unserializable via self.unserialize at this
742 743 point.
743 744 """
744 745 if copy:
745 746 idx = msg_list.index(DELIM)
746 747 return msg_list[:idx], msg_list[idx+1:]
747 748 else:
748 749 failed = True
749 750 for idx,m in enumerate(msg_list):
750 751 if m.bytes == DELIM:
751 752 failed = False
752 753 break
753 754 if failed:
754 755 raise ValueError("DELIM not in msg_list")
755 756 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
756 757 return [m.bytes for m in idents], msg_list
757 758
758 759 def _add_digest(self, signature):
759 760 """add a digest to history to protect against replay attacks"""
760 761 if self.digest_history_size == 0:
761 762 # no history, never add digests
762 763 return
763 764
764 765 self.digest_history.add(signature)
765 766 if len(self.digest_history) > self.digest_history_size:
766 767 # threshold reached, cull 10%
767 768 self._cull_digest_history()
768 769
769 770 def _cull_digest_history(self):
770 771 """cull the digest history
771 772
772 773 Removes a randomly selected 10% of the digest history
773 774 """
774 775 current = len(self.digest_history)
775 776 n_to_cull = max(int(current // 10), current - self.digest_history_size)
776 777 if n_to_cull >= current:
777 778 self.digest_history = set()
778 779 return
779 780 to_cull = random.sample(self.digest_history, n_to_cull)
780 781 self.digest_history.difference_update(to_cull)
781 782
782 783 def unserialize(self, msg_list, content=True, copy=True):
783 784 """Unserialize a msg_list to a nested message dict.
784 785
785 786 This is roughly the inverse of serialize. The serialize/unserialize
786 787 methods work with full message lists, whereas pack/unpack work with
787 788 the individual message parts in the message list.
788 789
789 790 Parameters
790 791 ----------
791 792 msg_list : list of bytes or Message objects
792 793 The list of message parts of the form [HMAC,p_header,p_parent,
793 794 p_metadata,p_content,buffer1,buffer2,...].
794 795 content : bool (True)
795 796 Whether to unpack the content dict (True), or leave it packed
796 797 (False).
797 798 copy : bool (True)
798 799 Whether to return the bytes (True), or the non-copying Message
799 800 object in each place (False).
800 801
801 802 Returns
802 803 -------
803 804 msg : dict
804 805 The nested message dict with top-level keys [header, parent_header,
805 806 content, buffers].
806 807 """
807 808 minlen = 5
808 809 message = {}
809 810 if not copy:
810 811 for i in range(minlen):
811 812 msg_list[i] = msg_list[i].bytes
812 813 if self.auth is not None:
813 814 signature = msg_list[0]
814 815 if not signature:
815 816 raise ValueError("Unsigned Message")
816 817 if signature in self.digest_history:
817 818 raise ValueError("Duplicate Signature: %r" % signature)
818 819 self._add_digest(signature)
819 820 check = self.sign(msg_list[1:5])
820 821 if not signature == check:
821 822 raise ValueError("Invalid Signature: %r" % signature)
822 823 if not len(msg_list) >= minlen:
823 824 raise TypeError("malformed message, must have at least %i elements"%minlen)
824 825 header = self.unpack(msg_list[1])
825 826 message['header'] = extract_dates(header)
826 827 message['msg_id'] = header['msg_id']
827 828 message['msg_type'] = header['msg_type']
828 829 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
829 830 message['metadata'] = self.unpack(msg_list[3])
830 831 if content:
831 832 message['content'] = self.unpack(msg_list[4])
832 833 else:
833 834 message['content'] = msg_list[4]
834 835
835 836 message['buffers'] = msg_list[5:]
836 837 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
837 838 # adapt to the current version
838 839 return adapt(message)
839 840 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
840 841
841 842 def test_msg2obj():
842 843 am = dict(x=1)
843 844 ao = Message(am)
844 845 assert ao.x == am['x']
845 846
846 847 am['y'] = dict(z=1)
847 848 ao = Message(am)
848 849 assert ao.y.z == am['y']['z']
849 850
850 851 k1, k2 = 'y', 'z'
851 852 assert ao[k1][k2] == am[k1][k2]
852 853
853 854 am2 = dict(ao)
854 855 assert am['x'] == am2['x']
855 856 assert am['y']['z'] == am2['y']['z']
856 857
@@ -1,433 +1,438 b''
1 1 # encoding: utf-8
2 2 """Pickle related utilities. Perhaps this should be called 'can'."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import copy
8 8 import logging
9 9 import sys
10 10 from types import FunctionType
11 11
12 12 try:
13 13 import cPickle as pickle
14 14 except ImportError:
15 15 import pickle
16 16
17 17 from . import codeutil # This registers a hook when it's imported
18 18 from . import py3compat
19 19 from .importstring import import_item
20 20 from .py3compat import string_types, iteritems
21 21
22 22 from IPython.config import Application
23 23
24 24 if py3compat.PY3:
25 25 buffer = memoryview
26 26 class_type = type
27 27 else:
28 28 from types import ClassType
29 29 class_type = (type, ClassType)
30 30
31 try:
32 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
33 except AttributeError:
34 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
35
31 36 def _get_cell_type(a=None):
32 37 """the type of a closure cell doesn't seem to be importable,
33 38 so just create one
34 39 """
35 40 def inner():
36 41 return a
37 42 return type(py3compat.get_closure(inner)[0])
38 43
39 44 cell_type = _get_cell_type()
40 45
41 46 #-------------------------------------------------------------------------------
42 47 # Functions
43 48 #-------------------------------------------------------------------------------
44 49
45 50
46 51 def use_dill():
47 52 """use dill to expand serialization support
48 53
49 54 adds support for object methods and closures to serialization.
50 55 """
51 56 # import dill causes most of the magic
52 57 import dill
53 58
54 59 # dill doesn't work with cPickle,
55 60 # tell the two relevant modules to use plain pickle
56 61
57 62 global pickle
58 63 pickle = dill
59 64
60 65 try:
61 66 from IPython.kernel.zmq import serialize
62 67 except ImportError:
63 68 pass
64 69 else:
65 70 serialize.pickle = dill
66 71
67 72 # disable special function handling, let dill take care of it
68 73 can_map.pop(FunctionType, None)
69 74
70 75 def use_cloudpickle():
71 76 """use cloudpickle to expand serialization support
72 77
73 78 adds support for object methods and closures to serialization.
74 79 """
75 80 from cloud.serialization import cloudpickle
76 81
77 82 global pickle
78 83 pickle = cloudpickle
79 84
80 85 try:
81 86 from IPython.kernel.zmq import serialize
82 87 except ImportError:
83 88 pass
84 89 else:
85 90 serialize.pickle = cloudpickle
86 91
87 92 # disable special function handling, let cloudpickle take care of it
88 93 can_map.pop(FunctionType, None)
89 94
90 95
91 96 #-------------------------------------------------------------------------------
92 97 # Classes
93 98 #-------------------------------------------------------------------------------
94 99
95 100
96 101 class CannedObject(object):
97 102 def __init__(self, obj, keys=[], hook=None):
98 103 """can an object for safe pickling
99 104
100 105 Parameters
101 106 ==========
102 107
103 108 obj:
104 109 The object to be canned
105 110 keys: list (optional)
106 111 list of attribute names that will be explicitly canned / uncanned
107 112 hook: callable (optional)
108 113 An optional extra callable,
109 114 which can do additional processing of the uncanned object.
110 115
111 116 large data may be offloaded into the buffers list,
112 117 used for zero-copy transfers.
113 118 """
114 119 self.keys = keys
115 120 self.obj = copy.copy(obj)
116 121 self.hook = can(hook)
117 122 for key in keys:
118 123 setattr(self.obj, key, can(getattr(obj, key)))
119 124
120 125 self.buffers = []
121 126
122 127 def get_object(self, g=None):
123 128 if g is None:
124 129 g = {}
125 130 obj = self.obj
126 131 for key in self.keys:
127 132 setattr(obj, key, uncan(getattr(obj, key), g))
128 133
129 134 if self.hook:
130 135 self.hook = uncan(self.hook, g)
131 136 self.hook(obj, g)
132 137 return self.obj
133 138
134 139
135 140 class Reference(CannedObject):
136 141 """object for wrapping a remote reference by name."""
137 142 def __init__(self, name):
138 143 if not isinstance(name, string_types):
139 144 raise TypeError("illegal name: %r"%name)
140 145 self.name = name
141 146 self.buffers = []
142 147
143 148 def __repr__(self):
144 149 return "<Reference: %r>"%self.name
145 150
146 151 def get_object(self, g=None):
147 152 if g is None:
148 153 g = {}
149 154
150 155 return eval(self.name, g)
151 156
152 157
153 158 class CannedCell(CannedObject):
154 159 """Can a closure cell"""
155 160 def __init__(self, cell):
156 161 self.cell_contents = can(cell.cell_contents)
157 162
158 163 def get_object(self, g=None):
159 164 cell_contents = uncan(self.cell_contents, g)
160 165 def inner():
161 166 return cell_contents
162 167 return py3compat.get_closure(inner)[0]
163 168
164 169
165 170 class CannedFunction(CannedObject):
166 171
167 172 def __init__(self, f):
168 173 self._check_type(f)
169 174 self.code = f.__code__
170 175 if f.__defaults__:
171 176 self.defaults = [ can(fd) for fd in f.__defaults__ ]
172 177 else:
173 178 self.defaults = None
174 179
175 180 closure = py3compat.get_closure(f)
176 181 if closure:
177 182 self.closure = tuple( can(cell) for cell in closure )
178 183 else:
179 184 self.closure = None
180 185
181 186 self.module = f.__module__ or '__main__'
182 187 self.__name__ = f.__name__
183 188 self.buffers = []
184 189
185 190 def _check_type(self, obj):
186 191 assert isinstance(obj, FunctionType), "Not a function type"
187 192
188 193 def get_object(self, g=None):
189 194 # try to load function back into its module:
190 195 if not self.module.startswith('__'):
191 196 __import__(self.module)
192 197 g = sys.modules[self.module].__dict__
193 198
194 199 if g is None:
195 200 g = {}
196 201 if self.defaults:
197 202 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
198 203 else:
199 204 defaults = None
200 205 if self.closure:
201 206 closure = tuple(uncan(cell, g) for cell in self.closure)
202 207 else:
203 208 closure = None
204 209 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
205 210 return newFunc
206 211
207 212 class CannedClass(CannedObject):
208 213
209 214 def __init__(self, cls):
210 215 self._check_type(cls)
211 216 self.name = cls.__name__
212 217 self.old_style = not isinstance(cls, type)
213 218 self._canned_dict = {}
214 219 for k,v in cls.__dict__.items():
215 220 if k not in ('__weakref__', '__dict__'):
216 221 self._canned_dict[k] = can(v)
217 222 if self.old_style:
218 223 mro = []
219 224 else:
220 225 mro = cls.mro()
221 226
222 227 self.parents = [ can(c) for c in mro[1:] ]
223 228 self.buffers = []
224 229
225 230 def _check_type(self, obj):
226 231 assert isinstance(obj, class_type), "Not a class type"
227 232
228 233 def get_object(self, g=None):
229 234 parents = tuple(uncan(p, g) for p in self.parents)
230 235 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
231 236
232 237 class CannedArray(CannedObject):
233 238 def __init__(self, obj):
234 239 from numpy import ascontiguousarray
235 240 self.shape = obj.shape
236 241 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
237 242 self.pickled = False
238 243 if sum(obj.shape) == 0:
239 244 self.pickled = True
240 245 elif obj.dtype == 'O':
241 246 # can't handle object dtype with buffer approach
242 247 self.pickled = True
243 248 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
244 249 self.pickled = True
245 250 if self.pickled:
246 251 # just pickle it
247 self.buffers = [pickle.dumps(obj, -1)]
252 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
248 253 else:
249 254 # ensure contiguous
250 255 obj = ascontiguousarray(obj, dtype=None)
251 256 self.buffers = [buffer(obj)]
252 257
253 258 def get_object(self, g=None):
254 259 from numpy import frombuffer
255 260 data = self.buffers[0]
256 261 if self.pickled:
257 262 # no shape, we just pickled it
258 263 return pickle.loads(data)
259 264 else:
260 265 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
261 266
262 267
263 268 class CannedBytes(CannedObject):
264 269 wrap = bytes
265 270 def __init__(self, obj):
266 271 self.buffers = [obj]
267 272
268 273 def get_object(self, g=None):
269 274 data = self.buffers[0]
270 275 return self.wrap(data)
271 276
272 277 def CannedBuffer(CannedBytes):
273 278 wrap = buffer
274 279
275 280 #-------------------------------------------------------------------------------
276 281 # Functions
277 282 #-------------------------------------------------------------------------------
278 283
279 284 def _logger():
280 285 """get the logger for the current Application
281 286
282 287 the root logger will be used if no Application is running
283 288 """
284 289 if Application.initialized():
285 290 logger = Application.instance().log
286 291 else:
287 292 logger = logging.getLogger()
288 293 if not logger.handlers:
289 294 logging.basicConfig()
290 295
291 296 return logger
292 297
293 298 def _import_mapping(mapping, original=None):
294 299 """import any string-keys in a type mapping
295 300
296 301 """
297 302 log = _logger()
298 303 log.debug("Importing canning map")
299 304 for key,value in list(mapping.items()):
300 305 if isinstance(key, string_types):
301 306 try:
302 307 cls = import_item(key)
303 308 except Exception:
304 309 if original and key not in original:
305 310 # only message on user-added classes
306 311 log.error("canning class not importable: %r", key, exc_info=True)
307 312 mapping.pop(key)
308 313 else:
309 314 mapping[cls] = mapping.pop(key)
310 315
311 316 def istype(obj, check):
312 317 """like isinstance(obj, check), but strict
313 318
314 319 This won't catch subclasses.
315 320 """
316 321 if isinstance(check, tuple):
317 322 for cls in check:
318 323 if type(obj) is cls:
319 324 return True
320 325 return False
321 326 else:
322 327 return type(obj) is check
323 328
324 329 def can(obj):
325 330 """prepare an object for pickling"""
326 331
327 332 import_needed = False
328 333
329 334 for cls,canner in iteritems(can_map):
330 335 if isinstance(cls, string_types):
331 336 import_needed = True
332 337 break
333 338 elif istype(obj, cls):
334 339 return canner(obj)
335 340
336 341 if import_needed:
337 342 # perform can_map imports, then try again
338 343 # this will usually only happen once
339 344 _import_mapping(can_map, _original_can_map)
340 345 return can(obj)
341 346
342 347 return obj
343 348
344 349 def can_class(obj):
345 350 if isinstance(obj, class_type) and obj.__module__ == '__main__':
346 351 return CannedClass(obj)
347 352 else:
348 353 return obj
349 354
350 355 def can_dict(obj):
351 356 """can the *values* of a dict"""
352 357 if istype(obj, dict):
353 358 newobj = {}
354 359 for k, v in iteritems(obj):
355 360 newobj[k] = can(v)
356 361 return newobj
357 362 else:
358 363 return obj
359 364
360 365 sequence_types = (list, tuple, set)
361 366
362 367 def can_sequence(obj):
363 368 """can the elements of a sequence"""
364 369 if istype(obj, sequence_types):
365 370 t = type(obj)
366 371 return t([can(i) for i in obj])
367 372 else:
368 373 return obj
369 374
370 375 def uncan(obj, g=None):
371 376 """invert canning"""
372 377
373 378 import_needed = False
374 379 for cls,uncanner in iteritems(uncan_map):
375 380 if isinstance(cls, string_types):
376 381 import_needed = True
377 382 break
378 383 elif isinstance(obj, cls):
379 384 return uncanner(obj, g)
380 385
381 386 if import_needed:
382 387 # perform uncan_map imports, then try again
383 388 # this will usually only happen once
384 389 _import_mapping(uncan_map, _original_uncan_map)
385 390 return uncan(obj, g)
386 391
387 392 return obj
388 393
389 394 def uncan_dict(obj, g=None):
390 395 if istype(obj, dict):
391 396 newobj = {}
392 397 for k, v in iteritems(obj):
393 398 newobj[k] = uncan(v,g)
394 399 return newobj
395 400 else:
396 401 return obj
397 402
398 403 def uncan_sequence(obj, g=None):
399 404 if istype(obj, sequence_types):
400 405 t = type(obj)
401 406 return t([uncan(i,g) for i in obj])
402 407 else:
403 408 return obj
404 409
405 410 def _uncan_dependent_hook(dep, g=None):
406 411 dep.check_dependency()
407 412
408 413 def can_dependent(obj):
409 414 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
410 415
411 416 #-------------------------------------------------------------------------------
412 417 # API dictionaries
413 418 #-------------------------------------------------------------------------------
414 419
415 420 # These dicts can be extended for custom serialization of new objects
416 421
417 422 can_map = {
418 423 'IPython.parallel.dependent' : can_dependent,
419 424 'numpy.ndarray' : CannedArray,
420 425 FunctionType : CannedFunction,
421 426 bytes : CannedBytes,
422 427 buffer : CannedBuffer,
423 428 cell_type : CannedCell,
424 429 class_type : can_class,
425 430 }
426 431
427 432 uncan_map = {
428 433 CannedObject : lambda obj, g: obj.get_object(g),
429 434 }
430 435
431 436 # for use in _import_mapping:
432 437 _original_can_map = can_map.copy()
433 438 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now