pickleutil.py
244 lines
| 6.6 KiB
| text/x-python
|
PythonLexer
|
r3539 | # encoding: utf-8 | ||
"""Pickle related utilities. Perhaps this should be called 'can'.""" | ||||
__docformat__ = "restructuredtext en" | ||||
#------------------------------------------------------------------------------- | ||||
|
r5390 | # Copyright (C) 2008-2011 The IPython Development Team | ||
|
r3539 | # | ||
# Distributed under the terms of the BSD License. The full license is in | ||||
# the file COPYING, distributed as part of this software. | ||||
#------------------------------------------------------------------------------- | ||||
#------------------------------------------------------------------------------- | ||||
# Imports | ||||
#------------------------------------------------------------------------------- | ||||
|
r3607 | import copy | ||
|
r8034 | import logging | ||
|
r3664 | import sys | ||
from types import FunctionType | ||||
|
r3607 | |||
|
r7967 | try: | ||
import cPickle as pickle | ||||
except ImportError: | ||||
import pickle | ||||
try: | ||||
import numpy | ||||
except: | ||||
numpy = None | ||||
|
r3557 | import codeutil | ||
|
r7967 | import py3compat | ||
from importstring import import_item | ||||
|
r8034 | from IPython.config import Application | ||
|
r7967 | if py3compat.PY3: | ||
buffer = memoryview | ||||
|
r3539 | |||
|
r3607 | #------------------------------------------------------------------------------- | ||
# Classes | ||||
#------------------------------------------------------------------------------- | ||||
|
r3539 | class CannedObject(object): | ||
|
r3546 | def __init__(self, obj, keys=[]): | ||
self.keys = keys | ||||
|
r3607 | self.obj = copy.copy(obj) | ||
|
r3546 | for key in keys: | ||
|
r3607 | setattr(self.obj, key, can(getattr(obj, key))) | ||
|
r7967 | |||
self.buffers = [] | ||||
|
r4872 | |||
|
r7967 | def get_object(self, g=None): | ||
|
r3546 | if g is None: | ||
|
r7967 | g = {} | ||
|
r3546 | for key in self.keys: | ||
setattr(self.obj, key, uncan(getattr(self.obj, key), g)) | ||||
return self.obj | ||||
|
r7967 | |||
|
r3546 | |||
|
r3643 | class Reference(CannedObject): | ||
"""object for wrapping a remote reference by name.""" | ||||
def __init__(self, name): | ||||
if not isinstance(name, basestring): | ||||
raise TypeError("illegal name: %r"%name) | ||||
self.name = name | ||||
|
r7967 | self.buffers = [] | ||
|
r4872 | |||
|
r3643 | def __repr__(self): | ||
return "<Reference: %r>"%self.name | ||||
|
r4872 | |||
|
r7967 | def get_object(self, g=None): | ||
|
r3643 | if g is None: | ||
|
r7967 | g = {} | ||
|
r6159 | |||
return eval(self.name, g) | ||||
|
r4872 | |||
|
r3546 | |||
|
r3539 | class CannedFunction(CannedObject): | ||
|
r4872 | |||
|
r3539 | def __init__(self, f): | ||
|
r7967 | self._check_type(f) | ||
|
r3539 | self.code = f.func_code | ||
|
r8041 | if f.func_defaults: | ||
self.defaults = [ can(fd) for fd in f.func_defaults ] | ||||
else: | ||||
self.defaults = None | ||||
|
r3664 | self.module = f.__module__ or '__main__' | ||
|
r3607 | self.__name__ = f.__name__ | ||
|
r7967 | self.buffers = [] | ||
|
r4872 | |||
|
r7967 | def _check_type(self, obj): | ||
|
r3539 | assert isinstance(obj, FunctionType), "Not a function type" | ||
|
r4872 | |||
|
r7967 | def get_object(self, g=None): | ||
|
r3664 | # try to load function back into its module: | ||
if not self.module.startswith('__'): | ||||
try: | ||||
__import__(self.module) | ||||
except ImportError: | ||||
pass | ||||
else: | ||||
g = sys.modules[self.module].__dict__ | ||||
|
r4872 | |||
|
r3539 | if g is None: | ||
|
r7967 | g = {} | ||
|
r8041 | if self.defaults: | ||
defaults = tuple(uncan(cfd, g) for cfd in self.defaults) | ||||
else: | ||||
defaults = None | ||||
newFunc = FunctionType(self.code, g, self.__name__, defaults) | ||||
|
r3539 | return newFunc | ||
|
r7967 | |||
class CannedArray(CannedObject): | ||||
def __init__(self, obj): | ||||
self.shape = obj.shape | ||||
|
r7971 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str | ||
|
r7967 | if sum(obj.shape) == 0: | ||
# just pickle it | ||||
self.buffers = [pickle.dumps(obj, -1)] | ||||
else: | ||||
# ensure contiguous | ||||
obj = numpy.ascontiguousarray(obj, dtype=None) | ||||
self.buffers = [buffer(obj)] | ||||
def get_object(self, g=None): | ||||
data = self.buffers[0] | ||||
if sum(self.shape) == 0: | ||||
# no shape, we just pickled it | ||||
return pickle.loads(data) | ||||
else: | ||||
return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape) | ||||
class CannedBytes(CannedObject): | ||||
wrap = bytes | ||||
def __init__(self, obj): | ||||
self.buffers = [obj] | ||||
def get_object(self, g=None): | ||||
data = self.buffers[0] | ||||
return self.wrap(data) | ||||
def CannedBuffer(CannedBytes): | ||||
wrap = buffer | ||||
|
r3607 | #------------------------------------------------------------------------------- | ||
# Functions | ||||
#------------------------------------------------------------------------------- | ||||
|
r8034 | def _error(*args, **kwargs): | ||
if Application.initialized(): | ||||
logger = Application.instance().log | ||||
else: | ||||
logger = logging.getLogger() | ||||
if not logger.handlers: | ||||
logging.basicConfig() | ||||
logger.error(*args, **kwargs) | ||||
|
r3539 | |||
|
r7967 | def can(obj): | ||
"""prepare an object for pickling""" | ||||
for cls,canner in can_map.iteritems(): | ||||
if isinstance(cls, basestring): | ||||
try: | ||||
cls = import_item(cls) | ||||
except Exception: | ||||
|
r8034 | _error("cannning class not importable: %r", cls, exc_info=True) | ||
cls = None | ||||
|
r7967 | continue | ||
if isinstance(obj, cls): | ||||
return canner(obj) | ||||
return obj | ||||
def can_dict(obj): | ||||
"""can the *values* of a dict""" | ||||
|
r3539 | if isinstance(obj, dict): | ||
newobj = {} | ||||
for k, v in obj.iteritems(): | ||||
newobj[k] = can(v) | ||||
return newobj | ||||
else: | ||||
return obj | ||||
|
r7967 | def can_sequence(obj): | ||
"""can the elements of a sequence""" | ||||
|
r3539 | if isinstance(obj, (list, tuple)): | ||
t = type(obj) | ||||
return t([can(i) for i in obj]) | ||||
else: | ||||
return obj | ||||
def uncan(obj, g=None): | ||||
|
r7967 | """invert canning""" | ||
for cls,uncanner in uncan_map.iteritems(): | ||||
if isinstance(cls, basestring): | ||||
try: | ||||
cls = import_item(cls) | ||||
except Exception: | ||||
|
r8034 | _error("uncanning class not importable: %r", cls, exc_info=True) | ||
cls = None | ||||
|
r7967 | continue | ||
if isinstance(obj, cls): | ||||
return uncanner(obj, g) | ||||
return obj | ||||
def uncan_dict(obj, g=None): | ||||
|
r3539 | if isinstance(obj, dict): | ||
newobj = {} | ||||
for k, v in obj.iteritems(): | ||||
newobj[k] = uncan(v,g) | ||||
return newobj | ||||
else: | ||||
return obj | ||||
|
r7967 | def uncan_sequence(obj, g=None): | ||
|
r3539 | if isinstance(obj, (list, tuple)): | ||
t = type(obj) | ||||
return t([uncan(i,g) for i in obj]) | ||||
else: | ||||
return obj | ||||
|
r7967 | #------------------------------------------------------------------------------- | ||
# API dictionary | ||||
#------------------------------------------------------------------------------- | ||||
# These dicts can be extended for custom serialization of new objects | ||||
can_map = { | ||||
'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')), | ||||
'numpy.ndarray' : CannedArray, | ||||
FunctionType : CannedFunction, | ||||
bytes : CannedBytes, | ||||
buffer : CannedBuffer, | ||||
} | ||||
uncan_map = { | ||||
CannedObject : lambda obj, g: obj.get_object(g), | ||||
} | ||||