From 6e68a6cd91f20e4bcef6eb57b7e126c3c9bbfaa4 2012-07-21 06:11:41 From: MinRK Date: 2012-07-21 06:11:41 Subject: [PATCH] better serialization for parallel code * no longer use newserialized * no longer double-pickle anything * can/pickleutil handles arrays * pickleutil has [un]can_map dicts for extensibility * don't can each element of long sequences --- diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index e13399c..fca0956 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -239,7 +239,7 @@ class TestView(ClusterTestCase, ParametricTestCase): from numpy.testing.utils import assert_array_equal, assert_array_almost_equal view = self.client[:] a = numpy.arange(64) - view.scatter('a', a) + view.scatter('a', a, block=True) b = view.gather('a', block=True) assert_array_equal(b, a) @@ -325,7 +325,7 @@ class TestView(ClusterTestCase, ParametricTestCase): r = view.map_sync(lambda x:x, arr) self.assertEqual(r, list(arr)) - def test_scatterGatherNonblocking(self): + def test_scatter_gather_nonblocking(self): data = range(16) view = self.client[:] view.scatter('a', data, block=False) diff --git a/IPython/parallel/util.py b/IPython/parallel/util.py index 61cbb5d..0e02fed 100644 --- a/IPython/parallel/util.py +++ b/IPython/parallel/util.py @@ -43,17 +43,11 @@ from IPython.external.decorator import decorator # IPython imports from IPython.config.application import Application -from IPython.utils import py3compat -from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence -from IPython.utils.newserialized import serialize, unserialize from IPython.zmq.log import EnginePUBHandler from IPython.zmq.serialize import ( unserialize_object, serialize_object, pack_apply_message, unpack_apply_message ) -if py3compat.PY3: - buffer = memoryview - #----------------------------------------------------------------------------- # Classes #----------------------------------------------------------------------------- diff --git a/IPython/utils/pickleutil.py b/IPython/utils/pickleutil.py index 695d880..ed78ada 100644 --- a/IPython/utils/pickleutil.py +++ b/IPython/utils/pickleutil.py @@ -19,7 +19,22 @@ import copy import sys from types import FunctionType +try: + import cPickle as pickle +except ImportError: + import pickle + +try: + import numpy +except: + numpy = None + import codeutil +import py3compat +from importstring import import_item + +if py3compat.PY3: + buffer = memoryview #------------------------------------------------------------------------------- # Classes @@ -32,14 +47,16 @@ class CannedObject(object): self.obj = copy.copy(obj) for key in keys: setattr(self.obj, key, can(getattr(obj, key))) + + self.buffers = [] - - def getObject(self, g=None): + def get_object(self, g=None): if g is None: - g = globals() + g = {} for key in self.keys: setattr(self.obj, key, uncan(getattr(self.obj, key), g)) return self.obj + class Reference(CannedObject): """object for wrapping a remote reference by name.""" @@ -47,13 +64,14 @@ class Reference(CannedObject): if not isinstance(name, basestring): raise TypeError("illegal name: %r"%name) self.name = name + self.buffers = [] def __repr__(self): return ""%self.name - def getObject(self, g=None): + def get_object(self, g=None): if g is None: - g = globals() + g = {} return eval(self.name, g) @@ -61,16 +79,17 @@ class Reference(CannedObject): class CannedFunction(CannedObject): def __init__(self, f): - self._checkType(f) + self._check_type(f) self.code = f.func_code self.defaults = f.func_defaults self.module = f.__module__ or '__main__' self.__name__ = f.__name__ + self.buffers = [] - def _checkType(self, obj): + def _check_type(self, obj): assert isinstance(obj, FunctionType), "Not a function type" - def getObject(self, g=None): + def get_object(self, g=None): # try to load function back into its module: if not self.module.startswith('__'): try: @@ -81,30 +100,65 @@ class CannedFunction(CannedObject): g = sys.modules[self.module].__dict__ if g is None: - g = globals() + g = {} newFunc = FunctionType(self.code, g, self.__name__, self.defaults) return newFunc + +class CannedArray(CannedObject): + def __init__(self, obj): + self.shape = obj.shape + self.dtype = obj.dtype + if sum(obj.shape) == 0: + # just pickle it + self.buffers = [pickle.dumps(obj, -1)] + else: + # ensure contiguous + obj = numpy.ascontiguousarray(obj, dtype=None) + self.buffers = [buffer(obj)] + + def get_object(self, g=None): + data = self.buffers[0] + if sum(self.shape) == 0: + # no shape, we just pickled it + return pickle.loads(data) + else: + return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape) + + +class CannedBytes(CannedObject): + wrap = bytes + def __init__(self, obj): + self.buffers = [obj] + + def get_object(self, g=None): + data = self.buffers[0] + return self.wrap(data) + +def CannedBuffer(CannedBytes): + wrap = buffer + #------------------------------------------------------------------------------- # Functions #------------------------------------------------------------------------------- -def can(obj): - # import here to prevent module-level circular imports - from IPython.parallel import dependent - if isinstance(obj, dependent): - keys = ('f','df') - return CannedObject(obj, keys=keys) - elif isinstance(obj, FunctionType): - return CannedFunction(obj) - elif isinstance(obj,dict): - return canDict(obj) - elif isinstance(obj, (list,tuple)): - return canSequence(obj) - else: - return obj -def canDict(obj): +def can(obj): + """prepare an object for pickling""" + for cls,canner in can_map.iteritems(): + if isinstance(cls, basestring): + try: + cls = import_item(cls) + except Exception: + # not importable + print "not importable: %r" % cls + continue + if isinstance(obj, cls): + return canner(obj) + return obj + +def can_dict(obj): + """can the *values* of a dict""" if isinstance(obj, dict): newobj = {} for k, v in obj.iteritems(): @@ -113,7 +167,8 @@ def canDict(obj): else: return obj -def canSequence(obj): +def can_sequence(obj): + """can the elements of a sequence""" if isinstance(obj, (list, tuple)): t = type(obj) return t([can(i) for i in obj]) @@ -121,16 +176,20 @@ def canSequence(obj): return obj def uncan(obj, g=None): - if isinstance(obj, CannedObject): - return obj.getObject(g) - elif isinstance(obj,dict): - return uncanDict(obj, g) - elif isinstance(obj, (list,tuple)): - return uncanSequence(obj, g) - else: - return obj - -def uncanDict(obj, g=None): + """invert canning""" + for cls,uncanner in uncan_map.iteritems(): + if isinstance(cls, basestring): + try: + cls = import_item(cls) + except Exception: + # not importable + print "not importable: %r" % cls + continue + if isinstance(obj, cls): + return uncanner(obj, g) + return obj + +def uncan_dict(obj, g=None): if isinstance(obj, dict): newobj = {} for k, v in obj.iteritems(): @@ -139,7 +198,7 @@ def uncanDict(obj, g=None): else: return obj -def uncanSequence(obj, g=None): +def uncan_sequence(obj, g=None): if isinstance(obj, (list, tuple)): t = type(obj) return t([uncan(i,g) for i in obj]) @@ -147,5 +206,27 @@ def uncanSequence(obj, g=None): return obj -def rebindFunctionGlobals(f, glbls): - return FunctionType(f.func_code, glbls) +#------------------------------------------------------------------------------- +# API dictionary +#------------------------------------------------------------------------------- + +# These dicts can be extended for custom serialization of new objects + +can_map = { + 'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')), + 'numpy.ndarray' : CannedArray, + FunctionType : CannedFunction, + bytes : CannedBytes, + buffer : CannedBuffer, + # dict : can_dict, + # list : can_sequence, + # tuple : can_sequence, +} + +uncan_map = { + CannedObject : lambda obj, g: obj.get_object(g), + # dict : uncan_dict, + # list : uncan_sequence, + # tuple : uncan_sequence, +} + diff --git a/IPython/zmq/ipkernel.py b/IPython/zmq/ipkernel.py index 21b521d..09b43f6 100755 --- a/IPython/zmq/ipkernel.py +++ b/IPython/zmq/ipkernel.py @@ -572,8 +572,8 @@ class Kernel(Configurable): for key in ns.iterkeys(): working.pop(key) - packed_result,buf = serialize_object(result) - result_buf = [packed_result]+buf + result_buf = serialize_object(result) + except: # invoke IPython traceback formatting shell.showtraceback() diff --git a/IPython/zmq/serialize.py b/IPython/zmq/serialize.py index efff2d6..95744d2 100644 --- a/IPython/zmq/serialize.py +++ b/IPython/zmq/serialize.py @@ -32,7 +32,9 @@ except: # IPython imports from IPython.utils import py3compat -from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence +from IPython.utils.pickleutil import ( + can, uncan, can_sequence, uncan_sequence, CannedObject +) from IPython.utils.newserialized import serialize, unserialize if py3compat.PY3: @@ -42,7 +44,32 @@ if py3compat.PY3: # Serialization Functions #----------------------------------------------------------------------------- -def serialize_object(obj, threshold=64e-6): +# maximum items to iterate through in a container +MAX_ITEMS = 64 + +def _extract_buffers(obj, threshold=1024): + """extract buffers larger than a certain threshold""" + buffers = [] + if isinstance(obj, CannedObject) and obj.buffers: + for i,buf in enumerate(obj.buffers): + if len(buf) > threshold: + # buffer larger than threshold, prevent pickling + obj.buffers[i] = None + buffers.append(buf) + elif isinstance(buf, buffer): + # buffer too small for separate send, coerce to bytes + # because pickling buffer objects just results in broken pointers + obj.buffers[i] = bytes(buf) + return buffers + +def _restore_buffers(obj, buffers): + """restore buffers extracted by """ + if isinstance(obj, CannedObject) and obj.buffers: + for i,buf in enumerate(obj.buffers): + if buf is None: + obj.buffers[i] = buffers.pop(0) + +def serialize_object(obj, threshold=1024): """Serialize an object into a list of sendable buffers. Parameters @@ -50,76 +77,78 @@ def serialize_object(obj, threshold=64e-6): obj : object The object to be serialized - threshold : float - The threshold for not double-pickling the content. - + threshold : int + The threshold (in bytes) for pulling out data buffers + to avoid pickling them. Returns ------- - ('pmd', [bufs]) : - where pmd is the pickled metadata wrapper, - bufs is a list of data buffers + [bufs] : list of buffers representing the serialized object. """ - databuffers = [] - if isinstance(obj, (list, tuple)): - clist = canSequence(obj) - slist = map(serialize, clist) - for s in slist: - if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold: - databuffers.append(s.getData()) - s.data = None - return pickle.dumps(slist,-1), databuffers - elif isinstance(obj, dict): - sobj = {} + buffers = [] + if isinstance(obj, (list, tuple)) and len(obj) < MAX_ITEMS: + cobj = can_sequence(obj) + for c in cobj: + buffers.extend(_extract_buffers(c, threshold)) + elif isinstance(obj, dict) and len(obj) < MAX_ITEMS: + cobj = {} for k in sorted(obj.iterkeys()): - s = serialize(can(obj[k])) - if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold: - databuffers.append(s.getData()) - s.data = None - sobj[k] = s - return pickle.dumps(sobj,-1),databuffers + c = can(obj[k]) + buffers.extend(_extract_buffers(c, threshold)) + cobj[k] = c else: - s = serialize(can(obj)) - if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold: - databuffers.append(s.getData()) - s.data = None - return pickle.dumps(s,-1),databuffers - - -def unserialize_object(bufs): - """reconstruct an object serialized by serialize_object from data buffers.""" - bufs = list(bufs) - sobj = pickle.loads(bufs.pop(0)) - if isinstance(sobj, (list, tuple)): - for s in sobj: - if s.data is None: - s.data = bufs.pop(0) - return uncanSequence(map(unserialize, sobj)), bufs - elif isinstance(sobj, dict): + cobj = can(obj) + buffers.extend(_extract_buffers(cobj, threshold)) + + buffers.insert(0, pickle.dumps(cobj,-1)) + return buffers + +def unserialize_object(buffers, g=None): + """reconstruct an object serialized by serialize_object from data buffers. + + Parameters + ---------- + + bufs : list of buffers/bytes + + g : globals to be used when uncanning + + Returns + ------- + + (newobj, bufs) : unpacked object, and the list of remaining unused buffers. + """ + bufs = list(buffers) + canned = pickle.loads(bufs.pop(0)) + if isinstance(canned, (list, tuple)) and len(canned) < MAX_ITEMS: + for c in canned: + _restore_buffers(c, bufs) + newobj = uncan_sequence(canned, g) + elif isinstance(canned, dict) and len(canned) < MAX_ITEMS: newobj = {} - for k in sorted(sobj.iterkeys()): - s = sobj[k] - if s.data is None: - s.data = bufs.pop(0) - newobj[k] = uncan(unserialize(s)) - return newobj, bufs + for k in sorted(canned.iterkeys()): + c = canned[k] + _restore_buffers(c, bufs) + newobj[k] = uncan(c, g) else: - if sobj.data is None: - sobj.data = bufs.pop(0) - return uncan(unserialize(sobj)), bufs + _restore_buffers(canned, bufs) + newobj = uncan(canned, g) + + return newobj, bufs -def pack_apply_message(f, args, kwargs, threshold=64e-6): +def pack_apply_message(f, args, kwargs, threshold=1024): """pack up a function, args, and kwargs to be sent over the wire as a series of buffers. Any object whose data is larger than `threshold` - will not have their data copied (currently only numpy arrays support zero-copy)""" + will not have their data copied (currently only numpy arrays support zero-copy) + """ msg = [pickle.dumps(can(f),-1)] databuffers = [] # for large objects - sargs, bufs = serialize_object(args,threshold) - msg.append(sargs) - databuffers.extend(bufs) - skwargs, bufs = serialize_object(kwargs,threshold) - msg.append(skwargs) - databuffers.extend(bufs) + sargs = serialize_object(args,threshold) + msg.append(sargs[0]) + databuffers.extend(sargs[1:]) + skwargs = serialize_object(kwargs,threshold) + msg.append(skwargs[0]) + databuffers.extend(skwargs[1:]) msg.extend(databuffers) return msg @@ -131,49 +160,16 @@ def unpack_apply_message(bufs, g=None, copy=True): if not copy: for i in range(3): bufs[i] = bufs[i].bytes - cf = pickle.loads(bufs.pop(0)) - sargs = list(pickle.loads(bufs.pop(0))) - skwargs = dict(pickle.loads(bufs.pop(0))) - # print sargs, skwargs - f = uncan(cf, g) - for sa in sargs: - if sa.data is None: - m = bufs.pop(0) - if sa.getTypeDescriptor() in ('buffer', 'ndarray'): - # always use a buffer, until memoryviews get sorted out - sa.data = buffer(m) - # disable memoryview support - # if copy: - # sa.data = buffer(m) - # else: - # sa.data = m.buffer - else: - if copy: - sa.data = m - else: - sa.data = m.bytes + f = uncan(pickle.loads(bufs.pop(0)), g) + # sargs = bufs.pop(0) + # pop kwargs out, so first n-elements are args, serialized + skwargs = bufs.pop(1) + args, bufs = unserialize_object(bufs, g) + # put skwargs back in as the first element + bufs.insert(0, skwargs) + kwargs, bufs = unserialize_object(bufs, g) - args = uncanSequence(map(unserialize, sargs), g) - kwargs = {} - for k in sorted(skwargs.iterkeys()): - sa = skwargs[k] - if sa.data is None: - m = bufs.pop(0) - if sa.getTypeDescriptor() in ('buffer', 'ndarray'): - # always use a buffer, until memoryviews get sorted out - sa.data = buffer(m) - # disable memoryview support - # if copy: - # sa.data = buffer(m) - # else: - # sa.data = m.buffer - else: - if copy: - sa.data = m - else: - sa.data = m.bytes - - kwargs[k] = uncan(unserialize(sa), g) + assert not bufs, "Shouldn't be any data left over" return f,args,kwargs diff --git a/IPython/zmq/tests/test_serialize.py b/IPython/zmq/tests/test_serialize.py new file mode 100644 index 0000000..c1e0429 --- /dev/null +++ b/IPython/zmq/tests/test_serialize.py @@ -0,0 +1,115 @@ +"""test serialization tools""" + +#------------------------------------------------------------------------------- +# Copyright (C) 2011 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import pickle + +import nose.tools as nt + +# from unittest import TestCaes +from IPython.zmq.serialize import serialize_object, unserialize_object +from IPython.testing import decorators as dec +from IPython.utils.pickleutil import CannedArray + +def roundtrip(obj): + """roundtrip an object through serialization""" + bufs = serialize_object(obj) + obj2, remainder = unserialize_object(bufs) + nt.assert_equals(remainder, []) + return obj2 + +class C(object): + """dummy class for """ + + def __init__(self, **kwargs): + for key,value in kwargs.iteritems(): + setattr(self, key, value) + +@dec.parametric +def test_roundtrip_simple(): + for obj in [ + 'hello', + dict(a='b', b=10), + [1,2,'hi'], + (b'123', 'hello'), + ]: + obj2 = roundtrip(obj) + yield nt.assert_equals(obj, obj2) + +@dec.parametric +def test_roundtrip_nested(): + for obj in [ + dict(a=range(5), b={1:b'hello'}), + [range(5),[range(3),(1,[b'whoda'])]], + ]: + obj2 = roundtrip(obj) + yield nt.assert_equals(obj, obj2) + +@dec.parametric +def test_roundtrip_buffered(): + for obj in [ + dict(a=b"x"*1025), + b"hello"*500, + [b"hello"*501, 1,2,3] + ]: + bufs = serialize_object(obj) + yield nt.assert_equals(len(bufs), 2) + obj2, remainder = unserialize_object(bufs) + yield nt.assert_equals(remainder, []) + yield nt.assert_equals(obj, obj2) + +@dec.parametric +@dec.skip_without('numpy') +def test_numpy(): + import numpy + from numpy.testing.utils import assert_array_equal + for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)): + for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]): + A = numpy.empty(shape, dtype=dtype) + bufs = serialize_object(A) + B, r = unserialize_object(bufs) + yield nt.assert_equals(r, []) + yield assert_array_equal(A,B) + +@dec.parametric +@dec.skip_without('numpy') +def test_numpy_in_seq(): + import numpy + from numpy.testing.utils import assert_array_equal + for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)): + for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]): + A = numpy.empty(shape, dtype=dtype) + bufs = serialize_object((A,1,2,b'hello')) + canned = pickle.loads(bufs[0]) + yield nt.assert_true(canned[0], CannedArray) + tup, r = unserialize_object(bufs) + B = tup[0] + yield nt.assert_equals(r, []) + yield assert_array_equal(A,B) + +@dec.parametric +@dec.skip_without('numpy') +def test_numpy_in_dict(): + import numpy + from numpy.testing.utils import assert_array_equal + for shape in ((), (0,), (100,), (1024,10), (10,8,6,5)): + for dtype in ('uint8', 'float64', 'int32', [('int16', 'float32')]): + A = numpy.empty(shape, dtype=dtype) + bufs = serialize_object(dict(a=A,b=1,c=range(20))) + canned = pickle.loads(bufs[0]) + yield nt.assert_true(canned['a'], CannedArray) + d, r = unserialize_object(bufs) + B = d['a'] + yield nt.assert_equals(r, []) + yield assert_array_equal(A,B) + +