##// END OF EJS Templates
add message tracking to client, add/improve tests
MinRK -
Show More
@@ -1,305 +1,322 b''
1 1 """AsyncResult objects for the client"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import time
14 14
15 15 from IPython.external.decorator import decorator
16 16 from . import error
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Classes
20 20 #-----------------------------------------------------------------------------
21 21
22 22 @decorator
23 23 def check_ready(f, self, *args, **kwargs):
24 24 """Call spin() to sync state prior to calling the method."""
25 25 self.wait(0)
26 26 if not self._ready:
27 27 raise error.TimeoutError("result not ready")
28 28 return f(self, *args, **kwargs)
29 29
30 30 class AsyncResult(object):
31 31 """Class for representing results of non-blocking calls.
32 32
33 33 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
34 34 """
35 35
36 36 msg_ids = None
37 _targets = None
38 _tracker = None
37 39
38 def __init__(self, client, msg_ids, fname='unknown'):
40 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
39 41 self._client = client
40 42 if isinstance(msg_ids, basestring):
41 43 msg_ids = [msg_ids]
42 44 self.msg_ids = msg_ids
43 45 self._fname=fname
46 self._targets = targets
47 self._tracker = tracker
44 48 self._ready = False
45 49 self._success = None
46 50 self._single_result = len(msg_ids) == 1
47 51
48 52 def __repr__(self):
49 53 if self._ready:
50 54 return "<%s: finished>"%(self.__class__.__name__)
51 55 else:
52 56 return "<%s: %s>"%(self.__class__.__name__,self._fname)
53 57
54 58
55 59 def _reconstruct_result(self, res):
56 60 """Reconstruct our result from actual result list (always a list)
57 61
58 62 Override me in subclasses for turning a list of results
59 63 into the expected form.
60 64 """
61 65 if self._single_result:
62 66 return res[0]
63 67 else:
64 68 return res
65 69
66 70 def get(self, timeout=-1):
67 71 """Return the result when it arrives.
68 72
69 73 If `timeout` is not ``None`` and the result does not arrive within
70 74 `timeout` seconds then ``TimeoutError`` is raised. If the
71 75 remote call raised an exception then that exception will be reraised
72 76 by get() inside a `RemoteError`.
73 77 """
74 78 if not self.ready():
75 79 self.wait(timeout)
76 80
77 81 if self._ready:
78 82 if self._success:
79 83 return self._result
80 84 else:
81 85 raise self._exception
82 86 else:
83 87 raise error.TimeoutError("Result not ready.")
84 88
85 89 def ready(self):
86 90 """Return whether the call has completed."""
87 91 if not self._ready:
88 92 self.wait(0)
89 93 return self._ready
90 94
91 95 def wait(self, timeout=-1):
92 96 """Wait until the result is available or until `timeout` seconds pass.
93 97
94 98 This method always returns None.
95 99 """
96 100 if self._ready:
97 101 return
98 102 self._ready = self._client.barrier(self.msg_ids, timeout)
99 103 if self._ready:
100 104 try:
101 105 results = map(self._client.results.get, self.msg_ids)
102 106 self._result = results
103 107 if self._single_result:
104 108 r = results[0]
105 109 if isinstance(r, Exception):
106 110 raise r
107 111 else:
108 112 results = error.collect_exceptions(results, self._fname)
109 113 self._result = self._reconstruct_result(results)
110 114 except Exception, e:
111 115 self._exception = e
112 116 self._success = False
113 117 else:
114 118 self._success = True
115 119 finally:
116 120 self._metadata = map(self._client.metadata.get, self.msg_ids)
117 121
118 122
119 123 def successful(self):
120 124 """Return whether the call completed without raising an exception.
121 125
122 126 Will raise ``AssertionError`` if the result is not ready.
123 127 """
124 128 assert self.ready()
125 129 return self._success
126 130
127 131 #----------------------------------------------------------------
128 132 # Extra methods not in mp.pool.AsyncResult
129 133 #----------------------------------------------------------------
130 134
131 135 def get_dict(self, timeout=-1):
132 136 """Get the results as a dict, keyed by engine_id.
133 137
134 138 timeout behavior is described in `get()`.
135 139 """
136 140
137 141 results = self.get(timeout)
138 142 engine_ids = [ md['engine_id'] for md in self._metadata ]
139 143 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
140 144 maxcount = bycount.count(bycount[-1])
141 145 if maxcount > 1:
142 146 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
143 147 maxcount, bycount[-1]))
144 148
145 149 return dict(zip(engine_ids,results))
146 150
147 151 @property
148 152 @check_ready
149 153 def result(self):
150 154 """result property wrapper for `get(timeout=0)`."""
151 155 return self._result
152 156
153 157 # abbreviated alias:
154 158 r = result
155 159
156 160 @property
157 161 @check_ready
158 162 def metadata(self):
159 163 """property for accessing execution metadata."""
160 164 if self._single_result:
161 165 return self._metadata[0]
162 166 else:
163 167 return self._metadata
164 168
165 169 @property
166 170 def result_dict(self):
167 171 """result property as a dict."""
168 172 return self.get_dict(0)
169 173
170 174 def __dict__(self):
171 175 return self.get_dict(0)
176
177 def abort(self):
178 """abort my tasks."""
179 assert not self.ready(), "Can't abort, I am already done!"
180 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
181
182 @property
183 def sent(self):
184 """check whether my messages have been sent"""
185 if self._tracker is None:
186 return True
187 else:
188 return self._tracker.done
172 189
173 190 #-------------------------------------
174 191 # dict-access
175 192 #-------------------------------------
176 193
177 194 @check_ready
178 195 def __getitem__(self, key):
179 196 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
180 197 """
181 198 if isinstance(key, int):
182 199 return error.collect_exceptions([self._result[key]], self._fname)[0]
183 200 elif isinstance(key, slice):
184 201 return error.collect_exceptions(self._result[key], self._fname)
185 202 elif isinstance(key, basestring):
186 203 values = [ md[key] for md in self._metadata ]
187 204 if self._single_result:
188 205 return values[0]
189 206 else:
190 207 return values
191 208 else:
192 209 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
193 210
194 211 @check_ready
195 212 def __getattr__(self, key):
196 213 """getattr maps to getitem for convenient attr access to metadata."""
197 214 if key not in self._metadata[0].keys():
198 215 raise AttributeError("%r object has no attribute %r"%(
199 216 self.__class__.__name__, key))
200 217 return self.__getitem__(key)
201 218
202 219 # asynchronous iterator:
203 220 def __iter__(self):
204 221 if self._single_result:
205 222 raise TypeError("AsyncResults with a single result are not iterable.")
206 223 try:
207 224 rlist = self.get(0)
208 225 except error.TimeoutError:
209 226 # wait for each result individually
210 227 for msg_id in self.msg_ids:
211 228 ar = AsyncResult(self._client, msg_id, self._fname)
212 229 yield ar.get()
213 230 else:
214 231 # already done
215 232 for r in rlist:
216 233 yield r
217 234
218 235
219 236
220 237 class AsyncMapResult(AsyncResult):
221 238 """Class for representing results of non-blocking gathers.
222 239
223 240 This will properly reconstruct the gather.
224 241 """
225 242
226 243 def __init__(self, client, msg_ids, mapObject, fname=''):
227 244 AsyncResult.__init__(self, client, msg_ids, fname=fname)
228 245 self._mapObject = mapObject
229 246 self._single_result = False
230 247
231 248 def _reconstruct_result(self, res):
232 249 """Perform the gather on the actual results."""
233 250 return self._mapObject.joinPartitions(res)
234 251
235 252 # asynchronous iterator:
236 253 def __iter__(self):
237 254 try:
238 255 rlist = self.get(0)
239 256 except error.TimeoutError:
240 257 # wait for each result individually
241 258 for msg_id in self.msg_ids:
242 259 ar = AsyncResult(self._client, msg_id, self._fname)
243 260 rlist = ar.get()
244 261 try:
245 262 for r in rlist:
246 263 yield r
247 264 except TypeError:
248 265 # flattened, not a list
249 266 # this could get broken by flattened data that returns iterables
250 267 # but most calls to map do not expose the `flatten` argument
251 268 yield rlist
252 269 else:
253 270 # already done
254 271 for r in rlist:
255 272 yield r
256 273
257 274
258 275 class AsyncHubResult(AsyncResult):
259 276 """Class to wrap pending results that must be requested from the Hub.
260 277
261 278 Note that waiting/polling on these objects requires polling the Hubover the network,
262 279 so use `AsyncHubResult.wait()` sparingly.
263 280 """
264 281
265 282 def wait(self, timeout=-1):
266 283 """wait for result to complete."""
267 284 start = time.time()
268 285 if self._ready:
269 286 return
270 287 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
271 288 local_ready = self._client.barrier(local_ids, timeout)
272 289 if local_ready:
273 290 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
274 291 if not remote_ids:
275 292 self._ready = True
276 293 else:
277 294 rdict = self._client.result_status(remote_ids, status_only=False)
278 295 pending = rdict['pending']
279 296 while pending and (timeout < 0 or time.time() < start+timeout):
280 297 rdict = self._client.result_status(remote_ids, status_only=False)
281 298 pending = rdict['pending']
282 299 if pending:
283 300 time.sleep(0.1)
284 301 if not pending:
285 302 self._ready = True
286 303 if self._ready:
287 304 try:
288 305 results = map(self._client.results.get, self.msg_ids)
289 306 self._result = results
290 307 if self._single_result:
291 308 r = results[0]
292 309 if isinstance(r, Exception):
293 310 raise r
294 311 else:
295 312 results = error.collect_exceptions(results, self._fname)
296 313 self._result = self._reconstruct_result(results)
297 314 except Exception, e:
298 315 self._exception = e
299 316 self._success = False
300 317 else:
301 318 self._success = True
302 319 finally:
303 320 self._metadata = map(self._client.metadata.get, self.msg_ids)
304 321
305 322 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
@@ -1,1545 +1,1569 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.pickleutil import Reference
28 28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 29 Dict, List, Bool, Str, Set)
30 30 from IPython.external.decorator import decorator
31 31 from IPython.external.ssh import tunnel
32 32
33 33 from . import error
34 34 from . import map as Map
35 35 from . import util
36 36 from . import streamsession as ss
37 37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
38 38 from .clusterdir import ClusterDir, ClusterDirError
39 39 from .dependency import Dependency, depend, require, dependent
40 40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 41 from .util import ReverseDict, validate_url, disambiguate_url
42 42 from .view import DirectView, LoadBalancedView
43 43
44 44 #--------------------------------------------------------------------------
45 45 # helpers for implementing old MEC API via client.apply
46 46 #--------------------------------------------------------------------------
47 47
48 48 def _push(ns):
49 49 """helper method for implementing `client.push` via `client.apply`"""
50 50 globals().update(ns)
51 51
52 52 def _pull(keys):
53 53 """helper method for implementing `client.pull` via `client.apply`"""
54 54 g = globals()
55 55 if isinstance(keys, (list,tuple, set)):
56 56 for key in keys:
57 57 if not g.has_key(key):
58 58 raise NameError("name '%s' is not defined"%key)
59 59 return map(g.get, keys)
60 60 else:
61 61 if not g.has_key(keys):
62 62 raise NameError("name '%s' is not defined"%keys)
63 63 return g.get(keys)
64 64
65 65 def _clear():
66 66 """helper method for implementing `client.clear` via `client.apply`"""
67 67 globals().clear()
68 68
69 69 def _execute(code):
70 70 """helper method for implementing `client.execute` via `client.apply`"""
71 71 exec code in globals()
72 72
73 73
74 74 #--------------------------------------------------------------------------
75 75 # Decorators for Client methods
76 76 #--------------------------------------------------------------------------
77 77
78 78 @decorator
79 79 def spinfirst(f, self, *args, **kwargs):
80 80 """Call spin() to sync state prior to calling the method."""
81 81 self.spin()
82 82 return f(self, *args, **kwargs)
83 83
84 84 @decorator
85 85 def defaultblock(f, self, *args, **kwargs):
86 86 """Default to self.block; preserve self.block."""
87 87 block = kwargs.get('block',None)
88 88 block = self.block if block is None else block
89 89 saveblock = self.block
90 90 self.block = block
91 91 try:
92 92 ret = f(self, *args, **kwargs)
93 93 finally:
94 94 self.block = saveblock
95 95 return ret
96 96
97 97
98 98 #--------------------------------------------------------------------------
99 99 # Classes
100 100 #--------------------------------------------------------------------------
101 101
102 102 class Metadata(dict):
103 103 """Subclass of dict for initializing metadata values.
104 104
105 105 Attribute access works on keys.
106 106
107 107 These objects have a strict set of keys - errors will raise if you try
108 108 to add new keys.
109 109 """
110 110 def __init__(self, *args, **kwargs):
111 111 dict.__init__(self)
112 112 md = {'msg_id' : None,
113 113 'submitted' : None,
114 114 'started' : None,
115 115 'completed' : None,
116 116 'received' : None,
117 117 'engine_uuid' : None,
118 118 'engine_id' : None,
119 119 'follow' : None,
120 120 'after' : None,
121 121 'status' : None,
122 122
123 123 'pyin' : None,
124 124 'pyout' : None,
125 125 'pyerr' : None,
126 126 'stdout' : '',
127 127 'stderr' : '',
128 128 }
129 129 self.update(md)
130 130 self.update(dict(*args, **kwargs))
131 131
132 132 def __getattr__(self, key):
133 133 """getattr aliased to getitem"""
134 134 if key in self.iterkeys():
135 135 return self[key]
136 136 else:
137 137 raise AttributeError(key)
138 138
139 139 def __setattr__(self, key, value):
140 140 """setattr aliased to setitem, with strict"""
141 141 if key in self.iterkeys():
142 142 self[key] = value
143 143 else:
144 144 raise AttributeError(key)
145 145
146 146 def __setitem__(self, key, value):
147 147 """strict static key enforcement"""
148 148 if key in self.iterkeys():
149 149 dict.__setitem__(self, key, value)
150 150 else:
151 151 raise KeyError(key)
152 152
153 153
154 154 class Client(HasTraits):
155 155 """A semi-synchronous client to the IPython ZMQ controller
156 156
157 157 Parameters
158 158 ----------
159 159
160 160 url_or_file : bytes; zmq url or path to ipcontroller-client.json
161 161 Connection information for the Hub's registration. If a json connector
162 162 file is given, then likely no further configuration is necessary.
163 163 [Default: use profile]
164 164 profile : bytes
165 165 The name of the Cluster profile to be used to find connector information.
166 166 [Default: 'default']
167 167 context : zmq.Context
168 168 Pass an existing zmq.Context instance, otherwise the client will create its own.
169 169 username : bytes
170 170 set username to be passed to the Session object
171 171 debug : bool
172 172 flag for lots of message printing for debug purposes
173 173
174 174 #-------------- ssh related args ----------------
175 175 # These are args for configuring the ssh tunnel to be used
176 176 # credentials are used to forward connections over ssh to the Controller
177 177 # Note that the ip given in `addr` needs to be relative to sshserver
178 178 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
179 179 # and set sshserver as the same machine the Controller is on. However,
180 180 # the only requirement is that sshserver is able to see the Controller
181 181 # (i.e. is within the same trusted network).
182 182
183 183 sshserver : str
184 184 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
185 185 If keyfile or password is specified, and this is not, it will default to
186 186 the ip given in addr.
187 187 sshkey : str; path to public ssh key file
188 188 This specifies a key to be used in ssh login, default None.
189 189 Regular default ssh keys will be used without specifying this argument.
190 190 password : str
191 191 Your ssh password to sshserver. Note that if this is left None,
192 192 you will be prompted for it if passwordless key based login is unavailable.
193 193 paramiko : bool
194 194 flag for whether to use paramiko instead of shell ssh for tunneling.
195 195 [default: True on win32, False else]
196 196
197 197 #------- exec authentication args -------
198 198 # If even localhost is untrusted, you can have some protection against
199 199 # unauthorized execution by using a key. Messages are still sent
200 200 # as cleartext, so if someone can snoop your loopback traffic this will
201 201 # not help against malicious attacks.
202 202
203 203 exec_key : str
204 204 an authentication key or file containing a key
205 205 default: None
206 206
207 207
208 208 Attributes
209 209 ----------
210 210
211 211 ids : set of int engine IDs
212 212 requesting the ids attribute always synchronizes
213 213 the registration state. To request ids without synchronization,
214 214 use semi-private _ids attributes.
215 215
216 216 history : list of msg_ids
217 217 a list of msg_ids, keeping track of all the execution
218 218 messages you have submitted in order.
219 219
220 220 outstanding : set of msg_ids
221 221 a set of msg_ids that have been submitted, but whose
222 222 results have not yet been received.
223 223
224 224 results : dict
225 225 a dict of all our results, keyed by msg_id
226 226
227 227 block : bool
228 228 determines default behavior when block not specified
229 229 in execution methods
230 230
231 231 Methods
232 232 -------
233 233
234 234 spin
235 235 flushes incoming results and registration state changes
236 236 control methods spin, and requesting `ids` also ensures up to date
237 237
238 238 barrier
239 239 wait on one or more msg_ids
240 240
241 241 execution methods
242 242 apply
243 243 legacy: execute, run
244 244
245 245 query methods
246 246 queue_status, get_result, purge
247 247
248 248 control methods
249 249 abort, shutdown
250 250
251 251 """
252 252
253 253
254 254 block = Bool(False)
255 255 outstanding = Set()
256 256 results = Instance('collections.defaultdict', (dict,))
257 257 metadata = Instance('collections.defaultdict', (Metadata,))
258 258 history = List()
259 259 debug = Bool(False)
260 260 profile=CUnicode('default')
261 261
262 262 _outstanding_dict = Instance('collections.defaultdict', (set,))
263 263 _ids = List()
264 264 _connected=Bool(False)
265 265 _ssh=Bool(False)
266 266 _context = Instance('zmq.Context')
267 267 _config = Dict()
268 268 _engines=Instance(ReverseDict, (), {})
269 269 _registration_socket=Instance('zmq.Socket')
270 270 _query_socket=Instance('zmq.Socket')
271 271 _control_socket=Instance('zmq.Socket')
272 272 _iopub_socket=Instance('zmq.Socket')
273 273 _notification_socket=Instance('zmq.Socket')
274 274 _mux_socket=Instance('zmq.Socket')
275 275 _task_socket=Instance('zmq.Socket')
276 276 _task_scheme=Str()
277 277 _balanced_views=Dict()
278 278 _direct_views=Dict()
279 279 _closed = False
280 280
281 281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
282 282 context=None, username=None, debug=False, exec_key=None,
283 283 sshserver=None, sshkey=None, password=None, paramiko=None,
284 284 ):
285 285 super(Client, self).__init__(debug=debug, profile=profile)
286 286 if context is None:
287 287 context = zmq.Context()
288 288 self._context = context
289 289
290 290
291 291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
292 292 if self._cd is not None:
293 293 if url_or_file is None:
294 294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
295 295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
296 296 " Please specify at least one of url_or_file or profile."
297 297
298 298 try:
299 299 validate_url(url_or_file)
300 300 except AssertionError:
301 301 if not os.path.exists(url_or_file):
302 302 if self._cd:
303 303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
304 304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
305 305 with open(url_or_file) as f:
306 306 cfg = json.loads(f.read())
307 307 else:
308 308 cfg = {'url':url_or_file}
309 309
310 310 # sync defaults from args, json:
311 311 if sshserver:
312 312 cfg['ssh'] = sshserver
313 313 if exec_key:
314 314 cfg['exec_key'] = exec_key
315 315 exec_key = cfg['exec_key']
316 316 sshserver=cfg['ssh']
317 317 url = cfg['url']
318 318 location = cfg.setdefault('location', None)
319 319 cfg['url'] = disambiguate_url(cfg['url'], location)
320 320 url = cfg['url']
321 321
322 322 self._config = cfg
323 323
324 324 self._ssh = bool(sshserver or sshkey or password)
325 325 if self._ssh and sshserver is None:
326 326 # default to ssh via localhost
327 327 sshserver = url.split('://')[1].split(':')[0]
328 328 if self._ssh and password is None:
329 329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
330 330 password=False
331 331 else:
332 332 password = getpass("SSH Password for %s: "%sshserver)
333 333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
334 334 if exec_key is not None and os.path.isfile(exec_key):
335 335 arg = 'keyfile'
336 336 else:
337 337 arg = 'key'
338 338 key_arg = {arg:exec_key}
339 339 if username is None:
340 340 self.session = ss.StreamSession(**key_arg)
341 341 else:
342 342 self.session = ss.StreamSession(username, **key_arg)
343 343 self._registration_socket = self._context.socket(zmq.XREQ)
344 344 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
345 345 if self._ssh:
346 346 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
347 347 else:
348 348 self._registration_socket.connect(url)
349 349
350 350 self.session.debug = self.debug
351 351
352 352 self._notification_handlers = {'registration_notification' : self._register_engine,
353 353 'unregistration_notification' : self._unregister_engine,
354 354 }
355 355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
356 356 'apply_reply' : self._handle_apply_reply}
357 357 self._connect(sshserver, ssh_kwargs)
358 358
359 def __del__(self):
360 """cleanup sockets, but _not_ context."""
361 self.close()
359 362
360 363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
361 364 if ipython_dir is None:
362 365 ipython_dir = get_ipython_dir()
363 366 if cluster_dir is not None:
364 367 try:
365 368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
366 369 return
367 370 except ClusterDirError:
368 371 pass
369 372 elif profile is not None:
370 373 try:
371 374 self._cd = ClusterDir.find_cluster_dir_by_profile(
372 375 ipython_dir, profile)
373 376 return
374 377 except ClusterDirError:
375 378 pass
376 379 self._cd = None
377 380
378 381 @property
379 382 def ids(self):
380 383 """Always up-to-date ids property."""
381 384 self._flush_notifications()
382 385 # always copy:
383 386 return list(self._ids)
384 387
385 388 def close(self):
386 389 if self._closed:
387 390 return
388 391 snames = filter(lambda n: n.endswith('socket'), dir(self))
389 392 for socket in map(lambda name: getattr(self, name), snames):
390 socket.close()
393 if isinstance(socket, zmq.Socket) and not socket.closed:
394 socket.close()
391 395 self._closed = True
392 396
393 397 def _update_engines(self, engines):
394 398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
395 399 for k,v in engines.iteritems():
396 400 eid = int(k)
397 401 self._engines[eid] = bytes(v) # force not unicode
398 402 self._ids.append(eid)
399 403 self._ids = sorted(self._ids)
400 404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
401 405 self._task_scheme == 'pure' and self._task_socket:
402 406 self._stop_scheduling_tasks()
403 407
404 408 def _stop_scheduling_tasks(self):
405 409 """Stop scheduling tasks because an engine has been unregistered
406 410 from a pure ZMQ scheduler.
407 411 """
408 412
409 413 self._task_socket.close()
410 414 self._task_socket = None
411 415 msg = "An engine has been unregistered, and we are using pure " +\
412 416 "ZMQ task scheduling. Task farming will be disabled."
413 417 if self.outstanding:
414 418 msg += " If you were running tasks when this happened, " +\
415 419 "some `outstanding` msg_ids may never resolve."
416 420 warnings.warn(msg, RuntimeWarning)
417 421
418 422 def _build_targets(self, targets):
419 423 """Turn valid target IDs or 'all' into two lists:
420 424 (int_ids, uuids).
421 425 """
422 426 if targets is None:
423 427 targets = self._ids
424 428 elif isinstance(targets, str):
425 429 if targets.lower() == 'all':
426 430 targets = self._ids
427 431 else:
428 432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
429 433 elif isinstance(targets, int):
430 434 targets = [targets]
431 435 return [self._engines[t] for t in targets], list(targets)
432 436
433 437 def _connect(self, sshserver, ssh_kwargs):
434 438 """setup all our socket connections to the controller. This is called from
435 439 __init__."""
436 440
437 441 # Maybe allow reconnecting?
438 442 if self._connected:
439 443 return
440 444 self._connected=True
441 445
442 446 def connect_socket(s, url):
443 447 url = disambiguate_url(url, self._config['location'])
444 448 if self._ssh:
445 449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
446 450 else:
447 451 return s.connect(url)
448 452
449 453 self.session.send(self._registration_socket, 'connection_request')
450 454 idents,msg = self.session.recv(self._registration_socket,mode=0)
451 455 if self.debug:
452 456 pprint(msg)
453 457 msg = ss.Message(msg)
454 458 content = msg.content
455 459 self._config['registration'] = dict(content)
456 460 if content.status == 'ok':
457 461 if content.mux:
458 462 self._mux_socket = self._context.socket(zmq.PAIR)
459 463 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
460 464 connect_socket(self._mux_socket, content.mux)
461 465 if content.task:
462 466 self._task_scheme, task_addr = content.task
463 467 self._task_socket = self._context.socket(zmq.PAIR)
464 468 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
465 469 connect_socket(self._task_socket, task_addr)
466 470 if content.notification:
467 471 self._notification_socket = self._context.socket(zmq.SUB)
468 472 connect_socket(self._notification_socket, content.notification)
469 473 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
470 474 if content.query:
471 475 self._query_socket = self._context.socket(zmq.PAIR)
472 476 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
473 477 connect_socket(self._query_socket, content.query)
474 478 if content.control:
475 479 self._control_socket = self._context.socket(zmq.PAIR)
476 480 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
477 481 connect_socket(self._control_socket, content.control)
478 482 if content.iopub:
479 483 self._iopub_socket = self._context.socket(zmq.SUB)
480 484 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
481 485 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
482 486 connect_socket(self._iopub_socket, content.iopub)
483 487 self._update_engines(dict(content.engines))
484 488
485 489 else:
486 490 self._connected = False
487 491 raise Exception("Failed to connect!")
488 492
489 493 #--------------------------------------------------------------------------
490 494 # handlers and callbacks for incoming messages
491 495 #--------------------------------------------------------------------------
492 496
493 497 def _unwrap_exception(self, content):
494 498 """unwrap exception, and remap engineid to int."""
495 499 e = error.unwrap_exception(content)
496 500 if e.engine_info:
497 501 e_uuid = e.engine_info['engine_uuid']
498 502 eid = self._engines[e_uuid]
499 503 e.engine_info['engine_id'] = eid
500 504 return e
501 505
502 506 def _extract_metadata(self, header, parent, content):
503 507 md = {'msg_id' : parent['msg_id'],
504 508 'received' : datetime.now(),
505 509 'engine_uuid' : header.get('engine', None),
506 510 'follow' : parent.get('follow', []),
507 511 'after' : parent.get('after', []),
508 512 'status' : content['status'],
509 513 }
510 514
511 515 if md['engine_uuid'] is not None:
512 516 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
513 517
514 518 if 'date' in parent:
515 519 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
516 520 if 'started' in header:
517 521 md['started'] = datetime.strptime(header['started'], util.ISO8601)
518 522 if 'date' in header:
519 523 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
520 524 return md
521 525
522 526 def _register_engine(self, msg):
523 527 """Register a new engine, and update our connection info."""
524 528 content = msg['content']
525 529 eid = content['id']
526 530 d = {eid : content['queue']}
527 531 self._update_engines(d)
528 532
529 533 def _unregister_engine(self, msg):
530 534 """Unregister an engine that has died."""
531 535 content = msg['content']
532 536 eid = int(content['id'])
533 537 if eid in self._ids:
534 538 self._ids.remove(eid)
535 539 uuid = self._engines.pop(eid)
536 540
537 541 self._handle_stranded_msgs(eid, uuid)
538 542
539 543 if self._task_socket and self._task_scheme == 'pure':
540 544 self._stop_scheduling_tasks()
541 545
542 546 def _handle_stranded_msgs(self, eid, uuid):
543 547 """Handle messages known to be on an engine when the engine unregisters.
544 548
545 549 It is possible that this will fire prematurely - that is, an engine will
546 550 go down after completing a result, and the client will be notified
547 551 of the unregistration and later receive the successful result.
548 552 """
549 553
550 554 outstanding = self._outstanding_dict[uuid]
551 555
552 556 for msg_id in list(outstanding):
553 print msg_id
554 557 if msg_id in self.results:
555 558 # we already
556 559 continue
557 560 try:
558 561 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
559 562 except:
560 563 content = error.wrap_exception()
561 564 # build a fake message:
562 565 parent = {}
563 566 header = {}
564 567 parent['msg_id'] = msg_id
565 568 header['engine'] = uuid
566 569 header['date'] = datetime.now().strftime(util.ISO8601)
567 570 msg = dict(parent_header=parent, header=header, content=content)
568 571 self._handle_apply_reply(msg)
569 572
570 573 def _handle_execute_reply(self, msg):
571 574 """Save the reply to an execute_request into our results.
572 575
573 576 execute messages are never actually used. apply is used instead.
574 577 """
575 578
576 579 parent = msg['parent_header']
577 580 msg_id = parent['msg_id']
578 581 if msg_id not in self.outstanding:
579 582 if msg_id in self.history:
580 583 print ("got stale result: %s"%msg_id)
581 584 else:
582 585 print ("got unknown result: %s"%msg_id)
583 586 else:
584 587 self.outstanding.remove(msg_id)
585 588 self.results[msg_id] = self._unwrap_exception(msg['content'])
586 589
587 590 def _handle_apply_reply(self, msg):
588 591 """Save the reply to an apply_request into our results."""
589 592 parent = msg['parent_header']
590 593 msg_id = parent['msg_id']
591 594 if msg_id not in self.outstanding:
592 595 if msg_id in self.history:
593 596 print ("got stale result: %s"%msg_id)
594 597 print self.results[msg_id]
595 598 print msg
596 599 else:
597 600 print ("got unknown result: %s"%msg_id)
598 601 else:
599 602 self.outstanding.remove(msg_id)
600 603 content = msg['content']
601 604 header = msg['header']
602 605
603 606 # construct metadata:
604 607 md = self.metadata[msg_id]
605 608 md.update(self._extract_metadata(header, parent, content))
606 609 # is this redundant?
607 610 self.metadata[msg_id] = md
608 611
609 612 e_outstanding = self._outstanding_dict[md['engine_uuid']]
610 613 if msg_id in e_outstanding:
611 614 e_outstanding.remove(msg_id)
612 615
613 616 # construct result:
614 617 if content['status'] == 'ok':
615 618 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
616 619 elif content['status'] == 'aborted':
617 620 self.results[msg_id] = error.AbortedTask(msg_id)
618 621 elif content['status'] == 'resubmitted':
619 622 # TODO: handle resubmission
620 623 pass
621 624 else:
622 625 self.results[msg_id] = self._unwrap_exception(content)
623 626
624 627 def _flush_notifications(self):
625 628 """Flush notifications of engine registrations waiting
626 629 in ZMQ queue."""
627 630 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
628 631 while msg is not None:
629 632 if self.debug:
630 633 pprint(msg)
631 634 msg = msg[-1]
632 635 msg_type = msg['msg_type']
633 636 handler = self._notification_handlers.get(msg_type, None)
634 637 if handler is None:
635 638 raise Exception("Unhandled message type: %s"%msg.msg_type)
636 639 else:
637 640 handler(msg)
638 641 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
639 642
640 643 def _flush_results(self, sock):
641 644 """Flush task or queue results waiting in ZMQ queue."""
642 645 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
643 646 while msg is not None:
644 647 if self.debug:
645 648 pprint(msg)
646 649 msg = msg[-1]
647 650 msg_type = msg['msg_type']
648 651 handler = self._queue_handlers.get(msg_type, None)
649 652 if handler is None:
650 653 raise Exception("Unhandled message type: %s"%msg.msg_type)
651 654 else:
652 655 handler(msg)
653 656 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
654 657
655 658 def _flush_control(self, sock):
656 659 """Flush replies from the control channel waiting
657 660 in the ZMQ queue.
658 661
659 662 Currently: ignore them."""
660 663 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
661 664 while msg is not None:
662 665 if self.debug:
663 666 pprint(msg)
664 667 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
665 668
666 669 def _flush_iopub(self, sock):
667 670 """Flush replies from the iopub channel waiting
668 671 in the ZMQ queue.
669 672 """
670 673 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
671 674 while msg is not None:
672 675 if self.debug:
673 676 pprint(msg)
674 677 msg = msg[-1]
675 678 parent = msg['parent_header']
676 679 msg_id = parent['msg_id']
677 680 content = msg['content']
678 681 header = msg['header']
679 682 msg_type = msg['msg_type']
680 683
681 684 # init metadata:
682 685 md = self.metadata[msg_id]
683 686
684 687 if msg_type == 'stream':
685 688 name = content['name']
686 689 s = md[name] or ''
687 690 md[name] = s + content['data']
688 691 elif msg_type == 'pyerr':
689 692 md.update({'pyerr' : self._unwrap_exception(content)})
690 693 else:
691 694 md.update({msg_type : content['data']})
692 695
693 696 # reduntant?
694 697 self.metadata[msg_id] = md
695 698
696 699 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
697 700
698 701 #--------------------------------------------------------------------------
699 702 # len, getitem
700 703 #--------------------------------------------------------------------------
701 704
702 705 def __len__(self):
703 706 """len(client) returns # of engines."""
704 707 return len(self.ids)
705 708
706 709 def __getitem__(self, key):
707 710 """index access returns DirectView multiplexer objects
708 711
709 712 Must be int, slice, or list/tuple/xrange of ints"""
710 713 if not isinstance(key, (int, slice, tuple, list, xrange)):
711 714 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
712 715 else:
713 716 return self.view(key, balanced=False)
714 717
715 718 #--------------------------------------------------------------------------
716 719 # Begin public methods
717 720 #--------------------------------------------------------------------------
718 721
719 722 def spin(self):
720 723 """Flush any registration notifications and execution results
721 724 waiting in the ZMQ queue.
722 725 """
723 726 if self._notification_socket:
724 727 self._flush_notifications()
725 728 if self._mux_socket:
726 729 self._flush_results(self._mux_socket)
727 730 if self._task_socket:
728 731 self._flush_results(self._task_socket)
729 732 if self._control_socket:
730 733 self._flush_control(self._control_socket)
731 734 if self._iopub_socket:
732 735 self._flush_iopub(self._iopub_socket)
733 736
734 737 def barrier(self, jobs=None, timeout=-1):
735 738 """waits on one or more `jobs`, for up to `timeout` seconds.
736 739
737 740 Parameters
738 741 ----------
739 742
740 743 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
741 744 ints are indices to self.history
742 745 strs are msg_ids
743 746 default: wait on all outstanding messages
744 747 timeout : float
745 748 a time in seconds, after which to give up.
746 749 default is -1, which means no timeout
747 750
748 751 Returns
749 752 -------
750 753
751 754 True : when all msg_ids are done
752 755 False : timeout reached, some msg_ids still outstanding
753 756 """
754 757 tic = time.time()
755 758 if jobs is None:
756 759 theids = self.outstanding
757 760 else:
758 761 if isinstance(jobs, (int, str, AsyncResult)):
759 762 jobs = [jobs]
760 763 theids = set()
761 764 for job in jobs:
762 765 if isinstance(job, int):
763 766 # index access
764 767 job = self.history[job]
765 768 elif isinstance(job, AsyncResult):
766 769 map(theids.add, job.msg_ids)
767 770 continue
768 771 theids.add(job)
769 772 if not theids.intersection(self.outstanding):
770 773 return True
771 774 self.spin()
772 775 while theids.intersection(self.outstanding):
773 776 if timeout >= 0 and ( time.time()-tic ) > timeout:
774 777 break
775 778 time.sleep(1e-3)
776 779 self.spin()
777 780 return len(theids.intersection(self.outstanding)) == 0
778 781
779 782 #--------------------------------------------------------------------------
780 783 # Control methods
781 784 #--------------------------------------------------------------------------
782 785
783 786 @spinfirst
784 787 @defaultblock
785 788 def clear(self, targets=None, block=None):
786 789 """Clear the namespace in target(s)."""
787 790 targets = self._build_targets(targets)[0]
788 791 for t in targets:
789 792 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
790 793 error = False
791 794 if self.block:
792 795 for i in range(len(targets)):
793 796 idents,msg = self.session.recv(self._control_socket,0)
794 797 if self.debug:
795 798 pprint(msg)
796 799 if msg['content']['status'] != 'ok':
797 800 error = self._unwrap_exception(msg['content'])
798 801 if error:
799 return error
802 raise error
800 803
801 804
802 805 @spinfirst
803 806 @defaultblock
804 807 def abort(self, jobs=None, targets=None, block=None):
805 808 """Abort specific jobs from the execution queues of target(s).
806 809
807 810 This is a mechanism to prevent jobs that have already been submitted
808 811 from executing.
809 812
810 813 Parameters
811 814 ----------
812 815
813 816 jobs : msg_id, list of msg_ids, or AsyncResult
814 817 The jobs to be aborted
815 818
816 819
817 820 """
818 821 targets = self._build_targets(targets)[0]
819 822 msg_ids = []
820 823 if isinstance(jobs, (basestring,AsyncResult)):
821 824 jobs = [jobs]
822 825 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
823 826 if bad_ids:
824 827 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
825 828 for j in jobs:
826 829 if isinstance(j, AsyncResult):
827 830 msg_ids.extend(j.msg_ids)
828 831 else:
829 832 msg_ids.append(j)
830 833 content = dict(msg_ids=msg_ids)
831 834 for t in targets:
832 835 self.session.send(self._control_socket, 'abort_request',
833 836 content=content, ident=t)
834 837 error = False
835 838 if self.block:
836 839 for i in range(len(targets)):
837 840 idents,msg = self.session.recv(self._control_socket,0)
838 841 if self.debug:
839 842 pprint(msg)
840 843 if msg['content']['status'] != 'ok':
841 844 error = self._unwrap_exception(msg['content'])
842 845 if error:
843 return error
846 raise error
844 847
845 848 @spinfirst
846 849 @defaultblock
847 850 def shutdown(self, targets=None, restart=False, controller=False, block=None):
848 851 """Terminates one or more engine processes, optionally including the controller."""
849 852 if controller:
850 853 targets = 'all'
851 854 targets = self._build_targets(targets)[0]
852 855 for t in targets:
853 856 self.session.send(self._control_socket, 'shutdown_request',
854 857 content={'restart':restart},ident=t)
855 858 error = False
856 859 if block or controller:
857 860 for i in range(len(targets)):
858 861 idents,msg = self.session.recv(self._control_socket,0)
859 862 if self.debug:
860 863 pprint(msg)
861 864 if msg['content']['status'] != 'ok':
862 865 error = self._unwrap_exception(msg['content'])
863 866
864 867 if controller:
865 868 time.sleep(0.25)
866 869 self.session.send(self._query_socket, 'shutdown_request')
867 870 idents,msg = self.session.recv(self._query_socket, 0)
868 871 if self.debug:
869 872 pprint(msg)
870 873 if msg['content']['status'] != 'ok':
871 874 error = self._unwrap_exception(msg['content'])
872 875
873 876 if error:
874 877 raise error
875 878
876 879 #--------------------------------------------------------------------------
877 880 # Execution methods
878 881 #--------------------------------------------------------------------------
879 882
880 883 @defaultblock
881 884 def execute(self, code, targets='all', block=None):
882 885 """Executes `code` on `targets` in blocking or nonblocking manner.
883 886
884 887 ``execute`` is always `bound` (affects engine namespace)
885 888
886 889 Parameters
887 890 ----------
888 891
889 892 code : str
890 893 the code string to be executed
891 894 targets : int/str/list of ints/strs
892 895 the engines on which to execute
893 896 default : all
894 897 block : bool
895 898 whether or not to wait until done to return
896 899 default: self.block
897 900 """
898 901 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
899 902 if not block:
900 903 return result
901 904
902 905 def run(self, filename, targets='all', block=None):
903 906 """Execute contents of `filename` on engine(s).
904 907
905 908 This simply reads the contents of the file and calls `execute`.
906 909
907 910 Parameters
908 911 ----------
909 912
910 913 filename : str
911 914 The path to the file
912 915 targets : int/str/list of ints/strs
913 916 the engines on which to execute
914 917 default : all
915 918 block : bool
916 919 whether or not to wait until done
917 920 default: self.block
918 921
919 922 """
920 923 with open(filename, 'r') as f:
921 924 # add newline in case of trailing indented whitespace
922 925 # which will cause SyntaxError
923 926 code = f.read()+'\n'
924 927 return self.execute(code, targets=targets, block=block)
925 928
926 929 def _maybe_raise(self, result):
927 930 """wrapper for maybe raising an exception if apply failed."""
928 931 if isinstance(result, error.RemoteError):
929 932 raise result
930 933
931 934 return result
932 935
933 936 def _build_dependency(self, dep):
934 937 """helper for building jsonable dependencies from various input forms"""
935 938 if isinstance(dep, Dependency):
936 939 return dep.as_dict()
937 940 elif isinstance(dep, AsyncResult):
938 941 return dep.msg_ids
939 942 elif dep is None:
940 943 return []
941 944 else:
942 945 # pass to Dependency constructor
943 946 return list(Dependency(dep))
944 947
945 948 @defaultblock
946 949 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
947 950 targets=None, balanced=None,
948 after=None, follow=None, timeout=None):
951 after=None, follow=None, timeout=None,
952 track=False):
949 953 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
950 954
951 955 This is the central execution command for the client.
952 956
953 957 Parameters
954 958 ----------
955 959
956 960 f : function
957 961 The fuction to be called remotely
958 962 args : tuple/list
959 963 The positional arguments passed to `f`
960 964 kwargs : dict
961 965 The keyword arguments passed to `f`
962 966 bound : bool (default: True)
963 967 Whether to execute in the Engine(s) namespace, or in a clean
964 968 namespace not affecting the engine.
965 969 block : bool (default: self.block)
966 970 Whether to wait for the result, or return immediately.
967 971 False:
968 972 returns AsyncResult
969 973 True:
970 974 returns actual result(s) of f(*args, **kwargs)
971 975 if multiple targets:
972 976 list of results, matching `targets`
973 977 targets : int,list of ints, 'all', None
974 978 Specify the destination of the job.
975 979 if None:
976 980 Submit via Task queue for load-balancing.
977 981 if 'all':
978 982 Run on all active engines
979 983 if list:
980 984 Run on each specified engine
981 985 if int:
982 986 Run on single engine
983 987
984 988 balanced : bool, default None
985 989 whether to load-balance. This will default to True
986 990 if targets is unspecified, or False if targets is specified.
987 991
988 992 The following arguments are only used when balanced is True:
989 993 after : Dependency or collection of msg_ids
990 994 Only for load-balanced execution (targets=None)
991 995 Specify a list of msg_ids as a time-based dependency.
992 996 This job will only be run *after* the dependencies
993 997 have been met.
994 998
995 999 follow : Dependency or collection of msg_ids
996 1000 Only for load-balanced execution (targets=None)
997 1001 Specify a list of msg_ids as a location-based dependency.
998 1002 This job will only be run on an engine where this dependency
999 1003 is met.
1000 1004
1001 1005 timeout : float/int or None
1002 1006 Only for load-balanced execution (targets=None)
1003 1007 Specify an amount of time (in seconds) for the scheduler to
1004 1008 wait for dependencies to be met before failing with a
1005 1009 DependencyTimeout.
1010 track : bool
1011 whether to track non-copying sends.
1012 [default False]
1006 1013
1007 1014 after,follow,timeout only used if `balanced=True`.
1008 1015
1009 1016 Returns
1010 1017 -------
1011 1018
1012 1019 if block is False:
1013 1020 return AsyncResult wrapping msg_ids
1014 1021 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1015 1022 else:
1016 1023 if single target:
1017 1024 return result of `f(*args, **kwargs)`
1018 1025 else:
1019 1026 return list of results, matching `targets`
1020 1027 """
1021 1028 assert not self._closed, "cannot use me anymore, I'm closed!"
1022 1029 # defaults:
1023 1030 block = block if block is not None else self.block
1024 1031 args = args if args is not None else []
1025 1032 kwargs = kwargs if kwargs is not None else {}
1026 1033
1027 1034 if balanced is None:
1028 1035 if targets is None:
1029 1036 # default to balanced if targets unspecified
1030 1037 balanced = True
1031 1038 else:
1032 1039 # otherwise default to multiplexing
1033 1040 balanced = False
1034 1041
1035 1042 if targets is None and balanced is False:
1036 1043 # default to all if *not* balanced, and targets is unspecified
1037 1044 targets = 'all'
1038 1045
1039 1046 # enforce types of f,args,kwrags
1040 1047 if not callable(f):
1041 1048 raise TypeError("f must be callable, not %s"%type(f))
1042 1049 if not isinstance(args, (tuple, list)):
1043 1050 raise TypeError("args must be tuple or list, not %s"%type(args))
1044 1051 if not isinstance(kwargs, dict):
1045 1052 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046 1053
1047 options = dict(bound=bound, block=block, targets=targets)
1054 options = dict(bound=bound, block=block, targets=targets, track=track)
1048 1055
1049 1056 if balanced:
1050 1057 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1051 1058 after=after, follow=follow, **options)
1052 1059 elif follow or after or timeout:
1053 1060 msg = "follow, after, and timeout args are only used for"
1054 1061 msg += " load-balanced execution."
1055 1062 raise ValueError(msg)
1056 1063 else:
1057 1064 return self._apply_direct(f, args, kwargs, **options)
1058 1065
1059 1066 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1060 after=None, follow=None, timeout=None):
1067 after=None, follow=None, timeout=None, track=None):
1061 1068 """call f(*args, **kwargs) remotely in a load-balanced manner.
1062 1069
1063 1070 This is a private method, see `apply` for details.
1064 1071 Not to be called directly!
1065 1072 """
1066 1073
1067 1074 loc = locals()
1068 for name in ('bound', 'block'):
1075 for name in ('bound', 'block', 'track'):
1069 1076 assert loc[name] is not None, "kwarg %r must be specified!"%name
1070 1077
1071 1078 if self._task_socket is None:
1072 1079 msg = "Task farming is disabled"
1073 1080 if self._task_scheme == 'pure':
1074 1081 msg += " because the pure ZMQ scheduler cannot handle"
1075 1082 msg += " disappearing engines."
1076 1083 raise RuntimeError(msg)
1077 1084
1078 1085 if self._task_scheme == 'pure':
1079 1086 # pure zmq scheme doesn't support dependencies
1080 1087 msg = "Pure ZMQ scheduler doesn't support dependencies"
1081 1088 if (follow or after):
1082 1089 # hard fail on DAG dependencies
1083 1090 raise RuntimeError(msg)
1084 1091 if isinstance(f, dependent):
1085 1092 # soft warn on functional dependencies
1086 1093 warnings.warn(msg, RuntimeWarning)
1087 1094
1088 1095 # defaults:
1089 1096 args = args if args is not None else []
1090 1097 kwargs = kwargs if kwargs is not None else {}
1091 1098
1092 1099 if targets:
1093 1100 idents,_ = self._build_targets(targets)
1094 1101 else:
1095 1102 idents = []
1096 1103
1097 1104 after = self._build_dependency(after)
1098 1105 follow = self._build_dependency(follow)
1099 1106 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1100 1107 bufs = util.pack_apply_message(f,args,kwargs)
1101 1108 content = dict(bound=bound)
1102 1109
1103 1110 msg = self.session.send(self._task_socket, "apply_request",
1104 content=content, buffers=bufs, subheader=subheader)
1111 content=content, buffers=bufs, subheader=subheader, track=track)
1105 1112 msg_id = msg['msg_id']
1106 1113 self.outstanding.add(msg_id)
1107 1114 self.history.append(msg_id)
1108 1115 self.metadata[msg_id]['submitted'] = datetime.now()
1109
1110 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1116 tracker = None if track is False else msg['tracker']
1117 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1111 1118 if block:
1112 1119 try:
1113 1120 return ar.get()
1114 1121 except KeyboardInterrupt:
1115 1122 return ar
1116 1123 else:
1117 1124 return ar
1118 1125
1119 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
1126 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1127 track=None):
1120 1128 """Then underlying method for applying functions to specific engines
1121 1129 via the MUX queue.
1122 1130
1123 1131 This is a private method, see `apply` for details.
1124 1132 Not to be called directly!
1125 1133 """
1126 1134 loc = locals()
1127 for name in ('bound', 'block', 'targets'):
1135 for name in ('bound', 'block', 'targets', 'track'):
1128 1136 assert loc[name] is not None, "kwarg %r must be specified!"%name
1129 1137
1130 1138 idents,targets = self._build_targets(targets)
1131 1139
1132 1140 subheader = {}
1133 1141 content = dict(bound=bound)
1134 1142 bufs = util.pack_apply_message(f,args,kwargs)
1135 1143
1136 1144 msg_ids = []
1145 trackers = []
1137 1146 for ident in idents:
1138 1147 msg = self.session.send(self._mux_socket, "apply_request",
1139 content=content, buffers=bufs, ident=ident, subheader=subheader)
1148 content=content, buffers=bufs, ident=ident, subheader=subheader,
1149 track=track)
1150 if track:
1151 trackers.append(msg['tracker'])
1140 1152 msg_id = msg['msg_id']
1141 1153 self.outstanding.add(msg_id)
1142 1154 self._outstanding_dict[ident].add(msg_id)
1143 1155 self.history.append(msg_id)
1144 1156 msg_ids.append(msg_id)
1145 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1157
1158 tracker = None if track is False else zmq.MessageTracker(*trackers)
1159 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1160
1146 1161 if block:
1147 1162 try:
1148 1163 return ar.get()
1149 1164 except KeyboardInterrupt:
1150 1165 return ar
1151 1166 else:
1152 1167 return ar
1153 1168
1154 1169 #--------------------------------------------------------------------------
1155 1170 # construct a View object
1156 1171 #--------------------------------------------------------------------------
1157 1172
1158 1173 @defaultblock
1159 1174 def remote(self, bound=True, block=None, targets=None, balanced=None):
1160 1175 """Decorator for making a RemoteFunction"""
1161 1176 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1162 1177
1163 1178 @defaultblock
1164 1179 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1165 1180 """Decorator for making a ParallelFunction"""
1166 1181 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1167 1182
1168 1183 def _cache_view(self, targets, balanced):
1169 1184 """save views, so subsequent requests don't create new objects."""
1170 1185 if balanced:
1171 1186 view_class = LoadBalancedView
1172 1187 view_cache = self._balanced_views
1173 1188 else:
1174 1189 view_class = DirectView
1175 1190 view_cache = self._direct_views
1176 1191
1177 1192 # use str, since often targets will be a list
1178 1193 key = str(targets)
1179 1194 if key not in view_cache:
1180 1195 view_cache[key] = view_class(client=self, targets=targets)
1181 1196
1182 1197 return view_cache[key]
1183 1198
1184 1199 def view(self, targets=None, balanced=None):
1185 1200 """Method for constructing View objects.
1186 1201
1187 1202 If no arguments are specified, create a LoadBalancedView
1188 1203 using all engines. If only `targets` specified, it will
1189 1204 be a DirectView. This method is the underlying implementation
1190 1205 of ``client.__getitem__``.
1191 1206
1192 1207 Parameters
1193 1208 ----------
1194 1209
1195 1210 targets: list,slice,int,etc. [default: use all engines]
1196 1211 The engines to use for the View
1197 1212 balanced : bool [default: False if targets specified, True else]
1198 1213 whether to build a LoadBalancedView or a DirectView
1199 1214
1200 1215 """
1201 1216
1202 1217 balanced = (targets is None) if balanced is None else balanced
1203 1218
1204 1219 if targets is None:
1205 1220 if balanced:
1206 1221 return self._cache_view(None,True)
1207 1222 else:
1208 1223 targets = slice(None)
1209 1224
1210 1225 if isinstance(targets, int):
1211 1226 if targets < 0:
1212 1227 targets = self.ids[targets]
1213 1228 if targets not in self.ids:
1214 1229 raise IndexError("No such engine: %i"%targets)
1215 1230 return self._cache_view(targets, balanced)
1216 1231
1217 1232 if isinstance(targets, slice):
1218 1233 indices = range(len(self.ids))[targets]
1219 1234 ids = sorted(self._ids)
1220 1235 targets = [ ids[i] for i in indices ]
1221 1236
1222 1237 if isinstance(targets, (tuple, list, xrange)):
1223 1238 _,targets = self._build_targets(list(targets))
1224 1239 return self._cache_view(targets, balanced)
1225 1240 else:
1226 1241 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1227 1242
1228 1243 #--------------------------------------------------------------------------
1229 1244 # Data movement
1230 1245 #--------------------------------------------------------------------------
1231 1246
1232 1247 @defaultblock
1233 def push(self, ns, targets='all', block=None):
1248 def push(self, ns, targets='all', block=None, track=False):
1234 1249 """Push the contents of `ns` into the namespace on `target`"""
1235 1250 if not isinstance(ns, dict):
1236 1251 raise TypeError("Must be a dict, not %s"%type(ns))
1237 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
1252 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track)
1238 1253 if not block:
1239 1254 return result
1240 1255
1241 1256 @defaultblock
1242 1257 def pull(self, keys, targets='all', block=None):
1243 1258 """Pull objects from `target`'s namespace by `keys`"""
1244 1259 if isinstance(keys, str):
1245 1260 pass
1246 1261 elif isinstance(keys, (list,tuple,set)):
1247 1262 for key in keys:
1248 1263 if not isinstance(key, str):
1249 1264 raise TypeError
1250 1265 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1251 1266 return result
1252 1267
1253 1268 @defaultblock
1254 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1269 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1255 1270 """
1256 1271 Partition a Python sequence and send the partitions to a set of engines.
1257 1272 """
1258 1273 targets = self._build_targets(targets)[-1]
1259 1274 mapObject = Map.dists[dist]()
1260 1275 nparts = len(targets)
1261 1276 msg_ids = []
1277 trackers = []
1262 1278 for index, engineid in enumerate(targets):
1263 1279 partition = mapObject.getPartition(seq, index, nparts)
1264 1280 if flatten and len(partition) == 1:
1265 r = self.push({key: partition[0]}, targets=engineid, block=False)
1281 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1266 1282 else:
1267 r = self.push({key: partition}, targets=engineid, block=False)
1283 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1268 1284 msg_ids.extend(r.msg_ids)
1269 r = AsyncResult(self, msg_ids, fname='scatter')
1285 if track:
1286 trackers.append(r._tracker)
1287
1288 if track:
1289 tracker = zmq.MessageTracker(*trackers)
1290 else:
1291 tracker = None
1292
1293 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1270 1294 if block:
1271 r.get()
1295 r.wait()
1272 1296 else:
1273 1297 return r
1274 1298
1275 1299 @defaultblock
1276 1300 def gather(self, key, dist='b', targets='all', block=None):
1277 1301 """
1278 1302 Gather a partitioned sequence on a set of engines as a single local seq.
1279 1303 """
1280 1304
1281 1305 targets = self._build_targets(targets)[-1]
1282 1306 mapObject = Map.dists[dist]()
1283 1307 msg_ids = []
1284 1308 for index, engineid in enumerate(targets):
1285 1309 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1286 1310
1287 1311 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1288 1312 if block:
1289 1313 return r.get()
1290 1314 else:
1291 1315 return r
1292 1316
1293 1317 #--------------------------------------------------------------------------
1294 1318 # Query methods
1295 1319 #--------------------------------------------------------------------------
1296 1320
1297 1321 @spinfirst
1298 1322 @defaultblock
1299 1323 def get_result(self, indices_or_msg_ids=None, block=None):
1300 1324 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1301 1325
1302 1326 If the client already has the results, no request to the Hub will be made.
1303 1327
1304 1328 This is a convenient way to construct AsyncResult objects, which are wrappers
1305 1329 that include metadata about execution, and allow for awaiting results that
1306 1330 were not submitted by this Client.
1307 1331
1308 1332 It can also be a convenient way to retrieve the metadata associated with
1309 1333 blocking execution, since it always retrieves
1310 1334
1311 1335 Examples
1312 1336 --------
1313 1337 ::
1314 1338
1315 1339 In [10]: r = client.apply()
1316 1340
1317 1341 Parameters
1318 1342 ----------
1319 1343
1320 1344 indices_or_msg_ids : integer history index, str msg_id, or list of either
1321 1345 The indices or msg_ids of indices to be retrieved
1322 1346
1323 1347 block : bool
1324 1348 Whether to wait for the result to be done
1325 1349
1326 1350 Returns
1327 1351 -------
1328 1352
1329 1353 AsyncResult
1330 1354 A single AsyncResult object will always be returned.
1331 1355
1332 1356 AsyncHubResult
1333 1357 A subclass of AsyncResult that retrieves results from the Hub
1334 1358
1335 1359 """
1336 1360 if indices_or_msg_ids is None:
1337 1361 indices_or_msg_ids = -1
1338 1362
1339 1363 if not isinstance(indices_or_msg_ids, (list,tuple)):
1340 1364 indices_or_msg_ids = [indices_or_msg_ids]
1341 1365
1342 1366 theids = []
1343 1367 for id in indices_or_msg_ids:
1344 1368 if isinstance(id, int):
1345 1369 id = self.history[id]
1346 1370 if not isinstance(id, str):
1347 1371 raise TypeError("indices must be str or int, not %r"%id)
1348 1372 theids.append(id)
1349 1373
1350 1374 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1351 1375 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1352 1376
1353 1377 if remote_ids:
1354 1378 ar = AsyncHubResult(self, msg_ids=theids)
1355 1379 else:
1356 1380 ar = AsyncResult(self, msg_ids=theids)
1357 1381
1358 1382 if block:
1359 1383 ar.wait()
1360 1384
1361 1385 return ar
1362 1386
1363 1387 @spinfirst
1364 1388 def result_status(self, msg_ids, status_only=True):
1365 1389 """Check on the status of the result(s) of the apply request with `msg_ids`.
1366 1390
1367 1391 If status_only is False, then the actual results will be retrieved, else
1368 1392 only the status of the results will be checked.
1369 1393
1370 1394 Parameters
1371 1395 ----------
1372 1396
1373 1397 msg_ids : list of msg_ids
1374 1398 if int:
1375 1399 Passed as index to self.history for convenience.
1376 1400 status_only : bool (default: True)
1377 1401 if False:
1378 1402 Retrieve the actual results of completed tasks.
1379 1403
1380 1404 Returns
1381 1405 -------
1382 1406
1383 1407 results : dict
1384 1408 There will always be the keys 'pending' and 'completed', which will
1385 1409 be lists of msg_ids that are incomplete or complete. If `status_only`
1386 1410 is False, then completed results will be keyed by their `msg_id`.
1387 1411 """
1388 1412 if not isinstance(msg_ids, (list,tuple)):
1389 1413 msg_ids = [msg_ids]
1390 1414
1391 1415 theids = []
1392 1416 for msg_id in msg_ids:
1393 1417 if isinstance(msg_id, int):
1394 1418 msg_id = self.history[msg_id]
1395 1419 if not isinstance(msg_id, basestring):
1396 1420 raise TypeError("msg_ids must be str, not %r"%msg_id)
1397 1421 theids.append(msg_id)
1398 1422
1399 1423 completed = []
1400 1424 local_results = {}
1401 1425
1402 1426 # comment this block out to temporarily disable local shortcut:
1403 1427 for msg_id in theids:
1404 1428 if msg_id in self.results:
1405 1429 completed.append(msg_id)
1406 1430 local_results[msg_id] = self.results[msg_id]
1407 1431 theids.remove(msg_id)
1408 1432
1409 1433 if theids: # some not locally cached
1410 1434 content = dict(msg_ids=theids, status_only=status_only)
1411 1435 msg = self.session.send(self._query_socket, "result_request", content=content)
1412 1436 zmq.select([self._query_socket], [], [])
1413 1437 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1414 1438 if self.debug:
1415 1439 pprint(msg)
1416 1440 content = msg['content']
1417 1441 if content['status'] != 'ok':
1418 1442 raise self._unwrap_exception(content)
1419 1443 buffers = msg['buffers']
1420 1444 else:
1421 1445 content = dict(completed=[],pending=[])
1422 1446
1423 1447 content['completed'].extend(completed)
1424 1448
1425 1449 if status_only:
1426 1450 return content
1427 1451
1428 1452 failures = []
1429 1453 # load cached results into result:
1430 1454 content.update(local_results)
1431 1455 # update cache with results:
1432 1456 for msg_id in sorted(theids):
1433 1457 if msg_id in content['completed']:
1434 1458 rec = content[msg_id]
1435 1459 parent = rec['header']
1436 1460 header = rec['result_header']
1437 1461 rcontent = rec['result_content']
1438 1462 iodict = rec['io']
1439 1463 if isinstance(rcontent, str):
1440 1464 rcontent = self.session.unpack(rcontent)
1441 1465
1442 1466 md = self.metadata[msg_id]
1443 1467 md.update(self._extract_metadata(header, parent, rcontent))
1444 1468 md.update(iodict)
1445 1469
1446 1470 if rcontent['status'] == 'ok':
1447 1471 res,buffers = util.unserialize_object(buffers)
1448 1472 else:
1449 1473 print rcontent
1450 1474 res = self._unwrap_exception(rcontent)
1451 1475 failures.append(res)
1452 1476
1453 1477 self.results[msg_id] = res
1454 1478 content[msg_id] = res
1455 1479
1456 1480 if len(theids) == 1 and failures:
1457 1481 raise failures[0]
1458 1482
1459 1483 error.collect_exceptions(failures, "result_status")
1460 1484 return content
1461 1485
1462 1486 @spinfirst
1463 1487 def queue_status(self, targets='all', verbose=False):
1464 1488 """Fetch the status of engine queues.
1465 1489
1466 1490 Parameters
1467 1491 ----------
1468 1492
1469 1493 targets : int/str/list of ints/strs
1470 1494 the engines whose states are to be queried.
1471 1495 default : all
1472 1496 verbose : bool
1473 1497 Whether to return lengths only, or lists of ids for each element
1474 1498 """
1475 1499 targets = self._build_targets(targets)[1]
1476 1500 content = dict(targets=targets, verbose=verbose)
1477 1501 self.session.send(self._query_socket, "queue_request", content=content)
1478 1502 idents,msg = self.session.recv(self._query_socket, 0)
1479 1503 if self.debug:
1480 1504 pprint(msg)
1481 1505 content = msg['content']
1482 1506 status = content.pop('status')
1483 1507 if status != 'ok':
1484 1508 raise self._unwrap_exception(content)
1485 1509 return util.rekey(content)
1486 1510
1487 1511 @spinfirst
1488 1512 def purge_results(self, jobs=[], targets=[]):
1489 1513 """Tell the controller to forget results.
1490 1514
1491 1515 Individual results can be purged by msg_id, or the entire
1492 1516 history of specific targets can be purged.
1493 1517
1494 1518 Parameters
1495 1519 ----------
1496 1520
1497 1521 jobs : str or list of strs or AsyncResult objects
1498 1522 the msg_ids whose results should be forgotten.
1499 1523 targets : int/str/list of ints/strs
1500 1524 The targets, by uuid or int_id, whose entire history is to be purged.
1501 1525 Use `targets='all'` to scrub everything from the controller's memory.
1502 1526
1503 1527 default : None
1504 1528 """
1505 1529 if not targets and not jobs:
1506 1530 raise ValueError("Must specify at least one of `targets` and `jobs`")
1507 1531 if targets:
1508 1532 targets = self._build_targets(targets)[1]
1509 1533
1510 1534 # construct msg_ids from jobs
1511 1535 msg_ids = []
1512 1536 if isinstance(jobs, (basestring,AsyncResult)):
1513 1537 jobs = [jobs]
1514 1538 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1515 1539 if bad_ids:
1516 1540 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1517 1541 for j in jobs:
1518 1542 if isinstance(j, AsyncResult):
1519 1543 msg_ids.extend(j.msg_ids)
1520 1544 else:
1521 1545 msg_ids.append(j)
1522 1546
1523 1547 content = dict(targets=targets, msg_ids=msg_ids)
1524 1548 self.session.send(self._query_socket, "purge_request", content=content)
1525 1549 idents, msg = self.session.recv(self._query_socket, 0)
1526 1550 if self.debug:
1527 1551 pprint(msg)
1528 1552 content = msg['content']
1529 1553 if content['status'] != 'ok':
1530 1554 raise self._unwrap_exception(content)
1531 1555
1532 1556
1533 1557 __all__ = [ 'Client',
1534 1558 'depend',
1535 1559 'require',
1536 1560 'remote',
1537 1561 'parallel',
1538 1562 'RemoteFunction',
1539 1563 'ParallelFunction',
1540 1564 'DirectView',
1541 1565 'LoadBalancedView',
1542 1566 'AsyncResult',
1543 1567 'AsyncMapResult',
1544 1568 'Reference'
1545 1569 ]
@@ -1,377 +1,412 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import pprint
8 8 import uuid
9 9 from datetime import datetime
10 10
11 11 try:
12 12 import cPickle
13 13 pickle = cPickle
14 14 except:
15 15 cPickle = None
16 16 import pickle
17 17
18 18 import zmq
19 19 from zmq.utils import jsonapi
20 20 from zmq.eventloop.zmqstream import ZMQStream
21 21
22 22 from .util import ISO8601
23 23
24 24 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
25 25 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
26 26 if json_name in ('jsonlib', 'jsonlib2'):
27 27 use_json = True
28 28 elif json_name:
29 29 if cPickle is None:
30 30 use_json = True
31 31 else:
32 32 use_json = False
33 33 else:
34 34 use_json = False
35 35
36 36 def squash_unicode(obj):
37 37 if isinstance(obj,dict):
38 38 for key in obj.keys():
39 39 obj[key] = squash_unicode(obj[key])
40 40 if isinstance(key, unicode):
41 41 obj[squash_unicode(key)] = obj.pop(key)
42 42 elif isinstance(obj, list):
43 43 for i,v in enumerate(obj):
44 44 obj[i] = squash_unicode(v)
45 45 elif isinstance(obj, unicode):
46 46 obj = obj.encode('utf8')
47 47 return obj
48 48
49 49 json_packer = jsonapi.dumps
50 50 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
51 51
52 52 pickle_packer = lambda o: pickle.dumps(o,-1)
53 53 pickle_unpacker = pickle.loads
54 54
55 55 if use_json:
56 56 default_packer = json_packer
57 57 default_unpacker = json_unpacker
58 58 else:
59 59 default_packer = pickle_packer
60 60 default_unpacker = pickle_unpacker
61 61
62 62
63 63 DELIM="<IDS|MSG>"
64 64
65 65 class Message(object):
66 66 """A simple message object that maps dict keys to attributes.
67 67
68 68 A Message can be created from a dict and a dict from a Message instance
69 69 simply by calling dict(msg_obj)."""
70 70
71 71 def __init__(self, msg_dict):
72 72 dct = self.__dict__
73 73 for k, v in dict(msg_dict).iteritems():
74 74 if isinstance(v, dict):
75 75 v = Message(v)
76 76 dct[k] = v
77 77
78 78 # Having this iterator lets dict(msg_obj) work out of the box.
79 79 def __iter__(self):
80 80 return iter(self.__dict__.iteritems())
81 81
82 82 def __repr__(self):
83 83 return repr(self.__dict__)
84 84
85 85 def __str__(self):
86 86 return pprint.pformat(self.__dict__)
87 87
88 88 def __contains__(self, k):
89 89 return k in self.__dict__
90 90
91 91 def __getitem__(self, k):
92 92 return self.__dict__[k]
93 93
94 94
95 95 def msg_header(msg_id, msg_type, username, session):
96 96 date=datetime.now().strftime(ISO8601)
97 97 return locals()
98 98
99 99 def extract_header(msg_or_header):
100 100 """Given a message or header, return the header."""
101 101 if not msg_or_header:
102 102 return {}
103 103 try:
104 104 # See if msg_or_header is the entire message.
105 105 h = msg_or_header['header']
106 106 except KeyError:
107 107 try:
108 108 # See if msg_or_header is just the header
109 109 h = msg_or_header['msg_id']
110 110 except KeyError:
111 111 raise
112 112 else:
113 113 h = msg_or_header
114 114 if not isinstance(h, dict):
115 115 h = dict(h)
116 116 return h
117 117
118 118 class StreamSession(object):
119 119 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
120 120 debug=False
121 121 key=None
122 122
123 123 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
124 124 if username is None:
125 125 username = os.environ.get('USER','username')
126 126 self.username = username
127 127 if session is None:
128 128 self.session = str(uuid.uuid4())
129 129 else:
130 130 self.session = session
131 131 self.msg_id = str(uuid.uuid4())
132 132 if packer is None:
133 133 self.pack = default_packer
134 134 else:
135 135 if not callable(packer):
136 136 raise TypeError("packer must be callable, not %s"%type(packer))
137 137 self.pack = packer
138 138
139 139 if unpacker is None:
140 140 self.unpack = default_unpacker
141 141 else:
142 142 if not callable(unpacker):
143 143 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
144 144 self.unpack = unpacker
145 145
146 146 if key is not None and keyfile is not None:
147 147 raise TypeError("Must specify key OR keyfile, not both")
148 148 if keyfile is not None:
149 149 with open(keyfile) as f:
150 150 self.key = f.read().strip()
151 151 else:
152 152 self.key = key
153 153 if isinstance(self.key, unicode):
154 154 self.key = self.key.encode('utf8')
155 155 # print key, keyfile, self.key
156 156 self.none = self.pack({})
157 157
158 158 def msg_header(self, msg_type):
159 159 h = msg_header(self.msg_id, msg_type, self.username, self.session)
160 160 self.msg_id = str(uuid.uuid4())
161 161 return h
162 162
163 163 def msg(self, msg_type, content=None, parent=None, subheader=None):
164 164 msg = {}
165 165 msg['header'] = self.msg_header(msg_type)
166 166 msg['msg_id'] = msg['header']['msg_id']
167 167 msg['parent_header'] = {} if parent is None else extract_header(parent)
168 168 msg['msg_type'] = msg_type
169 169 msg['content'] = {} if content is None else content
170 170 sub = {} if subheader is None else subheader
171 171 msg['header'].update(sub)
172 172 return msg
173 173
174 174 def check_key(self, msg_or_header):
175 175 """Check that a message's header has the right key"""
176 176 if self.key is None:
177 177 return True
178 178 header = extract_header(msg_or_header)
179 179 return header.get('key', None) == self.key
180 180
181 181
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
183 183 """Build and send a message via stream or socket.
184 184
185 185 Parameters
186 186 ----------
187 187
188 188 stream : zmq.Socket or ZMQStream
189 189 the socket-like object used to send the data
190 190 msg_or_type : str or Message/dict
191 191 Normally, msg_or_type will be a msg_type unless a message is being sent more
192 192 than once.
193 193
194 content : dict or None
195 the content of the message (ignored if msg_or_type is a message)
196 buffers : list or None
197 the already-serialized buffers to be appended to the message
198 parent : Message or dict or None
199 the parent or parent header describing the parent of this message
200 subheader : dict or None
201 extra header keys for this message's header
202 ident : bytes or list of bytes
203 the zmq.IDENTITY routing path
204 track : bool
205 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
206
194 207 Returns
195 208 -------
196 (msg,sent) : tuple
197 msg : Message
198 the nice wrapped dict-like object containing the headers
209 msg : message dict
210 the constructed message
211 (msg,tracker) : (message dict, MessageTracker)
212 if track=True, then a 2-tuple will be returned, the first element being the constructed
213 message, and the second being the MessageTracker
199 214
200 215 """
216
217 if not isinstance(stream, (zmq.Socket, ZMQStream)):
218 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
219 elif track and isinstance(stream, ZMQStream):
220 raise TypeError("ZMQStream cannot track messages")
221
201 222 if isinstance(msg_or_type, (Message, dict)):
202 223 # we got a Message, not a msg_type
203 224 # don't build a new Message
204 225 msg = msg_or_type
205 226 content = msg['content']
206 227 else:
207 228 msg = self.msg(msg_or_type, content, parent, subheader)
229
208 230 buffers = [] if buffers is None else buffers
209 231 to_send = []
210 232 if isinstance(ident, list):
211 233 # accept list of idents
212 234 to_send.extend(ident)
213 235 elif ident is not None:
214 236 to_send.append(ident)
215 237 to_send.append(DELIM)
216 238 if self.key is not None:
217 239 to_send.append(self.key)
218 240 to_send.append(self.pack(msg['header']))
219 241 to_send.append(self.pack(msg['parent_header']))
220 242
221 243 if content is None:
222 244 content = self.none
223 245 elif isinstance(content, dict):
224 246 content = self.pack(content)
225 elif isinstance(content, str):
247 elif isinstance(content, bytes):
226 248 # content is already packed, as in a relayed message
227 249 pass
228 250 else:
229 251 raise TypeError("Content incorrect type: %s"%type(content))
230 252 to_send.append(content)
231 253 flag = 0
232 254 if buffers:
233 255 flag = zmq.SNDMORE
234 stream.send_multipart(to_send, flag, copy=False)
256 _track = False
257 else:
258 _track=track
259 if track:
260 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
261 else:
262 tracker = stream.send_multipart(to_send, flag, copy=False)
235 263 for b in buffers[:-1]:
236 264 stream.send(b, flag, copy=False)
237 265 if buffers:
238 stream.send(buffers[-1], copy=False)
266 if track:
267 tracker = stream.send(buffers[-1], copy=False, track=track)
268 else:
269 tracker = stream.send(buffers[-1], copy=False)
270
239 271 # omsg = Message(msg)
240 272 if self.debug:
241 273 pprint.pprint(msg)
242 274 pprint.pprint(to_send)
243 275 pprint.pprint(buffers)
276
277 msg['tracker'] = tracker
278
244 279 return msg
245 280
246 281 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
247 282 """Send a raw message via ident path.
248 283
249 284 Parameters
250 285 ----------
251 286 msg : list of sendable buffers"""
252 287 to_send = []
253 if isinstance(ident, str):
288 if isinstance(ident, bytes):
254 289 ident = [ident]
255 290 if ident is not None:
256 291 to_send.extend(ident)
257 292 to_send.append(DELIM)
258 293 if self.key is not None:
259 294 to_send.append(self.key)
260 295 to_send.extend(msg)
261 296 stream.send_multipart(msg, flags, copy=copy)
262 297
263 298 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
264 299 """receives and unpacks a message
265 300 returns [idents], msg"""
266 301 if isinstance(socket, ZMQStream):
267 302 socket = socket.socket
268 303 try:
269 304 msg = socket.recv_multipart(mode)
270 305 except zmq.ZMQError as e:
271 306 if e.errno == zmq.EAGAIN:
272 307 # We can convert EAGAIN to None as we know in this case
273 308 # recv_multipart won't return None.
274 309 return None
275 310 else:
276 311 raise
277 312 # return an actual Message object
278 313 # determine the number of idents by trying to unpack them.
279 314 # this is terrible:
280 315 idents, msg = self.feed_identities(msg, copy)
281 316 try:
282 317 return idents, self.unpack_message(msg, content=content, copy=copy)
283 318 except Exception as e:
284 319 print (idents, msg)
285 320 # TODO: handle it
286 321 raise e
287 322
288 323 def feed_identities(self, msg, copy=True):
289 324 """feed until DELIM is reached, then return the prefix as idents and remainder as
290 325 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
291 326
292 327 Parameters
293 328 ----------
294 329 msg : a list of Message or bytes objects
295 330 the message to be split
296 331 copy : bool
297 332 flag determining whether the arguments are bytes or Messages
298 333
299 334 Returns
300 335 -------
301 336 (idents,msg) : two lists
302 337 idents will always be a list of bytes - the indentity prefix
303 338 msg will be a list of bytes or Messages, unchanged from input
304 339 msg should be unpackable via self.unpack_message at this point.
305 340 """
306 341 ikey = int(self.key is not None)
307 342 minlen = 3 + ikey
308 343 msg = list(msg)
309 344 idents = []
310 345 while len(msg) > minlen:
311 346 if copy:
312 347 s = msg[0]
313 348 else:
314 349 s = msg[0].bytes
315 350 if s == DELIM:
316 351 msg.pop(0)
317 352 break
318 353 else:
319 354 idents.append(s)
320 355 msg.pop(0)
321 356
322 357 return idents, msg
323 358
324 359 def unpack_message(self, msg, content=True, copy=True):
325 360 """Return a message object from the format
326 361 sent by self.send.
327 362
328 363 Parameters:
329 364 -----------
330 365
331 366 content : bool (True)
332 367 whether to unpack the content dict (True),
333 368 or leave it serialized (False)
334 369
335 370 copy : bool (True)
336 371 whether to return the bytes (True),
337 372 or the non-copying Message object in each place (False)
338 373
339 374 """
340 375 ikey = int(self.key is not None)
341 376 minlen = 3 + ikey
342 377 message = {}
343 378 if not copy:
344 379 for i in range(minlen):
345 380 msg[i] = msg[i].bytes
346 381 if ikey:
347 382 if not self.key == msg[0]:
348 383 raise KeyError("Invalid Session Key: %s"%msg[0])
349 384 if not len(msg) >= minlen:
350 385 raise TypeError("malformed message, must have at least %i elements"%minlen)
351 386 message['header'] = self.unpack(msg[ikey+0])
352 387 message['msg_type'] = message['header']['msg_type']
353 388 message['parent_header'] = self.unpack(msg[ikey+1])
354 389 if content:
355 390 message['content'] = self.unpack(msg[ikey+2])
356 391 else:
357 392 message['content'] = msg[ikey+2]
358 393
359 394 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
360 395 return message
361 396
362 397
363 398 def test_msg2obj():
364 399 am = dict(x=1)
365 400 ao = Message(am)
366 401 assert ao.x == am['x']
367 402
368 403 am['y'] = dict(z=1)
369 404 ao = Message(am)
370 405 assert ao.y.z == am['y']['z']
371 406
372 407 k1, k2 = 'y', 'z'
373 408 assert ao[k1][k2] == am[k1][k2]
374 409
375 410 am2 = dict(ao)
376 411 assert am['x'] == am2['x']
377 412 assert am['y']['z'] == am2['y']['z']
@@ -1,44 +1,46 b''
1 1 """toplevel setup/teardown for parallel tests."""
2 2
3 import tempfile
3 4 import time
4 from subprocess import Popen, PIPE
5 from subprocess import Popen, PIPE, STDOUT
5 6
6 7 from IPython.zmq.parallel.ipcluster import launch_process
7 8 from IPython.zmq.parallel.entry_point import select_random_ports
8 9
9 10 processes = []
11 blackhole = tempfile.TemporaryFile()
10 12
11 13 # nose setup/teardown
12 14
13 15 def setup():
14 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
16 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
15 17 processes.append(cp)
16 18 time.sleep(.5)
17 19 add_engine()
18 time.sleep(3)
20 time.sleep(2)
19 21
20 22 def add_engine(profile='iptest'):
21 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
23 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
22 24 # ep.start()
23 25 processes.append(ep)
24 26 return ep
25 27
26 28 def teardown():
27 29 time.sleep(1)
28 30 while processes:
29 31 p = processes.pop()
30 32 if p.poll() is None:
31 33 try:
32 34 p.terminate()
33 35 except Exception, e:
34 36 print e
35 37 pass
36 38 if p.poll() is None:
37 39 time.sleep(.25)
38 40 if p.poll() is None:
39 41 try:
40 42 print 'killing'
41 43 p.kill()
42 44 except:
43 45 print "couldn't shutdown process: ",p
44 46
@@ -1,98 +1,100 b''
1 1 import time
2 2 from signal import SIGINT
3 3 from multiprocessing import Process
4 4
5 5 from nose import SkipTest
6 6
7 7 from zmq.tests import BaseZMQTestCase
8 8
9 9 from IPython.external.decorator import decorator
10 10
11 11 from IPython.zmq.parallel import error
12 12 from IPython.zmq.parallel.client import Client
13 13 from IPython.zmq.parallel.ipcluster import launch_process
14 14 from IPython.zmq.parallel.entry_point import select_random_ports
15 15 from IPython.zmq.parallel.tests import processes,add_engine
16 16
17 17 # simple tasks for use in apply tests
18 18
19 19 def segfault():
20 20 """this will segfault"""
21 21 import ctypes
22 22 ctypes.memset(-1,0,1)
23 23
24 24 def wait(n):
25 25 """sleep for a time"""
26 26 import time
27 27 time.sleep(n)
28 28 return n
29 29
30 30 def raiser(eclass):
31 31 """raise an exception"""
32 32 raise eclass()
33 33
34 34 # test decorator for skipping tests when libraries are unavailable
35 35 def skip_without(*names):
36 36 """skip a test if some names are not importable"""
37 37 @decorator
38 38 def skip_without_names(f, *args, **kwargs):
39 39 """decorator to skip tests in the absence of numpy."""
40 40 for name in names:
41 41 try:
42 42 __import__(name)
43 43 except ImportError:
44 44 raise SkipTest
45 45 return f(*args, **kwargs)
46 46 return skip_without_names
47 47
48 48
49 49 class ClusterTestCase(BaseZMQTestCase):
50 50
51 51 def add_engines(self, n=1, block=True):
52 52 """add multiple engines to our cluster"""
53 53 for i in range(n):
54 54 self.engines.append(add_engine())
55 55 if block:
56 56 self.wait_on_engines()
57 57
58 58 def wait_on_engines(self, timeout=5):
59 59 """wait for our engines to connect."""
60 60 n = len(self.engines)+self.base_engine_count
61 61 tic = time.time()
62 62 while time.time()-tic < timeout and len(self.client.ids) < n:
63 63 time.sleep(0.1)
64 64
65 65 assert not self.client.ids < n, "waiting for engines timed out"
66 66
67 67 def connect_client(self):
68 68 """connect a client with my Context, and track its sockets for cleanup"""
69 69 c = Client(profile='iptest',context=self.context)
70 70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
71 71 self.sockets.append(getattr(c, name))
72 72 return c
73 73
74 74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 75 try:
76 76 try:
77 77 f(*args, **kwargs)
78 78 except error.CompositeError as e:
79 79 e.raise_exception()
80 80 except error.RemoteError as e:
81 81 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
82 82 else:
83 83 self.fail("should have raised a RemoteError")
84 84
85 85 def setUp(self):
86 86 BaseZMQTestCase.setUp(self)
87 87 self.client = self.connect_client()
88 88 self.base_engine_count=len(self.client.ids)
89 89 self.engines=[]
90 90
91 # def tearDown(self):
91 def tearDown(self):
92 self.client.close()
93 BaseZMQTestCase.tearDown(self)
92 94 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
93 95 # [ e.wait() for e in self.engines ]
94 96 # while len(self.client.ids) > self.base_engine_count:
95 97 # time.sleep(.1)
96 98 # del self.engines
97 99 # BaseZMQTestCase.tearDown(self)
98 100 No newline at end of file
@@ -1,187 +1,252 b''
1 1 import time
2 2 from tempfile import mktemp
3 3
4 4 import nose.tools as nt
5 import zmq
5 6
6 7 from IPython.zmq.parallel import client as clientmod
7 8 from IPython.zmq.parallel import error
8 9 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
9 10 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
10 11
11 12 from clienttest import ClusterTestCase, segfault, wait
12 13
13 14 class TestClient(ClusterTestCase):
14 15
15 16 def test_ids(self):
16 17 n = len(self.client.ids)
17 18 self.add_engines(3)
18 19 self.assertEquals(len(self.client.ids), n+3)
19 20 self.assertTrue
20 21
21 def test_segfault(self):
22 """test graceful handling of engine death"""
22 def test_segfault_task(self):
23 """test graceful handling of engine death (balanced)"""
23 24 self.add_engines(1)
24 eid = self.client.ids[-1]
25 25 ar = self.client.apply(segfault, block=False)
26 26 self.assertRaisesRemote(error.EngineError, ar.get)
27 27 eid = ar.engine_id
28 28 while eid in self.client.ids:
29 29 time.sleep(.01)
30 30 self.client.spin()
31 31
32 def test_segfault_mux(self):
33 """test graceful handling of engine death (direct)"""
34 self.add_engines(1)
35 eid = self.client.ids[-1]
36 ar = self.client[eid].apply_async(segfault)
37 self.assertRaisesRemote(error.EngineError, ar.get)
38 eid = ar.engine_id
39 while eid in self.client.ids:
40 time.sleep(.01)
41 self.client.spin()
42
32 43 def test_view_indexing(self):
33 44 """test index access for views"""
34 45 self.add_engines(2)
35 46 targets = self.client._build_targets('all')[-1]
36 47 v = self.client[:]
37 48 self.assertEquals(v.targets, targets)
38 49 t = self.client.ids[2]
39 50 v = self.client[t]
40 51 self.assert_(isinstance(v, DirectView))
41 52 self.assertEquals(v.targets, t)
42 53 t = self.client.ids[2:4]
43 54 v = self.client[t]
44 55 self.assert_(isinstance(v, DirectView))
45 56 self.assertEquals(v.targets, t)
46 57 v = self.client[::2]
47 58 self.assert_(isinstance(v, DirectView))
48 59 self.assertEquals(v.targets, targets[::2])
49 60 v = self.client[1::3]
50 61 self.assert_(isinstance(v, DirectView))
51 62 self.assertEquals(v.targets, targets[1::3])
52 63 v = self.client[:-3]
53 64 self.assert_(isinstance(v, DirectView))
54 65 self.assertEquals(v.targets, targets[:-3])
55 66 v = self.client[-1]
56 67 self.assert_(isinstance(v, DirectView))
57 68 self.assertEquals(v.targets, targets[-1])
58 69 nt.assert_raises(TypeError, lambda : self.client[None])
59 70
60 71 def test_view_cache(self):
61 72 """test that multiple view requests return the same object"""
62 73 v = self.client[:2]
63 74 v2 =self.client[:2]
64 75 self.assertTrue(v is v2)
65 76 v = self.client.view()
66 77 v2 = self.client.view(balanced=True)
67 78 self.assertTrue(v is v2)
68 79
69 80 def test_targets(self):
70 81 """test various valid targets arguments"""
71 82 build = self.client._build_targets
72 83 ids = self.client.ids
73 84 idents,targets = build(None)
74 85 self.assertEquals(ids, targets)
75 86
76 87 def test_clear(self):
77 88 """test clear behavior"""
78 89 self.add_engines(2)
79 90 self.client.block=True
80 91 self.client.push(dict(a=5))
81 92 self.client.pull('a')
82 93 id0 = self.client.ids[-1]
83 94 self.client.clear(targets=id0)
84 95 self.client.pull('a', targets=self.client.ids[:-1])
85 96 self.assertRaisesRemote(NameError, self.client.pull, 'a')
86 97 self.client.clear()
87 98 for i in self.client.ids:
88 99 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
89 100
90 101
91 102 def test_push_pull(self):
92 103 """test pushing and pulling"""
93 104 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
105 t = self.client.ids[-1]
94 106 self.add_engines(2)
95 107 push = self.client.push
96 108 pull = self.client.pull
97 109 self.client.block=True
98 110 nengines = len(self.client)
99 push({'data':data}, targets=0)
100 d = pull('data', targets=0)
111 push({'data':data}, targets=t)
112 d = pull('data', targets=t)
101 113 self.assertEquals(d, data)
102 114 push({'data':data})
103 115 d = pull('data')
104 116 self.assertEquals(d, nengines*[data])
105 117 ar = push({'data':data}, block=False)
106 118 self.assertTrue(isinstance(ar, AsyncResult))
107 119 r = ar.get()
108 120 ar = pull('data', block=False)
109 121 self.assertTrue(isinstance(ar, AsyncResult))
110 122 r = ar.get()
111 123 self.assertEquals(r, nengines*[data])
112 124 push(dict(a=10,b=20))
113 125 r = pull(('a','b'))
114 126 self.assertEquals(r, nengines*[[10,20]])
115 127
116 128 def test_push_pull_function(self):
117 129 "test pushing and pulling functions"
118 130 def testf(x):
119 131 return 2.0*x
120 132
121 133 self.add_engines(4)
134 t = self.client.ids[-1]
122 135 self.client.block=True
123 136 push = self.client.push
124 137 pull = self.client.pull
125 138 execute = self.client.execute
126 push({'testf':testf}, targets=0)
127 r = pull('testf', targets=0)
139 push({'testf':testf}, targets=t)
140 r = pull('testf', targets=t)
128 141 self.assertEqual(r(1.0), testf(1.0))
129 execute('r = testf(10)', targets=0)
130 r = pull('r', targets=0)
142 execute('r = testf(10)', targets=t)
143 r = pull('r', targets=t)
131 144 self.assertEquals(r, testf(10))
132 145 ar = push({'testf':testf}, block=False)
133 146 ar.get()
134 147 ar = pull('testf', block=False)
135 148 rlist = ar.get()
136 149 for r in rlist:
137 150 self.assertEqual(r(1.0), testf(1.0))
138 execute("def g(x): return x*x", targets=0)
139 r = pull(('testf','g'),targets=0)
151 execute("def g(x): return x*x", targets=t)
152 r = pull(('testf','g'),targets=t)
140 153 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
141 154
142 155 def test_push_function_globals(self):
143 156 """test that pushed functions have access to globals"""
144 157 def geta():
145 158 return a
146 159 self.add_engines(1)
147 160 v = self.client[-1]
148 161 v.block=True
149 162 v['f'] = geta
150 163 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
151 164 v.execute('a=5')
152 165 v.execute('b=f()')
153 166 self.assertEquals(v['b'], 5)
154 167
155 168 def test_get_result(self):
156 169 """test getting results from the Hub."""
157 170 c = clientmod.Client(profile='iptest')
158 171 t = self.client.ids[-1]
159 172 ar = c.apply(wait, (1,), block=False, targets=t)
160 173 time.sleep(.25)
161 174 ahr = self.client.get_result(ar.msg_ids)
162 175 self.assertTrue(isinstance(ahr, AsyncHubResult))
163 176 self.assertEquals(ahr.get(), ar.get())
164 177 ar2 = self.client.get_result(ar.msg_ids)
165 178 self.assertFalse(isinstance(ar2, AsyncHubResult))
166 179
167 180 def test_ids_list(self):
168 181 """test client.ids"""
169 182 self.add_engines(2)
170 183 ids = self.client.ids
171 184 self.assertEquals(ids, self.client._ids)
172 185 self.assertFalse(ids is self.client._ids)
173 186 ids.remove(ids[-1])
174 187 self.assertNotEquals(ids, self.client._ids)
175 188
176 def test_arun_newline(self):
189 def test_run_newline(self):
177 190 """test that run appends newline to files"""
178 191 tmpfile = mktemp()
179 192 with open(tmpfile, 'w') as f:
180 193 f.write("""def g():
181 194 return 5
182 195 """)
183 196 v = self.client[-1]
184 197 v.run(tmpfile, block=True)
185 198 self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
186 199
187 No newline at end of file
200 def test_apply_tracked(self):
201 """test tracking for apply"""
202 # self.add_engines(1)
203 t = self.client.ids[-1]
204 self.client.block=False
205 def echo(n=1024*1024, **kwargs):
206 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
207 ar = echo(1)
208 self.assertTrue(ar._tracker is None)
209 self.assertTrue(ar.sent)
210 ar = echo(track=True)
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 self.assertEquals(ar.sent, ar._tracker.done)
213 ar._tracker.wait()
214 self.assertTrue(ar.sent)
215
216 def test_push_tracked(self):
217 t = self.client.ids[-1]
218 ns = dict(x='x'*1024*1024)
219 ar = self.client.push(ns, targets=t, block=False)
220 self.assertTrue(ar._tracker is None)
221 self.assertTrue(ar.sent)
222
223 ar = self.client.push(ns, targets=t, block=False, track=True)
224 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
225 self.assertEquals(ar.sent, ar._tracker.done)
226 ar._tracker.wait()
227 self.assertTrue(ar.sent)
228 ar.get()
229
230 def test_scatter_tracked(self):
231 t = self.client.ids
232 x='x'*1024*1024
233 ar = self.client.scatter('x', x, targets=t, block=False)
234 self.assertTrue(ar._tracker is None)
235 self.assertTrue(ar.sent)
236
237 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
238 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
239 self.assertEquals(ar.sent, ar._tracker.done)
240 ar._tracker.wait()
241 self.assertTrue(ar.sent)
242 ar.get()
243
244 def test_remote_reference(self):
245 v = self.client[-1]
246 v['a'] = 123
247 ra = clientmod.Reference('a')
248 b = v.apply_sync_bound(lambda x: x, ra)
249 self.assertEquals(b, 123)
250 self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra)
251
252
@@ -1,82 +1,99 b''
1 1
2 2 import os
3 3 import uuid
4 4 import zmq
5 5
6 6 from zmq.tests import BaseZMQTestCase
7
7 from zmq.eventloop.zmqstream import ZMQStream
8 8 # from IPython.zmq.tests import SessionTestCase
9 9 from IPython.zmq.parallel import streamsession as ss
10 10
11 11 class SessionTestCase(BaseZMQTestCase):
12 12
13 13 def setUp(self):
14 14 BaseZMQTestCase.setUp(self)
15 15 self.session = ss.StreamSession()
16 16
17 17 class TestSession(SessionTestCase):
18 18
19 19 def test_msg(self):
20 20 """message format"""
21 21 msg = self.session.msg('execute')
22 22 thekeys = set('header msg_id parent_header msg_type content'.split())
23 23 s = set(msg.keys())
24 24 self.assertEquals(s, thekeys)
25 25 self.assertTrue(isinstance(msg['content'],dict))
26 26 self.assertTrue(isinstance(msg['header'],dict))
27 27 self.assertTrue(isinstance(msg['parent_header'],dict))
28 28 self.assertEquals(msg['msg_type'], 'execute')
29 29
30 30
31 31
32 32 def test_args(self):
33 33 """initialization arguments for StreamSession"""
34 s = ss.StreamSession()
34 s = self.session
35 35 self.assertTrue(s.pack is ss.default_packer)
36 36 self.assertTrue(s.unpack is ss.default_unpacker)
37 37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
38 38
39 39 s = ss.StreamSession(username=None)
40 40 self.assertEquals(s.username, os.environ.get('USER', 'username'))
41 41
42 42 self.assertRaises(TypeError, ss.StreamSession, packer='hi')
43 43 self.assertRaises(TypeError, ss.StreamSession, unpacker='hi')
44 44 u = str(uuid.uuid4())
45 45 s = ss.StreamSession(username='carrot', session=u)
46 46 self.assertEquals(s.session, u)
47 47 self.assertEquals(s.username, 'carrot')
48 48
49
49 def test_tracking(self):
50 """test tracking messages"""
51 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
52 s = self.session
53 stream = ZMQStream(a)
54 msg = s.send(a, 'hello', track=False)
55 self.assertTrue(msg['tracker'] is None)
56 msg = s.send(a, 'hello', track=True)
57 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
58 M = zmq.Message(b'hi there', track=True)
59 msg = s.send(a, 'hello', buffers=[M], track=True)
60 t = msg['tracker']
61 self.assertTrue(isinstance(t, zmq.MessageTracker))
62 self.assertRaises(zmq.NotDone, t.wait, .1)
63 del M
64 t.wait(1) # this will raise
65
66
50 67 # def test_rekey(self):
51 68 # """rekeying dict around json str keys"""
52 69 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
53 70 # self.assertRaises(KeyError, ss.rekey, d)
54 71 #
55 72 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
56 73 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
57 74 # rd = ss.rekey(d)
58 75 # self.assertEquals(d2,rd)
59 76 #
60 77 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
61 78 # d2 = {1.5:d['1.5'],1:d['1']}
62 79 # rd = ss.rekey(d)
63 80 # self.assertEquals(d2,rd)
64 81 #
65 82 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
66 83 # self.assertRaises(KeyError, ss.rekey, d)
67 84 #
68 85 def test_unique_msg_ids(self):
69 86 """test that messages receive unique ids"""
70 87 ids = set()
71 88 for i in range(2**12):
72 89 h = self.session.msg_header('test')
73 90 msg_id = h['msg_id']
74 91 self.assertTrue(msg_id not in ids)
75 92 ids.add(msg_id)
76 93
77 94 def test_feed_identities(self):
78 95 """scrub the front for zmq IDENTITIES"""
79 96 theids = "engine client other".split()
80 97 content = dict(code='whoda',stuff=object())
81 98 themsg = self.session.msg('execute',content=content)
82 99 pmsg = theids
General Comments 0
You need to be logged in to leave comments. Login now