Show More
@@ -0,0 +1,158 b'' | |||
|
1 | # encoding: utf-8 | |
|
2 | ||
|
3 | """Classes used in scattering and gathering sequences. | |
|
4 | ||
|
5 | Scattering consists of partitioning a sequence and sending the various | |
|
6 | pieces to individual nodes in a cluster. | |
|
7 | """ | |
|
8 | ||
|
9 | __docformat__ = "restructuredtext en" | |
|
10 | ||
|
11 | #------------------------------------------------------------------------------- | |
|
12 | # Copyright (C) 2008 The IPython Development Team | |
|
13 | # | |
|
14 | # Distributed under the terms of the BSD License. The full license is in | |
|
15 | # the file COPYING, distributed as part of this software. | |
|
16 | #------------------------------------------------------------------------------- | |
|
17 | ||
|
18 | #------------------------------------------------------------------------------- | |
|
19 | # Imports | |
|
20 | #------------------------------------------------------------------------------- | |
|
21 | ||
|
22 | import types | |
|
23 | ||
|
24 | from IPython.utils.data import flatten as utils_flatten | |
|
25 | ||
|
26 | #------------------------------------------------------------------------------- | |
|
27 | # Figure out which array packages are present and their array types | |
|
28 | #------------------------------------------------------------------------------- | |
|
29 | ||
|
30 | arrayModules = [] | |
|
31 | try: | |
|
32 | import Numeric | |
|
33 | except ImportError: | |
|
34 | pass | |
|
35 | else: | |
|
36 | arrayModules.append({'module':Numeric, 'type':Numeric.arraytype}) | |
|
37 | try: | |
|
38 | import numpy | |
|
39 | except ImportError: | |
|
40 | pass | |
|
41 | else: | |
|
42 | arrayModules.append({'module':numpy, 'type':numpy.ndarray}) | |
|
43 | try: | |
|
44 | import numarray | |
|
45 | except ImportError: | |
|
46 | pass | |
|
47 | else: | |
|
48 | arrayModules.append({'module':numarray, | |
|
49 | 'type':numarray.numarraycore.NumArray}) | |
|
50 | ||
|
51 | class Map: | |
|
52 | """A class for partitioning a sequence using a map.""" | |
|
53 | ||
|
54 | def getPartition(self, seq, p, q): | |
|
55 | """Returns the pth partition of q partitions of seq.""" | |
|
56 | ||
|
57 | # Test for error conditions here | |
|
58 | if p<0 or p>=q: | |
|
59 | print "No partition exists." | |
|
60 | return | |
|
61 | ||
|
62 | remainder = len(seq)%q | |
|
63 | basesize = len(seq)/q | |
|
64 | hi = [] | |
|
65 | lo = [] | |
|
66 | for n in range(q): | |
|
67 | if n < remainder: | |
|
68 | lo.append(n * (basesize + 1)) | |
|
69 | hi.append(lo[-1] + basesize + 1) | |
|
70 | else: | |
|
71 | lo.append(n*basesize + remainder) | |
|
72 | hi.append(lo[-1] + basesize) | |
|
73 | ||
|
74 | ||
|
75 | result = seq[lo[p]:hi[p]] | |
|
76 | return result | |
|
77 | ||
|
78 | def joinPartitions(self, listOfPartitions): | |
|
79 | return self.concatenate(listOfPartitions) | |
|
80 | ||
|
81 | def concatenate(self, listOfPartitions): | |
|
82 | testObject = listOfPartitions[0] | |
|
83 | # First see if we have a known array type | |
|
84 | for m in arrayModules: | |
|
85 | #print m | |
|
86 | if isinstance(testObject, m['type']): | |
|
87 | return m['module'].concatenate(listOfPartitions) | |
|
88 | # Next try for Python sequence types | |
|
89 | if isinstance(testObject, (types.ListType, types.TupleType)): | |
|
90 | return utils_flatten(listOfPartitions) | |
|
91 | # If we have scalars, just return listOfPartitions | |
|
92 | return listOfPartitions | |
|
93 | ||
|
94 | class RoundRobinMap(Map): | |
|
95 | """Partitions a sequence in a roun robin fashion. | |
|
96 | ||
|
97 | This currently does not work! | |
|
98 | """ | |
|
99 | ||
|
100 | def getPartition(self, seq, p, q): | |
|
101 | # if not isinstance(seq,(list,tuple)): | |
|
102 | # raise NotImplementedError("cannot RR partition type %s"%type(seq)) | |
|
103 | return seq[p:len(seq):q] | |
|
104 | #result = [] | |
|
105 | #for i in range(p,len(seq),q): | |
|
106 | # result.append(seq[i]) | |
|
107 | #return result | |
|
108 | ||
|
109 | def joinPartitions(self, listOfPartitions): | |
|
110 | testObject = listOfPartitions[0] | |
|
111 | # First see if we have a known array type | |
|
112 | for m in arrayModules: | |
|
113 | #print m | |
|
114 | if isinstance(testObject, m['type']): | |
|
115 | return self.flatten_array(m['type'], listOfPartitions) | |
|
116 | if isinstance(testObject, (types.ListType, types.TupleType)): | |
|
117 | return self.flatten_list(listOfPartitions) | |
|
118 | return listOfPartitions | |
|
119 | ||
|
120 | def flatten_array(self, klass, listOfPartitions): | |
|
121 | test = listOfPartitions[0] | |
|
122 | shape = list(test.shape) | |
|
123 | shape[0] = sum([ p.shape[0] for p in listOfPartitions]) | |
|
124 | A = klass(shape) | |
|
125 | N = shape[0] | |
|
126 | q = len(listOfPartitions) | |
|
127 | for p,part in enumerate(listOfPartitions): | |
|
128 | A[p:N:q] = part | |
|
129 | return A | |
|
130 | ||
|
131 | def flatten_list(self, listOfPartitions): | |
|
132 | flat = [] | |
|
133 | for i in range(len(listOfPartitions[0])): | |
|
134 | flat.extend([ part[i] for part in listOfPartitions if len(part) > i ]) | |
|
135 | return flat | |
|
136 | #lengths = [len(x) for x in listOfPartitions] | |
|
137 | #maxPartitionLength = len(listOfPartitions[0]) | |
|
138 | #numberOfPartitions = len(listOfPartitions) | |
|
139 | #concat = self.concatenate(listOfPartitions) | |
|
140 | #totalLength = len(concat) | |
|
141 | #result = [] | |
|
142 | #for i in range(maxPartitionLength): | |
|
143 | # result.append(concat[i:totalLength:maxPartitionLength]) | |
|
144 | # return self.concatenate(listOfPartitions) | |
|
145 | ||
|
146 | def mappable(obj): | |
|
147 | """return whether an object is mappable or not.""" | |
|
148 | if isinstance(obj, (tuple,list)): | |
|
149 | return True | |
|
150 | for m in arrayModules: | |
|
151 | if isinstance(obj,m['type']): | |
|
152 | return True | |
|
153 | return False | |
|
154 | ||
|
155 | dists = {'b':Map,'r':RoundRobinMap} | |
|
156 | ||
|
157 | ||
|
158 |
@@ -28,6 +28,7 b' import streamsession as ss' | |||
|
28 | 28 | from view import DirectView, LoadBalancedView |
|
29 | 29 | from dependency import Dependency, depend, require |
|
30 | 30 | import error |
|
31 | import map as Map | |
|
31 | 32 | |
|
32 | 33 | #-------------------------------------------------------------------------- |
|
33 | 34 | # helpers for implementing old MEC API via client.apply |
@@ -92,6 +93,18 b' def remote(client, bound=False, block=None, targets=None):' | |||
|
92 | 93 | return RemoteFunction(client, f, bound, block, targets) |
|
93 | 94 | return remote_function |
|
94 | 95 | |
|
96 | def parallel(client, dist='b', bound=False, block=None, targets='all'): | |
|
97 | """Turn a function into a parallel remote function. | |
|
98 | ||
|
99 | This method can be used for map: | |
|
100 | ||
|
101 | >>> @parallel(client,block=True) | |
|
102 | def func(a) | |
|
103 | """ | |
|
104 | def parallel_function(f): | |
|
105 | return ParallelFunction(client, f, dist, bound, block, targets) | |
|
106 | return parallel_function | |
|
107 | ||
|
95 | 108 | #-------------------------------------------------------------------------- |
|
96 | 109 | # Classes |
|
97 | 110 | #-------------------------------------------------------------------------- |
@@ -133,6 +146,103 b' class RemoteFunction(object):' | |||
|
133 | 146 | block=self.block, targets=self.targets, bound=self.bound) |
|
134 | 147 | |
|
135 | 148 | |
|
149 | class ParallelFunction(RemoteFunction): | |
|
150 | """Class for mapping a function to sequences.""" | |
|
151 | def __init__(self, client, f, dist='b', bound=False, block=None, targets='all'): | |
|
152 | super(ParallelFunction, self).__init__(client,f,bound,block,targets) | |
|
153 | mapClass = Map.dists[dist] | |
|
154 | self.mapObject = mapClass() | |
|
155 | ||
|
156 | def __call__(self, *sequences): | |
|
157 | len_0 = len(sequences[0]) | |
|
158 | for s in sequences: | |
|
159 | if len(s)!=len_0: | |
|
160 | raise ValueError('all sequences must have equal length') | |
|
161 | ||
|
162 | if self.targets is None: | |
|
163 | # load-balanced: | |
|
164 | engines = [None]*len_0 | |
|
165 | else: | |
|
166 | # multiplexed: | |
|
167 | engines = self.client._build_targets(self.targets)[-1] | |
|
168 | ||
|
169 | nparts = len(engines) | |
|
170 | msg_ids = [] | |
|
171 | for index, engineid in enumerate(engines): | |
|
172 | args = [] | |
|
173 | for seq in sequences: | |
|
174 | args.append(self.mapObject.getPartition(seq, index, nparts)) | |
|
175 | mid = self.client.apply(self.func, args=args, block=False, | |
|
176 | bound=self.bound, | |
|
177 | targets=engineid) | |
|
178 | msg_ids.append(mid) | |
|
179 | ||
|
180 | if self.block: | |
|
181 | dg = PendingMapResult(self.client, msg_ids, self.mapObject) | |
|
182 | dg.wait() | |
|
183 | return dg.result | |
|
184 | else: | |
|
185 | return dg | |
|
186 | ||
|
187 | ||
|
188 | class PendingResult(object): | |
|
189 | """Class for representing results of non-blocking calls.""" | |
|
190 | def __init__(self, client, msg_ids): | |
|
191 | self.client = client | |
|
192 | self.msg_ids = msg_ids | |
|
193 | self._result = None | |
|
194 | self.done = False | |
|
195 | ||
|
196 | def __repr__(self): | |
|
197 | if self.done: | |
|
198 | return "<%s: finished>"%(self.__class__.__name__) | |
|
199 | else: | |
|
200 | return "<%s: %r>"%(self.__class__.__name__,self.msg_ids) | |
|
201 | ||
|
202 | @property | |
|
203 | def result(self): | |
|
204 | if self._result is not None: | |
|
205 | return self._result | |
|
206 | if not self.done: | |
|
207 | self.wait(0) | |
|
208 | if self.done: | |
|
209 | results = map(self.client.results.get, self.msg_ids) | |
|
210 | results = error.collect_exceptions(results, 'get_result') | |
|
211 | self._result = self.reconstruct_result(results) | |
|
212 | return self._result | |
|
213 | else: | |
|
214 | raise error.ResultNotCompleted | |
|
215 | ||
|
216 | def reconstruct_result(self, res): | |
|
217 | """ | |
|
218 | Override me in subclasses for turning a list of results | |
|
219 | into the expected form. | |
|
220 | """ | |
|
221 | if len(res) == 1: | |
|
222 | return res[0] | |
|
223 | else: | |
|
224 | return res | |
|
225 | ||
|
226 | def wait(self, timout=-1): | |
|
227 | self.done = self.client.barrier(self.msg_ids) | |
|
228 | return self.done | |
|
229 | ||
|
230 | class PendingMapResult(PendingResult): | |
|
231 | """Class for representing results of non-blocking gathers. | |
|
232 | ||
|
233 | This will properly reconstruct the gather. | |
|
234 | """ | |
|
235 | ||
|
236 | def __init__(self, client, msg_ids, mapObject): | |
|
237 | self.mapObject = mapObject | |
|
238 | PendingResult.__init__(self, client, msg_ids) | |
|
239 | ||
|
240 | def reconstruct_result(self, res): | |
|
241 | """Perform the gather on the actual results.""" | |
|
242 | return self.mapObject.joinPartitions(res) | |
|
243 | ||
|
244 | ||
|
245 | ||
|
136 | 246 | class AbortedTask(object): |
|
137 | 247 | """A basic wrapper object describing an aborted task.""" |
|
138 | 248 | def __init__(self, msg_id): |
@@ -498,6 +608,17 b' class Client(object):' | |||
|
498 | 608 | # Begin public methods |
|
499 | 609 | #-------------------------------------------------------------------------- |
|
500 | 610 | |
|
611 | @property | |
|
612 | def remote(self): | |
|
613 | """property for convenient RemoteFunction generation. | |
|
614 | ||
|
615 | >>> @client.remote | |
|
616 | ... def f(): | |
|
617 | import os | |
|
618 | print (os.getpid()) | |
|
619 | """ | |
|
620 | return remote(self, block=self.block) | |
|
621 | ||
|
501 | 622 | def spin(self): |
|
502 | 623 | """Flush any registration notifications and execution results |
|
503 | 624 | waiting in the ZMQ queue. |
@@ -784,7 +905,7 b' class Client(object):' | |||
|
784 | 905 | self.barrier(msg_id) |
|
785 | 906 | return self._maybe_raise(self.results[msg_id]) |
|
786 | 907 | else: |
|
787 | return msg_id | |
|
908 | return PendingResult(self, [msg_id]) | |
|
788 | 909 | |
|
789 | 910 | def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None, |
|
790 | 911 | after=None, follow=None): |
@@ -814,10 +935,7 b' class Client(object):' | |||
|
814 | 935 | if block: |
|
815 | 936 | self.barrier(msg_ids) |
|
816 | 937 | else: |
|
817 |
|
|
|
818 | return msg_ids[0] | |
|
819 | else: | |
|
820 | return msg_ids | |
|
938 | return PendingResult(self, msg_ids) | |
|
821 | 939 | if len(msg_ids) == 1: |
|
822 | 940 | return self._maybe_raise(self.results[msg_ids[0]]) |
|
823 | 941 | else: |
@@ -826,12 +944,17 b' class Client(object):' | |||
|
826 | 944 | result[target] = self.results[mid] |
|
827 | 945 | return error.collect_exceptions(result, f.__name__) |
|
828 | 946 | |
|
947 | @defaultblock | |
|
948 | def map(self, f, sequences, targets=None, block=None, bound=False): | |
|
949 | pf = ParallelFunction(self,f,block=block,bound=bound,targets=targets) | |
|
950 | return pf(*sequences) | |
|
951 | ||
|
829 | 952 | #-------------------------------------------------------------------------- |
|
830 | 953 | # Data movement |
|
831 | 954 | #-------------------------------------------------------------------------- |
|
832 | 955 | |
|
833 | 956 | @defaultblock |
|
834 |
def push(self, ns, targets= |
|
|
957 | def push(self, ns, targets='all', block=None): | |
|
835 | 958 | """Push the contents of `ns` into the namespace on `target`""" |
|
836 | 959 | if not isinstance(ns, dict): |
|
837 | 960 | raise TypeError("Must be a dict, not %s"%type(ns)) |
@@ -839,7 +962,7 b' class Client(object):' | |||
|
839 | 962 | return result |
|
840 | 963 | |
|
841 | 964 | @defaultblock |
|
842 |
def pull(self, keys, targets= |
|
|
965 | def pull(self, keys, targets='all', block=True): | |
|
843 | 966 | """Pull objects from `target`'s namespace by `keys`""" |
|
844 | 967 | if isinstance(keys, str): |
|
845 | 968 | pass |
@@ -850,6 +973,48 b' class Client(object):' | |||
|
850 | 973 | result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True) |
|
851 | 974 | return result |
|
852 | 975 | |
|
976 | @defaultblock | |
|
977 | def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None): | |
|
978 | """ | |
|
979 | Partition a Python sequence and send the partitions to a set of engines. | |
|
980 | """ | |
|
981 | targets = self._build_targets(targets)[-1] | |
|
982 | mapObject = Map.dists[dist]() | |
|
983 | nparts = len(targets) | |
|
984 | msg_ids = [] | |
|
985 | for index, engineid in enumerate(targets): | |
|
986 | partition = mapObject.getPartition(seq, index, nparts) | |
|
987 | if flatten and len(partition) == 1: | |
|
988 | mid = self.push({key: partition[0]}, targets=engineid, block=False) | |
|
989 | else: | |
|
990 | mid = self.push({key: partition}, targets=engineid, block=False) | |
|
991 | msg_ids.append(mid) | |
|
992 | r = PendingResult(self, msg_ids) | |
|
993 | if block: | |
|
994 | r.wait() | |
|
995 | return | |
|
996 | else: | |
|
997 | return r | |
|
998 | ||
|
999 | @defaultblock | |
|
1000 | def gather(self, key, dist='b', targets='all', block=True): | |
|
1001 | """ | |
|
1002 | Gather a partitioned sequence on a set of engines as a single local seq. | |
|
1003 | """ | |
|
1004 | ||
|
1005 | targets = self._build_targets(targets)[-1] | |
|
1006 | mapObject = Map.dists[dist]() | |
|
1007 | msg_ids = [] | |
|
1008 | for index, engineid in enumerate(targets): | |
|
1009 | msg_ids.append(self.pull(key, targets=engineid,block=False)) | |
|
1010 | ||
|
1011 | r = PendingMapResult(self, msg_ids, mapObject) | |
|
1012 | if block: | |
|
1013 | r.wait() | |
|
1014 | return r.result | |
|
1015 | else: | |
|
1016 | return r | |
|
1017 | ||
|
853 | 1018 | #-------------------------------------------------------------------------- |
|
854 | 1019 | # Query methods |
|
855 | 1020 | #-------------------------------------------------------------------------- |
@@ -985,4 +1150,16 b' class AsynClient(Client):' | |||
|
985 | 1150 | for stream in (self.queue_stream, self.notifier_stream, |
|
986 | 1151 | self.task_stream, self.control_stream): |
|
987 | 1152 | stream.flush() |
|
988 | ||
|
1153 | ||
|
1154 | __all__ = [ 'Client', | |
|
1155 | 'depend', | |
|
1156 | 'require', | |
|
1157 | 'remote', | |
|
1158 | 'parallel', | |
|
1159 | 'RemoteFunction', | |
|
1160 | 'ParallelFunction', | |
|
1161 | 'DirectView', | |
|
1162 | 'LoadBalancedView', | |
|
1163 | 'PendingResult', | |
|
1164 | 'PendingMapResult' | |
|
1165 | ] |
@@ -247,11 +247,15 b' class CompositeError(KernelError):' | |||
|
247 | 247 | et,ev,tb = sys.exc_info() |
|
248 | 248 | |
|
249 | 249 | |
|
250 | def collect_exceptions(rdict, method): | |
|
250 | def collect_exceptions(rdict_or_list, method): | |
|
251 | 251 | """check a result dict for errors, and raise CompositeError if any exist. |
|
252 | 252 | Passthrough otherwise.""" |
|
253 | 253 | elist = [] |
|
254 | for r in rdict.values(): | |
|
254 | if isinstance(rdict_or_list, dict): | |
|
255 | rlist = rdict_or_list.values() | |
|
256 | else: | |
|
257 | rlist = rdict_or_list | |
|
258 | for r in rlist: | |
|
255 | 259 | if isinstance(r, RemoteError): |
|
256 | 260 | en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info |
|
257 | 261 | # Sometimes we could have CompositeError in our list. Just take |
@@ -264,7 +268,7 b' def collect_exceptions(rdict, method):' | |||
|
264 | 268 | else: |
|
265 | 269 | elist.append((en, ev, etb, ei)) |
|
266 | 270 | if len(elist)==0: |
|
267 | return rdict | |
|
271 | return rdict_or_list | |
|
268 | 272 | else: |
|
269 | 273 | msg = "one or more exceptions from call to method: %s" % (method) |
|
270 | 274 | # This silliness is needed so the debugger has access to the exception |
@@ -228,6 +228,27 b' class DirectView(View):' | |||
|
228 | 228 | block = block if block is not None else self.block |
|
229 | 229 | return self.client.pull(key_s, block=block, targets=self.targets) |
|
230 | 230 | |
|
231 | def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None): | |
|
232 | """ | |
|
233 | Partition a Python sequence and send the partitions to a set of engines. | |
|
234 | """ | |
|
235 | block = block if block is not None else self.block | |
|
236 | if targets is None: | |
|
237 | targets = self.targets | |
|
238 | ||
|
239 | return self.client.scatter(key, seq, dist=dist, flatten=flatten, | |
|
240 | targets=targets, block=block) | |
|
241 | ||
|
242 | def gather(self, key, dist='b', targets=None, block=True): | |
|
243 | """ | |
|
244 | Gather a partitioned sequence on a set of engines as a single local seq. | |
|
245 | """ | |
|
246 | block = block if block is not None else self.block | |
|
247 | if targets is None: | |
|
248 | targets = self.targets | |
|
249 | ||
|
250 | return self.client.gather(key, dist=dist, targets=targets, block=block) | |
|
251 | ||
|
231 | 252 | def __getitem__(self, key): |
|
232 | 253 | return self.get(key) |
|
233 | 254 |
General Comments 0
You need to be logged in to leave comments.
Login now