##// END OF EJS Templates
add message tracking to client, add/improve tests
MinRK -
Show More
@@ -1,305 +1,322 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import time
13 import time
14
14
15 from IPython.external.decorator import decorator
15 from IPython.external.decorator import decorator
16 from . import error
16 from . import error
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Classes
19 # Classes
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 @decorator
22 @decorator
23 def check_ready(f, self, *args, **kwargs):
23 def check_ready(f, self, *args, **kwargs):
24 """Call spin() to sync state prior to calling the method."""
24 """Call spin() to sync state prior to calling the method."""
25 self.wait(0)
25 self.wait(0)
26 if not self._ready:
26 if not self._ready:
27 raise error.TimeoutError("result not ready")
27 raise error.TimeoutError("result not ready")
28 return f(self, *args, **kwargs)
28 return f(self, *args, **kwargs)
29
29
30 class AsyncResult(object):
30 class AsyncResult(object):
31 """Class for representing results of non-blocking calls.
31 """Class for representing results of non-blocking calls.
32
32
33 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
33 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
34 """
34 """
35
35
36 msg_ids = None
36 msg_ids = None
37 _targets = None
38 _tracker = None
37
39
38 def __init__(self, client, msg_ids, fname='unknown'):
40 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
39 self._client = client
41 self._client = client
40 if isinstance(msg_ids, basestring):
42 if isinstance(msg_ids, basestring):
41 msg_ids = [msg_ids]
43 msg_ids = [msg_ids]
42 self.msg_ids = msg_ids
44 self.msg_ids = msg_ids
43 self._fname=fname
45 self._fname=fname
46 self._targets = targets
47 self._tracker = tracker
44 self._ready = False
48 self._ready = False
45 self._success = None
49 self._success = None
46 self._single_result = len(msg_ids) == 1
50 self._single_result = len(msg_ids) == 1
47
51
48 def __repr__(self):
52 def __repr__(self):
49 if self._ready:
53 if self._ready:
50 return "<%s: finished>"%(self.__class__.__name__)
54 return "<%s: finished>"%(self.__class__.__name__)
51 else:
55 else:
52 return "<%s: %s>"%(self.__class__.__name__,self._fname)
56 return "<%s: %s>"%(self.__class__.__name__,self._fname)
53
57
54
58
55 def _reconstruct_result(self, res):
59 def _reconstruct_result(self, res):
56 """Reconstruct our result from actual result list (always a list)
60 """Reconstruct our result from actual result list (always a list)
57
61
58 Override me in subclasses for turning a list of results
62 Override me in subclasses for turning a list of results
59 into the expected form.
63 into the expected form.
60 """
64 """
61 if self._single_result:
65 if self._single_result:
62 return res[0]
66 return res[0]
63 else:
67 else:
64 return res
68 return res
65
69
66 def get(self, timeout=-1):
70 def get(self, timeout=-1):
67 """Return the result when it arrives.
71 """Return the result when it arrives.
68
72
69 If `timeout` is not ``None`` and the result does not arrive within
73 If `timeout` is not ``None`` and the result does not arrive within
70 `timeout` seconds then ``TimeoutError`` is raised. If the
74 `timeout` seconds then ``TimeoutError`` is raised. If the
71 remote call raised an exception then that exception will be reraised
75 remote call raised an exception then that exception will be reraised
72 by get() inside a `RemoteError`.
76 by get() inside a `RemoteError`.
73 """
77 """
74 if not self.ready():
78 if not self.ready():
75 self.wait(timeout)
79 self.wait(timeout)
76
80
77 if self._ready:
81 if self._ready:
78 if self._success:
82 if self._success:
79 return self._result
83 return self._result
80 else:
84 else:
81 raise self._exception
85 raise self._exception
82 else:
86 else:
83 raise error.TimeoutError("Result not ready.")
87 raise error.TimeoutError("Result not ready.")
84
88
85 def ready(self):
89 def ready(self):
86 """Return whether the call has completed."""
90 """Return whether the call has completed."""
87 if not self._ready:
91 if not self._ready:
88 self.wait(0)
92 self.wait(0)
89 return self._ready
93 return self._ready
90
94
91 def wait(self, timeout=-1):
95 def wait(self, timeout=-1):
92 """Wait until the result is available or until `timeout` seconds pass.
96 """Wait until the result is available or until `timeout` seconds pass.
93
97
94 This method always returns None.
98 This method always returns None.
95 """
99 """
96 if self._ready:
100 if self._ready:
97 return
101 return
98 self._ready = self._client.barrier(self.msg_ids, timeout)
102 self._ready = self._client.barrier(self.msg_ids, timeout)
99 if self._ready:
103 if self._ready:
100 try:
104 try:
101 results = map(self._client.results.get, self.msg_ids)
105 results = map(self._client.results.get, self.msg_ids)
102 self._result = results
106 self._result = results
103 if self._single_result:
107 if self._single_result:
104 r = results[0]
108 r = results[0]
105 if isinstance(r, Exception):
109 if isinstance(r, Exception):
106 raise r
110 raise r
107 else:
111 else:
108 results = error.collect_exceptions(results, self._fname)
112 results = error.collect_exceptions(results, self._fname)
109 self._result = self._reconstruct_result(results)
113 self._result = self._reconstruct_result(results)
110 except Exception, e:
114 except Exception, e:
111 self._exception = e
115 self._exception = e
112 self._success = False
116 self._success = False
113 else:
117 else:
114 self._success = True
118 self._success = True
115 finally:
119 finally:
116 self._metadata = map(self._client.metadata.get, self.msg_ids)
120 self._metadata = map(self._client.metadata.get, self.msg_ids)
117
121
118
122
119 def successful(self):
123 def successful(self):
120 """Return whether the call completed without raising an exception.
124 """Return whether the call completed without raising an exception.
121
125
122 Will raise ``AssertionError`` if the result is not ready.
126 Will raise ``AssertionError`` if the result is not ready.
123 """
127 """
124 assert self.ready()
128 assert self.ready()
125 return self._success
129 return self._success
126
130
127 #----------------------------------------------------------------
131 #----------------------------------------------------------------
128 # Extra methods not in mp.pool.AsyncResult
132 # Extra methods not in mp.pool.AsyncResult
129 #----------------------------------------------------------------
133 #----------------------------------------------------------------
130
134
131 def get_dict(self, timeout=-1):
135 def get_dict(self, timeout=-1):
132 """Get the results as a dict, keyed by engine_id.
136 """Get the results as a dict, keyed by engine_id.
133
137
134 timeout behavior is described in `get()`.
138 timeout behavior is described in `get()`.
135 """
139 """
136
140
137 results = self.get(timeout)
141 results = self.get(timeout)
138 engine_ids = [ md['engine_id'] for md in self._metadata ]
142 engine_ids = [ md['engine_id'] for md in self._metadata ]
139 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
143 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
140 maxcount = bycount.count(bycount[-1])
144 maxcount = bycount.count(bycount[-1])
141 if maxcount > 1:
145 if maxcount > 1:
142 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
146 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
143 maxcount, bycount[-1]))
147 maxcount, bycount[-1]))
144
148
145 return dict(zip(engine_ids,results))
149 return dict(zip(engine_ids,results))
146
150
147 @property
151 @property
148 @check_ready
152 @check_ready
149 def result(self):
153 def result(self):
150 """result property wrapper for `get(timeout=0)`."""
154 """result property wrapper for `get(timeout=0)`."""
151 return self._result
155 return self._result
152
156
153 # abbreviated alias:
157 # abbreviated alias:
154 r = result
158 r = result
155
159
156 @property
160 @property
157 @check_ready
161 @check_ready
158 def metadata(self):
162 def metadata(self):
159 """property for accessing execution metadata."""
163 """property for accessing execution metadata."""
160 if self._single_result:
164 if self._single_result:
161 return self._metadata[0]
165 return self._metadata[0]
162 else:
166 else:
163 return self._metadata
167 return self._metadata
164
168
165 @property
169 @property
166 def result_dict(self):
170 def result_dict(self):
167 """result property as a dict."""
171 """result property as a dict."""
168 return self.get_dict(0)
172 return self.get_dict(0)
169
173
170 def __dict__(self):
174 def __dict__(self):
171 return self.get_dict(0)
175 return self.get_dict(0)
176
177 def abort(self):
178 """abort my tasks."""
179 assert not self.ready(), "Can't abort, I am already done!"
180 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
181
182 @property
183 def sent(self):
184 """check whether my messages have been sent"""
185 if self._tracker is None:
186 return True
187 else:
188 return self._tracker.done
172
189
173 #-------------------------------------
190 #-------------------------------------
174 # dict-access
191 # dict-access
175 #-------------------------------------
192 #-------------------------------------
176
193
177 @check_ready
194 @check_ready
178 def __getitem__(self, key):
195 def __getitem__(self, key):
179 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
196 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
180 """
197 """
181 if isinstance(key, int):
198 if isinstance(key, int):
182 return error.collect_exceptions([self._result[key]], self._fname)[0]
199 return error.collect_exceptions([self._result[key]], self._fname)[0]
183 elif isinstance(key, slice):
200 elif isinstance(key, slice):
184 return error.collect_exceptions(self._result[key], self._fname)
201 return error.collect_exceptions(self._result[key], self._fname)
185 elif isinstance(key, basestring):
202 elif isinstance(key, basestring):
186 values = [ md[key] for md in self._metadata ]
203 values = [ md[key] for md in self._metadata ]
187 if self._single_result:
204 if self._single_result:
188 return values[0]
205 return values[0]
189 else:
206 else:
190 return values
207 return values
191 else:
208 else:
192 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
209 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
193
210
194 @check_ready
211 @check_ready
195 def __getattr__(self, key):
212 def __getattr__(self, key):
196 """getattr maps to getitem for convenient attr access to metadata."""
213 """getattr maps to getitem for convenient attr access to metadata."""
197 if key not in self._metadata[0].keys():
214 if key not in self._metadata[0].keys():
198 raise AttributeError("%r object has no attribute %r"%(
215 raise AttributeError("%r object has no attribute %r"%(
199 self.__class__.__name__, key))
216 self.__class__.__name__, key))
200 return self.__getitem__(key)
217 return self.__getitem__(key)
201
218
202 # asynchronous iterator:
219 # asynchronous iterator:
203 def __iter__(self):
220 def __iter__(self):
204 if self._single_result:
221 if self._single_result:
205 raise TypeError("AsyncResults with a single result are not iterable.")
222 raise TypeError("AsyncResults with a single result are not iterable.")
206 try:
223 try:
207 rlist = self.get(0)
224 rlist = self.get(0)
208 except error.TimeoutError:
225 except error.TimeoutError:
209 # wait for each result individually
226 # wait for each result individually
210 for msg_id in self.msg_ids:
227 for msg_id in self.msg_ids:
211 ar = AsyncResult(self._client, msg_id, self._fname)
228 ar = AsyncResult(self._client, msg_id, self._fname)
212 yield ar.get()
229 yield ar.get()
213 else:
230 else:
214 # already done
231 # already done
215 for r in rlist:
232 for r in rlist:
216 yield r
233 yield r
217
234
218
235
219
236
220 class AsyncMapResult(AsyncResult):
237 class AsyncMapResult(AsyncResult):
221 """Class for representing results of non-blocking gathers.
238 """Class for representing results of non-blocking gathers.
222
239
223 This will properly reconstruct the gather.
240 This will properly reconstruct the gather.
224 """
241 """
225
242
226 def __init__(self, client, msg_ids, mapObject, fname=''):
243 def __init__(self, client, msg_ids, mapObject, fname=''):
227 AsyncResult.__init__(self, client, msg_ids, fname=fname)
244 AsyncResult.__init__(self, client, msg_ids, fname=fname)
228 self._mapObject = mapObject
245 self._mapObject = mapObject
229 self._single_result = False
246 self._single_result = False
230
247
231 def _reconstruct_result(self, res):
248 def _reconstruct_result(self, res):
232 """Perform the gather on the actual results."""
249 """Perform the gather on the actual results."""
233 return self._mapObject.joinPartitions(res)
250 return self._mapObject.joinPartitions(res)
234
251
235 # asynchronous iterator:
252 # asynchronous iterator:
236 def __iter__(self):
253 def __iter__(self):
237 try:
254 try:
238 rlist = self.get(0)
255 rlist = self.get(0)
239 except error.TimeoutError:
256 except error.TimeoutError:
240 # wait for each result individually
257 # wait for each result individually
241 for msg_id in self.msg_ids:
258 for msg_id in self.msg_ids:
242 ar = AsyncResult(self._client, msg_id, self._fname)
259 ar = AsyncResult(self._client, msg_id, self._fname)
243 rlist = ar.get()
260 rlist = ar.get()
244 try:
261 try:
245 for r in rlist:
262 for r in rlist:
246 yield r
263 yield r
247 except TypeError:
264 except TypeError:
248 # flattened, not a list
265 # flattened, not a list
249 # this could get broken by flattened data that returns iterables
266 # this could get broken by flattened data that returns iterables
250 # but most calls to map do not expose the `flatten` argument
267 # but most calls to map do not expose the `flatten` argument
251 yield rlist
268 yield rlist
252 else:
269 else:
253 # already done
270 # already done
254 for r in rlist:
271 for r in rlist:
255 yield r
272 yield r
256
273
257
274
258 class AsyncHubResult(AsyncResult):
275 class AsyncHubResult(AsyncResult):
259 """Class to wrap pending results that must be requested from the Hub.
276 """Class to wrap pending results that must be requested from the Hub.
260
277
261 Note that waiting/polling on these objects requires polling the Hubover the network,
278 Note that waiting/polling on these objects requires polling the Hubover the network,
262 so use `AsyncHubResult.wait()` sparingly.
279 so use `AsyncHubResult.wait()` sparingly.
263 """
280 """
264
281
265 def wait(self, timeout=-1):
282 def wait(self, timeout=-1):
266 """wait for result to complete."""
283 """wait for result to complete."""
267 start = time.time()
284 start = time.time()
268 if self._ready:
285 if self._ready:
269 return
286 return
270 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
287 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
271 local_ready = self._client.barrier(local_ids, timeout)
288 local_ready = self._client.barrier(local_ids, timeout)
272 if local_ready:
289 if local_ready:
273 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
290 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
274 if not remote_ids:
291 if not remote_ids:
275 self._ready = True
292 self._ready = True
276 else:
293 else:
277 rdict = self._client.result_status(remote_ids, status_only=False)
294 rdict = self._client.result_status(remote_ids, status_only=False)
278 pending = rdict['pending']
295 pending = rdict['pending']
279 while pending and (timeout < 0 or time.time() < start+timeout):
296 while pending and (timeout < 0 or time.time() < start+timeout):
280 rdict = self._client.result_status(remote_ids, status_only=False)
297 rdict = self._client.result_status(remote_ids, status_only=False)
281 pending = rdict['pending']
298 pending = rdict['pending']
282 if pending:
299 if pending:
283 time.sleep(0.1)
300 time.sleep(0.1)
284 if not pending:
301 if not pending:
285 self._ready = True
302 self._ready = True
286 if self._ready:
303 if self._ready:
287 try:
304 try:
288 results = map(self._client.results.get, self.msg_ids)
305 results = map(self._client.results.get, self.msg_ids)
289 self._result = results
306 self._result = results
290 if self._single_result:
307 if self._single_result:
291 r = results[0]
308 r = results[0]
292 if isinstance(r, Exception):
309 if isinstance(r, Exception):
293 raise r
310 raise r
294 else:
311 else:
295 results = error.collect_exceptions(results, self._fname)
312 results = error.collect_exceptions(results, self._fname)
296 self._result = self._reconstruct_result(results)
313 self._result = self._reconstruct_result(results)
297 except Exception, e:
314 except Exception, e:
298 self._exception = e
315 self._exception = e
299 self._success = False
316 self._success = False
300 else:
317 else:
301 self._success = True
318 self._success = True
302 finally:
319 finally:
303 self._metadata = map(self._client.metadata.get, self.msg_ids)
320 self._metadata = map(self._client.metadata.get, self.msg_ids)
304
321
305 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
322 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1545 +1,1569 b''
1 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.pickleutil import Reference
27 from IPython.utils.pickleutil import Reference
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 Dict, List, Bool, Str, Set)
29 Dict, List, Bool, Str, Set)
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.external.ssh import tunnel
31 from IPython.external.ssh import tunnel
32
32
33 from . import error
33 from . import error
34 from . import map as Map
34 from . import map as Map
35 from . import util
35 from . import util
36 from . import streamsession as ss
36 from . import streamsession as ss
37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
38 from .clusterdir import ClusterDir, ClusterDirError
38 from .clusterdir import ClusterDir, ClusterDirError
39 from .dependency import Dependency, depend, require, dependent
39 from .dependency import Dependency, depend, require, dependent
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 from .util import ReverseDict, validate_url, disambiguate_url
41 from .util import ReverseDict, validate_url, disambiguate_url
42 from .view import DirectView, LoadBalancedView
42 from .view import DirectView, LoadBalancedView
43
43
44 #--------------------------------------------------------------------------
44 #--------------------------------------------------------------------------
45 # helpers for implementing old MEC API via client.apply
45 # helpers for implementing old MEC API via client.apply
46 #--------------------------------------------------------------------------
46 #--------------------------------------------------------------------------
47
47
48 def _push(ns):
48 def _push(ns):
49 """helper method for implementing `client.push` via `client.apply`"""
49 """helper method for implementing `client.push` via `client.apply`"""
50 globals().update(ns)
50 globals().update(ns)
51
51
52 def _pull(keys):
52 def _pull(keys):
53 """helper method for implementing `client.pull` via `client.apply`"""
53 """helper method for implementing `client.pull` via `client.apply`"""
54 g = globals()
54 g = globals()
55 if isinstance(keys, (list,tuple, set)):
55 if isinstance(keys, (list,tuple, set)):
56 for key in keys:
56 for key in keys:
57 if not g.has_key(key):
57 if not g.has_key(key):
58 raise NameError("name '%s' is not defined"%key)
58 raise NameError("name '%s' is not defined"%key)
59 return map(g.get, keys)
59 return map(g.get, keys)
60 else:
60 else:
61 if not g.has_key(keys):
61 if not g.has_key(keys):
62 raise NameError("name '%s' is not defined"%keys)
62 raise NameError("name '%s' is not defined"%keys)
63 return g.get(keys)
63 return g.get(keys)
64
64
65 def _clear():
65 def _clear():
66 """helper method for implementing `client.clear` via `client.apply`"""
66 """helper method for implementing `client.clear` via `client.apply`"""
67 globals().clear()
67 globals().clear()
68
68
69 def _execute(code):
69 def _execute(code):
70 """helper method for implementing `client.execute` via `client.apply`"""
70 """helper method for implementing `client.execute` via `client.apply`"""
71 exec code in globals()
71 exec code in globals()
72
72
73
73
74 #--------------------------------------------------------------------------
74 #--------------------------------------------------------------------------
75 # Decorators for Client methods
75 # Decorators for Client methods
76 #--------------------------------------------------------------------------
76 #--------------------------------------------------------------------------
77
77
78 @decorator
78 @decorator
79 def spinfirst(f, self, *args, **kwargs):
79 def spinfirst(f, self, *args, **kwargs):
80 """Call spin() to sync state prior to calling the method."""
80 """Call spin() to sync state prior to calling the method."""
81 self.spin()
81 self.spin()
82 return f(self, *args, **kwargs)
82 return f(self, *args, **kwargs)
83
83
84 @decorator
84 @decorator
85 def defaultblock(f, self, *args, **kwargs):
85 def defaultblock(f, self, *args, **kwargs):
86 """Default to self.block; preserve self.block."""
86 """Default to self.block; preserve self.block."""
87 block = kwargs.get('block',None)
87 block = kwargs.get('block',None)
88 block = self.block if block is None else block
88 block = self.block if block is None else block
89 saveblock = self.block
89 saveblock = self.block
90 self.block = block
90 self.block = block
91 try:
91 try:
92 ret = f(self, *args, **kwargs)
92 ret = f(self, *args, **kwargs)
93 finally:
93 finally:
94 self.block = saveblock
94 self.block = saveblock
95 return ret
95 return ret
96
96
97
97
98 #--------------------------------------------------------------------------
98 #--------------------------------------------------------------------------
99 # Classes
99 # Classes
100 #--------------------------------------------------------------------------
100 #--------------------------------------------------------------------------
101
101
102 class Metadata(dict):
102 class Metadata(dict):
103 """Subclass of dict for initializing metadata values.
103 """Subclass of dict for initializing metadata values.
104
104
105 Attribute access works on keys.
105 Attribute access works on keys.
106
106
107 These objects have a strict set of keys - errors will raise if you try
107 These objects have a strict set of keys - errors will raise if you try
108 to add new keys.
108 to add new keys.
109 """
109 """
110 def __init__(self, *args, **kwargs):
110 def __init__(self, *args, **kwargs):
111 dict.__init__(self)
111 dict.__init__(self)
112 md = {'msg_id' : None,
112 md = {'msg_id' : None,
113 'submitted' : None,
113 'submitted' : None,
114 'started' : None,
114 'started' : None,
115 'completed' : None,
115 'completed' : None,
116 'received' : None,
116 'received' : None,
117 'engine_uuid' : None,
117 'engine_uuid' : None,
118 'engine_id' : None,
118 'engine_id' : None,
119 'follow' : None,
119 'follow' : None,
120 'after' : None,
120 'after' : None,
121 'status' : None,
121 'status' : None,
122
122
123 'pyin' : None,
123 'pyin' : None,
124 'pyout' : None,
124 'pyout' : None,
125 'pyerr' : None,
125 'pyerr' : None,
126 'stdout' : '',
126 'stdout' : '',
127 'stderr' : '',
127 'stderr' : '',
128 }
128 }
129 self.update(md)
129 self.update(md)
130 self.update(dict(*args, **kwargs))
130 self.update(dict(*args, **kwargs))
131
131
132 def __getattr__(self, key):
132 def __getattr__(self, key):
133 """getattr aliased to getitem"""
133 """getattr aliased to getitem"""
134 if key in self.iterkeys():
134 if key in self.iterkeys():
135 return self[key]
135 return self[key]
136 else:
136 else:
137 raise AttributeError(key)
137 raise AttributeError(key)
138
138
139 def __setattr__(self, key, value):
139 def __setattr__(self, key, value):
140 """setattr aliased to setitem, with strict"""
140 """setattr aliased to setitem, with strict"""
141 if key in self.iterkeys():
141 if key in self.iterkeys():
142 self[key] = value
142 self[key] = value
143 else:
143 else:
144 raise AttributeError(key)
144 raise AttributeError(key)
145
145
146 def __setitem__(self, key, value):
146 def __setitem__(self, key, value):
147 """strict static key enforcement"""
147 """strict static key enforcement"""
148 if key in self.iterkeys():
148 if key in self.iterkeys():
149 dict.__setitem__(self, key, value)
149 dict.__setitem__(self, key, value)
150 else:
150 else:
151 raise KeyError(key)
151 raise KeyError(key)
152
152
153
153
154 class Client(HasTraits):
154 class Client(HasTraits):
155 """A semi-synchronous client to the IPython ZMQ controller
155 """A semi-synchronous client to the IPython ZMQ controller
156
156
157 Parameters
157 Parameters
158 ----------
158 ----------
159
159
160 url_or_file : bytes; zmq url or path to ipcontroller-client.json
160 url_or_file : bytes; zmq url or path to ipcontroller-client.json
161 Connection information for the Hub's registration. If a json connector
161 Connection information for the Hub's registration. If a json connector
162 file is given, then likely no further configuration is necessary.
162 file is given, then likely no further configuration is necessary.
163 [Default: use profile]
163 [Default: use profile]
164 profile : bytes
164 profile : bytes
165 The name of the Cluster profile to be used to find connector information.
165 The name of the Cluster profile to be used to find connector information.
166 [Default: 'default']
166 [Default: 'default']
167 context : zmq.Context
167 context : zmq.Context
168 Pass an existing zmq.Context instance, otherwise the client will create its own.
168 Pass an existing zmq.Context instance, otherwise the client will create its own.
169 username : bytes
169 username : bytes
170 set username to be passed to the Session object
170 set username to be passed to the Session object
171 debug : bool
171 debug : bool
172 flag for lots of message printing for debug purposes
172 flag for lots of message printing for debug purposes
173
173
174 #-------------- ssh related args ----------------
174 #-------------- ssh related args ----------------
175 # These are args for configuring the ssh tunnel to be used
175 # These are args for configuring the ssh tunnel to be used
176 # credentials are used to forward connections over ssh to the Controller
176 # credentials are used to forward connections over ssh to the Controller
177 # Note that the ip given in `addr` needs to be relative to sshserver
177 # Note that the ip given in `addr` needs to be relative to sshserver
178 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
178 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
179 # and set sshserver as the same machine the Controller is on. However,
179 # and set sshserver as the same machine the Controller is on. However,
180 # the only requirement is that sshserver is able to see the Controller
180 # the only requirement is that sshserver is able to see the Controller
181 # (i.e. is within the same trusted network).
181 # (i.e. is within the same trusted network).
182
182
183 sshserver : str
183 sshserver : str
184 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
184 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
185 If keyfile or password is specified, and this is not, it will default to
185 If keyfile or password is specified, and this is not, it will default to
186 the ip given in addr.
186 the ip given in addr.
187 sshkey : str; path to public ssh key file
187 sshkey : str; path to public ssh key file
188 This specifies a key to be used in ssh login, default None.
188 This specifies a key to be used in ssh login, default None.
189 Regular default ssh keys will be used without specifying this argument.
189 Regular default ssh keys will be used without specifying this argument.
190 password : str
190 password : str
191 Your ssh password to sshserver. Note that if this is left None,
191 Your ssh password to sshserver. Note that if this is left None,
192 you will be prompted for it if passwordless key based login is unavailable.
192 you will be prompted for it if passwordless key based login is unavailable.
193 paramiko : bool
193 paramiko : bool
194 flag for whether to use paramiko instead of shell ssh for tunneling.
194 flag for whether to use paramiko instead of shell ssh for tunneling.
195 [default: True on win32, False else]
195 [default: True on win32, False else]
196
196
197 #------- exec authentication args -------
197 #------- exec authentication args -------
198 # If even localhost is untrusted, you can have some protection against
198 # If even localhost is untrusted, you can have some protection against
199 # unauthorized execution by using a key. Messages are still sent
199 # unauthorized execution by using a key. Messages are still sent
200 # as cleartext, so if someone can snoop your loopback traffic this will
200 # as cleartext, so if someone can snoop your loopback traffic this will
201 # not help against malicious attacks.
201 # not help against malicious attacks.
202
202
203 exec_key : str
203 exec_key : str
204 an authentication key or file containing a key
204 an authentication key or file containing a key
205 default: None
205 default: None
206
206
207
207
208 Attributes
208 Attributes
209 ----------
209 ----------
210
210
211 ids : set of int engine IDs
211 ids : set of int engine IDs
212 requesting the ids attribute always synchronizes
212 requesting the ids attribute always synchronizes
213 the registration state. To request ids without synchronization,
213 the registration state. To request ids without synchronization,
214 use semi-private _ids attributes.
214 use semi-private _ids attributes.
215
215
216 history : list of msg_ids
216 history : list of msg_ids
217 a list of msg_ids, keeping track of all the execution
217 a list of msg_ids, keeping track of all the execution
218 messages you have submitted in order.
218 messages you have submitted in order.
219
219
220 outstanding : set of msg_ids
220 outstanding : set of msg_ids
221 a set of msg_ids that have been submitted, but whose
221 a set of msg_ids that have been submitted, but whose
222 results have not yet been received.
222 results have not yet been received.
223
223
224 results : dict
224 results : dict
225 a dict of all our results, keyed by msg_id
225 a dict of all our results, keyed by msg_id
226
226
227 block : bool
227 block : bool
228 determines default behavior when block not specified
228 determines default behavior when block not specified
229 in execution methods
229 in execution methods
230
230
231 Methods
231 Methods
232 -------
232 -------
233
233
234 spin
234 spin
235 flushes incoming results and registration state changes
235 flushes incoming results and registration state changes
236 control methods spin, and requesting `ids` also ensures up to date
236 control methods spin, and requesting `ids` also ensures up to date
237
237
238 barrier
238 barrier
239 wait on one or more msg_ids
239 wait on one or more msg_ids
240
240
241 execution methods
241 execution methods
242 apply
242 apply
243 legacy: execute, run
243 legacy: execute, run
244
244
245 query methods
245 query methods
246 queue_status, get_result, purge
246 queue_status, get_result, purge
247
247
248 control methods
248 control methods
249 abort, shutdown
249 abort, shutdown
250
250
251 """
251 """
252
252
253
253
254 block = Bool(False)
254 block = Bool(False)
255 outstanding = Set()
255 outstanding = Set()
256 results = Instance('collections.defaultdict', (dict,))
256 results = Instance('collections.defaultdict', (dict,))
257 metadata = Instance('collections.defaultdict', (Metadata,))
257 metadata = Instance('collections.defaultdict', (Metadata,))
258 history = List()
258 history = List()
259 debug = Bool(False)
259 debug = Bool(False)
260 profile=CUnicode('default')
260 profile=CUnicode('default')
261
261
262 _outstanding_dict = Instance('collections.defaultdict', (set,))
262 _outstanding_dict = Instance('collections.defaultdict', (set,))
263 _ids = List()
263 _ids = List()
264 _connected=Bool(False)
264 _connected=Bool(False)
265 _ssh=Bool(False)
265 _ssh=Bool(False)
266 _context = Instance('zmq.Context')
266 _context = Instance('zmq.Context')
267 _config = Dict()
267 _config = Dict()
268 _engines=Instance(ReverseDict, (), {})
268 _engines=Instance(ReverseDict, (), {})
269 _registration_socket=Instance('zmq.Socket')
269 _registration_socket=Instance('zmq.Socket')
270 _query_socket=Instance('zmq.Socket')
270 _query_socket=Instance('zmq.Socket')
271 _control_socket=Instance('zmq.Socket')
271 _control_socket=Instance('zmq.Socket')
272 _iopub_socket=Instance('zmq.Socket')
272 _iopub_socket=Instance('zmq.Socket')
273 _notification_socket=Instance('zmq.Socket')
273 _notification_socket=Instance('zmq.Socket')
274 _mux_socket=Instance('zmq.Socket')
274 _mux_socket=Instance('zmq.Socket')
275 _task_socket=Instance('zmq.Socket')
275 _task_socket=Instance('zmq.Socket')
276 _task_scheme=Str()
276 _task_scheme=Str()
277 _balanced_views=Dict()
277 _balanced_views=Dict()
278 _direct_views=Dict()
278 _direct_views=Dict()
279 _closed = False
279 _closed = False
280
280
281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
282 context=None, username=None, debug=False, exec_key=None,
282 context=None, username=None, debug=False, exec_key=None,
283 sshserver=None, sshkey=None, password=None, paramiko=None,
283 sshserver=None, sshkey=None, password=None, paramiko=None,
284 ):
284 ):
285 super(Client, self).__init__(debug=debug, profile=profile)
285 super(Client, self).__init__(debug=debug, profile=profile)
286 if context is None:
286 if context is None:
287 context = zmq.Context()
287 context = zmq.Context()
288 self._context = context
288 self._context = context
289
289
290
290
291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
292 if self._cd is not None:
292 if self._cd is not None:
293 if url_or_file is None:
293 if url_or_file is None:
294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
296 " Please specify at least one of url_or_file or profile."
296 " Please specify at least one of url_or_file or profile."
297
297
298 try:
298 try:
299 validate_url(url_or_file)
299 validate_url(url_or_file)
300 except AssertionError:
300 except AssertionError:
301 if not os.path.exists(url_or_file):
301 if not os.path.exists(url_or_file):
302 if self._cd:
302 if self._cd:
303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
305 with open(url_or_file) as f:
305 with open(url_or_file) as f:
306 cfg = json.loads(f.read())
306 cfg = json.loads(f.read())
307 else:
307 else:
308 cfg = {'url':url_or_file}
308 cfg = {'url':url_or_file}
309
309
310 # sync defaults from args, json:
310 # sync defaults from args, json:
311 if sshserver:
311 if sshserver:
312 cfg['ssh'] = sshserver
312 cfg['ssh'] = sshserver
313 if exec_key:
313 if exec_key:
314 cfg['exec_key'] = exec_key
314 cfg['exec_key'] = exec_key
315 exec_key = cfg['exec_key']
315 exec_key = cfg['exec_key']
316 sshserver=cfg['ssh']
316 sshserver=cfg['ssh']
317 url = cfg['url']
317 url = cfg['url']
318 location = cfg.setdefault('location', None)
318 location = cfg.setdefault('location', None)
319 cfg['url'] = disambiguate_url(cfg['url'], location)
319 cfg['url'] = disambiguate_url(cfg['url'], location)
320 url = cfg['url']
320 url = cfg['url']
321
321
322 self._config = cfg
322 self._config = cfg
323
323
324 self._ssh = bool(sshserver or sshkey or password)
324 self._ssh = bool(sshserver or sshkey or password)
325 if self._ssh and sshserver is None:
325 if self._ssh and sshserver is None:
326 # default to ssh via localhost
326 # default to ssh via localhost
327 sshserver = url.split('://')[1].split(':')[0]
327 sshserver = url.split('://')[1].split(':')[0]
328 if self._ssh and password is None:
328 if self._ssh and password is None:
329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
330 password=False
330 password=False
331 else:
331 else:
332 password = getpass("SSH Password for %s: "%sshserver)
332 password = getpass("SSH Password for %s: "%sshserver)
333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
334 if exec_key is not None and os.path.isfile(exec_key):
334 if exec_key is not None and os.path.isfile(exec_key):
335 arg = 'keyfile'
335 arg = 'keyfile'
336 else:
336 else:
337 arg = 'key'
337 arg = 'key'
338 key_arg = {arg:exec_key}
338 key_arg = {arg:exec_key}
339 if username is None:
339 if username is None:
340 self.session = ss.StreamSession(**key_arg)
340 self.session = ss.StreamSession(**key_arg)
341 else:
341 else:
342 self.session = ss.StreamSession(username, **key_arg)
342 self.session = ss.StreamSession(username, **key_arg)
343 self._registration_socket = self._context.socket(zmq.XREQ)
343 self._registration_socket = self._context.socket(zmq.XREQ)
344 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
345 if self._ssh:
345 if self._ssh:
346 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
346 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
347 else:
347 else:
348 self._registration_socket.connect(url)
348 self._registration_socket.connect(url)
349
349
350 self.session.debug = self.debug
350 self.session.debug = self.debug
351
351
352 self._notification_handlers = {'registration_notification' : self._register_engine,
352 self._notification_handlers = {'registration_notification' : self._register_engine,
353 'unregistration_notification' : self._unregister_engine,
353 'unregistration_notification' : self._unregister_engine,
354 }
354 }
355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
356 'apply_reply' : self._handle_apply_reply}
356 'apply_reply' : self._handle_apply_reply}
357 self._connect(sshserver, ssh_kwargs)
357 self._connect(sshserver, ssh_kwargs)
358
358
359 def __del__(self):
360 """cleanup sockets, but _not_ context."""
361 self.close()
359
362
360 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
361 if ipython_dir is None:
364 if ipython_dir is None:
362 ipython_dir = get_ipython_dir()
365 ipython_dir = get_ipython_dir()
363 if cluster_dir is not None:
366 if cluster_dir is not None:
364 try:
367 try:
365 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
366 return
369 return
367 except ClusterDirError:
370 except ClusterDirError:
368 pass
371 pass
369 elif profile is not None:
372 elif profile is not None:
370 try:
373 try:
371 self._cd = ClusterDir.find_cluster_dir_by_profile(
374 self._cd = ClusterDir.find_cluster_dir_by_profile(
372 ipython_dir, profile)
375 ipython_dir, profile)
373 return
376 return
374 except ClusterDirError:
377 except ClusterDirError:
375 pass
378 pass
376 self._cd = None
379 self._cd = None
377
380
378 @property
381 @property
379 def ids(self):
382 def ids(self):
380 """Always up-to-date ids property."""
383 """Always up-to-date ids property."""
381 self._flush_notifications()
384 self._flush_notifications()
382 # always copy:
385 # always copy:
383 return list(self._ids)
386 return list(self._ids)
384
387
385 def close(self):
388 def close(self):
386 if self._closed:
389 if self._closed:
387 return
390 return
388 snames = filter(lambda n: n.endswith('socket'), dir(self))
391 snames = filter(lambda n: n.endswith('socket'), dir(self))
389 for socket in map(lambda name: getattr(self, name), snames):
392 for socket in map(lambda name: getattr(self, name), snames):
390 socket.close()
393 if isinstance(socket, zmq.Socket) and not socket.closed:
394 socket.close()
391 self._closed = True
395 self._closed = True
392
396
393 def _update_engines(self, engines):
397 def _update_engines(self, engines):
394 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
395 for k,v in engines.iteritems():
399 for k,v in engines.iteritems():
396 eid = int(k)
400 eid = int(k)
397 self._engines[eid] = bytes(v) # force not unicode
401 self._engines[eid] = bytes(v) # force not unicode
398 self._ids.append(eid)
402 self._ids.append(eid)
399 self._ids = sorted(self._ids)
403 self._ids = sorted(self._ids)
400 if sorted(self._engines.keys()) != range(len(self._engines)) and \
404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
401 self._task_scheme == 'pure' and self._task_socket:
405 self._task_scheme == 'pure' and self._task_socket:
402 self._stop_scheduling_tasks()
406 self._stop_scheduling_tasks()
403
407
404 def _stop_scheduling_tasks(self):
408 def _stop_scheduling_tasks(self):
405 """Stop scheduling tasks because an engine has been unregistered
409 """Stop scheduling tasks because an engine has been unregistered
406 from a pure ZMQ scheduler.
410 from a pure ZMQ scheduler.
407 """
411 """
408
412
409 self._task_socket.close()
413 self._task_socket.close()
410 self._task_socket = None
414 self._task_socket = None
411 msg = "An engine has been unregistered, and we are using pure " +\
415 msg = "An engine has been unregistered, and we are using pure " +\
412 "ZMQ task scheduling. Task farming will be disabled."
416 "ZMQ task scheduling. Task farming will be disabled."
413 if self.outstanding:
417 if self.outstanding:
414 msg += " If you were running tasks when this happened, " +\
418 msg += " If you were running tasks when this happened, " +\
415 "some `outstanding` msg_ids may never resolve."
419 "some `outstanding` msg_ids may never resolve."
416 warnings.warn(msg, RuntimeWarning)
420 warnings.warn(msg, RuntimeWarning)
417
421
418 def _build_targets(self, targets):
422 def _build_targets(self, targets):
419 """Turn valid target IDs or 'all' into two lists:
423 """Turn valid target IDs or 'all' into two lists:
420 (int_ids, uuids).
424 (int_ids, uuids).
421 """
425 """
422 if targets is None:
426 if targets is None:
423 targets = self._ids
427 targets = self._ids
424 elif isinstance(targets, str):
428 elif isinstance(targets, str):
425 if targets.lower() == 'all':
429 if targets.lower() == 'all':
426 targets = self._ids
430 targets = self._ids
427 else:
431 else:
428 raise TypeError("%r not valid str target, must be 'all'"%(targets))
432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
429 elif isinstance(targets, int):
433 elif isinstance(targets, int):
430 targets = [targets]
434 targets = [targets]
431 return [self._engines[t] for t in targets], list(targets)
435 return [self._engines[t] for t in targets], list(targets)
432
436
433 def _connect(self, sshserver, ssh_kwargs):
437 def _connect(self, sshserver, ssh_kwargs):
434 """setup all our socket connections to the controller. This is called from
438 """setup all our socket connections to the controller. This is called from
435 __init__."""
439 __init__."""
436
440
437 # Maybe allow reconnecting?
441 # Maybe allow reconnecting?
438 if self._connected:
442 if self._connected:
439 return
443 return
440 self._connected=True
444 self._connected=True
441
445
442 def connect_socket(s, url):
446 def connect_socket(s, url):
443 url = disambiguate_url(url, self._config['location'])
447 url = disambiguate_url(url, self._config['location'])
444 if self._ssh:
448 if self._ssh:
445 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
446 else:
450 else:
447 return s.connect(url)
451 return s.connect(url)
448
452
449 self.session.send(self._registration_socket, 'connection_request')
453 self.session.send(self._registration_socket, 'connection_request')
450 idents,msg = self.session.recv(self._registration_socket,mode=0)
454 idents,msg = self.session.recv(self._registration_socket,mode=0)
451 if self.debug:
455 if self.debug:
452 pprint(msg)
456 pprint(msg)
453 msg = ss.Message(msg)
457 msg = ss.Message(msg)
454 content = msg.content
458 content = msg.content
455 self._config['registration'] = dict(content)
459 self._config['registration'] = dict(content)
456 if content.status == 'ok':
460 if content.status == 'ok':
457 if content.mux:
461 if content.mux:
458 self._mux_socket = self._context.socket(zmq.PAIR)
462 self._mux_socket = self._context.socket(zmq.PAIR)
459 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
463 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
460 connect_socket(self._mux_socket, content.mux)
464 connect_socket(self._mux_socket, content.mux)
461 if content.task:
465 if content.task:
462 self._task_scheme, task_addr = content.task
466 self._task_scheme, task_addr = content.task
463 self._task_socket = self._context.socket(zmq.PAIR)
467 self._task_socket = self._context.socket(zmq.PAIR)
464 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
468 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
465 connect_socket(self._task_socket, task_addr)
469 connect_socket(self._task_socket, task_addr)
466 if content.notification:
470 if content.notification:
467 self._notification_socket = self._context.socket(zmq.SUB)
471 self._notification_socket = self._context.socket(zmq.SUB)
468 connect_socket(self._notification_socket, content.notification)
472 connect_socket(self._notification_socket, content.notification)
469 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
473 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
470 if content.query:
474 if content.query:
471 self._query_socket = self._context.socket(zmq.PAIR)
475 self._query_socket = self._context.socket(zmq.PAIR)
472 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
476 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
473 connect_socket(self._query_socket, content.query)
477 connect_socket(self._query_socket, content.query)
474 if content.control:
478 if content.control:
475 self._control_socket = self._context.socket(zmq.PAIR)
479 self._control_socket = self._context.socket(zmq.PAIR)
476 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
480 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
477 connect_socket(self._control_socket, content.control)
481 connect_socket(self._control_socket, content.control)
478 if content.iopub:
482 if content.iopub:
479 self._iopub_socket = self._context.socket(zmq.SUB)
483 self._iopub_socket = self._context.socket(zmq.SUB)
480 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
484 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
481 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
485 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
482 connect_socket(self._iopub_socket, content.iopub)
486 connect_socket(self._iopub_socket, content.iopub)
483 self._update_engines(dict(content.engines))
487 self._update_engines(dict(content.engines))
484
488
485 else:
489 else:
486 self._connected = False
490 self._connected = False
487 raise Exception("Failed to connect!")
491 raise Exception("Failed to connect!")
488
492
489 #--------------------------------------------------------------------------
493 #--------------------------------------------------------------------------
490 # handlers and callbacks for incoming messages
494 # handlers and callbacks for incoming messages
491 #--------------------------------------------------------------------------
495 #--------------------------------------------------------------------------
492
496
493 def _unwrap_exception(self, content):
497 def _unwrap_exception(self, content):
494 """unwrap exception, and remap engineid to int."""
498 """unwrap exception, and remap engineid to int."""
495 e = error.unwrap_exception(content)
499 e = error.unwrap_exception(content)
496 if e.engine_info:
500 if e.engine_info:
497 e_uuid = e.engine_info['engine_uuid']
501 e_uuid = e.engine_info['engine_uuid']
498 eid = self._engines[e_uuid]
502 eid = self._engines[e_uuid]
499 e.engine_info['engine_id'] = eid
503 e.engine_info['engine_id'] = eid
500 return e
504 return e
501
505
502 def _extract_metadata(self, header, parent, content):
506 def _extract_metadata(self, header, parent, content):
503 md = {'msg_id' : parent['msg_id'],
507 md = {'msg_id' : parent['msg_id'],
504 'received' : datetime.now(),
508 'received' : datetime.now(),
505 'engine_uuid' : header.get('engine', None),
509 'engine_uuid' : header.get('engine', None),
506 'follow' : parent.get('follow', []),
510 'follow' : parent.get('follow', []),
507 'after' : parent.get('after', []),
511 'after' : parent.get('after', []),
508 'status' : content['status'],
512 'status' : content['status'],
509 }
513 }
510
514
511 if md['engine_uuid'] is not None:
515 if md['engine_uuid'] is not None:
512 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
516 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
513
517
514 if 'date' in parent:
518 if 'date' in parent:
515 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
519 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
516 if 'started' in header:
520 if 'started' in header:
517 md['started'] = datetime.strptime(header['started'], util.ISO8601)
521 md['started'] = datetime.strptime(header['started'], util.ISO8601)
518 if 'date' in header:
522 if 'date' in header:
519 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
523 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
520 return md
524 return md
521
525
522 def _register_engine(self, msg):
526 def _register_engine(self, msg):
523 """Register a new engine, and update our connection info."""
527 """Register a new engine, and update our connection info."""
524 content = msg['content']
528 content = msg['content']
525 eid = content['id']
529 eid = content['id']
526 d = {eid : content['queue']}
530 d = {eid : content['queue']}
527 self._update_engines(d)
531 self._update_engines(d)
528
532
529 def _unregister_engine(self, msg):
533 def _unregister_engine(self, msg):
530 """Unregister an engine that has died."""
534 """Unregister an engine that has died."""
531 content = msg['content']
535 content = msg['content']
532 eid = int(content['id'])
536 eid = int(content['id'])
533 if eid in self._ids:
537 if eid in self._ids:
534 self._ids.remove(eid)
538 self._ids.remove(eid)
535 uuid = self._engines.pop(eid)
539 uuid = self._engines.pop(eid)
536
540
537 self._handle_stranded_msgs(eid, uuid)
541 self._handle_stranded_msgs(eid, uuid)
538
542
539 if self._task_socket and self._task_scheme == 'pure':
543 if self._task_socket and self._task_scheme == 'pure':
540 self._stop_scheduling_tasks()
544 self._stop_scheduling_tasks()
541
545
542 def _handle_stranded_msgs(self, eid, uuid):
546 def _handle_stranded_msgs(self, eid, uuid):
543 """Handle messages known to be on an engine when the engine unregisters.
547 """Handle messages known to be on an engine when the engine unregisters.
544
548
545 It is possible that this will fire prematurely - that is, an engine will
549 It is possible that this will fire prematurely - that is, an engine will
546 go down after completing a result, and the client will be notified
550 go down after completing a result, and the client will be notified
547 of the unregistration and later receive the successful result.
551 of the unregistration and later receive the successful result.
548 """
552 """
549
553
550 outstanding = self._outstanding_dict[uuid]
554 outstanding = self._outstanding_dict[uuid]
551
555
552 for msg_id in list(outstanding):
556 for msg_id in list(outstanding):
553 print msg_id
554 if msg_id in self.results:
557 if msg_id in self.results:
555 # we already
558 # we already
556 continue
559 continue
557 try:
560 try:
558 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
561 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
559 except:
562 except:
560 content = error.wrap_exception()
563 content = error.wrap_exception()
561 # build a fake message:
564 # build a fake message:
562 parent = {}
565 parent = {}
563 header = {}
566 header = {}
564 parent['msg_id'] = msg_id
567 parent['msg_id'] = msg_id
565 header['engine'] = uuid
568 header['engine'] = uuid
566 header['date'] = datetime.now().strftime(util.ISO8601)
569 header['date'] = datetime.now().strftime(util.ISO8601)
567 msg = dict(parent_header=parent, header=header, content=content)
570 msg = dict(parent_header=parent, header=header, content=content)
568 self._handle_apply_reply(msg)
571 self._handle_apply_reply(msg)
569
572
570 def _handle_execute_reply(self, msg):
573 def _handle_execute_reply(self, msg):
571 """Save the reply to an execute_request into our results.
574 """Save the reply to an execute_request into our results.
572
575
573 execute messages are never actually used. apply is used instead.
576 execute messages are never actually used. apply is used instead.
574 """
577 """
575
578
576 parent = msg['parent_header']
579 parent = msg['parent_header']
577 msg_id = parent['msg_id']
580 msg_id = parent['msg_id']
578 if msg_id not in self.outstanding:
581 if msg_id not in self.outstanding:
579 if msg_id in self.history:
582 if msg_id in self.history:
580 print ("got stale result: %s"%msg_id)
583 print ("got stale result: %s"%msg_id)
581 else:
584 else:
582 print ("got unknown result: %s"%msg_id)
585 print ("got unknown result: %s"%msg_id)
583 else:
586 else:
584 self.outstanding.remove(msg_id)
587 self.outstanding.remove(msg_id)
585 self.results[msg_id] = self._unwrap_exception(msg['content'])
588 self.results[msg_id] = self._unwrap_exception(msg['content'])
586
589
587 def _handle_apply_reply(self, msg):
590 def _handle_apply_reply(self, msg):
588 """Save the reply to an apply_request into our results."""
591 """Save the reply to an apply_request into our results."""
589 parent = msg['parent_header']
592 parent = msg['parent_header']
590 msg_id = parent['msg_id']
593 msg_id = parent['msg_id']
591 if msg_id not in self.outstanding:
594 if msg_id not in self.outstanding:
592 if msg_id in self.history:
595 if msg_id in self.history:
593 print ("got stale result: %s"%msg_id)
596 print ("got stale result: %s"%msg_id)
594 print self.results[msg_id]
597 print self.results[msg_id]
595 print msg
598 print msg
596 else:
599 else:
597 print ("got unknown result: %s"%msg_id)
600 print ("got unknown result: %s"%msg_id)
598 else:
601 else:
599 self.outstanding.remove(msg_id)
602 self.outstanding.remove(msg_id)
600 content = msg['content']
603 content = msg['content']
601 header = msg['header']
604 header = msg['header']
602
605
603 # construct metadata:
606 # construct metadata:
604 md = self.metadata[msg_id]
607 md = self.metadata[msg_id]
605 md.update(self._extract_metadata(header, parent, content))
608 md.update(self._extract_metadata(header, parent, content))
606 # is this redundant?
609 # is this redundant?
607 self.metadata[msg_id] = md
610 self.metadata[msg_id] = md
608
611
609 e_outstanding = self._outstanding_dict[md['engine_uuid']]
612 e_outstanding = self._outstanding_dict[md['engine_uuid']]
610 if msg_id in e_outstanding:
613 if msg_id in e_outstanding:
611 e_outstanding.remove(msg_id)
614 e_outstanding.remove(msg_id)
612
615
613 # construct result:
616 # construct result:
614 if content['status'] == 'ok':
617 if content['status'] == 'ok':
615 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
618 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
616 elif content['status'] == 'aborted':
619 elif content['status'] == 'aborted':
617 self.results[msg_id] = error.AbortedTask(msg_id)
620 self.results[msg_id] = error.AbortedTask(msg_id)
618 elif content['status'] == 'resubmitted':
621 elif content['status'] == 'resubmitted':
619 # TODO: handle resubmission
622 # TODO: handle resubmission
620 pass
623 pass
621 else:
624 else:
622 self.results[msg_id] = self._unwrap_exception(content)
625 self.results[msg_id] = self._unwrap_exception(content)
623
626
624 def _flush_notifications(self):
627 def _flush_notifications(self):
625 """Flush notifications of engine registrations waiting
628 """Flush notifications of engine registrations waiting
626 in ZMQ queue."""
629 in ZMQ queue."""
627 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
630 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
628 while msg is not None:
631 while msg is not None:
629 if self.debug:
632 if self.debug:
630 pprint(msg)
633 pprint(msg)
631 msg = msg[-1]
634 msg = msg[-1]
632 msg_type = msg['msg_type']
635 msg_type = msg['msg_type']
633 handler = self._notification_handlers.get(msg_type, None)
636 handler = self._notification_handlers.get(msg_type, None)
634 if handler is None:
637 if handler is None:
635 raise Exception("Unhandled message type: %s"%msg.msg_type)
638 raise Exception("Unhandled message type: %s"%msg.msg_type)
636 else:
639 else:
637 handler(msg)
640 handler(msg)
638 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
641 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
639
642
640 def _flush_results(self, sock):
643 def _flush_results(self, sock):
641 """Flush task or queue results waiting in ZMQ queue."""
644 """Flush task or queue results waiting in ZMQ queue."""
642 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
645 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
643 while msg is not None:
646 while msg is not None:
644 if self.debug:
647 if self.debug:
645 pprint(msg)
648 pprint(msg)
646 msg = msg[-1]
649 msg = msg[-1]
647 msg_type = msg['msg_type']
650 msg_type = msg['msg_type']
648 handler = self._queue_handlers.get(msg_type, None)
651 handler = self._queue_handlers.get(msg_type, None)
649 if handler is None:
652 if handler is None:
650 raise Exception("Unhandled message type: %s"%msg.msg_type)
653 raise Exception("Unhandled message type: %s"%msg.msg_type)
651 else:
654 else:
652 handler(msg)
655 handler(msg)
653 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
656 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
654
657
655 def _flush_control(self, sock):
658 def _flush_control(self, sock):
656 """Flush replies from the control channel waiting
659 """Flush replies from the control channel waiting
657 in the ZMQ queue.
660 in the ZMQ queue.
658
661
659 Currently: ignore them."""
662 Currently: ignore them."""
660 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
663 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
661 while msg is not None:
664 while msg is not None:
662 if self.debug:
665 if self.debug:
663 pprint(msg)
666 pprint(msg)
664 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
667 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
665
668
666 def _flush_iopub(self, sock):
669 def _flush_iopub(self, sock):
667 """Flush replies from the iopub channel waiting
670 """Flush replies from the iopub channel waiting
668 in the ZMQ queue.
671 in the ZMQ queue.
669 """
672 """
670 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
673 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
671 while msg is not None:
674 while msg is not None:
672 if self.debug:
675 if self.debug:
673 pprint(msg)
676 pprint(msg)
674 msg = msg[-1]
677 msg = msg[-1]
675 parent = msg['parent_header']
678 parent = msg['parent_header']
676 msg_id = parent['msg_id']
679 msg_id = parent['msg_id']
677 content = msg['content']
680 content = msg['content']
678 header = msg['header']
681 header = msg['header']
679 msg_type = msg['msg_type']
682 msg_type = msg['msg_type']
680
683
681 # init metadata:
684 # init metadata:
682 md = self.metadata[msg_id]
685 md = self.metadata[msg_id]
683
686
684 if msg_type == 'stream':
687 if msg_type == 'stream':
685 name = content['name']
688 name = content['name']
686 s = md[name] or ''
689 s = md[name] or ''
687 md[name] = s + content['data']
690 md[name] = s + content['data']
688 elif msg_type == 'pyerr':
691 elif msg_type == 'pyerr':
689 md.update({'pyerr' : self._unwrap_exception(content)})
692 md.update({'pyerr' : self._unwrap_exception(content)})
690 else:
693 else:
691 md.update({msg_type : content['data']})
694 md.update({msg_type : content['data']})
692
695
693 # reduntant?
696 # reduntant?
694 self.metadata[msg_id] = md
697 self.metadata[msg_id] = md
695
698
696 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
699 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
697
700
698 #--------------------------------------------------------------------------
701 #--------------------------------------------------------------------------
699 # len, getitem
702 # len, getitem
700 #--------------------------------------------------------------------------
703 #--------------------------------------------------------------------------
701
704
702 def __len__(self):
705 def __len__(self):
703 """len(client) returns # of engines."""
706 """len(client) returns # of engines."""
704 return len(self.ids)
707 return len(self.ids)
705
708
706 def __getitem__(self, key):
709 def __getitem__(self, key):
707 """index access returns DirectView multiplexer objects
710 """index access returns DirectView multiplexer objects
708
711
709 Must be int, slice, or list/tuple/xrange of ints"""
712 Must be int, slice, or list/tuple/xrange of ints"""
710 if not isinstance(key, (int, slice, tuple, list, xrange)):
713 if not isinstance(key, (int, slice, tuple, list, xrange)):
711 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
714 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
712 else:
715 else:
713 return self.view(key, balanced=False)
716 return self.view(key, balanced=False)
714
717
715 #--------------------------------------------------------------------------
718 #--------------------------------------------------------------------------
716 # Begin public methods
719 # Begin public methods
717 #--------------------------------------------------------------------------
720 #--------------------------------------------------------------------------
718
721
719 def spin(self):
722 def spin(self):
720 """Flush any registration notifications and execution results
723 """Flush any registration notifications and execution results
721 waiting in the ZMQ queue.
724 waiting in the ZMQ queue.
722 """
725 """
723 if self._notification_socket:
726 if self._notification_socket:
724 self._flush_notifications()
727 self._flush_notifications()
725 if self._mux_socket:
728 if self._mux_socket:
726 self._flush_results(self._mux_socket)
729 self._flush_results(self._mux_socket)
727 if self._task_socket:
730 if self._task_socket:
728 self._flush_results(self._task_socket)
731 self._flush_results(self._task_socket)
729 if self._control_socket:
732 if self._control_socket:
730 self._flush_control(self._control_socket)
733 self._flush_control(self._control_socket)
731 if self._iopub_socket:
734 if self._iopub_socket:
732 self._flush_iopub(self._iopub_socket)
735 self._flush_iopub(self._iopub_socket)
733
736
734 def barrier(self, jobs=None, timeout=-1):
737 def barrier(self, jobs=None, timeout=-1):
735 """waits on one or more `jobs`, for up to `timeout` seconds.
738 """waits on one or more `jobs`, for up to `timeout` seconds.
736
739
737 Parameters
740 Parameters
738 ----------
741 ----------
739
742
740 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
743 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
741 ints are indices to self.history
744 ints are indices to self.history
742 strs are msg_ids
745 strs are msg_ids
743 default: wait on all outstanding messages
746 default: wait on all outstanding messages
744 timeout : float
747 timeout : float
745 a time in seconds, after which to give up.
748 a time in seconds, after which to give up.
746 default is -1, which means no timeout
749 default is -1, which means no timeout
747
750
748 Returns
751 Returns
749 -------
752 -------
750
753
751 True : when all msg_ids are done
754 True : when all msg_ids are done
752 False : timeout reached, some msg_ids still outstanding
755 False : timeout reached, some msg_ids still outstanding
753 """
756 """
754 tic = time.time()
757 tic = time.time()
755 if jobs is None:
758 if jobs is None:
756 theids = self.outstanding
759 theids = self.outstanding
757 else:
760 else:
758 if isinstance(jobs, (int, str, AsyncResult)):
761 if isinstance(jobs, (int, str, AsyncResult)):
759 jobs = [jobs]
762 jobs = [jobs]
760 theids = set()
763 theids = set()
761 for job in jobs:
764 for job in jobs:
762 if isinstance(job, int):
765 if isinstance(job, int):
763 # index access
766 # index access
764 job = self.history[job]
767 job = self.history[job]
765 elif isinstance(job, AsyncResult):
768 elif isinstance(job, AsyncResult):
766 map(theids.add, job.msg_ids)
769 map(theids.add, job.msg_ids)
767 continue
770 continue
768 theids.add(job)
771 theids.add(job)
769 if not theids.intersection(self.outstanding):
772 if not theids.intersection(self.outstanding):
770 return True
773 return True
771 self.spin()
774 self.spin()
772 while theids.intersection(self.outstanding):
775 while theids.intersection(self.outstanding):
773 if timeout >= 0 and ( time.time()-tic ) > timeout:
776 if timeout >= 0 and ( time.time()-tic ) > timeout:
774 break
777 break
775 time.sleep(1e-3)
778 time.sleep(1e-3)
776 self.spin()
779 self.spin()
777 return len(theids.intersection(self.outstanding)) == 0
780 return len(theids.intersection(self.outstanding)) == 0
778
781
779 #--------------------------------------------------------------------------
782 #--------------------------------------------------------------------------
780 # Control methods
783 # Control methods
781 #--------------------------------------------------------------------------
784 #--------------------------------------------------------------------------
782
785
783 @spinfirst
786 @spinfirst
784 @defaultblock
787 @defaultblock
785 def clear(self, targets=None, block=None):
788 def clear(self, targets=None, block=None):
786 """Clear the namespace in target(s)."""
789 """Clear the namespace in target(s)."""
787 targets = self._build_targets(targets)[0]
790 targets = self._build_targets(targets)[0]
788 for t in targets:
791 for t in targets:
789 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
792 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
790 error = False
793 error = False
791 if self.block:
794 if self.block:
792 for i in range(len(targets)):
795 for i in range(len(targets)):
793 idents,msg = self.session.recv(self._control_socket,0)
796 idents,msg = self.session.recv(self._control_socket,0)
794 if self.debug:
797 if self.debug:
795 pprint(msg)
798 pprint(msg)
796 if msg['content']['status'] != 'ok':
799 if msg['content']['status'] != 'ok':
797 error = self._unwrap_exception(msg['content'])
800 error = self._unwrap_exception(msg['content'])
798 if error:
801 if error:
799 return error
802 raise error
800
803
801
804
802 @spinfirst
805 @spinfirst
803 @defaultblock
806 @defaultblock
804 def abort(self, jobs=None, targets=None, block=None):
807 def abort(self, jobs=None, targets=None, block=None):
805 """Abort specific jobs from the execution queues of target(s).
808 """Abort specific jobs from the execution queues of target(s).
806
809
807 This is a mechanism to prevent jobs that have already been submitted
810 This is a mechanism to prevent jobs that have already been submitted
808 from executing.
811 from executing.
809
812
810 Parameters
813 Parameters
811 ----------
814 ----------
812
815
813 jobs : msg_id, list of msg_ids, or AsyncResult
816 jobs : msg_id, list of msg_ids, or AsyncResult
814 The jobs to be aborted
817 The jobs to be aborted
815
818
816
819
817 """
820 """
818 targets = self._build_targets(targets)[0]
821 targets = self._build_targets(targets)[0]
819 msg_ids = []
822 msg_ids = []
820 if isinstance(jobs, (basestring,AsyncResult)):
823 if isinstance(jobs, (basestring,AsyncResult)):
821 jobs = [jobs]
824 jobs = [jobs]
822 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
823 if bad_ids:
826 if bad_ids:
824 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
825 for j in jobs:
828 for j in jobs:
826 if isinstance(j, AsyncResult):
829 if isinstance(j, AsyncResult):
827 msg_ids.extend(j.msg_ids)
830 msg_ids.extend(j.msg_ids)
828 else:
831 else:
829 msg_ids.append(j)
832 msg_ids.append(j)
830 content = dict(msg_ids=msg_ids)
833 content = dict(msg_ids=msg_ids)
831 for t in targets:
834 for t in targets:
832 self.session.send(self._control_socket, 'abort_request',
835 self.session.send(self._control_socket, 'abort_request',
833 content=content, ident=t)
836 content=content, ident=t)
834 error = False
837 error = False
835 if self.block:
838 if self.block:
836 for i in range(len(targets)):
839 for i in range(len(targets)):
837 idents,msg = self.session.recv(self._control_socket,0)
840 idents,msg = self.session.recv(self._control_socket,0)
838 if self.debug:
841 if self.debug:
839 pprint(msg)
842 pprint(msg)
840 if msg['content']['status'] != 'ok':
843 if msg['content']['status'] != 'ok':
841 error = self._unwrap_exception(msg['content'])
844 error = self._unwrap_exception(msg['content'])
842 if error:
845 if error:
843 return error
846 raise error
844
847
845 @spinfirst
848 @spinfirst
846 @defaultblock
849 @defaultblock
847 def shutdown(self, targets=None, restart=False, controller=False, block=None):
850 def shutdown(self, targets=None, restart=False, controller=False, block=None):
848 """Terminates one or more engine processes, optionally including the controller."""
851 """Terminates one or more engine processes, optionally including the controller."""
849 if controller:
852 if controller:
850 targets = 'all'
853 targets = 'all'
851 targets = self._build_targets(targets)[0]
854 targets = self._build_targets(targets)[0]
852 for t in targets:
855 for t in targets:
853 self.session.send(self._control_socket, 'shutdown_request',
856 self.session.send(self._control_socket, 'shutdown_request',
854 content={'restart':restart},ident=t)
857 content={'restart':restart},ident=t)
855 error = False
858 error = False
856 if block or controller:
859 if block or controller:
857 for i in range(len(targets)):
860 for i in range(len(targets)):
858 idents,msg = self.session.recv(self._control_socket,0)
861 idents,msg = self.session.recv(self._control_socket,0)
859 if self.debug:
862 if self.debug:
860 pprint(msg)
863 pprint(msg)
861 if msg['content']['status'] != 'ok':
864 if msg['content']['status'] != 'ok':
862 error = self._unwrap_exception(msg['content'])
865 error = self._unwrap_exception(msg['content'])
863
866
864 if controller:
867 if controller:
865 time.sleep(0.25)
868 time.sleep(0.25)
866 self.session.send(self._query_socket, 'shutdown_request')
869 self.session.send(self._query_socket, 'shutdown_request')
867 idents,msg = self.session.recv(self._query_socket, 0)
870 idents,msg = self.session.recv(self._query_socket, 0)
868 if self.debug:
871 if self.debug:
869 pprint(msg)
872 pprint(msg)
870 if msg['content']['status'] != 'ok':
873 if msg['content']['status'] != 'ok':
871 error = self._unwrap_exception(msg['content'])
874 error = self._unwrap_exception(msg['content'])
872
875
873 if error:
876 if error:
874 raise error
877 raise error
875
878
876 #--------------------------------------------------------------------------
879 #--------------------------------------------------------------------------
877 # Execution methods
880 # Execution methods
878 #--------------------------------------------------------------------------
881 #--------------------------------------------------------------------------
879
882
880 @defaultblock
883 @defaultblock
881 def execute(self, code, targets='all', block=None):
884 def execute(self, code, targets='all', block=None):
882 """Executes `code` on `targets` in blocking or nonblocking manner.
885 """Executes `code` on `targets` in blocking or nonblocking manner.
883
886
884 ``execute`` is always `bound` (affects engine namespace)
887 ``execute`` is always `bound` (affects engine namespace)
885
888
886 Parameters
889 Parameters
887 ----------
890 ----------
888
891
889 code : str
892 code : str
890 the code string to be executed
893 the code string to be executed
891 targets : int/str/list of ints/strs
894 targets : int/str/list of ints/strs
892 the engines on which to execute
895 the engines on which to execute
893 default : all
896 default : all
894 block : bool
897 block : bool
895 whether or not to wait until done to return
898 whether or not to wait until done to return
896 default: self.block
899 default: self.block
897 """
900 """
898 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
901 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
899 if not block:
902 if not block:
900 return result
903 return result
901
904
902 def run(self, filename, targets='all', block=None):
905 def run(self, filename, targets='all', block=None):
903 """Execute contents of `filename` on engine(s).
906 """Execute contents of `filename` on engine(s).
904
907
905 This simply reads the contents of the file and calls `execute`.
908 This simply reads the contents of the file and calls `execute`.
906
909
907 Parameters
910 Parameters
908 ----------
911 ----------
909
912
910 filename : str
913 filename : str
911 The path to the file
914 The path to the file
912 targets : int/str/list of ints/strs
915 targets : int/str/list of ints/strs
913 the engines on which to execute
916 the engines on which to execute
914 default : all
917 default : all
915 block : bool
918 block : bool
916 whether or not to wait until done
919 whether or not to wait until done
917 default: self.block
920 default: self.block
918
921
919 """
922 """
920 with open(filename, 'r') as f:
923 with open(filename, 'r') as f:
921 # add newline in case of trailing indented whitespace
924 # add newline in case of trailing indented whitespace
922 # which will cause SyntaxError
925 # which will cause SyntaxError
923 code = f.read()+'\n'
926 code = f.read()+'\n'
924 return self.execute(code, targets=targets, block=block)
927 return self.execute(code, targets=targets, block=block)
925
928
926 def _maybe_raise(self, result):
929 def _maybe_raise(self, result):
927 """wrapper for maybe raising an exception if apply failed."""
930 """wrapper for maybe raising an exception if apply failed."""
928 if isinstance(result, error.RemoteError):
931 if isinstance(result, error.RemoteError):
929 raise result
932 raise result
930
933
931 return result
934 return result
932
935
933 def _build_dependency(self, dep):
936 def _build_dependency(self, dep):
934 """helper for building jsonable dependencies from various input forms"""
937 """helper for building jsonable dependencies from various input forms"""
935 if isinstance(dep, Dependency):
938 if isinstance(dep, Dependency):
936 return dep.as_dict()
939 return dep.as_dict()
937 elif isinstance(dep, AsyncResult):
940 elif isinstance(dep, AsyncResult):
938 return dep.msg_ids
941 return dep.msg_ids
939 elif dep is None:
942 elif dep is None:
940 return []
943 return []
941 else:
944 else:
942 # pass to Dependency constructor
945 # pass to Dependency constructor
943 return list(Dependency(dep))
946 return list(Dependency(dep))
944
947
945 @defaultblock
948 @defaultblock
946 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
949 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
947 targets=None, balanced=None,
950 targets=None, balanced=None,
948 after=None, follow=None, timeout=None):
951 after=None, follow=None, timeout=None,
952 track=False):
949 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
953 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
950
954
951 This is the central execution command for the client.
955 This is the central execution command for the client.
952
956
953 Parameters
957 Parameters
954 ----------
958 ----------
955
959
956 f : function
960 f : function
957 The fuction to be called remotely
961 The fuction to be called remotely
958 args : tuple/list
962 args : tuple/list
959 The positional arguments passed to `f`
963 The positional arguments passed to `f`
960 kwargs : dict
964 kwargs : dict
961 The keyword arguments passed to `f`
965 The keyword arguments passed to `f`
962 bound : bool (default: True)
966 bound : bool (default: True)
963 Whether to execute in the Engine(s) namespace, or in a clean
967 Whether to execute in the Engine(s) namespace, or in a clean
964 namespace not affecting the engine.
968 namespace not affecting the engine.
965 block : bool (default: self.block)
969 block : bool (default: self.block)
966 Whether to wait for the result, or return immediately.
970 Whether to wait for the result, or return immediately.
967 False:
971 False:
968 returns AsyncResult
972 returns AsyncResult
969 True:
973 True:
970 returns actual result(s) of f(*args, **kwargs)
974 returns actual result(s) of f(*args, **kwargs)
971 if multiple targets:
975 if multiple targets:
972 list of results, matching `targets`
976 list of results, matching `targets`
973 targets : int,list of ints, 'all', None
977 targets : int,list of ints, 'all', None
974 Specify the destination of the job.
978 Specify the destination of the job.
975 if None:
979 if None:
976 Submit via Task queue for load-balancing.
980 Submit via Task queue for load-balancing.
977 if 'all':
981 if 'all':
978 Run on all active engines
982 Run on all active engines
979 if list:
983 if list:
980 Run on each specified engine
984 Run on each specified engine
981 if int:
985 if int:
982 Run on single engine
986 Run on single engine
983
987
984 balanced : bool, default None
988 balanced : bool, default None
985 whether to load-balance. This will default to True
989 whether to load-balance. This will default to True
986 if targets is unspecified, or False if targets is specified.
990 if targets is unspecified, or False if targets is specified.
987
991
988 The following arguments are only used when balanced is True:
992 The following arguments are only used when balanced is True:
989 after : Dependency or collection of msg_ids
993 after : Dependency or collection of msg_ids
990 Only for load-balanced execution (targets=None)
994 Only for load-balanced execution (targets=None)
991 Specify a list of msg_ids as a time-based dependency.
995 Specify a list of msg_ids as a time-based dependency.
992 This job will only be run *after* the dependencies
996 This job will only be run *after* the dependencies
993 have been met.
997 have been met.
994
998
995 follow : Dependency or collection of msg_ids
999 follow : Dependency or collection of msg_ids
996 Only for load-balanced execution (targets=None)
1000 Only for load-balanced execution (targets=None)
997 Specify a list of msg_ids as a location-based dependency.
1001 Specify a list of msg_ids as a location-based dependency.
998 This job will only be run on an engine where this dependency
1002 This job will only be run on an engine where this dependency
999 is met.
1003 is met.
1000
1004
1001 timeout : float/int or None
1005 timeout : float/int or None
1002 Only for load-balanced execution (targets=None)
1006 Only for load-balanced execution (targets=None)
1003 Specify an amount of time (in seconds) for the scheduler to
1007 Specify an amount of time (in seconds) for the scheduler to
1004 wait for dependencies to be met before failing with a
1008 wait for dependencies to be met before failing with a
1005 DependencyTimeout.
1009 DependencyTimeout.
1010 track : bool
1011 whether to track non-copying sends.
1012 [default False]
1006
1013
1007 after,follow,timeout only used if `balanced=True`.
1014 after,follow,timeout only used if `balanced=True`.
1008
1015
1009 Returns
1016 Returns
1010 -------
1017 -------
1011
1018
1012 if block is False:
1019 if block is False:
1013 return AsyncResult wrapping msg_ids
1020 return AsyncResult wrapping msg_ids
1014 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1021 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1015 else:
1022 else:
1016 if single target:
1023 if single target:
1017 return result of `f(*args, **kwargs)`
1024 return result of `f(*args, **kwargs)`
1018 else:
1025 else:
1019 return list of results, matching `targets`
1026 return list of results, matching `targets`
1020 """
1027 """
1021 assert not self._closed, "cannot use me anymore, I'm closed!"
1028 assert not self._closed, "cannot use me anymore, I'm closed!"
1022 # defaults:
1029 # defaults:
1023 block = block if block is not None else self.block
1030 block = block if block is not None else self.block
1024 args = args if args is not None else []
1031 args = args if args is not None else []
1025 kwargs = kwargs if kwargs is not None else {}
1032 kwargs = kwargs if kwargs is not None else {}
1026
1033
1027 if balanced is None:
1034 if balanced is None:
1028 if targets is None:
1035 if targets is None:
1029 # default to balanced if targets unspecified
1036 # default to balanced if targets unspecified
1030 balanced = True
1037 balanced = True
1031 else:
1038 else:
1032 # otherwise default to multiplexing
1039 # otherwise default to multiplexing
1033 balanced = False
1040 balanced = False
1034
1041
1035 if targets is None and balanced is False:
1042 if targets is None and balanced is False:
1036 # default to all if *not* balanced, and targets is unspecified
1043 # default to all if *not* balanced, and targets is unspecified
1037 targets = 'all'
1044 targets = 'all'
1038
1045
1039 # enforce types of f,args,kwrags
1046 # enforce types of f,args,kwrags
1040 if not callable(f):
1047 if not callable(f):
1041 raise TypeError("f must be callable, not %s"%type(f))
1048 raise TypeError("f must be callable, not %s"%type(f))
1042 if not isinstance(args, (tuple, list)):
1049 if not isinstance(args, (tuple, list)):
1043 raise TypeError("args must be tuple or list, not %s"%type(args))
1050 raise TypeError("args must be tuple or list, not %s"%type(args))
1044 if not isinstance(kwargs, dict):
1051 if not isinstance(kwargs, dict):
1045 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1052 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046
1053
1047 options = dict(bound=bound, block=block, targets=targets)
1054 options = dict(bound=bound, block=block, targets=targets, track=track)
1048
1055
1049 if balanced:
1056 if balanced:
1050 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1057 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1051 after=after, follow=follow, **options)
1058 after=after, follow=follow, **options)
1052 elif follow or after or timeout:
1059 elif follow or after or timeout:
1053 msg = "follow, after, and timeout args are only used for"
1060 msg = "follow, after, and timeout args are only used for"
1054 msg += " load-balanced execution."
1061 msg += " load-balanced execution."
1055 raise ValueError(msg)
1062 raise ValueError(msg)
1056 else:
1063 else:
1057 return self._apply_direct(f, args, kwargs, **options)
1064 return self._apply_direct(f, args, kwargs, **options)
1058
1065
1059 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1066 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1060 after=None, follow=None, timeout=None):
1067 after=None, follow=None, timeout=None, track=None):
1061 """call f(*args, **kwargs) remotely in a load-balanced manner.
1068 """call f(*args, **kwargs) remotely in a load-balanced manner.
1062
1069
1063 This is a private method, see `apply` for details.
1070 This is a private method, see `apply` for details.
1064 Not to be called directly!
1071 Not to be called directly!
1065 """
1072 """
1066
1073
1067 loc = locals()
1074 loc = locals()
1068 for name in ('bound', 'block'):
1075 for name in ('bound', 'block', 'track'):
1069 assert loc[name] is not None, "kwarg %r must be specified!"%name
1076 assert loc[name] is not None, "kwarg %r must be specified!"%name
1070
1077
1071 if self._task_socket is None:
1078 if self._task_socket is None:
1072 msg = "Task farming is disabled"
1079 msg = "Task farming is disabled"
1073 if self._task_scheme == 'pure':
1080 if self._task_scheme == 'pure':
1074 msg += " because the pure ZMQ scheduler cannot handle"
1081 msg += " because the pure ZMQ scheduler cannot handle"
1075 msg += " disappearing engines."
1082 msg += " disappearing engines."
1076 raise RuntimeError(msg)
1083 raise RuntimeError(msg)
1077
1084
1078 if self._task_scheme == 'pure':
1085 if self._task_scheme == 'pure':
1079 # pure zmq scheme doesn't support dependencies
1086 # pure zmq scheme doesn't support dependencies
1080 msg = "Pure ZMQ scheduler doesn't support dependencies"
1087 msg = "Pure ZMQ scheduler doesn't support dependencies"
1081 if (follow or after):
1088 if (follow or after):
1082 # hard fail on DAG dependencies
1089 # hard fail on DAG dependencies
1083 raise RuntimeError(msg)
1090 raise RuntimeError(msg)
1084 if isinstance(f, dependent):
1091 if isinstance(f, dependent):
1085 # soft warn on functional dependencies
1092 # soft warn on functional dependencies
1086 warnings.warn(msg, RuntimeWarning)
1093 warnings.warn(msg, RuntimeWarning)
1087
1094
1088 # defaults:
1095 # defaults:
1089 args = args if args is not None else []
1096 args = args if args is not None else []
1090 kwargs = kwargs if kwargs is not None else {}
1097 kwargs = kwargs if kwargs is not None else {}
1091
1098
1092 if targets:
1099 if targets:
1093 idents,_ = self._build_targets(targets)
1100 idents,_ = self._build_targets(targets)
1094 else:
1101 else:
1095 idents = []
1102 idents = []
1096
1103
1097 after = self._build_dependency(after)
1104 after = self._build_dependency(after)
1098 follow = self._build_dependency(follow)
1105 follow = self._build_dependency(follow)
1099 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1106 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1100 bufs = util.pack_apply_message(f,args,kwargs)
1107 bufs = util.pack_apply_message(f,args,kwargs)
1101 content = dict(bound=bound)
1108 content = dict(bound=bound)
1102
1109
1103 msg = self.session.send(self._task_socket, "apply_request",
1110 msg = self.session.send(self._task_socket, "apply_request",
1104 content=content, buffers=bufs, subheader=subheader)
1111 content=content, buffers=bufs, subheader=subheader, track=track)
1105 msg_id = msg['msg_id']
1112 msg_id = msg['msg_id']
1106 self.outstanding.add(msg_id)
1113 self.outstanding.add(msg_id)
1107 self.history.append(msg_id)
1114 self.history.append(msg_id)
1108 self.metadata[msg_id]['submitted'] = datetime.now()
1115 self.metadata[msg_id]['submitted'] = datetime.now()
1109
1116 tracker = None if track is False else msg['tracker']
1110 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1117 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1111 if block:
1118 if block:
1112 try:
1119 try:
1113 return ar.get()
1120 return ar.get()
1114 except KeyboardInterrupt:
1121 except KeyboardInterrupt:
1115 return ar
1122 return ar
1116 else:
1123 else:
1117 return ar
1124 return ar
1118
1125
1119 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
1126 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1127 track=None):
1120 """Then underlying method for applying functions to specific engines
1128 """Then underlying method for applying functions to specific engines
1121 via the MUX queue.
1129 via the MUX queue.
1122
1130
1123 This is a private method, see `apply` for details.
1131 This is a private method, see `apply` for details.
1124 Not to be called directly!
1132 Not to be called directly!
1125 """
1133 """
1126 loc = locals()
1134 loc = locals()
1127 for name in ('bound', 'block', 'targets'):
1135 for name in ('bound', 'block', 'targets', 'track'):
1128 assert loc[name] is not None, "kwarg %r must be specified!"%name
1136 assert loc[name] is not None, "kwarg %r must be specified!"%name
1129
1137
1130 idents,targets = self._build_targets(targets)
1138 idents,targets = self._build_targets(targets)
1131
1139
1132 subheader = {}
1140 subheader = {}
1133 content = dict(bound=bound)
1141 content = dict(bound=bound)
1134 bufs = util.pack_apply_message(f,args,kwargs)
1142 bufs = util.pack_apply_message(f,args,kwargs)
1135
1143
1136 msg_ids = []
1144 msg_ids = []
1145 trackers = []
1137 for ident in idents:
1146 for ident in idents:
1138 msg = self.session.send(self._mux_socket, "apply_request",
1147 msg = self.session.send(self._mux_socket, "apply_request",
1139 content=content, buffers=bufs, ident=ident, subheader=subheader)
1148 content=content, buffers=bufs, ident=ident, subheader=subheader,
1149 track=track)
1150 if track:
1151 trackers.append(msg['tracker'])
1140 msg_id = msg['msg_id']
1152 msg_id = msg['msg_id']
1141 self.outstanding.add(msg_id)
1153 self.outstanding.add(msg_id)
1142 self._outstanding_dict[ident].add(msg_id)
1154 self._outstanding_dict[ident].add(msg_id)
1143 self.history.append(msg_id)
1155 self.history.append(msg_id)
1144 msg_ids.append(msg_id)
1156 msg_ids.append(msg_id)
1145 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1157
1158 tracker = None if track is False else zmq.MessageTracker(*trackers)
1159 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1160
1146 if block:
1161 if block:
1147 try:
1162 try:
1148 return ar.get()
1163 return ar.get()
1149 except KeyboardInterrupt:
1164 except KeyboardInterrupt:
1150 return ar
1165 return ar
1151 else:
1166 else:
1152 return ar
1167 return ar
1153
1168
1154 #--------------------------------------------------------------------------
1169 #--------------------------------------------------------------------------
1155 # construct a View object
1170 # construct a View object
1156 #--------------------------------------------------------------------------
1171 #--------------------------------------------------------------------------
1157
1172
1158 @defaultblock
1173 @defaultblock
1159 def remote(self, bound=True, block=None, targets=None, balanced=None):
1174 def remote(self, bound=True, block=None, targets=None, balanced=None):
1160 """Decorator for making a RemoteFunction"""
1175 """Decorator for making a RemoteFunction"""
1161 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1176 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1162
1177
1163 @defaultblock
1178 @defaultblock
1164 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1179 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1165 """Decorator for making a ParallelFunction"""
1180 """Decorator for making a ParallelFunction"""
1166 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1181 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1167
1182
1168 def _cache_view(self, targets, balanced):
1183 def _cache_view(self, targets, balanced):
1169 """save views, so subsequent requests don't create new objects."""
1184 """save views, so subsequent requests don't create new objects."""
1170 if balanced:
1185 if balanced:
1171 view_class = LoadBalancedView
1186 view_class = LoadBalancedView
1172 view_cache = self._balanced_views
1187 view_cache = self._balanced_views
1173 else:
1188 else:
1174 view_class = DirectView
1189 view_class = DirectView
1175 view_cache = self._direct_views
1190 view_cache = self._direct_views
1176
1191
1177 # use str, since often targets will be a list
1192 # use str, since often targets will be a list
1178 key = str(targets)
1193 key = str(targets)
1179 if key not in view_cache:
1194 if key not in view_cache:
1180 view_cache[key] = view_class(client=self, targets=targets)
1195 view_cache[key] = view_class(client=self, targets=targets)
1181
1196
1182 return view_cache[key]
1197 return view_cache[key]
1183
1198
1184 def view(self, targets=None, balanced=None):
1199 def view(self, targets=None, balanced=None):
1185 """Method for constructing View objects.
1200 """Method for constructing View objects.
1186
1201
1187 If no arguments are specified, create a LoadBalancedView
1202 If no arguments are specified, create a LoadBalancedView
1188 using all engines. If only `targets` specified, it will
1203 using all engines. If only `targets` specified, it will
1189 be a DirectView. This method is the underlying implementation
1204 be a DirectView. This method is the underlying implementation
1190 of ``client.__getitem__``.
1205 of ``client.__getitem__``.
1191
1206
1192 Parameters
1207 Parameters
1193 ----------
1208 ----------
1194
1209
1195 targets: list,slice,int,etc. [default: use all engines]
1210 targets: list,slice,int,etc. [default: use all engines]
1196 The engines to use for the View
1211 The engines to use for the View
1197 balanced : bool [default: False if targets specified, True else]
1212 balanced : bool [default: False if targets specified, True else]
1198 whether to build a LoadBalancedView or a DirectView
1213 whether to build a LoadBalancedView or a DirectView
1199
1214
1200 """
1215 """
1201
1216
1202 balanced = (targets is None) if balanced is None else balanced
1217 balanced = (targets is None) if balanced is None else balanced
1203
1218
1204 if targets is None:
1219 if targets is None:
1205 if balanced:
1220 if balanced:
1206 return self._cache_view(None,True)
1221 return self._cache_view(None,True)
1207 else:
1222 else:
1208 targets = slice(None)
1223 targets = slice(None)
1209
1224
1210 if isinstance(targets, int):
1225 if isinstance(targets, int):
1211 if targets < 0:
1226 if targets < 0:
1212 targets = self.ids[targets]
1227 targets = self.ids[targets]
1213 if targets not in self.ids:
1228 if targets not in self.ids:
1214 raise IndexError("No such engine: %i"%targets)
1229 raise IndexError("No such engine: %i"%targets)
1215 return self._cache_view(targets, balanced)
1230 return self._cache_view(targets, balanced)
1216
1231
1217 if isinstance(targets, slice):
1232 if isinstance(targets, slice):
1218 indices = range(len(self.ids))[targets]
1233 indices = range(len(self.ids))[targets]
1219 ids = sorted(self._ids)
1234 ids = sorted(self._ids)
1220 targets = [ ids[i] for i in indices ]
1235 targets = [ ids[i] for i in indices ]
1221
1236
1222 if isinstance(targets, (tuple, list, xrange)):
1237 if isinstance(targets, (tuple, list, xrange)):
1223 _,targets = self._build_targets(list(targets))
1238 _,targets = self._build_targets(list(targets))
1224 return self._cache_view(targets, balanced)
1239 return self._cache_view(targets, balanced)
1225 else:
1240 else:
1226 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1241 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1227
1242
1228 #--------------------------------------------------------------------------
1243 #--------------------------------------------------------------------------
1229 # Data movement
1244 # Data movement
1230 #--------------------------------------------------------------------------
1245 #--------------------------------------------------------------------------
1231
1246
1232 @defaultblock
1247 @defaultblock
1233 def push(self, ns, targets='all', block=None):
1248 def push(self, ns, targets='all', block=None, track=False):
1234 """Push the contents of `ns` into the namespace on `target`"""
1249 """Push the contents of `ns` into the namespace on `target`"""
1235 if not isinstance(ns, dict):
1250 if not isinstance(ns, dict):
1236 raise TypeError("Must be a dict, not %s"%type(ns))
1251 raise TypeError("Must be a dict, not %s"%type(ns))
1237 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
1252 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False, track=track)
1238 if not block:
1253 if not block:
1239 return result
1254 return result
1240
1255
1241 @defaultblock
1256 @defaultblock
1242 def pull(self, keys, targets='all', block=None):
1257 def pull(self, keys, targets='all', block=None):
1243 """Pull objects from `target`'s namespace by `keys`"""
1258 """Pull objects from `target`'s namespace by `keys`"""
1244 if isinstance(keys, str):
1259 if isinstance(keys, str):
1245 pass
1260 pass
1246 elif isinstance(keys, (list,tuple,set)):
1261 elif isinstance(keys, (list,tuple,set)):
1247 for key in keys:
1262 for key in keys:
1248 if not isinstance(key, str):
1263 if not isinstance(key, str):
1249 raise TypeError
1264 raise TypeError
1250 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1265 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1251 return result
1266 return result
1252
1267
1253 @defaultblock
1268 @defaultblock
1254 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1269 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1255 """
1270 """
1256 Partition a Python sequence and send the partitions to a set of engines.
1271 Partition a Python sequence and send the partitions to a set of engines.
1257 """
1272 """
1258 targets = self._build_targets(targets)[-1]
1273 targets = self._build_targets(targets)[-1]
1259 mapObject = Map.dists[dist]()
1274 mapObject = Map.dists[dist]()
1260 nparts = len(targets)
1275 nparts = len(targets)
1261 msg_ids = []
1276 msg_ids = []
1277 trackers = []
1262 for index, engineid in enumerate(targets):
1278 for index, engineid in enumerate(targets):
1263 partition = mapObject.getPartition(seq, index, nparts)
1279 partition = mapObject.getPartition(seq, index, nparts)
1264 if flatten and len(partition) == 1:
1280 if flatten and len(partition) == 1:
1265 r = self.push({key: partition[0]}, targets=engineid, block=False)
1281 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1266 else:
1282 else:
1267 r = self.push({key: partition}, targets=engineid, block=False)
1283 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1268 msg_ids.extend(r.msg_ids)
1284 msg_ids.extend(r.msg_ids)
1269 r = AsyncResult(self, msg_ids, fname='scatter')
1285 if track:
1286 trackers.append(r._tracker)
1287
1288 if track:
1289 tracker = zmq.MessageTracker(*trackers)
1290 else:
1291 tracker = None
1292
1293 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1270 if block:
1294 if block:
1271 r.get()
1295 r.wait()
1272 else:
1296 else:
1273 return r
1297 return r
1274
1298
1275 @defaultblock
1299 @defaultblock
1276 def gather(self, key, dist='b', targets='all', block=None):
1300 def gather(self, key, dist='b', targets='all', block=None):
1277 """
1301 """
1278 Gather a partitioned sequence on a set of engines as a single local seq.
1302 Gather a partitioned sequence on a set of engines as a single local seq.
1279 """
1303 """
1280
1304
1281 targets = self._build_targets(targets)[-1]
1305 targets = self._build_targets(targets)[-1]
1282 mapObject = Map.dists[dist]()
1306 mapObject = Map.dists[dist]()
1283 msg_ids = []
1307 msg_ids = []
1284 for index, engineid in enumerate(targets):
1308 for index, engineid in enumerate(targets):
1285 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1309 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1286
1310
1287 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1311 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1288 if block:
1312 if block:
1289 return r.get()
1313 return r.get()
1290 else:
1314 else:
1291 return r
1315 return r
1292
1316
1293 #--------------------------------------------------------------------------
1317 #--------------------------------------------------------------------------
1294 # Query methods
1318 # Query methods
1295 #--------------------------------------------------------------------------
1319 #--------------------------------------------------------------------------
1296
1320
1297 @spinfirst
1321 @spinfirst
1298 @defaultblock
1322 @defaultblock
1299 def get_result(self, indices_or_msg_ids=None, block=None):
1323 def get_result(self, indices_or_msg_ids=None, block=None):
1300 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1324 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1301
1325
1302 If the client already has the results, no request to the Hub will be made.
1326 If the client already has the results, no request to the Hub will be made.
1303
1327
1304 This is a convenient way to construct AsyncResult objects, which are wrappers
1328 This is a convenient way to construct AsyncResult objects, which are wrappers
1305 that include metadata about execution, and allow for awaiting results that
1329 that include metadata about execution, and allow for awaiting results that
1306 were not submitted by this Client.
1330 were not submitted by this Client.
1307
1331
1308 It can also be a convenient way to retrieve the metadata associated with
1332 It can also be a convenient way to retrieve the metadata associated with
1309 blocking execution, since it always retrieves
1333 blocking execution, since it always retrieves
1310
1334
1311 Examples
1335 Examples
1312 --------
1336 --------
1313 ::
1337 ::
1314
1338
1315 In [10]: r = client.apply()
1339 In [10]: r = client.apply()
1316
1340
1317 Parameters
1341 Parameters
1318 ----------
1342 ----------
1319
1343
1320 indices_or_msg_ids : integer history index, str msg_id, or list of either
1344 indices_or_msg_ids : integer history index, str msg_id, or list of either
1321 The indices or msg_ids of indices to be retrieved
1345 The indices or msg_ids of indices to be retrieved
1322
1346
1323 block : bool
1347 block : bool
1324 Whether to wait for the result to be done
1348 Whether to wait for the result to be done
1325
1349
1326 Returns
1350 Returns
1327 -------
1351 -------
1328
1352
1329 AsyncResult
1353 AsyncResult
1330 A single AsyncResult object will always be returned.
1354 A single AsyncResult object will always be returned.
1331
1355
1332 AsyncHubResult
1356 AsyncHubResult
1333 A subclass of AsyncResult that retrieves results from the Hub
1357 A subclass of AsyncResult that retrieves results from the Hub
1334
1358
1335 """
1359 """
1336 if indices_or_msg_ids is None:
1360 if indices_or_msg_ids is None:
1337 indices_or_msg_ids = -1
1361 indices_or_msg_ids = -1
1338
1362
1339 if not isinstance(indices_or_msg_ids, (list,tuple)):
1363 if not isinstance(indices_or_msg_ids, (list,tuple)):
1340 indices_or_msg_ids = [indices_or_msg_ids]
1364 indices_or_msg_ids = [indices_or_msg_ids]
1341
1365
1342 theids = []
1366 theids = []
1343 for id in indices_or_msg_ids:
1367 for id in indices_or_msg_ids:
1344 if isinstance(id, int):
1368 if isinstance(id, int):
1345 id = self.history[id]
1369 id = self.history[id]
1346 if not isinstance(id, str):
1370 if not isinstance(id, str):
1347 raise TypeError("indices must be str or int, not %r"%id)
1371 raise TypeError("indices must be str or int, not %r"%id)
1348 theids.append(id)
1372 theids.append(id)
1349
1373
1350 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1374 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1351 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1375 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1352
1376
1353 if remote_ids:
1377 if remote_ids:
1354 ar = AsyncHubResult(self, msg_ids=theids)
1378 ar = AsyncHubResult(self, msg_ids=theids)
1355 else:
1379 else:
1356 ar = AsyncResult(self, msg_ids=theids)
1380 ar = AsyncResult(self, msg_ids=theids)
1357
1381
1358 if block:
1382 if block:
1359 ar.wait()
1383 ar.wait()
1360
1384
1361 return ar
1385 return ar
1362
1386
1363 @spinfirst
1387 @spinfirst
1364 def result_status(self, msg_ids, status_only=True):
1388 def result_status(self, msg_ids, status_only=True):
1365 """Check on the status of the result(s) of the apply request with `msg_ids`.
1389 """Check on the status of the result(s) of the apply request with `msg_ids`.
1366
1390
1367 If status_only is False, then the actual results will be retrieved, else
1391 If status_only is False, then the actual results will be retrieved, else
1368 only the status of the results will be checked.
1392 only the status of the results will be checked.
1369
1393
1370 Parameters
1394 Parameters
1371 ----------
1395 ----------
1372
1396
1373 msg_ids : list of msg_ids
1397 msg_ids : list of msg_ids
1374 if int:
1398 if int:
1375 Passed as index to self.history for convenience.
1399 Passed as index to self.history for convenience.
1376 status_only : bool (default: True)
1400 status_only : bool (default: True)
1377 if False:
1401 if False:
1378 Retrieve the actual results of completed tasks.
1402 Retrieve the actual results of completed tasks.
1379
1403
1380 Returns
1404 Returns
1381 -------
1405 -------
1382
1406
1383 results : dict
1407 results : dict
1384 There will always be the keys 'pending' and 'completed', which will
1408 There will always be the keys 'pending' and 'completed', which will
1385 be lists of msg_ids that are incomplete or complete. If `status_only`
1409 be lists of msg_ids that are incomplete or complete. If `status_only`
1386 is False, then completed results will be keyed by their `msg_id`.
1410 is False, then completed results will be keyed by their `msg_id`.
1387 """
1411 """
1388 if not isinstance(msg_ids, (list,tuple)):
1412 if not isinstance(msg_ids, (list,tuple)):
1389 msg_ids = [msg_ids]
1413 msg_ids = [msg_ids]
1390
1414
1391 theids = []
1415 theids = []
1392 for msg_id in msg_ids:
1416 for msg_id in msg_ids:
1393 if isinstance(msg_id, int):
1417 if isinstance(msg_id, int):
1394 msg_id = self.history[msg_id]
1418 msg_id = self.history[msg_id]
1395 if not isinstance(msg_id, basestring):
1419 if not isinstance(msg_id, basestring):
1396 raise TypeError("msg_ids must be str, not %r"%msg_id)
1420 raise TypeError("msg_ids must be str, not %r"%msg_id)
1397 theids.append(msg_id)
1421 theids.append(msg_id)
1398
1422
1399 completed = []
1423 completed = []
1400 local_results = {}
1424 local_results = {}
1401
1425
1402 # comment this block out to temporarily disable local shortcut:
1426 # comment this block out to temporarily disable local shortcut:
1403 for msg_id in theids:
1427 for msg_id in theids:
1404 if msg_id in self.results:
1428 if msg_id in self.results:
1405 completed.append(msg_id)
1429 completed.append(msg_id)
1406 local_results[msg_id] = self.results[msg_id]
1430 local_results[msg_id] = self.results[msg_id]
1407 theids.remove(msg_id)
1431 theids.remove(msg_id)
1408
1432
1409 if theids: # some not locally cached
1433 if theids: # some not locally cached
1410 content = dict(msg_ids=theids, status_only=status_only)
1434 content = dict(msg_ids=theids, status_only=status_only)
1411 msg = self.session.send(self._query_socket, "result_request", content=content)
1435 msg = self.session.send(self._query_socket, "result_request", content=content)
1412 zmq.select([self._query_socket], [], [])
1436 zmq.select([self._query_socket], [], [])
1413 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1437 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1414 if self.debug:
1438 if self.debug:
1415 pprint(msg)
1439 pprint(msg)
1416 content = msg['content']
1440 content = msg['content']
1417 if content['status'] != 'ok':
1441 if content['status'] != 'ok':
1418 raise self._unwrap_exception(content)
1442 raise self._unwrap_exception(content)
1419 buffers = msg['buffers']
1443 buffers = msg['buffers']
1420 else:
1444 else:
1421 content = dict(completed=[],pending=[])
1445 content = dict(completed=[],pending=[])
1422
1446
1423 content['completed'].extend(completed)
1447 content['completed'].extend(completed)
1424
1448
1425 if status_only:
1449 if status_only:
1426 return content
1450 return content
1427
1451
1428 failures = []
1452 failures = []
1429 # load cached results into result:
1453 # load cached results into result:
1430 content.update(local_results)
1454 content.update(local_results)
1431 # update cache with results:
1455 # update cache with results:
1432 for msg_id in sorted(theids):
1456 for msg_id in sorted(theids):
1433 if msg_id in content['completed']:
1457 if msg_id in content['completed']:
1434 rec = content[msg_id]
1458 rec = content[msg_id]
1435 parent = rec['header']
1459 parent = rec['header']
1436 header = rec['result_header']
1460 header = rec['result_header']
1437 rcontent = rec['result_content']
1461 rcontent = rec['result_content']
1438 iodict = rec['io']
1462 iodict = rec['io']
1439 if isinstance(rcontent, str):
1463 if isinstance(rcontent, str):
1440 rcontent = self.session.unpack(rcontent)
1464 rcontent = self.session.unpack(rcontent)
1441
1465
1442 md = self.metadata[msg_id]
1466 md = self.metadata[msg_id]
1443 md.update(self._extract_metadata(header, parent, rcontent))
1467 md.update(self._extract_metadata(header, parent, rcontent))
1444 md.update(iodict)
1468 md.update(iodict)
1445
1469
1446 if rcontent['status'] == 'ok':
1470 if rcontent['status'] == 'ok':
1447 res,buffers = util.unserialize_object(buffers)
1471 res,buffers = util.unserialize_object(buffers)
1448 else:
1472 else:
1449 print rcontent
1473 print rcontent
1450 res = self._unwrap_exception(rcontent)
1474 res = self._unwrap_exception(rcontent)
1451 failures.append(res)
1475 failures.append(res)
1452
1476
1453 self.results[msg_id] = res
1477 self.results[msg_id] = res
1454 content[msg_id] = res
1478 content[msg_id] = res
1455
1479
1456 if len(theids) == 1 and failures:
1480 if len(theids) == 1 and failures:
1457 raise failures[0]
1481 raise failures[0]
1458
1482
1459 error.collect_exceptions(failures, "result_status")
1483 error.collect_exceptions(failures, "result_status")
1460 return content
1484 return content
1461
1485
1462 @spinfirst
1486 @spinfirst
1463 def queue_status(self, targets='all', verbose=False):
1487 def queue_status(self, targets='all', verbose=False):
1464 """Fetch the status of engine queues.
1488 """Fetch the status of engine queues.
1465
1489
1466 Parameters
1490 Parameters
1467 ----------
1491 ----------
1468
1492
1469 targets : int/str/list of ints/strs
1493 targets : int/str/list of ints/strs
1470 the engines whose states are to be queried.
1494 the engines whose states are to be queried.
1471 default : all
1495 default : all
1472 verbose : bool
1496 verbose : bool
1473 Whether to return lengths only, or lists of ids for each element
1497 Whether to return lengths only, or lists of ids for each element
1474 """
1498 """
1475 targets = self._build_targets(targets)[1]
1499 targets = self._build_targets(targets)[1]
1476 content = dict(targets=targets, verbose=verbose)
1500 content = dict(targets=targets, verbose=verbose)
1477 self.session.send(self._query_socket, "queue_request", content=content)
1501 self.session.send(self._query_socket, "queue_request", content=content)
1478 idents,msg = self.session.recv(self._query_socket, 0)
1502 idents,msg = self.session.recv(self._query_socket, 0)
1479 if self.debug:
1503 if self.debug:
1480 pprint(msg)
1504 pprint(msg)
1481 content = msg['content']
1505 content = msg['content']
1482 status = content.pop('status')
1506 status = content.pop('status')
1483 if status != 'ok':
1507 if status != 'ok':
1484 raise self._unwrap_exception(content)
1508 raise self._unwrap_exception(content)
1485 return util.rekey(content)
1509 return util.rekey(content)
1486
1510
1487 @spinfirst
1511 @spinfirst
1488 def purge_results(self, jobs=[], targets=[]):
1512 def purge_results(self, jobs=[], targets=[]):
1489 """Tell the controller to forget results.
1513 """Tell the controller to forget results.
1490
1514
1491 Individual results can be purged by msg_id, or the entire
1515 Individual results can be purged by msg_id, or the entire
1492 history of specific targets can be purged.
1516 history of specific targets can be purged.
1493
1517
1494 Parameters
1518 Parameters
1495 ----------
1519 ----------
1496
1520
1497 jobs : str or list of strs or AsyncResult objects
1521 jobs : str or list of strs or AsyncResult objects
1498 the msg_ids whose results should be forgotten.
1522 the msg_ids whose results should be forgotten.
1499 targets : int/str/list of ints/strs
1523 targets : int/str/list of ints/strs
1500 The targets, by uuid or int_id, whose entire history is to be purged.
1524 The targets, by uuid or int_id, whose entire history is to be purged.
1501 Use `targets='all'` to scrub everything from the controller's memory.
1525 Use `targets='all'` to scrub everything from the controller's memory.
1502
1526
1503 default : None
1527 default : None
1504 """
1528 """
1505 if not targets and not jobs:
1529 if not targets and not jobs:
1506 raise ValueError("Must specify at least one of `targets` and `jobs`")
1530 raise ValueError("Must specify at least one of `targets` and `jobs`")
1507 if targets:
1531 if targets:
1508 targets = self._build_targets(targets)[1]
1532 targets = self._build_targets(targets)[1]
1509
1533
1510 # construct msg_ids from jobs
1534 # construct msg_ids from jobs
1511 msg_ids = []
1535 msg_ids = []
1512 if isinstance(jobs, (basestring,AsyncResult)):
1536 if isinstance(jobs, (basestring,AsyncResult)):
1513 jobs = [jobs]
1537 jobs = [jobs]
1514 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1538 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1515 if bad_ids:
1539 if bad_ids:
1516 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1540 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1517 for j in jobs:
1541 for j in jobs:
1518 if isinstance(j, AsyncResult):
1542 if isinstance(j, AsyncResult):
1519 msg_ids.extend(j.msg_ids)
1543 msg_ids.extend(j.msg_ids)
1520 else:
1544 else:
1521 msg_ids.append(j)
1545 msg_ids.append(j)
1522
1546
1523 content = dict(targets=targets, msg_ids=msg_ids)
1547 content = dict(targets=targets, msg_ids=msg_ids)
1524 self.session.send(self._query_socket, "purge_request", content=content)
1548 self.session.send(self._query_socket, "purge_request", content=content)
1525 idents, msg = self.session.recv(self._query_socket, 0)
1549 idents, msg = self.session.recv(self._query_socket, 0)
1526 if self.debug:
1550 if self.debug:
1527 pprint(msg)
1551 pprint(msg)
1528 content = msg['content']
1552 content = msg['content']
1529 if content['status'] != 'ok':
1553 if content['status'] != 'ok':
1530 raise self._unwrap_exception(content)
1554 raise self._unwrap_exception(content)
1531
1555
1532
1556
1533 __all__ = [ 'Client',
1557 __all__ = [ 'Client',
1534 'depend',
1558 'depend',
1535 'require',
1559 'require',
1536 'remote',
1560 'remote',
1537 'parallel',
1561 'parallel',
1538 'RemoteFunction',
1562 'RemoteFunction',
1539 'ParallelFunction',
1563 'ParallelFunction',
1540 'DirectView',
1564 'DirectView',
1541 'LoadBalancedView',
1565 'LoadBalancedView',
1542 'AsyncResult',
1566 'AsyncResult',
1543 'AsyncMapResult',
1567 'AsyncMapResult',
1544 'Reference'
1568 'Reference'
1545 ]
1569 ]
@@ -1,377 +1,412 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4
4
5
5
6 import os
6 import os
7 import pprint
7 import pprint
8 import uuid
8 import uuid
9 from datetime import datetime
9 from datetime import datetime
10
10
11 try:
11 try:
12 import cPickle
12 import cPickle
13 pickle = cPickle
13 pickle = cPickle
14 except:
14 except:
15 cPickle = None
15 cPickle = None
16 import pickle
16 import pickle
17
17
18 import zmq
18 import zmq
19 from zmq.utils import jsonapi
19 from zmq.utils import jsonapi
20 from zmq.eventloop.zmqstream import ZMQStream
20 from zmq.eventloop.zmqstream import ZMQStream
21
21
22 from .util import ISO8601
22 from .util import ISO8601
23
23
24 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
24 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
25 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
25 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
26 if json_name in ('jsonlib', 'jsonlib2'):
26 if json_name in ('jsonlib', 'jsonlib2'):
27 use_json = True
27 use_json = True
28 elif json_name:
28 elif json_name:
29 if cPickle is None:
29 if cPickle is None:
30 use_json = True
30 use_json = True
31 else:
31 else:
32 use_json = False
32 use_json = False
33 else:
33 else:
34 use_json = False
34 use_json = False
35
35
36 def squash_unicode(obj):
36 def squash_unicode(obj):
37 if isinstance(obj,dict):
37 if isinstance(obj,dict):
38 for key in obj.keys():
38 for key in obj.keys():
39 obj[key] = squash_unicode(obj[key])
39 obj[key] = squash_unicode(obj[key])
40 if isinstance(key, unicode):
40 if isinstance(key, unicode):
41 obj[squash_unicode(key)] = obj.pop(key)
41 obj[squash_unicode(key)] = obj.pop(key)
42 elif isinstance(obj, list):
42 elif isinstance(obj, list):
43 for i,v in enumerate(obj):
43 for i,v in enumerate(obj):
44 obj[i] = squash_unicode(v)
44 obj[i] = squash_unicode(v)
45 elif isinstance(obj, unicode):
45 elif isinstance(obj, unicode):
46 obj = obj.encode('utf8')
46 obj = obj.encode('utf8')
47 return obj
47 return obj
48
48
49 json_packer = jsonapi.dumps
49 json_packer = jsonapi.dumps
50 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
50 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
51
51
52 pickle_packer = lambda o: pickle.dumps(o,-1)
52 pickle_packer = lambda o: pickle.dumps(o,-1)
53 pickle_unpacker = pickle.loads
53 pickle_unpacker = pickle.loads
54
54
55 if use_json:
55 if use_json:
56 default_packer = json_packer
56 default_packer = json_packer
57 default_unpacker = json_unpacker
57 default_unpacker = json_unpacker
58 else:
58 else:
59 default_packer = pickle_packer
59 default_packer = pickle_packer
60 default_unpacker = pickle_unpacker
60 default_unpacker = pickle_unpacker
61
61
62
62
63 DELIM="<IDS|MSG>"
63 DELIM="<IDS|MSG>"
64
64
65 class Message(object):
65 class Message(object):
66 """A simple message object that maps dict keys to attributes.
66 """A simple message object that maps dict keys to attributes.
67
67
68 A Message can be created from a dict and a dict from a Message instance
68 A Message can be created from a dict and a dict from a Message instance
69 simply by calling dict(msg_obj)."""
69 simply by calling dict(msg_obj)."""
70
70
71 def __init__(self, msg_dict):
71 def __init__(self, msg_dict):
72 dct = self.__dict__
72 dct = self.__dict__
73 for k, v in dict(msg_dict).iteritems():
73 for k, v in dict(msg_dict).iteritems():
74 if isinstance(v, dict):
74 if isinstance(v, dict):
75 v = Message(v)
75 v = Message(v)
76 dct[k] = v
76 dct[k] = v
77
77
78 # Having this iterator lets dict(msg_obj) work out of the box.
78 # Having this iterator lets dict(msg_obj) work out of the box.
79 def __iter__(self):
79 def __iter__(self):
80 return iter(self.__dict__.iteritems())
80 return iter(self.__dict__.iteritems())
81
81
82 def __repr__(self):
82 def __repr__(self):
83 return repr(self.__dict__)
83 return repr(self.__dict__)
84
84
85 def __str__(self):
85 def __str__(self):
86 return pprint.pformat(self.__dict__)
86 return pprint.pformat(self.__dict__)
87
87
88 def __contains__(self, k):
88 def __contains__(self, k):
89 return k in self.__dict__
89 return k in self.__dict__
90
90
91 def __getitem__(self, k):
91 def __getitem__(self, k):
92 return self.__dict__[k]
92 return self.__dict__[k]
93
93
94
94
95 def msg_header(msg_id, msg_type, username, session):
95 def msg_header(msg_id, msg_type, username, session):
96 date=datetime.now().strftime(ISO8601)
96 date=datetime.now().strftime(ISO8601)
97 return locals()
97 return locals()
98
98
99 def extract_header(msg_or_header):
99 def extract_header(msg_or_header):
100 """Given a message or header, return the header."""
100 """Given a message or header, return the header."""
101 if not msg_or_header:
101 if not msg_or_header:
102 return {}
102 return {}
103 try:
103 try:
104 # See if msg_or_header is the entire message.
104 # See if msg_or_header is the entire message.
105 h = msg_or_header['header']
105 h = msg_or_header['header']
106 except KeyError:
106 except KeyError:
107 try:
107 try:
108 # See if msg_or_header is just the header
108 # See if msg_or_header is just the header
109 h = msg_or_header['msg_id']
109 h = msg_or_header['msg_id']
110 except KeyError:
110 except KeyError:
111 raise
111 raise
112 else:
112 else:
113 h = msg_or_header
113 h = msg_or_header
114 if not isinstance(h, dict):
114 if not isinstance(h, dict):
115 h = dict(h)
115 h = dict(h)
116 return h
116 return h
117
117
118 class StreamSession(object):
118 class StreamSession(object):
119 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
119 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
120 debug=False
120 debug=False
121 key=None
121 key=None
122
122
123 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
123 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
124 if username is None:
124 if username is None:
125 username = os.environ.get('USER','username')
125 username = os.environ.get('USER','username')
126 self.username = username
126 self.username = username
127 if session is None:
127 if session is None:
128 self.session = str(uuid.uuid4())
128 self.session = str(uuid.uuid4())
129 else:
129 else:
130 self.session = session
130 self.session = session
131 self.msg_id = str(uuid.uuid4())
131 self.msg_id = str(uuid.uuid4())
132 if packer is None:
132 if packer is None:
133 self.pack = default_packer
133 self.pack = default_packer
134 else:
134 else:
135 if not callable(packer):
135 if not callable(packer):
136 raise TypeError("packer must be callable, not %s"%type(packer))
136 raise TypeError("packer must be callable, not %s"%type(packer))
137 self.pack = packer
137 self.pack = packer
138
138
139 if unpacker is None:
139 if unpacker is None:
140 self.unpack = default_unpacker
140 self.unpack = default_unpacker
141 else:
141 else:
142 if not callable(unpacker):
142 if not callable(unpacker):
143 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
143 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
144 self.unpack = unpacker
144 self.unpack = unpacker
145
145
146 if key is not None and keyfile is not None:
146 if key is not None and keyfile is not None:
147 raise TypeError("Must specify key OR keyfile, not both")
147 raise TypeError("Must specify key OR keyfile, not both")
148 if keyfile is not None:
148 if keyfile is not None:
149 with open(keyfile) as f:
149 with open(keyfile) as f:
150 self.key = f.read().strip()
150 self.key = f.read().strip()
151 else:
151 else:
152 self.key = key
152 self.key = key
153 if isinstance(self.key, unicode):
153 if isinstance(self.key, unicode):
154 self.key = self.key.encode('utf8')
154 self.key = self.key.encode('utf8')
155 # print key, keyfile, self.key
155 # print key, keyfile, self.key
156 self.none = self.pack({})
156 self.none = self.pack({})
157
157
158 def msg_header(self, msg_type):
158 def msg_header(self, msg_type):
159 h = msg_header(self.msg_id, msg_type, self.username, self.session)
159 h = msg_header(self.msg_id, msg_type, self.username, self.session)
160 self.msg_id = str(uuid.uuid4())
160 self.msg_id = str(uuid.uuid4())
161 return h
161 return h
162
162
163 def msg(self, msg_type, content=None, parent=None, subheader=None):
163 def msg(self, msg_type, content=None, parent=None, subheader=None):
164 msg = {}
164 msg = {}
165 msg['header'] = self.msg_header(msg_type)
165 msg['header'] = self.msg_header(msg_type)
166 msg['msg_id'] = msg['header']['msg_id']
166 msg['msg_id'] = msg['header']['msg_id']
167 msg['parent_header'] = {} if parent is None else extract_header(parent)
167 msg['parent_header'] = {} if parent is None else extract_header(parent)
168 msg['msg_type'] = msg_type
168 msg['msg_type'] = msg_type
169 msg['content'] = {} if content is None else content
169 msg['content'] = {} if content is None else content
170 sub = {} if subheader is None else subheader
170 sub = {} if subheader is None else subheader
171 msg['header'].update(sub)
171 msg['header'].update(sub)
172 return msg
172 return msg
173
173
174 def check_key(self, msg_or_header):
174 def check_key(self, msg_or_header):
175 """Check that a message's header has the right key"""
175 """Check that a message's header has the right key"""
176 if self.key is None:
176 if self.key is None:
177 return True
177 return True
178 header = extract_header(msg_or_header)
178 header = extract_header(msg_or_header)
179 return header.get('key', None) == self.key
179 return header.get('key', None) == self.key
180
180
181
181
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
183 """Build and send a message via stream or socket.
183 """Build and send a message via stream or socket.
184
184
185 Parameters
185 Parameters
186 ----------
186 ----------
187
187
188 stream : zmq.Socket or ZMQStream
188 stream : zmq.Socket or ZMQStream
189 the socket-like object used to send the data
189 the socket-like object used to send the data
190 msg_or_type : str or Message/dict
190 msg_or_type : str or Message/dict
191 Normally, msg_or_type will be a msg_type unless a message is being sent more
191 Normally, msg_or_type will be a msg_type unless a message is being sent more
192 than once.
192 than once.
193
193
194 content : dict or None
195 the content of the message (ignored if msg_or_type is a message)
196 buffers : list or None
197 the already-serialized buffers to be appended to the message
198 parent : Message or dict or None
199 the parent or parent header describing the parent of this message
200 subheader : dict or None
201 extra header keys for this message's header
202 ident : bytes or list of bytes
203 the zmq.IDENTITY routing path
204 track : bool
205 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
206
194 Returns
207 Returns
195 -------
208 -------
196 (msg,sent) : tuple
209 msg : message dict
197 msg : Message
210 the constructed message
198 the nice wrapped dict-like object containing the headers
211 (msg,tracker) : (message dict, MessageTracker)
212 if track=True, then a 2-tuple will be returned, the first element being the constructed
213 message, and the second being the MessageTracker
199
214
200 """
215 """
216
217 if not isinstance(stream, (zmq.Socket, ZMQStream)):
218 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
219 elif track and isinstance(stream, ZMQStream):
220 raise TypeError("ZMQStream cannot track messages")
221
201 if isinstance(msg_or_type, (Message, dict)):
222 if isinstance(msg_or_type, (Message, dict)):
202 # we got a Message, not a msg_type
223 # we got a Message, not a msg_type
203 # don't build a new Message
224 # don't build a new Message
204 msg = msg_or_type
225 msg = msg_or_type
205 content = msg['content']
226 content = msg['content']
206 else:
227 else:
207 msg = self.msg(msg_or_type, content, parent, subheader)
228 msg = self.msg(msg_or_type, content, parent, subheader)
229
208 buffers = [] if buffers is None else buffers
230 buffers = [] if buffers is None else buffers
209 to_send = []
231 to_send = []
210 if isinstance(ident, list):
232 if isinstance(ident, list):
211 # accept list of idents
233 # accept list of idents
212 to_send.extend(ident)
234 to_send.extend(ident)
213 elif ident is not None:
235 elif ident is not None:
214 to_send.append(ident)
236 to_send.append(ident)
215 to_send.append(DELIM)
237 to_send.append(DELIM)
216 if self.key is not None:
238 if self.key is not None:
217 to_send.append(self.key)
239 to_send.append(self.key)
218 to_send.append(self.pack(msg['header']))
240 to_send.append(self.pack(msg['header']))
219 to_send.append(self.pack(msg['parent_header']))
241 to_send.append(self.pack(msg['parent_header']))
220
242
221 if content is None:
243 if content is None:
222 content = self.none
244 content = self.none
223 elif isinstance(content, dict):
245 elif isinstance(content, dict):
224 content = self.pack(content)
246 content = self.pack(content)
225 elif isinstance(content, str):
247 elif isinstance(content, bytes):
226 # content is already packed, as in a relayed message
248 # content is already packed, as in a relayed message
227 pass
249 pass
228 else:
250 else:
229 raise TypeError("Content incorrect type: %s"%type(content))
251 raise TypeError("Content incorrect type: %s"%type(content))
230 to_send.append(content)
252 to_send.append(content)
231 flag = 0
253 flag = 0
232 if buffers:
254 if buffers:
233 flag = zmq.SNDMORE
255 flag = zmq.SNDMORE
234 stream.send_multipart(to_send, flag, copy=False)
256 _track = False
257 else:
258 _track=track
259 if track:
260 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
261 else:
262 tracker = stream.send_multipart(to_send, flag, copy=False)
235 for b in buffers[:-1]:
263 for b in buffers[:-1]:
236 stream.send(b, flag, copy=False)
264 stream.send(b, flag, copy=False)
237 if buffers:
265 if buffers:
238 stream.send(buffers[-1], copy=False)
266 if track:
267 tracker = stream.send(buffers[-1], copy=False, track=track)
268 else:
269 tracker = stream.send(buffers[-1], copy=False)
270
239 # omsg = Message(msg)
271 # omsg = Message(msg)
240 if self.debug:
272 if self.debug:
241 pprint.pprint(msg)
273 pprint.pprint(msg)
242 pprint.pprint(to_send)
274 pprint.pprint(to_send)
243 pprint.pprint(buffers)
275 pprint.pprint(buffers)
276
277 msg['tracker'] = tracker
278
244 return msg
279 return msg
245
280
246 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
281 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
247 """Send a raw message via ident path.
282 """Send a raw message via ident path.
248
283
249 Parameters
284 Parameters
250 ----------
285 ----------
251 msg : list of sendable buffers"""
286 msg : list of sendable buffers"""
252 to_send = []
287 to_send = []
253 if isinstance(ident, str):
288 if isinstance(ident, bytes):
254 ident = [ident]
289 ident = [ident]
255 if ident is not None:
290 if ident is not None:
256 to_send.extend(ident)
291 to_send.extend(ident)
257 to_send.append(DELIM)
292 to_send.append(DELIM)
258 if self.key is not None:
293 if self.key is not None:
259 to_send.append(self.key)
294 to_send.append(self.key)
260 to_send.extend(msg)
295 to_send.extend(msg)
261 stream.send_multipart(msg, flags, copy=copy)
296 stream.send_multipart(msg, flags, copy=copy)
262
297
263 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
298 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
264 """receives and unpacks a message
299 """receives and unpacks a message
265 returns [idents], msg"""
300 returns [idents], msg"""
266 if isinstance(socket, ZMQStream):
301 if isinstance(socket, ZMQStream):
267 socket = socket.socket
302 socket = socket.socket
268 try:
303 try:
269 msg = socket.recv_multipart(mode)
304 msg = socket.recv_multipart(mode)
270 except zmq.ZMQError as e:
305 except zmq.ZMQError as e:
271 if e.errno == zmq.EAGAIN:
306 if e.errno == zmq.EAGAIN:
272 # We can convert EAGAIN to None as we know in this case
307 # We can convert EAGAIN to None as we know in this case
273 # recv_multipart won't return None.
308 # recv_multipart won't return None.
274 return None
309 return None
275 else:
310 else:
276 raise
311 raise
277 # return an actual Message object
312 # return an actual Message object
278 # determine the number of idents by trying to unpack them.
313 # determine the number of idents by trying to unpack them.
279 # this is terrible:
314 # this is terrible:
280 idents, msg = self.feed_identities(msg, copy)
315 idents, msg = self.feed_identities(msg, copy)
281 try:
316 try:
282 return idents, self.unpack_message(msg, content=content, copy=copy)
317 return idents, self.unpack_message(msg, content=content, copy=copy)
283 except Exception as e:
318 except Exception as e:
284 print (idents, msg)
319 print (idents, msg)
285 # TODO: handle it
320 # TODO: handle it
286 raise e
321 raise e
287
322
288 def feed_identities(self, msg, copy=True):
323 def feed_identities(self, msg, copy=True):
289 """feed until DELIM is reached, then return the prefix as idents and remainder as
324 """feed until DELIM is reached, then return the prefix as idents and remainder as
290 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
325 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
291
326
292 Parameters
327 Parameters
293 ----------
328 ----------
294 msg : a list of Message or bytes objects
329 msg : a list of Message or bytes objects
295 the message to be split
330 the message to be split
296 copy : bool
331 copy : bool
297 flag determining whether the arguments are bytes or Messages
332 flag determining whether the arguments are bytes or Messages
298
333
299 Returns
334 Returns
300 -------
335 -------
301 (idents,msg) : two lists
336 (idents,msg) : two lists
302 idents will always be a list of bytes - the indentity prefix
337 idents will always be a list of bytes - the indentity prefix
303 msg will be a list of bytes or Messages, unchanged from input
338 msg will be a list of bytes or Messages, unchanged from input
304 msg should be unpackable via self.unpack_message at this point.
339 msg should be unpackable via self.unpack_message at this point.
305 """
340 """
306 ikey = int(self.key is not None)
341 ikey = int(self.key is not None)
307 minlen = 3 + ikey
342 minlen = 3 + ikey
308 msg = list(msg)
343 msg = list(msg)
309 idents = []
344 idents = []
310 while len(msg) > minlen:
345 while len(msg) > minlen:
311 if copy:
346 if copy:
312 s = msg[0]
347 s = msg[0]
313 else:
348 else:
314 s = msg[0].bytes
349 s = msg[0].bytes
315 if s == DELIM:
350 if s == DELIM:
316 msg.pop(0)
351 msg.pop(0)
317 break
352 break
318 else:
353 else:
319 idents.append(s)
354 idents.append(s)
320 msg.pop(0)
355 msg.pop(0)
321
356
322 return idents, msg
357 return idents, msg
323
358
324 def unpack_message(self, msg, content=True, copy=True):
359 def unpack_message(self, msg, content=True, copy=True):
325 """Return a message object from the format
360 """Return a message object from the format
326 sent by self.send.
361 sent by self.send.
327
362
328 Parameters:
363 Parameters:
329 -----------
364 -----------
330
365
331 content : bool (True)
366 content : bool (True)
332 whether to unpack the content dict (True),
367 whether to unpack the content dict (True),
333 or leave it serialized (False)
368 or leave it serialized (False)
334
369
335 copy : bool (True)
370 copy : bool (True)
336 whether to return the bytes (True),
371 whether to return the bytes (True),
337 or the non-copying Message object in each place (False)
372 or the non-copying Message object in each place (False)
338
373
339 """
374 """
340 ikey = int(self.key is not None)
375 ikey = int(self.key is not None)
341 minlen = 3 + ikey
376 minlen = 3 + ikey
342 message = {}
377 message = {}
343 if not copy:
378 if not copy:
344 for i in range(minlen):
379 for i in range(minlen):
345 msg[i] = msg[i].bytes
380 msg[i] = msg[i].bytes
346 if ikey:
381 if ikey:
347 if not self.key == msg[0]:
382 if not self.key == msg[0]:
348 raise KeyError("Invalid Session Key: %s"%msg[0])
383 raise KeyError("Invalid Session Key: %s"%msg[0])
349 if not len(msg) >= minlen:
384 if not len(msg) >= minlen:
350 raise TypeError("malformed message, must have at least %i elements"%minlen)
385 raise TypeError("malformed message, must have at least %i elements"%minlen)
351 message['header'] = self.unpack(msg[ikey+0])
386 message['header'] = self.unpack(msg[ikey+0])
352 message['msg_type'] = message['header']['msg_type']
387 message['msg_type'] = message['header']['msg_type']
353 message['parent_header'] = self.unpack(msg[ikey+1])
388 message['parent_header'] = self.unpack(msg[ikey+1])
354 if content:
389 if content:
355 message['content'] = self.unpack(msg[ikey+2])
390 message['content'] = self.unpack(msg[ikey+2])
356 else:
391 else:
357 message['content'] = msg[ikey+2]
392 message['content'] = msg[ikey+2]
358
393
359 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
394 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
360 return message
395 return message
361
396
362
397
363 def test_msg2obj():
398 def test_msg2obj():
364 am = dict(x=1)
399 am = dict(x=1)
365 ao = Message(am)
400 ao = Message(am)
366 assert ao.x == am['x']
401 assert ao.x == am['x']
367
402
368 am['y'] = dict(z=1)
403 am['y'] = dict(z=1)
369 ao = Message(am)
404 ao = Message(am)
370 assert ao.y.z == am['y']['z']
405 assert ao.y.z == am['y']['z']
371
406
372 k1, k2 = 'y', 'z'
407 k1, k2 = 'y', 'z'
373 assert ao[k1][k2] == am[k1][k2]
408 assert ao[k1][k2] == am[k1][k2]
374
409
375 am2 = dict(ao)
410 am2 = dict(ao)
376 assert am['x'] == am2['x']
411 assert am['x'] == am2['x']
377 assert am['y']['z'] == am2['y']['z']
412 assert am['y']['z'] == am2['y']['z']
@@ -1,44 +1,46 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2
3 import tempfile
3 import time
4 import time
4 from subprocess import Popen, PIPE
5 from subprocess import Popen, PIPE, STDOUT
5
6
6 from IPython.zmq.parallel.ipcluster import launch_process
7 from IPython.zmq.parallel.ipcluster import launch_process
7 from IPython.zmq.parallel.entry_point import select_random_ports
8 from IPython.zmq.parallel.entry_point import select_random_ports
8
9
9 processes = []
10 processes = []
11 blackhole = tempfile.TemporaryFile()
10
12
11 # nose setup/teardown
13 # nose setup/teardown
12
14
13 def setup():
15 def setup():
14 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
16 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
15 processes.append(cp)
17 processes.append(cp)
16 time.sleep(.5)
18 time.sleep(.5)
17 add_engine()
19 add_engine()
18 time.sleep(3)
20 time.sleep(2)
19
21
20 def add_engine(profile='iptest'):
22 def add_engine(profile='iptest'):
21 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
23 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
22 # ep.start()
24 # ep.start()
23 processes.append(ep)
25 processes.append(ep)
24 return ep
26 return ep
25
27
26 def teardown():
28 def teardown():
27 time.sleep(1)
29 time.sleep(1)
28 while processes:
30 while processes:
29 p = processes.pop()
31 p = processes.pop()
30 if p.poll() is None:
32 if p.poll() is None:
31 try:
33 try:
32 p.terminate()
34 p.terminate()
33 except Exception, e:
35 except Exception, e:
34 print e
36 print e
35 pass
37 pass
36 if p.poll() is None:
38 if p.poll() is None:
37 time.sleep(.25)
39 time.sleep(.25)
38 if p.poll() is None:
40 if p.poll() is None:
39 try:
41 try:
40 print 'killing'
42 print 'killing'
41 p.kill()
43 p.kill()
42 except:
44 except:
43 print "couldn't shutdown process: ",p
45 print "couldn't shutdown process: ",p
44
46
@@ -1,98 +1,100 b''
1 import time
1 import time
2 from signal import SIGINT
2 from signal import SIGINT
3 from multiprocessing import Process
3 from multiprocessing import Process
4
4
5 from nose import SkipTest
5 from nose import SkipTest
6
6
7 from zmq.tests import BaseZMQTestCase
7 from zmq.tests import BaseZMQTestCase
8
8
9 from IPython.external.decorator import decorator
9 from IPython.external.decorator import decorator
10
10
11 from IPython.zmq.parallel import error
11 from IPython.zmq.parallel import error
12 from IPython.zmq.parallel.client import Client
12 from IPython.zmq.parallel.client import Client
13 from IPython.zmq.parallel.ipcluster import launch_process
13 from IPython.zmq.parallel.ipcluster import launch_process
14 from IPython.zmq.parallel.entry_point import select_random_ports
14 from IPython.zmq.parallel.entry_point import select_random_ports
15 from IPython.zmq.parallel.tests import processes,add_engine
15 from IPython.zmq.parallel.tests import processes,add_engine
16
16
17 # simple tasks for use in apply tests
17 # simple tasks for use in apply tests
18
18
19 def segfault():
19 def segfault():
20 """this will segfault"""
20 """this will segfault"""
21 import ctypes
21 import ctypes
22 ctypes.memset(-1,0,1)
22 ctypes.memset(-1,0,1)
23
23
24 def wait(n):
24 def wait(n):
25 """sleep for a time"""
25 """sleep for a time"""
26 import time
26 import time
27 time.sleep(n)
27 time.sleep(n)
28 return n
28 return n
29
29
30 def raiser(eclass):
30 def raiser(eclass):
31 """raise an exception"""
31 """raise an exception"""
32 raise eclass()
32 raise eclass()
33
33
34 # test decorator for skipping tests when libraries are unavailable
34 # test decorator for skipping tests when libraries are unavailable
35 def skip_without(*names):
35 def skip_without(*names):
36 """skip a test if some names are not importable"""
36 """skip a test if some names are not importable"""
37 @decorator
37 @decorator
38 def skip_without_names(f, *args, **kwargs):
38 def skip_without_names(f, *args, **kwargs):
39 """decorator to skip tests in the absence of numpy."""
39 """decorator to skip tests in the absence of numpy."""
40 for name in names:
40 for name in names:
41 try:
41 try:
42 __import__(name)
42 __import__(name)
43 except ImportError:
43 except ImportError:
44 raise SkipTest
44 raise SkipTest
45 return f(*args, **kwargs)
45 return f(*args, **kwargs)
46 return skip_without_names
46 return skip_without_names
47
47
48
48
49 class ClusterTestCase(BaseZMQTestCase):
49 class ClusterTestCase(BaseZMQTestCase):
50
50
51 def add_engines(self, n=1, block=True):
51 def add_engines(self, n=1, block=True):
52 """add multiple engines to our cluster"""
52 """add multiple engines to our cluster"""
53 for i in range(n):
53 for i in range(n):
54 self.engines.append(add_engine())
54 self.engines.append(add_engine())
55 if block:
55 if block:
56 self.wait_on_engines()
56 self.wait_on_engines()
57
57
58 def wait_on_engines(self, timeout=5):
58 def wait_on_engines(self, timeout=5):
59 """wait for our engines to connect."""
59 """wait for our engines to connect."""
60 n = len(self.engines)+self.base_engine_count
60 n = len(self.engines)+self.base_engine_count
61 tic = time.time()
61 tic = time.time()
62 while time.time()-tic < timeout and len(self.client.ids) < n:
62 while time.time()-tic < timeout and len(self.client.ids) < n:
63 time.sleep(0.1)
63 time.sleep(0.1)
64
64
65 assert not self.client.ids < n, "waiting for engines timed out"
65 assert not self.client.ids < n, "waiting for engines timed out"
66
66
67 def connect_client(self):
67 def connect_client(self):
68 """connect a client with my Context, and track its sockets for cleanup"""
68 """connect a client with my Context, and track its sockets for cleanup"""
69 c = Client(profile='iptest',context=self.context)
69 c = Client(profile='iptest',context=self.context)
70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
71 self.sockets.append(getattr(c, name))
71 self.sockets.append(getattr(c, name))
72 return c
72 return c
73
73
74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 try:
75 try:
76 try:
76 try:
77 f(*args, **kwargs)
77 f(*args, **kwargs)
78 except error.CompositeError as e:
78 except error.CompositeError as e:
79 e.raise_exception()
79 e.raise_exception()
80 except error.RemoteError as e:
80 except error.RemoteError as e:
81 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
81 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
82 else:
82 else:
83 self.fail("should have raised a RemoteError")
83 self.fail("should have raised a RemoteError")
84
84
85 def setUp(self):
85 def setUp(self):
86 BaseZMQTestCase.setUp(self)
86 BaseZMQTestCase.setUp(self)
87 self.client = self.connect_client()
87 self.client = self.connect_client()
88 self.base_engine_count=len(self.client.ids)
88 self.base_engine_count=len(self.client.ids)
89 self.engines=[]
89 self.engines=[]
90
90
91 # def tearDown(self):
91 def tearDown(self):
92 self.client.close()
93 BaseZMQTestCase.tearDown(self)
92 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
94 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
93 # [ e.wait() for e in self.engines ]
95 # [ e.wait() for e in self.engines ]
94 # while len(self.client.ids) > self.base_engine_count:
96 # while len(self.client.ids) > self.base_engine_count:
95 # time.sleep(.1)
97 # time.sleep(.1)
96 # del self.engines
98 # del self.engines
97 # BaseZMQTestCase.tearDown(self)
99 # BaseZMQTestCase.tearDown(self)
98 No newline at end of file
100
@@ -1,187 +1,252 b''
1 import time
1 import time
2 from tempfile import mktemp
2 from tempfile import mktemp
3
3
4 import nose.tools as nt
4 import nose.tools as nt
5 import zmq
5
6
6 from IPython.zmq.parallel import client as clientmod
7 from IPython.zmq.parallel import client as clientmod
7 from IPython.zmq.parallel import error
8 from IPython.zmq.parallel import error
8 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
9 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
9 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
10 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
10
11
11 from clienttest import ClusterTestCase, segfault, wait
12 from clienttest import ClusterTestCase, segfault, wait
12
13
13 class TestClient(ClusterTestCase):
14 class TestClient(ClusterTestCase):
14
15
15 def test_ids(self):
16 def test_ids(self):
16 n = len(self.client.ids)
17 n = len(self.client.ids)
17 self.add_engines(3)
18 self.add_engines(3)
18 self.assertEquals(len(self.client.ids), n+3)
19 self.assertEquals(len(self.client.ids), n+3)
19 self.assertTrue
20 self.assertTrue
20
21
21 def test_segfault(self):
22 def test_segfault_task(self):
22 """test graceful handling of engine death"""
23 """test graceful handling of engine death (balanced)"""
23 self.add_engines(1)
24 self.add_engines(1)
24 eid = self.client.ids[-1]
25 ar = self.client.apply(segfault, block=False)
25 ar = self.client.apply(segfault, block=False)
26 self.assertRaisesRemote(error.EngineError, ar.get)
26 self.assertRaisesRemote(error.EngineError, ar.get)
27 eid = ar.engine_id
27 eid = ar.engine_id
28 while eid in self.client.ids:
28 while eid in self.client.ids:
29 time.sleep(.01)
29 time.sleep(.01)
30 self.client.spin()
30 self.client.spin()
31
31
32 def test_segfault_mux(self):
33 """test graceful handling of engine death (direct)"""
34 self.add_engines(1)
35 eid = self.client.ids[-1]
36 ar = self.client[eid].apply_async(segfault)
37 self.assertRaisesRemote(error.EngineError, ar.get)
38 eid = ar.engine_id
39 while eid in self.client.ids:
40 time.sleep(.01)
41 self.client.spin()
42
32 def test_view_indexing(self):
43 def test_view_indexing(self):
33 """test index access for views"""
44 """test index access for views"""
34 self.add_engines(2)
45 self.add_engines(2)
35 targets = self.client._build_targets('all')[-1]
46 targets = self.client._build_targets('all')[-1]
36 v = self.client[:]
47 v = self.client[:]
37 self.assertEquals(v.targets, targets)
48 self.assertEquals(v.targets, targets)
38 t = self.client.ids[2]
49 t = self.client.ids[2]
39 v = self.client[t]
50 v = self.client[t]
40 self.assert_(isinstance(v, DirectView))
51 self.assert_(isinstance(v, DirectView))
41 self.assertEquals(v.targets, t)
52 self.assertEquals(v.targets, t)
42 t = self.client.ids[2:4]
53 t = self.client.ids[2:4]
43 v = self.client[t]
54 v = self.client[t]
44 self.assert_(isinstance(v, DirectView))
55 self.assert_(isinstance(v, DirectView))
45 self.assertEquals(v.targets, t)
56 self.assertEquals(v.targets, t)
46 v = self.client[::2]
57 v = self.client[::2]
47 self.assert_(isinstance(v, DirectView))
58 self.assert_(isinstance(v, DirectView))
48 self.assertEquals(v.targets, targets[::2])
59 self.assertEquals(v.targets, targets[::2])
49 v = self.client[1::3]
60 v = self.client[1::3]
50 self.assert_(isinstance(v, DirectView))
61 self.assert_(isinstance(v, DirectView))
51 self.assertEquals(v.targets, targets[1::3])
62 self.assertEquals(v.targets, targets[1::3])
52 v = self.client[:-3]
63 v = self.client[:-3]
53 self.assert_(isinstance(v, DirectView))
64 self.assert_(isinstance(v, DirectView))
54 self.assertEquals(v.targets, targets[:-3])
65 self.assertEquals(v.targets, targets[:-3])
55 v = self.client[-1]
66 v = self.client[-1]
56 self.assert_(isinstance(v, DirectView))
67 self.assert_(isinstance(v, DirectView))
57 self.assertEquals(v.targets, targets[-1])
68 self.assertEquals(v.targets, targets[-1])
58 nt.assert_raises(TypeError, lambda : self.client[None])
69 nt.assert_raises(TypeError, lambda : self.client[None])
59
70
60 def test_view_cache(self):
71 def test_view_cache(self):
61 """test that multiple view requests return the same object"""
72 """test that multiple view requests return the same object"""
62 v = self.client[:2]
73 v = self.client[:2]
63 v2 =self.client[:2]
74 v2 =self.client[:2]
64 self.assertTrue(v is v2)
75 self.assertTrue(v is v2)
65 v = self.client.view()
76 v = self.client.view()
66 v2 = self.client.view(balanced=True)
77 v2 = self.client.view(balanced=True)
67 self.assertTrue(v is v2)
78 self.assertTrue(v is v2)
68
79
69 def test_targets(self):
80 def test_targets(self):
70 """test various valid targets arguments"""
81 """test various valid targets arguments"""
71 build = self.client._build_targets
82 build = self.client._build_targets
72 ids = self.client.ids
83 ids = self.client.ids
73 idents,targets = build(None)
84 idents,targets = build(None)
74 self.assertEquals(ids, targets)
85 self.assertEquals(ids, targets)
75
86
76 def test_clear(self):
87 def test_clear(self):
77 """test clear behavior"""
88 """test clear behavior"""
78 self.add_engines(2)
89 self.add_engines(2)
79 self.client.block=True
90 self.client.block=True
80 self.client.push(dict(a=5))
91 self.client.push(dict(a=5))
81 self.client.pull('a')
92 self.client.pull('a')
82 id0 = self.client.ids[-1]
93 id0 = self.client.ids[-1]
83 self.client.clear(targets=id0)
94 self.client.clear(targets=id0)
84 self.client.pull('a', targets=self.client.ids[:-1])
95 self.client.pull('a', targets=self.client.ids[:-1])
85 self.assertRaisesRemote(NameError, self.client.pull, 'a')
96 self.assertRaisesRemote(NameError, self.client.pull, 'a')
86 self.client.clear()
97 self.client.clear()
87 for i in self.client.ids:
98 for i in self.client.ids:
88 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
99 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
89
100
90
101
91 def test_push_pull(self):
102 def test_push_pull(self):
92 """test pushing and pulling"""
103 """test pushing and pulling"""
93 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
104 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
105 t = self.client.ids[-1]
94 self.add_engines(2)
106 self.add_engines(2)
95 push = self.client.push
107 push = self.client.push
96 pull = self.client.pull
108 pull = self.client.pull
97 self.client.block=True
109 self.client.block=True
98 nengines = len(self.client)
110 nengines = len(self.client)
99 push({'data':data}, targets=0)
111 push({'data':data}, targets=t)
100 d = pull('data', targets=0)
112 d = pull('data', targets=t)
101 self.assertEquals(d, data)
113 self.assertEquals(d, data)
102 push({'data':data})
114 push({'data':data})
103 d = pull('data')
115 d = pull('data')
104 self.assertEquals(d, nengines*[data])
116 self.assertEquals(d, nengines*[data])
105 ar = push({'data':data}, block=False)
117 ar = push({'data':data}, block=False)
106 self.assertTrue(isinstance(ar, AsyncResult))
118 self.assertTrue(isinstance(ar, AsyncResult))
107 r = ar.get()
119 r = ar.get()
108 ar = pull('data', block=False)
120 ar = pull('data', block=False)
109 self.assertTrue(isinstance(ar, AsyncResult))
121 self.assertTrue(isinstance(ar, AsyncResult))
110 r = ar.get()
122 r = ar.get()
111 self.assertEquals(r, nengines*[data])
123 self.assertEquals(r, nengines*[data])
112 push(dict(a=10,b=20))
124 push(dict(a=10,b=20))
113 r = pull(('a','b'))
125 r = pull(('a','b'))
114 self.assertEquals(r, nengines*[[10,20]])
126 self.assertEquals(r, nengines*[[10,20]])
115
127
116 def test_push_pull_function(self):
128 def test_push_pull_function(self):
117 "test pushing and pulling functions"
129 "test pushing and pulling functions"
118 def testf(x):
130 def testf(x):
119 return 2.0*x
131 return 2.0*x
120
132
121 self.add_engines(4)
133 self.add_engines(4)
134 t = self.client.ids[-1]
122 self.client.block=True
135 self.client.block=True
123 push = self.client.push
136 push = self.client.push
124 pull = self.client.pull
137 pull = self.client.pull
125 execute = self.client.execute
138 execute = self.client.execute
126 push({'testf':testf}, targets=0)
139 push({'testf':testf}, targets=t)
127 r = pull('testf', targets=0)
140 r = pull('testf', targets=t)
128 self.assertEqual(r(1.0), testf(1.0))
141 self.assertEqual(r(1.0), testf(1.0))
129 execute('r = testf(10)', targets=0)
142 execute('r = testf(10)', targets=t)
130 r = pull('r', targets=0)
143 r = pull('r', targets=t)
131 self.assertEquals(r, testf(10))
144 self.assertEquals(r, testf(10))
132 ar = push({'testf':testf}, block=False)
145 ar = push({'testf':testf}, block=False)
133 ar.get()
146 ar.get()
134 ar = pull('testf', block=False)
147 ar = pull('testf', block=False)
135 rlist = ar.get()
148 rlist = ar.get()
136 for r in rlist:
149 for r in rlist:
137 self.assertEqual(r(1.0), testf(1.0))
150 self.assertEqual(r(1.0), testf(1.0))
138 execute("def g(x): return x*x", targets=0)
151 execute("def g(x): return x*x", targets=t)
139 r = pull(('testf','g'),targets=0)
152 r = pull(('testf','g'),targets=t)
140 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
153 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
141
154
142 def test_push_function_globals(self):
155 def test_push_function_globals(self):
143 """test that pushed functions have access to globals"""
156 """test that pushed functions have access to globals"""
144 def geta():
157 def geta():
145 return a
158 return a
146 self.add_engines(1)
159 self.add_engines(1)
147 v = self.client[-1]
160 v = self.client[-1]
148 v.block=True
161 v.block=True
149 v['f'] = geta
162 v['f'] = geta
150 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
163 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
151 v.execute('a=5')
164 v.execute('a=5')
152 v.execute('b=f()')
165 v.execute('b=f()')
153 self.assertEquals(v['b'], 5)
166 self.assertEquals(v['b'], 5)
154
167
155 def test_get_result(self):
168 def test_get_result(self):
156 """test getting results from the Hub."""
169 """test getting results from the Hub."""
157 c = clientmod.Client(profile='iptest')
170 c = clientmod.Client(profile='iptest')
158 t = self.client.ids[-1]
171 t = self.client.ids[-1]
159 ar = c.apply(wait, (1,), block=False, targets=t)
172 ar = c.apply(wait, (1,), block=False, targets=t)
160 time.sleep(.25)
173 time.sleep(.25)
161 ahr = self.client.get_result(ar.msg_ids)
174 ahr = self.client.get_result(ar.msg_ids)
162 self.assertTrue(isinstance(ahr, AsyncHubResult))
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
163 self.assertEquals(ahr.get(), ar.get())
176 self.assertEquals(ahr.get(), ar.get())
164 ar2 = self.client.get_result(ar.msg_ids)
177 ar2 = self.client.get_result(ar.msg_ids)
165 self.assertFalse(isinstance(ar2, AsyncHubResult))
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
166
179
167 def test_ids_list(self):
180 def test_ids_list(self):
168 """test client.ids"""
181 """test client.ids"""
169 self.add_engines(2)
182 self.add_engines(2)
170 ids = self.client.ids
183 ids = self.client.ids
171 self.assertEquals(ids, self.client._ids)
184 self.assertEquals(ids, self.client._ids)
172 self.assertFalse(ids is self.client._ids)
185 self.assertFalse(ids is self.client._ids)
173 ids.remove(ids[-1])
186 ids.remove(ids[-1])
174 self.assertNotEquals(ids, self.client._ids)
187 self.assertNotEquals(ids, self.client._ids)
175
188
176 def test_arun_newline(self):
189 def test_run_newline(self):
177 """test that run appends newline to files"""
190 """test that run appends newline to files"""
178 tmpfile = mktemp()
191 tmpfile = mktemp()
179 with open(tmpfile, 'w') as f:
192 with open(tmpfile, 'w') as f:
180 f.write("""def g():
193 f.write("""def g():
181 return 5
194 return 5
182 """)
195 """)
183 v = self.client[-1]
196 v = self.client[-1]
184 v.run(tmpfile, block=True)
197 v.run(tmpfile, block=True)
185 self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
198 self.assertEquals(v.apply_sync_bound(lambda : g()), 5)
186
199
187 No newline at end of file
200 def test_apply_tracked(self):
201 """test tracking for apply"""
202 # self.add_engines(1)
203 t = self.client.ids[-1]
204 self.client.block=False
205 def echo(n=1024*1024, **kwargs):
206 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
207 ar = echo(1)
208 self.assertTrue(ar._tracker is None)
209 self.assertTrue(ar.sent)
210 ar = echo(track=True)
211 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 self.assertEquals(ar.sent, ar._tracker.done)
213 ar._tracker.wait()
214 self.assertTrue(ar.sent)
215
216 def test_push_tracked(self):
217 t = self.client.ids[-1]
218 ns = dict(x='x'*1024*1024)
219 ar = self.client.push(ns, targets=t, block=False)
220 self.assertTrue(ar._tracker is None)
221 self.assertTrue(ar.sent)
222
223 ar = self.client.push(ns, targets=t, block=False, track=True)
224 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
225 self.assertEquals(ar.sent, ar._tracker.done)
226 ar._tracker.wait()
227 self.assertTrue(ar.sent)
228 ar.get()
229
230 def test_scatter_tracked(self):
231 t = self.client.ids
232 x='x'*1024*1024
233 ar = self.client.scatter('x', x, targets=t, block=False)
234 self.assertTrue(ar._tracker is None)
235 self.assertTrue(ar.sent)
236
237 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
238 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
239 self.assertEquals(ar.sent, ar._tracker.done)
240 ar._tracker.wait()
241 self.assertTrue(ar.sent)
242 ar.get()
243
244 def test_remote_reference(self):
245 v = self.client[-1]
246 v['a'] = 123
247 ra = clientmod.Reference('a')
248 b = v.apply_sync_bound(lambda x: x, ra)
249 self.assertEquals(b, 123)
250 self.assertRaisesRemote(NameError, v.apply_sync, lambda x: x, ra)
251
252
@@ -1,82 +1,99 b''
1
1
2 import os
2 import os
3 import uuid
3 import uuid
4 import zmq
4 import zmq
5
5
6 from zmq.tests import BaseZMQTestCase
6 from zmq.tests import BaseZMQTestCase
7
7 from zmq.eventloop.zmqstream import ZMQStream
8 # from IPython.zmq.tests import SessionTestCase
8 # from IPython.zmq.tests import SessionTestCase
9 from IPython.zmq.parallel import streamsession as ss
9 from IPython.zmq.parallel import streamsession as ss
10
10
11 class SessionTestCase(BaseZMQTestCase):
11 class SessionTestCase(BaseZMQTestCase):
12
12
13 def setUp(self):
13 def setUp(self):
14 BaseZMQTestCase.setUp(self)
14 BaseZMQTestCase.setUp(self)
15 self.session = ss.StreamSession()
15 self.session = ss.StreamSession()
16
16
17 class TestSession(SessionTestCase):
17 class TestSession(SessionTestCase):
18
18
19 def test_msg(self):
19 def test_msg(self):
20 """message format"""
20 """message format"""
21 msg = self.session.msg('execute')
21 msg = self.session.msg('execute')
22 thekeys = set('header msg_id parent_header msg_type content'.split())
22 thekeys = set('header msg_id parent_header msg_type content'.split())
23 s = set(msg.keys())
23 s = set(msg.keys())
24 self.assertEquals(s, thekeys)
24 self.assertEquals(s, thekeys)
25 self.assertTrue(isinstance(msg['content'],dict))
25 self.assertTrue(isinstance(msg['content'],dict))
26 self.assertTrue(isinstance(msg['header'],dict))
26 self.assertTrue(isinstance(msg['header'],dict))
27 self.assertTrue(isinstance(msg['parent_header'],dict))
27 self.assertTrue(isinstance(msg['parent_header'],dict))
28 self.assertEquals(msg['msg_type'], 'execute')
28 self.assertEquals(msg['msg_type'], 'execute')
29
29
30
30
31
31
32 def test_args(self):
32 def test_args(self):
33 """initialization arguments for StreamSession"""
33 """initialization arguments for StreamSession"""
34 s = ss.StreamSession()
34 s = self.session
35 self.assertTrue(s.pack is ss.default_packer)
35 self.assertTrue(s.pack is ss.default_packer)
36 self.assertTrue(s.unpack is ss.default_unpacker)
36 self.assertTrue(s.unpack is ss.default_unpacker)
37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
38
38
39 s = ss.StreamSession(username=None)
39 s = ss.StreamSession(username=None)
40 self.assertEquals(s.username, os.environ.get('USER', 'username'))
40 self.assertEquals(s.username, os.environ.get('USER', 'username'))
41
41
42 self.assertRaises(TypeError, ss.StreamSession, packer='hi')
42 self.assertRaises(TypeError, ss.StreamSession, packer='hi')
43 self.assertRaises(TypeError, ss.StreamSession, unpacker='hi')
43 self.assertRaises(TypeError, ss.StreamSession, unpacker='hi')
44 u = str(uuid.uuid4())
44 u = str(uuid.uuid4())
45 s = ss.StreamSession(username='carrot', session=u)
45 s = ss.StreamSession(username='carrot', session=u)
46 self.assertEquals(s.session, u)
46 self.assertEquals(s.session, u)
47 self.assertEquals(s.username, 'carrot')
47 self.assertEquals(s.username, 'carrot')
48
48
49
49 def test_tracking(self):
50 """test tracking messages"""
51 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
52 s = self.session
53 stream = ZMQStream(a)
54 msg = s.send(a, 'hello', track=False)
55 self.assertTrue(msg['tracker'] is None)
56 msg = s.send(a, 'hello', track=True)
57 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
58 M = zmq.Message(b'hi there', track=True)
59 msg = s.send(a, 'hello', buffers=[M], track=True)
60 t = msg['tracker']
61 self.assertTrue(isinstance(t, zmq.MessageTracker))
62 self.assertRaises(zmq.NotDone, t.wait, .1)
63 del M
64 t.wait(1) # this will raise
65
66
50 # def test_rekey(self):
67 # def test_rekey(self):
51 # """rekeying dict around json str keys"""
68 # """rekeying dict around json str keys"""
52 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
69 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
53 # self.assertRaises(KeyError, ss.rekey, d)
70 # self.assertRaises(KeyError, ss.rekey, d)
54 #
71 #
55 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
72 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
56 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
73 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
57 # rd = ss.rekey(d)
74 # rd = ss.rekey(d)
58 # self.assertEquals(d2,rd)
75 # self.assertEquals(d2,rd)
59 #
76 #
60 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
77 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
61 # d2 = {1.5:d['1.5'],1:d['1']}
78 # d2 = {1.5:d['1.5'],1:d['1']}
62 # rd = ss.rekey(d)
79 # rd = ss.rekey(d)
63 # self.assertEquals(d2,rd)
80 # self.assertEquals(d2,rd)
64 #
81 #
65 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
82 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
66 # self.assertRaises(KeyError, ss.rekey, d)
83 # self.assertRaises(KeyError, ss.rekey, d)
67 #
84 #
68 def test_unique_msg_ids(self):
85 def test_unique_msg_ids(self):
69 """test that messages receive unique ids"""
86 """test that messages receive unique ids"""
70 ids = set()
87 ids = set()
71 for i in range(2**12):
88 for i in range(2**12):
72 h = self.session.msg_header('test')
89 h = self.session.msg_header('test')
73 msg_id = h['msg_id']
90 msg_id = h['msg_id']
74 self.assertTrue(msg_id not in ids)
91 self.assertTrue(msg_id not in ids)
75 ids.add(msg_id)
92 ids.add(msg_id)
76
93
77 def test_feed_identities(self):
94 def test_feed_identities(self):
78 """scrub the front for zmq IDENTITIES"""
95 """scrub the front for zmq IDENTITIES"""
79 theids = "engine client other".split()
96 theids = "engine client other".split()
80 content = dict(code='whoda',stuff=object())
97 content = dict(code='whoda',stuff=object())
81 themsg = self.session.msg('execute',content=content)
98 themsg = self.session.msg('execute',content=content)
82 pmsg = theids
99 pmsg = theids
General Comments 0
You need to be logged in to leave comments. Login now