Show More
@@ -0,0 +1,218 | |||
|
1 | """ | |
|
2 | BinaryTree inter-engine communication class | |
|
3 | ||
|
4 | use from bintree_script.py | |
|
5 | ||
|
6 | Provides parallel [all]reduce functionality | |
|
7 | ||
|
8 | """ | |
|
9 | ||
|
10 | import cPickle as pickle | |
|
11 | import re | |
|
12 | import socket | |
|
13 | import uuid | |
|
14 | ||
|
15 | import zmq | |
|
16 | ||
|
17 | from IPython.parallel.util import disambiguate_url | |
|
18 | ||
|
19 | ||
|
20 | #---------------------------------------------------------------------------- | |
|
21 | # bintree-related construction/printing helpers | |
|
22 | #---------------------------------------------------------------------------- | |
|
23 | ||
|
24 | def bintree(ids, parent=None): | |
|
25 | """construct {child:parent} dict representation of a binary tree""" | |
|
26 | parents = {} | |
|
27 | n = len(ids) | |
|
28 | if n == 0: | |
|
29 | return parents | |
|
30 | root = ids[0] | |
|
31 | parents[root] = parent | |
|
32 | if len(ids) == 1: | |
|
33 | return parents | |
|
34 | else: | |
|
35 | ids = ids[1:] | |
|
36 | n = len(ids) | |
|
37 | left = bintree(ids[:n/2], parent=root) | |
|
38 | right = bintree(ids[n/2:], parent=root) | |
|
39 | parents.update(left) | |
|
40 | parents.update(right) | |
|
41 | return parents | |
|
42 | ||
|
43 | def reverse_bintree(parents): | |
|
44 | """construct {parent:[children]} dict from {child:parent}""" | |
|
45 | children = {} | |
|
46 | for child,parent in parents.iteritems(): | |
|
47 | if parent is None: | |
|
48 | children[None] = child | |
|
49 | continue | |
|
50 | elif parent not in children: | |
|
51 | children[parent] = [] | |
|
52 | children[parent].append(child) | |
|
53 | ||
|
54 | return children | |
|
55 | ||
|
56 | def depth(n, tree): | |
|
57 | """get depth of an element in the tree""" | |
|
58 | d = 0 | |
|
59 | parent = tree[n] | |
|
60 | while parent is not None: | |
|
61 | d += 1 | |
|
62 | parent = tree[parent] | |
|
63 | return d | |
|
64 | ||
|
65 | def print_bintree(tree, indent=' '): | |
|
66 | """print a binary tree""" | |
|
67 | for n in sorted(tree.keys()): | |
|
68 | print "%s%s" % (indent * depth(n,tree), n) | |
|
69 | ||
|
70 | #---------------------------------------------------------------------------- | |
|
71 | # Communicator class for a binary-tree map | |
|
72 | #---------------------------------------------------------------------------- | |
|
73 | ||
|
74 | ip_pat = re.compile(r'^\d+\.\d+\.\d+\.\d+$') | |
|
75 | ||
|
76 | def disambiguate_dns_url(url, location): | |
|
77 | """accept either IP address or dns name, and return IP""" | |
|
78 | if not ip_pat.match(location): | |
|
79 | location = socket.gethostbyname(location) | |
|
80 | return disambiguate_url(url, location) | |
|
81 | ||
|
82 | class BinaryTreeCommunicator(object): | |
|
83 | ||
|
84 | id = None | |
|
85 | pub = None | |
|
86 | sub = None | |
|
87 | downstream = None | |
|
88 | upstream = None | |
|
89 | pub_url = None | |
|
90 | tree_url = None | |
|
91 | ||
|
92 | def __init__(self, id, interface='tcp://*', root=False): | |
|
93 | self.id = id | |
|
94 | self.root = root | |
|
95 | ||
|
96 | # create context and sockets | |
|
97 | self._ctx = zmq.Context() | |
|
98 | if root: | |
|
99 | self.pub = self._ctx.socket(zmq.PUB) | |
|
100 | else: | |
|
101 | self.sub = self._ctx.socket(zmq.SUB) | |
|
102 | self.sub.setsockopt(zmq.SUBSCRIBE, b'') | |
|
103 | self.downstream = self._ctx.socket(zmq.PULL) | |
|
104 | self.upstream = self._ctx.socket(zmq.PUSH) | |
|
105 | ||
|
106 | # bind to ports | |
|
107 | interface_f = interface + ":%i" | |
|
108 | if self.root: | |
|
109 | pub_port = self.pub.bind_to_random_port(interface) | |
|
110 | self.pub_url = interface_f % pub_port | |
|
111 | ||
|
112 | tree_port = self.downstream.bind_to_random_port(interface) | |
|
113 | self.tree_url = interface_f % tree_port | |
|
114 | self.downstream_poller = zmq.Poller() | |
|
115 | self.downstream_poller.register(self.downstream, zmq.POLLIN) | |
|
116 | ||
|
117 | # guess first public IP from socket | |
|
118 | self.location = socket.gethostbyname_ex(socket.gethostname())[-1][0] | |
|
119 | ||
|
120 | def __del__(self): | |
|
121 | self.downstream.close() | |
|
122 | self.upstream.close() | |
|
123 | if self.root: | |
|
124 | self.pub.close() | |
|
125 | else: | |
|
126 | self.sub.close() | |
|
127 | self._ctx.term() | |
|
128 | ||
|
129 | @property | |
|
130 | def info(self): | |
|
131 | """return the connection info for this object's sockets.""" | |
|
132 | return (self.tree_url, self.location) | |
|
133 | ||
|
134 | def connect(self, peers, btree, pub_url, root_id=0): | |
|
135 | """connect to peers. `peers` will be a dict of 4-tuples, keyed by name. | |
|
136 | {peer : (ident, addr, pub_addr, location)} | |
|
137 | where peer is the name, ident is the XREP identity, addr,pub_addr are the | |
|
138 | """ | |
|
139 | ||
|
140 | # count the number of children we have | |
|
141 | self.nchildren = btree.values().count(self.id) | |
|
142 | ||
|
143 | if self.root: | |
|
144 | return # root only binds | |
|
145 | ||
|
146 | root_location = peers[root_id][-1] | |
|
147 | self.sub.connect(disambiguate_dns_url(pub_url, root_location)) | |
|
148 | ||
|
149 | parent = btree[self.id] | |
|
150 | ||
|
151 | tree_url, location = peers[parent] | |
|
152 | self.upstream.connect(disambiguate_dns_url(tree_url, location)) | |
|
153 | ||
|
154 | def serialize(self, obj): | |
|
155 | """serialize objects. | |
|
156 | ||
|
157 | Must return list of sendable buffers. | |
|
158 | ||
|
159 | Can be extended for more efficient/noncopying serialization of numpy arrays, etc. | |
|
160 | """ | |
|
161 | return [pickle.dumps(obj)] | |
|
162 | ||
|
163 | def unserialize(self, msg): | |
|
164 | """inverse of serialize""" | |
|
165 | return pickle.loads(msg[0]) | |
|
166 | ||
|
167 | def publish(self, value): | |
|
168 | assert self.root | |
|
169 | self.pub.send_multipart(self.serialize(value)) | |
|
170 | ||
|
171 | def consume(self): | |
|
172 | assert not self.root | |
|
173 | return self.unserialize(self.sub.recv_multipart()) | |
|
174 | ||
|
175 | def send_upstream(self, value, flags=0): | |
|
176 | assert not self.root | |
|
177 | self.upstream.send_multipart(self.serialize(value), flags=flags|zmq.NOBLOCK) | |
|
178 | ||
|
179 | def recv_downstream(self, flags=0, timeout=2000.): | |
|
180 | # wait for a message, so we won't block if there was a bug | |
|
181 | self.downstream_poller.poll(timeout) | |
|
182 | ||
|
183 | msg = self.downstream.recv_multipart(zmq.NOBLOCK|flags) | |
|
184 | return self.unserialize(msg) | |
|
185 | ||
|
186 | def reduce(self, f, value, flat=True, all=False): | |
|
187 | """parallel reduce on binary tree | |
|
188 | ||
|
189 | if flat: | |
|
190 | value is an entry in the sequence | |
|
191 | else: | |
|
192 | value is a list of entries in the sequence | |
|
193 | ||
|
194 | if all: | |
|
195 | broadcast final result to all nodes | |
|
196 | else: | |
|
197 | only root gets final result | |
|
198 | """ | |
|
199 | if not flat: | |
|
200 | value = reduce(f, value) | |
|
201 | ||
|
202 | for i in range(self.nchildren): | |
|
203 | value = f(value, self.recv_downstream()) | |
|
204 | ||
|
205 | if not self.root: | |
|
206 | self.send_upstream(value) | |
|
207 | ||
|
208 | if all: | |
|
209 | if self.root: | |
|
210 | self.publish(value) | |
|
211 | else: | |
|
212 | value = self.consume() | |
|
213 | return value | |
|
214 | ||
|
215 | def allreduce(self, f, value, flat=True): | |
|
216 | """parallel reduce followed by broadcast of the result""" | |
|
217 | return self.reduce(f, value, flat=flat, all=True) | |
|
218 |
@@ -0,0 +1,75 | |||
|
1 | """ | |
|
2 | Script for setting up and using [all]reduce with a binary-tree engine interconnect. | |
|
3 | ||
|
4 | usage: `python bintree_script.py` | |
|
5 | ||
|
6 | """ | |
|
7 | ||
|
8 | from IPython.parallel import Client, Reference | |
|
9 | ||
|
10 | ||
|
11 | # connect client and create views | |
|
12 | rc = Client() | |
|
13 | rc.block=True | |
|
14 | ids = rc.ids | |
|
15 | ||
|
16 | root_id = ids[0] | |
|
17 | root = rc[root_id] | |
|
18 | ||
|
19 | view = rc[:] | |
|
20 | ||
|
21 | # run bintree.py script defining bintree functions, etc. | |
|
22 | execfile('bintree.py') | |
|
23 | ||
|
24 | # generate binary tree of parents | |
|
25 | btree = bintree(ids) | |
|
26 | ||
|
27 | print "setting up binary tree interconnect:" | |
|
28 | print_bintree(btree) | |
|
29 | ||
|
30 | view.run('bintree.py') | |
|
31 | view.scatter('id', ids, flatten=True) | |
|
32 | view['root_id'] = root_id | |
|
33 | ||
|
34 | # create the Communicator objects on the engines | |
|
35 | view.execute('com = BinaryTreeCommunicator(id, root = id==root_id )') | |
|
36 | pub_url = root.apply_sync(lambda : com.pub_url) | |
|
37 | ||
|
38 | # gather the connection information into a dict | |
|
39 | ar = view.apply_async(lambda : com.info) | |
|
40 | peers = ar.get_dict() | |
|
41 | # this is a dict, keyed by engine ID, of the connection info for the EngineCommunicators | |
|
42 | ||
|
43 | # connect the engines to each other: | |
|
44 | def connect(com, peers, tree, pub_url, root_id): | |
|
45 | """this function will be called on the engines""" | |
|
46 | com.connect(peers, tree, pub_url, root_id) | |
|
47 | ||
|
48 | view.apply_sync(connect, Reference('com'), peers, btree, pub_url, root_id) | |
|
49 | ||
|
50 | # functions that can be used for reductions | |
|
51 | # max and min builtins can be used as well | |
|
52 | def add(a,b): | |
|
53 | """cumulative sum reduction""" | |
|
54 | return a+b | |
|
55 | ||
|
56 | def mul(a,b): | |
|
57 | """cumulative product reduction""" | |
|
58 | return a*b | |
|
59 | ||
|
60 | view['add'] = add | |
|
61 | view['mul'] = mul | |
|
62 | ||
|
63 | # scatter some data | |
|
64 | data = range(1000) | |
|
65 | view.scatter('data', data) | |
|
66 | ||
|
67 | # perform cumulative sum via allreduce | |
|
68 | view.execute("data_sum = com.allreduce(add, data, flat=False)") | |
|
69 | print "allreduce sum of data on all engines:", view['data_sum'] | |
|
70 | ||
|
71 | # perform cumulative sum *without* final broadcast | |
|
72 | # when not broadcasting with allreduce, the final result resides on the root node: | |
|
73 | view.execute("ids_sum = com.reduce(add, id, flat=True)") | |
|
74 | print "reduce sum of engine ids (not broadcast):", root['ids_sum'] | |
|
75 | print "partial result on each engine:", view['ids_sum'] |
General Comments 0
You need to be logged in to leave comments.
Login now