Show More
@@ -34,13 +34,17 class AsyncResult(object): | |||
|
34 | 34 | """ |
|
35 | 35 | |
|
36 | 36 | msg_ids = None |
|
37 | _targets = None | |
|
38 | _tracker = None | |
|
37 | 39 | |
|
38 | def __init__(self, client, msg_ids, fname='unknown'): | |
|
40 | def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None): | |
|
39 | 41 | self._client = client |
|
40 | 42 | if isinstance(msg_ids, basestring): |
|
41 | 43 | msg_ids = [msg_ids] |
|
42 | 44 | self.msg_ids = msg_ids |
|
43 | 45 | self._fname=fname |
|
46 | self._targets = targets | |
|
47 | self._tracker = tracker | |
|
44 | 48 | self._ready = False |
|
45 | 49 | self._success = None |
|
46 | 50 | self._single_result = len(msg_ids) == 1 |
@@ -169,6 +173,19 class AsyncResult(object): | |||
|
169 | 173 | |
|
170 | 174 | def __dict__(self): |
|
171 | 175 | return self.get_dict(0) |
|
176 | ||
|
177 | def abort(self): | |
|
178 | """abort my tasks.""" | |
|
179 | assert not self.ready(), "Can't abort, I am already done!" | |
|
180 | return self.client.abort(self.msg_ids, targets=self._targets, block=True) | |
|
181 | ||
|
182 | @property | |
|
183 | def sent(self): | |
|
184 | """check whether my messages have been sent""" | |
|
185 | if self._tracker is None: | |
|
186 | return True | |
|
187 | else: | |
|
188 | return self._tracker.done | |
|
172 | 189 | |
|
173 | 190 | #------------------------------------- |
|
174 | 191 | # dict-access |
@@ -356,6 +356,9 class Client(HasTraits): | |||
|
356 | 356 | 'apply_reply' : self._handle_apply_reply} |
|
357 | 357 | self._connect(sshserver, ssh_kwargs) |
|
358 | 358 | |
|
359 | def __del__(self): | |
|
360 | """cleanup sockets, but _not_ context.""" | |
|
361 | self.close() | |
|
359 | 362 | |
|
360 | 363 | def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir): |
|
361 | 364 | if ipython_dir is None: |
@@ -387,7 +390,8 class Client(HasTraits): | |||
|
387 | 390 | return |
|
388 | 391 | snames = filter(lambda n: n.endswith('socket'), dir(self)) |
|
389 | 392 | for socket in map(lambda name: getattr(self, name), snames): |
|
390 | socket.close() | |
|
393 | if isinstance(socket, zmq.Socket) and not socket.closed: | |
|
394 | socket.close() | |
|
391 | 395 | self._closed = True |
|
392 | 396 | |
|
393 | 397 | def _update_engines(self, engines): |
@@ -550,7 +554,6 class Client(HasTraits): | |||
|
550 | 554 | outstanding = self._outstanding_dict[uuid] |
|
551 | 555 | |
|
552 | 556 | for msg_id in list(outstanding): |
|
553 | print msg_id | |
|
554 | 557 | if msg_id in self.results: |
|
555 | 558 | # we already |
|
556 | 559 | continue |
@@ -796,7 +799,7 class Client(HasTraits): | |||
|
796 | 799 | if msg['content']['status'] != 'ok': |
|
797 | 800 | error = self._unwrap_exception(msg['content']) |
|
798 | 801 | if error: |
|
799 |
re |
|
|
802 | raise error | |
|
800 | 803 | |
|
801 | 804 | |
|
802 | 805 | @spinfirst |
@@ -840,7 +843,7 class Client(HasTraits): | |||
|
840 | 843 | if msg['content']['status'] != 'ok': |
|
841 | 844 | error = self._unwrap_exception(msg['content']) |
|
842 | 845 | if error: |
|
843 |
re |
|
|
846 | raise error | |
|
844 | 847 | |
|
845 | 848 | @spinfirst |
|
846 | 849 | @defaultblock |
@@ -945,7 +948,8 class Client(HasTraits): | |||
|
945 | 948 | @defaultblock |
|
946 | 949 | def apply(self, f, args=None, kwargs=None, bound=True, block=None, |
|
947 | 950 | targets=None, balanced=None, |
|
948 |
after=None, follow=None, timeout=None |
|
|
951 | after=None, follow=None, timeout=None, | |
|
952 | track=False): | |
|
949 | 953 | """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. |
|
950 | 954 | |
|
951 | 955 | This is the central execution command for the client. |
@@ -1003,6 +1007,9 class Client(HasTraits): | |||
|
1003 | 1007 | Specify an amount of time (in seconds) for the scheduler to |
|
1004 | 1008 | wait for dependencies to be met before failing with a |
|
1005 | 1009 | DependencyTimeout. |
|
1010 | track : bool | |
|
1011 | whether to track non-copying sends. | |
|
1012 | [default False] | |
|
1006 | 1013 | |
|
1007 | 1014 | after,follow,timeout only used if `balanced=True`. |
|
1008 | 1015 | |
@@ -1044,7 +1051,7 class Client(HasTraits): | |||
|
1044 | 1051 | if not isinstance(kwargs, dict): |
|
1045 | 1052 | raise TypeError("kwargs must be dict, not %s"%type(kwargs)) |
|
1046 | 1053 | |
|
1047 | options = dict(bound=bound, block=block, targets=targets) | |
|
1054 | options = dict(bound=bound, block=block, targets=targets, track=track) | |
|
1048 | 1055 | |
|
1049 | 1056 | if balanced: |
|
1050 | 1057 | return self._apply_balanced(f, args, kwargs, timeout=timeout, |
@@ -1057,7 +1064,7 class Client(HasTraits): | |||
|
1057 | 1064 | return self._apply_direct(f, args, kwargs, **options) |
|
1058 | 1065 | |
|
1059 | 1066 | def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None, |
|
1060 | after=None, follow=None, timeout=None): | |
|
1067 | after=None, follow=None, timeout=None, track=None): | |
|
1061 | 1068 | """call f(*args, **kwargs) remotely in a load-balanced manner. |
|
1062 | 1069 | |
|
1063 | 1070 | This is a private method, see `apply` for details. |
@@ -1065,7 +1072,7 class Client(HasTraits): | |||
|
1065 | 1072 | """ |
|
1066 | 1073 | |
|
1067 | 1074 | loc = locals() |
|
1068 | for name in ('bound', 'block'): | |
|
1075 | for name in ('bound', 'block', 'track'): | |
|
1069 | 1076 | assert loc[name] is not None, "kwarg %r must be specified!"%name |
|
1070 | 1077 | |
|
1071 | 1078 | if self._task_socket is None: |
@@ -1101,13 +1108,13 class Client(HasTraits): | |||
|
1101 | 1108 | content = dict(bound=bound) |
|
1102 | 1109 | |
|
1103 | 1110 | msg = self.session.send(self._task_socket, "apply_request", |
|
1104 | content=content, buffers=bufs, subheader=subheader) | |
|
1111 | content=content, buffers=bufs, subheader=subheader, track=track) | |
|
1105 | 1112 | msg_id = msg['msg_id'] |
|
1106 | 1113 | self.outstanding.add(msg_id) |
|
1107 | 1114 | self.history.append(msg_id) |
|
1108 | 1115 | self.metadata[msg_id]['submitted'] = datetime.now() |
|
1109 | ||
|
1110 | ar = AsyncResult(self, [msg_id], fname=f.__name__) | |
|
1116 | tracker = None if track is False else msg['tracker'] | |
|
1117 | ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker) | |
|
1111 | 1118 | if block: |
|
1112 | 1119 | try: |
|
1113 | 1120 | return ar.get() |
@@ -1116,7 +1123,8 class Client(HasTraits): | |||
|
1116 | 1123 | else: |
|
1117 | 1124 | return ar |
|
1118 | 1125 | |
|
1119 |
def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None |
|
|
1126 | def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None, | |
|
1127 | track=None): | |
|
1120 | 1128 | """Then underlying method for applying functions to specific engines |
|
1121 | 1129 | via the MUX queue. |
|
1122 | 1130 | |
@@ -1124,7 +1132,7 class Client(HasTraits): | |||
|
1124 | 1132 | Not to be called directly! |
|
1125 | 1133 | """ |
|
1126 | 1134 | loc = locals() |
|
1127 | for name in ('bound', 'block', 'targets'): | |
|
1135 | for name in ('bound', 'block', 'targets', 'track'): | |
|
1128 | 1136 | assert loc[name] is not None, "kwarg %r must be specified!"%name |
|
1129 | 1137 | |
|
1130 | 1138 | idents,targets = self._build_targets(targets) |
@@ -1134,15 +1142,22 class Client(HasTraits): | |||
|
1134 | 1142 | bufs = util.pack_apply_message(f,args,kwargs) |
|
1135 | 1143 | |
|
1136 | 1144 | msg_ids = [] |
|
1145 | trackers = [] | |
|
1137 | 1146 | for ident in idents: |
|
1138 | 1147 | msg = self.session.send(self._mux_socket, "apply_request", |
|
1139 |
content=content, buffers=bufs, ident=ident, subheader=subheader |
|
|
1148 | content=content, buffers=bufs, ident=ident, subheader=subheader, | |
|
1149 | track=track) | |
|
1150 | if track: | |
|
1151 | trackers.append(msg['tracker']) | |
|
1140 | 1152 | msg_id = msg['msg_id'] |
|
1141 | 1153 | self.outstanding.add(msg_id) |
|
1142 | 1154 | self._outstanding_dict[ident].add(msg_id) |
|
1143 | 1155 | self.history.append(msg_id) |
|
1144 | 1156 | msg_ids.append(msg_id) |
|
1145 | ar = AsyncResult(self, msg_ids, fname=f.__name__) | |
|
1157 | ||
|
1158 | tracker = None if track is False else zmq.MessageTracker(*trackers) | |
|
1159 | ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) | |
|
1160 | ||
|
1146 | 1161 | if block: |
|
1147 | 1162 | try: |
|
1148 | 1163 | return ar.get() |
@@ -1230,11 +1245,11 class Client(HasTraits): | |||
|
1230 | 1245 | #-------------------------------------------------------------------------- |
|
1231 | 1246 | |
|
1232 | 1247 | @defaultblock |
|
1233 | def push(self, ns, targets='all', block=None): | |
|
1248 | def push(self, ns, targets='all', block=None, track=False): | |
|
1234 | 1249 | """Push the contents of `ns` into the namespace on `target`""" |
|
1235 | 1250 | if not isinstance(ns, dict): |
|
1236 | 1251 | raise TypeError("Must be a dict, not %s"%type(ns)) |
|
1237 | result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False) | |
|
1252 | result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track) | |
|
1238 | 1253 | if not block: |
|
1239 | 1254 | return result |
|
1240 | 1255 | |
@@ -1251,7 +1266,7 class Client(HasTraits): | |||
|
1251 | 1266 | return result |
|
1252 | 1267 | |
|
1253 | 1268 | @defaultblock |
|
1254 | def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None): | |
|
1269 | def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False): | |
|
1255 | 1270 | """ |
|
1256 | 1271 | Partition a Python sequence and send the partitions to a set of engines. |
|
1257 | 1272 | """ |
@@ -1259,16 +1274,25 class Client(HasTraits): | |||
|
1259 | 1274 | mapObject = Map.dists[dist]() |
|
1260 | 1275 | nparts = len(targets) |
|
1261 | 1276 | msg_ids = [] |
|
1277 | trackers = [] | |
|
1262 | 1278 | for index, engineid in enumerate(targets): |
|
1263 | 1279 | partition = mapObject.getPartition(seq, index, nparts) |
|
1264 | 1280 | if flatten and len(partition) == 1: |
|
1265 | r = self.push({key: partition[0]}, targets=engineid, block=False) | |
|
1281 | r = self.push({key: partition[0]}, targets=engineid, block=False, track=track) | |
|
1266 | 1282 | else: |
|
1267 | r = self.push({key: partition}, targets=engineid, block=False) | |
|
1283 | r = self.push({key: partition}, targets=engineid, block=False, track=track) | |
|
1268 | 1284 | msg_ids.extend(r.msg_ids) |
|
1269 | r = AsyncResult(self, msg_ids, fname='scatter') | |
|
1285 | if track: | |
|
1286 | trackers.append(r._tracker) | |
|
1287 | ||
|
1288 | if track: | |
|
1289 | tracker = zmq.MessageTracker(*trackers) | |
|
1290 | else: | |
|
1291 | tracker = None | |
|
1292 | ||
|
1293 | r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker) | |
|
1270 | 1294 | if block: |
|
1271 |
r. |
|
|
1295 | r.wait() | |
|
1272 | 1296 | else: |
|
1273 | 1297 | return r |
|
1274 | 1298 |
@@ -179,7 +179,7 class StreamSession(object): | |||
|
179 | 179 | return header.get('key', None) == self.key |
|
180 | 180 | |
|
181 | 181 | |
|
182 | def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None): | |
|
182 | def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False): | |
|
183 | 183 | """Build and send a message via stream or socket. |
|
184 | 184 | |
|
185 | 185 | Parameters |
@@ -191,13 +191,34 class StreamSession(object): | |||
|
191 | 191 | Normally, msg_or_type will be a msg_type unless a message is being sent more |
|
192 | 192 | than once. |
|
193 | 193 | |
|
194 | content : dict or None | |
|
195 | the content of the message (ignored if msg_or_type is a message) | |
|
196 | buffers : list or None | |
|
197 | the already-serialized buffers to be appended to the message | |
|
198 | parent : Message or dict or None | |
|
199 | the parent or parent header describing the parent of this message | |
|
200 | subheader : dict or None | |
|
201 | extra header keys for this message's header | |
|
202 | ident : bytes or list of bytes | |
|
203 | the zmq.IDENTITY routing path | |
|
204 | track : bool | |
|
205 | whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. | |
|
206 | ||
|
194 | 207 | Returns |
|
195 | 208 | ------- |
|
196 | (msg,sent) : tuple | |
|
197 |
|
|
|
198 | the nice wrapped dict-like object containing the headers | |
|
209 | msg : message dict | |
|
210 | the constructed message | |
|
211 | (msg,tracker) : (message dict, MessageTracker) | |
|
212 | if track=True, then a 2-tuple will be returned, the first element being the constructed | |
|
213 | message, and the second being the MessageTracker | |
|
199 | 214 | |
|
200 | 215 | """ |
|
216 | ||
|
217 | if not isinstance(stream, (zmq.Socket, ZMQStream)): | |
|
218 | raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) | |
|
219 | elif track and isinstance(stream, ZMQStream): | |
|
220 | raise TypeError("ZMQStream cannot track messages") | |
|
221 | ||
|
201 | 222 | if isinstance(msg_or_type, (Message, dict)): |
|
202 | 223 | # we got a Message, not a msg_type |
|
203 | 224 | # don't build a new Message |
@@ -205,6 +226,7 class StreamSession(object): | |||
|
205 | 226 | content = msg['content'] |
|
206 | 227 | else: |
|
207 | 228 | msg = self.msg(msg_or_type, content, parent, subheader) |
|
229 | ||
|
208 | 230 | buffers = [] if buffers is None else buffers |
|
209 | 231 | to_send = [] |
|
210 | 232 | if isinstance(ident, list): |
@@ -222,7 +244,7 class StreamSession(object): | |||
|
222 | 244 | content = self.none |
|
223 | 245 | elif isinstance(content, dict): |
|
224 | 246 | content = self.pack(content) |
|
225 |
elif isinstance(content, |
|
|
247 | elif isinstance(content, bytes): | |
|
226 | 248 | # content is already packed, as in a relayed message |
|
227 | 249 | pass |
|
228 | 250 | else: |
@@ -231,16 +253,29 class StreamSession(object): | |||
|
231 | 253 | flag = 0 |
|
232 | 254 | if buffers: |
|
233 | 255 | flag = zmq.SNDMORE |
|
234 | stream.send_multipart(to_send, flag, copy=False) | |
|
256 | _track = False | |
|
257 | else: | |
|
258 | _track=track | |
|
259 | if track: | |
|
260 | tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) | |
|
261 | else: | |
|
262 | tracker = stream.send_multipart(to_send, flag, copy=False) | |
|
235 | 263 | for b in buffers[:-1]: |
|
236 | 264 | stream.send(b, flag, copy=False) |
|
237 | 265 | if buffers: |
|
238 | stream.send(buffers[-1], copy=False) | |
|
266 | if track: | |
|
267 | tracker = stream.send(buffers[-1], copy=False, track=track) | |
|
268 | else: | |
|
269 | tracker = stream.send(buffers[-1], copy=False) | |
|
270 | ||
|
239 | 271 | # omsg = Message(msg) |
|
240 | 272 | if self.debug: |
|
241 | 273 | pprint.pprint(msg) |
|
242 | 274 | pprint.pprint(to_send) |
|
243 | 275 | pprint.pprint(buffers) |
|
276 | ||
|
277 | msg['tracker'] = tracker | |
|
278 | ||
|
244 | 279 | return msg |
|
245 | 280 | |
|
246 | 281 | def send_raw(self, stream, msg, flags=0, copy=True, ident=None): |
@@ -250,7 +285,7 class StreamSession(object): | |||
|
250 | 285 | ---------- |
|
251 | 286 | msg : list of sendable buffers""" |
|
252 | 287 | to_send = [] |
|
253 |
if isinstance(ident, |
|
|
288 | if isinstance(ident, bytes): | |
|
254 | 289 | ident = [ident] |
|
255 | 290 | if ident is not None: |
|
256 | 291 | to_send.extend(ident) |
@@ -1,24 +1,26 | |||
|
1 | 1 | """toplevel setup/teardown for parallel tests.""" |
|
2 | 2 | |
|
3 | import tempfile | |
|
3 | 4 | import time |
|
4 | from subprocess import Popen, PIPE | |
|
5 | from subprocess import Popen, PIPE, STDOUT | |
|
5 | 6 | |
|
6 | 7 | from IPython.zmq.parallel.ipcluster import launch_process |
|
7 | 8 | from IPython.zmq.parallel.entry_point import select_random_ports |
|
8 | 9 | |
|
9 | 10 | processes = [] |
|
11 | blackhole = tempfile.TemporaryFile() | |
|
10 | 12 | |
|
11 | 13 | # nose setup/teardown |
|
12 | 14 | |
|
13 | 15 | def setup(): |
|
14 |
cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout= |
|
|
16 | cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT) | |
|
15 | 17 | processes.append(cp) |
|
16 | 18 | time.sleep(.5) |
|
17 | 19 | add_engine() |
|
18 |
time.sleep( |
|
|
20 | time.sleep(2) | |
|
19 | 21 | |
|
20 | 22 | def add_engine(profile='iptest'): |
|
21 |
ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout= |
|
|
23 | ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT) | |
|
22 | 24 | # ep.start() |
|
23 | 25 | processes.append(ep) |
|
24 | 26 | return ep |
@@ -88,7 +88,9 class ClusterTestCase(BaseZMQTestCase): | |||
|
88 | 88 | self.base_engine_count=len(self.client.ids) |
|
89 | 89 | self.engines=[] |
|
90 | 90 | |
|
91 |
|
|
|
91 | def tearDown(self): | |
|
92 | self.client.close() | |
|
93 | BaseZMQTestCase.tearDown(self) | |
|
92 | 94 | # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ] |
|
93 | 95 | # [ e.wait() for e in self.engines ] |
|
94 | 96 | # while len(self.client.ids) > self.base_engine_count: |
@@ -2,6 +2,7 import time | |||
|
2 | 2 | from tempfile import mktemp |
|
3 | 3 | |
|
4 | 4 | import nose.tools as nt |
|
5 | import zmq | |
|
5 | 6 | |
|
6 | 7 | from IPython.zmq.parallel import client as clientmod |
|
7 | 8 | from IPython.zmq.parallel import error |
@@ -18,10 +19,9 class TestClient(ClusterTestCase): | |||
|
18 | 19 | self.assertEquals(len(self.client.ids), n+3) |
|
19 | 20 | self.assertTrue |
|
20 | 21 | |
|
21 | def test_segfault(self): | |
|
22 | """test graceful handling of engine death""" | |
|
22 | def test_segfault_task(self): | |
|
23 | """test graceful handling of engine death (balanced)""" | |
|
23 | 24 | self.add_engines(1) |
|
24 | eid = self.client.ids[-1] | |
|
25 | 25 | ar = self.client.apply(segfault, block=False) |
|
26 | 26 | self.assertRaisesRemote(error.EngineError, ar.get) |
|
27 | 27 | eid = ar.engine_id |
@@ -29,6 +29,17 class TestClient(ClusterTestCase): | |||
|
29 | 29 | time.sleep(.01) |
|
30 | 30 | self.client.spin() |
|
31 | 31 | |
|
32 | def test_segfault_mux(self): | |
|
33 | """test graceful handling of engine death (direct)""" | |
|
34 | self.add_engines(1) | |
|
35 | eid = self.client.ids[-1] | |
|
36 | ar = self.client[eid].apply_async(segfault) | |
|
37 | self.assertRaisesRemote(error.EngineError, ar.get) | |
|
38 | eid = ar.engine_id | |
|
39 | while eid in self.client.ids: | |
|
40 | time.sleep(.01) | |
|
41 | self.client.spin() | |
|
42 | ||
|
32 | 43 | def test_view_indexing(self): |
|
33 | 44 | """test index access for views""" |
|
34 | 45 | self.add_engines(2) |
@@ -91,13 +102,14 class TestClient(ClusterTestCase): | |||
|
91 | 102 | def test_push_pull(self): |
|
92 | 103 | """test pushing and pulling""" |
|
93 | 104 | data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) |
|
105 | t = self.client.ids[-1] | |
|
94 | 106 | self.add_engines(2) |
|
95 | 107 | push = self.client.push |
|
96 | 108 | pull = self.client.pull |
|
97 | 109 | self.client.block=True |
|
98 | 110 | nengines = len(self.client) |
|
99 |
push({'data':data}, targets= |
|
|
100 |
d = pull('data', targets= |
|
|
111 | push({'data':data}, targets=t) | |
|
112 | d = pull('data', targets=t) | |
|
101 | 113 | self.assertEquals(d, data) |
|
102 | 114 | push({'data':data}) |
|
103 | 115 | d = pull('data') |
@@ -119,15 +131,16 class TestClient(ClusterTestCase): | |||
|
119 | 131 | return 2.0*x |
|
120 | 132 | |
|
121 | 133 | self.add_engines(4) |
|
134 | t = self.client.ids[-1] | |
|
122 | 135 | self.client.block=True |
|
123 | 136 | push = self.client.push |
|
124 | 137 | pull = self.client.pull |
|
125 | 138 | execute = self.client.execute |
|
126 |
push({'testf':testf}, targets= |
|
|
127 |
r = pull('testf', targets= |
|
|
139 | push({'testf':testf}, targets=t) | |
|
140 | r = pull('testf', targets=t) | |
|
128 | 141 | self.assertEqual(r(1.0), testf(1.0)) |
|
129 |
execute('r = testf(10)', targets= |
|
|
130 |
r = pull('r', targets= |
|
|
142 | execute('r = testf(10)', targets=t) | |
|
143 | r = pull('r', targets=t) | |
|
131 | 144 | self.assertEquals(r, testf(10)) |
|
132 | 145 | ar = push({'testf':testf}, block=False) |
|
133 | 146 | ar.get() |
@@ -135,8 +148,8 class TestClient(ClusterTestCase): | |||
|
135 | 148 | rlist = ar.get() |
|
136 | 149 | for r in rlist: |
|
137 | 150 | self.assertEqual(r(1.0), testf(1.0)) |
|
138 |
execute("def g(x): return x*x", targets= |
|
|
139 |
r = pull(('testf','g'),targets= |
|
|
151 | execute("def g(x): return x*x", targets=t) | |
|
152 | r = pull(('testf','g'),targets=t) | |
|
140 | 153 | self.assertEquals((r[0](10),r[1](10)), (testf(10), 100)) |
|
141 | 154 | |
|
142 | 155 | def test_push_function_globals(self): |
@@ -173,7 +186,7 class TestClient(ClusterTestCase): | |||
|
173 | 186 | ids.remove(ids[-1]) |
|
174 | 187 | self.assertNotEquals(ids, self.client._ids) |
|
175 | 188 | |
|
176 |
def test_ |
|
|
189 | def test_run_newline(self): | |
|
177 | 190 | """test that run appends newline to files""" |
|
178 | 191 | tmpfile = mktemp() |
|
179 | 192 | with open(tmpfile, 'w') as f: |
@@ -184,4 +197,56 class TestClient(ClusterTestCase): | |||
|
184 | 197 | v.run(tmpfile, block=True) |
|
185 | 198 | self.assertEquals(v.apply_sync_bound(lambda : g()), 5) |
|
186 | 199 | |
|
187 | No newline at end of file | |
|
200 | def test_apply_tracked(self): | |
|
201 | """test tracking for apply""" | |
|
202 | # self.add_engines(1) | |
|
203 | t = self.client.ids[-1] | |
|
204 | self.client.block=False | |
|
205 | def echo(n=1024*1024, **kwargs): | |
|
206 | return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs) | |
|
207 | ar = echo(1) | |
|
208 | self.assertTrue(ar._tracker is None) | |
|
209 | self.assertTrue(ar.sent) | |
|
210 | ar = echo(track=True) | |
|
211 | self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) | |
|
212 | self.assertEquals(ar.sent, ar._tracker.done) | |
|
213 | ar._tracker.wait() | |
|
214 | self.assertTrue(ar.sent) | |
|
215 | ||
|
216 | def test_push_tracked(self): | |
|
217 | t = self.client.ids[-1] | |
|
218 | ns = dict(x='x'*1024*1024) | |
|
219 | ar = self.client.push(ns, targets=t, block=False) | |
|
220 | self.assertTrue(ar._tracker is None) | |
|
221 | self.assertTrue(ar.sent) | |
|
222 | ||
|
223 | ar = self.client.push(ns, targets=t, block=False, track=True) | |
|
224 | self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) | |
|
225 | self.assertEquals(ar.sent, ar._tracker.done) | |
|
226 | ar._tracker.wait() | |
|
227 | self.assertTrue(ar.sent) | |
|
228 | ar.get() | |
|
229 | ||
|
230 | def test_scatter_tracked(self): | |
|
231 | t = self.client.ids | |
|
232 | x='x'*1024*1024 | |
|
233 | ar = self.client.scatter('x', x, targets=t, block=False) | |
|
234 | self.assertTrue(ar._tracker is None) | |
|
235 | self.assertTrue(ar.sent) | |
|
236 | ||
|
237 | ar = self.client.scatter('x', x, targets=t, block=False, track=True) | |
|
238 | self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) | |
|
239 | self.assertEquals(ar.sent, ar._tracker.done) | |
|
240 | ar._tracker.wait() | |
|
241 | self.assertTrue(ar.sent) | |
|
242 | ar.get() | |
|
243 | ||
|
244 | def test_remote_reference(self): | |
|
245 | v = self.client[-1] | |
|
246 | v['a'] = 123 | |
|
247 | ra = clientmod.Reference('a') | |
|
248 | b = v.apply_sync_bound(lambda x: x, ra) | |
|
249 | self.assertEquals(b, 123) | |
|
250 | self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra) | |
|
251 | ||
|
252 |
@@ -4,7 +4,7 import uuid | |||
|
4 | 4 | import zmq |
|
5 | 5 | |
|
6 | 6 | from zmq.tests import BaseZMQTestCase |
|
7 | ||
|
7 | from zmq.eventloop.zmqstream import ZMQStream | |
|
8 | 8 | # from IPython.zmq.tests import SessionTestCase |
|
9 | 9 | from IPython.zmq.parallel import streamsession as ss |
|
10 | 10 | |
@@ -31,7 +31,7 class TestSession(SessionTestCase): | |||
|
31 | 31 | |
|
32 | 32 | def test_args(self): |
|
33 | 33 | """initialization arguments for StreamSession""" |
|
34 |
s = s |
|
|
34 | s = self.session | |
|
35 | 35 | self.assertTrue(s.pack is ss.default_packer) |
|
36 | 36 | self.assertTrue(s.unpack is ss.default_unpacker) |
|
37 | 37 | self.assertEquals(s.username, os.environ.get('USER', 'username')) |
@@ -46,7 +46,24 class TestSession(SessionTestCase): | |||
|
46 | 46 | self.assertEquals(s.session, u) |
|
47 | 47 | self.assertEquals(s.username, 'carrot') |
|
48 | 48 | |
|
49 | ||
|
49 | def test_tracking(self): | |
|
50 | """test tracking messages""" | |
|
51 | a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) | |
|
52 | s = self.session | |
|
53 | stream = ZMQStream(a) | |
|
54 | msg = s.send(a, 'hello', track=False) | |
|
55 | self.assertTrue(msg['tracker'] is None) | |
|
56 | msg = s.send(a, 'hello', track=True) | |
|
57 | self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) | |
|
58 | M = zmq.Message(b'hi there', track=True) | |
|
59 | msg = s.send(a, 'hello', buffers=[M], track=True) | |
|
60 | t = msg['tracker'] | |
|
61 | self.assertTrue(isinstance(t, zmq.MessageTracker)) | |
|
62 | self.assertRaises(zmq.NotDone, t.wait, .1) | |
|
63 | del M | |
|
64 | t.wait(1) # this will raise | |
|
65 | ||
|
66 | ||
|
50 | 67 | # def test_rekey(self): |
|
51 | 68 | # """rekeying dict around json str keys""" |
|
52 | 69 | # d = {'0': uuid.uuid4(), 0:uuid.uuid4()} |
General Comments 0
You need to be logged in to leave comments.
Login now