##// END OF EJS Templates
testing fixes
MinRK -
Show More
@@ -1,294 +1,294 b''
1 1 """AsyncResult objects for the client"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import time
14 14
15 15 from IPython.external.decorator import decorator
16 16 import error
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Classes
20 20 #-----------------------------------------------------------------------------
21 21
22 22 @decorator
23 23 def check_ready(f, self, *args, **kwargs):
24 24 """Call spin() to sync state prior to calling the method."""
25 25 self.wait(0)
26 26 if not self._ready:
27 27 raise error.TimeoutError("result not ready")
28 28 return f(self, *args, **kwargs)
29 29
30 30 class AsyncResult(object):
31 31 """Class for representing results of non-blocking calls.
32 32
33 33 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
34 34 """
35 35
36 36 msg_ids = None
37 37
38 def __init__(self, client, msg_ids, fname=''):
38 def __init__(self, client, msg_ids, fname='unknown'):
39 39 self._client = client
40 40 if isinstance(msg_ids, basestring):
41 41 msg_ids = [msg_ids]
42 42 self.msg_ids = msg_ids
43 43 self._fname=fname
44 44 self._ready = False
45 45 self._success = None
46 46 self._single_result = len(msg_ids) == 1
47 47
48 48 def __repr__(self):
49 49 if self._ready:
50 50 return "<%s: finished>"%(self.__class__.__name__)
51 51 else:
52 52 return "<%s: %s>"%(self.__class__.__name__,self._fname)
53 53
54 54
55 55 def _reconstruct_result(self, res):
56 56 """
57 57 Override me in subclasses for turning a list of results
58 58 into the expected form.
59 59 """
60 60 if self._single_result:
61 61 return res[0]
62 62 else:
63 63 return res
64 64
65 65 def get(self, timeout=-1):
66 66 """Return the result when it arrives.
67 67
68 68 If `timeout` is not ``None`` and the result does not arrive within
69 69 `timeout` seconds then ``TimeoutError`` is raised. If the
70 70 remote call raised an exception then that exception will be reraised
71 71 by get().
72 72 """
73 73 if not self.ready():
74 74 self.wait(timeout)
75 75
76 76 if self._ready:
77 77 if self._success:
78 78 return self._result
79 79 else:
80 80 raise self._exception
81 81 else:
82 82 raise error.TimeoutError("Result not ready.")
83 83
84 84 def ready(self):
85 85 """Return whether the call has completed."""
86 86 if not self._ready:
87 87 self.wait(0)
88 88 return self._ready
89 89
90 90 def wait(self, timeout=-1):
91 91 """Wait until the result is available or until `timeout` seconds pass.
92 92 """
93 93 if self._ready:
94 94 return
95 95 self._ready = self._client.barrier(self.msg_ids, timeout)
96 96 if self._ready:
97 97 try:
98 98 results = map(self._client.results.get, self.msg_ids)
99 99 self._result = results
100 100 if self._single_result:
101 101 r = results[0]
102 102 if isinstance(r, Exception):
103 103 raise r
104 104 else:
105 105 results = error.collect_exceptions(results, self._fname)
106 106 self._result = self._reconstruct_result(results)
107 107 except Exception, e:
108 108 self._exception = e
109 109 self._success = False
110 110 else:
111 111 self._success = True
112 112 finally:
113 113 self._metadata = map(self._client.metadata.get, self.msg_ids)
114 114
115 115
116 116 def successful(self):
117 117 """Return whether the call completed without raising an exception.
118 118
119 119 Will raise ``AssertionError`` if the result is not ready.
120 120 """
121 121 assert self._ready
122 122 return self._success
123 123
124 124 #----------------------------------------------------------------
125 125 # Extra methods not in mp.pool.AsyncResult
126 126 #----------------------------------------------------------------
127 127
128 128 def get_dict(self, timeout=-1):
129 129 """Get the results as a dict, keyed by engine_id."""
130 130 results = self.get(timeout)
131 131 engine_ids = [ md['engine_id'] for md in self._metadata ]
132 132 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
133 133 maxcount = bycount.count(bycount[-1])
134 134 if maxcount > 1:
135 135 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
136 136 maxcount, bycount[-1]))
137 137
138 138 return dict(zip(engine_ids,results))
139 139
140 140 @property
141 141 @check_ready
142 142 def result(self):
143 143 """result property."""
144 144 return self._result
145 145
146 146 # abbreviated alias:
147 147 r = result
148 148
149 149 @property
150 150 @check_ready
151 151 def metadata(self):
152 152 """metadata property."""
153 153 if self._single_result:
154 154 return self._metadata[0]
155 155 else:
156 156 return self._metadata
157 157
158 158 @property
159 159 def result_dict(self):
160 160 """result property as a dict."""
161 161 return self.get_dict(0)
162 162
163 163 def __dict__(self):
164 164 return self.get_dict(0)
165 165
166 166 #-------------------------------------
167 167 # dict-access
168 168 #-------------------------------------
169 169
170 170 @check_ready
171 171 def __getitem__(self, key):
172 172 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
173 173 """
174 174 if isinstance(key, int):
175 175 return error.collect_exceptions([self._result[key]], self._fname)[0]
176 176 elif isinstance(key, slice):
177 177 return error.collect_exceptions(self._result[key], self._fname)
178 178 elif isinstance(key, basestring):
179 179 values = [ md[key] for md in self._metadata ]
180 180 if self._single_result:
181 181 return values[0]
182 182 else:
183 183 return values
184 184 else:
185 185 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
186 186
187 187 @check_ready
188 188 def __getattr__(self, key):
189 189 """getattr maps to getitem for convenient access to metadata."""
190 190 if key not in self._metadata[0].keys():
191 191 raise AttributeError("%r object has no attribute %r"%(
192 192 self.__class__.__name__, key))
193 193 return self.__getitem__(key)
194 194
195 195 # asynchronous iterator:
196 196 def __iter__(self):
197 197 if self._single_result:
198 198 raise TypeError("AsyncResults with a single result are not iterable.")
199 199 try:
200 200 rlist = self.get(0)
201 201 except error.TimeoutError:
202 202 # wait for each result individually
203 203 for msg_id in self.msg_ids:
204 204 ar = AsyncResult(self._client, msg_id, self._fname)
205 205 yield ar.get()
206 206 else:
207 207 # already done
208 208 for r in rlist:
209 209 yield r
210 210
211 211
212 212
213 213 class AsyncMapResult(AsyncResult):
214 214 """Class for representing results of non-blocking gathers.
215 215
216 216 This will properly reconstruct the gather.
217 217 """
218 218
219 219 def __init__(self, client, msg_ids, mapObject, fname=''):
220 220 AsyncResult.__init__(self, client, msg_ids, fname=fname)
221 221 self._mapObject = mapObject
222 222 self._single_result = False
223 223
224 224 def _reconstruct_result(self, res):
225 225 """Perform the gather on the actual results."""
226 226 return self._mapObject.joinPartitions(res)
227 227
228 228 # asynchronous iterator:
229 229 def __iter__(self):
230 230 try:
231 231 rlist = self.get(0)
232 232 except error.TimeoutError:
233 233 # wait for each result individually
234 234 for msg_id in self.msg_ids:
235 235 ar = AsyncResult(self._client, msg_id, self._fname)
236 236 rlist = ar.get()
237 237 try:
238 238 for r in rlist:
239 239 yield r
240 240 except TypeError:
241 241 # flattened, not a list
242 242 # this could get broken by flattened data that returns iterables
243 243 # but most calls to map do not expose the `flatten` argument
244 244 yield rlist
245 245 else:
246 246 # already done
247 247 for r in rlist:
248 248 yield r
249 249
250 250
251 251 class AsyncHubResult(AsyncResult):
252 252 """Class to wrap pending results that must be requested from the Hub"""
253 253
254 254 def wait(self, timeout=-1):
255 255 """wait for result to complete."""
256 256 start = time.time()
257 257 if self._ready:
258 258 return
259 259 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
260 260 local_ready = self._client.barrier(local_ids, timeout)
261 261 if local_ready:
262 262 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
263 263 if not remote_ids:
264 264 self._ready = True
265 265 else:
266 266 rdict = self._client.result_status(remote_ids, status_only=False)
267 267 pending = rdict['pending']
268 while pending and time.time() < start+timeout:
268 while pending and (timeout < 0 or time.time() < start+timeout):
269 269 rdict = self._client.result_status(remote_ids, status_only=False)
270 270 pending = rdict['pending']
271 271 if pending:
272 272 time.sleep(0.1)
273 273 if not pending:
274 274 self._ready = True
275 275 if self._ready:
276 276 try:
277 277 results = map(self._client.results.get, self.msg_ids)
278 278 self._result = results
279 279 if self._single_result:
280 280 r = results[0]
281 281 if isinstance(r, Exception):
282 282 raise r
283 283 else:
284 284 results = error.collect_exceptions(results, self._fname)
285 285 self._result = self._reconstruct_result(results)
286 286 except Exception, e:
287 287 self._exception = e
288 288 self._success = False
289 289 else:
290 290 self._success = True
291 291 finally:
292 292 self._metadata = map(self._client.metadata.get, self.msg_ids)
293 293
294 294 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
@@ -1,1497 +1,1498 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 28 Dict, List, Bool, Str, Set)
29 29 from IPython.external.decorator import decorator
30 30 from IPython.external.ssh import tunnel
31 31
32 32 import error
33 33 import map as Map
34 34 import streamsession as ss
35 35 from asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
36 36 from clusterdir import ClusterDir, ClusterDirError
37 37 from dependency import Dependency, depend, require, dependent
38 38 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
39 39 from util import ReverseDict, disambiguate_url, validate_url
40 40 from view import DirectView, LoadBalancedView
41 41
42 42 #--------------------------------------------------------------------------
43 43 # helpers for implementing old MEC API via client.apply
44 44 #--------------------------------------------------------------------------
45 45
46 46 def _push(ns):
47 47 """helper method for implementing `client.push` via `client.apply`"""
48 48 globals().update(ns)
49 49
50 50 def _pull(keys):
51 51 """helper method for implementing `client.pull` via `client.apply`"""
52 52 g = globals()
53 53 if isinstance(keys, (list,tuple, set)):
54 54 for key in keys:
55 55 if not g.has_key(key):
56 56 raise NameError("name '%s' is not defined"%key)
57 57 return map(g.get, keys)
58 58 else:
59 59 if not g.has_key(keys):
60 60 raise NameError("name '%s' is not defined"%keys)
61 61 return g.get(keys)
62 62
63 63 def _clear():
64 64 """helper method for implementing `client.clear` via `client.apply`"""
65 65 globals().clear()
66 66
67 67 def _execute(code):
68 68 """helper method for implementing `client.execute` via `client.apply`"""
69 69 exec code in globals()
70 70
71 71
72 72 #--------------------------------------------------------------------------
73 73 # Decorators for Client methods
74 74 #--------------------------------------------------------------------------
75 75
76 76 @decorator
77 77 def spinfirst(f, self, *args, **kwargs):
78 78 """Call spin() to sync state prior to calling the method."""
79 79 self.spin()
80 80 return f(self, *args, **kwargs)
81 81
82 82 @decorator
83 83 def defaultblock(f, self, *args, **kwargs):
84 84 """Default to self.block; preserve self.block."""
85 85 block = kwargs.get('block',None)
86 86 block = self.block if block is None else block
87 87 saveblock = self.block
88 88 self.block = block
89 89 try:
90 90 ret = f(self, *args, **kwargs)
91 91 finally:
92 92 self.block = saveblock
93 93 return ret
94 94
95 95
96 96 #--------------------------------------------------------------------------
97 97 # Classes
98 98 #--------------------------------------------------------------------------
99 99
100 100 class Metadata(dict):
101 101 """Subclass of dict for initializing metadata values.
102 102
103 103 Attribute access works on keys.
104 104
105 105 These objects have a strict set of keys - errors will raise if you try
106 106 to add new keys.
107 107 """
108 108 def __init__(self, *args, **kwargs):
109 109 dict.__init__(self)
110 110 md = {'msg_id' : None,
111 111 'submitted' : None,
112 112 'started' : None,
113 113 'completed' : None,
114 114 'received' : None,
115 115 'engine_uuid' : None,
116 116 'engine_id' : None,
117 117 'follow' : None,
118 118 'after' : None,
119 119 'status' : None,
120 120
121 121 'pyin' : None,
122 122 'pyout' : None,
123 123 'pyerr' : None,
124 124 'stdout' : '',
125 125 'stderr' : '',
126 126 }
127 127 self.update(md)
128 128 self.update(dict(*args, **kwargs))
129 129
130 130 def __getattr__(self, key):
131 131 """getattr aliased to getitem"""
132 132 if key in self.iterkeys():
133 133 return self[key]
134 134 else:
135 135 raise AttributeError(key)
136 136
137 137 def __setattr__(self, key, value):
138 138 """setattr aliased to setitem, with strict"""
139 139 if key in self.iterkeys():
140 140 self[key] = value
141 141 else:
142 142 raise AttributeError(key)
143 143
144 144 def __setitem__(self, key, value):
145 145 """strict static key enforcement"""
146 146 if key in self.iterkeys():
147 147 dict.__setitem__(self, key, value)
148 148 else:
149 149 raise KeyError(key)
150 150
151 151
152 152 class Client(HasTraits):
153 153 """A semi-synchronous client to the IPython ZMQ controller
154 154
155 155 Parameters
156 156 ----------
157 157
158 158 url_or_file : bytes; zmq url or path to ipcontroller-client.json
159 159 Connection information for the Hub's registration. If a json connector
160 160 file is given, then likely no further configuration is necessary.
161 161 [Default: use profile]
162 162 profile : bytes
163 163 The name of the Cluster profile to be used to find connector information.
164 164 [Default: 'default']
165 165 context : zmq.Context
166 166 Pass an existing zmq.Context instance, otherwise the client will create its own.
167 167 username : bytes
168 168 set username to be passed to the Session object
169 169 debug : bool
170 170 flag for lots of message printing for debug purposes
171 171
172 172 #-------------- ssh related args ----------------
173 173 # These are args for configuring the ssh tunnel to be used
174 174 # credentials are used to forward connections over ssh to the Controller
175 175 # Note that the ip given in `addr` needs to be relative to sshserver
176 176 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
177 177 # and set sshserver as the same machine the Controller is on. However,
178 178 # the only requirement is that sshserver is able to see the Controller
179 179 # (i.e. is within the same trusted network).
180 180
181 181 sshserver : str
182 182 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
183 183 If keyfile or password is specified, and this is not, it will default to
184 184 the ip given in addr.
185 185 sshkey : str; path to public ssh key file
186 186 This specifies a key to be used in ssh login, default None.
187 187 Regular default ssh keys will be used without specifying this argument.
188 188 password : str
189 189 Your ssh password to sshserver. Note that if this is left None,
190 190 you will be prompted for it if passwordless key based login is unavailable.
191 191 paramiko : bool
192 192 flag for whether to use paramiko instead of shell ssh for tunneling.
193 193 [default: True on win32, False else]
194 194
195 195 #------- exec authentication args -------
196 196 # If even localhost is untrusted, you can have some protection against
197 197 # unauthorized execution by using a key. Messages are still sent
198 198 # as cleartext, so if someone can snoop your loopback traffic this will
199 199 # not help against malicious attacks.
200 200
201 201 exec_key : str
202 202 an authentication key or file containing a key
203 203 default: None
204 204
205 205
206 206 Attributes
207 207 ----------
208 208
209 209 ids : set of int engine IDs
210 210 requesting the ids attribute always synchronizes
211 211 the registration state. To request ids without synchronization,
212 212 use semi-private _ids attributes.
213 213
214 214 history : list of msg_ids
215 215 a list of msg_ids, keeping track of all the execution
216 216 messages you have submitted in order.
217 217
218 218 outstanding : set of msg_ids
219 219 a set of msg_ids that have been submitted, but whose
220 220 results have not yet been received.
221 221
222 222 results : dict
223 223 a dict of all our results, keyed by msg_id
224 224
225 225 block : bool
226 226 determines default behavior when block not specified
227 227 in execution methods
228 228
229 229 Methods
230 230 -------
231 231
232 232 spin
233 233 flushes incoming results and registration state changes
234 234 control methods spin, and requesting `ids` also ensures up to date
235 235
236 236 barrier
237 237 wait on one or more msg_ids
238 238
239 239 execution methods
240 240 apply
241 241 legacy: execute, run
242 242
243 243 query methods
244 244 queue_status, get_result, purge
245 245
246 246 control methods
247 247 abort, shutdown
248 248
249 249 """
250 250
251 251
252 252 block = Bool(False)
253 253 outstanding=Set()
254 254 results = Dict()
255 255 metadata = Dict()
256 256 history = List()
257 257 debug = Bool(False)
258 258 profile=CUnicode('default')
259 259
260 260 _ids = List()
261 261 _connected=Bool(False)
262 262 _ssh=Bool(False)
263 263 _context = Instance('zmq.Context')
264 264 _config = Dict()
265 265 _engines=Instance(ReverseDict, (), {})
266 266 _registration_socket=Instance('zmq.Socket')
267 267 _query_socket=Instance('zmq.Socket')
268 268 _control_socket=Instance('zmq.Socket')
269 269 _iopub_socket=Instance('zmq.Socket')
270 270 _notification_socket=Instance('zmq.Socket')
271 271 _mux_socket=Instance('zmq.Socket')
272 272 _task_socket=Instance('zmq.Socket')
273 273 _task_scheme=Str()
274 274 _balanced_views=Dict()
275 275 _direct_views=Dict()
276 276 _closed = False
277 277
278 278 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
279 279 context=None, username=None, debug=False, exec_key=None,
280 280 sshserver=None, sshkey=None, password=None, paramiko=None,
281 281 ):
282 282 super(Client, self).__init__(debug=debug, profile=profile)
283 283 if context is None:
284 284 context = zmq.Context()
285 285 self._context = context
286 286
287 287
288 288 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
289 289 if self._cd is not None:
290 290 if url_or_file is None:
291 291 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
292 292 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
293 293 " Please specify at least one of url_or_file or profile."
294 294
295 295 try:
296 296 validate_url(url_or_file)
297 297 except AssertionError:
298 298 if not os.path.exists(url_or_file):
299 299 if self._cd:
300 300 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
301 301 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
302 302 with open(url_or_file) as f:
303 303 cfg = json.loads(f.read())
304 304 else:
305 305 cfg = {'url':url_or_file}
306 306
307 307 # sync defaults from args, json:
308 308 if sshserver:
309 309 cfg['ssh'] = sshserver
310 310 if exec_key:
311 311 cfg['exec_key'] = exec_key
312 312 exec_key = cfg['exec_key']
313 313 sshserver=cfg['ssh']
314 314 url = cfg['url']
315 315 location = cfg.setdefault('location', None)
316 316 cfg['url'] = disambiguate_url(cfg['url'], location)
317 317 url = cfg['url']
318 318
319 319 self._config = cfg
320 320
321 321 self._ssh = bool(sshserver or sshkey or password)
322 322 if self._ssh and sshserver is None:
323 323 # default to ssh via localhost
324 324 sshserver = url.split('://')[1].split(':')[0]
325 325 if self._ssh and password is None:
326 326 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
327 327 password=False
328 328 else:
329 329 password = getpass("SSH Password for %s: "%sshserver)
330 330 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
331 331 if exec_key is not None and os.path.isfile(exec_key):
332 332 arg = 'keyfile'
333 333 else:
334 334 arg = 'key'
335 335 key_arg = {arg:exec_key}
336 336 if username is None:
337 337 self.session = ss.StreamSession(**key_arg)
338 338 else:
339 339 self.session = ss.StreamSession(username, **key_arg)
340 340 self._registration_socket = self._context.socket(zmq.XREQ)
341 341 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
342 342 if self._ssh:
343 343 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
344 344 else:
345 345 self._registration_socket.connect(url)
346 346
347 347 self.session.debug = self.debug
348 348
349 349 self._notification_handlers = {'registration_notification' : self._register_engine,
350 350 'unregistration_notification' : self._unregister_engine,
351 351 }
352 352 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
353 353 'apply_reply' : self._handle_apply_reply}
354 354 self._connect(sshserver, ssh_kwargs)
355 355
356 356
357 357 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
358 358 if ipython_dir is None:
359 359 ipython_dir = get_ipython_dir()
360 360 if cluster_dir is not None:
361 361 try:
362 362 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
363 return
363 364 except ClusterDirError:
364 365 pass
365 366 elif profile is not None:
366 367 try:
367 368 self._cd = ClusterDir.find_cluster_dir_by_profile(
368 369 ipython_dir, profile)
370 return
369 371 except ClusterDirError:
370 372 pass
371 else:
372 self._cd = None
373 self._cd = None
373 374
374 375 @property
375 376 def ids(self):
376 377 """Always up-to-date ids property."""
377 378 self._flush_notifications()
378 379 return self._ids
379 380
380 381 def close(self):
381 382 if self._closed:
382 383 return
383 384 snames = filter(lambda n: n.endswith('socket'), dir(self))
384 385 for socket in map(lambda name: getattr(self, name), snames):
385 386 socket.close()
386 387 self._closed = True
387 388
388 389 def _update_engines(self, engines):
389 390 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
390 391 for k,v in engines.iteritems():
391 392 eid = int(k)
392 393 self._engines[eid] = bytes(v) # force not unicode
393 394 self._ids.append(eid)
394 395 self._ids = sorted(self._ids)
395 396 if sorted(self._engines.keys()) != range(len(self._engines)) and \
396 397 self._task_scheme == 'pure' and self._task_socket:
397 398 self._stop_scheduling_tasks()
398 399
399 400 def _stop_scheduling_tasks(self):
400 401 """Stop scheduling tasks because an engine has been unregistered
401 402 from a pure ZMQ scheduler.
402 403 """
403 404
404 405 self._task_socket.close()
405 406 self._task_socket = None
406 407 msg = "An engine has been unregistered, and we are using pure " +\
407 408 "ZMQ task scheduling. Task farming will be disabled."
408 409 if self.outstanding:
409 410 msg += " If you were running tasks when this happened, " +\
410 411 "some `outstanding` msg_ids may never resolve."
411 412 warnings.warn(msg, RuntimeWarning)
412 413
413 414 def _build_targets(self, targets):
414 415 """Turn valid target IDs or 'all' into two lists:
415 416 (int_ids, uuids).
416 417 """
417 418 if targets is None:
418 419 targets = self._ids
419 420 elif isinstance(targets, str):
420 421 if targets.lower() == 'all':
421 422 targets = self._ids
422 423 else:
423 424 raise TypeError("%r not valid str target, must be 'all'"%(targets))
424 425 elif isinstance(targets, int):
425 426 targets = [targets]
426 427 return [self._engines[t] for t in targets], list(targets)
427 428
428 429 def _connect(self, sshserver, ssh_kwargs):
429 430 """setup all our socket connections to the controller. This is called from
430 431 __init__."""
431 432
432 433 # Maybe allow reconnecting?
433 434 if self._connected:
434 435 return
435 436 self._connected=True
436 437
437 438 def connect_socket(s, url):
438 439 url = disambiguate_url(url, self._config['location'])
439 440 if self._ssh:
440 441 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
441 442 else:
442 443 return s.connect(url)
443 444
444 445 self.session.send(self._registration_socket, 'connection_request')
445 446 idents,msg = self.session.recv(self._registration_socket,mode=0)
446 447 if self.debug:
447 448 pprint(msg)
448 449 msg = ss.Message(msg)
449 450 content = msg.content
450 451 self._config['registration'] = dict(content)
451 452 if content.status == 'ok':
452 453 if content.mux:
453 454 self._mux_socket = self._context.socket(zmq.PAIR)
454 455 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
455 456 connect_socket(self._mux_socket, content.mux)
456 457 if content.task:
457 458 self._task_scheme, task_addr = content.task
458 459 self._task_socket = self._context.socket(zmq.PAIR)
459 460 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
460 461 connect_socket(self._task_socket, task_addr)
461 462 if content.notification:
462 463 self._notification_socket = self._context.socket(zmq.SUB)
463 464 connect_socket(self._notification_socket, content.notification)
464 465 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
465 466 if content.query:
466 467 self._query_socket = self._context.socket(zmq.PAIR)
467 468 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
468 469 connect_socket(self._query_socket, content.query)
469 470 if content.control:
470 471 self._control_socket = self._context.socket(zmq.PAIR)
471 472 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
472 473 connect_socket(self._control_socket, content.control)
473 474 if content.iopub:
474 475 self._iopub_socket = self._context.socket(zmq.SUB)
475 476 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
476 477 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
477 478 connect_socket(self._iopub_socket, content.iopub)
478 479 self._update_engines(dict(content.engines))
479 480
480 481 else:
481 482 self._connected = False
482 483 raise Exception("Failed to connect!")
483 484
484 485 #--------------------------------------------------------------------------
485 486 # handlers and callbacks for incoming messages
486 487 #--------------------------------------------------------------------------
487 488
488 489 def _unwrap_exception(self, content):
489 490 """unwrap exception, and remap engineid to int."""
490 491 e = ss.unwrap_exception(content)
491 492 if e.engine_info:
492 e_uuid = e.engine_info['engineid']
493 e_uuid = e.engine_info['engine_uuid']
493 494 eid = self._engines[e_uuid]
494 e.engine_info['engineid'] = eid
495 e.engine_info['engine_id'] = eid
495 496 return e
496 497
497 498 def _register_engine(self, msg):
498 499 """Register a new engine, and update our connection info."""
499 500 content = msg['content']
500 501 eid = content['id']
501 502 d = {eid : content['queue']}
502 503 self._update_engines(d)
503 504
504 505 def _unregister_engine(self, msg):
505 506 """Unregister an engine that has died."""
506 507 content = msg['content']
507 508 eid = int(content['id'])
508 509 if eid in self._ids:
509 510 self._ids.remove(eid)
510 511 self._engines.pop(eid)
511 512 if self._task_socket and self._task_scheme == 'pure':
512 513 self._stop_scheduling_tasks()
513 514
514 515 def _extract_metadata(self, header, parent, content):
515 516 md = {'msg_id' : parent['msg_id'],
516 517 'received' : datetime.now(),
517 518 'engine_uuid' : header.get('engine', None),
518 519 'follow' : parent.get('follow', []),
519 520 'after' : parent.get('after', []),
520 521 'status' : content['status'],
521 522 }
522 523
523 524 if md['engine_uuid'] is not None:
524 525 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
525 526
526 527 if 'date' in parent:
527 528 md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
528 529 if 'started' in header:
529 530 md['started'] = datetime.strptime(header['started'], ss.ISO8601)
530 531 if 'date' in header:
531 532 md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
532 533 return md
533 534
534 535 def _handle_execute_reply(self, msg):
535 536 """Save the reply to an execute_request into our results.
536 537
537 538 execute messages are never actually used. apply is used instead.
538 539 """
539 540
540 541 parent = msg['parent_header']
541 542 msg_id = parent['msg_id']
542 543 if msg_id not in self.outstanding:
543 544 if msg_id in self.history:
544 545 print ("got stale result: %s"%msg_id)
545 546 else:
546 547 print ("got unknown result: %s"%msg_id)
547 548 else:
548 549 self.outstanding.remove(msg_id)
549 550 self.results[msg_id] = self._unwrap_exception(msg['content'])
550 551
551 552 def _handle_apply_reply(self, msg):
552 553 """Save the reply to an apply_request into our results."""
553 554 parent = msg['parent_header']
554 555 msg_id = parent['msg_id']
555 556 if msg_id not in self.outstanding:
556 557 if msg_id in self.history:
557 558 print ("got stale result: %s"%msg_id)
558 559 print self.results[msg_id]
559 560 print msg
560 561 else:
561 562 print ("got unknown result: %s"%msg_id)
562 563 else:
563 564 self.outstanding.remove(msg_id)
564 565 content = msg['content']
565 566 header = msg['header']
566 567
567 568 # construct metadata:
568 569 md = self.metadata.setdefault(msg_id, Metadata())
569 570 md.update(self._extract_metadata(header, parent, content))
570 571 self.metadata[msg_id] = md
571 572
572 573 # construct result:
573 574 if content['status'] == 'ok':
574 575 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
575 576 elif content['status'] == 'aborted':
576 577 self.results[msg_id] = error.AbortedTask(msg_id)
577 578 elif content['status'] == 'resubmitted':
578 579 # TODO: handle resubmission
579 580 pass
580 581 else:
581 582 self.results[msg_id] = self._unwrap_exception(content)
582 583
583 584 def _flush_notifications(self):
584 585 """Flush notifications of engine registrations waiting
585 586 in ZMQ queue."""
586 587 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
587 588 while msg is not None:
588 589 if self.debug:
589 590 pprint(msg)
590 591 msg = msg[-1]
591 592 msg_type = msg['msg_type']
592 593 handler = self._notification_handlers.get(msg_type, None)
593 594 if handler is None:
594 595 raise Exception("Unhandled message type: %s"%msg.msg_type)
595 596 else:
596 597 handler(msg)
597 598 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
598 599
599 600 def _flush_results(self, sock):
600 601 """Flush task or queue results waiting in ZMQ queue."""
601 602 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
602 603 while msg is not None:
603 604 if self.debug:
604 605 pprint(msg)
605 606 msg = msg[-1]
606 607 msg_type = msg['msg_type']
607 608 handler = self._queue_handlers.get(msg_type, None)
608 609 if handler is None:
609 610 raise Exception("Unhandled message type: %s"%msg.msg_type)
610 611 else:
611 612 handler(msg)
612 613 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
613 614
614 615 def _flush_control(self, sock):
615 616 """Flush replies from the control channel waiting
616 617 in the ZMQ queue.
617 618
618 619 Currently: ignore them."""
619 620 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
620 621 while msg is not None:
621 622 if self.debug:
622 623 pprint(msg)
623 624 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
624 625
625 626 def _flush_iopub(self, sock):
626 627 """Flush replies from the iopub channel waiting
627 628 in the ZMQ queue.
628 629 """
629 630 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
630 631 while msg is not None:
631 632 if self.debug:
632 633 pprint(msg)
633 634 msg = msg[-1]
634 635 parent = msg['parent_header']
635 636 msg_id = parent['msg_id']
636 637 content = msg['content']
637 638 header = msg['header']
638 639 msg_type = msg['msg_type']
639 640
640 641 # init metadata:
641 642 md = self.metadata.setdefault(msg_id, Metadata())
642 643
643 644 if msg_type == 'stream':
644 645 name = content['name']
645 646 s = md[name] or ''
646 647 md[name] = s + content['data']
647 648 elif msg_type == 'pyerr':
648 649 md.update({'pyerr' : self._unwrap_exception(content)})
649 650 else:
650 651 md.update({msg_type : content['data']})
651 652
652 653 self.metadata[msg_id] = md
653 654
654 655 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
655 656
656 657 #--------------------------------------------------------------------------
657 658 # len, getitem
658 659 #--------------------------------------------------------------------------
659 660
660 661 def __len__(self):
661 662 """len(client) returns # of engines."""
662 663 return len(self.ids)
663 664
664 665 def __getitem__(self, key):
665 666 """index access returns DirectView multiplexer objects
666 667
667 668 Must be int, slice, or list/tuple/xrange of ints"""
668 669 if not isinstance(key, (int, slice, tuple, list, xrange)):
669 670 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
670 671 else:
671 672 return self.view(key, balanced=False)
672 673
673 674 #--------------------------------------------------------------------------
674 675 # Begin public methods
675 676 #--------------------------------------------------------------------------
676 677
677 678 def spin(self):
678 679 """Flush any registration notifications and execution results
679 680 waiting in the ZMQ queue.
680 681 """
681 682 if self._notification_socket:
682 683 self._flush_notifications()
683 684 if self._mux_socket:
684 685 self._flush_results(self._mux_socket)
685 686 if self._task_socket:
686 687 self._flush_results(self._task_socket)
687 688 if self._control_socket:
688 689 self._flush_control(self._control_socket)
689 690 if self._iopub_socket:
690 691 self._flush_iopub(self._iopub_socket)
691 692
692 693 def barrier(self, jobs=None, timeout=-1):
693 694 """waits on one or more `jobs`, for up to `timeout` seconds.
694 695
695 696 Parameters
696 697 ----------
697 698
698 699 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
699 700 ints are indices to self.history
700 701 strs are msg_ids
701 702 default: wait on all outstanding messages
702 703 timeout : float
703 704 a time in seconds, after which to give up.
704 705 default is -1, which means no timeout
705 706
706 707 Returns
707 708 -------
708 709
709 710 True : when all msg_ids are done
710 711 False : timeout reached, some msg_ids still outstanding
711 712 """
712 713 tic = time.time()
713 714 if jobs is None:
714 715 theids = self.outstanding
715 716 else:
716 717 if isinstance(jobs, (int, str, AsyncResult)):
717 718 jobs = [jobs]
718 719 theids = set()
719 720 for job in jobs:
720 721 if isinstance(job, int):
721 722 # index access
722 723 job = self.history[job]
723 724 elif isinstance(job, AsyncResult):
724 725 map(theids.add, job.msg_ids)
725 726 continue
726 727 theids.add(job)
727 728 if not theids.intersection(self.outstanding):
728 729 return True
729 730 self.spin()
730 731 while theids.intersection(self.outstanding):
731 732 if timeout >= 0 and ( time.time()-tic ) > timeout:
732 733 break
733 734 time.sleep(1e-3)
734 735 self.spin()
735 736 return len(theids.intersection(self.outstanding)) == 0
736 737
737 738 #--------------------------------------------------------------------------
738 739 # Control methods
739 740 #--------------------------------------------------------------------------
740 741
741 742 @spinfirst
742 743 @defaultblock
743 744 def clear(self, targets=None, block=None):
744 745 """Clear the namespace in target(s)."""
745 746 targets = self._build_targets(targets)[0]
746 747 for t in targets:
747 748 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
748 749 error = False
749 750 if self.block:
750 751 for i in range(len(targets)):
751 752 idents,msg = self.session.recv(self._control_socket,0)
752 753 if self.debug:
753 754 pprint(msg)
754 755 if msg['content']['status'] != 'ok':
755 756 error = self._unwrap_exception(msg['content'])
756 757 if error:
757 758 return error
758 759
759 760
760 761 @spinfirst
761 762 @defaultblock
762 763 def abort(self, jobs=None, targets=None, block=None):
763 764 """Abort specific jobs from the execution queues of target(s).
764 765
765 766 This is a mechanism to prevent jobs that have already been submitted
766 767 from executing.
767 768
768 769 Parameters
769 770 ----------
770 771
771 772 jobs : msg_id, list of msg_ids, or AsyncResult
772 773 The jobs to be aborted
773 774
774 775
775 776 """
776 777 targets = self._build_targets(targets)[0]
777 778 msg_ids = []
778 779 if isinstance(jobs, (basestring,AsyncResult)):
779 780 jobs = [jobs]
780 781 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
781 782 if bad_ids:
782 783 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
783 784 for j in jobs:
784 785 if isinstance(j, AsyncResult):
785 786 msg_ids.extend(j.msg_ids)
786 787 else:
787 788 msg_ids.append(j)
788 789 content = dict(msg_ids=msg_ids)
789 790 for t in targets:
790 791 self.session.send(self._control_socket, 'abort_request',
791 792 content=content, ident=t)
792 793 error = False
793 794 if self.block:
794 795 for i in range(len(targets)):
795 796 idents,msg = self.session.recv(self._control_socket,0)
796 797 if self.debug:
797 798 pprint(msg)
798 799 if msg['content']['status'] != 'ok':
799 800 error = self._unwrap_exception(msg['content'])
800 801 if error:
801 802 return error
802 803
803 804 @spinfirst
804 805 @defaultblock
805 806 def shutdown(self, targets=None, restart=False, controller=False, block=None):
806 807 """Terminates one or more engine processes, optionally including the controller."""
807 808 if controller:
808 809 targets = 'all'
809 810 targets = self._build_targets(targets)[0]
810 811 for t in targets:
811 812 self.session.send(self._control_socket, 'shutdown_request',
812 813 content={'restart':restart},ident=t)
813 814 error = False
814 815 if block or controller:
815 816 for i in range(len(targets)):
816 817 idents,msg = self.session.recv(self._control_socket,0)
817 818 if self.debug:
818 819 pprint(msg)
819 820 if msg['content']['status'] != 'ok':
820 821 error = self._unwrap_exception(msg['content'])
821 822
822 823 if controller:
823 824 time.sleep(0.25)
824 825 self.session.send(self._query_socket, 'shutdown_request')
825 826 idents,msg = self.session.recv(self._query_socket, 0)
826 827 if self.debug:
827 828 pprint(msg)
828 829 if msg['content']['status'] != 'ok':
829 830 error = self._unwrap_exception(msg['content'])
830 831
831 832 if error:
832 833 raise error
833 834
834 835 #--------------------------------------------------------------------------
835 836 # Execution methods
836 837 #--------------------------------------------------------------------------
837 838
838 839 @defaultblock
839 840 def execute(self, code, targets='all', block=None):
840 841 """Executes `code` on `targets` in blocking or nonblocking manner.
841 842
842 843 ``execute`` is always `bound` (affects engine namespace)
843 844
844 845 Parameters
845 846 ----------
846 847
847 848 code : str
848 849 the code string to be executed
849 850 targets : int/str/list of ints/strs
850 851 the engines on which to execute
851 852 default : all
852 853 block : bool
853 854 whether or not to wait until done to return
854 855 default: self.block
855 856 """
856 857 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
857 858 if not block:
858 859 return result
859 860
860 861 def run(self, filename, targets='all', block=None):
861 862 """Execute contents of `filename` on engine(s).
862 863
863 864 This simply reads the contents of the file and calls `execute`.
864 865
865 866 Parameters
866 867 ----------
867 868
868 869 filename : str
869 870 The path to the file
870 871 targets : int/str/list of ints/strs
871 872 the engines on which to execute
872 873 default : all
873 874 block : bool
874 875 whether or not to wait until done
875 876 default: self.block
876 877
877 878 """
878 879 with open(filename, 'rb') as f:
879 880 code = f.read()
880 881 return self.execute(code, targets=targets, block=block)
881 882
882 883 def _maybe_raise(self, result):
883 884 """wrapper for maybe raising an exception if apply failed."""
884 885 if isinstance(result, error.RemoteError):
885 886 raise result
886 887
887 888 return result
888 889
889 890 def _build_dependency(self, dep):
890 891 """helper for building jsonable dependencies from various input forms"""
891 892 if isinstance(dep, Dependency):
892 893 return dep.as_dict()
893 894 elif isinstance(dep, AsyncResult):
894 895 return dep.msg_ids
895 896 elif dep is None:
896 897 return []
897 898 else:
898 899 # pass to Dependency constructor
899 900 return list(Dependency(dep))
900 901
901 902 @defaultblock
902 903 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
903 904 targets=None, balanced=None,
904 905 after=None, follow=None, timeout=None):
905 906 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
906 907
907 908 This is the central execution command for the client.
908 909
909 910 Parameters
910 911 ----------
911 912
912 913 f : function
913 914 The fuction to be called remotely
914 915 args : tuple/list
915 916 The positional arguments passed to `f`
916 917 kwargs : dict
917 918 The keyword arguments passed to `f`
918 919 bound : bool (default: True)
919 920 Whether to execute in the Engine(s) namespace, or in a clean
920 921 namespace not affecting the engine.
921 922 block : bool (default: self.block)
922 923 Whether to wait for the result, or return immediately.
923 924 False:
924 925 returns AsyncResult
925 926 True:
926 927 returns actual result(s) of f(*args, **kwargs)
927 928 if multiple targets:
928 929 list of results, matching `targets`
929 930 targets : int,list of ints, 'all', None
930 931 Specify the destination of the job.
931 932 if None:
932 933 Submit via Task queue for load-balancing.
933 934 if 'all':
934 935 Run on all active engines
935 936 if list:
936 937 Run on each specified engine
937 938 if int:
938 939 Run on single engine
939 940
940 941 balanced : bool, default None
941 942 whether to load-balance. This will default to True
942 943 if targets is unspecified, or False if targets is specified.
943 944
944 945 The following arguments are only used when balanced is True:
945 946 after : Dependency or collection of msg_ids
946 947 Only for load-balanced execution (targets=None)
947 948 Specify a list of msg_ids as a time-based dependency.
948 949 This job will only be run *after* the dependencies
949 950 have been met.
950 951
951 952 follow : Dependency or collection of msg_ids
952 953 Only for load-balanced execution (targets=None)
953 954 Specify a list of msg_ids as a location-based dependency.
954 955 This job will only be run on an engine where this dependency
955 956 is met.
956 957
957 958 timeout : float/int or None
958 959 Only for load-balanced execution (targets=None)
959 960 Specify an amount of time (in seconds) for the scheduler to
960 961 wait for dependencies to be met before failing with a
961 962 DependencyTimeout.
962 963
963 964 after,follow,timeout only used if `balanced=True`.
964 965
965 966 Returns
966 967 -------
967 968
968 969 if block is False:
969 970 return AsyncResult wrapping msg_ids
970 971 output of AsyncResult.get() is identical to that of `apply(...block=True)`
971 972 else:
972 973 if single target:
973 974 return result of `f(*args, **kwargs)`
974 975 else:
975 976 return list of results, matching `targets`
976 977 """
977 978 assert not self._closed, "cannot use me anymore, I'm closed!"
978 979 # defaults:
979 980 block = block if block is not None else self.block
980 981 args = args if args is not None else []
981 982 kwargs = kwargs if kwargs is not None else {}
982 983
983 984 if balanced is None:
984 985 if targets is None:
985 986 # default to balanced if targets unspecified
986 987 balanced = True
987 988 else:
988 989 # otherwise default to multiplexing
989 990 balanced = False
990 991
991 992 if targets is None and balanced is False:
992 993 # default to all if *not* balanced, and targets is unspecified
993 994 targets = 'all'
994 995
995 996 # enforce types of f,args,kwrags
996 997 if not callable(f):
997 998 raise TypeError("f must be callable, not %s"%type(f))
998 999 if not isinstance(args, (tuple, list)):
999 1000 raise TypeError("args must be tuple or list, not %s"%type(args))
1000 1001 if not isinstance(kwargs, dict):
1001 1002 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1002 1003
1003 1004 options = dict(bound=bound, block=block, targets=targets)
1004 1005
1005 1006 if balanced:
1006 1007 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1007 1008 after=after, follow=follow, **options)
1008 1009 elif follow or after or timeout:
1009 1010 msg = "follow, after, and timeout args are only used for"
1010 1011 msg += " load-balanced execution."
1011 1012 raise ValueError(msg)
1012 1013 else:
1013 1014 return self._apply_direct(f, args, kwargs, **options)
1014 1015
1015 1016 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1016 1017 after=None, follow=None, timeout=None):
1017 1018 """call f(*args, **kwargs) remotely in a load-balanced manner.
1018 1019
1019 1020 This is a private method, see `apply` for details.
1020 1021 Not to be called directly!
1021 1022 """
1022 1023
1023 1024 loc = locals()
1024 1025 for name in ('bound', 'block'):
1025 1026 assert loc[name] is not None, "kwarg %r must be specified!"%name
1026 1027
1027 1028 if self._task_socket is None:
1028 1029 msg = "Task farming is disabled"
1029 1030 if self._task_scheme == 'pure':
1030 1031 msg += " because the pure ZMQ scheduler cannot handle"
1031 1032 msg += " disappearing engines."
1032 1033 raise RuntimeError(msg)
1033 1034
1034 1035 if self._task_scheme == 'pure':
1035 1036 # pure zmq scheme doesn't support dependencies
1036 1037 msg = "Pure ZMQ scheduler doesn't support dependencies"
1037 1038 if (follow or after):
1038 1039 # hard fail on DAG dependencies
1039 1040 raise RuntimeError(msg)
1040 1041 if isinstance(f, dependent):
1041 1042 # soft warn on functional dependencies
1042 1043 warnings.warn(msg, RuntimeWarning)
1043 1044
1044 1045 # defaults:
1045 1046 args = args if args is not None else []
1046 1047 kwargs = kwargs if kwargs is not None else {}
1047 1048
1048 1049 if targets:
1049 1050 idents,_ = self._build_targets(targets)
1050 1051 else:
1051 1052 idents = []
1052 1053
1053 1054 after = self._build_dependency(after)
1054 1055 follow = self._build_dependency(follow)
1055 1056 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1056 1057 bufs = ss.pack_apply_message(f,args,kwargs)
1057 1058 content = dict(bound=bound)
1058 1059
1059 1060 msg = self.session.send(self._task_socket, "apply_request",
1060 1061 content=content, buffers=bufs, subheader=subheader)
1061 1062 msg_id = msg['msg_id']
1062 1063 self.outstanding.add(msg_id)
1063 1064 self.history.append(msg_id)
1064 1065 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1065 1066 if block:
1066 1067 try:
1067 1068 return ar.get()
1068 1069 except KeyboardInterrupt:
1069 1070 return ar
1070 1071 else:
1071 1072 return ar
1072 1073
1073 1074 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
1074 1075 """Then underlying method for applying functions to specific engines
1075 1076 via the MUX queue.
1076 1077
1077 1078 This is a private method, see `apply` for details.
1078 1079 Not to be called directly!
1079 1080 """
1080 1081 loc = locals()
1081 1082 for name in ('bound', 'block', 'targets'):
1082 1083 assert loc[name] is not None, "kwarg %r must be specified!"%name
1083 1084
1084 1085 idents,targets = self._build_targets(targets)
1085 1086
1086 1087 subheader = {}
1087 1088 content = dict(bound=bound)
1088 1089 bufs = ss.pack_apply_message(f,args,kwargs)
1089 1090
1090 1091 msg_ids = []
1091 1092 for ident in idents:
1092 1093 msg = self.session.send(self._mux_socket, "apply_request",
1093 1094 content=content, buffers=bufs, ident=ident, subheader=subheader)
1094 1095 msg_id = msg['msg_id']
1095 1096 self.outstanding.add(msg_id)
1096 1097 self.history.append(msg_id)
1097 1098 msg_ids.append(msg_id)
1098 1099 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1099 1100 if block:
1100 1101 try:
1101 1102 return ar.get()
1102 1103 except KeyboardInterrupt:
1103 1104 return ar
1104 1105 else:
1105 1106 return ar
1106 1107
1107 1108 #--------------------------------------------------------------------------
1108 1109 # construct a View object
1109 1110 #--------------------------------------------------------------------------
1110 1111
1111 1112 @defaultblock
1112 1113 def remote(self, bound=True, block=None, targets=None, balanced=None):
1113 1114 """Decorator for making a RemoteFunction"""
1114 1115 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1115 1116
1116 1117 @defaultblock
1117 1118 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1118 1119 """Decorator for making a ParallelFunction"""
1119 1120 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1120 1121
1121 1122 def _cache_view(self, targets, balanced):
1122 1123 """save views, so subsequent requests don't create new objects."""
1123 1124 if balanced:
1124 1125 view_class = LoadBalancedView
1125 1126 view_cache = self._balanced_views
1126 1127 else:
1127 1128 view_class = DirectView
1128 1129 view_cache = self._direct_views
1129 1130
1130 1131 # use str, since often targets will be a list
1131 1132 key = str(targets)
1132 1133 if key not in view_cache:
1133 1134 view_cache[key] = view_class(client=self, targets=targets)
1134 1135
1135 1136 return view_cache[key]
1136 1137
1137 1138 def view(self, targets=None, balanced=None):
1138 1139 """Method for constructing View objects.
1139 1140
1140 1141 If no arguments are specified, create a LoadBalancedView
1141 1142 using all engines. If only `targets` specified, it will
1142 1143 be a DirectView. This method is the underlying implementation
1143 1144 of ``client.__getitem__``.
1144 1145
1145 1146 Parameters
1146 1147 ----------
1147 1148
1148 1149 targets: list,slice,int,etc. [default: use all engines]
1149 1150 The engines to use for the View
1150 1151 balanced : bool [default: False if targets specified, True else]
1151 1152 whether to build a LoadBalancedView or a DirectView
1152 1153
1153 1154 """
1154 1155
1155 1156 balanced = (targets is None) if balanced is None else balanced
1156 1157
1157 1158 if targets is None:
1158 1159 if balanced:
1159 1160 return self._cache_view(None,True)
1160 1161 else:
1161 1162 targets = slice(None)
1162 1163
1163 1164 if isinstance(targets, int):
1164 1165 if targets < 0:
1165 1166 targets = self.ids[targets]
1166 1167 if targets not in self.ids:
1167 1168 raise IndexError("No such engine: %i"%targets)
1168 1169 return self._cache_view(targets, balanced)
1169 1170
1170 1171 if isinstance(targets, slice):
1171 1172 indices = range(len(self.ids))[targets]
1172 1173 ids = sorted(self._ids)
1173 1174 targets = [ ids[i] for i in indices ]
1174 1175
1175 1176 if isinstance(targets, (tuple, list, xrange)):
1176 1177 _,targets = self._build_targets(list(targets))
1177 1178 return self._cache_view(targets, balanced)
1178 1179 else:
1179 1180 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1180 1181
1181 1182 #--------------------------------------------------------------------------
1182 1183 # Data movement
1183 1184 #--------------------------------------------------------------------------
1184 1185
1185 1186 @defaultblock
1186 1187 def push(self, ns, targets='all', block=None):
1187 1188 """Push the contents of `ns` into the namespace on `target`"""
1188 1189 if not isinstance(ns, dict):
1189 1190 raise TypeError("Must be a dict, not %s"%type(ns))
1190 1191 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
1191 1192 if not block:
1192 1193 return result
1193 1194
1194 1195 @defaultblock
1195 1196 def pull(self, keys, targets='all', block=None):
1196 1197 """Pull objects from `target`'s namespace by `keys`"""
1197 1198 if isinstance(keys, str):
1198 1199 pass
1199 1200 elif isinstance(keys, (list,tuple,set)):
1200 1201 for key in keys:
1201 1202 if not isinstance(key, str):
1202 1203 raise TypeError
1203 1204 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1204 1205 return result
1205 1206
1206 1207 @defaultblock
1207 1208 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1208 1209 """
1209 1210 Partition a Python sequence and send the partitions to a set of engines.
1210 1211 """
1211 1212 targets = self._build_targets(targets)[-1]
1212 1213 mapObject = Map.dists[dist]()
1213 1214 nparts = len(targets)
1214 1215 msg_ids = []
1215 1216 for index, engineid in enumerate(targets):
1216 1217 partition = mapObject.getPartition(seq, index, nparts)
1217 1218 if flatten and len(partition) == 1:
1218 1219 r = self.push({key: partition[0]}, targets=engineid, block=False)
1219 1220 else:
1220 1221 r = self.push({key: partition}, targets=engineid, block=False)
1221 1222 msg_ids.extend(r.msg_ids)
1222 1223 r = AsyncResult(self, msg_ids, fname='scatter')
1223 1224 if block:
1224 1225 r.get()
1225 1226 else:
1226 1227 return r
1227 1228
1228 1229 @defaultblock
1229 1230 def gather(self, key, dist='b', targets='all', block=None):
1230 1231 """
1231 1232 Gather a partitioned sequence on a set of engines as a single local seq.
1232 1233 """
1233 1234
1234 1235 targets = self._build_targets(targets)[-1]
1235 1236 mapObject = Map.dists[dist]()
1236 1237 msg_ids = []
1237 1238 for index, engineid in enumerate(targets):
1238 1239 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1239 1240
1240 1241 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1241 1242 if block:
1242 1243 return r.get()
1243 1244 else:
1244 1245 return r
1245 1246
1246 1247 #--------------------------------------------------------------------------
1247 1248 # Query methods
1248 1249 #--------------------------------------------------------------------------
1249 1250
1250 1251 @spinfirst
1251 1252 @defaultblock
1252 1253 def get_result(self, indices_or_msg_ids=None, block=None):
1253 1254 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1254 1255
1255 1256 If the client already has the results, no request to the Hub will be made.
1256 1257
1257 1258 This is a convenient way to construct AsyncResult objects, which are wrappers
1258 1259 that include metadata about execution, and allow for awaiting results that
1259 1260 were not submitted by this Client.
1260 1261
1261 1262 It can also be a convenient way to retrieve the metadata associated with
1262 1263 blocking execution, since it always retrieves
1263 1264
1264 1265 Examples
1265 1266 --------
1266 1267 ::
1267 1268
1268 1269 In [10]: r = client.apply()
1269 1270
1270 1271 Parameters
1271 1272 ----------
1272 1273
1273 1274 indices_or_msg_ids : integer history index, str msg_id, or list of either
1274 1275 The indices or msg_ids of indices to be retrieved
1275 1276
1276 1277 block : bool
1277 1278 Whether to wait for the result to be done
1278 1279
1279 1280 Returns
1280 1281 -------
1281 1282
1282 1283 AsyncResult
1283 1284 A single AsyncResult object will always be returned.
1284 1285
1285 1286 AsyncHubResult
1286 1287 A subclass of AsyncResult that retrieves results from the Hub
1287 1288
1288 1289 """
1289 1290 if indices_or_msg_ids is None:
1290 1291 indices_or_msg_ids = -1
1291 1292
1292 1293 if not isinstance(indices_or_msg_ids, (list,tuple)):
1293 1294 indices_or_msg_ids = [indices_or_msg_ids]
1294 1295
1295 1296 theids = []
1296 1297 for id in indices_or_msg_ids:
1297 1298 if isinstance(id, int):
1298 1299 id = self.history[id]
1299 1300 if not isinstance(id, str):
1300 1301 raise TypeError("indices must be str or int, not %r"%id)
1301 1302 theids.append(id)
1302 1303
1303 1304 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1304 1305 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1305 1306
1306 1307 if remote_ids:
1307 1308 ar = AsyncHubResult(self, msg_ids=theids)
1308 1309 else:
1309 1310 ar = AsyncResult(self, msg_ids=theids)
1310 1311
1311 1312 if block:
1312 1313 ar.wait()
1313 1314
1314 1315 return ar
1315 1316
1316 1317 @spinfirst
1317 1318 def result_status(self, msg_ids, status_only=True):
1318 1319 """Check on the status of the result(s) of the apply request with `msg_ids`.
1319 1320
1320 1321 If status_only is False, then the actual results will be retrieved, else
1321 1322 only the status of the results will be checked.
1322 1323
1323 1324 Parameters
1324 1325 ----------
1325 1326
1326 1327 msg_ids : list of msg_ids
1327 1328 if int:
1328 1329 Passed as index to self.history for convenience.
1329 1330 status_only : bool (default: True)
1330 1331 if False:
1331 1332 Retrieve the actual results of completed tasks.
1332 1333
1333 1334 Returns
1334 1335 -------
1335 1336
1336 1337 results : dict
1337 1338 There will always be the keys 'pending' and 'completed', which will
1338 1339 be lists of msg_ids that are incomplete or complete. If `status_only`
1339 1340 is False, then completed results will be keyed by their `msg_id`.
1340 1341 """
1341 if not isinstance(indices_or_msg_ids, (list,tuple)):
1342 indices_or_msg_ids = [indices_or_msg_ids]
1342 if not isinstance(msg_ids, (list,tuple)):
1343 indices_or_msg_ids = [msg_ids]
1343 1344
1344 1345 theids = []
1345 for msg_id in indices_or_msg_ids:
1346 for msg_id in msg_ids:
1346 1347 if isinstance(msg_id, int):
1347 1348 msg_id = self.history[msg_id]
1348 1349 if not isinstance(msg_id, basestring):
1349 1350 raise TypeError("msg_ids must be str, not %r"%msg_id)
1350 1351 theids.append(msg_id)
1351 1352
1352 1353 completed = []
1353 1354 local_results = {}
1354 1355
1355 1356 # comment this block out to temporarily disable local shortcut:
1356 1357 for msg_id in theids:
1357 1358 if msg_id in self.results:
1358 1359 completed.append(msg_id)
1359 1360 local_results[msg_id] = self.results[msg_id]
1360 1361 theids.remove(msg_id)
1361 1362
1362 1363 if theids: # some not locally cached
1363 1364 content = dict(msg_ids=theids, status_only=status_only)
1364 1365 msg = self.session.send(self._query_socket, "result_request", content=content)
1365 1366 zmq.select([self._query_socket], [], [])
1366 1367 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1367 1368 if self.debug:
1368 1369 pprint(msg)
1369 1370 content = msg['content']
1370 1371 if content['status'] != 'ok':
1371 1372 raise self._unwrap_exception(content)
1372 1373 buffers = msg['buffers']
1373 1374 else:
1374 1375 content = dict(completed=[],pending=[])
1375 1376
1376 1377 content['completed'].extend(completed)
1377 1378
1378 1379 if status_only:
1379 1380 return content
1380 1381
1381 1382 failures = []
1382 1383 # load cached results into result:
1383 1384 content.update(local_results)
1384 1385 # update cache with results:
1385 1386 for msg_id in sorted(theids):
1386 1387 if msg_id in content['completed']:
1387 1388 rec = content[msg_id]
1388 1389 parent = rec['header']
1389 1390 header = rec['result_header']
1390 1391 rcontent = rec['result_content']
1391 1392 iodict = rec['io']
1392 1393 if isinstance(rcontent, str):
1393 1394 rcontent = self.session.unpack(rcontent)
1394 1395
1395 1396 md = self.metadata.setdefault(msg_id, Metadata())
1396 1397 md.update(self._extract_metadata(header, parent, rcontent))
1397 1398 md.update(iodict)
1398 1399
1399 1400 if rcontent['status'] == 'ok':
1400 1401 res,buffers = ss.unserialize_object(buffers)
1401 1402 else:
1402 1403 print rcontent
1403 1404 res = self._unwrap_exception(rcontent)
1404 1405 failures.append(res)
1405 1406
1406 1407 self.results[msg_id] = res
1407 1408 content[msg_id] = res
1408 1409
1409 1410 if len(theids) == 1 and failures:
1410 1411 raise failures[0]
1411 1412
1412 1413 error.collect_exceptions(failures, "result_status")
1413 1414 return content
1414 1415
1415 1416 @spinfirst
1416 1417 def queue_status(self, targets='all', verbose=False):
1417 1418 """Fetch the status of engine queues.
1418 1419
1419 1420 Parameters
1420 1421 ----------
1421 1422
1422 1423 targets : int/str/list of ints/strs
1423 1424 the engines whose states are to be queried.
1424 1425 default : all
1425 1426 verbose : bool
1426 1427 Whether to return lengths only, or lists of ids for each element
1427 1428 """
1428 1429 targets = self._build_targets(targets)[1]
1429 1430 content = dict(targets=targets, verbose=verbose)
1430 1431 self.session.send(self._query_socket, "queue_request", content=content)
1431 1432 idents,msg = self.session.recv(self._query_socket, 0)
1432 1433 if self.debug:
1433 1434 pprint(msg)
1434 1435 content = msg['content']
1435 1436 status = content.pop('status')
1436 1437 if status != 'ok':
1437 1438 raise self._unwrap_exception(content)
1438 1439 return ss.rekey(content)
1439 1440
1440 1441 @spinfirst
1441 1442 def purge_results(self, jobs=[], targets=[]):
1442 1443 """Tell the controller to forget results.
1443 1444
1444 1445 Individual results can be purged by msg_id, or the entire
1445 1446 history of specific targets can be purged.
1446 1447
1447 1448 Parameters
1448 1449 ----------
1449 1450
1450 1451 jobs : str or list of strs or AsyncResult objects
1451 1452 the msg_ids whose results should be forgotten.
1452 1453 targets : int/str/list of ints/strs
1453 1454 The targets, by uuid or int_id, whose entire history is to be purged.
1454 1455 Use `targets='all'` to scrub everything from the controller's memory.
1455 1456
1456 1457 default : None
1457 1458 """
1458 1459 if not targets and not jobs:
1459 1460 raise ValueError("Must specify at least one of `targets` and `jobs`")
1460 1461 if targets:
1461 1462 targets = self._build_targets(targets)[1]
1462 1463
1463 1464 # construct msg_ids from jobs
1464 1465 msg_ids = []
1465 1466 if isinstance(jobs, (basestring,AsyncResult)):
1466 1467 jobs = [jobs]
1467 1468 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1468 1469 if bad_ids:
1469 1470 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1470 1471 for j in jobs:
1471 1472 if isinstance(j, AsyncResult):
1472 1473 msg_ids.extend(j.msg_ids)
1473 1474 else:
1474 1475 msg_ids.append(j)
1475 1476
1476 1477 content = dict(targets=targets, msg_ids=msg_ids)
1477 1478 self.session.send(self._query_socket, "purge_request", content=content)
1478 1479 idents, msg = self.session.recv(self._query_socket, 0)
1479 1480 if self.debug:
1480 1481 pprint(msg)
1481 1482 content = msg['content']
1482 1483 if content['status'] != 'ok':
1483 1484 raise self._unwrap_exception(content)
1484 1485
1485 1486
1486 1487 __all__ = [ 'Client',
1487 1488 'depend',
1488 1489 'require',
1489 1490 'remote',
1490 1491 'parallel',
1491 1492 'RemoteFunction',
1492 1493 'ParallelFunction',
1493 1494 'DirectView',
1494 1495 'LoadBalancedView',
1495 1496 'AsyncResult',
1496 1497 'AsyncMapResult'
1497 1498 ]
@@ -1,292 +1,292 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes and functions for kernel related errors and exceptions."""
4 4 from __future__ import print_function
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Error classes
20 20 #-------------------------------------------------------------------------------
21 21 class IPythonError(Exception):
22 22 """Base exception that all of our exceptions inherit from.
23 23
24 24 This can be raised by code that doesn't have any more specific
25 25 information."""
26 26
27 27 pass
28 28
29 29 # Exceptions associated with the controller objects
30 30 class ControllerError(IPythonError): pass
31 31
32 32 class ControllerCreationError(ControllerError): pass
33 33
34 34
35 35 # Exceptions associated with the Engines
36 36 class EngineError(IPythonError): pass
37 37
38 38 class EngineCreationError(EngineError): pass
39 39
40 40 class KernelError(IPythonError):
41 41 pass
42 42
43 43 class NotDefined(KernelError):
44 44 def __init__(self, name):
45 45 self.name = name
46 46 self.args = (name,)
47 47
48 48 def __repr__(self):
49 49 return '<NotDefined: %s>' % self.name
50 50
51 51 __str__ = __repr__
52 52
53 53
54 54 class QueueCleared(KernelError):
55 55 pass
56 56
57 57
58 58 class IdInUse(KernelError):
59 59 pass
60 60
61 61
62 62 class ProtocolError(KernelError):
63 63 pass
64 64
65 65
66 66 class ConnectionError(KernelError):
67 67 pass
68 68
69 69
70 70 class InvalidEngineID(KernelError):
71 71 pass
72 72
73 73
74 74 class NoEnginesRegistered(KernelError):
75 75 pass
76 76
77 77
78 78 class InvalidClientID(KernelError):
79 79 pass
80 80
81 81
82 82 class InvalidDeferredID(KernelError):
83 83 pass
84 84
85 85
86 86 class SerializationError(KernelError):
87 87 pass
88 88
89 89
90 90 class MessageSizeError(KernelError):
91 91 pass
92 92
93 93
94 94 class PBMessageSizeError(MessageSizeError):
95 95 pass
96 96
97 97
98 98 class ResultNotCompleted(KernelError):
99 99 pass
100 100
101 101
102 102 class ResultAlreadyRetrieved(KernelError):
103 103 pass
104 104
105 105 class ClientError(KernelError):
106 106 pass
107 107
108 108
109 109 class TaskAborted(KernelError):
110 110 pass
111 111
112 112
113 113 class TaskTimeout(KernelError):
114 114 pass
115 115
116 116
117 117 class NotAPendingResult(KernelError):
118 118 pass
119 119
120 120
121 121 class UnpickleableException(KernelError):
122 122 pass
123 123
124 124
125 125 class AbortedPendingDeferredError(KernelError):
126 126 pass
127 127
128 128
129 129 class InvalidProperty(KernelError):
130 130 pass
131 131
132 132
133 133 class MissingBlockArgument(KernelError):
134 134 pass
135 135
136 136
137 137 class StopLocalExecution(KernelError):
138 138 pass
139 139
140 140
141 141 class SecurityError(KernelError):
142 142 pass
143 143
144 144
145 145 class FileTimeoutError(KernelError):
146 146 pass
147 147
148 148 class TimeoutError(KernelError):
149 149 pass
150 150
151 151 class UnmetDependency(KernelError):
152 152 pass
153 153
154 154 class ImpossibleDependency(UnmetDependency):
155 155 pass
156 156
157 157 class DependencyTimeout(ImpossibleDependency):
158 158 pass
159 159
160 160 class InvalidDependency(ImpossibleDependency):
161 161 pass
162 162
163 163 class RemoteError(KernelError):
164 164 """Error raised elsewhere"""
165 165 ename=None
166 166 evalue=None
167 167 traceback=None
168 168 engine_info=None
169 169
170 170 def __init__(self, ename, evalue, traceback, engine_info=None):
171 171 self.ename=ename
172 172 self.evalue=evalue
173 173 self.traceback=traceback
174 174 self.engine_info=engine_info or {}
175 175 self.args=(ename, evalue)
176 176
177 177 def __repr__(self):
178 engineid = self.engine_info.get('engineid', ' ')
178 engineid = self.engine_info.get('engine_id', ' ')
179 179 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
180 180
181 181 def __str__(self):
182 182 sig = "%s(%s)"%(self.ename, self.evalue)
183 183 if self.traceback:
184 184 return sig + '\n' + self.traceback
185 185 else:
186 186 return sig
187 187
188 188
189 189 class TaskRejectError(KernelError):
190 190 """Exception to raise when a task should be rejected by an engine.
191 191
192 192 This exception can be used to allow a task running on an engine to test
193 193 if the engine (or the user's namespace on the engine) has the needed
194 194 task dependencies. If not, the task should raise this exception. For
195 195 the task to be retried on another engine, the task should be created
196 196 with the `retries` argument > 1.
197 197
198 198 The advantage of this approach over our older properties system is that
199 199 tasks have full access to the user's namespace on the engines and the
200 200 properties don't have to be managed or tested by the controller.
201 201 """
202 202
203 203
204 204 class CompositeError(RemoteError):
205 205 """Error for representing possibly multiple errors on engines"""
206 206 def __init__(self, message, elist):
207 207 Exception.__init__(self, *(message, elist))
208 208 # Don't use pack_exception because it will conflict with the .message
209 209 # attribute that is being deprecated in 2.6 and beyond.
210 210 self.msg = message
211 211 self.elist = elist
212 212 self.args = [ e[0] for e in elist ]
213 213
214 214 def _get_engine_str(self, ei):
215 215 if not ei:
216 216 return '[Engine Exception]'
217 217 else:
218 218 return '[%s:%s]: ' % (ei['engineid'], ei['method'])
219 219
220 220 def _get_traceback(self, ev):
221 221 try:
222 222 tb = ev._ipython_traceback_text
223 223 except AttributeError:
224 224 return 'No traceback available'
225 225 else:
226 226 return tb
227 227
228 228 def __str__(self):
229 229 s = str(self.msg)
230 230 for en, ev, etb, ei in self.elist:
231 231 engine_str = self._get_engine_str(ei)
232 232 s = s + '\n' + engine_str + en + ': ' + str(ev)
233 233 return s
234 234
235 235 def __repr__(self):
236 236 return "CompositeError(%i)"%len(self.elist)
237 237
238 238 def print_tracebacks(self, excid=None):
239 239 if excid is None:
240 240 for (en,ev,etb,ei) in self.elist:
241 241 print (self._get_engine_str(ei))
242 242 print (etb or 'No traceback available')
243 243 print ()
244 244 else:
245 245 try:
246 246 en,ev,etb,ei = self.elist[excid]
247 247 except:
248 248 raise IndexError("an exception with index %i does not exist"%excid)
249 249 else:
250 250 print (self._get_engine_str(ei))
251 251 print (etb or 'No traceback available')
252 252
253 253 def raise_exception(self, excid=0):
254 254 try:
255 255 en,ev,etb,ei = self.elist[excid]
256 256 except:
257 257 raise IndexError("an exception with index %i does not exist"%excid)
258 258 else:
259 259 raise RemoteError(en, ev, etb, ei)
260 260
261 261
262 262 def collect_exceptions(rdict_or_list, method='unspecified'):
263 263 """check a result dict for errors, and raise CompositeError if any exist.
264 264 Passthrough otherwise."""
265 265 elist = []
266 266 if isinstance(rdict_or_list, dict):
267 267 rlist = rdict_or_list.values()
268 268 else:
269 269 rlist = rdict_or_list
270 270 for r in rlist:
271 271 if isinstance(r, RemoteError):
272 272 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
273 273 # Sometimes we could have CompositeError in our list. Just take
274 274 # the errors out of them and put them in our new list. This
275 275 # has the effect of flattening lists of CompositeErrors into one
276 276 # CompositeError
277 277 if en=='CompositeError':
278 278 for e in ev.elist:
279 279 elist.append(e)
280 280 else:
281 281 elist.append((en, ev, etb, ei))
282 282 if len(elist)==0:
283 283 return rdict_or_list
284 284 else:
285 285 msg = "one or more exceptions from call to method: %s" % (method)
286 286 # This silliness is needed so the debugger has access to the exception
287 287 # instance (e in this case)
288 288 try:
289 289 raise CompositeError(msg, elist)
290 290 except CompositeError as e:
291 291 raise e
292 292
@@ -1,1054 +1,1054 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import logging
19 19 import sys
20 20 import time
21 21 from datetime import datetime
22 22
23 23 import zmq
24 24 from zmq.eventloop import ioloop
25 25 from zmq.eventloop.zmqstream import ZMQStream
26 26
27 27 # internal:
28 28 from IPython.config.configurable import Configurable
29 29 from IPython.utils.importstring import import_item
30 30 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
31 31
32 32 from entry_point import select_random_ports
33 33 from factory import RegistrationFactory, LoggingFactory
34 34
35 35 from heartmonitor import HeartMonitor
36 36 from streamsession import Message, wrap_exception, ISO8601
37 37 from util import validate_url_container
38 38
39 39 try:
40 40 from pymongo.binary import Binary
41 41 except ImportError:
42 42 MongoDB=None
43 43 else:
44 44 from mongodb import MongoDB
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Code
48 48 #-----------------------------------------------------------------------------
49 49
50 50 def _passer(*args, **kwargs):
51 51 return
52 52
53 53 def _printer(*args, **kwargs):
54 54 print (args)
55 55 print (kwargs)
56 56
57 57 def init_record(msg):
58 58 """Initialize a TaskRecord based on a request."""
59 59 header = msg['header']
60 60 return {
61 61 'msg_id' : header['msg_id'],
62 62 'header' : header,
63 63 'content': msg['content'],
64 64 'buffers': msg['buffers'],
65 65 'submitted': datetime.strptime(header['date'], ISO8601),
66 66 'client_uuid' : None,
67 67 'engine_uuid' : None,
68 68 'started': None,
69 69 'completed': None,
70 70 'resubmitted': None,
71 71 'result_header' : None,
72 72 'result_content' : None,
73 73 'result_buffers' : None,
74 74 'queue' : None,
75 75 'pyin' : None,
76 76 'pyout': None,
77 77 'pyerr': None,
78 78 'stdout': '',
79 79 'stderr': '',
80 80 }
81 81
82 82
83 83 class EngineConnector(HasTraits):
84 84 """A simple object for accessing the various zmq connections of an object.
85 85 Attributes are:
86 86 id (int): engine ID
87 87 uuid (str): uuid (unused?)
88 88 queue (str): identity of queue's XREQ socket
89 89 registration (str): identity of registration XREQ socket
90 90 heartbeat (str): identity of heartbeat XREQ socket
91 91 """
92 92 id=Int(0)
93 93 queue=Str()
94 94 control=Str()
95 95 registration=Str()
96 96 heartbeat=Str()
97 97 pending=Set()
98 98
99 99 class HubFactory(RegistrationFactory):
100 100 """The Configurable for setting up a Hub."""
101 101
102 102 # name of a scheduler scheme
103 103 scheme = Str('leastload', config=True)
104 104
105 105 # port-pairs for monitoredqueues:
106 106 hb = Instance(list, config=True)
107 107 def _hb_default(self):
108 108 return select_random_ports(2)
109 109
110 110 mux = Instance(list, config=True)
111 111 def _mux_default(self):
112 112 return select_random_ports(2)
113 113
114 114 task = Instance(list, config=True)
115 115 def _task_default(self):
116 116 return select_random_ports(2)
117 117
118 118 control = Instance(list, config=True)
119 119 def _control_default(self):
120 120 return select_random_ports(2)
121 121
122 122 iopub = Instance(list, config=True)
123 123 def _iopub_default(self):
124 124 return select_random_ports(2)
125 125
126 126 # single ports:
127 127 mon_port = Instance(int, config=True)
128 128 def _mon_port_default(self):
129 129 return select_random_ports(1)[0]
130 130
131 131 query_port = Instance(int, config=True)
132 132 def _query_port_default(self):
133 133 return select_random_ports(1)[0]
134 134
135 135 notifier_port = Instance(int, config=True)
136 136 def _notifier_port_default(self):
137 137 return select_random_ports(1)[0]
138 138
139 139 ping = Int(1000, config=True) # ping frequency
140 140
141 141 engine_ip = CStr('127.0.0.1', config=True)
142 142 engine_transport = CStr('tcp', config=True)
143 143
144 144 client_ip = CStr('127.0.0.1', config=True)
145 145 client_transport = CStr('tcp', config=True)
146 146
147 147 monitor_ip = CStr('127.0.0.1', config=True)
148 148 monitor_transport = CStr('tcp', config=True)
149 149
150 150 monitor_url = CStr('')
151 151
152 152 db_class = CStr('IPython.zmq.parallel.dictdb.DictDB', config=True)
153 153
154 154 # not configurable
155 155 db = Instance('IPython.zmq.parallel.dictdb.BaseDB')
156 156 heartmonitor = Instance('IPython.zmq.parallel.heartmonitor.HeartMonitor')
157 157 subconstructors = List()
158 158 _constructed = Bool(False)
159 159
160 160 def _ip_changed(self, name, old, new):
161 161 self.engine_ip = new
162 162 self.client_ip = new
163 163 self.monitor_ip = new
164 164 self._update_monitor_url()
165 165
166 166 def _update_monitor_url(self):
167 167 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
168 168
169 169 def _transport_changed(self, name, old, new):
170 170 self.engine_transport = new
171 171 self.client_transport = new
172 172 self.monitor_transport = new
173 173 self._update_monitor_url()
174 174
175 175 def __init__(self, **kwargs):
176 176 super(HubFactory, self).__init__(**kwargs)
177 177 self._update_monitor_url()
178 178 # self.on_trait_change(self._sync_ips, 'ip')
179 179 # self.on_trait_change(self._sync_transports, 'transport')
180 180 self.subconstructors.append(self.construct_hub)
181 181
182 182
183 183 def construct(self):
184 184 assert not self._constructed, "already constructed!"
185 185
186 186 for subc in self.subconstructors:
187 187 subc()
188 188
189 189 self._constructed = True
190 190
191 191
192 192 def start(self):
193 193 assert self._constructed, "must be constructed by self.construct() first!"
194 194 self.heartmonitor.start()
195 195 self.log.info("Heartmonitor started")
196 196
197 197 def construct_hub(self):
198 198 """construct"""
199 199 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
200 200 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
201 201
202 202 ctx = self.context
203 203 loop = self.loop
204 204
205 205 # Registrar socket
206 206 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
207 207 reg.bind(client_iface % self.regport)
208 208 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
209 209 if self.client_ip != self.engine_ip:
210 210 reg.bind(engine_iface % self.regport)
211 211 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
212 212
213 213 ### Engine connections ###
214 214
215 215 # heartbeat
216 216 hpub = ctx.socket(zmq.PUB)
217 217 hpub.bind(engine_iface % self.hb[0])
218 218 hrep = ctx.socket(zmq.XREP)
219 219 hrep.bind(engine_iface % self.hb[1])
220 220 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
221 221 period=self.ping, logname=self.log.name)
222 222
223 223 ### Client connections ###
224 224 # Clientele socket
225 225 c = ZMQStream(ctx.socket(zmq.XREP), loop)
226 226 c.bind(client_iface%self.query_port)
227 227 # Notifier socket
228 228 n = ZMQStream(ctx.socket(zmq.PUB), loop)
229 229 n.bind(client_iface%self.notifier_port)
230 230
231 231 ### build and launch the queues ###
232 232
233 233 # monitor socket
234 234 sub = ctx.socket(zmq.SUB)
235 235 sub.setsockopt(zmq.SUBSCRIBE, "")
236 236 sub.bind(self.monitor_url)
237 237 sub.bind('inproc://monitor')
238 238 sub = ZMQStream(sub, loop)
239 239
240 240 # connect the db
241 241 self.db = import_item(self.db_class)(self.session.session)
242 242 time.sleep(.25)
243 243
244 244 # build connection dicts
245 245 self.engine_info = {
246 246 'control' : engine_iface%self.control[1],
247 247 'mux': engine_iface%self.mux[1],
248 248 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
249 249 'task' : engine_iface%self.task[1],
250 250 'iopub' : engine_iface%self.iopub[1],
251 251 # 'monitor' : engine_iface%self.mon_port,
252 252 }
253 253
254 254 self.client_info = {
255 255 'control' : client_iface%self.control[0],
256 256 'query': client_iface%self.query_port,
257 257 'mux': client_iface%self.mux[0],
258 258 'task' : (self.scheme, client_iface%self.task[0]),
259 259 'iopub' : client_iface%self.iopub[0],
260 260 'notification': client_iface%self.notifier_port
261 261 }
262 262 self.log.debug("hub::Hub engine addrs: %s"%self.engine_info)
263 263 self.log.debug("hub::Hub client addrs: %s"%self.client_info)
264 264 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
265 265 registrar=reg, clientele=c, notifier=n, db=self.db,
266 266 engine_info=self.engine_info, client_info=self.client_info,
267 267 logname=self.log.name)
268 268
269 269
270 270 class Hub(LoggingFactory):
271 271 """The IPython Controller Hub with 0MQ connections
272 272
273 273 Parameters
274 274 ==========
275 275 loop: zmq IOLoop instance
276 276 session: StreamSession object
277 277 <removed> context: zmq context for creating new connections (?)
278 278 queue: ZMQStream for monitoring the command queue (SUB)
279 279 registrar: ZMQStream for engine registration requests (XREP)
280 280 heartbeat: HeartMonitor object checking the pulse of the engines
281 281 clientele: ZMQStream for client connections (XREP)
282 282 not used for jobs, only query/control commands
283 283 notifier: ZMQStream for broadcasting engine registration changes (PUB)
284 284 db: connection to db for out of memory logging of commands
285 285 NotImplemented
286 286 engine_info: dict of zmq connection information for engines to connect
287 287 to the queues.
288 288 client_info: dict of zmq connection information for engines to connect
289 289 to the queues.
290 290 """
291 291 # internal data structures:
292 292 ids=Set() # engine IDs
293 293 keytable=Dict()
294 294 by_ident=Dict()
295 295 engines=Dict()
296 296 clients=Dict()
297 297 hearts=Dict()
298 298 pending=Set()
299 299 queues=Dict() # pending msg_ids keyed by engine_id
300 300 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
301 301 completed=Dict() # completed msg_ids keyed by engine_id
302 302 all_completed=Set() # completed msg_ids keyed by engine_id
303 303 # mia=None
304 304 incoming_registrations=Dict()
305 305 registration_timeout=Int()
306 306 _idcounter=Int(0)
307 307
308 308 # objects from constructor:
309 309 loop=Instance(ioloop.IOLoop)
310 310 registrar=Instance(ZMQStream)
311 311 clientele=Instance(ZMQStream)
312 312 monitor=Instance(ZMQStream)
313 313 heartmonitor=Instance(HeartMonitor)
314 314 notifier=Instance(ZMQStream)
315 315 db=Instance(object)
316 316 client_info=Dict()
317 317 engine_info=Dict()
318 318
319 319
320 320 def __init__(self, **kwargs):
321 321 """
322 322 # universal:
323 323 loop: IOLoop for creating future connections
324 324 session: streamsession for sending serialized data
325 325 # engine:
326 326 queue: ZMQStream for monitoring queue messages
327 327 registrar: ZMQStream for engine registration
328 328 heartbeat: HeartMonitor object for tracking engines
329 329 # client:
330 330 clientele: ZMQStream for client connections
331 331 # extra:
332 332 db: ZMQStream for db connection (NotImplemented)
333 333 engine_info: zmq address/protocol dict for engine connections
334 334 client_info: zmq address/protocol dict for client connections
335 335 """
336 336
337 337 super(Hub, self).__init__(**kwargs)
338 338 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
339 339
340 340 # validate connection dicts:
341 341 for k,v in self.client_info.iteritems():
342 342 if k == 'task':
343 343 validate_url_container(v[1])
344 344 else:
345 345 validate_url_container(v)
346 346 # validate_url_container(self.client_info)
347 347 validate_url_container(self.engine_info)
348 348
349 349 # register our callbacks
350 350 self.registrar.on_recv(self.dispatch_register_request)
351 351 self.clientele.on_recv(self.dispatch_client_msg)
352 352 self.monitor.on_recv(self.dispatch_monitor_traffic)
353 353
354 354 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
355 355 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
356 356
357 357 self.monitor_handlers = { 'in' : self.save_queue_request,
358 358 'out': self.save_queue_result,
359 359 'intask': self.save_task_request,
360 360 'outtask': self.save_task_result,
361 361 'tracktask': self.save_task_destination,
362 362 'incontrol': _passer,
363 363 'outcontrol': _passer,
364 364 'iopub': self.save_iopub_message,
365 365 }
366 366
367 367 self.client_handlers = {'queue_request': self.queue_status,
368 368 'result_request': self.get_results,
369 369 'purge_request': self.purge_results,
370 370 'load_request': self.check_load,
371 371 'resubmit_request': self.resubmit_task,
372 372 'shutdown_request': self.shutdown_request,
373 373 }
374 374
375 375 self.registrar_handlers = {'registration_request' : self.register_engine,
376 376 'unregistration_request' : self.unregister_engine,
377 377 'connection_request': self.connection_request,
378 378 }
379 379
380 380 self.log.info("hub::created hub")
381 381
382 382 @property
383 383 def _next_id(self):
384 384 """gemerate a new ID.
385 385
386 386 No longer reuse old ids, just count from 0."""
387 387 newid = self._idcounter
388 388 self._idcounter += 1
389 389 return newid
390 390 # newid = 0
391 391 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
392 392 # # print newid, self.ids, self.incoming_registrations
393 393 # while newid in self.ids or newid in incoming:
394 394 # newid += 1
395 395 # return newid
396 396
397 397 #-----------------------------------------------------------------------------
398 398 # message validation
399 399 #-----------------------------------------------------------------------------
400 400
401 401 def _validate_targets(self, targets):
402 402 """turn any valid targets argument into a list of integer ids"""
403 403 if targets is None:
404 404 # default to all
405 405 targets = self.ids
406 406
407 407 if isinstance(targets, (int,str,unicode)):
408 408 # only one target specified
409 409 targets = [targets]
410 410 _targets = []
411 411 for t in targets:
412 412 # map raw identities to ids
413 413 if isinstance(t, (str,unicode)):
414 414 t = self.by_ident.get(t, t)
415 415 _targets.append(t)
416 416 targets = _targets
417 417 bad_targets = [ t for t in targets if t not in self.ids ]
418 418 if bad_targets:
419 419 raise IndexError("No Such Engine: %r"%bad_targets)
420 420 if not targets:
421 421 raise IndexError("No Engines Registered")
422 422 return targets
423 423
424 424 def _validate_client_msg(self, msg):
425 425 """validates and unpacks headers of a message. Returns False if invalid,
426 426 (ident, header, parent, content)"""
427 427 client_id = msg[0]
428 428 try:
429 429 msg = self.session.unpack_message(msg[1:], content=True)
430 430 except:
431 431 self.log.error("client::Invalid Message %s"%msg, exc_info=True)
432 432 return False
433 433
434 434 msg_type = msg.get('msg_type', None)
435 435 if msg_type is None:
436 436 return False
437 437 header = msg.get('header')
438 438 # session doesn't handle split content for now:
439 439 return client_id, msg
440 440
441 441
442 442 #-----------------------------------------------------------------------------
443 443 # dispatch methods (1 per stream)
444 444 #-----------------------------------------------------------------------------
445 445
446 446 def dispatch_register_request(self, msg):
447 447 """"""
448 448 self.log.debug("registration::dispatch_register_request(%s)"%msg)
449 449 idents,msg = self.session.feed_identities(msg)
450 450 if not idents:
451 451 self.log.error("Bad Queue Message: %s"%msg, exc_info=True)
452 452 return
453 453 try:
454 454 msg = self.session.unpack_message(msg,content=True)
455 455 except:
456 456 self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
457 457 return
458 458
459 459 msg_type = msg['msg_type']
460 460 content = msg['content']
461 461
462 462 handler = self.registrar_handlers.get(msg_type, None)
463 463 if handler is None:
464 464 self.log.error("registration::got bad registration message: %s"%msg)
465 465 else:
466 466 handler(idents, msg)
467 467
468 468 def dispatch_monitor_traffic(self, msg):
469 469 """all ME and Task queue messages come through here, as well as
470 470 IOPub traffic."""
471 471 self.log.debug("monitor traffic: %s"%msg[:2])
472 472 switch = msg[0]
473 473 idents, msg = self.session.feed_identities(msg[1:])
474 474 if not idents:
475 475 self.log.error("Bad Monitor Message: %s"%msg)
476 476 return
477 477 handler = self.monitor_handlers.get(switch, None)
478 478 if handler is not None:
479 479 handler(idents, msg)
480 480 else:
481 481 self.log.error("Invalid monitor topic: %s"%switch)
482 482
483 483
484 484 def dispatch_client_msg(self, msg):
485 485 """Route messages from clients"""
486 486 idents, msg = self.session.feed_identities(msg)
487 487 if not idents:
488 488 self.log.error("Bad Client Message: %s"%msg)
489 489 return
490 490 client_id = idents[0]
491 491 try:
492 492 msg = self.session.unpack_message(msg, content=True)
493 493 except:
494 494 content = wrap_exception()
495 495 self.log.error("Bad Client Message: %s"%msg, exc_info=True)
496 496 self.session.send(self.clientele, "hub_error", ident=client_id,
497 497 content=content)
498 498 return
499 499
500 500 # print client_id, header, parent, content
501 501 #switch on message type:
502 502 msg_type = msg['msg_type']
503 503 self.log.info("client:: client %s requested %s"%(client_id, msg_type))
504 504 handler = self.client_handlers.get(msg_type, None)
505 505 try:
506 506 assert handler is not None, "Bad Message Type: %s"%msg_type
507 507 except:
508 508 content = wrap_exception()
509 509 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
510 510 self.session.send(self.clientele, "hub_error", ident=client_id,
511 511 content=content)
512 512 return
513 513 else:
514 514 handler(client_id, msg)
515 515
516 516 def dispatch_db(self, msg):
517 517 """"""
518 518 raise NotImplementedError
519 519
520 520 #---------------------------------------------------------------------------
521 521 # handler methods (1 per event)
522 522 #---------------------------------------------------------------------------
523 523
524 524 #----------------------- Heartbeat --------------------------------------
525 525
526 526 def handle_new_heart(self, heart):
527 527 """handler to attach to heartbeater.
528 528 Called when a new heart starts to beat.
529 529 Triggers completion of registration."""
530 530 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
531 531 if heart not in self.incoming_registrations:
532 532 self.log.info("heartbeat::ignoring new heart: %r"%heart)
533 533 else:
534 534 self.finish_registration(heart)
535 535
536 536
537 537 def handle_heart_failure(self, heart):
538 538 """handler to attach to heartbeater.
539 539 called when a previously registered heart fails to respond to beat request.
540 540 triggers unregistration"""
541 541 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
542 542 eid = self.hearts.get(heart, None)
543 543 queue = self.engines[eid].queue
544 544 if eid is None:
545 545 self.log.info("heartbeat::ignoring heart failure %r"%heart)
546 546 else:
547 547 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
548 548
549 549 #----------------------- MUX Queue Traffic ------------------------------
550 550
551 551 def save_queue_request(self, idents, msg):
552 552 if len(idents) < 2:
553 553 self.log.error("invalid identity prefix: %s"%idents)
554 554 return
555 555 queue_id, client_id = idents[:2]
556 556 try:
557 557 msg = self.session.unpack_message(msg, content=False)
558 558 except:
559 559 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
560 560 return
561 561
562 562 eid = self.by_ident.get(queue_id, None)
563 563 if eid is None:
564 564 self.log.error("queue::target %r not registered"%queue_id)
565 565 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
566 566 return
567 567
568 568 header = msg['header']
569 569 msg_id = header['msg_id']
570 570 record = init_record(msg)
571 571 record['engine_uuid'] = queue_id
572 572 record['client_uuid'] = client_id
573 573 record['queue'] = 'mux'
574 574 if MongoDB is not None and isinstance(self.db, MongoDB):
575 575 record['buffers'] = map(Binary, record['buffers'])
576 576 self.pending.add(msg_id)
577 577 self.queues[eid].append(msg_id)
578 578 self.db.add_record(msg_id, record)
579 579
580 580 def save_queue_result(self, idents, msg):
581 581 if len(idents) < 2:
582 582 self.log.error("invalid identity prefix: %s"%idents)
583 583 return
584 584
585 585 client_id, queue_id = idents[:2]
586 586 try:
587 587 msg = self.session.unpack_message(msg, content=False)
588 588 except:
589 589 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
590 590 queue_id,client_id, msg), exc_info=True)
591 591 return
592 592
593 593 eid = self.by_ident.get(queue_id, None)
594 594 if eid is None:
595 595 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
596 596 self.log.debug("queue:: %s"%msg[2:])
597 597 return
598 598
599 599 parent = msg['parent_header']
600 600 if not parent:
601 601 return
602 602 msg_id = parent['msg_id']
603 603 if msg_id in self.pending:
604 604 self.pending.remove(msg_id)
605 605 self.all_completed.add(msg_id)
606 606 self.queues[eid].remove(msg_id)
607 607 self.completed[eid].append(msg_id)
608 608 rheader = msg['header']
609 609 completed = datetime.strptime(rheader['date'], ISO8601)
610 610 started = rheader.get('started', None)
611 611 if started is not None:
612 612 started = datetime.strptime(started, ISO8601)
613 613 result = {
614 614 'result_header' : rheader,
615 615 'result_content': msg['content'],
616 616 'started' : started,
617 617 'completed' : completed
618 618 }
619 619 if MongoDB is not None and isinstance(self.db, MongoDB):
620 620 result['result_buffers'] = map(Binary, msg['buffers'])
621 621 else:
622 622 result['result_buffers'] = msg['buffers']
623 623 self.db.update_record(msg_id, result)
624 624 else:
625 625 self.log.debug("queue:: unknown msg finished %s"%msg_id)
626 626
627 627 #--------------------- Task Queue Traffic ------------------------------
628 628
629 629 def save_task_request(self, idents, msg):
630 630 """Save the submission of a task."""
631 631 client_id = idents[0]
632 632
633 633 try:
634 634 msg = self.session.unpack_message(msg, content=False)
635 635 except:
636 636 self.log.error("task::client %r sent invalid task message: %s"%(
637 637 client_id, msg), exc_info=True)
638 638 return
639 639 record = init_record(msg)
640 640 if MongoDB is not None and isinstance(self.db, MongoDB):
641 641 record['buffers'] = map(Binary, record['buffers'])
642 642 record['client_uuid'] = client_id
643 643 record['queue'] = 'task'
644 644 header = msg['header']
645 645 msg_id = header['msg_id']
646 646 self.pending.add(msg_id)
647 647 self.db.add_record(msg_id, record)
648 648
649 649 def save_task_result(self, idents, msg):
650 650 """save the result of a completed task."""
651 651 client_id = idents[0]
652 652 try:
653 653 msg = self.session.unpack_message(msg, content=False)
654 654 except:
655 655 self.log.error("task::invalid task result message send to %r: %s"%(
656 656 client_id, msg), exc_info=True)
657 657 raise
658 658 return
659 659
660 660 parent = msg['parent_header']
661 661 if not parent:
662 662 # print msg
663 663 self.log.warn("Task %r had no parent!"%msg)
664 664 return
665 665 msg_id = parent['msg_id']
666 666
667 667 header = msg['header']
668 668 engine_uuid = header.get('engine', None)
669 669 eid = self.by_ident.get(engine_uuid, None)
670 670
671 671 if msg_id in self.pending:
672 672 self.pending.remove(msg_id)
673 673 self.all_completed.add(msg_id)
674 674 if eid is not None:
675 675 self.completed[eid].append(msg_id)
676 676 if msg_id in self.tasks[eid]:
677 677 self.tasks[eid].remove(msg_id)
678 678 completed = datetime.strptime(header['date'], ISO8601)
679 679 started = header.get('started', None)
680 680 if started is not None:
681 681 started = datetime.strptime(started, ISO8601)
682 682 result = {
683 683 'result_header' : header,
684 684 'result_content': msg['content'],
685 685 'started' : started,
686 686 'completed' : completed,
687 687 'engine_uuid': engine_uuid
688 688 }
689 689 if MongoDB is not None and isinstance(self.db, MongoDB):
690 690 result['result_buffers'] = map(Binary, msg['buffers'])
691 691 else:
692 692 result['result_buffers'] = msg['buffers']
693 693 self.db.update_record(msg_id, result)
694 694
695 695 else:
696 696 self.log.debug("task::unknown task %s finished"%msg_id)
697 697
698 698 def save_task_destination(self, idents, msg):
699 699 try:
700 700 msg = self.session.unpack_message(msg, content=True)
701 701 except:
702 702 self.log.error("task::invalid task tracking message", exc_info=True)
703 703 return
704 704 content = msg['content']
705 print (content)
705 # print (content)
706 706 msg_id = content['msg_id']
707 707 engine_uuid = content['engine_id']
708 708 eid = self.by_ident[engine_uuid]
709 709
710 710 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
711 711 # if msg_id in self.mia:
712 712 # self.mia.remove(msg_id)
713 713 # else:
714 714 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
715 715
716 716 self.tasks[eid].append(msg_id)
717 717 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
718 718 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
719 719
720 720 def mia_task_request(self, idents, msg):
721 721 raise NotImplementedError
722 722 client_id = idents[0]
723 723 # content = dict(mia=self.mia,status='ok')
724 724 # self.session.send('mia_reply', content=content, idents=client_id)
725 725
726 726
727 727 #--------------------- IOPub Traffic ------------------------------
728 728
729 729 def save_iopub_message(self, topics, msg):
730 730 """save an iopub message into the db"""
731 print (topics)
731 # print (topics)
732 732 try:
733 733 msg = self.session.unpack_message(msg, content=True)
734 734 except:
735 735 self.log.error("iopub::invalid IOPub message", exc_info=True)
736 736 return
737 737
738 738 parent = msg['parent_header']
739 739 if not parent:
740 740 self.log.error("iopub::invalid IOPub message: %s"%msg)
741 741 return
742 742 msg_id = parent['msg_id']
743 743 msg_type = msg['msg_type']
744 744 content = msg['content']
745 745
746 746 # ensure msg_id is in db
747 747 try:
748 748 rec = self.db.get_record(msg_id)
749 749 except:
750 750 self.log.error("iopub::IOPub message has invalid parent", exc_info=True)
751 751 return
752 752 # stream
753 753 d = {}
754 754 if msg_type == 'stream':
755 755 name = content['name']
756 756 s = rec[name] or ''
757 757 d[name] = s + content['data']
758 758
759 759 elif msg_type == 'pyerr':
760 760 d['pyerr'] = content
761 761 else:
762 762 d[msg_type] = content['data']
763 763
764 764 self.db.update_record(msg_id, d)
765 765
766 766
767 767
768 768 #-------------------------------------------------------------------------
769 769 # Registration requests
770 770 #-------------------------------------------------------------------------
771 771
772 772 def connection_request(self, client_id, msg):
773 773 """Reply with connection addresses for clients."""
774 774 self.log.info("client::client %s connected"%client_id)
775 775 content = dict(status='ok')
776 776 content.update(self.client_info)
777 777 jsonable = {}
778 778 for k,v in self.keytable.iteritems():
779 779 jsonable[str(k)] = v
780 780 content['engines'] = jsonable
781 781 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
782 782
783 783 def register_engine(self, reg, msg):
784 784 """Register a new engine."""
785 785 content = msg['content']
786 786 try:
787 787 queue = content['queue']
788 788 except KeyError:
789 789 self.log.error("registration::queue not specified", exc_info=True)
790 790 return
791 791 heart = content.get('heartbeat', None)
792 792 """register a new engine, and create the socket(s) necessary"""
793 793 eid = self._next_id
794 794 # print (eid, queue, reg, heart)
795 795
796 796 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
797 797
798 798 content = dict(id=eid,status='ok')
799 799 content.update(self.engine_info)
800 800 # check if requesting available IDs:
801 801 if queue in self.by_ident:
802 802 try:
803 803 raise KeyError("queue_id %r in use"%queue)
804 804 except:
805 805 content = wrap_exception()
806 806 self.log.error("queue_id %r in use"%queue, exc_info=True)
807 807 elif heart in self.hearts: # need to check unique hearts?
808 808 try:
809 809 raise KeyError("heart_id %r in use"%heart)
810 810 except:
811 811 self.log.error("heart_id %r in use"%heart, exc_info=True)
812 812 content = wrap_exception()
813 813 else:
814 814 for h, pack in self.incoming_registrations.iteritems():
815 815 if heart == h:
816 816 try:
817 817 raise KeyError("heart_id %r in use"%heart)
818 818 except:
819 819 self.log.error("heart_id %r in use"%heart, exc_info=True)
820 820 content = wrap_exception()
821 821 break
822 822 elif queue == pack[1]:
823 823 try:
824 824 raise KeyError("queue_id %r in use"%queue)
825 825 except:
826 826 self.log.error("queue_id %r in use"%queue, exc_info=True)
827 827 content = wrap_exception()
828 828 break
829 829
830 830 msg = self.session.send(self.registrar, "registration_reply",
831 831 content=content,
832 832 ident=reg)
833 833
834 834 if content['status'] == 'ok':
835 835 if heart in self.heartmonitor.hearts:
836 836 # already beating
837 837 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
838 838 self.finish_registration(heart)
839 839 else:
840 840 purge = lambda : self._purge_stalled_registration(heart)
841 841 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
842 842 dc.start()
843 843 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
844 844 else:
845 845 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
846 846 return eid
847 847
848 848 def unregister_engine(self, ident, msg):
849 849 """Unregister an engine that explicitly requested to leave."""
850 850 try:
851 851 eid = msg['content']['id']
852 852 except:
853 853 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
854 854 return
855 855 self.log.info("registration::unregister_engine(%s)"%eid)
856 856 content=dict(id=eid, queue=self.engines[eid].queue)
857 857 self.ids.remove(eid)
858 858 self.keytable.pop(eid)
859 859 ec = self.engines.pop(eid)
860 860 self.hearts.pop(ec.heartbeat)
861 861 self.by_ident.pop(ec.queue)
862 862 self.completed.pop(eid)
863 863 for msg_id in self.queues.pop(eid):
864 864 msg = self.pending.remove(msg_id)
865 865 ############## TODO: HANDLE IT ################
866 866
867 867 if self.notifier:
868 868 self.session.send(self.notifier, "unregistration_notification", content=content)
869 869
870 870 def finish_registration(self, heart):
871 871 """Second half of engine registration, called after our HeartMonitor
872 872 has received a beat from the Engine's Heart."""
873 873 try:
874 874 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
875 875 except KeyError:
876 876 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
877 877 return
878 878 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
879 879 if purge is not None:
880 880 purge.stop()
881 881 control = queue
882 882 self.ids.add(eid)
883 883 self.keytable[eid] = queue
884 884 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
885 885 control=control, heartbeat=heart)
886 886 self.by_ident[queue] = eid
887 887 self.queues[eid] = list()
888 888 self.tasks[eid] = list()
889 889 self.completed[eid] = list()
890 890 self.hearts[heart] = eid
891 891 content = dict(id=eid, queue=self.engines[eid].queue)
892 892 if self.notifier:
893 893 self.session.send(self.notifier, "registration_notification", content=content)
894 894 self.log.info("engine::Engine Connected: %i"%eid)
895 895
896 896 def _purge_stalled_registration(self, heart):
897 897 if heart in self.incoming_registrations:
898 898 eid = self.incoming_registrations.pop(heart)[0]
899 899 self.log.info("registration::purging stalled registration: %i"%eid)
900 900 else:
901 901 pass
902 902
903 903 #-------------------------------------------------------------------------
904 904 # Client Requests
905 905 #-------------------------------------------------------------------------
906 906
907 907 def shutdown_request(self, client_id, msg):
908 908 """handle shutdown request."""
909 909 # s = self.context.socket(zmq.XREQ)
910 910 # s.connect(self.client_connections['mux'])
911 911 # time.sleep(0.1)
912 912 # for eid,ec in self.engines.iteritems():
913 913 # self.session.send(s, 'shutdown_request', content=dict(restart=False), ident=ec.queue)
914 914 # time.sleep(1)
915 915 self.session.send(self.clientele, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
916 916 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
917 917 dc.start()
918 918
919 919 def _shutdown(self):
920 920 self.log.info("hub::hub shutting down.")
921 921 time.sleep(0.1)
922 922 sys.exit(0)
923 923
924 924
925 925 def check_load(self, client_id, msg):
926 926 content = msg['content']
927 927 try:
928 928 targets = content['targets']
929 929 targets = self._validate_targets(targets)
930 930 except:
931 931 content = wrap_exception()
932 932 self.session.send(self.clientele, "hub_error",
933 933 content=content, ident=client_id)
934 934 return
935 935
936 936 content = dict(status='ok')
937 937 # loads = {}
938 938 for t in targets:
939 939 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
940 940 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
941 941
942 942
943 943 def queue_status(self, client_id, msg):
944 944 """Return the Queue status of one or more targets.
945 945 if verbose: return the msg_ids
946 946 else: return len of each type.
947 947 keys: queue (pending MUX jobs)
948 948 tasks (pending Task jobs)
949 949 completed (finished jobs from both queues)"""
950 950 content = msg['content']
951 951 targets = content['targets']
952 952 try:
953 953 targets = self._validate_targets(targets)
954 954 except:
955 955 content = wrap_exception()
956 956 self.session.send(self.clientele, "hub_error",
957 957 content=content, ident=client_id)
958 958 return
959 959 verbose = content.get('verbose', False)
960 960 content = dict(status='ok')
961 961 for t in targets:
962 962 queue = self.queues[t]
963 963 completed = self.completed[t]
964 964 tasks = self.tasks[t]
965 965 if not verbose:
966 966 queue = len(queue)
967 967 completed = len(completed)
968 968 tasks = len(tasks)
969 969 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
970 970 # pending
971 971 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
972 972
973 973 def purge_results(self, client_id, msg):
974 974 """Purge results from memory. This method is more valuable before we move
975 975 to a DB based message storage mechanism."""
976 976 content = msg['content']
977 977 msg_ids = content.get('msg_ids', [])
978 978 reply = dict(status='ok')
979 979 if msg_ids == 'all':
980 980 self.db.drop_matching_records(dict(completed={'$ne':None}))
981 981 else:
982 982 for msg_id in msg_ids:
983 983 if msg_id in self.all_completed:
984 984 self.db.drop_record(msg_id)
985 985 else:
986 986 if msg_id in self.pending:
987 987 try:
988 988 raise IndexError("msg pending: %r"%msg_id)
989 989 except:
990 990 reply = wrap_exception()
991 991 else:
992 992 try:
993 993 raise IndexError("No such msg: %r"%msg_id)
994 994 except:
995 995 reply = wrap_exception()
996 996 break
997 997 eids = content.get('engine_ids', [])
998 998 for eid in eids:
999 999 if eid not in self.engines:
1000 1000 try:
1001 1001 raise IndexError("No such engine: %i"%eid)
1002 1002 except:
1003 1003 reply = wrap_exception()
1004 1004 break
1005 1005 msg_ids = self.completed.pop(eid)
1006 1006 uid = self.engines[eid].queue
1007 1007 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1008 1008
1009 1009 self.session.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
1010 1010
1011 1011 def resubmit_task(self, client_id, msg, buffers):
1012 1012 """Resubmit a task."""
1013 1013 raise NotImplementedError
1014 1014
1015 1015 def get_results(self, client_id, msg):
1016 1016 """Get the result of 1 or more messages."""
1017 1017 content = msg['content']
1018 1018 msg_ids = sorted(set(content['msg_ids']))
1019 1019 statusonly = content.get('status_only', False)
1020 1020 pending = []
1021 1021 completed = []
1022 1022 content = dict(status='ok')
1023 1023 content['pending'] = pending
1024 1024 content['completed'] = completed
1025 1025 buffers = []
1026 1026 if not statusonly:
1027 1027 content['results'] = {}
1028 1028 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1029 1029 for msg_id in msg_ids:
1030 1030 if msg_id in self.pending:
1031 1031 pending.append(msg_id)
1032 1032 elif msg_id in self.all_completed:
1033 1033 completed.append(msg_id)
1034 1034 if not statusonly:
1035 1035 rec = records[msg_id]
1036 1036 io_dict = {}
1037 1037 for key in 'pyin pyout pyerr stdout stderr'.split():
1038 1038 io_dict[key] = rec[key]
1039 1039 content[msg_id] = { 'result_content': rec['result_content'],
1040 1040 'header': rec['header'],
1041 1041 'result_header' : rec['result_header'],
1042 1042 'io' : io_dict,
1043 1043 }
1044 1044 buffers.extend(map(str, rec['result_buffers']))
1045 1045 else:
1046 1046 try:
1047 1047 raise KeyError('No such message: '+msg_id)
1048 1048 except:
1049 1049 content = wrap_exception()
1050 1050 break
1051 1051 self.session.send(self.clientele, "result_reply", content=content,
1052 1052 parent=msg, ident=client_id,
1053 1053 buffers=buffers)
1054 1054
@@ -1,166 +1,174 b''
1 1 """Remote Functions and decorators for the client."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import warnings
14 14
15 from IPython.testing import decorators as testdec
16
15 17 import map as Map
16 18 from asyncresult import AsyncMapResult
17 19
18 20 #-----------------------------------------------------------------------------
19 21 # Decorators
20 22 #-----------------------------------------------------------------------------
21 23
24 @testdec.skip_doctest
22 25 def remote(client, bound=True, block=None, targets=None, balanced=None):
23 26 """Turn a function into a remote function.
24 27
25 28 This method can be used for map:
26 29
27 >>> @remote(client,block=True)
28 def func(a)
30 In [1]: @remote(client,block=True)
31 ...: def func(a):
32 ...: pass
29 33 """
34
30 35 def remote_function(f):
31 36 return RemoteFunction(client, f, bound, block, targets, balanced)
32 37 return remote_function
33 38
39 @testdec.skip_doctest
34 40 def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None):
35 41 """Turn a function into a parallel remote function.
36 42
37 43 This method can be used for map:
38 44
39 >>> @parallel(client,block=True)
40 def func(a)
45 In [1]: @parallel(client,block=True)
46 ...: def func(a):
47 ...: pass
41 48 """
49
42 50 def parallel_function(f):
43 51 return ParallelFunction(client, f, dist, bound, block, targets, balanced)
44 52 return parallel_function
45 53
46 54 #--------------------------------------------------------------------------
47 55 # Classes
48 56 #--------------------------------------------------------------------------
49 57
50 58 class RemoteFunction(object):
51 59 """Turn an existing function into a remote function.
52 60
53 61 Parameters
54 62 ----------
55 63
56 64 client : Client instance
57 65 The client to be used to connect to engines
58 66 f : callable
59 67 The function to be wrapped into a remote function
60 68 bound : bool [default: False]
61 69 Whether the affect the remote namespace when called
62 70 block : bool [default: None]
63 71 Whether to wait for results or not. The default behavior is
64 72 to use the current `block` attribute of `client`
65 73 targets : valid target list [default: all]
66 74 The targets on which to execute.
67 75 balanced : bool
68 76 Whether to load-balance with the Task scheduler or not
69 77 """
70 78
71 79 client = None # the remote connection
72 80 func = None # the wrapped function
73 81 block = None # whether to block
74 82 bound = None # whether to affect the namespace
75 83 targets = None # where to execute
76 84 balanced = None # whether to load-balance
77 85
78 86 def __init__(self, client, f, bound=False, block=None, targets=None, balanced=None):
79 87 self.client = client
80 88 self.func = f
81 89 self.block=block
82 90 self.bound=bound
83 91 self.targets=targets
84 92 if balanced is None:
85 93 if targets is None:
86 94 balanced = True
87 95 else:
88 96 balanced = False
89 97 self.balanced = balanced
90 98
91 99 def __call__(self, *args, **kwargs):
92 100 return self.client.apply(self.func, args=args, kwargs=kwargs,
93 101 block=self.block, targets=self.targets, bound=self.bound, balanced=self.balanced)
94 102
95 103
96 104 class ParallelFunction(RemoteFunction):
97 105 """Class for mapping a function to sequences."""
98 106 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None):
99 107 super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
100 108 self.chunk_size = chunk_size
101 109
102 110 mapClass = Map.dists[dist]
103 111 self.mapObject = mapClass()
104 112
105 113 def __call__(self, *sequences):
106 114 len_0 = len(sequences[0])
107 115 for s in sequences:
108 116 if len(s)!=len_0:
109 117 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
110 118 raise ValueError(msg)
111 119
112 120 if self.balanced:
113 121 if self.chunk_size:
114 122 nparts = len_0/self.chunk_size + int(len_0%self.chunk_size > 0)
115 123 else:
116 124 nparts = len_0
117 125 targets = [self.targets]*nparts
118 126 else:
119 127 if self.chunk_size:
120 128 warnings.warn("`chunk_size` is ignored when `balanced=False", UserWarning)
121 129 # multiplexed:
122 130 targets = self.client._build_targets(self.targets)[-1]
123 131 nparts = len(targets)
124 132
125 133 msg_ids = []
126 134 # my_f = lambda *a: map(self.func, *a)
127 135 for index, t in enumerate(targets):
128 136 args = []
129 137 for seq in sequences:
130 138 part = self.mapObject.getPartition(seq, index, nparts)
131 139 if len(part) == 0:
132 140 continue
133 141 else:
134 142 args.append(part)
135 143 if not args:
136 144 continue
137 145
138 146 # print (args)
139 147 if hasattr(self, '_map'):
140 148 f = map
141 149 args = [self.func]+args
142 150 else:
143 151 f=self.func
144 152 ar = self.client.apply(f, args=args, block=False, bound=self.bound,
145 153 targets=t, balanced=self.balanced)
146 154
147 155 msg_ids.append(ar.msg_ids[0])
148 156
149 157 r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
150 158 if self.block:
151 159 try:
152 160 return r.get()
153 161 except KeyboardInterrupt:
154 162 return r
155 163 else:
156 164 return r
157 165
158 166 def map(self, *sequences):
159 167 """call a function on each element of a sequence remotely."""
160 168 self._map = True
161 169 try:
162 170 ret = self.__call__(*sequences)
163 171 finally:
164 172 del self._map
165 173 return ret
166 174
@@ -1,487 +1,487 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Imports
8 8 #-----------------------------------------------------------------------------
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 12 import __builtin__
13 13
14 14 import logging
15 15 import os
16 16 import sys
17 17 import time
18 18 import traceback
19 19
20 20 from code import CommandCompiler
21 21 from datetime import datetime
22 22 from pprint import pprint
23 23 from signal import SIGTERM, SIGKILL
24 24
25 25 # System library imports.
26 26 import zmq
27 27 from zmq.eventloop import ioloop, zmqstream
28 28
29 29 # Local imports.
30 30 from IPython.core import ultratb
31 31 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
32 32 from IPython.zmq.completer import KernelCompleter
33 33 from IPython.zmq.iostream import OutStream
34 34 from IPython.zmq.displayhook import DisplayHook
35 35
36 36 import heartmonitor
37 37 from client import Client
38 38 from factory import SessionFactory
39 39 from streamsession import StreamSession, Message, extract_header, serialize_object,\
40 40 unpack_apply_message, ISO8601, wrap_exception
41 41
42 42 def printer(*args):
43 43 pprint(args, stream=sys.__stdout__)
44 44
45 45
46 46 class _Passer:
47 47 """Empty class that implements `send()` that does nothing."""
48 48 def send(self, *args, **kwargs):
49 49 pass
50 50 send_multipart = send
51 51
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Main kernel class
55 55 #-----------------------------------------------------------------------------
56 56
57 57 class Kernel(SessionFactory):
58 58
59 59 #---------------------------------------------------------------------------
60 60 # Kernel interface
61 61 #---------------------------------------------------------------------------
62 62
63 63 # kwargs:
64 64 int_id = Int(-1, config=True)
65 65 user_ns = Dict(config=True)
66 66 exec_lines = List(config=True)
67 67
68 68 control_stream = Instance(zmqstream.ZMQStream)
69 69 task_stream = Instance(zmqstream.ZMQStream)
70 70 iopub_stream = Instance(zmqstream.ZMQStream)
71 71 client = Instance('IPython.zmq.parallel.client.Client')
72 72
73 73 # internals
74 74 shell_streams = List()
75 75 compiler = Instance(CommandCompiler, (), {})
76 76 completer = Instance(KernelCompleter)
77 77
78 78 aborted = Set()
79 79 shell_handlers = Dict()
80 80 control_handlers = Dict()
81 81
82 82 def _set_prefix(self):
83 83 self.prefix = "engine.%s"%self.int_id
84 84
85 85 def _connect_completer(self):
86 86 self.completer = KernelCompleter(self.user_ns)
87 87
88 88 def __init__(self, **kwargs):
89 89 super(Kernel, self).__init__(**kwargs)
90 90 self._set_prefix()
91 91 self._connect_completer()
92 92
93 93 self.on_trait_change(self._set_prefix, 'id')
94 94 self.on_trait_change(self._connect_completer, 'user_ns')
95 95
96 96 # Build dict of handlers for message types
97 97 for msg_type in ['execute_request', 'complete_request', 'apply_request',
98 98 'clear_request']:
99 99 self.shell_handlers[msg_type] = getattr(self, msg_type)
100 100
101 101 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
102 102 self.control_handlers[msg_type] = getattr(self, msg_type)
103 103
104 104 self._initial_exec_lines()
105 105
106 106 def _wrap_exception(self, method=None):
107 e_info = dict(engineid=self.ident, method=method)
107 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
108 108 content=wrap_exception(e_info)
109 109 return content
110 110
111 111 def _initial_exec_lines(self):
112 112 s = _Passer()
113 113 content = dict(silent=True, user_variable=[],user_expressions=[])
114 114 for line in self.exec_lines:
115 115 self.log.debug("executing initialization: %s"%line)
116 116 content.update({'code':line})
117 117 msg = self.session.msg('execute_request', content)
118 118 self.execute_request(s, [], msg)
119 119
120 120
121 121 #-------------------- control handlers -----------------------------
122 122 def abort_queues(self):
123 123 for stream in self.shell_streams:
124 124 if stream:
125 125 self.abort_queue(stream)
126 126
127 127 def abort_queue(self, stream):
128 128 while True:
129 129 try:
130 130 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
131 131 except zmq.ZMQError as e:
132 132 if e.errno == zmq.EAGAIN:
133 133 break
134 134 else:
135 135 return
136 136 else:
137 137 if msg is None:
138 138 return
139 139 else:
140 140 idents,msg = msg
141 141
142 142 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
143 143 # msg = self.reply_socket.recv_json()
144 144 self.log.info("Aborting:")
145 145 self.log.info(str(msg))
146 146 msg_type = msg['msg_type']
147 147 reply_type = msg_type.split('_')[0] + '_reply'
148 148 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
149 149 # self.reply_socket.send(ident,zmq.SNDMORE)
150 150 # self.reply_socket.send_json(reply_msg)
151 151 reply_msg = self.session.send(stream, reply_type,
152 152 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
153 153 self.log.debug(str(reply_msg))
154 154 # We need to wait a bit for requests to come in. This can probably
155 155 # be set shorter for true asynchronous clients.
156 156 time.sleep(0.05)
157 157
158 158 def abort_request(self, stream, ident, parent):
159 159 """abort a specifig msg by id"""
160 160 msg_ids = parent['content'].get('msg_ids', None)
161 161 if isinstance(msg_ids, basestring):
162 162 msg_ids = [msg_ids]
163 163 if not msg_ids:
164 164 self.abort_queues()
165 165 for mid in msg_ids:
166 166 self.aborted.add(str(mid))
167 167
168 168 content = dict(status='ok')
169 169 reply_msg = self.session.send(stream, 'abort_reply', content=content,
170 170 parent=parent, ident=ident)[0]
171 171 self.log.debug(str(reply_msg))
172 172
173 173 def shutdown_request(self, stream, ident, parent):
174 174 """kill ourself. This should really be handled in an external process"""
175 175 try:
176 176 self.abort_queues()
177 177 except:
178 178 content = self._wrap_exception('shutdown')
179 179 else:
180 180 content = dict(parent['content'])
181 181 content['status'] = 'ok'
182 182 msg = self.session.send(stream, 'shutdown_reply',
183 183 content=content, parent=parent, ident=ident)
184 184 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
185 185 # content, parent, ident)
186 186 # print >> sys.__stdout__, msg
187 187 # time.sleep(0.2)
188 188 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
189 189 dc.start()
190 190
191 191 def dispatch_control(self, msg):
192 192 idents,msg = self.session.feed_identities(msg, copy=False)
193 193 try:
194 194 msg = self.session.unpack_message(msg, content=True, copy=False)
195 195 except:
196 196 self.log.error("Invalid Message", exc_info=True)
197 197 return
198 198
199 199 header = msg['header']
200 200 msg_id = header['msg_id']
201 201
202 202 handler = self.control_handlers.get(msg['msg_type'], None)
203 203 if handler is None:
204 204 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
205 205 else:
206 206 handler(self.control_stream, idents, msg)
207 207
208 208
209 209 #-------------------- queue helpers ------------------------------
210 210
211 211 def check_dependencies(self, dependencies):
212 212 if not dependencies:
213 213 return True
214 214 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
215 215 anyorall = dependencies[0]
216 216 dependencies = dependencies[1]
217 217 else:
218 218 anyorall = 'all'
219 219 results = self.client.get_results(dependencies,status_only=True)
220 220 if results['status'] != 'ok':
221 221 return False
222 222
223 223 if anyorall == 'any':
224 224 if not results['completed']:
225 225 return False
226 226 else:
227 227 if results['pending']:
228 228 return False
229 229
230 230 return True
231 231
232 232 def check_aborted(self, msg_id):
233 233 return msg_id in self.aborted
234 234
235 235 #-------------------- queue handlers -----------------------------
236 236
237 237 def clear_request(self, stream, idents, parent):
238 238 """Clear our namespace."""
239 239 self.user_ns = {}
240 240 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
241 241 content = dict(status='ok'))
242 242 self._initial_exec_lines()
243 243
244 244 def execute_request(self, stream, ident, parent):
245 245 self.log.debug('execute request %s'%parent)
246 246 try:
247 247 code = parent[u'content'][u'code']
248 248 except:
249 249 self.log.error("Got bad msg: %s"%parent, exc_info=True)
250 250 return
251 251 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
252 252 ident='%s.pyin'%self.prefix)
253 253 started = datetime.now().strftime(ISO8601)
254 254 try:
255 255 comp_code = self.compiler(code, '<zmq-kernel>')
256 256 # allow for not overriding displayhook
257 257 if hasattr(sys.displayhook, 'set_parent'):
258 258 sys.displayhook.set_parent(parent)
259 259 sys.stdout.set_parent(parent)
260 260 sys.stderr.set_parent(parent)
261 261 exec comp_code in self.user_ns, self.user_ns
262 262 except:
263 263 exc_content = self._wrap_exception('execute')
264 264 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
265 265 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
266 266 ident='%s.pyerr'%self.prefix)
267 267 reply_content = exc_content
268 268 else:
269 269 reply_content = {'status' : 'ok'}
270 270
271 271 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
272 272 ident=ident, subheader = dict(started=started))
273 273 self.log.debug(str(reply_msg))
274 274 if reply_msg['content']['status'] == u'error':
275 275 self.abort_queues()
276 276
277 277 def complete_request(self, stream, ident, parent):
278 278 matches = {'matches' : self.complete(parent),
279 279 'status' : 'ok'}
280 280 completion_msg = self.session.send(stream, 'complete_reply',
281 281 matches, parent, ident)
282 282 # print >> sys.__stdout__, completion_msg
283 283
284 284 def complete(self, msg):
285 285 return self.completer.complete(msg.content.line, msg.content.text)
286 286
287 287 def apply_request(self, stream, ident, parent):
288 288 # flush previous reply, so this request won't block it
289 289 stream.flush(zmq.POLLOUT)
290 290
291 291 try:
292 292 content = parent[u'content']
293 293 bufs = parent[u'buffers']
294 294 msg_id = parent['header']['msg_id']
295 295 bound = content.get('bound', False)
296 296 except:
297 297 self.log.error("Got bad msg: %s"%parent, exc_info=True)
298 298 return
299 299 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
300 300 # self.iopub_stream.send(pyin_msg)
301 301 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
302 302 sub = {'dependencies_met' : True, 'engine' : self.ident,
303 303 'started': datetime.now().strftime(ISO8601)}
304 304 try:
305 305 # allow for not overriding displayhook
306 306 if hasattr(sys.displayhook, 'set_parent'):
307 307 sys.displayhook.set_parent(parent)
308 308 sys.stdout.set_parent(parent)
309 309 sys.stderr.set_parent(parent)
310 310 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
311 311 if bound:
312 312 working = self.user_ns
313 313 suffix = str(msg_id).replace("-","")
314 314 prefix = "_"
315 315
316 316 else:
317 317 working = dict()
318 318 suffix = prefix = "_" # prevent keyword collisions with lambda
319 319 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
320 320 # if f.fun
321 321 fname = getattr(f, '__name__', 'f')
322 322
323 323 fname = prefix+fname.strip('<>')+suffix
324 324 argname = prefix+"args"+suffix
325 325 kwargname = prefix+"kwargs"+suffix
326 326 resultname = prefix+"result"+suffix
327 327
328 328 ns = { fname : f, argname : args, kwargname : kwargs }
329 329 # print ns
330 330 working.update(ns)
331 331 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
332 332 exec code in working, working
333 333 result = working.get(resultname)
334 334 # clear the namespace
335 335 if bound:
336 336 for key in ns.iterkeys():
337 337 self.user_ns.pop(key)
338 338 else:
339 339 del working
340 340
341 341 packed_result,buf = serialize_object(result)
342 342 result_buf = [packed_result]+buf
343 343 except:
344 344 exc_content = self._wrap_exception('apply')
345 345 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
346 346 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
347 347 ident='%s.pyerr'%self.prefix)
348 348 reply_content = exc_content
349 349 result_buf = []
350 350
351 351 if exc_content['ename'] == 'UnmetDependency':
352 352 sub['dependencies_met'] = False
353 353 else:
354 354 reply_content = {'status' : 'ok'}
355 355
356 356 # put 'ok'/'error' status in header, for scheduler introspection:
357 357 sub['status'] = reply_content['status']
358 358
359 359 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
360 360 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
361 361
362 362 # if reply_msg['content']['status'] == u'error':
363 363 # self.abort_queues()
364 364
365 365 def dispatch_queue(self, stream, msg):
366 366 self.control_stream.flush()
367 367 idents,msg = self.session.feed_identities(msg, copy=False)
368 368 try:
369 369 msg = self.session.unpack_message(msg, content=True, copy=False)
370 370 except:
371 371 self.log.error("Invalid Message", exc_info=True)
372 372 return
373 373
374 374
375 375 header = msg['header']
376 376 msg_id = header['msg_id']
377 377 if self.check_aborted(msg_id):
378 378 self.aborted.remove(msg_id)
379 379 # is it safe to assume a msg_id will not be resubmitted?
380 380 reply_type = msg['msg_type'].split('_')[0] + '_reply'
381 381 reply_msg = self.session.send(stream, reply_type,
382 382 content={'status' : 'aborted'}, parent=msg, ident=idents)
383 383 return
384 384 handler = self.shell_handlers.get(msg['msg_type'], None)
385 385 if handler is None:
386 386 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
387 387 else:
388 388 handler(stream, idents, msg)
389 389
390 390 def start(self):
391 391 #### stream mode:
392 392 if self.control_stream:
393 393 self.control_stream.on_recv(self.dispatch_control, copy=False)
394 394 self.control_stream.on_err(printer)
395 395
396 396 def make_dispatcher(stream):
397 397 def dispatcher(msg):
398 398 return self.dispatch_queue(stream, msg)
399 399 return dispatcher
400 400
401 401 for s in self.shell_streams:
402 402 s.on_recv(make_dispatcher(s), copy=False)
403 403 s.on_err(printer)
404 404
405 405 if self.iopub_stream:
406 406 self.iopub_stream.on_err(printer)
407 407
408 408 #### while True mode:
409 409 # while True:
410 410 # idle = True
411 411 # try:
412 412 # msg = self.shell_stream.socket.recv_multipart(
413 413 # zmq.NOBLOCK, copy=False)
414 414 # except zmq.ZMQError, e:
415 415 # if e.errno != zmq.EAGAIN:
416 416 # raise e
417 417 # else:
418 418 # idle=False
419 419 # self.dispatch_queue(self.shell_stream, msg)
420 420 #
421 421 # if not self.task_stream.empty():
422 422 # idle=False
423 423 # msg = self.task_stream.recv_multipart()
424 424 # self.dispatch_queue(self.task_stream, msg)
425 425 # if idle:
426 426 # # don't busywait
427 427 # time.sleep(1e-3)
428 428
429 429 def make_kernel(int_id, identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
430 430 client_addr=None, loop=None, context=None, key=None,
431 431 out_stream_factory=OutStream, display_hook_factory=DisplayHook):
432 432 """NO LONGER IN USE"""
433 433 # create loop, context, and session:
434 434 if loop is None:
435 435 loop = ioloop.IOLoop.instance()
436 436 if context is None:
437 437 context = zmq.Context()
438 438 c = context
439 439 session = StreamSession(key=key)
440 440 # print (session.key)
441 441 # print (control_addr, shell_addrs, iopub_addr, hb_addrs)
442 442
443 443 # create Control Stream
444 444 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
445 445 control_stream.setsockopt(zmq.IDENTITY, identity)
446 446 control_stream.connect(control_addr)
447 447
448 448 # create Shell Streams (MUX, Task, etc.):
449 449 shell_streams = []
450 450 for addr in shell_addrs:
451 451 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
452 452 stream.setsockopt(zmq.IDENTITY, identity)
453 453 stream.connect(addr)
454 454 shell_streams.append(stream)
455 455
456 456 # create iopub stream:
457 457 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
458 458 iopub_stream.setsockopt(zmq.IDENTITY, identity)
459 459 iopub_stream.connect(iopub_addr)
460 460
461 461 # Redirect input streams and set a display hook.
462 462 if out_stream_factory:
463 463 sys.stdout = out_stream_factory(session, iopub_stream, u'stdout')
464 464 sys.stdout.topic = 'engine.%i.stdout'%int_id
465 465 sys.stderr = out_stream_factory(session, iopub_stream, u'stderr')
466 466 sys.stderr.topic = 'engine.%i.stderr'%int_id
467 467 if display_hook_factory:
468 468 sys.displayhook = display_hook_factory(session, iopub_stream)
469 469 sys.displayhook.topic = 'engine.%i.pyout'%int_id
470 470
471 471
472 472 # launch heartbeat
473 473 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
474 474 heart.start()
475 475
476 476 # create (optional) Client
477 477 if client_addr:
478 478 client = Client(client_addr, username=identity)
479 479 else:
480 480 client = None
481 481
482 482 kernel = Kernel(id=int_id, session=session, control_stream=control_stream,
483 483 shell_streams=shell_streams, iopub_stream=iopub_stream,
484 484 client=client, loop=loop)
485 485 kernel.start()
486 486 return loop, c, kernel
487 487
@@ -1,45 +1,44 b''
1 1 """toplevel setup/teardown for parallel tests."""
2 2
3 3 import time
4 4 from subprocess import Popen, PIPE
5 5
6 6 from IPython.zmq.parallel.ipcluster import launch_process
7 7 from IPython.zmq.parallel.entry_point import select_random_ports
8 8
9 9 processes = []
10 10
11 11 # nose setup/teardown
12 12
13 13 def setup():
14 14 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
15 15 processes.append(cp)
16 16 time.sleep(.5)
17 17 add_engine()
18 18 time.sleep(3)
19 19
20 20 def add_engine(profile='iptest'):
21 21 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
22 22 # ep.start()
23 23 processes.append(ep)
24 24 return ep
25 25
26 26 def teardown():
27 27 time.sleep(1)
28 28 while processes:
29 29 p = processes.pop()
30 30 if p.poll() is None:
31 31 try:
32 print 'terminating'
33 32 p.terminate()
34 33 except Exception, e:
35 34 print e
36 35 pass
37 36 if p.poll() is None:
38 37 time.sleep(.25)
39 38 if p.poll() is None:
40 39 try:
41 40 print 'killing'
42 41 p.kill()
43 42 except:
44 43 print "couldn't shutdown process: ",p
45 44
@@ -1,96 +1,98 b''
1 1 import time
2 2 from signal import SIGINT
3 3 from multiprocessing import Process
4 4
5 5 from nose import SkipTest
6 6
7 7 from zmq.tests import BaseZMQTestCase
8 8
9 9 from IPython.external.decorator import decorator
10 10
11 11 from IPython.zmq.parallel import error
12 12 from IPython.zmq.parallel.client import Client
13 13 from IPython.zmq.parallel.ipcluster import launch_process
14 14 from IPython.zmq.parallel.entry_point import select_random_ports
15 15 from IPython.zmq.parallel.tests import processes,add_engine
16 16
17 17 # simple tasks for use in apply tests
18 18
19 19 def segfault():
20 """"""
20 """this will segfault"""
21 21 import ctypes
22 22 ctypes.memset(-1,0,1)
23 23
24 24 def wait(n):
25 25 """sleep for a time"""
26 26 import time
27 27 time.sleep(n)
28 28 return n
29 29
30 30 def raiser(eclass):
31 31 """raise an exception"""
32 32 raise eclass()
33 33
34 34 # test decorator for skipping tests when libraries are unavailable
35 35 def skip_without(*names):
36 36 """skip a test if some names are not importable"""
37 37 @decorator
38 38 def skip_without_names(f, *args, **kwargs):
39 39 """decorator to skip tests in the absence of numpy."""
40 40 for name in names:
41 41 try:
42 42 __import__(name)
43 43 except ImportError:
44 44 raise SkipTest
45 45 return f(*args, **kwargs)
46 46 return skip_without_names
47 47
48 48
49 49 class ClusterTestCase(BaseZMQTestCase):
50 50
51 51 def add_engines(self, n=1, block=True):
52 52 """add multiple engines to our cluster"""
53 53 for i in range(n):
54 54 self.engines.append(add_engine())
55 55 if block:
56 56 self.wait_on_engines()
57 57
58 58 def wait_on_engines(self, timeout=5):
59 59 """wait for our engines to connect."""
60 60 n = len(self.engines)+self.base_engine_count
61 61 tic = time.time()
62 62 while time.time()-tic < timeout and len(self.client.ids) < n:
63 63 time.sleep(0.1)
64 64
65 65 assert not self.client.ids < n, "waiting for engines timed out"
66 66
67 67 def connect_client(self):
68 68 """connect a client with my Context, and track its sockets for cleanup"""
69 69 c = Client(profile='iptest',context=self.context)
70 70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
71 71 self.sockets.append(getattr(c, name))
72 72 return c
73 73
74 74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 75 try:
76 f(*args, **kwargs)
77 except error.CompositeError as e:
78 e.raise_exception()
76 try:
77 f(*args, **kwargs)
78 except error.CompositeError as e:
79 e.raise_exception()
79 80 except error.RemoteError as e:
80 81 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
81 82 else:
82 83 self.fail("should have raised a RemoteError")
83 84
84 85 def setUp(self):
85 86 BaseZMQTestCase.setUp(self)
86 87 self.client = self.connect_client()
87 88 self.base_engine_count=len(self.client.ids)
88 89 self.engines=[]
89 90
90 def tearDown(self):
91 [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
92 # while len(self.client.ids) > self.base_engine_count:
93 # time.sleep(.1)
94 del self.engines
95 BaseZMQTestCase.tearDown(self)
91 # def tearDown(self):
92 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
93 # [ e.wait() for e in self.engines ]
94 # while len(self.client.ids) > self.base_engine_count:
95 # time.sleep(.1)
96 # del self.engines
97 # BaseZMQTestCase.tearDown(self)
96 98 No newline at end of file
@@ -1,134 +1,165 b''
1 1 import time
2 2
3 3 import nose.tools as nt
4 4
5 from IPython.zmq.parallel.asyncresult import AsyncResult
5 from IPython.zmq.parallel import client as clientmod
6 from IPython.zmq.parallel import error
7 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
6 8 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
7 9
8 from clienttest import ClusterTestCase, segfault
10 from clienttest import ClusterTestCase, segfault, wait
9 11
10 12 class TestClient(ClusterTestCase):
11 13
12 14 def test_ids(self):
13 self.assertEquals(len(self.client.ids), 1)
15 n = len(self.client.ids)
14 16 self.add_engines(3)
15 self.assertEquals(len(self.client.ids), 4)
17 self.assertEquals(len(self.client.ids), n+3)
18 self.assertTrue
16 19
17 20 def test_segfault(self):
21 """test graceful handling of engine death"""
18 22 self.add_engines(1)
19 23 eid = self.client.ids[-1]
20 self.client[eid].apply(segfault)
24 ar = self.client.apply(segfault, block=False)
25 self.assertRaisesRemote(error.EngineError, ar.get)
26 eid = ar.engine_id
21 27 while eid in self.client.ids:
22 28 time.sleep(.01)
23 29 self.client.spin()
24 30
25 31 def test_view_indexing(self):
26 self.add_engines(4)
32 """test index access for views"""
33 self.add_engines(2)
27 34 targets = self.client._build_targets('all')[-1]
28 35 v = self.client[:]
29 36 self.assertEquals(v.targets, targets)
30 37 t = self.client.ids[2]
31 38 v = self.client[t]
32 39 self.assert_(isinstance(v, DirectView))
33 40 self.assertEquals(v.targets, t)
34 41 t = self.client.ids[2:4]
35 42 v = self.client[t]
36 43 self.assert_(isinstance(v, DirectView))
37 44 self.assertEquals(v.targets, t)
38 45 v = self.client[::2]
39 46 self.assert_(isinstance(v, DirectView))
40 47 self.assertEquals(v.targets, targets[::2])
41 48 v = self.client[1::3]
42 49 self.assert_(isinstance(v, DirectView))
43 50 self.assertEquals(v.targets, targets[1::3])
44 51 v = self.client[:-3]
45 52 self.assert_(isinstance(v, DirectView))
46 53 self.assertEquals(v.targets, targets[:-3])
47 54 v = self.client[-1]
48 55 self.assert_(isinstance(v, DirectView))
49 56 self.assertEquals(v.targets, targets[-1])
50 57 nt.assert_raises(TypeError, lambda : self.client[None])
51 58
52 59 def test_view_cache(self):
53 60 """test that multiple view requests return the same object"""
54 61 v = self.client[:2]
55 62 v2 =self.client[:2]
56 63 self.assertTrue(v is v2)
57 64 v = self.client.view()
58 65 v2 = self.client.view(balanced=True)
59 66 self.assertTrue(v is v2)
60 67
61 68 def test_targets(self):
62 69 """test various valid targets arguments"""
63 pass
70 build = self.client._build_targets
71 ids = self.client.ids
72 idents,targets = build(None)
73 self.assertEquals(ids, targets)
64 74
65 75 def test_clear(self):
66 76 """test clear behavior"""
67 # self.add_engines(4)
68 # self.client.push()
77 self.add_engines(2)
78 self.client.block=True
79 self.client.push(dict(a=5))
80 self.client.pull('a')
81 id0 = self.client.ids[-1]
82 self.client.clear(targets=id0)
83 self.client.pull('a', targets=self.client.ids[:-1])
84 self.assertRaisesRemote(NameError, self.client.pull, 'a')
85 self.client.clear()
86 for i in self.client.ids:
87 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
88
69 89
70 90 def test_push_pull(self):
71 91 """test pushing and pulling"""
72 92 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
73 self.add_engines(4)
93 self.add_engines(2)
74 94 push = self.client.push
75 95 pull = self.client.pull
76 96 self.client.block=True
77 97 nengines = len(self.client)
78 98 push({'data':data}, targets=0)
79 99 d = pull('data', targets=0)
80 100 self.assertEquals(d, data)
81 101 push({'data':data})
82 102 d = pull('data')
83 103 self.assertEquals(d, nengines*[data])
84 104 ar = push({'data':data}, block=False)
85 105 self.assertTrue(isinstance(ar, AsyncResult))
86 106 r = ar.get()
87 107 ar = pull('data', block=False)
88 108 self.assertTrue(isinstance(ar, AsyncResult))
89 109 r = ar.get()
90 110 self.assertEquals(r, nengines*[data])
91 111 push(dict(a=10,b=20))
92 112 r = pull(('a','b'))
93 113 self.assertEquals(r, nengines*[[10,20]])
94 114
95 115 def test_push_pull_function(self):
96 116 "test pushing and pulling functions"
97 117 def testf(x):
98 118 return 2.0*x
99 119
100 120 self.add_engines(4)
101 121 self.client.block=True
102 122 push = self.client.push
103 123 pull = self.client.pull
104 124 execute = self.client.execute
105 125 push({'testf':testf}, targets=0)
106 126 r = pull('testf', targets=0)
107 127 self.assertEqual(r(1.0), testf(1.0))
108 128 execute('r = testf(10)', targets=0)
109 129 r = pull('r', targets=0)
110 130 self.assertEquals(r, testf(10))
111 131 ar = push({'testf':testf}, block=False)
112 132 ar.get()
113 133 ar = pull('testf', block=False)
114 134 rlist = ar.get()
115 135 for r in rlist:
116 136 self.assertEqual(r(1.0), testf(1.0))
117 137 execute("def g(x): return x*x", targets=0)
118 138 r = pull(('testf','g'),targets=0)
119 139 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
120 140
121 141 def test_push_function_globals(self):
122 142 """test that pushed functions have access to globals"""
123 143 def geta():
124 144 return a
125 145 self.add_engines(1)
126 146 v = self.client[-1]
127 147 v.block=True
128 148 v['f'] = geta
129 149 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 150 v.execute('a=5')
131 151 v.execute('b=f()')
132 152 self.assertEquals(v['b'], 5)
133 153
154 def test_get_result(self):
155 """test getting results from the Hub."""
156 c = clientmod.Client(profile='iptest')
157 t = self.client.ids[-1]
158 ar = c.apply(wait, (1,), block=False, targets=t)
159 time.sleep(.25)
160 ahr = self.client.get_result(ar.msg_ids)
161 self.assertTrue(isinstance(ahr, AsyncHubResult))
162 self.assertEquals(ahr.get(), ar.get())
163 ar2 = self.client.get_result(ar.msg_ids)
164 self.assertFalse(isinstance(ar2, AsyncHubResult))
134 165 No newline at end of file
@@ -1,639 +1,639 b''
1 1 """Views of remote engines"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 from IPython.testing import decorators as testdec
13 14 from IPython.utils.traitlets import HasTraits, Bool, List, Dict, Set, Int, Instance
14 15
15 16 from IPython.external.decorator import decorator
16 17 from IPython.zmq.parallel.asyncresult import AsyncResult
17 18 from IPython.zmq.parallel.dependency import Dependency
18 19 from IPython.zmq.parallel.remotefunction import ParallelFunction, parallel, remote
19 20
20 21 #-----------------------------------------------------------------------------
21 22 # Decorators
22 23 #-----------------------------------------------------------------------------
23 24
24 25 @decorator
25 26 def myblock(f, self, *args, **kwargs):
26 27 """override client.block with self.block during a call"""
27 28 block = self.client.block
28 29 self.client.block = self.block
29 30 try:
30 31 ret = f(self, *args, **kwargs)
31 32 finally:
32 33 self.client.block = block
33 34 return ret
34 35
35 36 @decorator
36 37 def save_ids(f, self, *args, **kwargs):
37 38 """Keep our history and outstanding attributes up to date after a method call."""
38 39 n_previous = len(self.client.history)
39 40 ret = f(self, *args, **kwargs)
40 41 nmsgs = len(self.client.history) - n_previous
41 42 msg_ids = self.client.history[-nmsgs:]
42 43 self.history.extend(msg_ids)
43 44 map(self.outstanding.add, msg_ids)
44 45 return ret
45 46
46 47 @decorator
47 48 def sync_results(f, self, *args, **kwargs):
48 49 """sync relevant results from self.client to our results attribute."""
49 50 ret = f(self, *args, **kwargs)
50 51 delta = self.outstanding.difference(self.client.outstanding)
51 52 completed = self.outstanding.intersection(delta)
52 53 self.outstanding = self.outstanding.difference(completed)
53 54 for msg_id in completed:
54 55 self.results[msg_id] = self.client.results[msg_id]
55 56 return ret
56 57
57 58 @decorator
58 59 def spin_after(f, self, *args, **kwargs):
59 60 """call spin after the method."""
60 61 ret = f(self, *args, **kwargs)
61 62 self.spin()
62 63 return ret
63 64
64 65 #-----------------------------------------------------------------------------
65 66 # Classes
66 67 #-----------------------------------------------------------------------------
67 68
68 69 class View(HasTraits):
69 70 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
70 71
71 72 Don't use this class, use subclasses.
72 73 """
73 74 block=Bool(False)
74 75 bound=Bool(False)
75 76 history=List()
76 77 outstanding = Set()
77 78 results = Dict()
78 79 client = Instance('IPython.zmq.parallel.client.Client')
79 80
80 81 _ntargets = Int(1)
81 82 _balanced = Bool(False)
82 83 _default_names = List(['block', 'bound'])
83 84 _targets = None
84 85
85 86 def __init__(self, client=None, targets=None):
86 87 super(View, self).__init__(client=client)
87 88 self._targets = targets
88 89 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
89 90 self.block = client.block
90 91
91 92 for name in self._default_names:
92 93 setattr(self, name, getattr(self, name, None))
93 94
94 95 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
95 96
96 97
97 98 def __repr__(self):
98 99 strtargets = str(self._targets)
99 100 if len(strtargets) > 16:
100 101 strtargets = strtargets[:12]+'...]'
101 102 return "<%s %s>"%(self.__class__.__name__, strtargets)
102 103
103 104 @property
104 105 def targets(self):
105 106 return self._targets
106 107
107 108 @targets.setter
108 109 def targets(self, value):
109 110 raise AttributeError("Cannot set View `targets` after construction!")
110 111
111 112 @property
112 113 def balanced(self):
113 114 return self._balanced
114 115
115 116 @balanced.setter
116 117 def balanced(self, value):
117 118 raise AttributeError("Cannot set View `balanced` after construction!")
118 119
119 120 def _defaults(self, *excludes):
120 121 """return dict of our default attributes, excluding names given."""
121 122 d = dict(balanced=self._balanced, targets=self._targets)
122 123 for name in self._default_names:
123 124 if name not in excludes:
124 125 d[name] = getattr(self, name)
125 126 return d
126 127
127 128 def set_flags(self, **kwargs):
128 129 """set my attribute flags by keyword.
129 130
130 131 A View is a wrapper for the Client's apply method, but
131 132 with attributes that specify keyword arguments, those attributes
132 133 can be set by keyword argument with this method.
133 134
134 135 Parameters
135 136 ----------
136 137
137 138 block : bool
138 139 whether to wait for results
139 140 bound : bool
140 141 whether to use the client's namespace
141 142 """
142 143 for key in kwargs:
143 144 if key not in self._default_names:
144 145 raise KeyError("Invalid name: %r"%key)
145 146 for name in ('block', 'bound'):
146 147 if name in kwargs:
147 148 setattr(self, name, kwargs[name])
148 149
149 150 #----------------------------------------------------------------
150 151 # wrappers for client methods:
151 152 #----------------------------------------------------------------
152 153 @sync_results
153 154 def spin(self):
154 155 """spin the client, and sync"""
155 156 self.client.spin()
156 157
157 158 @sync_results
158 159 @save_ids
159 160 def apply(self, f, *args, **kwargs):
160 161 """calls f(*args, **kwargs) on remote engines, returning the result.
161 162
162 163 This method does not involve the engine's namespace.
163 164
164 165 if self.block is False:
165 166 returns msg_id
166 167 else:
167 168 returns actual result of f(*args, **kwargs)
168 169 """
169 170 return self.client.apply(f, args, kwargs, **self._defaults())
170 171
171 172 @save_ids
172 173 def apply_async(self, f, *args, **kwargs):
173 174 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
174 175
175 176 This method does not involve the engine's namespace.
176 177
177 178 returns msg_id
178 179 """
179 180 d = self._defaults('block', 'bound')
180 181 return self.client.apply(f,args,kwargs, block=False, bound=False, **d)
181 182
182 183 @spin_after
183 184 @save_ids
184 185 def apply_sync(self, f, *args, **kwargs):
185 186 """calls f(*args, **kwargs) on remote engines in a blocking manner,
186 187 returning the result.
187 188
188 189 This method does not involve the engine's namespace.
189 190
190 191 returns: actual result of f(*args, **kwargs)
191 192 """
192 193 d = self._defaults('block', 'bound')
193 194 return self.client.apply(f,args,kwargs, block=True, bound=False, **d)
194 195
195 196 # @sync_results
196 197 # @save_ids
197 198 # def apply_bound(self, f, *args, **kwargs):
198 199 # """calls f(*args, **kwargs) bound to engine namespace(s).
199 200 #
200 201 # if self.block is False:
201 202 # returns msg_id
202 203 # else:
203 204 # returns actual result of f(*args, **kwargs)
204 205 #
205 206 # This method has access to the targets' namespace via globals()
206 207 #
207 208 # """
208 209 # d = self._defaults('bound')
209 210 # return self.client.apply(f, args, kwargs, bound=True, **d)
210 211 #
211 212 @sync_results
212 213 @save_ids
213 214 def apply_async_bound(self, f, *args, **kwargs):
214 215 """calls f(*args, **kwargs) bound to engine namespace(s)
215 216 in a nonblocking manner.
216 217
217 218 returns: msg_id
218 219
219 220 This method has access to the targets' namespace via globals()
220 221
221 222 """
222 223 d = self._defaults('block', 'bound')
223 224 return self.client.apply(f, args, kwargs, block=False, bound=True, **d)
224 225
225 226 @spin_after
226 227 @save_ids
227 228 def apply_sync_bound(self, f, *args, **kwargs):
228 229 """calls f(*args, **kwargs) bound to engine namespace(s), waiting for the result.
229 230
230 231 returns: actual result of f(*args, **kwargs)
231 232
232 233 This method has access to the targets' namespace via globals()
233 234
234 235 """
235 236 d = self._defaults('block', 'bound')
236 237 return self.client.apply(f, args, kwargs, block=True, bound=True, **d)
237 238
238 239 def abort(self, jobs=None, block=None):
239 240 """Abort jobs on my engines.
240 241
241 242 Parameters
242 243 ----------
243 244
244 245 jobs : None, str, list of strs, optional
245 246 if None: abort all jobs.
246 247 else: abort specific msg_id(s).
247 248 """
248 249 block = block if block is not None else self.block
249 250 return self.client.abort(jobs=jobs, targets=self._targets, block=block)
250 251
251 252 def queue_status(self, verbose=False):
252 253 """Fetch the Queue status of my engines"""
253 254 return self.client.queue_status(targets=self._targets, verbose=verbose)
254 255
255 256 def purge_results(self, jobs=[], targets=[]):
256 257 """Instruct the controller to forget specific results."""
257 258 if targets is None or targets == 'all':
258 259 targets = self._targets
259 260 return self.client.purge_results(jobs=jobs, targets=targets)
260 261
261 262 @spin_after
262 263 def get_result(self, indices_or_msg_ids=None):
263 264 """return one or more results, specified by history index or msg_id.
264 265
265 266 See client.get_result for details.
266 267
267 268 """
268 269
269 270 if indices_or_msg_ids is None:
270 271 indices_or_msg_ids = -1
271 272 if isinstance(indices_or_msg_ids, int):
272 273 indices_or_msg_ids = self.history[indices_or_msg_ids]
273 274 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
274 275 indices_or_msg_ids = list(indices_or_msg_ids)
275 276 for i,index in enumerate(indices_or_msg_ids):
276 277 if isinstance(index, int):
277 278 indices_or_msg_ids[i] = self.history[index]
278 279 return self.client.get_result(indices_or_msg_ids)
279 280
280 281 #-------------------------------------------------------------------
281 282 # Map
282 283 #-------------------------------------------------------------------
283 284
284 285 def map(self, f, *sequences, **kwargs):
285 286 """override in subclasses"""
286 287 raise NotImplementedError
287 288
288 289 def map_async(self, f, *sequences, **kwargs):
289 290 """Parallel version of builtin `map`, using this view's engines.
290 291
291 292 This is equivalent to map(...block=False)
292 293
293 294 See `self.map` for details.
294 295 """
295 296 if 'block' in kwargs:
296 297 raise TypeError("map_async doesn't take a `block` keyword argument.")
297 298 kwargs['block'] = False
298 299 return self.map(f,*sequences,**kwargs)
299 300
300 301 def map_sync(self, f, *sequences, **kwargs):
301 302 """Parallel version of builtin `map`, using this view's engines.
302 303
303 304 This is equivalent to map(...block=True)
304 305
305 306 See `self.map` for details.
306 307 """
307 308 if 'block' in kwargs:
308 309 raise TypeError("map_sync doesn't take a `block` keyword argument.")
309 310 kwargs['block'] = True
310 311 return self.map(f,*sequences,**kwargs)
311 312
312 313 def imap(self, f, *sequences, **kwargs):
313 314 """Parallel version of `itertools.imap`.
314 315
315 316 See `self.map` for details.
316 317 """
317 318
318 319 return iter(self.map_async(f,*sequences, **kwargs))
319 320
320 321 #-------------------------------------------------------------------
321 322 # Decorators
322 323 #-------------------------------------------------------------------
323 324
324 325 def remote(self, bound=True, block=True):
325 326 """Decorator for making a RemoteFunction"""
326 327 return remote(self.client, bound=bound, targets=self._targets, block=block, balanced=self._balanced)
327 328
328 329 def parallel(self, dist='b', bound=True, block=None):
329 330 """Decorator for making a ParallelFunction"""
330 331 block = self.block if block is None else block
331 332 return parallel(self.client, bound=bound, targets=self._targets, block=block, balanced=self._balanced)
332 333
333
334 @testdec.skip_doctest
334 335 class DirectView(View):
335 336 """Direct Multiplexer View of one or more engines.
336 337
337 338 These are created via indexed access to a client:
338 339
339 340 >>> dv_1 = client[1]
340 341 >>> dv_all = client[:]
341 342 >>> dv_even = client[::2]
342 343 >>> dv_some = client[1:3]
343 344
344 345 This object provides dictionary access to engine namespaces:
345 346
346 347 # push a=5:
347 348 >>> dv['a'] = 5
348 349 # pull 'foo':
349 350 >>> db['foo']
350 351
351 352 """
352 353
353 354 def __init__(self, client=None, targets=None):
354 355 super(DirectView, self).__init__(client=client, targets=targets)
355 356 self._balanced = False
356 357
357 358 @spin_after
358 359 @save_ids
359 360 def map(self, f, *sequences, **kwargs):
360 361 """view.map(f, *sequences, block=self.block, bound=self.bound) => list|AsyncMapResult
361 362
362 363 Parallel version of builtin `map`, using this View's `targets`.
363 364
364 365 There will be one task per target, so work will be chunked
365 366 if the sequences are longer than `targets`.
366 367
367 368 Results can be iterated as they are ready, but will become available in chunks.
368 369
369 370 Parameters
370 371 ----------
371 372
372 373 f : callable
373 374 function to be mapped
374 375 *sequences: one or more sequences of matching length
375 376 the sequences to be distributed and passed to `f`
376 377 block : bool
377 378 whether to wait for the result or not [default self.block]
378 379 bound : bool
379 380 whether to have access to the engines' namespaces [default self.bound]
380 381
381 382 Returns
382 383 -------
383 384
384 385 if block=False:
385 386 AsyncMapResult
386 387 An object like AsyncResult, but which reassembles the sequence of results
387 388 into a single list. AsyncMapResults can be iterated through before all
388 389 results are complete.
389 390 else:
390 391 list
391 392 the result of map(f,*sequences)
392 393 """
393 394
394 395 block = kwargs.get('block', self.block)
395 396 bound = kwargs.get('bound', self.bound)
396 397 for k in kwargs.keys():
397 398 if k not in ['block', 'bound']:
398 399 raise TypeError("invalid keyword arg, %r"%k)
399 400
400 401 assert len(sequences) > 0, "must have some sequences to map onto!"
401 402 pf = ParallelFunction(self.client, f, block=block, bound=bound,
402 403 targets=self._targets, balanced=False)
403 404 return pf.map(*sequences)
404 405
405 406 @sync_results
406 407 @save_ids
407 408 def execute(self, code, block=True):
408 409 """execute some code on my targets."""
409 410 return self.client.execute(code, block=block, targets=self._targets)
410 411
411 412 def update(self, ns):
412 413 """update remote namespace with dict `ns`"""
413 414 return self.client.push(ns, targets=self._targets, block=self.block)
414 415
415 416 push = update
416
417
417 418 def get(self, key_s):
418 419 """get object(s) by `key_s` from remote namespace
419 420 will return one object if it is a key.
420 421 It also takes a list of keys, and will return a list of objects."""
421 422 # block = block if block is not None else self.block
422 423 return self.client.pull(key_s, block=True, targets=self._targets)
423 424
424 425 @sync_results
425 426 @save_ids
426 427 def pull(self, key_s, block=True):
427 428 """get object(s) by `key_s` from remote namespace
428 429 will return one object if it is a key.
429 430 It also takes a list of keys, and will return a list of objects."""
430 431 block = block if block is not None else self.block
431 432 return self.client.pull(key_s, block=block, targets=self._targets)
432 433
433 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None):
434 def scatter(self, key, seq, dist='b', flatten=False, block=None):
434 435 """
435 436 Partition a Python sequence and send the partitions to a set of engines.
436 437 """
437 438 block = block if block is not None else self.block
438 targets = targets if targets is not None else self._targets
439 439
440 440 return self.client.scatter(key, seq, dist=dist, flatten=flatten,
441 targets=targets, block=block)
441 targets=self._targets, block=block)
442 442
443 443 @sync_results
444 444 @save_ids
445 def gather(self, key, dist='b', targets=None, block=None):
445 def gather(self, key, dist='b', block=None):
446 446 """
447 447 Gather a partitioned sequence on a set of engines as a single local seq.
448 448 """
449 449 block = block if block is not None else self.block
450 targets = targets if targets is not None else self._targets
451 450
452 return self.client.gather(key, dist=dist, targets=targets, block=block)
451 return self.client.gather(key, dist=dist, targets=self._targets, block=block)
453 452
454 453 def __getitem__(self, key):
455 454 return self.get(key)
456 455
457 456 def __setitem__(self,key, value):
458 457 self.update({key:value})
459 458
460 459 def clear(self, block=False):
461 460 """Clear the remote namespaces on my engines."""
462 461 block = block if block is not None else self.block
463 462 return self.client.clear(targets=self._targets, block=block)
464 463
465 464 def kill(self, block=True):
466 465 """Kill my engines."""
467 466 block = block if block is not None else self.block
468 467 return self.client.kill(targets=self._targets, block=block)
469 468
470 469 #----------------------------------------
471 470 # activate for %px,%autopx magics
472 471 #----------------------------------------
473 472 def activate(self):
474 473 """Make this `View` active for parallel magic commands.
475 474
476 475 IPython has a magic command syntax to work with `MultiEngineClient` objects.
477 476 In a given IPython session there is a single active one. While
478 477 there can be many `Views` created and used by the user,
479 478 there is only one active one. The active `View` is used whenever
480 479 the magic commands %px and %autopx are used.
481 480
482 481 The activate() method is called on a given `View` to make it
483 482 active. Once this has been done, the magic commands can be used.
484 483 """
485 484
486 485 try:
487 486 # This is injected into __builtins__.
488 487 ip = get_ipython()
489 488 except NameError:
490 489 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
491 490 else:
492 491 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
493 492 if pmagic is not None:
494 493 pmagic.active_multiengine_client = self
495 494 else:
496 495 print "You must first load the parallelmagic extension " \
497 496 "by doing '%load_ext parallelmagic'"
498 497
499
498
499 @testdec.skip_doctest
500 500 class LoadBalancedView(View):
501 501 """An load-balancing View that only executes via the Task scheduler.
502 502
503 503 Load-balanced views can be created with the client's `view` method:
504 504
505 505 >>> v = client.view(balanced=True)
506 506
507 507 or targets can be specified, to restrict the potential destinations:
508 508
509 509 >>> v = client.view([1,3],balanced=True)
510 510
511 511 which would restrict loadbalancing to between engines 1 and 3.
512 512
513 513 """
514 514
515 515 _default_names = ['block', 'bound', 'follow', 'after', 'timeout']
516 516
517 517 def __init__(self, client=None, targets=None):
518 518 super(LoadBalancedView, self).__init__(client=client, targets=targets)
519 519 self._ntargets = 1
520 520 self._balanced = True
521 521
522 522 def _validate_dependency(self, dep):
523 523 """validate a dependency.
524 524
525 525 For use in `set_flags`.
526 526 """
527 527 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
528 528 return True
529 529 elif isinstance(dep, (list,set, tuple)):
530 530 for d in dep:
531 531 if not isinstance(d, str, AsyncResult):
532 532 return False
533 533 elif isinstance(dep, dict):
534 534 if set(dep.keys()) != set(Dependency().as_dict().keys()):
535 535 return False
536 536 if not isinstance(dep['msg_ids'], list):
537 537 return False
538 538 for d in dep['msg_ids']:
539 539 if not isinstance(d, str):
540 540 return False
541 541 else:
542 542 return False
543 543
544 544 def set_flags(self, **kwargs):
545 545 """set my attribute flags by keyword.
546 546
547 547 A View is a wrapper for the Client's apply method, but with attributes
548 548 that specify keyword arguments, those attributes can be set by keyword
549 549 argument with this method.
550 550
551 551 Parameters
552 552 ----------
553 553
554 554 block : bool
555 555 whether to wait for results
556 556 bound : bool
557 557 whether to use the engine's namespace
558 558 follow : Dependency, list, msg_id, AsyncResult
559 559 the location dependencies of tasks
560 560 after : Dependency, list, msg_id, AsyncResult
561 561 the time dependencies of tasks
562 562 timeout : int,None
563 563 the timeout to be used for tasks
564 564 """
565 565
566 566 super(LoadBalancedView, self).set_flags(**kwargs)
567 567 for name in ('follow', 'after'):
568 568 if name in kwargs:
569 569 value = kwargs[name]
570 570 if self._validate_dependency(value):
571 571 setattr(self, name, value)
572 572 else:
573 573 raise ValueError("Invalid dependency: %r"%value)
574 574 if 'timeout' in kwargs:
575 575 t = kwargs['timeout']
576 576 if not isinstance(t, (int, long, float, None)):
577 577 raise TypeError("Invalid type for timeout: %r"%type(t))
578 578 if t is not None:
579 579 if t < 0:
580 580 raise ValueError("Invalid timeout: %s"%t)
581 581 self.timeout = t
582 582
583 583 @spin_after
584 584 @save_ids
585 585 def map(self, f, *sequences, **kwargs):
586 586 """view.map(f, *sequences, block=self.block, bound=self.bound, chunk_size=1) => list|AsyncMapResult
587 587
588 588 Parallel version of builtin `map`, load-balanced by this View.
589 589
590 590 `block`, `bound`, and `chunk_size` can be specified by keyword only.
591 591
592 592 Each `chunk_size` elements will be a separate task, and will be
593 593 load-balanced. This lets individual elements be available for iteration
594 594 as soon as they arrive.
595 595
596 596 Parameters
597 597 ----------
598 598
599 599 f : callable
600 600 function to be mapped
601 601 *sequences: one or more sequences of matching length
602 602 the sequences to be distributed and passed to `f`
603 603 block : bool
604 604 whether to wait for the result or not [default self.block]
605 605 bound : bool
606 606 whether to use the engine's namespace [default self.bound]
607 607 chunk_size : int
608 608 how many elements should be in each task [default 1]
609 609
610 610 Returns
611 611 -------
612 612
613 613 if block=False:
614 614 AsyncMapResult
615 615 An object like AsyncResult, but which reassembles the sequence of results
616 616 into a single list. AsyncMapResults can be iterated through before all
617 617 results are complete.
618 618 else:
619 619 the result of map(f,*sequences)
620 620
621 621 """
622 622
623 623 # default
624 624 block = kwargs.get('block', self.block)
625 625 bound = kwargs.get('bound', self.bound)
626 626 chunk_size = kwargs.get('chunk_size', 1)
627 627
628 628 keyset = set(kwargs.keys())
629 629 extra_keys = keyset.difference_update(set(['block', 'bound', 'chunk_size']))
630 630 if extra_keys:
631 631 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
632 632
633 633 assert len(sequences) > 0, "must have some sequences to map onto!"
634 634
635 635 pf = ParallelFunction(self.client, f, block=block, bound=bound,
636 636 targets=self._targets, balanced=True,
637 637 chunk_size=chunk_size)
638 638 return pf.map(*sequences)
639 639
General Comments 0
You need to be logged in to leave comments. Login now