##// END OF EJS Templates
add binary-tree engine interconnect example...
MinRK -
Show More
@@ -0,0 +1,218 b''
1 """
2 BinaryTree inter-engine communication class
3
4 use from bintree_script.py
5
6 Provides parallel [all]reduce functionality
7
8 """
9
10 import cPickle as pickle
11 import re
12 import socket
13 import uuid
14
15 import zmq
16
17 from IPython.parallel.util import disambiguate_url
18
19
20 #----------------------------------------------------------------------------
21 # bintree-related construction/printing helpers
22 #----------------------------------------------------------------------------
23
24 def bintree(ids, parent=None):
25 """construct {child:parent} dict representation of a binary tree"""
26 parents = {}
27 n = len(ids)
28 if n == 0:
29 return parents
30 root = ids[0]
31 parents[root] = parent
32 if len(ids) == 1:
33 return parents
34 else:
35 ids = ids[1:]
36 n = len(ids)
37 left = bintree(ids[:n/2], parent=root)
38 right = bintree(ids[n/2:], parent=root)
39 parents.update(left)
40 parents.update(right)
41 return parents
42
43 def reverse_bintree(parents):
44 """construct {parent:[children]} dict from {child:parent}"""
45 children = {}
46 for child,parent in parents.iteritems():
47 if parent is None:
48 children[None] = child
49 continue
50 elif parent not in children:
51 children[parent] = []
52 children[parent].append(child)
53
54 return children
55
56 def depth(n, tree):
57 """get depth of an element in the tree"""
58 d = 0
59 parent = tree[n]
60 while parent is not None:
61 d += 1
62 parent = tree[parent]
63 return d
64
65 def print_bintree(tree, indent=' '):
66 """print a binary tree"""
67 for n in sorted(tree.keys()):
68 print "%s%s" % (indent * depth(n,tree), n)
69
70 #----------------------------------------------------------------------------
71 # Communicator class for a binary-tree map
72 #----------------------------------------------------------------------------
73
74 ip_pat = re.compile(r'^\d+\.\d+\.\d+\.\d+$')
75
76 def disambiguate_dns_url(url, location):
77 """accept either IP address or dns name, and return IP"""
78 if not ip_pat.match(location):
79 location = socket.gethostbyname(location)
80 return disambiguate_url(url, location)
81
82 class BinaryTreeCommunicator(object):
83
84 id = None
85 pub = None
86 sub = None
87 downstream = None
88 upstream = None
89 pub_url = None
90 tree_url = None
91
92 def __init__(self, id, interface='tcp://*', root=False):
93 self.id = id
94 self.root = root
95
96 # create context and sockets
97 self._ctx = zmq.Context()
98 if root:
99 self.pub = self._ctx.socket(zmq.PUB)
100 else:
101 self.sub = self._ctx.socket(zmq.SUB)
102 self.sub.setsockopt(zmq.SUBSCRIBE, b'')
103 self.downstream = self._ctx.socket(zmq.PULL)
104 self.upstream = self._ctx.socket(zmq.PUSH)
105
106 # bind to ports
107 interface_f = interface + ":%i"
108 if self.root:
109 pub_port = self.pub.bind_to_random_port(interface)
110 self.pub_url = interface_f % pub_port
111
112 tree_port = self.downstream.bind_to_random_port(interface)
113 self.tree_url = interface_f % tree_port
114 self.downstream_poller = zmq.Poller()
115 self.downstream_poller.register(self.downstream, zmq.POLLIN)
116
117 # guess first public IP from socket
118 self.location = socket.gethostbyname_ex(socket.gethostname())[-1][0]
119
120 def __del__(self):
121 self.downstream.close()
122 self.upstream.close()
123 if self.root:
124 self.pub.close()
125 else:
126 self.sub.close()
127 self._ctx.term()
128
129 @property
130 def info(self):
131 """return the connection info for this object's sockets."""
132 return (self.tree_url, self.location)
133
134 def connect(self, peers, btree, pub_url, root_id=0):
135 """connect to peers. `peers` will be a dict of 4-tuples, keyed by name.
136 {peer : (ident, addr, pub_addr, location)}
137 where peer is the name, ident is the XREP identity, addr,pub_addr are the
138 """
139
140 # count the number of children we have
141 self.nchildren = btree.values().count(self.id)
142
143 if self.root:
144 return # root only binds
145
146 root_location = peers[root_id][-1]
147 self.sub.connect(disambiguate_dns_url(pub_url, root_location))
148
149 parent = btree[self.id]
150
151 tree_url, location = peers[parent]
152 self.upstream.connect(disambiguate_dns_url(tree_url, location))
153
154 def serialize(self, obj):
155 """serialize objects.
156
157 Must return list of sendable buffers.
158
159 Can be extended for more efficient/noncopying serialization of numpy arrays, etc.
160 """
161 return [pickle.dumps(obj)]
162
163 def unserialize(self, msg):
164 """inverse of serialize"""
165 return pickle.loads(msg[0])
166
167 def publish(self, value):
168 assert self.root
169 self.pub.send_multipart(self.serialize(value))
170
171 def consume(self):
172 assert not self.root
173 return self.unserialize(self.sub.recv_multipart())
174
175 def send_upstream(self, value, flags=0):
176 assert not self.root
177 self.upstream.send_multipart(self.serialize(value), flags=flags|zmq.NOBLOCK)
178
179 def recv_downstream(self, flags=0, timeout=2000.):
180 # wait for a message, so we won't block if there was a bug
181 self.downstream_poller.poll(timeout)
182
183 msg = self.downstream.recv_multipart(zmq.NOBLOCK|flags)
184 return self.unserialize(msg)
185
186 def reduce(self, f, value, flat=True, all=False):
187 """parallel reduce on binary tree
188
189 if flat:
190 value is an entry in the sequence
191 else:
192 value is a list of entries in the sequence
193
194 if all:
195 broadcast final result to all nodes
196 else:
197 only root gets final result
198 """
199 if not flat:
200 value = reduce(f, value)
201
202 for i in range(self.nchildren):
203 value = f(value, self.recv_downstream())
204
205 if not self.root:
206 self.send_upstream(value)
207
208 if all:
209 if self.root:
210 self.publish(value)
211 else:
212 value = self.consume()
213 return value
214
215 def allreduce(self, f, value, flat=True):
216 """parallel reduce followed by broadcast of the result"""
217 return self.reduce(f, value, flat=flat, all=True)
218
@@ -0,0 +1,75 b''
1 """
2 Script for setting up and using [all]reduce with a binary-tree engine interconnect.
3
4 usage: `python bintree_script.py`
5
6 """
7
8 from IPython.parallel import Client, Reference
9
10
11 # connect client and create views
12 rc = Client()
13 rc.block=True
14 ids = rc.ids
15
16 root_id = ids[0]
17 root = rc[root_id]
18
19 view = rc[:]
20
21 # run bintree.py script defining bintree functions, etc.
22 execfile('bintree.py')
23
24 # generate binary tree of parents
25 btree = bintree(ids)
26
27 print "setting up binary tree interconnect:"
28 print_bintree(btree)
29
30 view.run('bintree.py')
31 view.scatter('id', ids, flatten=True)
32 view['root_id'] = root_id
33
34 # create the Communicator objects on the engines
35 view.execute('com = BinaryTreeCommunicator(id, root = id==root_id )')
36 pub_url = root.apply_sync(lambda : com.pub_url)
37
38 # gather the connection information into a dict
39 ar = view.apply_async(lambda : com.info)
40 peers = ar.get_dict()
41 # this is a dict, keyed by engine ID, of the connection info for the EngineCommunicators
42
43 # connect the engines to each other:
44 def connect(com, peers, tree, pub_url, root_id):
45 """this function will be called on the engines"""
46 com.connect(peers, tree, pub_url, root_id)
47
48 view.apply_sync(connect, Reference('com'), peers, btree, pub_url, root_id)
49
50 # functions that can be used for reductions
51 # max and min builtins can be used as well
52 def add(a,b):
53 """cumulative sum reduction"""
54 return a+b
55
56 def mul(a,b):
57 """cumulative product reduction"""
58 return a*b
59
60 view['add'] = add
61 view['mul'] = mul
62
63 # scatter some data
64 data = range(1000)
65 view.scatter('data', data)
66
67 # perform cumulative sum via allreduce
68 view.execute("data_sum = com.allreduce(add, data, flat=False)")
69 print "allreduce sum of data on all engines:", view['data_sum']
70
71 # perform cumulative sum *without* final broadcast
72 # when not broadcasting with allreduce, the final result resides on the root node:
73 view.execute("ids_sum = com.reduce(add, id, flat=True)")
74 print "reduce sum of engine ids (not broadcast):", root['ids_sum']
75 print "partial result on each engine:", view['ids_sum']
General Comments 0
You need to be logged in to leave comments. Login now