Show More
@@ -34,13 +34,17 b' class AsyncResult(object):' | |||||
34 | """ |
|
34 | """ | |
35 |
|
35 | |||
36 | msg_ids = None |
|
36 | msg_ids = None | |
|
37 | _targets = None | |||
|
38 | _tracker = None | |||
37 |
|
39 | |||
38 | def __init__(self, client, msg_ids, fname='unknown'): |
|
40 | def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None): | |
39 | self._client = client |
|
41 | self._client = client | |
40 | if isinstance(msg_ids, basestring): |
|
42 | if isinstance(msg_ids, basestring): | |
41 | msg_ids = [msg_ids] |
|
43 | msg_ids = [msg_ids] | |
42 | self.msg_ids = msg_ids |
|
44 | self.msg_ids = msg_ids | |
43 | self._fname=fname |
|
45 | self._fname=fname | |
|
46 | self._targets = targets | |||
|
47 | self._tracker = tracker | |||
44 | self._ready = False |
|
48 | self._ready = False | |
45 | self._success = None |
|
49 | self._success = None | |
46 | self._single_result = len(msg_ids) == 1 |
|
50 | self._single_result = len(msg_ids) == 1 | |
@@ -169,6 +173,19 b' class AsyncResult(object):' | |||||
169 |
|
173 | |||
170 | def __dict__(self): |
|
174 | def __dict__(self): | |
171 | return self.get_dict(0) |
|
175 | return self.get_dict(0) | |
|
176 | ||||
|
177 | def abort(self): | |||
|
178 | """abort my tasks.""" | |||
|
179 | assert not self.ready(), "Can't abort, I am already done!" | |||
|
180 | return self.client.abort(self.msg_ids, targets=self._targets, block=True) | |||
|
181 | ||||
|
182 | @property | |||
|
183 | def sent(self): | |||
|
184 | """check whether my messages have been sent""" | |||
|
185 | if self._tracker is None: | |||
|
186 | return True | |||
|
187 | else: | |||
|
188 | return self._tracker.done | |||
172 |
|
189 | |||
173 | #------------------------------------- |
|
190 | #------------------------------------- | |
174 | # dict-access |
|
191 | # dict-access |
@@ -356,6 +356,9 b' class Client(HasTraits):' | |||||
356 | 'apply_reply' : self._handle_apply_reply} |
|
356 | 'apply_reply' : self._handle_apply_reply} | |
357 | self._connect(sshserver, ssh_kwargs) |
|
357 | self._connect(sshserver, ssh_kwargs) | |
358 |
|
358 | |||
|
359 | def __del__(self): | |||
|
360 | """cleanup sockets, but _not_ context.""" | |||
|
361 | self.close() | |||
359 |
|
362 | |||
360 | def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir): |
|
363 | def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir): | |
361 | if ipython_dir is None: |
|
364 | if ipython_dir is None: | |
@@ -387,7 +390,8 b' class Client(HasTraits):' | |||||
387 | return |
|
390 | return | |
388 | snames = filter(lambda n: n.endswith('socket'), dir(self)) |
|
391 | snames = filter(lambda n: n.endswith('socket'), dir(self)) | |
389 | for socket in map(lambda name: getattr(self, name), snames): |
|
392 | for socket in map(lambda name: getattr(self, name), snames): | |
390 | socket.close() |
|
393 | if isinstance(socket, zmq.Socket) and not socket.closed: | |
|
394 | socket.close() | |||
391 | self._closed = True |
|
395 | self._closed = True | |
392 |
|
396 | |||
393 | def _update_engines(self, engines): |
|
397 | def _update_engines(self, engines): | |
@@ -550,7 +554,6 b' class Client(HasTraits):' | |||||
550 | outstanding = self._outstanding_dict[uuid] |
|
554 | outstanding = self._outstanding_dict[uuid] | |
551 |
|
555 | |||
552 | for msg_id in list(outstanding): |
|
556 | for msg_id in list(outstanding): | |
553 | print msg_id |
|
|||
554 | if msg_id in self.results: |
|
557 | if msg_id in self.results: | |
555 | # we already |
|
558 | # we already | |
556 | continue |
|
559 | continue | |
@@ -796,7 +799,7 b' class Client(HasTraits):' | |||||
796 | if msg['content']['status'] != 'ok': |
|
799 | if msg['content']['status'] != 'ok': | |
797 | error = self._unwrap_exception(msg['content']) |
|
800 | error = self._unwrap_exception(msg['content']) | |
798 | if error: |
|
801 | if error: | |
799 |
re |
|
802 | raise error | |
800 |
|
803 | |||
801 |
|
804 | |||
802 | @spinfirst |
|
805 | @spinfirst | |
@@ -840,7 +843,7 b' class Client(HasTraits):' | |||||
840 | if msg['content']['status'] != 'ok': |
|
843 | if msg['content']['status'] != 'ok': | |
841 | error = self._unwrap_exception(msg['content']) |
|
844 | error = self._unwrap_exception(msg['content']) | |
842 | if error: |
|
845 | if error: | |
843 |
re |
|
846 | raise error | |
844 |
|
847 | |||
845 | @spinfirst |
|
848 | @spinfirst | |
846 | @defaultblock |
|
849 | @defaultblock | |
@@ -945,7 +948,8 b' class Client(HasTraits):' | |||||
945 | @defaultblock |
|
948 | @defaultblock | |
946 | def apply(self, f, args=None, kwargs=None, bound=True, block=None, |
|
949 | def apply(self, f, args=None, kwargs=None, bound=True, block=None, | |
947 | targets=None, balanced=None, |
|
950 | targets=None, balanced=None, | |
948 |
after=None, follow=None, timeout=None |
|
951 | after=None, follow=None, timeout=None, | |
|
952 | track=False): | |||
949 | """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. |
|
953 | """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. | |
950 |
|
954 | |||
951 | This is the central execution command for the client. |
|
955 | This is the central execution command for the client. | |
@@ -1003,6 +1007,9 b' class Client(HasTraits):' | |||||
1003 | Specify an amount of time (in seconds) for the scheduler to |
|
1007 | Specify an amount of time (in seconds) for the scheduler to | |
1004 | wait for dependencies to be met before failing with a |
|
1008 | wait for dependencies to be met before failing with a | |
1005 | DependencyTimeout. |
|
1009 | DependencyTimeout. | |
|
1010 | track : bool | |||
|
1011 | whether to track non-copying sends. | |||
|
1012 | [default False] | |||
1006 |
|
1013 | |||
1007 | after,follow,timeout only used if `balanced=True`. |
|
1014 | after,follow,timeout only used if `balanced=True`. | |
1008 |
|
1015 | |||
@@ -1044,7 +1051,7 b' class Client(HasTraits):' | |||||
1044 | if not isinstance(kwargs, dict): |
|
1051 | if not isinstance(kwargs, dict): | |
1045 | raise TypeError("kwargs must be dict, not %s"%type(kwargs)) |
|
1052 | raise TypeError("kwargs must be dict, not %s"%type(kwargs)) | |
1046 |
|
1053 | |||
1047 | options = dict(bound=bound, block=block, targets=targets) |
|
1054 | options = dict(bound=bound, block=block, targets=targets, track=track) | |
1048 |
|
1055 | |||
1049 | if balanced: |
|
1056 | if balanced: | |
1050 | return self._apply_balanced(f, args, kwargs, timeout=timeout, |
|
1057 | return self._apply_balanced(f, args, kwargs, timeout=timeout, | |
@@ -1057,7 +1064,7 b' class Client(HasTraits):' | |||||
1057 | return self._apply_direct(f, args, kwargs, **options) |
|
1064 | return self._apply_direct(f, args, kwargs, **options) | |
1058 |
|
1065 | |||
1059 | def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None, |
|
1066 | def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None, | |
1060 | after=None, follow=None, timeout=None): |
|
1067 | after=None, follow=None, timeout=None, track=None): | |
1061 | """call f(*args, **kwargs) remotely in a load-balanced manner. |
|
1068 | """call f(*args, **kwargs) remotely in a load-balanced manner. | |
1062 |
|
1069 | |||
1063 | This is a private method, see `apply` for details. |
|
1070 | This is a private method, see `apply` for details. | |
@@ -1065,7 +1072,7 b' class Client(HasTraits):' | |||||
1065 | """ |
|
1072 | """ | |
1066 |
|
1073 | |||
1067 | loc = locals() |
|
1074 | loc = locals() | |
1068 | for name in ('bound', 'block'): |
|
1075 | for name in ('bound', 'block', 'track'): | |
1069 | assert loc[name] is not None, "kwarg %r must be specified!"%name |
|
1076 | assert loc[name] is not None, "kwarg %r must be specified!"%name | |
1070 |
|
1077 | |||
1071 | if self._task_socket is None: |
|
1078 | if self._task_socket is None: | |
@@ -1101,13 +1108,13 b' class Client(HasTraits):' | |||||
1101 | content = dict(bound=bound) |
|
1108 | content = dict(bound=bound) | |
1102 |
|
1109 | |||
1103 | msg = self.session.send(self._task_socket, "apply_request", |
|
1110 | msg = self.session.send(self._task_socket, "apply_request", | |
1104 | content=content, buffers=bufs, subheader=subheader) |
|
1111 | content=content, buffers=bufs, subheader=subheader, track=track) | |
1105 | msg_id = msg['msg_id'] |
|
1112 | msg_id = msg['msg_id'] | |
1106 | self.outstanding.add(msg_id) |
|
1113 | self.outstanding.add(msg_id) | |
1107 | self.history.append(msg_id) |
|
1114 | self.history.append(msg_id) | |
1108 | self.metadata[msg_id]['submitted'] = datetime.now() |
|
1115 | self.metadata[msg_id]['submitted'] = datetime.now() | |
1109 |
|
1116 | tracker = None if track is False else msg['tracker'] | ||
1110 | ar = AsyncResult(self, [msg_id], fname=f.__name__) |
|
1117 | ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker) | |
1111 | if block: |
|
1118 | if block: | |
1112 | try: |
|
1119 | try: | |
1113 | return ar.get() |
|
1120 | return ar.get() | |
@@ -1116,7 +1123,8 b' class Client(HasTraits):' | |||||
1116 | else: |
|
1123 | else: | |
1117 | return ar |
|
1124 | return ar | |
1118 |
|
1125 | |||
1119 |
def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None |
|
1126 | def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None, | |
|
1127 | track=None): | |||
1120 | """Then underlying method for applying functions to specific engines |
|
1128 | """Then underlying method for applying functions to specific engines | |
1121 | via the MUX queue. |
|
1129 | via the MUX queue. | |
1122 |
|
1130 | |||
@@ -1124,7 +1132,7 b' class Client(HasTraits):' | |||||
1124 | Not to be called directly! |
|
1132 | Not to be called directly! | |
1125 | """ |
|
1133 | """ | |
1126 | loc = locals() |
|
1134 | loc = locals() | |
1127 | for name in ('bound', 'block', 'targets'): |
|
1135 | for name in ('bound', 'block', 'targets', 'track'): | |
1128 | assert loc[name] is not None, "kwarg %r must be specified!"%name |
|
1136 | assert loc[name] is not None, "kwarg %r must be specified!"%name | |
1129 |
|
1137 | |||
1130 | idents,targets = self._build_targets(targets) |
|
1138 | idents,targets = self._build_targets(targets) | |
@@ -1134,15 +1142,22 b' class Client(HasTraits):' | |||||
1134 | bufs = util.pack_apply_message(f,args,kwargs) |
|
1142 | bufs = util.pack_apply_message(f,args,kwargs) | |
1135 |
|
1143 | |||
1136 | msg_ids = [] |
|
1144 | msg_ids = [] | |
|
1145 | trackers = [] | |||
1137 | for ident in idents: |
|
1146 | for ident in idents: | |
1138 | msg = self.session.send(self._mux_socket, "apply_request", |
|
1147 | msg = self.session.send(self._mux_socket, "apply_request", | |
1139 |
content=content, buffers=bufs, ident=ident, subheader=subheader |
|
1148 | content=content, buffers=bufs, ident=ident, subheader=subheader, | |
|
1149 | track=track) | |||
|
1150 | if track: | |||
|
1151 | trackers.append(msg['tracker']) | |||
1140 | msg_id = msg['msg_id'] |
|
1152 | msg_id = msg['msg_id'] | |
1141 | self.outstanding.add(msg_id) |
|
1153 | self.outstanding.add(msg_id) | |
1142 | self._outstanding_dict[ident].add(msg_id) |
|
1154 | self._outstanding_dict[ident].add(msg_id) | |
1143 | self.history.append(msg_id) |
|
1155 | self.history.append(msg_id) | |
1144 | msg_ids.append(msg_id) |
|
1156 | msg_ids.append(msg_id) | |
1145 | ar = AsyncResult(self, msg_ids, fname=f.__name__) |
|
1157 | ||
|
1158 | tracker = None if track is False else zmq.MessageTracker(*trackers) | |||
|
1159 | ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) | |||
|
1160 | ||||
1146 | if block: |
|
1161 | if block: | |
1147 | try: |
|
1162 | try: | |
1148 | return ar.get() |
|
1163 | return ar.get() | |
@@ -1230,11 +1245,11 b' class Client(HasTraits):' | |||||
1230 | #-------------------------------------------------------------------------- |
|
1245 | #-------------------------------------------------------------------------- | |
1231 |
|
1246 | |||
1232 | @defaultblock |
|
1247 | @defaultblock | |
1233 | def push(self, ns, targets='all', block=None): |
|
1248 | def push(self, ns, targets='all', block=None, track=False): | |
1234 | """Push the contents of `ns` into the namespace on `target`""" |
|
1249 | """Push the contents of `ns` into the namespace on `target`""" | |
1235 | if not isinstance(ns, dict): |
|
1250 | if not isinstance(ns, dict): | |
1236 | raise TypeError("Must be a dict, not %s"%type(ns)) |
|
1251 | raise TypeError("Must be a dict, not %s"%type(ns)) | |
1237 | result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False) |
|
1252 | result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track) | |
1238 | if not block: |
|
1253 | if not block: | |
1239 | return result |
|
1254 | return result | |
1240 |
|
1255 | |||
@@ -1251,7 +1266,7 b' class Client(HasTraits):' | |||||
1251 | return result |
|
1266 | return result | |
1252 |
|
1267 | |||
1253 | @defaultblock |
|
1268 | @defaultblock | |
1254 | def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None): |
|
1269 | def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False): | |
1255 | """ |
|
1270 | """ | |
1256 | Partition a Python sequence and send the partitions to a set of engines. |
|
1271 | Partition a Python sequence and send the partitions to a set of engines. | |
1257 | """ |
|
1272 | """ | |
@@ -1259,16 +1274,25 b' class Client(HasTraits):' | |||||
1259 | mapObject = Map.dists[dist]() |
|
1274 | mapObject = Map.dists[dist]() | |
1260 | nparts = len(targets) |
|
1275 | nparts = len(targets) | |
1261 | msg_ids = [] |
|
1276 | msg_ids = [] | |
|
1277 | trackers = [] | |||
1262 | for index, engineid in enumerate(targets): |
|
1278 | for index, engineid in enumerate(targets): | |
1263 | partition = mapObject.getPartition(seq, index, nparts) |
|
1279 | partition = mapObject.getPartition(seq, index, nparts) | |
1264 | if flatten and len(partition) == 1: |
|
1280 | if flatten and len(partition) == 1: | |
1265 | r = self.push({key: partition[0]}, targets=engineid, block=False) |
|
1281 | r = self.push({key: partition[0]}, targets=engineid, block=False, track=track) | |
1266 | else: |
|
1282 | else: | |
1267 | r = self.push({key: partition}, targets=engineid, block=False) |
|
1283 | r = self.push({key: partition}, targets=engineid, block=False, track=track) | |
1268 | msg_ids.extend(r.msg_ids) |
|
1284 | msg_ids.extend(r.msg_ids) | |
1269 | r = AsyncResult(self, msg_ids, fname='scatter') |
|
1285 | if track: | |
|
1286 | trackers.append(r._tracker) | |||
|
1287 | ||||
|
1288 | if track: | |||
|
1289 | tracker = zmq.MessageTracker(*trackers) | |||
|
1290 | else: | |||
|
1291 | tracker = None | |||
|
1292 | ||||
|
1293 | r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker) | |||
1270 | if block: |
|
1294 | if block: | |
1271 |
r. |
|
1295 | r.wait() | |
1272 | else: |
|
1296 | else: | |
1273 | return r |
|
1297 | return r | |
1274 |
|
1298 |
@@ -179,7 +179,7 b' class StreamSession(object):' | |||||
179 | return header.get('key', None) == self.key |
|
179 | return header.get('key', None) == self.key | |
180 |
|
180 | |||
181 |
|
181 | |||
182 | def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None): |
|
182 | def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False): | |
183 | """Build and send a message via stream or socket. |
|
183 | """Build and send a message via stream or socket. | |
184 |
|
184 | |||
185 | Parameters |
|
185 | Parameters | |
@@ -191,13 +191,34 b' class StreamSession(object):' | |||||
191 | Normally, msg_or_type will be a msg_type unless a message is being sent more |
|
191 | Normally, msg_or_type will be a msg_type unless a message is being sent more | |
192 | than once. |
|
192 | than once. | |
193 |
|
193 | |||
|
194 | content : dict or None | |||
|
195 | the content of the message (ignored if msg_or_type is a message) | |||
|
196 | buffers : list or None | |||
|
197 | the already-serialized buffers to be appended to the message | |||
|
198 | parent : Message or dict or None | |||
|
199 | the parent or parent header describing the parent of this message | |||
|
200 | subheader : dict or None | |||
|
201 | extra header keys for this message's header | |||
|
202 | ident : bytes or list of bytes | |||
|
203 | the zmq.IDENTITY routing path | |||
|
204 | track : bool | |||
|
205 | whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. | |||
|
206 | ||||
194 | Returns |
|
207 | Returns | |
195 | ------- |
|
208 | ------- | |
196 | (msg,sent) : tuple |
|
209 | msg : message dict | |
197 |
|
|
210 | the constructed message | |
198 | the nice wrapped dict-like object containing the headers |
|
211 | (msg,tracker) : (message dict, MessageTracker) | |
|
212 | if track=True, then a 2-tuple will be returned, the first element being the constructed | |||
|
213 | message, and the second being the MessageTracker | |||
199 |
|
214 | |||
200 | """ |
|
215 | """ | |
|
216 | ||||
|
217 | if not isinstance(stream, (zmq.Socket, ZMQStream)): | |||
|
218 | raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) | |||
|
219 | elif track and isinstance(stream, ZMQStream): | |||
|
220 | raise TypeError("ZMQStream cannot track messages") | |||
|
221 | ||||
201 | if isinstance(msg_or_type, (Message, dict)): |
|
222 | if isinstance(msg_or_type, (Message, dict)): | |
202 | # we got a Message, not a msg_type |
|
223 | # we got a Message, not a msg_type | |
203 | # don't build a new Message |
|
224 | # don't build a new Message | |
@@ -205,6 +226,7 b' class StreamSession(object):' | |||||
205 | content = msg['content'] |
|
226 | content = msg['content'] | |
206 | else: |
|
227 | else: | |
207 | msg = self.msg(msg_or_type, content, parent, subheader) |
|
228 | msg = self.msg(msg_or_type, content, parent, subheader) | |
|
229 | ||||
208 | buffers = [] if buffers is None else buffers |
|
230 | buffers = [] if buffers is None else buffers | |
209 | to_send = [] |
|
231 | to_send = [] | |
210 | if isinstance(ident, list): |
|
232 | if isinstance(ident, list): | |
@@ -222,7 +244,7 b' class StreamSession(object):' | |||||
222 | content = self.none |
|
244 | content = self.none | |
223 | elif isinstance(content, dict): |
|
245 | elif isinstance(content, dict): | |
224 | content = self.pack(content) |
|
246 | content = self.pack(content) | |
225 |
elif isinstance(content, |
|
247 | elif isinstance(content, bytes): | |
226 | # content is already packed, as in a relayed message |
|
248 | # content is already packed, as in a relayed message | |
227 | pass |
|
249 | pass | |
228 | else: |
|
250 | else: | |
@@ -231,16 +253,29 b' class StreamSession(object):' | |||||
231 | flag = 0 |
|
253 | flag = 0 | |
232 | if buffers: |
|
254 | if buffers: | |
233 | flag = zmq.SNDMORE |
|
255 | flag = zmq.SNDMORE | |
234 | stream.send_multipart(to_send, flag, copy=False) |
|
256 | _track = False | |
|
257 | else: | |||
|
258 | _track=track | |||
|
259 | if track: | |||
|
260 | tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) | |||
|
261 | else: | |||
|
262 | tracker = stream.send_multipart(to_send, flag, copy=False) | |||
235 | for b in buffers[:-1]: |
|
263 | for b in buffers[:-1]: | |
236 | stream.send(b, flag, copy=False) |
|
264 | stream.send(b, flag, copy=False) | |
237 | if buffers: |
|
265 | if buffers: | |
238 | stream.send(buffers[-1], copy=False) |
|
266 | if track: | |
|
267 | tracker = stream.send(buffers[-1], copy=False, track=track) | |||
|
268 | else: | |||
|
269 | tracker = stream.send(buffers[-1], copy=False) | |||
|
270 | ||||
239 | # omsg = Message(msg) |
|
271 | # omsg = Message(msg) | |
240 | if self.debug: |
|
272 | if self.debug: | |
241 | pprint.pprint(msg) |
|
273 | pprint.pprint(msg) | |
242 | pprint.pprint(to_send) |
|
274 | pprint.pprint(to_send) | |
243 | pprint.pprint(buffers) |
|
275 | pprint.pprint(buffers) | |
|
276 | ||||
|
277 | msg['tracker'] = tracker | |||
|
278 | ||||
244 | return msg |
|
279 | return msg | |
245 |
|
280 | |||
246 | def send_raw(self, stream, msg, flags=0, copy=True, ident=None): |
|
281 | def send_raw(self, stream, msg, flags=0, copy=True, ident=None): | |
@@ -250,7 +285,7 b' class StreamSession(object):' | |||||
250 | ---------- |
|
285 | ---------- | |
251 | msg : list of sendable buffers""" |
|
286 | msg : list of sendable buffers""" | |
252 | to_send = [] |
|
287 | to_send = [] | |
253 |
if isinstance(ident, |
|
288 | if isinstance(ident, bytes): | |
254 | ident = [ident] |
|
289 | ident = [ident] | |
255 | if ident is not None: |
|
290 | if ident is not None: | |
256 | to_send.extend(ident) |
|
291 | to_send.extend(ident) |
@@ -1,24 +1,26 b'' | |||||
1 | """toplevel setup/teardown for parallel tests.""" |
|
1 | """toplevel setup/teardown for parallel tests.""" | |
2 |
|
2 | |||
|
3 | import tempfile | |||
3 | import time |
|
4 | import time | |
4 | from subprocess import Popen, PIPE |
|
5 | from subprocess import Popen, PIPE, STDOUT | |
5 |
|
6 | |||
6 | from IPython.zmq.parallel.ipcluster import launch_process |
|
7 | from IPython.zmq.parallel.ipcluster import launch_process | |
7 | from IPython.zmq.parallel.entry_point import select_random_ports |
|
8 | from IPython.zmq.parallel.entry_point import select_random_ports | |
8 |
|
9 | |||
9 | processes = [] |
|
10 | processes = [] | |
|
11 | blackhole = tempfile.TemporaryFile() | |||
10 |
|
12 | |||
11 | # nose setup/teardown |
|
13 | # nose setup/teardown | |
12 |
|
14 | |||
13 | def setup(): |
|
15 | def setup(): | |
14 |
cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout= |
|
16 | cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT) | |
15 | processes.append(cp) |
|
17 | processes.append(cp) | |
16 | time.sleep(.5) |
|
18 | time.sleep(.5) | |
17 | add_engine() |
|
19 | add_engine() | |
18 |
time.sleep( |
|
20 | time.sleep(2) | |
19 |
|
21 | |||
20 | def add_engine(profile='iptest'): |
|
22 | def add_engine(profile='iptest'): | |
21 |
ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout= |
|
23 | ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT) | |
22 | # ep.start() |
|
24 | # ep.start() | |
23 | processes.append(ep) |
|
25 | processes.append(ep) | |
24 | return ep |
|
26 | return ep |
@@ -88,7 +88,9 b' class ClusterTestCase(BaseZMQTestCase):' | |||||
88 | self.base_engine_count=len(self.client.ids) |
|
88 | self.base_engine_count=len(self.client.ids) | |
89 | self.engines=[] |
|
89 | self.engines=[] | |
90 |
|
90 | |||
91 |
|
|
91 | def tearDown(self): | |
|
92 | self.client.close() | |||
|
93 | BaseZMQTestCase.tearDown(self) | |||
92 | # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ] |
|
94 | # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ] | |
93 | # [ e.wait() for e in self.engines ] |
|
95 | # [ e.wait() for e in self.engines ] | |
94 | # while len(self.client.ids) > self.base_engine_count: |
|
96 | # while len(self.client.ids) > self.base_engine_count: |
@@ -2,6 +2,7 b' import time' | |||||
2 | from tempfile import mktemp |
|
2 | from tempfile import mktemp | |
3 |
|
3 | |||
4 | import nose.tools as nt |
|
4 | import nose.tools as nt | |
|
5 | import zmq | |||
5 |
|
6 | |||
6 | from IPython.zmq.parallel import client as clientmod |
|
7 | from IPython.zmq.parallel import client as clientmod | |
7 | from IPython.zmq.parallel import error |
|
8 | from IPython.zmq.parallel import error | |
@@ -18,10 +19,9 b' class TestClient(ClusterTestCase):' | |||||
18 | self.assertEquals(len(self.client.ids), n+3) |
|
19 | self.assertEquals(len(self.client.ids), n+3) | |
19 | self.assertTrue |
|
20 | self.assertTrue | |
20 |
|
21 | |||
21 | def test_segfault(self): |
|
22 | def test_segfault_task(self): | |
22 | """test graceful handling of engine death""" |
|
23 | """test graceful handling of engine death (balanced)""" | |
23 | self.add_engines(1) |
|
24 | self.add_engines(1) | |
24 | eid = self.client.ids[-1] |
|
|||
25 | ar = self.client.apply(segfault, block=False) |
|
25 | ar = self.client.apply(segfault, block=False) | |
26 | self.assertRaisesRemote(error.EngineError, ar.get) |
|
26 | self.assertRaisesRemote(error.EngineError, ar.get) | |
27 | eid = ar.engine_id |
|
27 | eid = ar.engine_id | |
@@ -29,6 +29,17 b' class TestClient(ClusterTestCase):' | |||||
29 | time.sleep(.01) |
|
29 | time.sleep(.01) | |
30 | self.client.spin() |
|
30 | self.client.spin() | |
31 |
|
31 | |||
|
32 | def test_segfault_mux(self): | |||
|
33 | """test graceful handling of engine death (direct)""" | |||
|
34 | self.add_engines(1) | |||
|
35 | eid = self.client.ids[-1] | |||
|
36 | ar = self.client[eid].apply_async(segfault) | |||
|
37 | self.assertRaisesRemote(error.EngineError, ar.get) | |||
|
38 | eid = ar.engine_id | |||
|
39 | while eid in self.client.ids: | |||
|
40 | time.sleep(.01) | |||
|
41 | self.client.spin() | |||
|
42 | ||||
32 | def test_view_indexing(self): |
|
43 | def test_view_indexing(self): | |
33 | """test index access for views""" |
|
44 | """test index access for views""" | |
34 | self.add_engines(2) |
|
45 | self.add_engines(2) | |
@@ -91,13 +102,14 b' class TestClient(ClusterTestCase):' | |||||
91 | def test_push_pull(self): |
|
102 | def test_push_pull(self): | |
92 | """test pushing and pulling""" |
|
103 | """test pushing and pulling""" | |
93 | data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) |
|
104 | data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) | |
|
105 | t = self.client.ids[-1] | |||
94 | self.add_engines(2) |
|
106 | self.add_engines(2) | |
95 | push = self.client.push |
|
107 | push = self.client.push | |
96 | pull = self.client.pull |
|
108 | pull = self.client.pull | |
97 | self.client.block=True |
|
109 | self.client.block=True | |
98 | nengines = len(self.client) |
|
110 | nengines = len(self.client) | |
99 |
push({'data':data}, targets= |
|
111 | push({'data':data}, targets=t) | |
100 |
d = pull('data', targets= |
|
112 | d = pull('data', targets=t) | |
101 | self.assertEquals(d, data) |
|
113 | self.assertEquals(d, data) | |
102 | push({'data':data}) |
|
114 | push({'data':data}) | |
103 | d = pull('data') |
|
115 | d = pull('data') | |
@@ -119,15 +131,16 b' class TestClient(ClusterTestCase):' | |||||
119 | return 2.0*x |
|
131 | return 2.0*x | |
120 |
|
132 | |||
121 | self.add_engines(4) |
|
133 | self.add_engines(4) | |
|
134 | t = self.client.ids[-1] | |||
122 | self.client.block=True |
|
135 | self.client.block=True | |
123 | push = self.client.push |
|
136 | push = self.client.push | |
124 | pull = self.client.pull |
|
137 | pull = self.client.pull | |
125 | execute = self.client.execute |
|
138 | execute = self.client.execute | |
126 |
push({'testf':testf}, targets= |
|
139 | push({'testf':testf}, targets=t) | |
127 |
r = pull('testf', targets= |
|
140 | r = pull('testf', targets=t) | |
128 | self.assertEqual(r(1.0), testf(1.0)) |
|
141 | self.assertEqual(r(1.0), testf(1.0)) | |
129 |
execute('r = testf(10)', targets= |
|
142 | execute('r = testf(10)', targets=t) | |
130 |
r = pull('r', targets= |
|
143 | r = pull('r', targets=t) | |
131 | self.assertEquals(r, testf(10)) |
|
144 | self.assertEquals(r, testf(10)) | |
132 | ar = push({'testf':testf}, block=False) |
|
145 | ar = push({'testf':testf}, block=False) | |
133 | ar.get() |
|
146 | ar.get() | |
@@ -135,8 +148,8 b' class TestClient(ClusterTestCase):' | |||||
135 | rlist = ar.get() |
|
148 | rlist = ar.get() | |
136 | for r in rlist: |
|
149 | for r in rlist: | |
137 | self.assertEqual(r(1.0), testf(1.0)) |
|
150 | self.assertEqual(r(1.0), testf(1.0)) | |
138 |
execute("def g(x): return x*x", targets= |
|
151 | execute("def g(x): return x*x", targets=t) | |
139 |
r = pull(('testf','g'),targets= |
|
152 | r = pull(('testf','g'),targets=t) | |
140 | self.assertEquals((r[0](10),r[1](10)), (testf(10), 100)) |
|
153 | self.assertEquals((r[0](10),r[1](10)), (testf(10), 100)) | |
141 |
|
154 | |||
142 | def test_push_function_globals(self): |
|
155 | def test_push_function_globals(self): | |
@@ -173,7 +186,7 b' class TestClient(ClusterTestCase):' | |||||
173 | ids.remove(ids[-1]) |
|
186 | ids.remove(ids[-1]) | |
174 | self.assertNotEquals(ids, self.client._ids) |
|
187 | self.assertNotEquals(ids, self.client._ids) | |
175 |
|
188 | |||
176 |
def test_ |
|
189 | def test_run_newline(self): | |
177 | """test that run appends newline to files""" |
|
190 | """test that run appends newline to files""" | |
178 | tmpfile = mktemp() |
|
191 | tmpfile = mktemp() | |
179 | with open(tmpfile, 'w') as f: |
|
192 | with open(tmpfile, 'w') as f: | |
@@ -184,4 +197,56 b' class TestClient(ClusterTestCase):' | |||||
184 | v.run(tmpfile, block=True) |
|
197 | v.run(tmpfile, block=True) | |
185 | self.assertEquals(v.apply_sync_bound(lambda : g()), 5) |
|
198 | self.assertEquals(v.apply_sync_bound(lambda : g()), 5) | |
186 |
|
199 | |||
187 | No newline at end of file |
|
200 | def test_apply_tracked(self): | |
|
201 | """test tracking for apply""" | |||
|
202 | # self.add_engines(1) | |||
|
203 | t = self.client.ids[-1] | |||
|
204 | self.client.block=False | |||
|
205 | def echo(n=1024*1024, **kwargs): | |||
|
206 | return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs) | |||
|
207 | ar = echo(1) | |||
|
208 | self.assertTrue(ar._tracker is None) | |||
|
209 | self.assertTrue(ar.sent) | |||
|
210 | ar = echo(track=True) | |||
|
211 | self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) | |||
|
212 | self.assertEquals(ar.sent, ar._tracker.done) | |||
|
213 | ar._tracker.wait() | |||
|
214 | self.assertTrue(ar.sent) | |||
|
215 | ||||
|
216 | def test_push_tracked(self): | |||
|
217 | t = self.client.ids[-1] | |||
|
218 | ns = dict(x='x'*1024*1024) | |||
|
219 | ar = self.client.push(ns, targets=t, block=False) | |||
|
220 | self.assertTrue(ar._tracker is None) | |||
|
221 | self.assertTrue(ar.sent) | |||
|
222 | ||||
|
223 | ar = self.client.push(ns, targets=t, block=False, track=True) | |||
|
224 | self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) | |||
|
225 | self.assertEquals(ar.sent, ar._tracker.done) | |||
|
226 | ar._tracker.wait() | |||
|
227 | self.assertTrue(ar.sent) | |||
|
228 | ar.get() | |||
|
229 | ||||
|
230 | def test_scatter_tracked(self): | |||
|
231 | t = self.client.ids | |||
|
232 | x='x'*1024*1024 | |||
|
233 | ar = self.client.scatter('x', x, targets=t, block=False) | |||
|
234 | self.assertTrue(ar._tracker is None) | |||
|
235 | self.assertTrue(ar.sent) | |||
|
236 | ||||
|
237 | ar = self.client.scatter('x', x, targets=t, block=False, track=True) | |||
|
238 | self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker)) | |||
|
239 | self.assertEquals(ar.sent, ar._tracker.done) | |||
|
240 | ar._tracker.wait() | |||
|
241 | self.assertTrue(ar.sent) | |||
|
242 | ar.get() | |||
|
243 | ||||
|
244 | def test_remote_reference(self): | |||
|
245 | v = self.client[-1] | |||
|
246 | v['a'] = 123 | |||
|
247 | ra = clientmod.Reference('a') | |||
|
248 | b = v.apply_sync_bound(lambda x: x, ra) | |||
|
249 | self.assertEquals(b, 123) | |||
|
250 | self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra) | |||
|
251 | ||||
|
252 |
@@ -4,7 +4,7 b' import uuid' | |||||
4 | import zmq |
|
4 | import zmq | |
5 |
|
5 | |||
6 | from zmq.tests import BaseZMQTestCase |
|
6 | from zmq.tests import BaseZMQTestCase | |
7 |
|
7 | from zmq.eventloop.zmqstream import ZMQStream | ||
8 | # from IPython.zmq.tests import SessionTestCase |
|
8 | # from IPython.zmq.tests import SessionTestCase | |
9 | from IPython.zmq.parallel import streamsession as ss |
|
9 | from IPython.zmq.parallel import streamsession as ss | |
10 |
|
10 | |||
@@ -31,7 +31,7 b' class TestSession(SessionTestCase):' | |||||
31 |
|
31 | |||
32 | def test_args(self): |
|
32 | def test_args(self): | |
33 | """initialization arguments for StreamSession""" |
|
33 | """initialization arguments for StreamSession""" | |
34 |
s = s |
|
34 | s = self.session | |
35 | self.assertTrue(s.pack is ss.default_packer) |
|
35 | self.assertTrue(s.pack is ss.default_packer) | |
36 | self.assertTrue(s.unpack is ss.default_unpacker) |
|
36 | self.assertTrue(s.unpack is ss.default_unpacker) | |
37 | self.assertEquals(s.username, os.environ.get('USER', 'username')) |
|
37 | self.assertEquals(s.username, os.environ.get('USER', 'username')) | |
@@ -46,7 +46,24 b' class TestSession(SessionTestCase):' | |||||
46 | self.assertEquals(s.session, u) |
|
46 | self.assertEquals(s.session, u) | |
47 | self.assertEquals(s.username, 'carrot') |
|
47 | self.assertEquals(s.username, 'carrot') | |
48 |
|
48 | |||
49 |
|
49 | def test_tracking(self): | ||
|
50 | """test tracking messages""" | |||
|
51 | a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) | |||
|
52 | s = self.session | |||
|
53 | stream = ZMQStream(a) | |||
|
54 | msg = s.send(a, 'hello', track=False) | |||
|
55 | self.assertTrue(msg['tracker'] is None) | |||
|
56 | msg = s.send(a, 'hello', track=True) | |||
|
57 | self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) | |||
|
58 | M = zmq.Message(b'hi there', track=True) | |||
|
59 | msg = s.send(a, 'hello', buffers=[M], track=True) | |||
|
60 | t = msg['tracker'] | |||
|
61 | self.assertTrue(isinstance(t, zmq.MessageTracker)) | |||
|
62 | self.assertRaises(zmq.NotDone, t.wait, .1) | |||
|
63 | del M | |||
|
64 | t.wait(1) # this will raise | |||
|
65 | ||||
|
66 | ||||
50 | # def test_rekey(self): |
|
67 | # def test_rekey(self): | |
51 | # """rekeying dict around json str keys""" |
|
68 | # """rekeying dict around json str keys""" | |
52 | # d = {'0': uuid.uuid4(), 0:uuid.uuid4()} |
|
69 | # d = {'0': uuid.uuid4(), 0:uuid.uuid4()} |
General Comments 0
You need to be logged in to leave comments.
Login now