##// END OF EJS Templates
Client -> HasTraits, update examples with API tweaks
MinRK -
Show More
@@ -24,6 +24,8 b' import zmq'
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 Dict, List, Bool, Str, Set)
27 from IPython.external.decorator import decorator
29 from IPython.external.decorator import decorator
28 from IPython.external.ssh import tunnel
30 from IPython.external.ssh import tunnel
29
31
@@ -147,7 +149,7 b' class Metadata(dict):'
147 raise KeyError(key)
149 raise KeyError(key)
148
150
149
151
150 class Client(object):
152 class Client(HasTraits):
151 """A semi-synchronous client to the IPython ZMQ controller
153 """A semi-synchronous client to the IPython ZMQ controller
152
154
153 Parameters
155 Parameters
@@ -247,31 +249,41 b' class Client(object):'
247 """
249 """
248
250
249
251
250 _connected=False
252 block = Bool(False)
251 _ssh=False
253 outstanding=Set()
252 _engines=None
254 results = Dict()
253 _registration_socket=None
255 metadata = Dict()
254 _query_socket=None
256 history = List()
255 _control_socket=None
257 debug = Bool(False)
256 _iopub_socket=None
258 profile=CUnicode('default')
257 _notification_socket=None
259
258 _mux_socket=None
260 _ids = List()
259 _task_socket=None
261 _connected=Bool(False)
260 _task_scheme=None
262 _ssh=Bool(False)
261 block = False
263 _context = Instance('zmq.Context')
262 outstanding=None
264 _config = Dict()
263 results = None
265 _engines=Instance(ReverseDict, (), {})
264 history = None
266 _registration_socket=Instance('zmq.Socket')
265 debug = False
267 _query_socket=Instance('zmq.Socket')
266 targets = None
268 _control_socket=Instance('zmq.Socket')
269 _iopub_socket=Instance('zmq.Socket')
270 _notification_socket=Instance('zmq.Socket')
271 _mux_socket=Instance('zmq.Socket')
272 _task_socket=Instance('zmq.Socket')
273 _task_scheme=Str()
274 _balanced_views=Dict()
275 _direct_views=Dict()
276 _closed = False
267
277
268 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
278 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
269 context=None, username=None, debug=False, exec_key=None,
279 context=None, username=None, debug=False, exec_key=None,
270 sshserver=None, sshkey=None, password=None, paramiko=None,
280 sshserver=None, sshkey=None, password=None, paramiko=None,
271 ):
281 ):
282 super(Client, self).__init__(debug=debug, profile=profile)
272 if context is None:
283 if context is None:
273 context = zmq.Context()
284 context = zmq.Context()
274 self.context = context
285 self._context = context
286
275
287
276 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
288 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
277 if self._cd is not None:
289 if self._cd is not None:
@@ -325,20 +337,14 b' class Client(object):'
325 self.session = ss.StreamSession(**key_arg)
337 self.session = ss.StreamSession(**key_arg)
326 else:
338 else:
327 self.session = ss.StreamSession(username, **key_arg)
339 self.session = ss.StreamSession(username, **key_arg)
328 self._registration_socket = self.context.socket(zmq.XREQ)
340 self._registration_socket = self._context.socket(zmq.XREQ)
329 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
341 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
330 if self._ssh:
342 if self._ssh:
331 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
343 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
332 else:
344 else:
333 self._registration_socket.connect(url)
345 self._registration_socket.connect(url)
334 self._engines = ReverseDict()
346
335 self._ids = []
347 self.session.debug = self.debug
336 self.outstanding=set()
337 self.results = {}
338 self.metadata = {}
339 self.history = []
340 self.debug = debug
341 self.session.debug = debug
342
348
343 self._notification_handlers = {'registration_notification' : self._register_engine,
349 self._notification_handlers = {'registration_notification' : self._register_engine,
344 'unregistration_notification' : self._unregister_engine,
350 'unregistration_notification' : self._unregister_engine,
@@ -370,6 +376,14 b' class Client(object):'
370 """Always up-to-date ids property."""
376 """Always up-to-date ids property."""
371 self._flush_notifications()
377 self._flush_notifications()
372 return self._ids
378 return self._ids
379
380 def close(self):
381 if self._closed:
382 return
383 snames = filter(lambda n: n.endswith('socket'), dir(self))
384 for socket in map(lambda name: getattr(self, name), snames):
385 socket.close()
386 self._closed = True
373
387
374 def _update_engines(self, engines):
388 def _update_engines(self, engines):
375 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
389 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
@@ -436,28 +450,28 b' class Client(object):'
436 self._config['registration'] = dict(content)
450 self._config['registration'] = dict(content)
437 if content.status == 'ok':
451 if content.status == 'ok':
438 if content.mux:
452 if content.mux:
439 self._mux_socket = self.context.socket(zmq.PAIR)
453 self._mux_socket = self._context.socket(zmq.PAIR)
440 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
454 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
441 connect_socket(self._mux_socket, content.mux)
455 connect_socket(self._mux_socket, content.mux)
442 if content.task:
456 if content.task:
443 self._task_scheme, task_addr = content.task
457 self._task_scheme, task_addr = content.task
444 self._task_socket = self.context.socket(zmq.PAIR)
458 self._task_socket = self._context.socket(zmq.PAIR)
445 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
459 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
446 connect_socket(self._task_socket, task_addr)
460 connect_socket(self._task_socket, task_addr)
447 if content.notification:
461 if content.notification:
448 self._notification_socket = self.context.socket(zmq.SUB)
462 self._notification_socket = self._context.socket(zmq.SUB)
449 connect_socket(self._notification_socket, content.notification)
463 connect_socket(self._notification_socket, content.notification)
450 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
464 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
451 if content.query:
465 if content.query:
452 self._query_socket = self.context.socket(zmq.PAIR)
466 self._query_socket = self._context.socket(zmq.PAIR)
453 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
467 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
454 connect_socket(self._query_socket, content.query)
468 connect_socket(self._query_socket, content.query)
455 if content.control:
469 if content.control:
456 self._control_socket = self.context.socket(zmq.PAIR)
470 self._control_socket = self._context.socket(zmq.PAIR)
457 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
458 connect_socket(self._control_socket, content.control)
472 connect_socket(self._control_socket, content.control)
459 if content.iopub:
473 if content.iopub:
460 self._iopub_socket = self.context.socket(zmq.SUB)
474 self._iopub_socket = self._context.socket(zmq.SUB)
461 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
475 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
462 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
476 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
463 connect_socket(self._iopub_socket, content.iopub)
477 connect_socket(self._iopub_socket, content.iopub)
@@ -636,9 +650,13 b' class Client(object):'
636 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
637
651
638 #--------------------------------------------------------------------------
652 #--------------------------------------------------------------------------
639 # getitem
653 # len, getitem
640 #--------------------------------------------------------------------------
654 #--------------------------------------------------------------------------
641
655
656 def __len__(self):
657 """len(client) returns # of engines."""
658 return len(self.ids)
659
642 def __getitem__(self, key):
660 def __getitem__(self, key):
643 """index access returns DirectView multiplexer objects
661 """index access returns DirectView multiplexer objects
644
662
@@ -929,8 +947,9 b' class Client(object):'
929 else:
947 else:
930 return list of results, matching `targets`
948 return list of results, matching `targets`
931 """
949 """
932
950 assert not self._closed, "cannot use me anymore, I'm closed!"
933 # defaults:
951 # defaults:
952 block = block if block is not None else self.block
934 args = args if args is not None else []
953 args = args if args is not None else []
935 kwargs = kwargs if kwargs is not None else {}
954 kwargs = kwargs if kwargs is not None else {}
936
955
@@ -955,7 +974,7 b' class Client(object):'
955 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
974 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
956
975
957 options = dict(bound=bound, block=block, targets=targets)
976 options = dict(bound=bound, block=block, targets=targets)
958
977
959 if balanced:
978 if balanced:
960 return self._apply_balanced(f, args, kwargs, timeout=timeout,
979 return self._apply_balanced(f, args, kwargs, timeout=timeout,
961 after=after, follow=follow, **options)
980 after=after, follow=follow, **options)
@@ -966,7 +985,7 b' class Client(object):'
966 else:
985 else:
967 return self._apply_direct(f, args, kwargs, **options)
986 return self._apply_direct(f, args, kwargs, **options)
968
987
969 def _apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None,
988 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
970 after=None, follow=None, timeout=None):
989 after=None, follow=None, timeout=None):
971 """call f(*args, **kwargs) remotely in a load-balanced manner.
990 """call f(*args, **kwargs) remotely in a load-balanced manner.
972
991
@@ -974,8 +993,9 b' class Client(object):'
974 Not to be called directly!
993 Not to be called directly!
975 """
994 """
976
995
977 for kwarg in (bound, block, targets):
996 loc = locals()
978 assert kwarg is not None, "kwarg %r must be specified!"%kwarg
997 for name in ('bound', 'block'):
998 assert loc[name] is not None, "kwarg %r must be specified!"%name
979
999
980 if self._task_socket is None:
1000 if self._task_socket is None:
981 msg = "Task farming is disabled"
1001 msg = "Task farming is disabled"
@@ -1030,9 +1050,9 b' class Client(object):'
1030 This is a private method, see `apply` for details.
1050 This is a private method, see `apply` for details.
1031 Not to be called directly!
1051 Not to be called directly!
1032 """
1052 """
1033
1053 loc = locals()
1034 for kwarg in (bound, block, targets):
1054 for name in ('bound', 'block', 'targets'):
1035 assert kwarg is not None, "kwarg %r must be specified!"%kwarg
1055 assert loc[name] is not None, "kwarg %r must be specified!"%name
1036
1056
1037 idents,targets = self._build_targets(targets)
1057 idents,targets = self._build_targets(targets)
1038
1058
@@ -1058,35 +1078,65 b' class Client(object):'
1058 return ar
1078 return ar
1059
1079
1060 #--------------------------------------------------------------------------
1080 #--------------------------------------------------------------------------
1061 # decorators
1081 # construct a View object
1062 #--------------------------------------------------------------------------
1082 #--------------------------------------------------------------------------
1063
1083
1064 @defaultblock
1084 @defaultblock
1065 def parallel(self, bound=True, targets='all', block=None):
1085 def remote(self, bound=True, block=None, targets=None, balanced=None):
1066 """Decorator for making a ParallelFunction."""
1086 """Decorator for making a RemoteFunction"""
1067 return parallel(self, bound=bound, targets=targets, block=block)
1087 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1068
1088
1069 @defaultblock
1089 @defaultblock
1070 def remote(self, bound=True, targets='all', block=None):
1090 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1071 """Decorator for making a RemoteFunction."""
1091 """Decorator for making a ParallelFunction"""
1072 return remote(self, bound=bound, targets=targets, block=block)
1092 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1073
1093
1074 def view(self, targets=None, balanced=False):
1094 def _cache_view(self, targets, balanced):
1075 """Method for constructing View objects"""
1095 """save views, so subsequent requests don't create new objects."""
1096 if balanced:
1097 view_class = LoadBalancedView
1098 view_cache = self._balanced_views
1099 else:
1100 view_class = DirectView
1101 view_cache = self._direct_views
1102
1103 # use str, since often targets will be a list
1104 key = str(targets)
1105 if key not in view_cache:
1106 view_cache[key] = view_class(client=self, targets=targets)
1107
1108 return view_cache[key]
1109
1110 def view(self, targets=None, balanced=None):
1111 """Method for constructing View objects.
1112
1113 If no arguments are specified, create a LoadBalancedView
1114 using all engines. If only `targets` specified, it will
1115 be a DirectView. This method is the underlying implementation
1116 of ``client.__getitem__``.
1117
1118 Parameters
1119 ----------
1120
1121 targets: list,slice,int,etc. [default: use all engines]
1122 The engines to use for the View
1123 balanced : bool [default: False if targets specified, True else]
1124 whether to build a LoadBalancedView or a DirectView
1125
1126 """
1127
1128 balanced = (targets is None) if balanced is None else balanced
1129
1076 if targets is None:
1130 if targets is None:
1077 if balanced:
1131 if balanced:
1078 return LoadBalancedView(client=self)
1132 return self._cache_view(None,True)
1079 else:
1133 else:
1080 targets = slice(None)
1134 targets = slice(None)
1081
1135
1082 if balanced:
1083 view_class = LoadBalancedView
1084 else:
1085 view_class = DirectView
1086 if isinstance(targets, int):
1136 if isinstance(targets, int):
1087 if targets not in self.ids:
1137 if targets not in self.ids:
1088 raise IndexError("No such engine: %i"%targets)
1138 raise IndexError("No such engine: %i"%targets)
1089 return view_class(client=self, targets=targets)
1139 return self._cache_view(targets, balanced)
1090
1140
1091 if isinstance(targets, slice):
1141 if isinstance(targets, slice):
1092 indices = range(len(self.ids))[targets]
1142 indices = range(len(self.ids))[targets]
@@ -1095,7 +1145,7 b' class Client(object):'
1095
1145
1096 if isinstance(targets, (tuple, list, xrange)):
1146 if isinstance(targets, (tuple, list, xrange)):
1097 _,targets = self._build_targets(list(targets))
1147 _,targets = self._build_targets(list(targets))
1098 return view_class(client=self, targets=targets)
1148 return self._cache_view(targets, balanced)
1099 else:
1149 else:
1100 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1150 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1101
1151
@@ -10,6 +10,8 b''
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import warnings
14
13 import map as Map
15 import map as Map
14 from asyncresult import AsyncMapResult
16 from asyncresult import AsyncMapResult
15
17
@@ -17,7 +19,7 b' from asyncresult import AsyncMapResult'
17 # Decorators
19 # Decorators
18 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
19
21
20 def remote(client, bound=False, block=None, targets=None, balanced=None):
22 def remote(client, bound=True, block=None, targets=None, balanced=None):
21 """Turn a function into a remote function.
23 """Turn a function into a remote function.
22
24
23 This method can be used for map:
25 This method can be used for map:
@@ -29,7 +31,7 b' def remote(client, bound=False, block=None, targets=None, balanced=None):'
29 return RemoteFunction(client, f, bound, block, targets, balanced)
31 return RemoteFunction(client, f, bound, block, targets, balanced)
30 return remote_function
32 return remote_function
31
33
32 def parallel(client, dist='b', bound=False, block=None, targets='all', balanced=None):
34 def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None):
33 """Turn a function into a parallel remote function.
35 """Turn a function into a parallel remote function.
34
36
35 This method can be used for map:
37 This method can be used for map:
@@ -93,8 +95,10 b' class RemoteFunction(object):'
93
95
94 class ParallelFunction(RemoteFunction):
96 class ParallelFunction(RemoteFunction):
95 """Class for mapping a function to sequences."""
97 """Class for mapping a function to sequences."""
96 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None):
98 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None):
97 super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
99 super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
100 self.chunk_size = chunk_size
101
98 mapClass = Map.dists[dist]
102 mapClass = Map.dists[dist]
99 self.mapObject = mapClass()
103 self.mapObject = mapClass()
100
104
@@ -106,12 +110,18 b' class ParallelFunction(RemoteFunction):'
106 raise ValueError(msg)
110 raise ValueError(msg)
107
111
108 if self.balanced:
112 if self.balanced:
109 targets = [self.targets]*len_0
113 if self.chunk_size:
114 nparts = len_0/self.chunk_size + int(len_0%self.chunk_size > 0)
115 else:
116 nparts = len_0
117 targets = [self.targets]*nparts
110 else:
118 else:
119 if self.chunk_size:
120 warnings.warn("`chunk_size` is ignored when `balanced=False", UserWarning)
111 # multiplexed:
121 # multiplexed:
112 targets = self.client._build_targets(self.targets)[-1]
122 targets = self.client._build_targets(self.targets)[-1]
123 nparts = len(targets)
113
124
114 nparts = len(targets)
115 msg_ids = []
125 msg_ids = []
116 # my_f = lambda *a: map(self.func, *a)
126 # my_f = lambda *a: map(self.func, *a)
117 for index, t in enumerate(targets):
127 for index, t in enumerate(targets):
@@ -132,7 +142,7 b' class ParallelFunction(RemoteFunction):'
132 else:
142 else:
133 f=self.func
143 f=self.func
134 ar = self.client.apply(f, args=args, block=False, bound=self.bound,
144 ar = self.client.apply(f, args=args, block=False, bound=self.bound,
135 targets=targets, balanced=self.balanced)
145 targets=t, balanced=self.balanced)
136
146
137 msg_ids.append(ar.msg_ids[0])
147 msg_ids.append(ar.msg_ids[0])
138
148
@@ -134,7 +134,7 b' class View(HasTraits):'
134 raise KeyError("Invalid name: %r"%key)
134 raise KeyError("Invalid name: %r"%key)
135 for name in ('block', 'bound'):
135 for name in ('block', 'bound'):
136 if name in kwargs:
136 if name in kwargs:
137 setattr(self, name, kwargs)
137 setattr(self, name, kwargs[name])
138
138
139 #----------------------------------------------------------------
139 #----------------------------------------------------------------
140 # wrappers for client methods:
140 # wrappers for client methods:
@@ -249,16 +249,49 b' class View(HasTraits):'
249 return self.client.purge_results(msg_ids=msg_ids, targets=targets)
249 return self.client.purge_results(msg_ids=msg_ids, targets=targets)
250
250
251 #-------------------------------------------------------------------
251 #-------------------------------------------------------------------
252 # Map
253 #-------------------------------------------------------------------
254
255 def map(self, f, *sequences, **kwargs):
256 """override in subclasses"""
257 raise NotImplementedError
258
259 def map_async(self, f, *sequences, **kwargs):
260 """Parallel version of builtin `map`, using this view's engines.
261
262 This is equivalent to map(...block=False)
263
264 See `map` for details.
265 """
266 if 'block' in kwargs:
267 raise TypeError("map_async doesn't take a `block` keyword argument.")
268 kwargs['block'] = False
269 return self.map(f,*sequences,**kwargs)
270
271 def map_sync(self, f, *sequences, **kwargs):
272 """Parallel version of builtin `map`, using this view's engines.
273
274 This is equivalent to map(...block=True)
275
276 See `map` for details.
277 """
278 if 'block' in kwargs:
279 raise TypeError("map_sync doesn't take a `block` keyword argument.")
280 kwargs['block'] = True
281 return self.map(f,*sequences,**kwargs)
282
283 #-------------------------------------------------------------------
252 # Decorators
284 # Decorators
253 #-------------------------------------------------------------------
285 #-------------------------------------------------------------------
254 def parallel(self, bound=True, block=True):
255 """Decorator for making a ParallelFunction"""
256 return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
257
286
258 def remote(self, bound=True, block=True):
287 def remote(self, bound=True, block=True):
259 """Decorator for making a RemoteFunction"""
288 """Decorator for making a RemoteFunction"""
260 return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
289 return remote(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
261
290
291 def parallel(self, dist='b', bound=True, block=None):
292 """Decorator for making a ParallelFunction"""
293 block = self.block if block is None else block
294 return parallel(self.client, bound=bound, targets=self.targets, block=block, balanced=self._balanced)
262
295
263
296
264 class DirectView(View):
297 class DirectView(View):
@@ -325,17 +358,10 b' class DirectView(View):'
325 raise TypeError("invalid keyword arg, %r"%k)
358 raise TypeError("invalid keyword arg, %r"%k)
326
359
327 assert len(sequences) > 0, "must have some sequences to map onto!"
360 assert len(sequences) > 0, "must have some sequences to map onto!"
328 pf = ParallelFunction(self.client, f, block=block,
361 pf = ParallelFunction(self.client, f, block=block, bound=bound,
329 bound=bound, targets=self.targets, balanced=False)
362 targets=self.targets, balanced=False)
330 return pf.map(*sequences)
363 return pf.map(*sequences)
331
364
332 def map_async(self, f, *sequences, **kwargs):
333 """Parallel version of builtin `map`, using this view's engines."""
334 if 'block' in kwargs:
335 raise TypeError("map_async doesn't take a `block` keyword argument.")
336 kwargs['block'] = True
337 return self.map(f,*sequences,**kwargs)
338
339 @sync_results
365 @sync_results
340 @save_ids
366 @save_ids
341 def execute(self, code, block=True):
367 def execute(self, code, block=True):
@@ -446,12 +472,12 b' class LoadBalancedView(View):'
446
472
447 """
473 """
448
474
449 _apply_name = 'apply_balanced'
450 _default_names = ['block', 'bound', 'follow', 'after', 'timeout']
475 _default_names = ['block', 'bound', 'follow', 'after', 'timeout']
451
476
452 def __init__(self, client=None, targets=None):
477 def __init__(self, client=None, targets=None):
453 super(LoadBalancedView, self).__init__(client=client, targets=targets)
478 super(LoadBalancedView, self).__init__(client=client, targets=targets)
454 self._ntargets = 1
479 self._ntargets = 1
480 self._balanced = True
455
481
456 def _validate_dependency(self, dep):
482 def _validate_dependency(self, dep):
457 """validate a dependency.
483 """validate a dependency.
@@ -547,26 +573,20 b' class LoadBalancedView(View):'
547
573
548 """
574 """
549
575
576 # default
550 block = kwargs.get('block', self.block)
577 block = kwargs.get('block', self.block)
551 bound = kwargs.get('bound', self.bound)
578 bound = kwargs.get('bound', self.bound)
579 chunk_size = kwargs.get('chunk_size', 1)
580
581 keyset = set(kwargs.keys())
582 extra_keys = keyset.difference_update(set(['block', 'bound', 'chunk_size']))
583 if extra_keys:
584 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
552
585
553 assert len(sequences) > 0, "must have some sequences to map onto!"
586 assert len(sequences) > 0, "must have some sequences to map onto!"
554
587
555 pf = ParallelFunction(self.client, f, block=block, bound=bound,
588 pf = ParallelFunction(self.client, f, block=block, bound=bound,
556 targets=self.targets, balanced=True)
589 targets=self.targets, balanced=True,
590 chunk_size=chunk_size)
557 return pf.map(*sequences)
591 return pf.map(*sequences)
558
592
559 def map_async(self, f, *sequences, **kwargs):
560 """Parallel version of builtin `map`, using this view's engines.
561
562 This is equivalent to map(...block=False)
563
564 See `map` for details.
565 """
566
567 if 'block' in kwargs:
568 raise TypeError("map_async doesn't take a `block` keyword argument.")
569 kwargs['block'] = True
570 return self.map(f,*sequences,**kwargs)
571
572
@@ -76,6 +76,7 b' def main(nodes, edges):'
76 in-degree on the y (just for spread). All arrows must
76 in-degree on the y (just for spread). All arrows must
77 point at least slightly to the right if the graph is valid.
77 point at least slightly to the right if the graph is valid.
78 """
78 """
79 import pylab
79 from matplotlib.dates import date2num
80 from matplotlib.dates import date2num
80 from matplotlib.cm import gist_rainbow
81 from matplotlib.cm import gist_rainbow
81 print "building DAG"
82 print "building DAG"
@@ -99,7 +100,15 b' def main(nodes, edges):'
99 pos[node] = (start, runtime)
100 pos[node] = (start, runtime)
100 colors[node] = md.engine_id
101 colors[node] = md.engine_id
101 validate_tree(G, results)
102 validate_tree(G, results)
102 nx.draw(G, pos, node_list = colors.keys(), node_color=colors.values(), cmap=gist_rainbow)
103 nx.draw(G, pos, node_list=colors.keys(), node_color=colors.values(), cmap=gist_rainbow,
104 with_labels=False)
105 x,y = zip(*pos.values())
106 xmin,ymin = map(min, (x,y))
107 xmax,ymax = map(max, (x,y))
108 xscale = xmax-xmin
109 yscale = ymax-ymin
110 pylab.xlim(xmin-xscale*.1,xmax+xscale*.1)
111 pylab.ylim(ymin-yscale*.1,ymax+yscale*.1)
103 return G,results
112 return G,results
104
113
105 if __name__ == '__main__':
114 if __name__ == '__main__':
@@ -49,7 +49,7 b' c = client.Client(profile=cluster_profile)'
49
49
50 # A LoadBalancedView is an interface to the engines that provides dynamic load
50 # A LoadBalancedView is an interface to the engines that provides dynamic load
51 # balancing at the expense of not knowing which engine will execute the code.
51 # balancing at the expense of not knowing which engine will execute the code.
52 view = c[None]
52 view = c.view()
53
53
54 # Initialize the common code on the engines. This Python module has the
54 # Initialize the common code on the engines. This Python module has the
55 # price_options function that prices the options.
55 # price_options function that prices the options.
@@ -27,15 +27,17 b" filestring = 'pi200m.ascii.%(i)02dof20'"
27 files = [filestring % {'i':i} for i in range(1,16)]
27 files = [filestring % {'i':i} for i in range(1,16)]
28
28
29 # Connect to the IPython cluster
29 # Connect to the IPython cluster
30 c = client.Client()
30 c = client.Client(profile='edison')
31 c.run('pidigits.py')
31 c.run('pidigits.py')
32
32
33 # the number of engines
33 # the number of engines
34 n = len(c.ids)
34 n = len(c)
35 id0 = list(c.ids)[0]
35 id0 = c.ids[0]
36 v = c[:]
37 v.set_flags(bound=True,block=True)
36 # fetch the pi-files
38 # fetch the pi-files
37 print "downloading %i files of pi"%n
39 print "downloading %i files of pi"%n
38 c.map(fetch_pi_file, files[:n])
40 v.map(fetch_pi_file, files[:n])
39 print "done"
41 print "done"
40
42
41 # Run 10m digits on 1 engine
43 # Run 10m digits on 1 engine
@@ -48,8 +50,7 b' print "Digits per second (1 core, 10m digits): ", digits_per_second1'
48
50
49 # Run n*10m digits on all engines
51 # Run n*10m digits on all engines
50 t1 = clock()
52 t1 = clock()
51 c.block=True
53 freqs_all = v.map(compute_two_digit_freqs, files[:n])
52 freqs_all = c.map(compute_two_digit_freqs, files[:n])
53 freqs150m = reduce_freqs(freqs_all)
54 freqs150m = reduce_freqs(freqs_all)
54 t2 = clock()
55 t2 = clock()
55 digits_per_second8 = n*10.0e6/(t2-t1)
56 digits_per_second8 = n*10.0e6/(t2-t1)
@@ -18,9 +18,6 b' should be equal.'
18 # Import statements
18 # Import statements
19 from __future__ import division, with_statement
19 from __future__ import division, with_statement
20
20
21 import os
22 import urllib
23
24 import numpy as np
21 import numpy as np
25 from matplotlib import pyplot as plt
22 from matplotlib import pyplot as plt
26
23
@@ -30,6 +27,7 b' def fetch_pi_file(filename):'
30 """This will download a segment of pi from super-computing.org
27 """This will download a segment of pi from super-computing.org
31 if the file is not already present.
28 if the file is not already present.
32 """
29 """
30 import os, urllib
33 ftpdir="ftp://pi.super-computing.org/.2/pi200m/"
31 ftpdir="ftp://pi.super-computing.org/.2/pi200m/"
34 if os.path.exists(filename):
32 if os.path.exists(filename):
35 # we already have it
33 # we already have it
@@ -135,6 +135,7 b' calculation can also be run by simply typing the commands from'
135 # We simply pass Client the name of the cluster profile we
135 # We simply pass Client the name of the cluster profile we
136 # are using.
136 # are using.
137 In [2]: c = client.Client(profile='mycluster')
137 In [2]: c = client.Client(profile='mycluster')
138 In [3]: view = c.view(balanced=True)
138
139
139 In [3]: c.ids
140 In [3]: c.ids
140 Out[3]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
141 Out[3]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
General Comments 0
You need to be logged in to leave comments. Login now