##// END OF EJS Templates
add map/scatter/gather/ParallelFunction from kernel
MinRK -
Show More
@@ -0,0 +1,158 b''
1 # encoding: utf-8
2
3 """Classes used in scattering and gathering sequences.
4
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
7 """
8
9 __docformat__ = "restructuredtext en"
10
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
13 #
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
17
18 #-------------------------------------------------------------------------------
19 # Imports
20 #-------------------------------------------------------------------------------
21
22 import types
23
24 from IPython.utils.data import flatten as utils_flatten
25
26 #-------------------------------------------------------------------------------
27 # Figure out which array packages are present and their array types
28 #-------------------------------------------------------------------------------
29
30 arrayModules = []
31 try:
32 import Numeric
33 except ImportError:
34 pass
35 else:
36 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
37 try:
38 import numpy
39 except ImportError:
40 pass
41 else:
42 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
43 try:
44 import numarray
45 except ImportError:
46 pass
47 else:
48 arrayModules.append({'module':numarray,
49 'type':numarray.numarraycore.NumArray})
50
51 class Map:
52 """A class for partitioning a sequence using a map."""
53
54 def getPartition(self, seq, p, q):
55 """Returns the pth partition of q partitions of seq."""
56
57 # Test for error conditions here
58 if p<0 or p>=q:
59 print "No partition exists."
60 return
61
62 remainder = len(seq)%q
63 basesize = len(seq)/q
64 hi = []
65 lo = []
66 for n in range(q):
67 if n < remainder:
68 lo.append(n * (basesize + 1))
69 hi.append(lo[-1] + basesize + 1)
70 else:
71 lo.append(n*basesize + remainder)
72 hi.append(lo[-1] + basesize)
73
74
75 result = seq[lo[p]:hi[p]]
76 return result
77
78 def joinPartitions(self, listOfPartitions):
79 return self.concatenate(listOfPartitions)
80
81 def concatenate(self, listOfPartitions):
82 testObject = listOfPartitions[0]
83 # First see if we have a known array type
84 for m in arrayModules:
85 #print m
86 if isinstance(testObject, m['type']):
87 return m['module'].concatenate(listOfPartitions)
88 # Next try for Python sequence types
89 if isinstance(testObject, (types.ListType, types.TupleType)):
90 return utils_flatten(listOfPartitions)
91 # If we have scalars, just return listOfPartitions
92 return listOfPartitions
93
94 class RoundRobinMap(Map):
95 """Partitions a sequence in a roun robin fashion.
96
97 This currently does not work!
98 """
99
100 def getPartition(self, seq, p, q):
101 # if not isinstance(seq,(list,tuple)):
102 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
103 return seq[p:len(seq):q]
104 #result = []
105 #for i in range(p,len(seq),q):
106 # result.append(seq[i])
107 #return result
108
109 def joinPartitions(self, listOfPartitions):
110 testObject = listOfPartitions[0]
111 # First see if we have a known array type
112 for m in arrayModules:
113 #print m
114 if isinstance(testObject, m['type']):
115 return self.flatten_array(m['type'], listOfPartitions)
116 if isinstance(testObject, (types.ListType, types.TupleType)):
117 return self.flatten_list(listOfPartitions)
118 return listOfPartitions
119
120 def flatten_array(self, klass, listOfPartitions):
121 test = listOfPartitions[0]
122 shape = list(test.shape)
123 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
124 A = klass(shape)
125 N = shape[0]
126 q = len(listOfPartitions)
127 for p,part in enumerate(listOfPartitions):
128 A[p:N:q] = part
129 return A
130
131 def flatten_list(self, listOfPartitions):
132 flat = []
133 for i in range(len(listOfPartitions[0])):
134 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
135 return flat
136 #lengths = [len(x) for x in listOfPartitions]
137 #maxPartitionLength = len(listOfPartitions[0])
138 #numberOfPartitions = len(listOfPartitions)
139 #concat = self.concatenate(listOfPartitions)
140 #totalLength = len(concat)
141 #result = []
142 #for i in range(maxPartitionLength):
143 # result.append(concat[i:totalLength:maxPartitionLength])
144 # return self.concatenate(listOfPartitions)
145
146 def mappable(obj):
147 """return whether an object is mappable or not."""
148 if isinstance(obj, (tuple,list)):
149 return True
150 for m in arrayModules:
151 if isinstance(obj,m['type']):
152 return True
153 return False
154
155 dists = {'b':Map,'r':RoundRobinMap}
156
157
158
@@ -28,6 +28,7 b' import streamsession as ss'
28 28 from view import DirectView, LoadBalancedView
29 29 from dependency import Dependency, depend, require
30 30 import error
31 import map as Map
31 32
32 33 #--------------------------------------------------------------------------
33 34 # helpers for implementing old MEC API via client.apply
@@ -92,6 +93,18 b' def remote(client, bound=False, block=None, targets=None):'
92 93 return RemoteFunction(client, f, bound, block, targets)
93 94 return remote_function
94 95
96 def parallel(client, dist='b', bound=False, block=None, targets='all'):
97 """Turn a function into a parallel remote function.
98
99 This method can be used for map:
100
101 >>> @parallel(client,block=True)
102 def func(a)
103 """
104 def parallel_function(f):
105 return ParallelFunction(client, f, dist, bound, block, targets)
106 return parallel_function
107
95 108 #--------------------------------------------------------------------------
96 109 # Classes
97 110 #--------------------------------------------------------------------------
@@ -133,6 +146,103 b' class RemoteFunction(object):'
133 146 block=self.block, targets=self.targets, bound=self.bound)
134 147
135 148
149 class ParallelFunction(RemoteFunction):
150 """Class for mapping a function to sequences."""
151 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all'):
152 super(ParallelFunction, self).__init__(client,f,bound,block,targets)
153 mapClass = Map.dists[dist]
154 self.mapObject = mapClass()
155
156 def __call__(self, *sequences):
157 len_0 = len(sequences[0])
158 for s in sequences:
159 if len(s)!=len_0:
160 raise ValueError('all sequences must have equal length')
161
162 if self.targets is None:
163 # load-balanced:
164 engines = [None]*len_0
165 else:
166 # multiplexed:
167 engines = self.client._build_targets(self.targets)[-1]
168
169 nparts = len(engines)
170 msg_ids = []
171 for index, engineid in enumerate(engines):
172 args = []
173 for seq in sequences:
174 args.append(self.mapObject.getPartition(seq, index, nparts))
175 mid = self.client.apply(self.func, args=args, block=False,
176 bound=self.bound,
177 targets=engineid)
178 msg_ids.append(mid)
179
180 if self.block:
181 dg = PendingMapResult(self.client, msg_ids, self.mapObject)
182 dg.wait()
183 return dg.result
184 else:
185 return dg
186
187
188 class PendingResult(object):
189 """Class for representing results of non-blocking calls."""
190 def __init__(self, client, msg_ids):
191 self.client = client
192 self.msg_ids = msg_ids
193 self._result = None
194 self.done = False
195
196 def __repr__(self):
197 if self.done:
198 return "<%s: finished>"%(self.__class__.__name__)
199 else:
200 return "<%s: %r>"%(self.__class__.__name__,self.msg_ids)
201
202 @property
203 def result(self):
204 if self._result is not None:
205 return self._result
206 if not self.done:
207 self.wait(0)
208 if self.done:
209 results = map(self.client.results.get, self.msg_ids)
210 results = error.collect_exceptions(results, 'get_result')
211 self._result = self.reconstruct_result(results)
212 return self._result
213 else:
214 raise error.ResultNotCompleted
215
216 def reconstruct_result(self, res):
217 """
218 Override me in subclasses for turning a list of results
219 into the expected form.
220 """
221 if len(res) == 1:
222 return res[0]
223 else:
224 return res
225
226 def wait(self, timout=-1):
227 self.done = self.client.barrier(self.msg_ids)
228 return self.done
229
230 class PendingMapResult(PendingResult):
231 """Class for representing results of non-blocking gathers.
232
233 This will properly reconstruct the gather.
234 """
235
236 def __init__(self, client, msg_ids, mapObject):
237 self.mapObject = mapObject
238 PendingResult.__init__(self, client, msg_ids)
239
240 def reconstruct_result(self, res):
241 """Perform the gather on the actual results."""
242 return self.mapObject.joinPartitions(res)
243
244
245
136 246 class AbortedTask(object):
137 247 """A basic wrapper object describing an aborted task."""
138 248 def __init__(self, msg_id):
@@ -498,6 +608,17 b' class Client(object):'
498 608 # Begin public methods
499 609 #--------------------------------------------------------------------------
500 610
611 @property
612 def remote(self):
613 """property for convenient RemoteFunction generation.
614
615 >>> @client.remote
616 ... def f():
617 import os
618 print (os.getpid())
619 """
620 return remote(self, block=self.block)
621
501 622 def spin(self):
502 623 """Flush any registration notifications and execution results
503 624 waiting in the ZMQ queue.
@@ -784,7 +905,7 b' class Client(object):'
784 905 self.barrier(msg_id)
785 906 return self._maybe_raise(self.results[msg_id])
786 907 else:
787 return msg_id
908 return PendingResult(self, [msg_id])
788 909
789 910 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
790 911 after=None, follow=None):
@@ -814,10 +935,7 b' class Client(object):'
814 935 if block:
815 936 self.barrier(msg_ids)
816 937 else:
817 if len(msg_ids) == 1:
818 return msg_ids[0]
819 else:
820 return msg_ids
938 return PendingResult(self, msg_ids)
821 939 if len(msg_ids) == 1:
822 940 return self._maybe_raise(self.results[msg_ids[0]])
823 941 else:
@@ -826,12 +944,17 b' class Client(object):'
826 944 result[target] = self.results[mid]
827 945 return error.collect_exceptions(result, f.__name__)
828 946
947 @defaultblock
948 def map(self, f, sequences, targets=None, block=None, bound=False):
949 pf = ParallelFunction(self,f,block=block,bound=bound,targets=targets)
950 return pf(*sequences)
951
829 952 #--------------------------------------------------------------------------
830 953 # Data movement
831 954 #--------------------------------------------------------------------------
832 955
833 956 @defaultblock
834 def push(self, ns, targets=None, block=None):
957 def push(self, ns, targets='all', block=None):
835 958 """Push the contents of `ns` into the namespace on `target`"""
836 959 if not isinstance(ns, dict):
837 960 raise TypeError("Must be a dict, not %s"%type(ns))
@@ -839,7 +962,7 b' class Client(object):'
839 962 return result
840 963
841 964 @defaultblock
842 def pull(self, keys, targets=None, block=True):
965 def pull(self, keys, targets='all', block=True):
843 966 """Pull objects from `target`'s namespace by `keys`"""
844 967 if isinstance(keys, str):
845 968 pass
@@ -850,6 +973,48 b' class Client(object):'
850 973 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
851 974 return result
852 975
976 @defaultblock
977 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
978 """
979 Partition a Python sequence and send the partitions to a set of engines.
980 """
981 targets = self._build_targets(targets)[-1]
982 mapObject = Map.dists[dist]()
983 nparts = len(targets)
984 msg_ids = []
985 for index, engineid in enumerate(targets):
986 partition = mapObject.getPartition(seq, index, nparts)
987 if flatten and len(partition) == 1:
988 mid = self.push({key: partition[0]}, targets=engineid, block=False)
989 else:
990 mid = self.push({key: partition}, targets=engineid, block=False)
991 msg_ids.append(mid)
992 r = PendingResult(self, msg_ids)
993 if block:
994 r.wait()
995 return
996 else:
997 return r
998
999 @defaultblock
1000 def gather(self, key, dist='b', targets='all', block=True):
1001 """
1002 Gather a partitioned sequence on a set of engines as a single local seq.
1003 """
1004
1005 targets = self._build_targets(targets)[-1]
1006 mapObject = Map.dists[dist]()
1007 msg_ids = []
1008 for index, engineid in enumerate(targets):
1009 msg_ids.append(self.pull(key, targets=engineid,block=False))
1010
1011 r = PendingMapResult(self, msg_ids, mapObject)
1012 if block:
1013 r.wait()
1014 return r.result
1015 else:
1016 return r
1017
853 1018 #--------------------------------------------------------------------------
854 1019 # Query methods
855 1020 #--------------------------------------------------------------------------
@@ -985,4 +1150,16 b' class AsynClient(Client):'
985 1150 for stream in (self.queue_stream, self.notifier_stream,
986 1151 self.task_stream, self.control_stream):
987 1152 stream.flush()
988
1153
1154 __all__ = [ 'Client',
1155 'depend',
1156 'require',
1157 'remote',
1158 'parallel',
1159 'RemoteFunction',
1160 'ParallelFunction',
1161 'DirectView',
1162 'LoadBalancedView',
1163 'PendingResult',
1164 'PendingMapResult'
1165 ]
@@ -247,11 +247,15 b' class CompositeError(KernelError):'
247 247 et,ev,tb = sys.exc_info()
248 248
249 249
250 def collect_exceptions(rdict, method):
250 def collect_exceptions(rdict_or_list, method):
251 251 """check a result dict for errors, and raise CompositeError if any exist.
252 252 Passthrough otherwise."""
253 253 elist = []
254 for r in rdict.values():
254 if isinstance(rdict_or_list, dict):
255 rlist = rdict_or_list.values()
256 else:
257 rlist = rdict_or_list
258 for r in rlist:
255 259 if isinstance(r, RemoteError):
256 260 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
257 261 # Sometimes we could have CompositeError in our list. Just take
@@ -264,7 +268,7 b' def collect_exceptions(rdict, method):'
264 268 else:
265 269 elist.append((en, ev, etb, ei))
266 270 if len(elist)==0:
267 return rdict
271 return rdict_or_list
268 272 else:
269 273 msg = "one or more exceptions from call to method: %s" % (method)
270 274 # This silliness is needed so the debugger has access to the exception
@@ -228,6 +228,27 b' class DirectView(View):'
228 228 block = block if block is not None else self.block
229 229 return self.client.pull(key_s, block=block, targets=self.targets)
230 230
231 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None):
232 """
233 Partition a Python sequence and send the partitions to a set of engines.
234 """
235 block = block if block is not None else self.block
236 if targets is None:
237 targets = self.targets
238
239 return self.client.scatter(key, seq, dist=dist, flatten=flatten,
240 targets=targets, block=block)
241
242 def gather(self, key, dist='b', targets=None, block=True):
243 """
244 Gather a partitioned sequence on a set of engines as a single local seq.
245 """
246 block = block if block is not None else self.block
247 if targets is None:
248 targets = self.targets
249
250 return self.client.gather(key, dist=dist, targets=targets, block=block)
251
231 252 def __getitem__(self, key):
232 253 return self.get(key)
233 254
General Comments 0
You need to be logged in to leave comments. Login now