From 154798bf001af5a3e6106564b5b74ae4d0338b4e 2011-04-08 00:38:21 From: MinRK Date: 2011-04-08 00:38:21 Subject: [PATCH] Client -> HasTraits, update examples with API tweaks --- diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 4c0f598..81b57a0 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -24,6 +24,8 @@ import zmq # from zmq.eventloop import ioloop, zmqstream from IPython.utils.path import get_ipython_dir +from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode, + Dict, List, Bool, Str, Set) from IPython.external.decorator import decorator from IPython.external.ssh import tunnel @@ -147,7 +149,7 @@ class Metadata(dict): raise KeyError(key) -class Client(object): +class Client(HasTraits): """A semi-synchronous client to the IPython ZMQ controller Parameters @@ -247,31 +249,41 @@ class Client(object): """ - _connected=False - _ssh=False - _engines=None - _registration_socket=None - _query_socket=None - _control_socket=None - _iopub_socket=None - _notification_socket=None - _mux_socket=None - _task_socket=None - _task_scheme=None - block = False - outstanding=None - results = None - history = None - debug = False - targets = None + block = Bool(False) + outstanding=Set() + results = Dict() + metadata = Dict() + history = List() + debug = Bool(False) + profile=CUnicode('default') + + _ids = List() + _connected=Bool(False) + _ssh=Bool(False) + _context = Instance('zmq.Context') + _config = Dict() + _engines=Instance(ReverseDict, (), {}) + _registration_socket=Instance('zmq.Socket') + _query_socket=Instance('zmq.Socket') + _control_socket=Instance('zmq.Socket') + _iopub_socket=Instance('zmq.Socket') + _notification_socket=Instance('zmq.Socket') + _mux_socket=Instance('zmq.Socket') + _task_socket=Instance('zmq.Socket') + _task_scheme=Str() + _balanced_views=Dict() + _direct_views=Dict() + _closed = False def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None, context=None, username=None, debug=False, exec_key=None, sshserver=None, sshkey=None, password=None, paramiko=None, ): + super(Client, self).__init__(debug=debug, profile=profile) if context is None: context = zmq.Context() - self.context = context + self._context = context + self._setup_cluster_dir(profile, cluster_dir, ipython_dir) if self._cd is not None: @@ -325,20 +337,14 @@ class Client(object): self.session = ss.StreamSession(**key_arg) else: self.session = ss.StreamSession(username, **key_arg) - self._registration_socket = self.context.socket(zmq.XREQ) + self._registration_socket = self._context.socket(zmq.XREQ) self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) if self._ssh: tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs) else: self._registration_socket.connect(url) - self._engines = ReverseDict() - self._ids = [] - self.outstanding=set() - self.results = {} - self.metadata = {} - self.history = [] - self.debug = debug - self.session.debug = debug + + self.session.debug = self.debug self._notification_handlers = {'registration_notification' : self._register_engine, 'unregistration_notification' : self._unregister_engine, @@ -370,6 +376,14 @@ class Client(object): """Always up-to-date ids property.""" self._flush_notifications() return self._ids + + def close(self): + if self._closed: + return + snames = filter(lambda n: n.endswith('socket'), dir(self)) + for socket in map(lambda name: getattr(self, name), snames): + socket.close() + self._closed = True def _update_engines(self, engines): """Update our engines dict and _ids from a dict of the form: {id:uuid}.""" @@ -436,28 +450,28 @@ class Client(object): self._config['registration'] = dict(content) if content.status == 'ok': if content.mux: - self._mux_socket = self.context.socket(zmq.PAIR) + self._mux_socket = self._context.socket(zmq.PAIR) self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._mux_socket, content.mux) if content.task: self._task_scheme, task_addr = content.task - self._task_socket = self.context.socket(zmq.PAIR) + self._task_socket = self._context.socket(zmq.PAIR) self._task_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._task_socket, task_addr) if content.notification: - self._notification_socket = self.context.socket(zmq.SUB) + self._notification_socket = self._context.socket(zmq.SUB) connect_socket(self._notification_socket, content.notification) self._notification_socket.setsockopt(zmq.SUBSCRIBE, "") if content.query: - self._query_socket = self.context.socket(zmq.PAIR) + self._query_socket = self._context.socket(zmq.PAIR) self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._query_socket, content.query) if content.control: - self._control_socket = self.context.socket(zmq.PAIR) + self._control_socket = self._context.socket(zmq.PAIR) self._control_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._control_socket, content.control) if content.iopub: - self._iopub_socket = self.context.socket(zmq.SUB) + self._iopub_socket = self._context.socket(zmq.SUB) self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '') self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._iopub_socket, content.iopub) @@ -636,9 +650,13 @@ class Client(object): msg = self.session.recv(sock, mode=zmq.NOBLOCK) #-------------------------------------------------------------------------- - # getitem + # len, getitem #-------------------------------------------------------------------------- + def __len__(self): + """len(client) returns # of engines.""" + return len(self.ids) + def __getitem__(self, key): """index access returns DirectView multiplexer objects @@ -929,8 +947,9 @@ class Client(object): else: return list of results, matching `targets` """ - + assert not self._closed, "cannot use me anymore, I'm closed!" # defaults: + block = block if block is not None else self.block args = args if args is not None else [] kwargs = kwargs if kwargs is not None else {} @@ -955,7 +974,7 @@ class Client(object): raise TypeError("kwargs must be dict, not %s"%type(kwargs)) options = dict(bound=bound, block=block, targets=targets) - + if balanced: return self._apply_balanced(f, args, kwargs, timeout=timeout, after=after, follow=follow, **options) @@ -966,7 +985,7 @@ class Client(object): else: return self._apply_direct(f, args, kwargs, **options) - def _apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None, + def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None, after=None, follow=None, timeout=None): """call f(*args, **kwargs) remotely in a load-balanced manner. @@ -974,8 +993,9 @@ class Client(object): Not to be called directly! """ - for kwarg in (bound, block, targets): - assert kwarg is not None, "kwarg %r must be specified!"%kwarg + loc = locals() + for name in ('bound', 'block'): + assert loc[name] is not None, "kwarg %r must be specified!"%name if self._task_socket is None: msg = "Task farming is disabled" @@ -1030,9 +1050,9 @@ class Client(object): This is a private method, see `apply` for details. Not to be called directly! """ - - for kwarg in (bound, block, targets): - assert kwarg is not None, "kwarg %r must be specified!"%kwarg + loc = locals() + for name in ('bound', 'block', 'targets'): + assert loc[name] is not None, "kwarg %r must be specified!"%name idents,targets = self._build_targets(targets) @@ -1058,35 +1078,65 @@ class Client(object): return ar #-------------------------------------------------------------------------- - # decorators + # construct a View object #-------------------------------------------------------------------------- @defaultblock - def parallel(self, bound=True, targets='all', block=None): - """Decorator for making a ParallelFunction.""" - return parallel(self, bound=bound, targets=targets, block=block) + def remote(self, bound=True, block=None, targets=None, balanced=None): + """Decorator for making a RemoteFunction""" + return remote(self, bound=bound, targets=targets, block=block, balanced=balanced) @defaultblock - def remote(self, bound=True, targets='all', block=None): - """Decorator for making a RemoteFunction.""" - return remote(self, bound=bound, targets=targets, block=block) + def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None): + """Decorator for making a ParallelFunction""" + return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced) - def view(self, targets=None, balanced=False): - """Method for constructing View objects""" + def _cache_view(self, targets, balanced): + """save views, so subsequent requests don't create new objects.""" + if balanced: + view_class = LoadBalancedView + view_cache = self._balanced_views + else: + view_class = DirectView + view_cache = self._direct_views + + # use str, since often targets will be a list + key = str(targets) + if key not in view_cache: + view_cache[key] = view_class(client=self, targets=targets) + + return view_cache[key] + + def view(self, targets=None, balanced=None): + """Method for constructing View objects. + + If no arguments are specified, create a LoadBalancedView + using all engines. If only `targets` specified, it will + be a DirectView. This method is the underlying implementation + of ``client.__getitem__``. + + Parameters + ---------- + + targets: list,slice,int,etc. [default: use all engines] + The engines to use for the View + balanced : bool [default: False if targets specified, True else] + whether to build a LoadBalancedView or a DirectView + + """ + + balanced = (targets is None) if balanced is None else balanced + if targets is None: if balanced: - return LoadBalancedView(client=self) + return self._cache_view(None,True) else: targets = slice(None) - if balanced: - view_class = LoadBalancedView - else: - view_class = DirectView if isinstance(targets, int): if targets not in self.ids: raise IndexError("No such engine: %i"%targets) - return view_class(client=self, targets=targets) + return self._cache_view(targets, balanced) if isinstance(targets, slice): indices = range(len(self.ids))[targets] @@ -1095,7 +1145,7 @@ class Client(object): if isinstance(targets, (tuple, list, xrange)): _,targets = self._build_targets(list(targets)) - return view_class(client=self, targets=targets) + return self._cache_view(targets, balanced) else: raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) diff --git a/IPython/zmq/parallel/remotefunction.py b/IPython/zmq/parallel/remotefunction.py index 388a7e6..fa7283c 100644 --- a/IPython/zmq/parallel/remotefunction.py +++ b/IPython/zmq/parallel/remotefunction.py @@ -10,6 +10,8 @@ # Imports #----------------------------------------------------------------------------- +import warnings + import map as Map from asyncresult import AsyncMapResult @@ -17,7 +19,7 @@ from asyncresult import AsyncMapResult # Decorators #----------------------------------------------------------------------------- -def remote(client, bound=False, block=None, targets=None, balanced=None): +def remote(client, bound=True, block=None, targets=None, balanced=None): """Turn a function into a remote function. This method can be used for map: @@ -29,7 +31,7 @@ def remote(client, bound=False, block=None, targets=None, balanced=None): return RemoteFunction(client, f, bound, block, targets, balanced) return remote_function -def parallel(client, dist='b', bound=False, block=None, targets='all', balanced=None): +def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None): """Turn a function into a parallel remote function. This method can be used for map: @@ -93,8 +95,10 @@ class RemoteFunction(object): class ParallelFunction(RemoteFunction): """Class for mapping a function to sequences.""" - def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None): + def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None): super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced) + self.chunk_size = chunk_size + mapClass = Map.dists[dist] self.mapObject = mapClass() @@ -106,12 +110,18 @@ class ParallelFunction(RemoteFunction): raise ValueError(msg) if self.balanced: - targets = [self.targets]*len_0 + if self.chunk_size: + nparts = len_0/self.chunk_size + int(len_0%self.chunk_size > 0) + else: + nparts = len_0 + targets = [self.targets]*nparts else: + if self.chunk_size: + warnings.warn("`chunk_size` is ignored when `balanced=False", UserWarning) # multiplexed: targets = self.client._build_targets(self.targets)[-1] + nparts = len(targets) - nparts = len(targets) msg_ids = [] # my_f = lambda *a: map(self.func, *a) for index, t in enumerate(targets): @@ -132,7 +142,7 @@ class ParallelFunction(RemoteFunction): else: f=self.func ar = self.client.apply(f, args=args, block=False, bound=self.bound, - targets=targets, balanced=self.balanced) + targets=t, balanced=self.balanced) msg_ids.append(ar.msg_ids[0]) diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index fb039ab..1fc3f3f 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -134,7 +134,7 @@ class View(HasTraits): raise KeyError("Invalid name: %r"%key) for name in ('block', 'bound'): if name in kwargs: - setattr(self, name, kwargs) + setattr(self, name, kwargs[name]) #---------------------------------------------------------------- # wrappers for client methods: @@ -249,16 +249,49 @@ class View(HasTraits): return self.client.purge_results(msg_ids=msg_ids, targets=targets) #------------------------------------------------------------------- + # Map + #------------------------------------------------------------------- + + def map(self, f, *sequences, **kwargs): + """override in subclasses""" + raise NotImplementedError + + def map_async(self, f, *sequences, **kwargs): + """Parallel version of builtin `map`, using this view's engines. + + This is equivalent to map(...block=False) + + See `map` for details. + """ + if 'block' in kwargs: + raise TypeError("map_async doesn't take a `block` keyword argument.") + kwargs['block'] = False + return self.map(f,*sequences,**kwargs) + + def map_sync(self, f, *sequences, **kwargs): + """Parallel version of builtin `map`, using this view's engines. + + This is equivalent to map(...block=True) + + See `map` for details. + """ + if 'block' in kwargs: + raise TypeError("map_sync doesn't take a `block` keyword argument.") + kwargs['block'] = True + return self.map(f,*sequences,**kwargs) + + #------------------------------------------------------------------- # Decorators #------------------------------------------------------------------- - def parallel(self, bound=True, block=True): - """Decorator for making a ParallelFunction""" - return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced) def remote(self, bound=True, block=True): """Decorator for making a RemoteFunction""" - return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced) + return remote(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced) + def parallel(self, dist='b', bound=True, block=None): + """Decorator for making a ParallelFunction""" + block = self.block if block is None else block + return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced) class DirectView(View): @@ -325,17 +358,10 @@ class DirectView(View): raise TypeError("invalid keyword arg, %r"%k) assert len(sequences) > 0, "must have some sequences to map onto!" - pf = ParallelFunction(self.client, f, block=block, - bound=bound, targets=self.targets, balanced=False) + pf = ParallelFunction(self.client, f, block=block, bound=bound, + targets=self.targets, balanced=False) return pf.map(*sequences) - def map_async(self, f, *sequences, **kwargs): - """Parallel version of builtin `map`, using this view's engines.""" - if 'block' in kwargs: - raise TypeError("map_async doesn't take a `block` keyword argument.") - kwargs['block'] = True - return self.map(f,*sequences,**kwargs) - @sync_results @save_ids def execute(self, code, block=True): @@ -446,12 +472,12 @@ class LoadBalancedView(View): """ - _apply_name = 'apply_balanced' _default_names = ['block', 'bound', 'follow', 'after', 'timeout'] def __init__(self, client=None, targets=None): super(LoadBalancedView, self).__init__(client=client, targets=targets) self._ntargets = 1 + self._balanced = True def _validate_dependency(self, dep): """validate a dependency. @@ -547,26 +573,20 @@ class LoadBalancedView(View): """ + # default block = kwargs.get('block', self.block) bound = kwargs.get('bound', self.bound) + chunk_size = kwargs.get('chunk_size', 1) + + keyset = set(kwargs.keys()) + extra_keys = keyset.difference_update(set(['block', 'bound', 'chunk_size'])) + if extra_keys: + raise TypeError("Invalid kwargs: %s"%list(extra_keys)) assert len(sequences) > 0, "must have some sequences to map onto!" pf = ParallelFunction(self.client, f, block=block, bound=bound, - targets=self.targets, balanced=True) + targets=self.targets, balanced=True, + chunk_size=chunk_size) return pf.map(*sequences) - def map_async(self, f, *sequences, **kwargs): - """Parallel version of builtin `map`, using this view's engines. - - This is equivalent to map(...block=False) - - See `map` for details. - """ - - if 'block' in kwargs: - raise TypeError("map_async doesn't take a `block` keyword argument.") - kwargs['block'] = True - return self.map(f,*sequences,**kwargs) - - diff --git a/docs/examples/newparallel/dagdeps.py b/docs/examples/newparallel/dagdeps.py index ee2bcc8..8fe96e3 100644 --- a/docs/examples/newparallel/dagdeps.py +++ b/docs/examples/newparallel/dagdeps.py @@ -76,6 +76,7 @@ def main(nodes, edges): in-degree on the y (just for spread). All arrows must point at least slightly to the right if the graph is valid. """ + import pylab from matplotlib.dates import date2num from matplotlib.cm import gist_rainbow print "building DAG" @@ -99,7 +100,15 @@ def main(nodes, edges): pos[node] = (start, runtime) colors[node] = md.engine_id validate_tree(G, results) - nx.draw(G, pos, node_list = colors.keys(), node_color=colors.values(), cmap=gist_rainbow) + nx.draw(G, pos, node_list=colors.keys(), node_color=colors.values(), cmap=gist_rainbow, + with_labels=False) + x,y = zip(*pos.values()) + xmin,ymin = map(min, (x,y)) + xmax,ymax = map(max, (x,y)) + xscale = xmax-xmin + yscale = ymax-ymin + pylab.xlim(xmin-xscale*.1,xmax+xscale*.1) + pylab.ylim(ymin-yscale*.1,ymax+yscale*.1) return G,results if __name__ == '__main__': diff --git a/docs/examples/newparallel/mcdriver.py b/docs/examples/newparallel/mcdriver.py index 7d9e905..f068c10 100644 --- a/docs/examples/newparallel/mcdriver.py +++ b/docs/examples/newparallel/mcdriver.py @@ -49,7 +49,7 @@ c = client.Client(profile=cluster_profile) # A LoadBalancedView is an interface to the engines that provides dynamic load # balancing at the expense of not knowing which engine will execute the code. -view = c[None] +view = c.view() # Initialize the common code on the engines. This Python module has the # price_options function that prices the options. diff --git a/docs/examples/newparallel/parallelpi.py b/docs/examples/newparallel/parallelpi.py index 54c11a0..6216fdf 100644 --- a/docs/examples/newparallel/parallelpi.py +++ b/docs/examples/newparallel/parallelpi.py @@ -27,15 +27,17 @@ filestring = 'pi200m.ascii.%(i)02dof20' files = [filestring % {'i':i} for i in range(1,16)] # Connect to the IPython cluster -c = client.Client() +c = client.Client(profile='edison') c.run('pidigits.py') # the number of engines -n = len(c.ids) -id0 = list(c.ids)[0] +n = len(c) +id0 = c.ids[0] +v = c[:] +v.set_flags(bound=True,block=True) # fetch the pi-files print "downloading %i files of pi"%n -c.map(fetch_pi_file, files[:n]) +v.map(fetch_pi_file, files[:n]) print "done" # Run 10m digits on 1 engine @@ -48,8 +50,7 @@ print "Digits per second (1 core, 10m digits): ", digits_per_second1 # Run n*10m digits on all engines t1 = clock() -c.block=True -freqs_all = c.map(compute_two_digit_freqs, files[:n]) +freqs_all = v.map(compute_two_digit_freqs, files[:n]) freqs150m = reduce_freqs(freqs_all) t2 = clock() digits_per_second8 = n*10.0e6/(t2-t1) diff --git a/docs/examples/newparallel/pidigits.py b/docs/examples/newparallel/pidigits.py index da27c85..6c58ae9 100644 --- a/docs/examples/newparallel/pidigits.py +++ b/docs/examples/newparallel/pidigits.py @@ -18,9 +18,6 @@ should be equal. # Import statements from __future__ import division, with_statement -import os -import urllib - import numpy as np from matplotlib import pyplot as plt @@ -30,6 +27,7 @@ def fetch_pi_file(filename): """This will download a segment of pi from super-computing.org if the file is not already present. """ + import os, urllib ftpdir="ftp://pi.super-computing.org/.2/pi200m/" if os.path.exists(filename): # we already have it diff --git a/docs/source/parallelz/parallel_demos.txt b/docs/source/parallelz/parallel_demos.txt index c1a872a..e6c0fbf 100644 --- a/docs/source/parallelz/parallel_demos.txt +++ b/docs/source/parallelz/parallel_demos.txt @@ -135,6 +135,7 @@ calculation can also be run by simply typing the commands from # We simply pass Client the name of the cluster profile we # are using. In [2]: c = client.Client(profile='mycluster') + In [3]: view = c.view(balanced=True) In [3]: c.ids Out[3]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]