diff --git a/IPython/utils/codeutil.py b/IPython/utils/codeutil.py index a51af71..887bd90 100644 --- a/IPython/utils/codeutil.py +++ b/IPython/utils/codeutil.py @@ -10,18 +10,8 @@ we need to automate all of this so that functions themselves can be pickled. Reference: A. Tremols, P Cogolo, "Python Cookbook," p 302-305 """ -__docformat__ = "restructuredtext en" - -#------------------------------------------------------------------------------- -# Copyright (C) 2008-2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#------------------------------------------------------------------------------- - -#------------------------------------------------------------------------------- -# Imports -#------------------------------------------------------------------------------- +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. import sys import types @@ -34,12 +24,10 @@ def code_ctor(*args): return types.CodeType(*args) def reduce_code(co): - if co.co_freevars or co.co_cellvars: - raise ValueError("Sorry, cannot pickle code objects with closures") args = [co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code, co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno, - co.co_lnotab] + co.co_lnotab, co.co_freevars, co.co_cellvars] if sys.version_info[0] >= 3: args.insert(1, co.co_kwonlyargcount) return code_ctor, tuple(args) diff --git a/IPython/utils/pickleutil.py b/IPython/utils/pickleutil.py index d3fed87..41b3198 100644 --- a/IPython/utils/pickleutil.py +++ b/IPython/utils/pickleutil.py @@ -1,19 +1,8 @@ # encoding: utf-8 - """Pickle related utilities. Perhaps this should be called 'can'.""" -__docformat__ = "restructuredtext en" - -#------------------------------------------------------------------------------- -# Copyright (C) 2008-2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#------------------------------------------------------------------------------- - -#------------------------------------------------------------------------------- -# Imports -#------------------------------------------------------------------------------- +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. import copy import logging @@ -39,6 +28,16 @@ else: from types import ClassType class_type = (type, ClassType) +def _get_cell_type(a=None): + """the type of a closure cell doesn't seem to be importable, + so just create one + """ + def inner(): + return a + return type(py3compat.get_closure(inner)[0]) + +cell_type = _get_cell_type() + #------------------------------------------------------------------------------- # Functions #------------------------------------------------------------------------------- @@ -131,6 +130,18 @@ class Reference(CannedObject): return eval(self.name, g) +class CannedCell(CannedObject): + """Can a closure cell""" + def __init__(self, cell): + self.cell_contents = can(cell.cell_contents) + + def get_object(self, g=None): + cell_contents = uncan(self.cell_contents, g) + def inner(): + return cell_contents + return py3compat.get_closure(inner)[0] + + class CannedFunction(CannedObject): def __init__(self, f): @@ -140,6 +151,13 @@ class CannedFunction(CannedObject): self.defaults = [ can(fd) for fd in f.__defaults__ ] else: self.defaults = None + + closure = py3compat.get_closure(f) + if closure: + self.closure = tuple( can(cell) for cell in closure ) + else: + self.closure = None + self.module = f.__module__ or '__main__' self.__name__ = f.__name__ self.buffers = [] @@ -159,7 +177,11 @@ class CannedFunction(CannedObject): defaults = tuple(uncan(cfd, g) for cfd in self.defaults) else: defaults = None - newFunc = FunctionType(self.code, g, self.__name__, defaults) + if self.closure: + closure = tuple(uncan(cell, g) for cell in self.closure) + else: + closure = None + newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) return newFunc class CannedClass(CannedObject): @@ -378,6 +400,7 @@ can_map = { FunctionType : CannedFunction, bytes : CannedBytes, buffer : CannedBuffer, + cell_type : CannedCell, class_type : can_class, } diff --git a/IPython/utils/py3compat.py b/IPython/utils/py3compat.py index 161c54e..201a2eb 100644 --- a/IPython/utils/py3compat.py +++ b/IPython/utils/py3compat.py @@ -131,6 +131,10 @@ if sys.version_info[0] >= 3: Accepts a string or a function, so it can be used as a decorator.""" return s.format(u='') + + def get_closure(f): + """Get a function's closure attribute""" + return f.__closure__ else: PY3 = False @@ -192,6 +196,9 @@ else: def doctest_refactor_print(func_or_str): return func_or_str + def get_closure(f): + """Get a function's closure attribute""" + return f.func_closure # Abstract u'abc' syntax: @_modify_str_or_docstring diff --git a/IPython/utils/tests/test_pickleutil.py b/IPython/utils/tests/test_pickleutil.py new file mode 100644 index 0000000..82fe59e --- /dev/null +++ b/IPython/utils/tests/test_pickleutil.py @@ -0,0 +1,62 @@ + +import pickle + +import nose.tools as nt +from IPython.utils import codeutil +from IPython.utils.pickleutil import can, uncan + +def interactive(f): + f.__module__ = '__main__' + return f + +def dumps(obj): + return pickle.dumps(can(obj)) + +def loads(obj): + return uncan(pickle.loads(obj)) + +def test_no_closure(): + @interactive + def foo(): + a = 5 + return a + + pfoo = dumps(foo) + bar = loads(pfoo) + nt.assert_equal(foo(), bar()) + +def test_generator_closure(): + # this only creates a closure on Python 3 + @interactive + def foo(): + i = 'i' + r = [ i for j in (1,2) ] + return r + + pfoo = dumps(foo) + bar = loads(pfoo) + nt.assert_equal(foo(), bar()) + +def test_nested_closure(): + @interactive + def foo(): + i = 'i' + def g(): + return i + return g() + + pfoo = dumps(foo) + bar = loads(pfoo) + nt.assert_equal(foo(), bar()) + +def test_closure(): + i = 'i' + @interactive + def foo(): + return i + + pfoo = dumps(foo) + bar = loads(pfoo) + nt.assert_equal(foo(), bar()) + + \ No newline at end of file