##// END OF EJS Templates
avoid unnecessary imports of numpy...
MinRK -
Show More
@@ -1,166 +1,129
1 1 # encoding: utf-8
2 2
3 3 """Classes used in scattering and gathering sequences.
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
14 7 """
15 8
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
22
23 #-------------------------------------------------------------------------------
24 # Imports
25 #-------------------------------------------------------------------------------
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
26 11
27 12 from __future__ import division
28 13
14 import sys
29 15 from itertools import islice
30 16
31 17 from IPython.utils.data import flatten as utils_flatten
32 18
33 #-------------------------------------------------------------------------------
34 # Figure out which array packages are present and their array types
35 #-------------------------------------------------------------------------------
36
37 arrayModules = []
38 try:
39 import Numeric
40 except ImportError:
41 pass
42 else:
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 try:
45 import numpy
46 except ImportError:
47 pass
48 else:
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 try:
51 import numarray
52 except ImportError:
53 pass
54 else:
55 arrayModules.append({'module':numarray,
56 'type':numarray.numarraycore.NumArray})
19
20 numpy = None
21
22 def is_array(obj):
23 """Is an object a numpy array?
24
25 Avoids importing numpy until it is requested
26 """
27 global numpy
28 if 'numpy' not in sys.modules:
29 return False
30
31 if numpy is None:
32 import numpy
33 return isinstance(obj, numpy.ndarray)
57 34
58 35 class Map(object):
59 36 """A class for partitioning a sequence using a map."""
60 37
61 38 def getPartition(self, seq, p, q, n=None):
62 39 """Returns the pth partition of q partitions of seq.
63 40
64 41 The length can be specified as `n`,
65 42 otherwise it is the value of `len(seq)`
66 43 """
67 44 n = len(seq) if n is None else n
68 45 # Test for error conditions here
69 46 if p<0 or p>=q:
70 47 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
71 48
72 49 remainder = n % q
73 50 basesize = n // q
74 51
75 52 if p < remainder:
76 53 low = p * (basesize + 1)
77 54 high = low + basesize + 1
78 55 else:
79 56 low = p * basesize + remainder
80 57 high = low + basesize
81 58
82 59 try:
83 60 result = seq[low:high]
84 61 except TypeError:
85 62 # some objects (iterators) can't be sliced,
86 63 # use islice:
87 64 result = list(islice(seq, low, high))
88 65
89 66 return result
90 67
91 68 def joinPartitions(self, listOfPartitions):
92 69 return self.concatenate(listOfPartitions)
93
70
94 71 def concatenate(self, listOfPartitions):
95 72 testObject = listOfPartitions[0]
96 73 # First see if we have a known array type
97 for m in arrayModules:
98 #print m
99 if isinstance(testObject, m['type']):
100 return m['module'].concatenate(listOfPartitions)
74 if is_array(testObject):
75 return numpy.concatenate(listOfPartitions)
101 76 # Next try for Python sequence types
102 77 if isinstance(testObject, (list, tuple)):
103 78 return utils_flatten(listOfPartitions)
104 79 # If we have scalars, just return listOfPartitions
105 80 return listOfPartitions
106 81
107 82 class RoundRobinMap(Map):
108 83 """Partitions a sequence in a round robin fashion.
109 84
110 85 This currently does not work!
111 86 """
112 87
113 88 def getPartition(self, seq, p, q, n=None):
114 89 n = len(seq) if n is None else n
115 90 return seq[p:n:q]
116 91
117 92 def joinPartitions(self, listOfPartitions):
118 93 testObject = listOfPartitions[0]
119 94 # First see if we have a known array type
120 for m in arrayModules:
121 #print m
122 if isinstance(testObject, m['type']):
123 return self.flatten_array(m['type'], listOfPartitions)
95 if is_array(testObject):
96 return self.flatten_array(listOfPartitions)
124 97 if isinstance(testObject, (list, tuple)):
125 98 return self.flatten_list(listOfPartitions)
126 99 return listOfPartitions
127 100
128 def flatten_array(self, klass, listOfPartitions):
101 def flatten_array(self, listOfPartitions):
129 102 test = listOfPartitions[0]
130 103 shape = list(test.shape)
131 104 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 A = klass(shape)
105 A = numpy.ndarray(shape)
133 106 N = shape[0]
134 107 q = len(listOfPartitions)
135 108 for p,part in enumerate(listOfPartitions):
136 109 A[p:N:q] = part
137 110 return A
138 111
139 112 def flatten_list(self, listOfPartitions):
140 113 flat = []
141 114 for i in range(len(listOfPartitions[0])):
142 115 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 116 return flat
144 #lengths = [len(x) for x in listOfPartitions]
145 #maxPartitionLength = len(listOfPartitions[0])
146 #numberOfPartitions = len(listOfPartitions)
147 #concat = self.concatenate(listOfPartitions)
148 #totalLength = len(concat)
149 #result = []
150 #for i in range(maxPartitionLength):
151 # result.append(concat[i:totalLength:maxPartitionLength])
152 # return self.concatenate(listOfPartitions)
153 117
154 118 def mappable(obj):
155 119 """return whether an object is mappable or not."""
156 120 if isinstance(obj, (tuple,list)):
157 121 return True
158 for m in arrayModules:
159 if isinstance(obj,m['type']):
160 return True
122 if is_array(obj):
123 return True
161 124 return False
162 125
163 126 dists = {'b':Map,'r':RoundRobinMap}
164 127
165 128
166 129
General Comments 0
You need to be logged in to leave comments. Login now