diff --git a/IPython/zmq/parallel/__init__.py b/IPython/zmq/parallel/__init__.py index f577c35..2dbf93d 100644 --- a/IPython/zmq/parallel/__init__.py +++ b/IPython/zmq/parallel/__init__.py @@ -10,14 +10,15 @@ # Imports #----------------------------------------------------------------------------- -# from .asyncresult import * -# from .client import Client -# from .dependency import * -# from .remotefunction import * -# from .view import * - import zmq if zmq.__version__ < '2.1.3': raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__) +from .asyncresult import * +from .client import Client +from .dependency import * +from .remotefunction import * +from .view import * + + diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 2045266..6bddd96 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -245,8 +245,6 @@ class Client(HasTraits): _mux_socket=Instance('zmq.Socket') _task_socket=Instance('zmq.Socket') _task_scheme=Str() - _balanced_views=Dict() - _direct_views=Dict() _closed = False _ignored_control_replies=Int(0) _ignored_hub_replies=Int(0) @@ -389,7 +387,20 @@ class Client(HasTraits): else: raise TypeError("%r not valid str target, must be 'all'"%(targets)) elif isinstance(targets, int): + if targets < 0: + targets = self.ids[targets] + if targets not in self.ids: + raise IndexError("No such engine: %i"%targets) targets = [targets] + + if isinstance(targets, slice): + indices = range(len(self._ids))[targets] + ids = self.ids + targets = [ ids[i] for i in indices ] + + if not isinstance(targets, (tuple, list, xrange)): + raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) + return [self._engines[t] for t in targets], list(targets) def _connect(self, sshserver, ssh_kwargs, timeout): @@ -688,7 +699,7 @@ class Client(HasTraits): if not isinstance(key, (int, slice, tuple, list, xrange)): raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key))) else: - return self._get_view(key, balanced=False) + return self.direct_view(key) #-------------------------------------------------------------------------- # Begin public methods @@ -962,31 +973,6 @@ class Client(HasTraits): # construct a View object #-------------------------------------------------------------------------- - def _cache_view(self, targets, balanced): - """save views, so subsequent requests don't create new objects.""" - if balanced: - # validate whether we can run - if not self._task_socket: - msg = "Task farming is disabled" - if self._task_scheme == 'pure': - msg += " because the pure ZMQ scheduler cannot handle" - msg += " disappearing engines." - raise RuntimeError(msg) - socket = self._task_socket - view_class = LoadBalancedView - view_cache = self._balanced_views - else: - socket = self._mux_socket - view_class = DirectView - view_cache = self._direct_views - - # use str, since often targets will be a list - key = str(targets) - if key not in view_cache: - view_cache[key] = view_class(client=self, socket=socket, targets=targets) - - return view_cache[key] - def load_balanced_view(self, targets=None): """construct a DirectView object. @@ -999,7 +985,9 @@ class Client(HasTraits): targets: list,slice,int,etc. [default: use all engines] The subset of engines across which to load-balance """ - return self._get_view(targets, balanced=True) + if targets is None: + targets = self._build_targets(targets)[1] + return LoadBalancedView(client=self, socket=self._task_socket, targets=targets) def direct_view(self, targets='all'): """construct a DirectView object. @@ -1013,49 +1001,11 @@ class Client(HasTraits): targets: list,slice,int,etc. [default: use all engines] The engines to use for the View """ - return self._get_view(targets, balanced=False) - - def _get_view(self, targets, balanced): - """Method for constructing View objects. - - If no arguments are specified, create a LoadBalancedView - using all engines. If only `targets` specified, it will - be a DirectView. This method is the underlying implementation - of ``client.__getitem__``. - - Parameters - ---------- - - targets: list,slice,int,etc. [default: use all engines] - The engines to use for the View - balanced : bool [default: False if targets specified, True else] - whether to build a LoadBalancedView or a DirectView - - """ - - if targets in (None,'all'): - if balanced: - return self._cache_view(None,True) - else: - targets = slice(None) - - if isinstance(targets, int): - if targets < 0: - targets = self.ids[targets] - if targets not in self.ids: - raise IndexError("No such engine: %i"%targets) - return self._cache_view(targets, balanced) - - if isinstance(targets, slice): - indices = range(len(self.ids))[targets] - ids = sorted(self._ids) - targets = [ ids[i] for i in indices ] - - if isinstance(targets, (tuple, list, xrange)): - _,targets = self._build_targets(list(targets)) - return self._cache_view(targets, balanced) - else: - raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) + single = isinstance(targets, int) + targets = self._build_targets(targets)[1] + if single: + targets = targets[0] + return DirectView(client=self, socket=self._mux_socket, targets=targets) #-------------------------------------------------------------------------- # Data movement (TO BE REMOVED) diff --git a/IPython/zmq/parallel/dependency.py b/IPython/zmq/parallel/dependency.py index 9216709..bbd60f3 100644 --- a/IPython/zmq/parallel/dependency.py +++ b/IPython/zmq/parallel/dependency.py @@ -6,6 +6,8 @@ # the file COPYING, distributed as part of this software. #----------------------------------------------------------------------------- +from types import ModuleType + from .asyncresult import AsyncResult from .error import UnmetDependency from .util import interactive @@ -76,7 +78,7 @@ def _require(*names): raise UnmetDependency(name) return True -def require(*names): +def require(*mods): """Simple decorator for requiring names to be importable. Examples @@ -87,6 +89,16 @@ def require(*names): ...: import numpy ...: return numpy.linalg.norm(a,2) """ + names = [] + for mod in mods: + if isinstance(mod, ModuleType): + mod = mod.__name__ + + if isinstance(mod, basestring): + names.append(mod) + else: + raise TypeError("names must be modules or module names, not %s"%type(mod)) + return depend(_require, *names) class Dependency(set): diff --git a/IPython/zmq/parallel/launcher.py b/IPython/zmq/parallel/launcher.py index 1e8e13e..c6f591e 100644 --- a/IPython/zmq/parallel/launcher.py +++ b/IPython/zmq/parallel/launcher.py @@ -48,9 +48,9 @@ from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError from .factory import LoggingFactory -# load winhpcjob from IPython.kernel +# load winhpcjob only on Windows try: - from IPython.kernel.winhpcjob import ( + from .winhpcjob import ( IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob ) diff --git a/IPython/zmq/parallel/tests/test_client.py b/IPython/zmq/parallel/tests/test_client.py index 47bebca..9a14c3c 100644 --- a/IPython/zmq/parallel/tests/test_client.py +++ b/IPython/zmq/parallel/tests/test_client.py @@ -61,15 +61,6 @@ class TestClient(ClusterTestCase): self.assertEquals(v.targets, targets[-1]) self.assertRaises(TypeError, lambda : self.client[None]) - def test_view_cache(self): - """test that multiple view requests return the same object""" - v = self.client[:2] - v2 =self.client[:2] - self.assertTrue(v is v2) - v = self.client.load_balanced_view() - v2 = self.client.load_balanced_view(targets=None) - self.assertTrue(v is v2) - def test_targets(self): """test various valid targets arguments""" build = self.client._build_targets diff --git a/IPython/zmq/parallel/tests/test_view.py b/IPython/zmq/parallel/tests/test_view.py index 4c1679f..91a9820 100644 --- a/IPython/zmq/parallel/tests/test_view.py +++ b/IPython/zmq/parallel/tests/test_view.py @@ -285,3 +285,17 @@ class TestView(ClusterTestCase): self.assertFalse(view.block) self.assertTrue(view.block) + def test_importer(self): + view = self.client[-1] + view.clear(block=True) + with view.importer: + import re + + @interactive + def findall(pat, s): + # this globals() step isn't necessary in real code + # only to prevent a closure in the test + return globals()['re'].findall(pat, s) + + self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split()) + diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index 74baa22..c198062 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -10,13 +10,16 @@ # Imports #----------------------------------------------------------------------------- +import imp +import sys import warnings from contextlib import contextmanager +from types import ModuleType import zmq from IPython.testing import decorators as testdec -from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance +from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat from IPython.external.decorator import decorator @@ -94,48 +97,36 @@ class View(HasTraits): abort, shutdown """ + # flags block=Bool(False) track=Bool(True) + targets = Any() + history=List() outstanding = Set() results = Dict() client = Instance('IPython.zmq.parallel.client.Client') _socket = Instance('zmq.Socket') - _ntargets = Int(1) - _flag_names = List(['block', 'track']) + _flag_names = List(['targets', 'block', 'track']) _targets = Any() _idents = Any() - def __init__(self, client=None, socket=None, targets=None): + def __init__(self, client=None, socket=None, **flags): super(View, self).__init__(client=client, _socket=socket) - self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets) self.block = client.block - self._idents, self._targets = self.client._build_targets(targets) - if targets is None or isinstance(targets, int): - self._targets = targets - for name in self._flag_names: - # set flags, if they haven't been set yet - setattr(self, name, getattr(self, name, None)) + self.set_flags(**flags) assert not self.__class__ is View, "Don't use base View objects, use subclasses" def __repr__(self): - strtargets = str(self._targets) + strtargets = str(self.targets) if len(strtargets) > 16: strtargets = strtargets[:12]+'...]' return "<%s %s>"%(self.__class__.__name__, strtargets) - @property - def targets(self): - return self._targets - - @targets.setter - def targets(self, value): - raise AttributeError("Cannot set View `targets` after construction!") - def set_flags(self, **kwargs): """set my attribute flags by keyword. @@ -182,9 +173,11 @@ class View(HasTraits): saved_flags[f] = getattr(self, f) self.set_flags(**kwargs) # yield to the with-statement block - yield - # postflight: restore saved flags - self.set_flags(**saved_flags) + try: + yield + finally: + # postflight: restore saved flags + self.set_flags(**saved_flags) #---------------------------------------------------------------- @@ -258,7 +251,7 @@ class View(HasTraits): jobs = self.history return self.client.wait(jobs, timeout) - def abort(self, jobs=None, block=None): + def abort(self, jobs=None, targets=None, block=None): """Abort jobs on my engines. Parameters @@ -269,16 +262,18 @@ class View(HasTraits): else: abort specific msg_id(s). """ block = block if block is not None else self.block - return self.client.abort(jobs=jobs, targets=self._targets, block=block) + targets = targets if targets is not None else self.targets + return self.client.abort(jobs=jobs, targets=targets, block=block) - def queue_status(self, verbose=False): + def queue_status(self, targets=None, verbose=False): """Fetch the Queue status of my engines""" - return self.client.queue_status(targets=self._targets, verbose=verbose) + targets = targets if targets is not None else self.targets + return self.client.queue_status(targets=targets, verbose=verbose) def purge_results(self, jobs=[], targets=[]): """Instruct the controller to forget specific results.""" if targets is None or targets == 'all': - targets = self._targets + targets = self.targets return self.client.purge_results(jobs=jobs, targets=targets) @spin_after @@ -377,11 +372,104 @@ class DirectView(View): def __init__(self, client=None, socket=None, targets=None): super(DirectView, self).__init__(client=client, socket=socket, targets=targets) + + @property + def importer(self): + """sync_imports(local=True) as a property. + See sync_imports for details. + + In [10]: with v.importer: + ....: import numpy + ....: + importing numpy on engine(s) + + """ + return self.sync_imports(True) + + @contextmanager + def sync_imports(self, local=True): + """Context Manager for performing simultaneous local and remote imports. + + 'import x as y' will *not* work. The 'as y' part will simply be ignored. + + >>> with view.sync_imports(): + ... from numpy import recarray + importing recarray from numpy on engine(s) + + """ + import __builtin__ + local_import = __builtin__.__import__ + modules = set() + results = [] + @util.interactive + def remote_import(name, fromlist, level): + """the function to be passed to apply, that actually performs the import + on the engine, and loads up the user namespace. + """ + import sys + user_ns = globals() + mod = __import__(name, fromlist=fromlist, level=level) + if fromlist: + for key in fromlist: + user_ns[key] = getattr(mod, key) + else: + user_ns[name] = sys.modules[name] + + def view_import(name, globals={}, locals={}, fromlist=[], level=-1): + """the drop-in replacement for __import__, that optionally imports + locally as well. + """ + # don't override nested imports + save_import = __builtin__.__import__ + __builtin__.__import__ = local_import + + if imp.lock_held(): + # this is a side-effect import, don't do it remotely, or even + # ignore the local effects + return local_import(name, globals, locals, fromlist, level) + + imp.acquire_lock() + if local: + mod = local_import(name, globals, locals, fromlist, level) + else: + raise NotImplementedError("remote-only imports not yet implemented") + imp.release_lock() + + key = name+':'+','.join(fromlist or []) + if level == -1 and key not in modules: + modules.add(key) + if fromlist: + print "importing %s from %s on engine(s)"%(','.join(fromlist), name) + else: + print "importing %s on engine(s)"%name + results.append(self.apply_async(remote_import, name, fromlist, level)) + # restore override + __builtin__.__import__ = save_import + + return mod + + # override __import__ + __builtin__.__import__ = view_import + try: + # enter the block + yield + except ImportError: + if not local: + # ignore import errors if not doing local imports + pass + finally: + # always restore __import__ + __builtin__.__import__ = local_import + + for r in results: + # raise possible remote ImportErrors here + r.get() + @sync_results @save_ids - def _really_apply(self, f, args=None, kwargs=None, block=None, track=None): + def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None): """calls f(*args, **kwargs) on remote engines, returning the result. This method sets all of `apply`'s flags via this View's attributes. @@ -395,6 +483,8 @@ class DirectView(View): kwargs : dict [default: empty] + targets : target list [default: self.targets] + where to run block : bool [default: self.block] whether to block track : bool [default: self.track] @@ -414,16 +504,19 @@ class DirectView(View): kwargs = {} if kwargs is None else kwargs block = self.block if block is None else block track = self.track if track is None else track + targets = self.targets if targets is None else targets + + _idents = self.client._build_targets(targets)[0] msg_ids = [] trackers = [] - for ident in self._idents: + for ident in _idents: msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, ident=ident) if track: trackers.append(msg['tracker']) msg_ids.append(msg['msg_id']) tracker = None if track is False else zmq.MessageTracker(*trackers) - ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=self._targets, tracker=tracker) + ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) if block: try: return ar.get() @@ -474,7 +567,7 @@ class DirectView(View): pf = ParallelFunction(self, f, block=block, **kwargs) return pf.map(*sequences) - def execute(self, code, block=None): + def execute(self, code, targets=None, block=None): """Executes `code` on `targets` in blocking or nonblocking manner. ``execute`` is always `bound` (affects engine namespace) @@ -488,9 +581,9 @@ class DirectView(View): whether or not to wait until done to return default: self.block """ - return self._really_apply(util._execute, args=(code,), block=block) + return self._really_apply(util._execute, args=(code,), block=block, targets=targets) - def run(self, filename, block=None): + def run(self, filename, targets=None, block=None): """Execute contents of `filename` on my engine(s). This simply reads the contents of the file and calls `execute`. @@ -512,7 +605,7 @@ class DirectView(View): # add newline in case of trailing indented whitespace # which will cause SyntaxError code = f.read()+'\n' - return self.execute(code, block=block) + return self.execute(code, block=block, targets=targets) def update(self, ns): """update remote namespace with dict `ns` @@ -521,7 +614,7 @@ class DirectView(View): """ return self.push(ns, block=self.block, track=self.track) - def push(self, ns, block=None, track=None): + def push(self, ns, targets=None, block=None, track=None): """update remote namespace with dict `ns` Parameters @@ -536,10 +629,11 @@ class DirectView(View): block = block if block is not None else self.block track = track if track is not None else self.track + targets = targets if targets is not None else self.targets # applier = self.apply_sync if block else self.apply_async if not isinstance(ns, dict): raise TypeError("Must be a dict, not %s"%type(ns)) - return self._really_apply(util._push, (ns,),block=block, track=track) + return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets) def get(self, key_s): """get object(s) by `key_s` from remote namespace @@ -549,13 +643,14 @@ class DirectView(View): # block = block if block is not None else self.block return self.pull(key_s, block=True) - def pull(self, names, block=True): + def pull(self, names, targets=None, block=True): """get object(s) by `name` from remote namespace will return one object if it is a key. can also take a list of keys, in which case it will return a list of objects. """ block = block if block is not None else self.block + targets = targets if targets is not None else self.targets applier = self.apply_sync if block else self.apply_async if isinstance(names, basestring): pass @@ -565,26 +660,27 @@ class DirectView(View): raise TypeError("keys must be str, not type %r"%type(key)) else: raise TypeError("names must be strs, not %r"%names) - return applier(util._pull, names) + return self._really_apply(util._pull, (names,), block=block, targets=targets) - def scatter(self, key, seq, dist='b', flatten=False, block=None, track=None): + def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None): """ Partition a Python sequence and send the partitions to a set of engines. """ block = block if block is not None else self.block track = track if track is not None else self.track - targets = self._targets + targets = targets if targets is not None else self.targets + mapObject = Map.dists[dist]() nparts = len(targets) msg_ids = [] trackers = [] for index, engineid in enumerate(targets): - push = self.client[engineid].push partition = mapObject.getPartition(seq, index, nparts) if flatten and len(partition) == 1: - r = push({key: partition[0]}, block=False, track=track) + ns = {key: partition[0]} else: - r = push({key: partition},block=False, track=track) + ns = {key: partition} + r = self.push(ns, block=False, track=track, targets=engineid) msg_ids.extend(r.msg_ids) if track: trackers.append(r._tracker) @@ -602,16 +698,17 @@ class DirectView(View): @sync_results @save_ids - def gather(self, key, dist='b', block=None): + def gather(self, key, dist='b', targets=None, block=None): """ Gather a partitioned sequence on a set of engines as a single local seq. """ block = block if block is not None else self.block + targets = targets if targets is not None else self.targets mapObject = Map.dists[dist]() msg_ids = [] - for index, engineid in enumerate(self._targets): - - msg_ids.extend(self.client[engineid].pull(key, block=False).msg_ids) + + for index, engineid in enumerate(targets): + msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids) r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather') @@ -628,15 +725,17 @@ class DirectView(View): def __setitem__(self,key, value): self.update({key:value}) - def clear(self, block=False): + def clear(self, targets=None, block=False): """Clear the remote namespaces on my engines.""" block = block if block is not None else self.block - return self.client.clear(targets=self._targets, block=block) + targets = targets if targets is not None else self.targets + return self.client.clear(targets=targets, block=block) - def kill(self, block=True): + def kill(self, targets=None, block=True): """Kill my engines.""" block = block if block is not None else self.block - return self.client.kill(targets=self._targets, block=block) + targets = targets if targets is not None else self.targets + return self.client.kill(targets=targets, block=block) #---------------------------------------- # activate for %px,%autopx magics @@ -684,15 +783,16 @@ class LoadBalancedView(View): """ - _flag_names = ['block', 'track', 'follow', 'after', 'timeout'] + follow=Any() + after=Any() + timeout=CFloat() - def __init__(self, client=None, socket=None, targets=None): - super(LoadBalancedView, self).__init__(client=client, socket=socket, targets=targets) - self._ntargets = 1 + _task_scheme = Any() + _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout']) + + def __init__(self, client=None, socket=None, **flags): + super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags) self._task_scheme=client._task_scheme - if targets is None: - self._targets = None - self._idents=[] def _validate_dependency(self, dep): """validate a dependency. @@ -786,7 +886,8 @@ class LoadBalancedView(View): @sync_results @save_ids def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, - after=None, follow=None, timeout=None): + after=None, follow=None, timeout=None, + targets=None): """calls f(*args, **kwargs) on a remote engine, returning the result. This method temporarily sets all of `apply`'s flags for a single call. @@ -844,9 +945,16 @@ class LoadBalancedView(View): after = self.after if after is None else after follow = self.follow if follow is None else follow timeout = self.timeout if timeout is None else timeout + targets = self.targets if targets is None else targets + + if targets is None: + idents = [] + else: + idents = self.client._build_targets(targets)[0] + after = self._render_dependency(after) follow = self._render_dependency(follow) - subheader = dict(after=after, follow=follow, timeout=timeout, targets=self._idents) + subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents) msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, subheader=subheader) @@ -916,5 +1024,5 @@ class LoadBalancedView(View): pf = ParallelFunction(self, f, block=block, chunksize=chunksize) return pf.map(*sequences) - + __all__ = ['LoadBalancedView', 'DirectView'] \ No newline at end of file