##// END OF EJS Templates
add message tracking to client, add/improve tests
MinRK -
Show More
@@ -34,13 +34,17 class AsyncResult(object):
34 34 """
35 35
36 36 msg_ids = None
37 _targets = None
38 _tracker = None
37 39
38 def __init__(self, client, msg_ids, fname='unknown'):
40 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
39 41 self._client = client
40 42 if isinstance(msg_ids, basestring):
41 43 msg_ids = [msg_ids]
42 44 self.msg_ids = msg_ids
43 45 self._fname=fname
46 self._targets = targets
47 self._tracker = tracker
44 48 self._ready = False
45 49 self._success = None
46 50 self._single_result = len(msg_ids) == 1
@@ -169,6 +173,19 class AsyncResult(object):
169 173
170 174 def __dict__(self):
171 175 return self.get_dict(0)
176
177 def abort(self):
178 """abort my tasks."""
179 assert not self.ready(), "Can't abort, I am already done!"
180 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
181
182 @property
183 def sent(self):
184 """check whether my messages have been sent"""
185 if self._tracker is None:
186 return True
187 else:
188 return self._tracker.done
172 189
173 190 #-------------------------------------
174 191 # dict-access
@@ -356,6 +356,9 class Client(HasTraits):
356 356 'apply_reply' : self._handle_apply_reply}
357 357 self._connect(sshserver, ssh_kwargs)
358 358
359 def __del__(self):
360 """cleanup sockets, but _not_ context."""
361 self.close()
359 362
360 363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
361 364 if ipython_dir is None:
@@ -387,7 +390,8 class Client(HasTraits):
387 390 return
388 391 snames = filter(lambda n: n.endswith('socket'), dir(self))
389 392 for socket in map(lambda name: getattr(self, name), snames):
390 socket.close()
393 if isinstance(socket, zmq.Socket) and not socket.closed:
394 socket.close()
391 395 self._closed = True
392 396
393 397 def _update_engines(self, engines):
@@ -550,7 +554,6 class Client(HasTraits):
550 554 outstanding = self._outstanding_dict[uuid]
551 555
552 556 for msg_id in list(outstanding):
553 print msg_id
554 557 if msg_id in self.results:
555 558 # we already
556 559 continue
@@ -796,7 +799,7 class Client(HasTraits):
796 799 if msg['content']['status'] != 'ok':
797 800 error = self._unwrap_exception(msg['content'])
798 801 if error:
799 return error
802 raise error
800 803
801 804
802 805 @spinfirst
@@ -840,7 +843,7 class Client(HasTraits):
840 843 if msg['content']['status'] != 'ok':
841 844 error = self._unwrap_exception(msg['content'])
842 845 if error:
843 return error
846 raise error
844 847
845 848 @spinfirst
846 849 @defaultblock
@@ -945,7 +948,8 class Client(HasTraits):
945 948 @defaultblock
946 949 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
947 950 targets=None, balanced=None,
948 after=None, follow=None, timeout=None):
951 after=None, follow=None, timeout=None,
952 track=False):
949 953 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
950 954
951 955 This is the central execution command for the client.
@@ -1003,6 +1007,9 class Client(HasTraits):
1003 1007 Specify an amount of time (in seconds) for the scheduler to
1004 1008 wait for dependencies to be met before failing with a
1005 1009 DependencyTimeout.
1010 track : bool
1011 whether to track non-copying sends.
1012 [default False]
1006 1013
1007 1014 after,follow,timeout only used if `balanced=True`.
1008 1015
@@ -1044,7 +1051,7 class Client(HasTraits):
1044 1051 if not isinstance(kwargs, dict):
1045 1052 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046 1053
1047 options = dict(bound=bound, block=block, targets=targets)
1054 options = dict(bound=bound, block=block, targets=targets, track=track)
1048 1055
1049 1056 if balanced:
1050 1057 return self._apply_balanced(f, args, kwargs, timeout=timeout,
@@ -1057,7 +1064,7 class Client(HasTraits):
1057 1064 return self._apply_direct(f, args, kwargs, **options)
1058 1065
1059 1066 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1060 after=None, follow=None, timeout=None):
1067 after=None, follow=None, timeout=None, track=None):
1061 1068 """call f(*args, **kwargs) remotely in a load-balanced manner.
1062 1069
1063 1070 This is a private method, see `apply` for details.
@@ -1065,7 +1072,7 class Client(HasTraits):
1065 1072 """
1066 1073
1067 1074 loc = locals()
1068 for name in ('bound', 'block'):
1075 for name in ('bound', 'block', 'track'):
1069 1076 assert loc[name] is not None, "kwarg %r must be specified!"%name
1070 1077
1071 1078 if self._task_socket is None:
@@ -1101,13 +1108,13 class Client(HasTraits):
1101 1108 content = dict(bound=bound)
1102 1109
1103 1110 msg = self.session.send(self._task_socket, "apply_request",
1104 content=content, buffers=bufs, subheader=subheader)
1111 content=content, buffers=bufs, subheader=subheader, track=track)
1105 1112 msg_id = msg['msg_id']
1106 1113 self.outstanding.add(msg_id)
1107 1114 self.history.append(msg_id)
1108 1115 self.metadata[msg_id]['submitted'] = datetime.now()
1109
1110 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1116 tracker = None if track is False else msg['tracker']
1117 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1111 1118 if block:
1112 1119 try:
1113 1120 return ar.get()
@@ -1116,7 +1123,8 class Client(HasTraits):
1116 1123 else:
1117 1124 return ar
1118 1125
1119 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
1126 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1127 track=None):
1120 1128 """Then underlying method for applying functions to specific engines
1121 1129 via the MUX queue.
1122 1130
@@ -1124,7 +1132,7 class Client(HasTraits):
1124 1132 Not to be called directly!
1125 1133 """
1126 1134 loc = locals()
1127 for name in ('bound', 'block', 'targets'):
1135 for name in ('bound', 'block', 'targets', 'track'):
1128 1136 assert loc[name] is not None, "kwarg %r must be specified!"%name
1129 1137
1130 1138 idents,targets = self._build_targets(targets)
@@ -1134,15 +1142,22 class Client(HasTraits):
1134 1142 bufs = util.pack_apply_message(f,args,kwargs)
1135 1143
1136 1144 msg_ids = []
1145 trackers = []
1137 1146 for ident in idents:
1138 1147 msg = self.session.send(self._mux_socket, "apply_request",
1139 content=content, buffers=bufs, ident=ident, subheader=subheader)
1148 content=content, buffers=bufs, ident=ident, subheader=subheader,
1149 track=track)
1150 if track:
1151 trackers.append(msg['tracker'])
1140 1152 msg_id = msg['msg_id']
1141 1153 self.outstanding.add(msg_id)
1142 1154 self._outstanding_dict[ident].add(msg_id)
1143 1155 self.history.append(msg_id)
1144 1156 msg_ids.append(msg_id)
1145 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1157
1158 tracker = None if track is False else zmq.MessageTracker(*trackers)
1159 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1160
1146 1161 if block:
1147 1162 try:
1148 1163 return ar.get()
@@ -1230,11 +1245,11 class Client(HasTraits):
1230 1245 #--------------------------------------------------------------------------
1231 1246
1232 1247 @defaultblock
1233 def push(self, ns, targets='all', block=None):
1248 def push(self, ns, targets='all', block=None, track=False):
1234 1249 """Push the contents of `ns` into the namespace on `target`"""
1235 1250 if not isinstance(ns, dict):
1236 1251 raise TypeError("Must be a dict, not %s"%type(ns))
1237 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
1252 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track)
1238 1253 if not block:
1239 1254 return result
1240 1255
@@ -1251,7 +1266,7 class Client(HasTraits):
1251 1266 return result
1252 1267
1253 1268 @defaultblock
1254 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1269 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1255 1270 """
1256 1271 Partition a Python sequence and send the partitions to a set of engines.
1257 1272 """
@@ -1259,16 +1274,25 class Client(HasTraits):
1259 1274 mapObject = Map.dists[dist]()
1260 1275 nparts = len(targets)
1261 1276 msg_ids = []
1277 trackers = []
1262 1278 for index, engineid in enumerate(targets):
1263 1279 partition = mapObject.getPartition(seq, index, nparts)
1264 1280 if flatten and len(partition) == 1:
1265 r = self.push({key: partition[0]}, targets=engineid, block=False)
1281 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1266 1282 else:
1267 r = self.push({key: partition}, targets=engineid, block=False)
1283 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1268 1284 msg_ids.extend(r.msg_ids)
1269 r = AsyncResult(self, msg_ids, fname='scatter')
1285 if track:
1286 trackers.append(r._tracker)
1287
1288 if track:
1289 tracker = zmq.MessageTracker(*trackers)
1290 else:
1291 tracker = None
1292
1293 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1270 1294 if block:
1271 r.get()
1295 r.wait()
1272 1296 else:
1273 1297 return r
1274 1298
@@ -179,7 +179,7 class StreamSession(object):
179 179 return header.get('key', None) == self.key
180 180
181 181
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
183 183 """Build and send a message via stream or socket.
184 184
185 185 Parameters
@@ -191,13 +191,34 class StreamSession(object):
191 191 Normally, msg_or_type will be a msg_type unless a message is being sent more
192 192 than once.
193 193
194 content : dict or None
195 the content of the message (ignored if msg_or_type is a message)
196 buffers : list or None
197 the already-serialized buffers to be appended to the message
198 parent : Message or dict or None
199 the parent or parent header describing the parent of this message
200 subheader : dict or None
201 extra header keys for this message's header
202 ident : bytes or list of bytes
203 the zmq.IDENTITY routing path
204 track : bool
205 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
206
194 207 Returns
195 208 -------
196 (msg,sent) : tuple
197 msg : Message
198 the nice wrapped dict-like object containing the headers
209 msg : message dict
210 the constructed message
211 (msg,tracker) : (message dict, MessageTracker)
212 if track=True, then a 2-tuple will be returned, the first element being the constructed
213 message, and the second being the MessageTracker
199 214
200 215 """
216
217 if not isinstance(stream, (zmq.Socket, ZMQStream)):
218 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
219 elif track and isinstance(stream, ZMQStream):
220 raise TypeError("ZMQStream cannot track messages")
221
201 222 if isinstance(msg_or_type, (Message, dict)):
202 223 # we got a Message, not a msg_type
203 224 # don't build a new Message
@@ -205,6 +226,7 class StreamSession(object):
205 226 content = msg['content']
206 227 else:
207 228 msg = self.msg(msg_or_type, content, parent, subheader)
229
208 230 buffers = [] if buffers is None else buffers
209 231 to_send = []
210 232 if isinstance(ident, list):
@@ -222,7 +244,7 class StreamSession(object):
222 244 content = self.none
223 245 elif isinstance(content, dict):
224 246 content = self.pack(content)
225 elif isinstance(content, str):
247 elif isinstance(content, bytes):
226 248 # content is already packed, as in a relayed message
227 249 pass
228 250 else:
@@ -231,16 +253,29 class StreamSession(object):
231 253 flag = 0
232 254 if buffers:
233 255 flag = zmq.SNDMORE
234 stream.send_multipart(to_send, flag, copy=False)
256 _track = False
257 else:
258 _track=track
259 if track:
260 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
261 else:
262 tracker = stream.send_multipart(to_send, flag, copy=False)
235 263 for b in buffers[:-1]:
236 264 stream.send(b, flag, copy=False)
237 265 if buffers:
238 stream.send(buffers[-1], copy=False)
266 if track:
267 tracker = stream.send(buffers[-1], copy=False, track=track)
268 else:
269 tracker = stream.send(buffers[-1], copy=False)
270
239 271 # omsg = Message(msg)
240 272 if self.debug:
241 273 pprint.pprint(msg)
242 274 pprint.pprint(to_send)
243 275 pprint.pprint(buffers)
276
277 msg['tracker'] = tracker
278
244 279 return msg
245 280
246 281 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
@@ -250,7 +285,7 class StreamSession(object):
250 285 ----------
251 286 msg : list of sendable buffers"""
252 287 to_send = []
253 if isinstance(ident, str):
288 if isinstance(ident, bytes):
254 289 ident = [ident]
255 290 if ident is not None:
256 291 to_send.extend(ident)
@@ -1,24 +1,26
1 1 """toplevel setup/teardown for parallel tests."""
2 2
3 import tempfile
3 4 import time
4 from subprocess import Popen, PIPE
5 from subprocess import Popen, PIPE, STDOUT
5 6
6 7 from IPython.zmq.parallel.ipcluster import launch_process
7 8 from IPython.zmq.parallel.entry_point import select_random_ports
8 9
9 10 processes = []
11 blackhole = tempfile.TemporaryFile()
10 12
11 13 # nose setup/teardown
12 14
13 15 def setup():
14 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
16 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
15 17 processes.append(cp)
16 18 time.sleep(.5)
17 19 add_engine()
18 time.sleep(3)
20 time.sleep(2)
19 21
20 22 def add_engine(profile='iptest'):
21 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
23 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
22 24 # ep.start()
23 25 processes.append(ep)
24 26 return ep
@@ -88,7 +88,9 class ClusterTestCase(BaseZMQTestCase):
88 88 self.base_engine_count=len(self.client.ids)
89 89 self.engines=[]
90 90
91 # def tearDown(self):
91 def tearDown(self):
92 self.client.close()
93 BaseZMQTestCase.tearDown(self)
92 94 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
93 95 # [ e.wait() for e in self.engines ]
94 96 # while len(self.client.ids) > self.base_engine_count:
@@ -2,6 +2,7 import time
2 2 from tempfile import mktemp
3 3
4 4 import nose.tools as nt
5 import zmq
5 6
6 7 from IPython.zmq.parallel import client as clientmod
7 8 from IPython.zmq.parallel import error
@@ -18,10 +19,9 class TestClient(ClusterTestCase):
18 19 self.assertEquals(len(self.client.ids), n+3)
19 20 self.assertTrue
20 21
21 def test_segfault(self):
22 """test graceful handling of engine death"""
22 def test_segfault_task(self):
23 """test graceful handling of engine death (balanced)"""
23 24 self.add_engines(1)
24 eid = self.client.ids[-1]
25 25 ar = self.client.apply(segfault, block=False)
26 26 self.assertRaisesRemote(error.EngineError, ar.get)
27 27 eid = ar.engine_id
@@ -29,6 +29,17 class TestClient(ClusterTestCase):
29 29 time.sleep(.01)
30 30 self.client.spin()
31 31
32 def test_segfault_mux(self):
33 """test graceful handling of engine death (direct)"""
34 self.add_engines(1)
35 eid = self.client.ids[-1]
36 ar = self.client[eid].apply_async(segfault)
37 self.assertRaisesRemote(error.EngineError, ar.get)
38 eid = ar.engine_id
39 while eid in self.client.ids:
40 time.sleep(.01)
41 self.client.spin()
42
32 43 def test_view_indexing(self):
33 44 """test index access for views"""
34 45 self.add_engines(2)
@@ -91,13 +102,14 class TestClient(ClusterTestCase):
91 102 def test_push_pull(self):
92 103 """test pushing and pulling"""
93 104 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
105 t = self.client.ids[-1]
94 106 self.add_engines(2)
95 107 push = self.client.push
96 108 pull = self.client.pull
97 109 self.client.block=True
98 110 nengines = len(self.client)
99 push({'data':data}, targets=0)
100 d = pull('data', targets=0)
111 push({'data':data}, targets=t)
112 d = pull('data', targets=t)
101 113 self.assertEquals(d, data)
102 114 push({'data':data})
103 115 d = pull('data')
@@ -119,15 +131,16 class TestClient(ClusterTestCase):
119 131 return 2.0*x
120 132
121 133 self.add_engines(4)
134 t = self.client.ids[-1]
122 135 self.client.block=True
123 136 push = self.client.push
124 137 pull = self.client.pull
125 138 execute = self.client.execute
126 push({'testf':testf}, targets=0)
127 r = pull('testf', targets=0)
139 push({'testf':testf}, targets=t)
140 r = pull('testf', targets=t)
128 141 self.assertEqual(r(1.0), testf(1.0))
129 execute('r = testf(10)', targets=0)
130 r = pull('r', targets=0)
142 execute('r = testf(10)', targets=t)
143 r = pull('r', targets=t)
131 144 self.assertEquals(r, testf(10))
132 145 ar = push({'testf':testf}, block=False)
133 146 ar.get()
@@ -135,8 +148,8 class TestClient(ClusterTestCase):
135 148 rlist = ar.get()
136 149 for r in rlist:
137 150 self.assertEqual(r(1.0), testf(1.0))
138 execute("def g(x): return x*x", targets=0)
139 r = pull(('testf','g'),targets=0)
151 execute("def g(x): return x*x", targets=t)
152 r = pull(('testf','g'),targets=t)
140 153 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
141 154
142 155 def test_push_function_globals(self):
@@ -173,7 +186,7 class TestClient(ClusterTestCase):
173 186 ids.remove(ids[-1])
174 187 self.assertNotEquals(ids, self.client._ids)
175 188
176 def test_arun_newline(self):
189 def test_run_newline(self):
177 190 """test that run appends newline to files"""
178 191 tmpfile = mktemp()
179 192 with open(tmpfile, 'w') as f:
@@ -184,4 +197,56 class TestClient(ClusterTestCase):
184 197 v.run(tmpfile, block=True)
185 198 self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
186 199
187 No newline at end of file
200 def test_apply_tracked(self):
201 """test tracking for apply"""
202 # self.add_engines(1)
203 t = self.client.ids[-1]
204 self.client.block=False
205 def echo(n=1024*1024, **kwargs):
206 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
207 ar = echo(1)
208 self.assertTrue(ar._tracker is None)
209 self.assertTrue(ar.sent)
210 ar = echo(track=True)
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 self.assertEquals(ar.sent, ar._tracker.done)
213 ar._tracker.wait()
214 self.assertTrue(ar.sent)
215
216 def test_push_tracked(self):
217 t = self.client.ids[-1]
218 ns = dict(x='x'*1024*1024)
219 ar = self.client.push(ns, targets=t, block=False)
220 self.assertTrue(ar._tracker is None)
221 self.assertTrue(ar.sent)
222
223 ar = self.client.push(ns, targets=t, block=False, track=True)
224 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
225 self.assertEquals(ar.sent, ar._tracker.done)
226 ar._tracker.wait()
227 self.assertTrue(ar.sent)
228 ar.get()
229
230 def test_scatter_tracked(self):
231 t = self.client.ids
232 x='x'*1024*1024
233 ar = self.client.scatter('x', x, targets=t, block=False)
234 self.assertTrue(ar._tracker is None)
235 self.assertTrue(ar.sent)
236
237 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
238 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
239 self.assertEquals(ar.sent, ar._tracker.done)
240 ar._tracker.wait()
241 self.assertTrue(ar.sent)
242 ar.get()
243
244 def test_remote_reference(self):
245 v = self.client[-1]
246 v['a'] = 123
247 ra = clientmod.Reference('a')
248 b = v.apply_sync_bound(lambda x: x, ra)
249 self.assertEquals(b, 123)
250 self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra)
251
252
@@ -4,7 +4,7 import uuid
4 4 import zmq
5 5
6 6 from zmq.tests import BaseZMQTestCase
7
7 from zmq.eventloop.zmqstream import ZMQStream
8 8 # from IPython.zmq.tests import SessionTestCase
9 9 from IPython.zmq.parallel import streamsession as ss
10 10
@@ -31,7 +31,7 class TestSession(SessionTestCase):
31 31
32 32 def test_args(self):
33 33 """initialization arguments for StreamSession"""
34 s = ss.StreamSession()
34 s = self.session
35 35 self.assertTrue(s.pack is ss.default_packer)
36 36 self.assertTrue(s.unpack is ss.default_unpacker)
37 37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
@@ -46,7 +46,24 class TestSession(SessionTestCase):
46 46 self.assertEquals(s.session, u)
47 47 self.assertEquals(s.username, 'carrot')
48 48
49
49 def test_tracking(self):
50 """test tracking messages"""
51 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
52 s = self.session
53 stream = ZMQStream(a)
54 msg = s.send(a, 'hello', track=False)
55 self.assertTrue(msg['tracker'] is None)
56 msg = s.send(a, 'hello', track=True)
57 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
58 M = zmq.Message(b'hi there', track=True)
59 msg = s.send(a, 'hello', buffers=[M], track=True)
60 t = msg['tracker']
61 self.assertTrue(isinstance(t, zmq.MessageTracker))
62 self.assertRaises(zmq.NotDone, t.wait, .1)
63 del M
64 t.wait(1) # this will raise
65
66
50 67 # def test_rekey(self):
51 68 # """rekeying dict around json str keys"""
52 69 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
General Comments 0
You need to be logged in to leave comments. Login now