##// END OF EJS Templates
add message tracking to client, add/improve tests
MinRK -
Show More
@@ -34,13 +34,17 b' class AsyncResult(object):'
34 """
34 """
35
35
36 msg_ids = None
36 msg_ids = None
37 _targets = None
38 _tracker = None
37
39
38 def __init__(self, client, msg_ids, fname='unknown'):
40 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
39 self._client = client
41 self._client = client
40 if isinstance(msg_ids, basestring):
42 if isinstance(msg_ids, basestring):
41 msg_ids = [msg_ids]
43 msg_ids = [msg_ids]
42 self.msg_ids = msg_ids
44 self.msg_ids = msg_ids
43 self._fname=fname
45 self._fname=fname
46 self._targets = targets
47 self._tracker = tracker
44 self._ready = False
48 self._ready = False
45 self._success = None
49 self._success = None
46 self._single_result = len(msg_ids) == 1
50 self._single_result = len(msg_ids) == 1
@@ -169,6 +173,19 b' class AsyncResult(object):'
169
173
170 def __dict__(self):
174 def __dict__(self):
171 return self.get_dict(0)
175 return self.get_dict(0)
176
177 def abort(self):
178 """abort my tasks."""
179 assert not self.ready(), "Can't abort, I am already done!"
180 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
181
182 @property
183 def sent(self):
184 """check whether my messages have been sent"""
185 if self._tracker is None:
186 return True
187 else:
188 return self._tracker.done
172
189
173 #-------------------------------------
190 #-------------------------------------
174 # dict-access
191 # dict-access
@@ -356,6 +356,9 b' class Client(HasTraits):'
356 'apply_reply' : self._handle_apply_reply}
356 'apply_reply' : self._handle_apply_reply}
357 self._connect(sshserver, ssh_kwargs)
357 self._connect(sshserver, ssh_kwargs)
358
358
359 def __del__(self):
360 """cleanup sockets, but _not_ context."""
361 self.close()
359
362
360 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
361 if ipython_dir is None:
364 if ipython_dir is None:
@@ -387,7 +390,8 b' class Client(HasTraits):'
387 return
390 return
388 snames = filter(lambda n: n.endswith('socket'), dir(self))
391 snames = filter(lambda n: n.endswith('socket'), dir(self))
389 for socket in map(lambda name: getattr(self, name), snames):
392 for socket in map(lambda name: getattr(self, name), snames):
390 socket.close()
393 if isinstance(socket, zmq.Socket) and not socket.closed:
394 socket.close()
391 self._closed = True
395 self._closed = True
392
396
393 def _update_engines(self, engines):
397 def _update_engines(self, engines):
@@ -550,7 +554,6 b' class Client(HasTraits):'
550 outstanding = self._outstanding_dict[uuid]
554 outstanding = self._outstanding_dict[uuid]
551
555
552 for msg_id in list(outstanding):
556 for msg_id in list(outstanding):
553 print msg_id
554 if msg_id in self.results:
557 if msg_id in self.results:
555 # we already
558 # we already
556 continue
559 continue
@@ -796,7 +799,7 b' class Client(HasTraits):'
796 if msg['content']['status'] != 'ok':
799 if msg['content']['status'] != 'ok':
797 error = self._unwrap_exception(msg['content'])
800 error = self._unwrap_exception(msg['content'])
798 if error:
801 if error:
799 return error
802 raise error
800
803
801
804
802 @spinfirst
805 @spinfirst
@@ -840,7 +843,7 b' class Client(HasTraits):'
840 if msg['content']['status'] != 'ok':
843 if msg['content']['status'] != 'ok':
841 error = self._unwrap_exception(msg['content'])
844 error = self._unwrap_exception(msg['content'])
842 if error:
845 if error:
843 return error
846 raise error
844
847
845 @spinfirst
848 @spinfirst
846 @defaultblock
849 @defaultblock
@@ -945,7 +948,8 b' class Client(HasTraits):'
945 @defaultblock
948 @defaultblock
946 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
949 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
947 targets=None, balanced=None,
950 targets=None, balanced=None,
948 after=None, follow=None, timeout=None):
951 after=None, follow=None, timeout=None,
952 track=False):
949 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
953 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
950
954
951 This is the central execution command for the client.
955 This is the central execution command for the client.
@@ -1003,6 +1007,9 b' class Client(HasTraits):'
1003 Specify an amount of time (in seconds) for the scheduler to
1007 Specify an amount of time (in seconds) for the scheduler to
1004 wait for dependencies to be met before failing with a
1008 wait for dependencies to be met before failing with a
1005 DependencyTimeout.
1009 DependencyTimeout.
1010 track : bool
1011 whether to track non-copying sends.
1012 [default False]
1006
1013
1007 after,follow,timeout only used if `balanced=True`.
1014 after,follow,timeout only used if `balanced=True`.
1008
1015
@@ -1044,7 +1051,7 b' class Client(HasTraits):'
1044 if not isinstance(kwargs, dict):
1051 if not isinstance(kwargs, dict):
1045 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1052 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046
1053
1047 options = dict(bound=bound, block=block, targets=targets)
1054 options = dict(bound=bound, block=block, targets=targets, track=track)
1048
1055
1049 if balanced:
1056 if balanced:
1050 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1057 return self._apply_balanced(f, args, kwargs, timeout=timeout,
@@ -1057,7 +1064,7 b' class Client(HasTraits):'
1057 return self._apply_direct(f, args, kwargs, **options)
1064 return self._apply_direct(f, args, kwargs, **options)
1058
1065
1059 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1066 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1060 after=None, follow=None, timeout=None):
1067 after=None, follow=None, timeout=None, track=None):
1061 """call f(*args, **kwargs) remotely in a load-balanced manner.
1068 """call f(*args, **kwargs) remotely in a load-balanced manner.
1062
1069
1063 This is a private method, see `apply` for details.
1070 This is a private method, see `apply` for details.
@@ -1065,7 +1072,7 b' class Client(HasTraits):'
1065 """
1072 """
1066
1073
1067 loc = locals()
1074 loc = locals()
1068 for name in ('bound', 'block'):
1075 for name in ('bound', 'block', 'track'):
1069 assert loc[name] is not None, "kwarg %r must be specified!"%name
1076 assert loc[name] is not None, "kwarg %r must be specified!"%name
1070
1077
1071 if self._task_socket is None:
1078 if self._task_socket is None:
@@ -1101,13 +1108,13 b' class Client(HasTraits):'
1101 content = dict(bound=bound)
1108 content = dict(bound=bound)
1102
1109
1103 msg = self.session.send(self._task_socket, "apply_request",
1110 msg = self.session.send(self._task_socket, "apply_request",
1104 content=content, buffers=bufs, subheader=subheader)
1111 content=content, buffers=bufs, subheader=subheader, track=track)
1105 msg_id = msg['msg_id']
1112 msg_id = msg['msg_id']
1106 self.outstanding.add(msg_id)
1113 self.outstanding.add(msg_id)
1107 self.history.append(msg_id)
1114 self.history.append(msg_id)
1108 self.metadata[msg_id]['submitted'] = datetime.now()
1115 self.metadata[msg_id]['submitted'] = datetime.now()
1109
1116 tracker = None if track is False else msg['tracker']
1110 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1117 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1111 if block:
1118 if block:
1112 try:
1119 try:
1113 return ar.get()
1120 return ar.get()
@@ -1116,7 +1123,8 b' class Client(HasTraits):'
1116 else:
1123 else:
1117 return ar
1124 return ar
1118
1125
1119 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
1126 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1127 track=None):
1120 """Then underlying method for applying functions to specific engines
1128 """Then underlying method for applying functions to specific engines
1121 via the MUX queue.
1129 via the MUX queue.
1122
1130
@@ -1124,7 +1132,7 b' class Client(HasTraits):'
1124 Not to be called directly!
1132 Not to be called directly!
1125 """
1133 """
1126 loc = locals()
1134 loc = locals()
1127 for name in ('bound', 'block', 'targets'):
1135 for name in ('bound', 'block', 'targets', 'track'):
1128 assert loc[name] is not None, "kwarg %r must be specified!"%name
1136 assert loc[name] is not None, "kwarg %r must be specified!"%name
1129
1137
1130 idents,targets = self._build_targets(targets)
1138 idents,targets = self._build_targets(targets)
@@ -1134,15 +1142,22 b' class Client(HasTraits):'
1134 bufs = util.pack_apply_message(f,args,kwargs)
1142 bufs = util.pack_apply_message(f,args,kwargs)
1135
1143
1136 msg_ids = []
1144 msg_ids = []
1145 trackers = []
1137 for ident in idents:
1146 for ident in idents:
1138 msg = self.session.send(self._mux_socket, "apply_request",
1147 msg = self.session.send(self._mux_socket, "apply_request",
1139 content=content, buffers=bufs, ident=ident, subheader=subheader)
1148 content=content, buffers=bufs, ident=ident, subheader=subheader,
1149 track=track)
1150 if track:
1151 trackers.append(msg['tracker'])
1140 msg_id = msg['msg_id']
1152 msg_id = msg['msg_id']
1141 self.outstanding.add(msg_id)
1153 self.outstanding.add(msg_id)
1142 self._outstanding_dict[ident].add(msg_id)
1154 self._outstanding_dict[ident].add(msg_id)
1143 self.history.append(msg_id)
1155 self.history.append(msg_id)
1144 msg_ids.append(msg_id)
1156 msg_ids.append(msg_id)
1145 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1157
1158 tracker = None if track is False else zmq.MessageTracker(*trackers)
1159 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1160
1146 if block:
1161 if block:
1147 try:
1162 try:
1148 return ar.get()
1163 return ar.get()
@@ -1230,11 +1245,11 b' class Client(HasTraits):'
1230 #--------------------------------------------------------------------------
1245 #--------------------------------------------------------------------------
1231
1246
1232 @defaultblock
1247 @defaultblock
1233 def push(self, ns, targets='all', block=None):
1248 def push(self, ns, targets='all', block=None, track=False):
1234 """Push the contents of `ns` into the namespace on `target`"""
1249 """Push the contents of `ns` into the namespace on `target`"""
1235 if not isinstance(ns, dict):
1250 if not isinstance(ns, dict):
1236 raise TypeError("Must be a dict, not %s"%type(ns))
1251 raise TypeError("Must be a dict, not %s"%type(ns))
1237 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
1252 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track)
1238 if not block:
1253 if not block:
1239 return result
1254 return result
1240
1255
@@ -1251,7 +1266,7 b' class Client(HasTraits):'
1251 return result
1266 return result
1252
1267
1253 @defaultblock
1268 @defaultblock
1254 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1269 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1255 """
1270 """
1256 Partition a Python sequence and send the partitions to a set of engines.
1271 Partition a Python sequence and send the partitions to a set of engines.
1257 """
1272 """
@@ -1259,16 +1274,25 b' class Client(HasTraits):'
1259 mapObject = Map.dists[dist]()
1274 mapObject = Map.dists[dist]()
1260 nparts = len(targets)
1275 nparts = len(targets)
1261 msg_ids = []
1276 msg_ids = []
1277 trackers = []
1262 for index, engineid in enumerate(targets):
1278 for index, engineid in enumerate(targets):
1263 partition = mapObject.getPartition(seq, index, nparts)
1279 partition = mapObject.getPartition(seq, index, nparts)
1264 if flatten and len(partition) == 1:
1280 if flatten and len(partition) == 1:
1265 r = self.push({key: partition[0]}, targets=engineid, block=False)
1281 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1266 else:
1282 else:
1267 r = self.push({key: partition}, targets=engineid, block=False)
1283 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1268 msg_ids.extend(r.msg_ids)
1284 msg_ids.extend(r.msg_ids)
1269 r = AsyncResult(self, msg_ids, fname='scatter')
1285 if track:
1286 trackers.append(r._tracker)
1287
1288 if track:
1289 tracker = zmq.MessageTracker(*trackers)
1290 else:
1291 tracker = None
1292
1293 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1270 if block:
1294 if block:
1271 r.get()
1295 r.wait()
1272 else:
1296 else:
1273 return r
1297 return r
1274
1298
@@ -179,7 +179,7 b' class StreamSession(object):'
179 return header.get('key', None) == self.key
179 return header.get('key', None) == self.key
180
180
181
181
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
183 """Build and send a message via stream or socket.
183 """Build and send a message via stream or socket.
184
184
185 Parameters
185 Parameters
@@ -191,13 +191,34 b' class StreamSession(object):'
191 Normally, msg_or_type will be a msg_type unless a message is being sent more
191 Normally, msg_or_type will be a msg_type unless a message is being sent more
192 than once.
192 than once.
193
193
194 content : dict or None
195 the content of the message (ignored if msg_or_type is a message)
196 buffers : list or None
197 the already-serialized buffers to be appended to the message
198 parent : Message or dict or None
199 the parent or parent header describing the parent of this message
200 subheader : dict or None
201 extra header keys for this message's header
202 ident : bytes or list of bytes
203 the zmq.IDENTITY routing path
204 track : bool
205 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
206
194 Returns
207 Returns
195 -------
208 -------
196 (msg,sent) : tuple
209 msg : message dict
197 msg : Message
210 the constructed message
198 the nice wrapped dict-like object containing the headers
211 (msg,tracker) : (message dict, MessageTracker)
212 if track=True, then a 2-tuple will be returned, the first element being the constructed
213 message, and the second being the MessageTracker
199
214
200 """
215 """
216
217 if not isinstance(stream, (zmq.Socket, ZMQStream)):
218 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
219 elif track and isinstance(stream, ZMQStream):
220 raise TypeError("ZMQStream cannot track messages")
221
201 if isinstance(msg_or_type, (Message, dict)):
222 if isinstance(msg_or_type, (Message, dict)):
202 # we got a Message, not a msg_type
223 # we got a Message, not a msg_type
203 # don't build a new Message
224 # don't build a new Message
@@ -205,6 +226,7 b' class StreamSession(object):'
205 content = msg['content']
226 content = msg['content']
206 else:
227 else:
207 msg = self.msg(msg_or_type, content, parent, subheader)
228 msg = self.msg(msg_or_type, content, parent, subheader)
229
208 buffers = [] if buffers is None else buffers
230 buffers = [] if buffers is None else buffers
209 to_send = []
231 to_send = []
210 if isinstance(ident, list):
232 if isinstance(ident, list):
@@ -222,7 +244,7 b' class StreamSession(object):'
222 content = self.none
244 content = self.none
223 elif isinstance(content, dict):
245 elif isinstance(content, dict):
224 content = self.pack(content)
246 content = self.pack(content)
225 elif isinstance(content, str):
247 elif isinstance(content, bytes):
226 # content is already packed, as in a relayed message
248 # content is already packed, as in a relayed message
227 pass
249 pass
228 else:
250 else:
@@ -231,16 +253,29 b' class StreamSession(object):'
231 flag = 0
253 flag = 0
232 if buffers:
254 if buffers:
233 flag = zmq.SNDMORE
255 flag = zmq.SNDMORE
234 stream.send_multipart(to_send, flag, copy=False)
256 _track = False
257 else:
258 _track=track
259 if track:
260 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
261 else:
262 tracker = stream.send_multipart(to_send, flag, copy=False)
235 for b in buffers[:-1]:
263 for b in buffers[:-1]:
236 stream.send(b, flag, copy=False)
264 stream.send(b, flag, copy=False)
237 if buffers:
265 if buffers:
238 stream.send(buffers[-1], copy=False)
266 if track:
267 tracker = stream.send(buffers[-1], copy=False, track=track)
268 else:
269 tracker = stream.send(buffers[-1], copy=False)
270
239 # omsg = Message(msg)
271 # omsg = Message(msg)
240 if self.debug:
272 if self.debug:
241 pprint.pprint(msg)
273 pprint.pprint(msg)
242 pprint.pprint(to_send)
274 pprint.pprint(to_send)
243 pprint.pprint(buffers)
275 pprint.pprint(buffers)
276
277 msg['tracker'] = tracker
278
244 return msg
279 return msg
245
280
246 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
281 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
@@ -250,7 +285,7 b' class StreamSession(object):'
250 ----------
285 ----------
251 msg : list of sendable buffers"""
286 msg : list of sendable buffers"""
252 to_send = []
287 to_send = []
253 if isinstance(ident, str):
288 if isinstance(ident, bytes):
254 ident = [ident]
289 ident = [ident]
255 if ident is not None:
290 if ident is not None:
256 to_send.extend(ident)
291 to_send.extend(ident)
@@ -1,24 +1,26 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2
3 import tempfile
3 import time
4 import time
4 from subprocess import Popen, PIPE
5 from subprocess import Popen, PIPE, STDOUT
5
6
6 from IPython.zmq.parallel.ipcluster import launch_process
7 from IPython.zmq.parallel.ipcluster import launch_process
7 from IPython.zmq.parallel.entry_point import select_random_ports
8 from IPython.zmq.parallel.entry_point import select_random_ports
8
9
9 processes = []
10 processes = []
11 blackhole = tempfile.TemporaryFile()
10
12
11 # nose setup/teardown
13 # nose setup/teardown
12
14
13 def setup():
15 def setup():
14 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
16 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
15 processes.append(cp)
17 processes.append(cp)
16 time.sleep(.5)
18 time.sleep(.5)
17 add_engine()
19 add_engine()
18 time.sleep(3)
20 time.sleep(2)
19
21
20 def add_engine(profile='iptest'):
22 def add_engine(profile='iptest'):
21 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
23 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
22 # ep.start()
24 # ep.start()
23 processes.append(ep)
25 processes.append(ep)
24 return ep
26 return ep
@@ -88,7 +88,9 b' class ClusterTestCase(BaseZMQTestCase):'
88 self.base_engine_count=len(self.client.ids)
88 self.base_engine_count=len(self.client.ids)
89 self.engines=[]
89 self.engines=[]
90
90
91 # def tearDown(self):
91 def tearDown(self):
92 self.client.close()
93 BaseZMQTestCase.tearDown(self)
92 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
94 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
93 # [ e.wait() for e in self.engines ]
95 # [ e.wait() for e in self.engines ]
94 # while len(self.client.ids) > self.base_engine_count:
96 # while len(self.client.ids) > self.base_engine_count:
@@ -2,6 +2,7 b' import time'
2 from tempfile import mktemp
2 from tempfile import mktemp
3
3
4 import nose.tools as nt
4 import nose.tools as nt
5 import zmq
5
6
6 from IPython.zmq.parallel import client as clientmod
7 from IPython.zmq.parallel import client as clientmod
7 from IPython.zmq.parallel import error
8 from IPython.zmq.parallel import error
@@ -18,10 +19,9 b' class TestClient(ClusterTestCase):'
18 self.assertEquals(len(self.client.ids), n+3)
19 self.assertEquals(len(self.client.ids), n+3)
19 self.assertTrue
20 self.assertTrue
20
21
21 def test_segfault(self):
22 def test_segfault_task(self):
22 """test graceful handling of engine death"""
23 """test graceful handling of engine death (balanced)"""
23 self.add_engines(1)
24 self.add_engines(1)
24 eid = self.client.ids[-1]
25 ar = self.client.apply(segfault, block=False)
25 ar = self.client.apply(segfault, block=False)
26 self.assertRaisesRemote(error.EngineError, ar.get)
26 self.assertRaisesRemote(error.EngineError, ar.get)
27 eid = ar.engine_id
27 eid = ar.engine_id
@@ -29,6 +29,17 b' class TestClient(ClusterTestCase):'
29 time.sleep(.01)
29 time.sleep(.01)
30 self.client.spin()
30 self.client.spin()
31
31
32 def test_segfault_mux(self):
33 """test graceful handling of engine death (direct)"""
34 self.add_engines(1)
35 eid = self.client.ids[-1]
36 ar = self.client[eid].apply_async(segfault)
37 self.assertRaisesRemote(error.EngineError, ar.get)
38 eid = ar.engine_id
39 while eid in self.client.ids:
40 time.sleep(.01)
41 self.client.spin()
42
32 def test_view_indexing(self):
43 def test_view_indexing(self):
33 """test index access for views"""
44 """test index access for views"""
34 self.add_engines(2)
45 self.add_engines(2)
@@ -91,13 +102,14 b' class TestClient(ClusterTestCase):'
91 def test_push_pull(self):
102 def test_push_pull(self):
92 """test pushing and pulling"""
103 """test pushing and pulling"""
93 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
104 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
105 t = self.client.ids[-1]
94 self.add_engines(2)
106 self.add_engines(2)
95 push = self.client.push
107 push = self.client.push
96 pull = self.client.pull
108 pull = self.client.pull
97 self.client.block=True
109 self.client.block=True
98 nengines = len(self.client)
110 nengines = len(self.client)
99 push({'data':data}, targets=0)
111 push({'data':data}, targets=t)
100 d = pull('data', targets=0)
112 d = pull('data', targets=t)
101 self.assertEquals(d, data)
113 self.assertEquals(d, data)
102 push({'data':data})
114 push({'data':data})
103 d = pull('data')
115 d = pull('data')
@@ -119,15 +131,16 b' class TestClient(ClusterTestCase):'
119 return 2.0*x
131 return 2.0*x
120
132
121 self.add_engines(4)
133 self.add_engines(4)
134 t = self.client.ids[-1]
122 self.client.block=True
135 self.client.block=True
123 push = self.client.push
136 push = self.client.push
124 pull = self.client.pull
137 pull = self.client.pull
125 execute = self.client.execute
138 execute = self.client.execute
126 push({'testf':testf}, targets=0)
139 push({'testf':testf}, targets=t)
127 r = pull('testf', targets=0)
140 r = pull('testf', targets=t)
128 self.assertEqual(r(1.0), testf(1.0))
141 self.assertEqual(r(1.0), testf(1.0))
129 execute('r = testf(10)', targets=0)
142 execute('r = testf(10)', targets=t)
130 r = pull('r', targets=0)
143 r = pull('r', targets=t)
131 self.assertEquals(r, testf(10))
144 self.assertEquals(r, testf(10))
132 ar = push({'testf':testf}, block=False)
145 ar = push({'testf':testf}, block=False)
133 ar.get()
146 ar.get()
@@ -135,8 +148,8 b' class TestClient(ClusterTestCase):'
135 rlist = ar.get()
148 rlist = ar.get()
136 for r in rlist:
149 for r in rlist:
137 self.assertEqual(r(1.0), testf(1.0))
150 self.assertEqual(r(1.0), testf(1.0))
138 execute("def g(x): return x*x", targets=0)
151 execute("def g(x): return x*x", targets=t)
139 r = pull(('testf','g'),targets=0)
152 r = pull(('testf','g'),targets=t)
140 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
153 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
141
154
142 def test_push_function_globals(self):
155 def test_push_function_globals(self):
@@ -173,7 +186,7 b' class TestClient(ClusterTestCase):'
173 ids.remove(ids[-1])
186 ids.remove(ids[-1])
174 self.assertNotEquals(ids, self.client._ids)
187 self.assertNotEquals(ids, self.client._ids)
175
188
176 def test_arun_newline(self):
189 def test_run_newline(self):
177 """test that run appends newline to files"""
190 """test that run appends newline to files"""
178 tmpfile = mktemp()
191 tmpfile = mktemp()
179 with open(tmpfile, 'w') as f:
192 with open(tmpfile, 'w') as f:
@@ -184,4 +197,56 b' class TestClient(ClusterTestCase):'
184 v.run(tmpfile, block=True)
197 v.run(tmpfile, block=True)
185 self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
198 self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
186
199
187 No newline at end of file
200 def test_apply_tracked(self):
201 """test tracking for apply"""
202 # self.add_engines(1)
203 t = self.client.ids[-1]
204 self.client.block=False
205 def echo(n=1024*1024, **kwargs):
206 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
207 ar = echo(1)
208 self.assertTrue(ar._tracker is None)
209 self.assertTrue(ar.sent)
210 ar = echo(track=True)
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 self.assertEquals(ar.sent, ar._tracker.done)
213 ar._tracker.wait()
214 self.assertTrue(ar.sent)
215
216 def test_push_tracked(self):
217 t = self.client.ids[-1]
218 ns = dict(x='x'*1024*1024)
219 ar = self.client.push(ns, targets=t, block=False)
220 self.assertTrue(ar._tracker is None)
221 self.assertTrue(ar.sent)
222
223 ar = self.client.push(ns, targets=t, block=False, track=True)
224 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
225 self.assertEquals(ar.sent, ar._tracker.done)
226 ar._tracker.wait()
227 self.assertTrue(ar.sent)
228 ar.get()
229
230 def test_scatter_tracked(self):
231 t = self.client.ids
232 x='x'*1024*1024
233 ar = self.client.scatter('x', x, targets=t, block=False)
234 self.assertTrue(ar._tracker is None)
235 self.assertTrue(ar.sent)
236
237 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
238 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
239 self.assertEquals(ar.sent, ar._tracker.done)
240 ar._tracker.wait()
241 self.assertTrue(ar.sent)
242 ar.get()
243
244 def test_remote_reference(self):
245 v = self.client[-1]
246 v['a'] = 123
247 ra = clientmod.Reference('a')
248 b = v.apply_sync_bound(lambda x: x, ra)
249 self.assertEquals(b, 123)
250 self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra)
251
252
@@ -4,7 +4,7 b' import uuid'
4 import zmq
4 import zmq
5
5
6 from zmq.tests import BaseZMQTestCase
6 from zmq.tests import BaseZMQTestCase
7
7 from zmq.eventloop.zmqstream import ZMQStream
8 # from IPython.zmq.tests import SessionTestCase
8 # from IPython.zmq.tests import SessionTestCase
9 from IPython.zmq.parallel import streamsession as ss
9 from IPython.zmq.parallel import streamsession as ss
10
10
@@ -31,7 +31,7 b' class TestSession(SessionTestCase):'
31
31
32 def test_args(self):
32 def test_args(self):
33 """initialization arguments for StreamSession"""
33 """initialization arguments for StreamSession"""
34 s = ss.StreamSession()
34 s = self.session
35 self.assertTrue(s.pack is ss.default_packer)
35 self.assertTrue(s.pack is ss.default_packer)
36 self.assertTrue(s.unpack is ss.default_unpacker)
36 self.assertTrue(s.unpack is ss.default_unpacker)
37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
@@ -46,7 +46,24 b' class TestSession(SessionTestCase):'
46 self.assertEquals(s.session, u)
46 self.assertEquals(s.session, u)
47 self.assertEquals(s.username, 'carrot')
47 self.assertEquals(s.username, 'carrot')
48
48
49
49 def test_tracking(self):
50 """test tracking messages"""
51 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
52 s = self.session
53 stream = ZMQStream(a)
54 msg = s.send(a, 'hello', track=False)
55 self.assertTrue(msg['tracker'] is None)
56 msg = s.send(a, 'hello', track=True)
57 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
58 M = zmq.Message(b'hi there', track=True)
59 msg = s.send(a, 'hello', buffers=[M], track=True)
60 t = msg['tracker']
61 self.assertTrue(isinstance(t, zmq.MessageTracker))
62 self.assertRaises(zmq.NotDone, t.wait, .1)
63 del M
64 t.wait(1) # this will raise
65
66
50 # def test_rekey(self):
67 # def test_rekey(self):
51 # """rekeying dict around json str keys"""
68 # """rekeying dict around json str keys"""
52 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
69 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
General Comments 0
You need to be logged in to leave comments. Login now