##// END OF EJS Templates
Client -> HasTraits, update examples with API tweaks
MinRK -
Show More
@@ -24,6 +24,8 b' import zmq'
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 Dict, List, Bool, Str, Set)
27 29 from IPython.external.decorator import decorator
28 30 from IPython.external.ssh import tunnel
29 31
@@ -147,7 +149,7 b' class Metadata(dict):'
147 149 raise KeyError(key)
148 150
149 151
150 class Client(object):
152 class Client(HasTraits):
151 153 """A semi-synchronous client to the IPython ZMQ controller
152 154
153 155 Parameters
@@ -247,31 +249,41 b' class Client(object):'
247 249 """
248 250
249 251
250 _connected=False
251 _ssh=False
252 _engines=None
253 _registration_socket=None
254 _query_socket=None
255 _control_socket=None
256 _iopub_socket=None
257 _notification_socket=None
258 _mux_socket=None
259 _task_socket=None
260 _task_scheme=None
261 block = False
262 outstanding=None
263 results = None
264 history = None
265 debug = False
266 targets = None
252 block = Bool(False)
253 outstanding=Set()
254 results = Dict()
255 metadata = Dict()
256 history = List()
257 debug = Bool(False)
258 profile=CUnicode('default')
259
260 _ids = List()
261 _connected=Bool(False)
262 _ssh=Bool(False)
263 _context = Instance('zmq.Context')
264 _config = Dict()
265 _engines=Instance(ReverseDict, (), {})
266 _registration_socket=Instance('zmq.Socket')
267 _query_socket=Instance('zmq.Socket')
268 _control_socket=Instance('zmq.Socket')
269 _iopub_socket=Instance('zmq.Socket')
270 _notification_socket=Instance('zmq.Socket')
271 _mux_socket=Instance('zmq.Socket')
272 _task_socket=Instance('zmq.Socket')
273 _task_scheme=Str()
274 _balanced_views=Dict()
275 _direct_views=Dict()
276 _closed = False
267 277
268 278 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
269 279 context=None, username=None, debug=False, exec_key=None,
270 280 sshserver=None, sshkey=None, password=None, paramiko=None,
271 281 ):
282 super(Client, self).__init__(debug=debug, profile=profile)
272 283 if context is None:
273 284 context = zmq.Context()
274 self.context = context
285 self._context = context
286
275 287
276 288 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
277 289 if self._cd is not None:
@@ -325,20 +337,14 b' class Client(object):'
325 337 self.session = ss.StreamSession(**key_arg)
326 338 else:
327 339 self.session = ss.StreamSession(username, **key_arg)
328 self._registration_socket = self.context.socket(zmq.XREQ)
340 self._registration_socket = self._context.socket(zmq.XREQ)
329 341 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
330 342 if self._ssh:
331 343 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
332 344 else:
333 345 self._registration_socket.connect(url)
334 self._engines = ReverseDict()
335 self._ids = []
336 self.outstanding=set()
337 self.results = {}
338 self.metadata = {}
339 self.history = []
340 self.debug = debug
341 self.session.debug = debug
346
347 self.session.debug = self.debug
342 348
343 349 self._notification_handlers = {'registration_notification' : self._register_engine,
344 350 'unregistration_notification' : self._unregister_engine,
@@ -370,6 +376,14 b' class Client(object):'
370 376 """Always up-to-date ids property."""
371 377 self._flush_notifications()
372 378 return self._ids
379
380 def close(self):
381 if self._closed:
382 return
383 snames = filter(lambda n: n.endswith('socket'), dir(self))
384 for socket in map(lambda name: getattr(self, name), snames):
385 socket.close()
386 self._closed = True
373 387
374 388 def _update_engines(self, engines):
375 389 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
@@ -436,28 +450,28 b' class Client(object):'
436 450 self._config['registration'] = dict(content)
437 451 if content.status == 'ok':
438 452 if content.mux:
439 self._mux_socket = self.context.socket(zmq.PAIR)
453 self._mux_socket = self._context.socket(zmq.PAIR)
440 454 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
441 455 connect_socket(self._mux_socket, content.mux)
442 456 if content.task:
443 457 self._task_scheme, task_addr = content.task
444 self._task_socket = self.context.socket(zmq.PAIR)
458 self._task_socket = self._context.socket(zmq.PAIR)
445 459 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
446 460 connect_socket(self._task_socket, task_addr)
447 461 if content.notification:
448 self._notification_socket = self.context.socket(zmq.SUB)
462 self._notification_socket = self._context.socket(zmq.SUB)
449 463 connect_socket(self._notification_socket, content.notification)
450 464 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
451 465 if content.query:
452 self._query_socket = self.context.socket(zmq.PAIR)
466 self._query_socket = self._context.socket(zmq.PAIR)
453 467 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
454 468 connect_socket(self._query_socket, content.query)
455 469 if content.control:
456 self._control_socket = self.context.socket(zmq.PAIR)
470 self._control_socket = self._context.socket(zmq.PAIR)
457 471 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
458 472 connect_socket(self._control_socket, content.control)
459 473 if content.iopub:
460 self._iopub_socket = self.context.socket(zmq.SUB)
474 self._iopub_socket = self._context.socket(zmq.SUB)
461 475 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
462 476 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
463 477 connect_socket(self._iopub_socket, content.iopub)
@@ -636,9 +650,13 b' class Client(object):'
636 650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
637 651
638 652 #--------------------------------------------------------------------------
639 # getitem
653 # len, getitem
640 654 #--------------------------------------------------------------------------
641 655
656 def __len__(self):
657 """len(client) returns # of engines."""
658 return len(self.ids)
659
642 660 def __getitem__(self, key):
643 661 """index access returns DirectView multiplexer objects
644 662
@@ -929,8 +947,9 b' class Client(object):'
929 947 else:
930 948 return list of results, matching `targets`
931 949 """
932
950 assert not self._closed, "cannot use me anymore, I'm closed!"
933 951 # defaults:
952 block = block if block is not None else self.block
934 953 args = args if args is not None else []
935 954 kwargs = kwargs if kwargs is not None else {}
936 955
@@ -955,7 +974,7 b' class Client(object):'
955 974 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
956 975
957 976 options = dict(bound=bound, block=block, targets=targets)
958
977
959 978 if balanced:
960 979 return self._apply_balanced(f, args, kwargs, timeout=timeout,
961 980 after=after, follow=follow, **options)
@@ -966,7 +985,7 b' class Client(object):'
966 985 else:
967 986 return self._apply_direct(f, args, kwargs, **options)
968 987
969 def _apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None,
988 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
970 989 after=None, follow=None, timeout=None):
971 990 """call f(*args, **kwargs) remotely in a load-balanced manner.
972 991
@@ -974,8 +993,9 b' class Client(object):'
974 993 Not to be called directly!
975 994 """
976 995
977 for kwarg in (bound, block, targets):
978 assert kwarg is not None, "kwarg %r must be specified!"%kwarg
996 loc = locals()
997 for name in ('bound', 'block'):
998 assert loc[name] is not None, "kwarg %r must be specified!"%name
979 999
980 1000 if self._task_socket is None:
981 1001 msg = "Task farming is disabled"
@@ -1030,9 +1050,9 b' class Client(object):'
1030 1050 This is a private method, see `apply` for details.
1031 1051 Not to be called directly!
1032 1052 """
1033
1034 for kwarg in (bound, block, targets):
1035 assert kwarg is not None, "kwarg %r must be specified!"%kwarg
1053 loc = locals()
1054 for name in ('bound', 'block', 'targets'):
1055 assert loc[name] is not None, "kwarg %r must be specified!"%name
1036 1056
1037 1057 idents,targets = self._build_targets(targets)
1038 1058
@@ -1058,35 +1078,65 b' class Client(object):'
1058 1078 return ar
1059 1079
1060 1080 #--------------------------------------------------------------------------
1061 # decorators
1081 # construct a View object
1062 1082 #--------------------------------------------------------------------------
1063 1083
1064 1084 @defaultblock
1065 def parallel(self, bound=True, targets='all', block=None):
1066 """Decorator for making a ParallelFunction."""
1067 return parallel(self, bound=bound, targets=targets, block=block)
1085 def remote(self, bound=True, block=None, targets=None, balanced=None):
1086 """Decorator for making a RemoteFunction"""
1087 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1068 1088
1069 1089 @defaultblock
1070 def remote(self, bound=True, targets='all', block=None):
1071 """Decorator for making a RemoteFunction."""
1072 return remote(self, bound=bound, targets=targets, block=block)
1090 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1091 """Decorator for making a ParallelFunction"""
1092 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1073 1093
1074 def view(self, targets=None, balanced=False):
1075 """Method for constructing View objects"""
1094 def _cache_view(self, targets, balanced):
1095 """save views, so subsequent requests don't create new objects."""
1096 if balanced:
1097 view_class = LoadBalancedView
1098 view_cache = self._balanced_views
1099 else:
1100 view_class = DirectView
1101 view_cache = self._direct_views
1102
1103 # use str, since often targets will be a list
1104 key = str(targets)
1105 if key not in view_cache:
1106 view_cache[key] = view_class(client=self, targets=targets)
1107
1108 return view_cache[key]
1109
1110 def view(self, targets=None, balanced=None):
1111 """Method for constructing View objects.
1112
1113 If no arguments are specified, create a LoadBalancedView
1114 using all engines. If only `targets` specified, it will
1115 be a DirectView. This method is the underlying implementation
1116 of ``client.__getitem__``.
1117
1118 Parameters
1119 ----------
1120
1121 targets: list,slice,int,etc. [default: use all engines]
1122 The engines to use for the View
1123 balanced : bool [default: False if targets specified, True else]
1124 whether to build a LoadBalancedView or a DirectView
1125
1126 """
1127
1128 balanced = (targets is None) if balanced is None else balanced
1129
1076 1130 if targets is None:
1077 1131 if balanced:
1078 return LoadBalancedView(client=self)
1132 return self._cache_view(None,True)
1079 1133 else:
1080 1134 targets = slice(None)
1081 1135
1082 if balanced:
1083 view_class = LoadBalancedView
1084 else:
1085 view_class = DirectView
1086 1136 if isinstance(targets, int):
1087 1137 if targets not in self.ids:
1088 1138 raise IndexError("No such engine: %i"%targets)
1089 return view_class(client=self, targets=targets)
1139 return self._cache_view(targets, balanced)
1090 1140
1091 1141 if isinstance(targets, slice):
1092 1142 indices = range(len(self.ids))[targets]
@@ -1095,7 +1145,7 b' class Client(object):'
1095 1145
1096 1146 if isinstance(targets, (tuple, list, xrange)):
1097 1147 _,targets = self._build_targets(list(targets))
1098 return view_class(client=self, targets=targets)
1148 return self._cache_view(targets, balanced)
1099 1149 else:
1100 1150 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1101 1151
@@ -10,6 +10,8 b''
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 import warnings
14
13 15 import map as Map
14 16 from asyncresult import AsyncMapResult
15 17
@@ -17,7 +19,7 b' from asyncresult import AsyncMapResult'
17 19 # Decorators
18 20 #-----------------------------------------------------------------------------
19 21
20 def remote(client, bound=False, block=None, targets=None, balanced=None):
22 def remote(client, bound=True, block=None, targets=None, balanced=None):
21 23 """Turn a function into a remote function.
22 24
23 25 This method can be used for map:
@@ -29,7 +31,7 b' def remote(client, bound=False, block=None, targets=None, balanced=None):'
29 31 return RemoteFunction(client, f, bound, block, targets, balanced)
30 32 return remote_function
31 33
32 def parallel(client, dist='b', bound=False, block=None, targets='all', balanced=None):
34 def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None):
33 35 """Turn a function into a parallel remote function.
34 36
35 37 This method can be used for map:
@@ -93,8 +95,10 b' class RemoteFunction(object):'
93 95
94 96 class ParallelFunction(RemoteFunction):
95 97 """Class for mapping a function to sequences."""
96 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None):
98 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None):
97 99 super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
100 self.chunk_size = chunk_size
101
98 102 mapClass = Map.dists[dist]
99 103 self.mapObject = mapClass()
100 104
@@ -106,12 +110,18 b' class ParallelFunction(RemoteFunction):'
106 110 raise ValueError(msg)
107 111
108 112 if self.balanced:
109 targets = [self.targets]*len_0
113 if self.chunk_size:
114 nparts = len_0/self.chunk_size + int(len_0%self.chunk_size > 0)
115 else:
116 nparts = len_0
117 targets = [self.targets]*nparts
110 118 else:
119 if self.chunk_size:
120 warnings.warn("`chunk_size` is ignored when `balanced=False", UserWarning)
111 121 # multiplexed:
112 122 targets = self.client._build_targets(self.targets)[-1]
123 nparts = len(targets)
113 124
114 nparts = len(targets)
115 125 msg_ids = []
116 126 # my_f = lambda *a: map(self.func, *a)
117 127 for index, t in enumerate(targets):
@@ -132,7 +142,7 b' class ParallelFunction(RemoteFunction):'
132 142 else:
133 143 f=self.func
134 144 ar = self.client.apply(f, args=args, block=False, bound=self.bound,
135 targets=targets, balanced=self.balanced)
145 targets=t, balanced=self.balanced)
136 146
137 147 msg_ids.append(ar.msg_ids[0])
138 148
@@ -134,7 +134,7 b' class View(HasTraits):'
134 134 raise KeyError("Invalid name: %r"%key)
135 135 for name in ('block', 'bound'):
136 136 if name in kwargs:
137 setattr(self, name, kwargs)
137 setattr(self, name, kwargs[name])
138 138
139 139 #----------------------------------------------------------------
140 140 # wrappers for client methods:
@@ -249,16 +249,49 b' class View(HasTraits):'
249 249 return self.client.purge_results(msg_ids=msg_ids, targets=targets)
250 250
251 251 #-------------------------------------------------------------------
252 # Map
253 #-------------------------------------------------------------------
254
255 def map(self, f, *sequences, **kwargs):
256 """override in subclasses"""
257 raise NotImplementedError
258
259 def map_async(self, f, *sequences, **kwargs):
260 """Parallel version of builtin `map`, using this view's engines.
261
262 This is equivalent to map(...block=False)
263
264 See `map` for details.
265 """
266 if 'block' in kwargs:
267 raise TypeError("map_async doesn't take a `block` keyword argument.")
268 kwargs['block'] = False
269 return self.map(f,*sequences,**kwargs)
270
271 def map_sync(self, f, *sequences, **kwargs):
272 """Parallel version of builtin `map`, using this view's engines.
273
274 This is equivalent to map(...block=True)
275
276 See `map` for details.
277 """
278 if 'block' in kwargs:
279 raise TypeError("map_sync doesn't take a `block` keyword argument.")
280 kwargs['block'] = True
281 return self.map(f,*sequences,**kwargs)
282
283 #-------------------------------------------------------------------
252 284 # Decorators
253 285 #-------------------------------------------------------------------
254 def parallel(self, bound=True, block=True):
255 """Decorator for making a ParallelFunction"""
256 return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
257 286
258 287 def remote(self, bound=True, block=True):
259 288 """Decorator for making a RemoteFunction"""
260 return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
289 return remote(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
261 290
291 def parallel(self, dist='b', bound=True, block=None):
292 """Decorator for making a ParallelFunction"""
293 block = self.block if block is None else block
294 return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
262 295
263 296
264 297 class DirectView(View):
@@ -325,17 +358,10 b' class DirectView(View):'
325 358 raise TypeError("invalid keyword arg, %r"%k)
326 359
327 360 assert len(sequences) > 0, "must have some sequences to map onto!"
328 pf = ParallelFunction(self.client, f, block=block,
329 bound=bound, targets=self.targets, balanced=False)
361 pf = ParallelFunction(self.client, f, block=block, bound=bound,
362 targets=self.targets, balanced=False)
330 363 return pf.map(*sequences)
331 364
332 def map_async(self, f, *sequences, **kwargs):
333 """Parallel version of builtin `map`, using this view's engines."""
334 if 'block' in kwargs:
335 raise TypeError("map_async doesn't take a `block` keyword argument.")
336 kwargs['block'] = True
337 return self.map(f,*sequences,**kwargs)
338
339 365 @sync_results
340 366 @save_ids
341 367 def execute(self, code, block=True):
@@ -446,12 +472,12 b' class LoadBalancedView(View):'
446 472
447 473 """
448 474
449 _apply_name = 'apply_balanced'
450 475 _default_names = ['block', 'bound', 'follow', 'after', 'timeout']
451 476
452 477 def __init__(self, client=None, targets=None):
453 478 super(LoadBalancedView, self).__init__(client=client, targets=targets)
454 479 self._ntargets = 1
480 self._balanced = True
455 481
456 482 def _validate_dependency(self, dep):
457 483 """validate a dependency.
@@ -547,26 +573,20 b' class LoadBalancedView(View):'
547 573
548 574 """
549 575
576 # default
550 577 block = kwargs.get('block', self.block)
551 578 bound = kwargs.get('bound', self.bound)
579 chunk_size = kwargs.get('chunk_size', 1)
580
581 keyset = set(kwargs.keys())
582 extra_keys = keyset.difference_update(set(['block', 'bound', 'chunk_size']))
583 if extra_keys:
584 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
552 585
553 586 assert len(sequences) > 0, "must have some sequences to map onto!"
554 587
555 588 pf = ParallelFunction(self.client, f, block=block, bound=bound,
556 targets=self.targets, balanced=True)
589 targets=self.targets, balanced=True,
590 chunk_size=chunk_size)
557 591 return pf.map(*sequences)
558 592
559 def map_async(self, f, *sequences, **kwargs):
560 """Parallel version of builtin `map`, using this view's engines.
561
562 This is equivalent to map(...block=False)
563
564 See `map` for details.
565 """
566
567 if 'block' in kwargs:
568 raise TypeError("map_async doesn't take a `block` keyword argument.")
569 kwargs['block'] = True
570 return self.map(f,*sequences,**kwargs)
571
572
@@ -76,6 +76,7 b' def main(nodes, edges):'
76 76 in-degree on the y (just for spread). All arrows must
77 77 point at least slightly to the right if the graph is valid.
78 78 """
79 import pylab
79 80 from matplotlib.dates import date2num
80 81 from matplotlib.cm import gist_rainbow
81 82 print "building DAG"
@@ -99,7 +100,15 b' def main(nodes, edges):'
99 100 pos[node] = (start, runtime)
100 101 colors[node] = md.engine_id
101 102 validate_tree(G, results)
102 nx.draw(G, pos, node_list = colors.keys(), node_color=colors.values(), cmap=gist_rainbow)
103 nx.draw(G, pos, node_list=colors.keys(), node_color=colors.values(), cmap=gist_rainbow,
104 with_labels=False)
105 x,y = zip(*pos.values())
106 xmin,ymin = map(min, (x,y))
107 xmax,ymax = map(max, (x,y))
108 xscale = xmax-xmin
109 yscale = ymax-ymin
110 pylab.xlim(xmin-xscale*.1,xmax+xscale*.1)
111 pylab.ylim(ymin-yscale*.1,ymax+yscale*.1)
103 112 return G,results
104 113
105 114 if __name__ == '__main__':
@@ -49,7 +49,7 b' c = client.Client(profile=cluster_profile)'
49 49
50 50 # A LoadBalancedView is an interface to the engines that provides dynamic load
51 51 # balancing at the expense of not knowing which engine will execute the code.
52 view = c[None]
52 view = c.view()
53 53
54 54 # Initialize the common code on the engines. This Python module has the
55 55 # price_options function that prices the options.
@@ -27,15 +27,17 b" filestring = 'pi200m.ascii.%(i)02dof20'"
27 27 files = [filestring % {'i':i} for i in range(1,16)]
28 28
29 29 # Connect to the IPython cluster
30 c = client.Client()
30 c = client.Client(profile='edison')
31 31 c.run('pidigits.py')
32 32
33 33 # the number of engines
34 n = len(c.ids)
35 id0 = list(c.ids)[0]
34 n = len(c)
35 id0 = c.ids[0]
36 v = c[:]
37 v.set_flags(bound=True,block=True)
36 38 # fetch the pi-files
37 39 print "downloading %i files of pi"%n
38 c.map(fetch_pi_file, files[:n])
40 v.map(fetch_pi_file, files[:n])
39 41 print "done"
40 42
41 43 # Run 10m digits on 1 engine
@@ -48,8 +50,7 b' print "Digits per second (1 core, 10m digits): ", digits_per_second1'
48 50
49 51 # Run n*10m digits on all engines
50 52 t1 = clock()
51 c.block=True
52 freqs_all = c.map(compute_two_digit_freqs, files[:n])
53 freqs_all = v.map(compute_two_digit_freqs, files[:n])
53 54 freqs150m = reduce_freqs(freqs_all)
54 55 t2 = clock()
55 56 digits_per_second8 = n*10.0e6/(t2-t1)
@@ -18,9 +18,6 b' should be equal.'
18 18 # Import statements
19 19 from __future__ import division, with_statement
20 20
21 import os
22 import urllib
23
24 21 import numpy as np
25 22 from matplotlib import pyplot as plt
26 23
@@ -30,6 +27,7 b' def fetch_pi_file(filename):'
30 27 """This will download a segment of pi from super-computing.org
31 28 if the file is not already present.
32 29 """
30 import os, urllib
33 31 ftpdir="ftp://pi.super-computing.org/.2/pi200m/"
34 32 if os.path.exists(filename):
35 33 # we already have it
@@ -135,6 +135,7 b' calculation can also be run by simply typing the commands from'
135 135 # We simply pass Client the name of the cluster profile we
136 136 # are using.
137 137 In [2]: c = client.Client(profile='mycluster')
138 In [3]: view = c.view(balanced=True)
138 139
139 140 In [3]: c.ids
140 141 Out[3]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
General Comments 0
You need to be logged in to leave comments. Login now