diff --git a/IPython/parallel/controller/dependency.py b/IPython/parallel/controller/dependency.py index 0499762..b6d4561 100644 --- a/IPython/parallel/controller/dependency.py +++ b/IPython/parallel/controller/dependency.py @@ -5,18 +5,21 @@ Authors: * Min RK """ #----------------------------------------------------------------------------- -# Copyright (C) 2010-2011 The IPython Development Team +# Copyright (C) 2013 The IPython Development Team # # Distributed under the terms of the BSD License. The full license is in # the file COPYING, distributed as part of this software. #----------------------------------------------------------------------------- +import sys + from types import ModuleType from IPython.parallel.client.asyncresult import AsyncResult from IPython.parallel.error import UnmetDependency from IPython.parallel.util import interactive from IPython.utils import py3compat +from IPython.utils.pickleutil import can, uncan class depend(object): """Dependency decorator, for use with tasks. @@ -60,8 +63,9 @@ class dependent(object): self.dkwargs = dkwargs def __call__(self, *args, **kwargs): - # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'): - # self.df.func_globals = self.f.func_globals + user_ns = sys.modules['__main__'].__dict__ + for key, value in self.dkwargs.items(): + self.dkwargs[key] = uncan(value, user_ns) if self.df(*self.dargs, **self.dkwargs) is False: raise UnmetDependency() return self.f(*args, **kwargs) @@ -72,41 +76,62 @@ class dependent(object): return self.func_name @interactive -def _require(*names): +def _require(*modules, **mapping): """Helper for @require decorator.""" from IPython.parallel.error import UnmetDependency + from IPython.utils.pickleutil import uncan user_ns = globals() - for name in names: - if name in user_ns: - continue + for name in modules: try: - exec 'import %s'%name in user_ns + exec 'import %s' % name in user_ns except ImportError: raise UnmetDependency(name) + + for name, cobj in mapping.items(): + user_ns[name] = uncan(cobj, user_ns) return True -def require(*mods): - """Simple decorator for requiring names to be importable. +def require(*objects, **mapping): + """Simple decorator for requiring local objects and modules to be available + when the decorated function is called on the engine. + + Modules specified by name or passed directly will be imported + prior to calling the decorated function. + + Objects other than modules will be pushed as a part of the task. + Functions can be passed positionally, + and will be pushed to the engine with their __name__. + Other objects can be passed by keyword arg. Examples -------- In [1]: @require('numpy') ...: def norm(a): - ...: import numpy ...: return numpy.linalg.norm(a,2) + + In [2]: foo = lambda x: x*x + In [3]: @require(foo) + ...: def bar(a): + ...: return foo(1-a) """ names = [] - for mod in mods: - if isinstance(mod, ModuleType): - mod = mod.__name__ + for obj in objects: + if isinstance(obj, ModuleType): + obj = obj.__name__ - if isinstance(mod, basestring): - names.append(mod) + if isinstance(obj, basestring): + names.append(obj) + elif hasattr(obj, '__name__'): + mapping[obj.__name__] = obj else: - raise TypeError("names must be modules or module names, not %s"%type(mod)) + raise TypeError("Objects other than modules and functions " + "must be passed by kwarg, but got: %s" % type(obj) + ) - return depend(_require, *names) + for name, obj in mapping.items(): + mapping[name] = can(obj) + return depend(_require, *names, **mapping) class Dependency(set): """An object for representing a set of msg_id dependencies.