Show More
@@ -0,0 +1,218 b'' | |||||
|
1 | """ | |||
|
2 | BinaryTree inter-engine communication class | |||
|
3 | ||||
|
4 | use from bintree_script.py | |||
|
5 | ||||
|
6 | Provides parallel [all]reduce functionality | |||
|
7 | ||||
|
8 | """ | |||
|
9 | ||||
|
10 | import cPickle as pickle | |||
|
11 | import re | |||
|
12 | import socket | |||
|
13 | import uuid | |||
|
14 | ||||
|
15 | import zmq | |||
|
16 | ||||
|
17 | from IPython.parallel.util import disambiguate_url | |||
|
18 | ||||
|
19 | ||||
|
20 | #---------------------------------------------------------------------------- | |||
|
21 | # bintree-related construction/printing helpers | |||
|
22 | #---------------------------------------------------------------------------- | |||
|
23 | ||||
|
24 | def bintree(ids, parent=None): | |||
|
25 | """construct {child:parent} dict representation of a binary tree""" | |||
|
26 | parents = {} | |||
|
27 | n = len(ids) | |||
|
28 | if n == 0: | |||
|
29 | return parents | |||
|
30 | root = ids[0] | |||
|
31 | parents[root] = parent | |||
|
32 | if len(ids) == 1: | |||
|
33 | return parents | |||
|
34 | else: | |||
|
35 | ids = ids[1:] | |||
|
36 | n = len(ids) | |||
|
37 | left = bintree(ids[:n/2], parent=root) | |||
|
38 | right = bintree(ids[n/2:], parent=root) | |||
|
39 | parents.update(left) | |||
|
40 | parents.update(right) | |||
|
41 | return parents | |||
|
42 | ||||
|
43 | def reverse_bintree(parents): | |||
|
44 | """construct {parent:[children]} dict from {child:parent}""" | |||
|
45 | children = {} | |||
|
46 | for child,parent in parents.iteritems(): | |||
|
47 | if parent is None: | |||
|
48 | children[None] = child | |||
|
49 | continue | |||
|
50 | elif parent not in children: | |||
|
51 | children[parent] = [] | |||
|
52 | children[parent].append(child) | |||
|
53 | ||||
|
54 | return children | |||
|
55 | ||||
|
56 | def depth(n, tree): | |||
|
57 | """get depth of an element in the tree""" | |||
|
58 | d = 0 | |||
|
59 | parent = tree[n] | |||
|
60 | while parent is not None: | |||
|
61 | d += 1 | |||
|
62 | parent = tree[parent] | |||
|
63 | return d | |||
|
64 | ||||
|
65 | def print_bintree(tree, indent=' '): | |||
|
66 | """print a binary tree""" | |||
|
67 | for n in sorted(tree.keys()): | |||
|
68 | print "%s%s" % (indent * depth(n,tree), n) | |||
|
69 | ||||
|
70 | #---------------------------------------------------------------------------- | |||
|
71 | # Communicator class for a binary-tree map | |||
|
72 | #---------------------------------------------------------------------------- | |||
|
73 | ||||
|
74 | ip_pat = re.compile(r'^\d+\.\d+\.\d+\.\d+$') | |||
|
75 | ||||
|
76 | def disambiguate_dns_url(url, location): | |||
|
77 | """accept either IP address or dns name, and return IP""" | |||
|
78 | if not ip_pat.match(location): | |||
|
79 | location = socket.gethostbyname(location) | |||
|
80 | return disambiguate_url(url, location) | |||
|
81 | ||||
|
82 | class BinaryTreeCommunicator(object): | |||
|
83 | ||||
|
84 | id = None | |||
|
85 | pub = None | |||
|
86 | sub = None | |||
|
87 | downstream = None | |||
|
88 | upstream = None | |||
|
89 | pub_url = None | |||
|
90 | tree_url = None | |||
|
91 | ||||
|
92 | def __init__(self, id, interface='tcp://*', root=False): | |||
|
93 | self.id = id | |||
|
94 | self.root = root | |||
|
95 | ||||
|
96 | # create context and sockets | |||
|
97 | self._ctx = zmq.Context() | |||
|
98 | if root: | |||
|
99 | self.pub = self._ctx.socket(zmq.PUB) | |||
|
100 | else: | |||
|
101 | self.sub = self._ctx.socket(zmq.SUB) | |||
|
102 | self.sub.setsockopt(zmq.SUBSCRIBE, b'') | |||
|
103 | self.downstream = self._ctx.socket(zmq.PULL) | |||
|
104 | self.upstream = self._ctx.socket(zmq.PUSH) | |||
|
105 | ||||
|
106 | # bind to ports | |||
|
107 | interface_f = interface + ":%i" | |||
|
108 | if self.root: | |||
|
109 | pub_port = self.pub.bind_to_random_port(interface) | |||
|
110 | self.pub_url = interface_f % pub_port | |||
|
111 | ||||
|
112 | tree_port = self.downstream.bind_to_random_port(interface) | |||
|
113 | self.tree_url = interface_f % tree_port | |||
|
114 | self.downstream_poller = zmq.Poller() | |||
|
115 | self.downstream_poller.register(self.downstream, zmq.POLLIN) | |||
|
116 | ||||
|
117 | # guess first public IP from socket | |||
|
118 | self.location = socket.gethostbyname_ex(socket.gethostname())[-1][0] | |||
|
119 | ||||
|
120 | def __del__(self): | |||
|
121 | self.downstream.close() | |||
|
122 | self.upstream.close() | |||
|
123 | if self.root: | |||
|
124 | self.pub.close() | |||
|
125 | else: | |||
|
126 | self.sub.close() | |||
|
127 | self._ctx.term() | |||
|
128 | ||||
|
129 | @property | |||
|
130 | def info(self): | |||
|
131 | """return the connection info for this object's sockets.""" | |||
|
132 | return (self.tree_url, self.location) | |||
|
133 | ||||
|
134 | def connect(self, peers, btree, pub_url, root_id=0): | |||
|
135 | """connect to peers. `peers` will be a dict of 4-tuples, keyed by name. | |||
|
136 | {peer : (ident, addr, pub_addr, location)} | |||
|
137 | where peer is the name, ident is the XREP identity, addr,pub_addr are the | |||
|
138 | """ | |||
|
139 | ||||
|
140 | # count the number of children we have | |||
|
141 | self.nchildren = btree.values().count(self.id) | |||
|
142 | ||||
|
143 | if self.root: | |||
|
144 | return # root only binds | |||
|
145 | ||||
|
146 | root_location = peers[root_id][-1] | |||
|
147 | self.sub.connect(disambiguate_dns_url(pub_url, root_location)) | |||
|
148 | ||||
|
149 | parent = btree[self.id] | |||
|
150 | ||||
|
151 | tree_url, location = peers[parent] | |||
|
152 | self.upstream.connect(disambiguate_dns_url(tree_url, location)) | |||
|
153 | ||||
|
154 | def serialize(self, obj): | |||
|
155 | """serialize objects. | |||
|
156 | ||||
|
157 | Must return list of sendable buffers. | |||
|
158 | ||||
|
159 | Can be extended for more efficient/noncopying serialization of numpy arrays, etc. | |||
|
160 | """ | |||
|
161 | return [pickle.dumps(obj)] | |||
|
162 | ||||
|
163 | def unserialize(self, msg): | |||
|
164 | """inverse of serialize""" | |||
|
165 | return pickle.loads(msg[0]) | |||
|
166 | ||||
|
167 | def publish(self, value): | |||
|
168 | assert self.root | |||
|
169 | self.pub.send_multipart(self.serialize(value)) | |||
|
170 | ||||
|
171 | def consume(self): | |||
|
172 | assert not self.root | |||
|
173 | return self.unserialize(self.sub.recv_multipart()) | |||
|
174 | ||||
|
175 | def send_upstream(self, value, flags=0): | |||
|
176 | assert not self.root | |||
|
177 | self.upstream.send_multipart(self.serialize(value), flags=flags|zmq.NOBLOCK) | |||
|
178 | ||||
|
179 | def recv_downstream(self, flags=0, timeout=2000.): | |||
|
180 | # wait for a message, so we won't block if there was a bug | |||
|
181 | self.downstream_poller.poll(timeout) | |||
|
182 | ||||
|
183 | msg = self.downstream.recv_multipart(zmq.NOBLOCK|flags) | |||
|
184 | return self.unserialize(msg) | |||
|
185 | ||||
|
186 | def reduce(self, f, value, flat=True, all=False): | |||
|
187 | """parallel reduce on binary tree | |||
|
188 | ||||
|
189 | if flat: | |||
|
190 | value is an entry in the sequence | |||
|
191 | else: | |||
|
192 | value is a list of entries in the sequence | |||
|
193 | ||||
|
194 | if all: | |||
|
195 | broadcast final result to all nodes | |||
|
196 | else: | |||
|
197 | only root gets final result | |||
|
198 | """ | |||
|
199 | if not flat: | |||
|
200 | value = reduce(f, value) | |||
|
201 | ||||
|
202 | for i in range(self.nchildren): | |||
|
203 | value = f(value, self.recv_downstream()) | |||
|
204 | ||||
|
205 | if not self.root: | |||
|
206 | self.send_upstream(value) | |||
|
207 | ||||
|
208 | if all: | |||
|
209 | if self.root: | |||
|
210 | self.publish(value) | |||
|
211 | else: | |||
|
212 | value = self.consume() | |||
|
213 | return value | |||
|
214 | ||||
|
215 | def allreduce(self, f, value, flat=True): | |||
|
216 | """parallel reduce followed by broadcast of the result""" | |||
|
217 | return self.reduce(f, value, flat=flat, all=True) | |||
|
218 |
@@ -0,0 +1,75 b'' | |||||
|
1 | """ | |||
|
2 | Script for setting up and using [all]reduce with a binary-tree engine interconnect. | |||
|
3 | ||||
|
4 | usage: `python bintree_script.py` | |||
|
5 | ||||
|
6 | """ | |||
|
7 | ||||
|
8 | from IPython.parallel import Client, Reference | |||
|
9 | ||||
|
10 | ||||
|
11 | # connect client and create views | |||
|
12 | rc = Client() | |||
|
13 | rc.block=True | |||
|
14 | ids = rc.ids | |||
|
15 | ||||
|
16 | root_id = ids[0] | |||
|
17 | root = rc[root_id] | |||
|
18 | ||||
|
19 | view = rc[:] | |||
|
20 | ||||
|
21 | # run bintree.py script defining bintree functions, etc. | |||
|
22 | execfile('bintree.py') | |||
|
23 | ||||
|
24 | # generate binary tree of parents | |||
|
25 | btree = bintree(ids) | |||
|
26 | ||||
|
27 | print "setting up binary tree interconnect:" | |||
|
28 | print_bintree(btree) | |||
|
29 | ||||
|
30 | view.run('bintree.py') | |||
|
31 | view.scatter('id', ids, flatten=True) | |||
|
32 | view['root_id'] = root_id | |||
|
33 | ||||
|
34 | # create the Communicator objects on the engines | |||
|
35 | view.execute('com = BinaryTreeCommunicator(id, root = id==root_id )') | |||
|
36 | pub_url = root.apply_sync(lambda : com.pub_url) | |||
|
37 | ||||
|
38 | # gather the connection information into a dict | |||
|
39 | ar = view.apply_async(lambda : com.info) | |||
|
40 | peers = ar.get_dict() | |||
|
41 | # this is a dict, keyed by engine ID, of the connection info for the EngineCommunicators | |||
|
42 | ||||
|
43 | # connect the engines to each other: | |||
|
44 | def connect(com, peers, tree, pub_url, root_id): | |||
|
45 | """this function will be called on the engines""" | |||
|
46 | com.connect(peers, tree, pub_url, root_id) | |||
|
47 | ||||
|
48 | view.apply_sync(connect, Reference('com'), peers, btree, pub_url, root_id) | |||
|
49 | ||||
|
50 | # functions that can be used for reductions | |||
|
51 | # max and min builtins can be used as well | |||
|
52 | def add(a,b): | |||
|
53 | """cumulative sum reduction""" | |||
|
54 | return a+b | |||
|
55 | ||||
|
56 | def mul(a,b): | |||
|
57 | """cumulative product reduction""" | |||
|
58 | return a*b | |||
|
59 | ||||
|
60 | view['add'] = add | |||
|
61 | view['mul'] = mul | |||
|
62 | ||||
|
63 | # scatter some data | |||
|
64 | data = range(1000) | |||
|
65 | view.scatter('data', data) | |||
|
66 | ||||
|
67 | # perform cumulative sum via allreduce | |||
|
68 | view.execute("data_sum = com.allreduce(add, data, flat=False)") | |||
|
69 | print "allreduce sum of data on all engines:", view['data_sum'] | |||
|
70 | ||||
|
71 | # perform cumulative sum *without* final broadcast | |||
|
72 | # when not broadcasting with allreduce, the final result resides on the root node: | |||
|
73 | view.execute("ids_sum = com.reduce(add, id, flat=True)") | |||
|
74 | print "reduce sum of engine ids (not broadcast):", root['ids_sum'] | |||
|
75 | print "partial result on each engine:", view['ids_sum'] |
General Comments 0
You need to be logged in to leave comments.
Login now