##// END OF EJS Templates
add utils.pickleutil.use_dill...
MinRK -
Show More
@@ -1,352 +1,382 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import sys
21 21 from types import FunctionType
22 22
23 23 try:
24 24 import cPickle as pickle
25 25 except ImportError:
26 26 import pickle
27 27
28 28 from . import codeutil # This registers a hook when it's imported
29 29 from . import py3compat
30 30 from .importstring import import_item
31 31 from .py3compat import string_types, iteritems
32 32
33 33 from IPython.config import Application
34 34
35 35 if py3compat.PY3:
36 36 buffer = memoryview
37 37 class_type = type
38 38 else:
39 39 from types import ClassType
40 40 class_type = (type, ClassType)
41 41
42 42 #-------------------------------------------------------------------------------
43 # Functions
44 #-------------------------------------------------------------------------------
45
46
47 def use_dill():
48 """use dill to expand serialization support
49
50 adds support for object methods and closures to serialization.
51 """
52 # import dill causes most of the magic
53 import dill
54
55 # dill doesn't work with cPickle,
56 # tell the two relevant modules to use plain pickle
57
58 global pickle
59 import pickle
60
61 try:
62 from IPython.kernel.zmq import serialize
63 except ImportError:
64 pass
65 else:
66 serialize.pickle = pickle
67
68 # disable special function handling, let dill take care of it
69 can_map.pop(FunctionType, None)
70
71
72 #-------------------------------------------------------------------------------
43 73 # Classes
44 74 #-------------------------------------------------------------------------------
45 75
46 76
47 77 class CannedObject(object):
48 78 def __init__(self, obj, keys=[], hook=None):
49 79 """can an object for safe pickling
50 80
51 81 Parameters
52 82 ==========
53 83
54 84 obj:
55 85 The object to be canned
56 86 keys: list (optional)
57 87 list of attribute names that will be explicitly canned / uncanned
58 88 hook: callable (optional)
59 89 An optional extra callable,
60 90 which can do additional processing of the uncanned object.
61 91
62 92 large data may be offloaded into the buffers list,
63 93 used for zero-copy transfers.
64 94 """
65 95 self.keys = keys
66 96 self.obj = copy.copy(obj)
67 97 self.hook = can(hook)
68 98 for key in keys:
69 99 setattr(self.obj, key, can(getattr(obj, key)))
70 100
71 101 self.buffers = []
72 102
73 103 def get_object(self, g=None):
74 104 if g is None:
75 105 g = {}
76 106 obj = self.obj
77 107 for key in self.keys:
78 108 setattr(obj, key, uncan(getattr(obj, key), g))
79 109
80 110 if self.hook:
81 111 self.hook = uncan(self.hook, g)
82 112 self.hook(obj, g)
83 113 return self.obj
84 114
85 115
86 116 class Reference(CannedObject):
87 117 """object for wrapping a remote reference by name."""
88 118 def __init__(self, name):
89 119 if not isinstance(name, string_types):
90 120 raise TypeError("illegal name: %r"%name)
91 121 self.name = name
92 122 self.buffers = []
93 123
94 124 def __repr__(self):
95 125 return "<Reference: %r>"%self.name
96 126
97 127 def get_object(self, g=None):
98 128 if g is None:
99 129 g = {}
100 130
101 131 return eval(self.name, g)
102 132
103 133
104 134 class CannedFunction(CannedObject):
105 135
106 136 def __init__(self, f):
107 137 self._check_type(f)
108 138 self.code = f.__code__
109 139 if f.__defaults__:
110 140 self.defaults = [ can(fd) for fd in f.__defaults__ ]
111 141 else:
112 142 self.defaults = None
113 143 self.module = f.__module__ or '__main__'
114 144 self.__name__ = f.__name__
115 145 self.buffers = []
116 146
117 147 def _check_type(self, obj):
118 148 assert isinstance(obj, FunctionType), "Not a function type"
119 149
120 150 def get_object(self, g=None):
121 151 # try to load function back into its module:
122 152 if not self.module.startswith('__'):
123 153 __import__(self.module)
124 154 g = sys.modules[self.module].__dict__
125 155
126 156 if g is None:
127 157 g = {}
128 158 if self.defaults:
129 159 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
130 160 else:
131 161 defaults = None
132 162 newFunc = FunctionType(self.code, g, self.__name__, defaults)
133 163 return newFunc
134 164
135 165 class CannedClass(CannedObject):
136 166
137 167 def __init__(self, cls):
138 168 self._check_type(cls)
139 169 self.name = cls.__name__
140 170 self.old_style = not isinstance(cls, type)
141 171 self._canned_dict = {}
142 172 for k,v in cls.__dict__.items():
143 173 if k not in ('__weakref__', '__dict__'):
144 174 self._canned_dict[k] = can(v)
145 175 if self.old_style:
146 176 mro = []
147 177 else:
148 178 mro = cls.mro()
149 179
150 180 self.parents = [ can(c) for c in mro[1:] ]
151 181 self.buffers = []
152 182
153 183 def _check_type(self, obj):
154 184 assert isinstance(obj, class_type), "Not a class type"
155 185
156 186 def get_object(self, g=None):
157 187 parents = tuple(uncan(p, g) for p in self.parents)
158 188 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
159 189
160 190 class CannedArray(CannedObject):
161 191 def __init__(self, obj):
162 192 from numpy import ascontiguousarray
163 193 self.shape = obj.shape
164 194 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
165 195 if sum(obj.shape) == 0:
166 196 # just pickle it
167 197 self.buffers = [pickle.dumps(obj, -1)]
168 198 else:
169 199 # ensure contiguous
170 200 obj = ascontiguousarray(obj, dtype=None)
171 201 self.buffers = [buffer(obj)]
172 202
173 203 def get_object(self, g=None):
174 204 from numpy import frombuffer
175 205 data = self.buffers[0]
176 206 if sum(self.shape) == 0:
177 207 # no shape, we just pickled it
178 208 return pickle.loads(data)
179 209 else:
180 210 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
181 211
182 212
183 213 class CannedBytes(CannedObject):
184 214 wrap = bytes
185 215 def __init__(self, obj):
186 216 self.buffers = [obj]
187 217
188 218 def get_object(self, g=None):
189 219 data = self.buffers[0]
190 220 return self.wrap(data)
191 221
192 222 def CannedBuffer(CannedBytes):
193 223 wrap = buffer
194 224
195 225 #-------------------------------------------------------------------------------
196 226 # Functions
197 227 #-------------------------------------------------------------------------------
198 228
199 229 def _logger():
200 230 """get the logger for the current Application
201 231
202 232 the root logger will be used if no Application is running
203 233 """
204 234 if Application.initialized():
205 235 logger = Application.instance().log
206 236 else:
207 237 logger = logging.getLogger()
208 238 if not logger.handlers:
209 239 logging.basicConfig()
210 240
211 241 return logger
212 242
213 243 def _import_mapping(mapping, original=None):
214 244 """import any string-keys in a type mapping
215 245
216 246 """
217 247 log = _logger()
218 248 log.debug("Importing canning map")
219 249 for key,value in list(mapping.items()):
220 250 if isinstance(key, string_types):
221 251 try:
222 252 cls = import_item(key)
223 253 except Exception:
224 254 if original and key not in original:
225 255 # only message on user-added classes
226 256 log.error("canning class not importable: %r", key, exc_info=True)
227 257 mapping.pop(key)
228 258 else:
229 259 mapping[cls] = mapping.pop(key)
230 260
231 261 def istype(obj, check):
232 262 """like isinstance(obj, check), but strict
233 263
234 264 This won't catch subclasses.
235 265 """
236 266 if isinstance(check, tuple):
237 267 for cls in check:
238 268 if type(obj) is cls:
239 269 return True
240 270 return False
241 271 else:
242 272 return type(obj) is check
243 273
244 274 def can(obj):
245 275 """prepare an object for pickling"""
246 276
247 277 import_needed = False
248 278
249 279 for cls,canner in iteritems(can_map):
250 280 if isinstance(cls, string_types):
251 281 import_needed = True
252 282 break
253 283 elif istype(obj, cls):
254 284 return canner(obj)
255 285
256 286 if import_needed:
257 287 # perform can_map imports, then try again
258 288 # this will usually only happen once
259 289 _import_mapping(can_map, _original_can_map)
260 290 return can(obj)
261 291
262 292 return obj
263 293
264 294 def can_class(obj):
265 295 if isinstance(obj, class_type) and obj.__module__ == '__main__':
266 296 return CannedClass(obj)
267 297 else:
268 298 return obj
269 299
270 300 def can_dict(obj):
271 301 """can the *values* of a dict"""
272 302 if istype(obj, dict):
273 303 newobj = {}
274 304 for k, v in iteritems(obj):
275 305 newobj[k] = can(v)
276 306 return newobj
277 307 else:
278 308 return obj
279 309
280 310 sequence_types = (list, tuple, set)
281 311
282 312 def can_sequence(obj):
283 313 """can the elements of a sequence"""
284 314 if istype(obj, sequence_types):
285 315 t = type(obj)
286 316 return t([can(i) for i in obj])
287 317 else:
288 318 return obj
289 319
290 320 def uncan(obj, g=None):
291 321 """invert canning"""
292 322
293 323 import_needed = False
294 324 for cls,uncanner in iteritems(uncan_map):
295 325 if isinstance(cls, string_types):
296 326 import_needed = True
297 327 break
298 328 elif isinstance(obj, cls):
299 329 return uncanner(obj, g)
300 330
301 331 if import_needed:
302 332 # perform uncan_map imports, then try again
303 333 # this will usually only happen once
304 334 _import_mapping(uncan_map, _original_uncan_map)
305 335 return uncan(obj, g)
306 336
307 337 return obj
308 338
309 339 def uncan_dict(obj, g=None):
310 340 if istype(obj, dict):
311 341 newobj = {}
312 342 for k, v in iteritems(obj):
313 343 newobj[k] = uncan(v,g)
314 344 return newobj
315 345 else:
316 346 return obj
317 347
318 348 def uncan_sequence(obj, g=None):
319 349 if istype(obj, sequence_types):
320 350 t = type(obj)
321 351 return t([uncan(i,g) for i in obj])
322 352 else:
323 353 return obj
324 354
325 355 def _uncan_dependent_hook(dep, g=None):
326 356 dep.check_dependency()
327 357
328 358 def can_dependent(obj):
329 359 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
330 360
331 361 #-------------------------------------------------------------------------------
332 362 # API dictionaries
333 363 #-------------------------------------------------------------------------------
334 364
335 365 # These dicts can be extended for custom serialization of new objects
336 366
337 367 can_map = {
338 368 'IPython.parallel.dependent' : can_dependent,
339 369 'numpy.ndarray' : CannedArray,
340 370 FunctionType : CannedFunction,
341 371 bytes : CannedBytes,
342 372 buffer : CannedBuffer,
343 373 class_type : can_class,
344 374 }
345 375
346 376 uncan_map = {
347 377 CannedObject : lambda obj, g: obj.get_object(g),
348 378 }
349 379
350 380 # for use in _import_mapping:
351 381 _original_can_map = can_map.copy()
352 382 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now