##// END OF EJS Templates
handle targets='all' in remotefunction...
MinRK -
Show More
@@ -1,1434 +1,1440 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 import time
21 import time
22 import warnings
22 import warnings
23 from datetime import datetime
23 from datetime import datetime
24 from getpass import getpass
24 from getpass import getpass
25 from pprint import pprint
25 from pprint import pprint
26
26
27 pjoin = os.path.join
27 pjoin = os.path.join
28
28
29 import zmq
29 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
30 # from zmq.eventloop import ioloop, zmqstream
31
31
32 from IPython.config.configurable import MultipleInstanceError
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
33 from IPython.core.application import BaseIPythonApplication
34
34
35 from IPython.utils.jsonutil import rekey
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 Dict, List, Bool, Set)
39 Dict, List, Bool, Set)
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import error
43 from IPython.parallel import error
44 from IPython.parallel import util
44 from IPython.parallel import util
45
45
46 from IPython.zmq.session import Session, Message
46 from IPython.zmq.session import Session, Message
47
47
48 from .asyncresult import AsyncResult, AsyncHubResult
48 from .asyncresult import AsyncResult, AsyncHubResult
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from .view import DirectView, LoadBalancedView
50 from .view import DirectView, LoadBalancedView
51
51
52 if sys.version_info[0] >= 3:
52 if sys.version_info[0] >= 3:
53 # xrange is used in a couple 'isinstance' tests in py2
53 # xrange is used in a couple 'isinstance' tests in py2
54 # should be just 'range' in 3k
54 # should be just 'range' in 3k
55 xrange = range
55 xrange = range
56
56
57 #--------------------------------------------------------------------------
57 #--------------------------------------------------------------------------
58 # Decorators for Client methods
58 # Decorators for Client methods
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60
60
61 @decorator
61 @decorator
62 def spin_first(f, self, *args, **kwargs):
62 def spin_first(f, self, *args, **kwargs):
63 """Call spin() to sync state prior to calling the method."""
63 """Call spin() to sync state prior to calling the method."""
64 self.spin()
64 self.spin()
65 return f(self, *args, **kwargs)
65 return f(self, *args, **kwargs)
66
66
67
67
68 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71
71
72 class Metadata(dict):
72 class Metadata(dict):
73 """Subclass of dict for initializing metadata values.
73 """Subclass of dict for initializing metadata values.
74
74
75 Attribute access works on keys.
75 Attribute access works on keys.
76
76
77 These objects have a strict set of keys - errors will raise if you try
77 These objects have a strict set of keys - errors will raise if you try
78 to add new keys.
78 to add new keys.
79 """
79 """
80 def __init__(self, *args, **kwargs):
80 def __init__(self, *args, **kwargs):
81 dict.__init__(self)
81 dict.__init__(self)
82 md = {'msg_id' : None,
82 md = {'msg_id' : None,
83 'submitted' : None,
83 'submitted' : None,
84 'started' : None,
84 'started' : None,
85 'completed' : None,
85 'completed' : None,
86 'received' : None,
86 'received' : None,
87 'engine_uuid' : None,
87 'engine_uuid' : None,
88 'engine_id' : None,
88 'engine_id' : None,
89 'follow' : None,
89 'follow' : None,
90 'after' : None,
90 'after' : None,
91 'status' : None,
91 'status' : None,
92
92
93 'pyin' : None,
93 'pyin' : None,
94 'pyout' : None,
94 'pyout' : None,
95 'pyerr' : None,
95 'pyerr' : None,
96 'stdout' : '',
96 'stdout' : '',
97 'stderr' : '',
97 'stderr' : '',
98 }
98 }
99 self.update(md)
99 self.update(md)
100 self.update(dict(*args, **kwargs))
100 self.update(dict(*args, **kwargs))
101
101
102 def __getattr__(self, key):
102 def __getattr__(self, key):
103 """getattr aliased to getitem"""
103 """getattr aliased to getitem"""
104 if key in self.iterkeys():
104 if key in self.iterkeys():
105 return self[key]
105 return self[key]
106 else:
106 else:
107 raise AttributeError(key)
107 raise AttributeError(key)
108
108
109 def __setattr__(self, key, value):
109 def __setattr__(self, key, value):
110 """setattr aliased to setitem, with strict"""
110 """setattr aliased to setitem, with strict"""
111 if key in self.iterkeys():
111 if key in self.iterkeys():
112 self[key] = value
112 self[key] = value
113 else:
113 else:
114 raise AttributeError(key)
114 raise AttributeError(key)
115
115
116 def __setitem__(self, key, value):
116 def __setitem__(self, key, value):
117 """strict static key enforcement"""
117 """strict static key enforcement"""
118 if key in self.iterkeys():
118 if key in self.iterkeys():
119 dict.__setitem__(self, key, value)
119 dict.__setitem__(self, key, value)
120 else:
120 else:
121 raise KeyError(key)
121 raise KeyError(key)
122
122
123
123
124 class Client(HasTraits):
124 class Client(HasTraits):
125 """A semi-synchronous client to the IPython ZMQ cluster
125 """A semi-synchronous client to the IPython ZMQ cluster
126
126
127 Parameters
127 Parameters
128 ----------
128 ----------
129
129
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 Connection information for the Hub's registration. If a json connector
131 Connection information for the Hub's registration. If a json connector
132 file is given, then likely no further configuration is necessary.
132 file is given, then likely no further configuration is necessary.
133 [Default: use profile]
133 [Default: use profile]
134 profile : bytes
134 profile : bytes
135 The name of the Cluster profile to be used to find connector information.
135 The name of the Cluster profile to be used to find connector information.
136 If run from an IPython application, the default profile will be the same
136 If run from an IPython application, the default profile will be the same
137 as the running application, otherwise it will be 'default'.
137 as the running application, otherwise it will be 'default'.
138 context : zmq.Context
138 context : zmq.Context
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 debug : bool
140 debug : bool
141 flag for lots of message printing for debug purposes
141 flag for lots of message printing for debug purposes
142 timeout : int/float
142 timeout : int/float
143 time (in seconds) to wait for connection replies from the Hub
143 time (in seconds) to wait for connection replies from the Hub
144 [Default: 10]
144 [Default: 10]
145
145
146 #-------------- session related args ----------------
146 #-------------- session related args ----------------
147
147
148 config : Config object
148 config : Config object
149 If specified, this will be relayed to the Session for configuration
149 If specified, this will be relayed to the Session for configuration
150 username : str
150 username : str
151 set username for the session object
151 set username for the session object
152 packer : str (import_string) or callable
152 packer : str (import_string) or callable
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 function to serialize messages. Must support same input as
154 function to serialize messages. Must support same input as
155 JSON, and output must be bytes.
155 JSON, and output must be bytes.
156 You can pass a callable directly as `pack`
156 You can pass a callable directly as `pack`
157 unpacker : str (import_string) or callable
157 unpacker : str (import_string) or callable
158 The inverse of packer. Only necessary if packer is specified as *not* one
158 The inverse of packer. Only necessary if packer is specified as *not* one
159 of 'json' or 'pickle'.
159 of 'json' or 'pickle'.
160
160
161 #-------------- ssh related args ----------------
161 #-------------- ssh related args ----------------
162 # These are args for configuring the ssh tunnel to be used
162 # These are args for configuring the ssh tunnel to be used
163 # credentials are used to forward connections over ssh to the Controller
163 # credentials are used to forward connections over ssh to the Controller
164 # Note that the ip given in `addr` needs to be relative to sshserver
164 # Note that the ip given in `addr` needs to be relative to sshserver
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # and set sshserver as the same machine the Controller is on. However,
166 # and set sshserver as the same machine the Controller is on. However,
167 # the only requirement is that sshserver is able to see the Controller
167 # the only requirement is that sshserver is able to see the Controller
168 # (i.e. is within the same trusted network).
168 # (i.e. is within the same trusted network).
169
169
170 sshserver : str
170 sshserver : str
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 If keyfile or password is specified, and this is not, it will default to
172 If keyfile or password is specified, and this is not, it will default to
173 the ip given in addr.
173 the ip given in addr.
174 sshkey : str; path to ssh private key file
174 sshkey : str; path to ssh private key file
175 This specifies a key to be used in ssh login, default None.
175 This specifies a key to be used in ssh login, default None.
176 Regular default ssh keys will be used without specifying this argument.
176 Regular default ssh keys will be used without specifying this argument.
177 password : str
177 password : str
178 Your ssh password to sshserver. Note that if this is left None,
178 Your ssh password to sshserver. Note that if this is left None,
179 you will be prompted for it if passwordless key based login is unavailable.
179 you will be prompted for it if passwordless key based login is unavailable.
180 paramiko : bool
180 paramiko : bool
181 flag for whether to use paramiko instead of shell ssh for tunneling.
181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 [default: True on win32, False else]
182 [default: True on win32, False else]
183
183
184 ------- exec authentication args -------
184 ------- exec authentication args -------
185 If even localhost is untrusted, you can have some protection against
185 If even localhost is untrusted, you can have some protection against
186 unauthorized execution by signing messages with HMAC digests.
186 unauthorized execution by signing messages with HMAC digests.
187 Messages are still sent as cleartext, so if someone can snoop your
187 Messages are still sent as cleartext, so if someone can snoop your
188 loopback traffic this will not protect your privacy, but will prevent
188 loopback traffic this will not protect your privacy, but will prevent
189 unauthorized execution.
189 unauthorized execution.
190
190
191 exec_key : str
191 exec_key : str
192 an authentication key or file containing a key
192 an authentication key or file containing a key
193 default: None
193 default: None
194
194
195
195
196 Attributes
196 Attributes
197 ----------
197 ----------
198
198
199 ids : list of int engine IDs
199 ids : list of int engine IDs
200 requesting the ids attribute always synchronizes
200 requesting the ids attribute always synchronizes
201 the registration state. To request ids without synchronization,
201 the registration state. To request ids without synchronization,
202 use semi-private _ids attributes.
202 use semi-private _ids attributes.
203
203
204 history : list of msg_ids
204 history : list of msg_ids
205 a list of msg_ids, keeping track of all the execution
205 a list of msg_ids, keeping track of all the execution
206 messages you have submitted in order.
206 messages you have submitted in order.
207
207
208 outstanding : set of msg_ids
208 outstanding : set of msg_ids
209 a set of msg_ids that have been submitted, but whose
209 a set of msg_ids that have been submitted, but whose
210 results have not yet been received.
210 results have not yet been received.
211
211
212 results : dict
212 results : dict
213 a dict of all our results, keyed by msg_id
213 a dict of all our results, keyed by msg_id
214
214
215 block : bool
215 block : bool
216 determines default behavior when block not specified
216 determines default behavior when block not specified
217 in execution methods
217 in execution methods
218
218
219 Methods
219 Methods
220 -------
220 -------
221
221
222 spin
222 spin
223 flushes incoming results and registration state changes
223 flushes incoming results and registration state changes
224 control methods spin, and requesting `ids` also ensures up to date
224 control methods spin, and requesting `ids` also ensures up to date
225
225
226 wait
226 wait
227 wait on one or more msg_ids
227 wait on one or more msg_ids
228
228
229 execution methods
229 execution methods
230 apply
230 apply
231 legacy: execute, run
231 legacy: execute, run
232
232
233 data movement
233 data movement
234 push, pull, scatter, gather
234 push, pull, scatter, gather
235
235
236 query methods
236 query methods
237 queue_status, get_result, purge, result_status
237 queue_status, get_result, purge, result_status
238
238
239 control methods
239 control methods
240 abort, shutdown
240 abort, shutdown
241
241
242 """
242 """
243
243
244
244
245 block = Bool(False)
245 block = Bool(False)
246 outstanding = Set()
246 outstanding = Set()
247 results = Instance('collections.defaultdict', (dict,))
247 results = Instance('collections.defaultdict', (dict,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
249 history = List()
249 history = List()
250 debug = Bool(False)
250 debug = Bool(False)
251
251
252 profile=Unicode()
252 profile=Unicode()
253 def _profile_default(self):
253 def _profile_default(self):
254 if BaseIPythonApplication.initialized():
254 if BaseIPythonApplication.initialized():
255 # an IPython app *might* be running, try to get its profile
255 # an IPython app *might* be running, try to get its profile
256 try:
256 try:
257 return BaseIPythonApplication.instance().profile
257 return BaseIPythonApplication.instance().profile
258 except (AttributeError, MultipleInstanceError):
258 except (AttributeError, MultipleInstanceError):
259 # could be a *different* subclass of config.Application,
259 # could be a *different* subclass of config.Application,
260 # which would raise one of these two errors.
260 # which would raise one of these two errors.
261 return u'default'
261 return u'default'
262 else:
262 else:
263 return u'default'
263 return u'default'
264
264
265
265
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _ids = List()
267 _ids = List()
268 _connected=Bool(False)
268 _connected=Bool(False)
269 _ssh=Bool(False)
269 _ssh=Bool(False)
270 _context = Instance('zmq.Context')
270 _context = Instance('zmq.Context')
271 _config = Dict()
271 _config = Dict()
272 _engines=Instance(util.ReverseDict, (), {})
272 _engines=Instance(util.ReverseDict, (), {})
273 # _hub_socket=Instance('zmq.Socket')
273 # _hub_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
280 _task_scheme=Unicode()
280 _task_scheme=Unicode()
281 _closed = False
281 _closed = False
282 _ignored_control_replies=Int(0)
282 _ignored_control_replies=Int(0)
283 _ignored_hub_replies=Int(0)
283 _ignored_hub_replies=Int(0)
284
284
285 def __new__(self, *args, **kw):
285 def __new__(self, *args, **kw):
286 # don't raise on positional args
286 # don't raise on positional args
287 return HasTraits.__new__(self, **kw)
287 return HasTraits.__new__(self, **kw)
288
288
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 context=None, debug=False, exec_key=None,
290 context=None, debug=False, exec_key=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 timeout=10, **extra_args
292 timeout=10, **extra_args
293 ):
293 ):
294 if profile:
294 if profile:
295 super(Client, self).__init__(debug=debug, profile=profile)
295 super(Client, self).__init__(debug=debug, profile=profile)
296 else:
296 else:
297 super(Client, self).__init__(debug=debug)
297 super(Client, self).__init__(debug=debug)
298 if context is None:
298 if context is None:
299 context = zmq.Context.instance()
299 context = zmq.Context.instance()
300 self._context = context
300 self._context = context
301
301
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 if self._cd is not None:
303 if self._cd is not None:
304 if url_or_file is None:
304 if url_or_file is None:
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 " Please specify at least one of url_or_file or profile."
307 " Please specify at least one of url_or_file or profile."
308
308
309 if not util.is_url(url_or_file):
309 if not util.is_url(url_or_file):
310 # it's not a url, try for a file
310 # it's not a url, try for a file
311 if not os.path.exists(url_or_file):
311 if not os.path.exists(url_or_file):
312 if self._cd:
312 if self._cd:
313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 with open(url_or_file) as f:
315 with open(url_or_file) as f:
316 cfg = json.loads(f.read())
316 cfg = json.loads(f.read())
317 else:
317 else:
318 cfg = {'url':url_or_file}
318 cfg = {'url':url_or_file}
319
319
320 # sync defaults from args, json:
320 # sync defaults from args, json:
321 if sshserver:
321 if sshserver:
322 cfg['ssh'] = sshserver
322 cfg['ssh'] = sshserver
323 if exec_key:
323 if exec_key:
324 cfg['exec_key'] = exec_key
324 cfg['exec_key'] = exec_key
325 exec_key = cfg['exec_key']
325 exec_key = cfg['exec_key']
326 location = cfg.setdefault('location', None)
326 location = cfg.setdefault('location', None)
327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 url = cfg['url']
328 url = cfg['url']
329 proto,addr,port = util.split_url(url)
329 proto,addr,port = util.split_url(url)
330 if location is not None and addr == '127.0.0.1':
330 if location is not None and addr == '127.0.0.1':
331 # location specified, and connection is expected to be local
331 # location specified, and connection is expected to be local
332 if location not in LOCAL_IPS and not sshserver:
332 if location not in LOCAL_IPS and not sshserver:
333 # load ssh from JSON *only* if the controller is not on
333 # load ssh from JSON *only* if the controller is not on
334 # this machine
334 # this machine
335 sshserver=cfg['ssh']
335 sshserver=cfg['ssh']
336 if location not in LOCAL_IPS and not sshserver:
336 if location not in LOCAL_IPS and not sshserver:
337 # warn if no ssh specified, but SSH is probably needed
337 # warn if no ssh specified, but SSH is probably needed
338 # This is only a warning, because the most likely cause
338 # This is only a warning, because the most likely cause
339 # is a local Controller on a laptop whose IP is dynamic
339 # is a local Controller on a laptop whose IP is dynamic
340 warnings.warn("""
340 warnings.warn("""
341 Controller appears to be listening on localhost, but not on this machine.
341 Controller appears to be listening on localhost, but not on this machine.
342 If this is true, you should specify Client(...,sshserver='you@%s')
342 If this is true, you should specify Client(...,sshserver='you@%s')
343 or instruct your controller to listen on an external IP."""%location,
343 or instruct your controller to listen on an external IP."""%location,
344 RuntimeWarning)
344 RuntimeWarning)
345 elif not sshserver:
345 elif not sshserver:
346 # otherwise sync with cfg
346 # otherwise sync with cfg
347 sshserver = cfg['ssh']
347 sshserver = cfg['ssh']
348
348
349 self._config = cfg
349 self._config = cfg
350
350
351 self._ssh = bool(sshserver or sshkey or password)
351 self._ssh = bool(sshserver or sshkey or password)
352 if self._ssh and sshserver is None:
352 if self._ssh and sshserver is None:
353 # default to ssh via localhost
353 # default to ssh via localhost
354 sshserver = url.split('://')[1].split(':')[0]
354 sshserver = url.split('://')[1].split(':')[0]
355 if self._ssh and password is None:
355 if self._ssh and password is None:
356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 password=False
357 password=False
358 else:
358 else:
359 password = getpass("SSH Password for %s: "%sshserver)
359 password = getpass("SSH Password for %s: "%sshserver)
360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361
361
362 # configure and construct the session
362 # configure and construct the session
363 if exec_key is not None:
363 if exec_key is not None:
364 if os.path.isfile(exec_key):
364 if os.path.isfile(exec_key):
365 extra_args['keyfile'] = exec_key
365 extra_args['keyfile'] = exec_key
366 else:
366 else:
367 exec_key = util.asbytes(exec_key)
367 exec_key = util.asbytes(exec_key)
368 extra_args['key'] = exec_key
368 extra_args['key'] = exec_key
369 self.session = Session(**extra_args)
369 self.session = Session(**extra_args)
370
370
371 self._query_socket = self._context.socket(zmq.DEALER)
371 self._query_socket = self._context.socket(zmq.DEALER)
372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 if self._ssh:
373 if self._ssh:
374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 else:
375 else:
376 self._query_socket.connect(url)
376 self._query_socket.connect(url)
377
377
378 self.session.debug = self.debug
378 self.session.debug = self.debug
379
379
380 self._notification_handlers = {'registration_notification' : self._register_engine,
380 self._notification_handlers = {'registration_notification' : self._register_engine,
381 'unregistration_notification' : self._unregister_engine,
381 'unregistration_notification' : self._unregister_engine,
382 'shutdown_notification' : lambda msg: self.close(),
382 'shutdown_notification' : lambda msg: self.close(),
383 }
383 }
384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 'apply_reply' : self._handle_apply_reply}
385 'apply_reply' : self._handle_apply_reply}
386 self._connect(sshserver, ssh_kwargs, timeout)
386 self._connect(sshserver, ssh_kwargs, timeout)
387
387
388 def __del__(self):
388 def __del__(self):
389 """cleanup sockets, but _not_ context."""
389 """cleanup sockets, but _not_ context."""
390 self.close()
390 self.close()
391
391
392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 if ipython_dir is None:
393 if ipython_dir is None:
394 ipython_dir = get_ipython_dir()
394 ipython_dir = get_ipython_dir()
395 if profile_dir is not None:
395 if profile_dir is not None:
396 try:
396 try:
397 self._cd = ProfileDir.find_profile_dir(profile_dir)
397 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 return
398 return
399 except ProfileDirError:
399 except ProfileDirError:
400 pass
400 pass
401 elif profile is not None:
401 elif profile is not None:
402 try:
402 try:
403 self._cd = ProfileDir.find_profile_dir_by_name(
403 self._cd = ProfileDir.find_profile_dir_by_name(
404 ipython_dir, profile)
404 ipython_dir, profile)
405 return
405 return
406 except ProfileDirError:
406 except ProfileDirError:
407 pass
407 pass
408 self._cd = None
408 self._cd = None
409
409
410 def _update_engines(self, engines):
410 def _update_engines(self, engines):
411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 for k,v in engines.iteritems():
412 for k,v in engines.iteritems():
413 eid = int(k)
413 eid = int(k)
414 self._engines[eid] = v
414 self._engines[eid] = v
415 self._ids.append(eid)
415 self._ids.append(eid)
416 self._ids = sorted(self._ids)
416 self._ids = sorted(self._ids)
417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 self._task_scheme == 'pure' and self._task_socket:
418 self._task_scheme == 'pure' and self._task_socket:
419 self._stop_scheduling_tasks()
419 self._stop_scheduling_tasks()
420
420
421 def _stop_scheduling_tasks(self):
421 def _stop_scheduling_tasks(self):
422 """Stop scheduling tasks because an engine has been unregistered
422 """Stop scheduling tasks because an engine has been unregistered
423 from a pure ZMQ scheduler.
423 from a pure ZMQ scheduler.
424 """
424 """
425 self._task_socket.close()
425 self._task_socket.close()
426 self._task_socket = None
426 self._task_socket = None
427 msg = "An engine has been unregistered, and we are using pure " +\
427 msg = "An engine has been unregistered, and we are using pure " +\
428 "ZMQ task scheduling. Task farming will be disabled."
428 "ZMQ task scheduling. Task farming will be disabled."
429 if self.outstanding:
429 if self.outstanding:
430 msg += " If you were running tasks when this happened, " +\
430 msg += " If you were running tasks when this happened, " +\
431 "some `outstanding` msg_ids may never resolve."
431 "some `outstanding` msg_ids may never resolve."
432 warnings.warn(msg, RuntimeWarning)
432 warnings.warn(msg, RuntimeWarning)
433
433
434 def _build_targets(self, targets):
434 def _build_targets(self, targets):
435 """Turn valid target IDs or 'all' into two lists:
435 """Turn valid target IDs or 'all' into two lists:
436 (int_ids, uuids).
436 (int_ids, uuids).
437 """
437 """
438 if not self._ids:
438 if not self._ids:
439 # flush notification socket if no engines yet, just in case
439 # flush notification socket if no engines yet, just in case
440 if not self.ids:
440 if not self.ids:
441 raise error.NoEnginesRegistered("Can't build targets without any engines")
441 raise error.NoEnginesRegistered("Can't build targets without any engines")
442
442
443 if targets is None:
443 if targets is None:
444 targets = self._ids
444 targets = self._ids
445 elif isinstance(targets, basestring):
445 elif isinstance(targets, basestring):
446 if targets.lower() == 'all':
446 if targets.lower() == 'all':
447 targets = self._ids
447 targets = self._ids
448 else:
448 else:
449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 elif isinstance(targets, int):
450 elif isinstance(targets, int):
451 if targets < 0:
451 if targets < 0:
452 targets = self.ids[targets]
452 targets = self.ids[targets]
453 if targets not in self._ids:
453 if targets not in self._ids:
454 raise IndexError("No such engine: %i"%targets)
454 raise IndexError("No such engine: %i"%targets)
455 targets = [targets]
455 targets = [targets]
456
456
457 if isinstance(targets, slice):
457 if isinstance(targets, slice):
458 indices = range(len(self._ids))[targets]
458 indices = range(len(self._ids))[targets]
459 ids = self.ids
459 ids = self.ids
460 targets = [ ids[i] for i in indices ]
460 targets = [ ids[i] for i in indices ]
461
461
462 if not isinstance(targets, (tuple, list, xrange)):
462 if not isinstance(targets, (tuple, list, xrange)):
463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464
464
465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466
466
467 def _connect(self, sshserver, ssh_kwargs, timeout):
467 def _connect(self, sshserver, ssh_kwargs, timeout):
468 """setup all our socket connections to the cluster. This is called from
468 """setup all our socket connections to the cluster. This is called from
469 __init__."""
469 __init__."""
470
470
471 # Maybe allow reconnecting?
471 # Maybe allow reconnecting?
472 if self._connected:
472 if self._connected:
473 return
473 return
474 self._connected=True
474 self._connected=True
475
475
476 def connect_socket(s, url):
476 def connect_socket(s, url):
477 url = util.disambiguate_url(url, self._config['location'])
477 url = util.disambiguate_url(url, self._config['location'])
478 if self._ssh:
478 if self._ssh:
479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 else:
480 else:
481 return s.connect(url)
481 return s.connect(url)
482
482
483 self.session.send(self._query_socket, 'connection_request')
483 self.session.send(self._query_socket, 'connection_request')
484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 poller = zmq.Poller()
485 poller = zmq.Poller()
486 poller.register(self._query_socket, zmq.POLLIN)
486 poller.register(self._query_socket, zmq.POLLIN)
487 # poll expects milliseconds, timeout is seconds
487 # poll expects milliseconds, timeout is seconds
488 evts = poller.poll(timeout*1000)
488 evts = poller.poll(timeout*1000)
489 if not evts:
489 if not evts:
490 raise error.TimeoutError("Hub connection request timed out")
490 raise error.TimeoutError("Hub connection request timed out")
491 idents,msg = self.session.recv(self._query_socket,mode=0)
491 idents,msg = self.session.recv(self._query_socket,mode=0)
492 if self.debug:
492 if self.debug:
493 pprint(msg)
493 pprint(msg)
494 msg = Message(msg)
494 msg = Message(msg)
495 content = msg.content
495 content = msg.content
496 self._config['registration'] = dict(content)
496 self._config['registration'] = dict(content)
497 if content.status == 'ok':
497 if content.status == 'ok':
498 ident = self.session.bsession
498 ident = self.session.bsession
499 if content.mux:
499 if content.mux:
500 self._mux_socket = self._context.socket(zmq.DEALER)
500 self._mux_socket = self._context.socket(zmq.DEALER)
501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 connect_socket(self._mux_socket, content.mux)
502 connect_socket(self._mux_socket, content.mux)
503 if content.task:
503 if content.task:
504 self._task_scheme, task_addr = content.task
504 self._task_scheme, task_addr = content.task
505 self._task_socket = self._context.socket(zmq.DEALER)
505 self._task_socket = self._context.socket(zmq.DEALER)
506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 connect_socket(self._task_socket, task_addr)
507 connect_socket(self._task_socket, task_addr)
508 if content.notification:
508 if content.notification:
509 self._notification_socket = self._context.socket(zmq.SUB)
509 self._notification_socket = self._context.socket(zmq.SUB)
510 connect_socket(self._notification_socket, content.notification)
510 connect_socket(self._notification_socket, content.notification)
511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 # if content.query:
512 # if content.query:
513 # self._query_socket = self._context.socket(zmq.DEALER)
513 # self._query_socket = self._context.socket(zmq.DEALER)
514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 # connect_socket(self._query_socket, content.query)
515 # connect_socket(self._query_socket, content.query)
516 if content.control:
516 if content.control:
517 self._control_socket = self._context.socket(zmq.DEALER)
517 self._control_socket = self._context.socket(zmq.DEALER)
518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 connect_socket(self._control_socket, content.control)
519 connect_socket(self._control_socket, content.control)
520 if content.iopub:
520 if content.iopub:
521 self._iopub_socket = self._context.socket(zmq.SUB)
521 self._iopub_socket = self._context.socket(zmq.SUB)
522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 connect_socket(self._iopub_socket, content.iopub)
524 connect_socket(self._iopub_socket, content.iopub)
525 self._update_engines(dict(content.engines))
525 self._update_engines(dict(content.engines))
526 else:
526 else:
527 self._connected = False
527 self._connected = False
528 raise Exception("Failed to connect!")
528 raise Exception("Failed to connect!")
529
529
530 #--------------------------------------------------------------------------
530 #--------------------------------------------------------------------------
531 # handlers and callbacks for incoming messages
531 # handlers and callbacks for incoming messages
532 #--------------------------------------------------------------------------
532 #--------------------------------------------------------------------------
533
533
534 def _unwrap_exception(self, content):
534 def _unwrap_exception(self, content):
535 """unwrap exception, and remap engine_id to int."""
535 """unwrap exception, and remap engine_id to int."""
536 e = error.unwrap_exception(content)
536 e = error.unwrap_exception(content)
537 # print e.traceback
537 # print e.traceback
538 if e.engine_info:
538 if e.engine_info:
539 e_uuid = e.engine_info['engine_uuid']
539 e_uuid = e.engine_info['engine_uuid']
540 eid = self._engines[e_uuid]
540 eid = self._engines[e_uuid]
541 e.engine_info['engine_id'] = eid
541 e.engine_info['engine_id'] = eid
542 return e
542 return e
543
543
544 def _extract_metadata(self, header, parent, content):
544 def _extract_metadata(self, header, parent, content):
545 md = {'msg_id' : parent['msg_id'],
545 md = {'msg_id' : parent['msg_id'],
546 'received' : datetime.now(),
546 'received' : datetime.now(),
547 'engine_uuid' : header.get('engine', None),
547 'engine_uuid' : header.get('engine', None),
548 'follow' : parent.get('follow', []),
548 'follow' : parent.get('follow', []),
549 'after' : parent.get('after', []),
549 'after' : parent.get('after', []),
550 'status' : content['status'],
550 'status' : content['status'],
551 }
551 }
552
552
553 if md['engine_uuid'] is not None:
553 if md['engine_uuid'] is not None:
554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555
555
556 if 'date' in parent:
556 if 'date' in parent:
557 md['submitted'] = parent['date']
557 md['submitted'] = parent['date']
558 if 'started' in header:
558 if 'started' in header:
559 md['started'] = header['started']
559 md['started'] = header['started']
560 if 'date' in header:
560 if 'date' in header:
561 md['completed'] = header['date']
561 md['completed'] = header['date']
562 return md
562 return md
563
563
564 def _register_engine(self, msg):
564 def _register_engine(self, msg):
565 """Register a new engine, and update our connection info."""
565 """Register a new engine, and update our connection info."""
566 content = msg['content']
566 content = msg['content']
567 eid = content['id']
567 eid = content['id']
568 d = {eid : content['queue']}
568 d = {eid : content['queue']}
569 self._update_engines(d)
569 self._update_engines(d)
570
570
571 def _unregister_engine(self, msg):
571 def _unregister_engine(self, msg):
572 """Unregister an engine that has died."""
572 """Unregister an engine that has died."""
573 content = msg['content']
573 content = msg['content']
574 eid = int(content['id'])
574 eid = int(content['id'])
575 if eid in self._ids:
575 if eid in self._ids:
576 self._ids.remove(eid)
576 self._ids.remove(eid)
577 uuid = self._engines.pop(eid)
577 uuid = self._engines.pop(eid)
578
578
579 self._handle_stranded_msgs(eid, uuid)
579 self._handle_stranded_msgs(eid, uuid)
580
580
581 if self._task_socket and self._task_scheme == 'pure':
581 if self._task_socket and self._task_scheme == 'pure':
582 self._stop_scheduling_tasks()
582 self._stop_scheduling_tasks()
583
583
584 def _handle_stranded_msgs(self, eid, uuid):
584 def _handle_stranded_msgs(self, eid, uuid):
585 """Handle messages known to be on an engine when the engine unregisters.
585 """Handle messages known to be on an engine when the engine unregisters.
586
586
587 It is possible that this will fire prematurely - that is, an engine will
587 It is possible that this will fire prematurely - that is, an engine will
588 go down after completing a result, and the client will be notified
588 go down after completing a result, and the client will be notified
589 of the unregistration and later receive the successful result.
589 of the unregistration and later receive the successful result.
590 """
590 """
591
591
592 outstanding = self._outstanding_dict[uuid]
592 outstanding = self._outstanding_dict[uuid]
593
593
594 for msg_id in list(outstanding):
594 for msg_id in list(outstanding):
595 if msg_id in self.results:
595 if msg_id in self.results:
596 # we already
596 # we already
597 continue
597 continue
598 try:
598 try:
599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 except:
600 except:
601 content = error.wrap_exception()
601 content = error.wrap_exception()
602 # build a fake message:
602 # build a fake message:
603 parent = {}
603 parent = {}
604 header = {}
604 header = {}
605 parent['msg_id'] = msg_id
605 parent['msg_id'] = msg_id
606 header['engine'] = uuid
606 header['engine'] = uuid
607 header['date'] = datetime.now()
607 header['date'] = datetime.now()
608 msg = dict(parent_header=parent, header=header, content=content)
608 msg = dict(parent_header=parent, header=header, content=content)
609 self._handle_apply_reply(msg)
609 self._handle_apply_reply(msg)
610
610
611 def _handle_execute_reply(self, msg):
611 def _handle_execute_reply(self, msg):
612 """Save the reply to an execute_request into our results.
612 """Save the reply to an execute_request into our results.
613
613
614 execute messages are never actually used. apply is used instead.
614 execute messages are never actually used. apply is used instead.
615 """
615 """
616
616
617 parent = msg['parent_header']
617 parent = msg['parent_header']
618 msg_id = parent['msg_id']
618 msg_id = parent['msg_id']
619 if msg_id not in self.outstanding:
619 if msg_id not in self.outstanding:
620 if msg_id in self.history:
620 if msg_id in self.history:
621 print ("got stale result: %s"%msg_id)
621 print ("got stale result: %s"%msg_id)
622 else:
622 else:
623 print ("got unknown result: %s"%msg_id)
623 print ("got unknown result: %s"%msg_id)
624 else:
624 else:
625 self.outstanding.remove(msg_id)
625 self.outstanding.remove(msg_id)
626 self.results[msg_id] = self._unwrap_exception(msg['content'])
626 self.results[msg_id] = self._unwrap_exception(msg['content'])
627
627
628 def _handle_apply_reply(self, msg):
628 def _handle_apply_reply(self, msg):
629 """Save the reply to an apply_request into our results."""
629 """Save the reply to an apply_request into our results."""
630 parent = msg['parent_header']
630 parent = msg['parent_header']
631 msg_id = parent['msg_id']
631 msg_id = parent['msg_id']
632 if msg_id not in self.outstanding:
632 if msg_id not in self.outstanding:
633 if msg_id in self.history:
633 if msg_id in self.history:
634 print ("got stale result: %s"%msg_id)
634 print ("got stale result: %s"%msg_id)
635 print self.results[msg_id]
635 print self.results[msg_id]
636 print msg
636 print msg
637 else:
637 else:
638 print ("got unknown result: %s"%msg_id)
638 print ("got unknown result: %s"%msg_id)
639 else:
639 else:
640 self.outstanding.remove(msg_id)
640 self.outstanding.remove(msg_id)
641 content = msg['content']
641 content = msg['content']
642 header = msg['header']
642 header = msg['header']
643
643
644 # construct metadata:
644 # construct metadata:
645 md = self.metadata[msg_id]
645 md = self.metadata[msg_id]
646 md.update(self._extract_metadata(header, parent, content))
646 md.update(self._extract_metadata(header, parent, content))
647 # is this redundant?
647 # is this redundant?
648 self.metadata[msg_id] = md
648 self.metadata[msg_id] = md
649
649
650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 if msg_id in e_outstanding:
651 if msg_id in e_outstanding:
652 e_outstanding.remove(msg_id)
652 e_outstanding.remove(msg_id)
653
653
654 # construct result:
654 # construct result:
655 if content['status'] == 'ok':
655 if content['status'] == 'ok':
656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 elif content['status'] == 'aborted':
657 elif content['status'] == 'aborted':
658 self.results[msg_id] = error.TaskAborted(msg_id)
658 self.results[msg_id] = error.TaskAborted(msg_id)
659 elif content['status'] == 'resubmitted':
659 elif content['status'] == 'resubmitted':
660 # TODO: handle resubmission
660 # TODO: handle resubmission
661 pass
661 pass
662 else:
662 else:
663 self.results[msg_id] = self._unwrap_exception(content)
663 self.results[msg_id] = self._unwrap_exception(content)
664
664
665 def _flush_notifications(self):
665 def _flush_notifications(self):
666 """Flush notifications of engine registrations waiting
666 """Flush notifications of engine registrations waiting
667 in ZMQ queue."""
667 in ZMQ queue."""
668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 while msg is not None:
669 while msg is not None:
670 if self.debug:
670 if self.debug:
671 pprint(msg)
671 pprint(msg)
672 msg_type = msg['header']['msg_type']
672 msg_type = msg['header']['msg_type']
673 handler = self._notification_handlers.get(msg_type, None)
673 handler = self._notification_handlers.get(msg_type, None)
674 if handler is None:
674 if handler is None:
675 raise Exception("Unhandled message type: %s"%msg.msg_type)
675 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 else:
676 else:
677 handler(msg)
677 handler(msg)
678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679
679
680 def _flush_results(self, sock):
680 def _flush_results(self, sock):
681 """Flush task or queue results waiting in ZMQ queue."""
681 """Flush task or queue results waiting in ZMQ queue."""
682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 while msg is not None:
683 while msg is not None:
684 if self.debug:
684 if self.debug:
685 pprint(msg)
685 pprint(msg)
686 msg_type = msg['header']['msg_type']
686 msg_type = msg['header']['msg_type']
687 handler = self._queue_handlers.get(msg_type, None)
687 handler = self._queue_handlers.get(msg_type, None)
688 if handler is None:
688 if handler is None:
689 raise Exception("Unhandled message type: %s"%msg.msg_type)
689 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 else:
690 else:
691 handler(msg)
691 handler(msg)
692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693
693
694 def _flush_control(self, sock):
694 def _flush_control(self, sock):
695 """Flush replies from the control channel waiting
695 """Flush replies from the control channel waiting
696 in the ZMQ queue.
696 in the ZMQ queue.
697
697
698 Currently: ignore them."""
698 Currently: ignore them."""
699 if self._ignored_control_replies <= 0:
699 if self._ignored_control_replies <= 0:
700 return
700 return
701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 while msg is not None:
702 while msg is not None:
703 self._ignored_control_replies -= 1
703 self._ignored_control_replies -= 1
704 if self.debug:
704 if self.debug:
705 pprint(msg)
705 pprint(msg)
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707
707
708 def _flush_ignored_control(self):
708 def _flush_ignored_control(self):
709 """flush ignored control replies"""
709 """flush ignored control replies"""
710 while self._ignored_control_replies > 0:
710 while self._ignored_control_replies > 0:
711 self.session.recv(self._control_socket)
711 self.session.recv(self._control_socket)
712 self._ignored_control_replies -= 1
712 self._ignored_control_replies -= 1
713
713
714 def _flush_ignored_hub_replies(self):
714 def _flush_ignored_hub_replies(self):
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 while msg is not None:
716 while msg is not None:
717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718
718
719 def _flush_iopub(self, sock):
719 def _flush_iopub(self, sock):
720 """Flush replies from the iopub channel waiting
720 """Flush replies from the iopub channel waiting
721 in the ZMQ queue.
721 in the ZMQ queue.
722 """
722 """
723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 while msg is not None:
724 while msg is not None:
725 if self.debug:
725 if self.debug:
726 pprint(msg)
726 pprint(msg)
727 parent = msg['parent_header']
727 parent = msg['parent_header']
728 msg_id = parent['msg_id']
728 msg_id = parent['msg_id']
729 content = msg['content']
729 content = msg['content']
730 header = msg['header']
730 header = msg['header']
731 msg_type = msg['header']['msg_type']
731 msg_type = msg['header']['msg_type']
732
732
733 # init metadata:
733 # init metadata:
734 md = self.metadata[msg_id]
734 md = self.metadata[msg_id]
735
735
736 if msg_type == 'stream':
736 if msg_type == 'stream':
737 name = content['name']
737 name = content['name']
738 s = md[name] or ''
738 s = md[name] or ''
739 md[name] = s + content['data']
739 md[name] = s + content['data']
740 elif msg_type == 'pyerr':
740 elif msg_type == 'pyerr':
741 md.update({'pyerr' : self._unwrap_exception(content)})
741 md.update({'pyerr' : self._unwrap_exception(content)})
742 elif msg_type == 'pyin':
742 elif msg_type == 'pyin':
743 md.update({'pyin' : content['code']})
743 md.update({'pyin' : content['code']})
744 else:
744 else:
745 md.update({msg_type : content.get('data', '')})
745 md.update({msg_type : content.get('data', '')})
746
746
747 # reduntant?
747 # reduntant?
748 self.metadata[msg_id] = md
748 self.metadata[msg_id] = md
749
749
750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751
751
752 #--------------------------------------------------------------------------
752 #--------------------------------------------------------------------------
753 # len, getitem
753 # len, getitem
754 #--------------------------------------------------------------------------
754 #--------------------------------------------------------------------------
755
755
756 def __len__(self):
756 def __len__(self):
757 """len(client) returns # of engines."""
757 """len(client) returns # of engines."""
758 return len(self.ids)
758 return len(self.ids)
759
759
760 def __getitem__(self, key):
760 def __getitem__(self, key):
761 """index access returns DirectView multiplexer objects
761 """index access returns DirectView multiplexer objects
762
762
763 Must be int, slice, or list/tuple/xrange of ints"""
763 Must be int, slice, or list/tuple/xrange of ints"""
764 if not isinstance(key, (int, slice, tuple, list, xrange)):
764 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 else:
766 else:
767 return self.direct_view(key)
767 return self.direct_view(key)
768
768
769 #--------------------------------------------------------------------------
769 #--------------------------------------------------------------------------
770 # Begin public methods
770 # Begin public methods
771 #--------------------------------------------------------------------------
771 #--------------------------------------------------------------------------
772
772
773 @property
773 @property
774 def ids(self):
774 def ids(self):
775 """Always up-to-date ids property."""
775 """Always up-to-date ids property."""
776 self._flush_notifications()
776 self._flush_notifications()
777 # always copy:
777 # always copy:
778 return list(self._ids)
778 return list(self._ids)
779
779
780 def close(self):
780 def close(self):
781 if self._closed:
781 if self._closed:
782 return
782 return
783 snames = filter(lambda n: n.endswith('socket'), dir(self))
783 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 for socket in map(lambda name: getattr(self, name), snames):
784 for socket in map(lambda name: getattr(self, name), snames):
785 if isinstance(socket, zmq.Socket) and not socket.closed:
785 if isinstance(socket, zmq.Socket) and not socket.closed:
786 socket.close()
786 socket.close()
787 self._closed = True
787 self._closed = True
788
788
789 def spin(self):
789 def spin(self):
790 """Flush any registration notifications and execution results
790 """Flush any registration notifications and execution results
791 waiting in the ZMQ queue.
791 waiting in the ZMQ queue.
792 """
792 """
793 if self._notification_socket:
793 if self._notification_socket:
794 self._flush_notifications()
794 self._flush_notifications()
795 if self._mux_socket:
795 if self._mux_socket:
796 self._flush_results(self._mux_socket)
796 self._flush_results(self._mux_socket)
797 if self._task_socket:
797 if self._task_socket:
798 self._flush_results(self._task_socket)
798 self._flush_results(self._task_socket)
799 if self._control_socket:
799 if self._control_socket:
800 self._flush_control(self._control_socket)
800 self._flush_control(self._control_socket)
801 if self._iopub_socket:
801 if self._iopub_socket:
802 self._flush_iopub(self._iopub_socket)
802 self._flush_iopub(self._iopub_socket)
803 if self._query_socket:
803 if self._query_socket:
804 self._flush_ignored_hub_replies()
804 self._flush_ignored_hub_replies()
805
805
806 def wait(self, jobs=None, timeout=-1):
806 def wait(self, jobs=None, timeout=-1):
807 """waits on one or more `jobs`, for up to `timeout` seconds.
807 """waits on one or more `jobs`, for up to `timeout` seconds.
808
808
809 Parameters
809 Parameters
810 ----------
810 ----------
811
811
812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 ints are indices to self.history
813 ints are indices to self.history
814 strs are msg_ids
814 strs are msg_ids
815 default: wait on all outstanding messages
815 default: wait on all outstanding messages
816 timeout : float
816 timeout : float
817 a time in seconds, after which to give up.
817 a time in seconds, after which to give up.
818 default is -1, which means no timeout
818 default is -1, which means no timeout
819
819
820 Returns
820 Returns
821 -------
821 -------
822
822
823 True : when all msg_ids are done
823 True : when all msg_ids are done
824 False : timeout reached, some msg_ids still outstanding
824 False : timeout reached, some msg_ids still outstanding
825 """
825 """
826 tic = time.time()
826 tic = time.time()
827 if jobs is None:
827 if jobs is None:
828 theids = self.outstanding
828 theids = self.outstanding
829 else:
829 else:
830 if isinstance(jobs, (int, basestring, AsyncResult)):
830 if isinstance(jobs, (int, basestring, AsyncResult)):
831 jobs = [jobs]
831 jobs = [jobs]
832 theids = set()
832 theids = set()
833 for job in jobs:
833 for job in jobs:
834 if isinstance(job, int):
834 if isinstance(job, int):
835 # index access
835 # index access
836 job = self.history[job]
836 job = self.history[job]
837 elif isinstance(job, AsyncResult):
837 elif isinstance(job, AsyncResult):
838 map(theids.add, job.msg_ids)
838 map(theids.add, job.msg_ids)
839 continue
839 continue
840 theids.add(job)
840 theids.add(job)
841 if not theids.intersection(self.outstanding):
841 if not theids.intersection(self.outstanding):
842 return True
842 return True
843 self.spin()
843 self.spin()
844 while theids.intersection(self.outstanding):
844 while theids.intersection(self.outstanding):
845 if timeout >= 0 and ( time.time()-tic ) > timeout:
845 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 break
846 break
847 time.sleep(1e-3)
847 time.sleep(1e-3)
848 self.spin()
848 self.spin()
849 return len(theids.intersection(self.outstanding)) == 0
849 return len(theids.intersection(self.outstanding)) == 0
850
850
851 #--------------------------------------------------------------------------
851 #--------------------------------------------------------------------------
852 # Control methods
852 # Control methods
853 #--------------------------------------------------------------------------
853 #--------------------------------------------------------------------------
854
854
855 @spin_first
855 @spin_first
856 def clear(self, targets=None, block=None):
856 def clear(self, targets=None, block=None):
857 """Clear the namespace in target(s)."""
857 """Clear the namespace in target(s)."""
858 block = self.block if block is None else block
858 block = self.block if block is None else block
859 targets = self._build_targets(targets)[0]
859 targets = self._build_targets(targets)[0]
860 for t in targets:
860 for t in targets:
861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 error = False
862 error = False
863 if block:
863 if block:
864 self._flush_ignored_control()
864 self._flush_ignored_control()
865 for i in range(len(targets)):
865 for i in range(len(targets)):
866 idents,msg = self.session.recv(self._control_socket,0)
866 idents,msg = self.session.recv(self._control_socket,0)
867 if self.debug:
867 if self.debug:
868 pprint(msg)
868 pprint(msg)
869 if msg['content']['status'] != 'ok':
869 if msg['content']['status'] != 'ok':
870 error = self._unwrap_exception(msg['content'])
870 error = self._unwrap_exception(msg['content'])
871 else:
871 else:
872 self._ignored_control_replies += len(targets)
872 self._ignored_control_replies += len(targets)
873 if error:
873 if error:
874 raise error
874 raise error
875
875
876
876
877 @spin_first
877 @spin_first
878 def abort(self, jobs=None, targets=None, block=None):
878 def abort(self, jobs=None, targets=None, block=None):
879 """Abort specific jobs from the execution queues of target(s).
879 """Abort specific jobs from the execution queues of target(s).
880
880
881 This is a mechanism to prevent jobs that have already been submitted
881 This is a mechanism to prevent jobs that have already been submitted
882 from executing.
882 from executing.
883
883
884 Parameters
884 Parameters
885 ----------
885 ----------
886
886
887 jobs : msg_id, list of msg_ids, or AsyncResult
887 jobs : msg_id, list of msg_ids, or AsyncResult
888 The jobs to be aborted
888 The jobs to be aborted
889
889
890
890
891 """
891 """
892 block = self.block if block is None else block
892 block = self.block if block is None else block
893 targets = self._build_targets(targets)[0]
893 targets = self._build_targets(targets)[0]
894 msg_ids = []
894 msg_ids = []
895 if isinstance(jobs, (basestring,AsyncResult)):
895 if isinstance(jobs, (basestring,AsyncResult)):
896 jobs = [jobs]
896 jobs = [jobs]
897 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
897 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
898 if bad_ids:
898 if bad_ids:
899 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
899 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
900 for j in jobs:
900 for j in jobs:
901 if isinstance(j, AsyncResult):
901 if isinstance(j, AsyncResult):
902 msg_ids.extend(j.msg_ids)
902 msg_ids.extend(j.msg_ids)
903 else:
903 else:
904 msg_ids.append(j)
904 msg_ids.append(j)
905 content = dict(msg_ids=msg_ids)
905 content = dict(msg_ids=msg_ids)
906 for t in targets:
906 for t in targets:
907 self.session.send(self._control_socket, 'abort_request',
907 self.session.send(self._control_socket, 'abort_request',
908 content=content, ident=t)
908 content=content, ident=t)
909 error = False
909 error = False
910 if block:
910 if block:
911 self._flush_ignored_control()
911 self._flush_ignored_control()
912 for i in range(len(targets)):
912 for i in range(len(targets)):
913 idents,msg = self.session.recv(self._control_socket,0)
913 idents,msg = self.session.recv(self._control_socket,0)
914 if self.debug:
914 if self.debug:
915 pprint(msg)
915 pprint(msg)
916 if msg['content']['status'] != 'ok':
916 if msg['content']['status'] != 'ok':
917 error = self._unwrap_exception(msg['content'])
917 error = self._unwrap_exception(msg['content'])
918 else:
918 else:
919 self._ignored_control_replies += len(targets)
919 self._ignored_control_replies += len(targets)
920 if error:
920 if error:
921 raise error
921 raise error
922
922
923 @spin_first
923 @spin_first
924 def shutdown(self, targets=None, restart=False, hub=False, block=None):
924 def shutdown(self, targets=None, restart=False, hub=False, block=None):
925 """Terminates one or more engine processes, optionally including the hub."""
925 """Terminates one or more engine processes, optionally including the hub."""
926 block = self.block if block is None else block
926 block = self.block if block is None else block
927 if hub:
927 if hub:
928 targets = 'all'
928 targets = 'all'
929 targets = self._build_targets(targets)[0]
929 targets = self._build_targets(targets)[0]
930 for t in targets:
930 for t in targets:
931 self.session.send(self._control_socket, 'shutdown_request',
931 self.session.send(self._control_socket, 'shutdown_request',
932 content={'restart':restart},ident=t)
932 content={'restart':restart},ident=t)
933 error = False
933 error = False
934 if block or hub:
934 if block or hub:
935 self._flush_ignored_control()
935 self._flush_ignored_control()
936 for i in range(len(targets)):
936 for i in range(len(targets)):
937 idents,msg = self.session.recv(self._control_socket, 0)
937 idents,msg = self.session.recv(self._control_socket, 0)
938 if self.debug:
938 if self.debug:
939 pprint(msg)
939 pprint(msg)
940 if msg['content']['status'] != 'ok':
940 if msg['content']['status'] != 'ok':
941 error = self._unwrap_exception(msg['content'])
941 error = self._unwrap_exception(msg['content'])
942 else:
942 else:
943 self._ignored_control_replies += len(targets)
943 self._ignored_control_replies += len(targets)
944
944
945 if hub:
945 if hub:
946 time.sleep(0.25)
946 time.sleep(0.25)
947 self.session.send(self._query_socket, 'shutdown_request')
947 self.session.send(self._query_socket, 'shutdown_request')
948 idents,msg = self.session.recv(self._query_socket, 0)
948 idents,msg = self.session.recv(self._query_socket, 0)
949 if self.debug:
949 if self.debug:
950 pprint(msg)
950 pprint(msg)
951 if msg['content']['status'] != 'ok':
951 if msg['content']['status'] != 'ok':
952 error = self._unwrap_exception(msg['content'])
952 error = self._unwrap_exception(msg['content'])
953
953
954 if error:
954 if error:
955 raise error
955 raise error
956
956
957 #--------------------------------------------------------------------------
957 #--------------------------------------------------------------------------
958 # Execution related methods
958 # Execution related methods
959 #--------------------------------------------------------------------------
959 #--------------------------------------------------------------------------
960
960
961 def _maybe_raise(self, result):
961 def _maybe_raise(self, result):
962 """wrapper for maybe raising an exception if apply failed."""
962 """wrapper for maybe raising an exception if apply failed."""
963 if isinstance(result, error.RemoteError):
963 if isinstance(result, error.RemoteError):
964 raise result
964 raise result
965
965
966 return result
966 return result
967
967
968 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
968 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
969 ident=None):
969 ident=None):
970 """construct and send an apply message via a socket.
970 """construct and send an apply message via a socket.
971
971
972 This is the principal method with which all engine execution is performed by views.
972 This is the principal method with which all engine execution is performed by views.
973 """
973 """
974
974
975 assert not self._closed, "cannot use me anymore, I'm closed!"
975 assert not self._closed, "cannot use me anymore, I'm closed!"
976 # defaults:
976 # defaults:
977 args = args if args is not None else []
977 args = args if args is not None else []
978 kwargs = kwargs if kwargs is not None else {}
978 kwargs = kwargs if kwargs is not None else {}
979 subheader = subheader if subheader is not None else {}
979 subheader = subheader if subheader is not None else {}
980
980
981 # validate arguments
981 # validate arguments
982 if not callable(f):
982 if not callable(f):
983 raise TypeError("f must be callable, not %s"%type(f))
983 raise TypeError("f must be callable, not %s"%type(f))
984 if not isinstance(args, (tuple, list)):
984 if not isinstance(args, (tuple, list)):
985 raise TypeError("args must be tuple or list, not %s"%type(args))
985 raise TypeError("args must be tuple or list, not %s"%type(args))
986 if not isinstance(kwargs, dict):
986 if not isinstance(kwargs, dict):
987 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
987 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
988 if not isinstance(subheader, dict):
988 if not isinstance(subheader, dict):
989 raise TypeError("subheader must be dict, not %s"%type(subheader))
989 raise TypeError("subheader must be dict, not %s"%type(subheader))
990
990
991 bufs = util.pack_apply_message(f,args,kwargs)
991 bufs = util.pack_apply_message(f,args,kwargs)
992
992
993 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
993 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
994 subheader=subheader, track=track)
994 subheader=subheader, track=track)
995
995
996 msg_id = msg['header']['msg_id']
996 msg_id = msg['header']['msg_id']
997 self.outstanding.add(msg_id)
997 self.outstanding.add(msg_id)
998 if ident:
998 if ident:
999 # possibly routed to a specific engine
999 # possibly routed to a specific engine
1000 if isinstance(ident, list):
1000 if isinstance(ident, list):
1001 ident = ident[-1]
1001 ident = ident[-1]
1002 if ident in self._engines.values():
1002 if ident in self._engines.values():
1003 # save for later, in case of engine death
1003 # save for later, in case of engine death
1004 self._outstanding_dict[ident].add(msg_id)
1004 self._outstanding_dict[ident].add(msg_id)
1005 self.history.append(msg_id)
1005 self.history.append(msg_id)
1006 self.metadata[msg_id]['submitted'] = datetime.now()
1006 self.metadata[msg_id]['submitted'] = datetime.now()
1007
1007
1008 return msg
1008 return msg
1009
1009
1010 #--------------------------------------------------------------------------
1010 #--------------------------------------------------------------------------
1011 # construct a View object
1011 # construct a View object
1012 #--------------------------------------------------------------------------
1012 #--------------------------------------------------------------------------
1013
1013
1014 def load_balanced_view(self, targets=None):
1014 def load_balanced_view(self, targets=None):
1015 """construct a DirectView object.
1015 """construct a DirectView object.
1016
1016
1017 If no arguments are specified, create a LoadBalancedView
1017 If no arguments are specified, create a LoadBalancedView
1018 using all engines.
1018 using all engines.
1019
1019
1020 Parameters
1020 Parameters
1021 ----------
1021 ----------
1022
1022
1023 targets: list,slice,int,etc. [default: use all engines]
1023 targets: list,slice,int,etc. [default: use all engines]
1024 The subset of engines across which to load-balance
1024 The subset of engines across which to load-balance
1025 """
1025 """
1026 if targets == 'all':
1026 if targets == 'all':
1027 targets = None
1027 targets = None
1028 if targets is not None:
1028 if targets is not None:
1029 targets = self._build_targets(targets)[1]
1029 targets = self._build_targets(targets)[1]
1030 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1030 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1031
1031
1032 def direct_view(self, targets='all'):
1032 def direct_view(self, targets='all'):
1033 """construct a DirectView object.
1033 """construct a DirectView object.
1034
1034
1035 If no targets are specified, create a DirectView
1035 If no targets are specified, create a DirectView using all engines.
1036 using all engines.
1036
1037 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1038 evaluate the target engines at each execution, whereas rc[:] will connect to
1039 all *current* engines, and that list will not change.
1040
1041 That is, 'all' will always use all engines, whereas rc[:] will not use
1042 engines added after the DirectView is constructed.
1037
1043
1038 Parameters
1044 Parameters
1039 ----------
1045 ----------
1040
1046
1041 targets: list,slice,int,etc. [default: use all engines]
1047 targets: list,slice,int,etc. [default: use all engines]
1042 The engines to use for the View
1048 The engines to use for the View
1043 """
1049 """
1044 single = isinstance(targets, int)
1050 single = isinstance(targets, int)
1045 # allow 'all' to be lazily evaluated at each execution
1051 # allow 'all' to be lazily evaluated at each execution
1046 if targets != 'all':
1052 if targets != 'all':
1047 targets = self._build_targets(targets)[1]
1053 targets = self._build_targets(targets)[1]
1048 if single:
1054 if single:
1049 targets = targets[0]
1055 targets = targets[0]
1050 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1056 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1051
1057
1052 #--------------------------------------------------------------------------
1058 #--------------------------------------------------------------------------
1053 # Query methods
1059 # Query methods
1054 #--------------------------------------------------------------------------
1060 #--------------------------------------------------------------------------
1055
1061
1056 @spin_first
1062 @spin_first
1057 def get_result(self, indices_or_msg_ids=None, block=None):
1063 def get_result(self, indices_or_msg_ids=None, block=None):
1058 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1064 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1059
1065
1060 If the client already has the results, no request to the Hub will be made.
1066 If the client already has the results, no request to the Hub will be made.
1061
1067
1062 This is a convenient way to construct AsyncResult objects, which are wrappers
1068 This is a convenient way to construct AsyncResult objects, which are wrappers
1063 that include metadata about execution, and allow for awaiting results that
1069 that include metadata about execution, and allow for awaiting results that
1064 were not submitted by this Client.
1070 were not submitted by this Client.
1065
1071
1066 It can also be a convenient way to retrieve the metadata associated with
1072 It can also be a convenient way to retrieve the metadata associated with
1067 blocking execution, since it always retrieves
1073 blocking execution, since it always retrieves
1068
1074
1069 Examples
1075 Examples
1070 --------
1076 --------
1071 ::
1077 ::
1072
1078
1073 In [10]: r = client.apply()
1079 In [10]: r = client.apply()
1074
1080
1075 Parameters
1081 Parameters
1076 ----------
1082 ----------
1077
1083
1078 indices_or_msg_ids : integer history index, str msg_id, or list of either
1084 indices_or_msg_ids : integer history index, str msg_id, or list of either
1079 The indices or msg_ids of indices to be retrieved
1085 The indices or msg_ids of indices to be retrieved
1080
1086
1081 block : bool
1087 block : bool
1082 Whether to wait for the result to be done
1088 Whether to wait for the result to be done
1083
1089
1084 Returns
1090 Returns
1085 -------
1091 -------
1086
1092
1087 AsyncResult
1093 AsyncResult
1088 A single AsyncResult object will always be returned.
1094 A single AsyncResult object will always be returned.
1089
1095
1090 AsyncHubResult
1096 AsyncHubResult
1091 A subclass of AsyncResult that retrieves results from the Hub
1097 A subclass of AsyncResult that retrieves results from the Hub
1092
1098
1093 """
1099 """
1094 block = self.block if block is None else block
1100 block = self.block if block is None else block
1095 if indices_or_msg_ids is None:
1101 if indices_or_msg_ids is None:
1096 indices_or_msg_ids = -1
1102 indices_or_msg_ids = -1
1097
1103
1098 if not isinstance(indices_or_msg_ids, (list,tuple)):
1104 if not isinstance(indices_or_msg_ids, (list,tuple)):
1099 indices_or_msg_ids = [indices_or_msg_ids]
1105 indices_or_msg_ids = [indices_or_msg_ids]
1100
1106
1101 theids = []
1107 theids = []
1102 for id in indices_or_msg_ids:
1108 for id in indices_or_msg_ids:
1103 if isinstance(id, int):
1109 if isinstance(id, int):
1104 id = self.history[id]
1110 id = self.history[id]
1105 if not isinstance(id, basestring):
1111 if not isinstance(id, basestring):
1106 raise TypeError("indices must be str or int, not %r"%id)
1112 raise TypeError("indices must be str or int, not %r"%id)
1107 theids.append(id)
1113 theids.append(id)
1108
1114
1109 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1115 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1110 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1116 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1111
1117
1112 if remote_ids:
1118 if remote_ids:
1113 ar = AsyncHubResult(self, msg_ids=theids)
1119 ar = AsyncHubResult(self, msg_ids=theids)
1114 else:
1120 else:
1115 ar = AsyncResult(self, msg_ids=theids)
1121 ar = AsyncResult(self, msg_ids=theids)
1116
1122
1117 if block:
1123 if block:
1118 ar.wait()
1124 ar.wait()
1119
1125
1120 return ar
1126 return ar
1121
1127
1122 @spin_first
1128 @spin_first
1123 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1129 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1124 """Resubmit one or more tasks.
1130 """Resubmit one or more tasks.
1125
1131
1126 in-flight tasks may not be resubmitted.
1132 in-flight tasks may not be resubmitted.
1127
1133
1128 Parameters
1134 Parameters
1129 ----------
1135 ----------
1130
1136
1131 indices_or_msg_ids : integer history index, str msg_id, or list of either
1137 indices_or_msg_ids : integer history index, str msg_id, or list of either
1132 The indices or msg_ids of indices to be retrieved
1138 The indices or msg_ids of indices to be retrieved
1133
1139
1134 block : bool
1140 block : bool
1135 Whether to wait for the result to be done
1141 Whether to wait for the result to be done
1136
1142
1137 Returns
1143 Returns
1138 -------
1144 -------
1139
1145
1140 AsyncHubResult
1146 AsyncHubResult
1141 A subclass of AsyncResult that retrieves results from the Hub
1147 A subclass of AsyncResult that retrieves results from the Hub
1142
1148
1143 """
1149 """
1144 block = self.block if block is None else block
1150 block = self.block if block is None else block
1145 if indices_or_msg_ids is None:
1151 if indices_or_msg_ids is None:
1146 indices_or_msg_ids = -1
1152 indices_or_msg_ids = -1
1147
1153
1148 if not isinstance(indices_or_msg_ids, (list,tuple)):
1154 if not isinstance(indices_or_msg_ids, (list,tuple)):
1149 indices_or_msg_ids = [indices_or_msg_ids]
1155 indices_or_msg_ids = [indices_or_msg_ids]
1150
1156
1151 theids = []
1157 theids = []
1152 for id in indices_or_msg_ids:
1158 for id in indices_or_msg_ids:
1153 if isinstance(id, int):
1159 if isinstance(id, int):
1154 id = self.history[id]
1160 id = self.history[id]
1155 if not isinstance(id, basestring):
1161 if not isinstance(id, basestring):
1156 raise TypeError("indices must be str or int, not %r"%id)
1162 raise TypeError("indices must be str or int, not %r"%id)
1157 theids.append(id)
1163 theids.append(id)
1158
1164
1159 for msg_id in theids:
1165 for msg_id in theids:
1160 self.outstanding.discard(msg_id)
1166 self.outstanding.discard(msg_id)
1161 if msg_id in self.history:
1167 if msg_id in self.history:
1162 self.history.remove(msg_id)
1168 self.history.remove(msg_id)
1163 self.results.pop(msg_id, None)
1169 self.results.pop(msg_id, None)
1164 self.metadata.pop(msg_id, None)
1170 self.metadata.pop(msg_id, None)
1165 content = dict(msg_ids = theids)
1171 content = dict(msg_ids = theids)
1166
1172
1167 self.session.send(self._query_socket, 'resubmit_request', content)
1173 self.session.send(self._query_socket, 'resubmit_request', content)
1168
1174
1169 zmq.select([self._query_socket], [], [])
1175 zmq.select([self._query_socket], [], [])
1170 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1176 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1171 if self.debug:
1177 if self.debug:
1172 pprint(msg)
1178 pprint(msg)
1173 content = msg['content']
1179 content = msg['content']
1174 if content['status'] != 'ok':
1180 if content['status'] != 'ok':
1175 raise self._unwrap_exception(content)
1181 raise self._unwrap_exception(content)
1176
1182
1177 ar = AsyncHubResult(self, msg_ids=theids)
1183 ar = AsyncHubResult(self, msg_ids=theids)
1178
1184
1179 if block:
1185 if block:
1180 ar.wait()
1186 ar.wait()
1181
1187
1182 return ar
1188 return ar
1183
1189
1184 @spin_first
1190 @spin_first
1185 def result_status(self, msg_ids, status_only=True):
1191 def result_status(self, msg_ids, status_only=True):
1186 """Check on the status of the result(s) of the apply request with `msg_ids`.
1192 """Check on the status of the result(s) of the apply request with `msg_ids`.
1187
1193
1188 If status_only is False, then the actual results will be retrieved, else
1194 If status_only is False, then the actual results will be retrieved, else
1189 only the status of the results will be checked.
1195 only the status of the results will be checked.
1190
1196
1191 Parameters
1197 Parameters
1192 ----------
1198 ----------
1193
1199
1194 msg_ids : list of msg_ids
1200 msg_ids : list of msg_ids
1195 if int:
1201 if int:
1196 Passed as index to self.history for convenience.
1202 Passed as index to self.history for convenience.
1197 status_only : bool (default: True)
1203 status_only : bool (default: True)
1198 if False:
1204 if False:
1199 Retrieve the actual results of completed tasks.
1205 Retrieve the actual results of completed tasks.
1200
1206
1201 Returns
1207 Returns
1202 -------
1208 -------
1203
1209
1204 results : dict
1210 results : dict
1205 There will always be the keys 'pending' and 'completed', which will
1211 There will always be the keys 'pending' and 'completed', which will
1206 be lists of msg_ids that are incomplete or complete. If `status_only`
1212 be lists of msg_ids that are incomplete or complete. If `status_only`
1207 is False, then completed results will be keyed by their `msg_id`.
1213 is False, then completed results will be keyed by their `msg_id`.
1208 """
1214 """
1209 if not isinstance(msg_ids, (list,tuple)):
1215 if not isinstance(msg_ids, (list,tuple)):
1210 msg_ids = [msg_ids]
1216 msg_ids = [msg_ids]
1211
1217
1212 theids = []
1218 theids = []
1213 for msg_id in msg_ids:
1219 for msg_id in msg_ids:
1214 if isinstance(msg_id, int):
1220 if isinstance(msg_id, int):
1215 msg_id = self.history[msg_id]
1221 msg_id = self.history[msg_id]
1216 if not isinstance(msg_id, basestring):
1222 if not isinstance(msg_id, basestring):
1217 raise TypeError("msg_ids must be str, not %r"%msg_id)
1223 raise TypeError("msg_ids must be str, not %r"%msg_id)
1218 theids.append(msg_id)
1224 theids.append(msg_id)
1219
1225
1220 completed = []
1226 completed = []
1221 local_results = {}
1227 local_results = {}
1222
1228
1223 # comment this block out to temporarily disable local shortcut:
1229 # comment this block out to temporarily disable local shortcut:
1224 for msg_id in theids:
1230 for msg_id in theids:
1225 if msg_id in self.results:
1231 if msg_id in self.results:
1226 completed.append(msg_id)
1232 completed.append(msg_id)
1227 local_results[msg_id] = self.results[msg_id]
1233 local_results[msg_id] = self.results[msg_id]
1228 theids.remove(msg_id)
1234 theids.remove(msg_id)
1229
1235
1230 if theids: # some not locally cached
1236 if theids: # some not locally cached
1231 content = dict(msg_ids=theids, status_only=status_only)
1237 content = dict(msg_ids=theids, status_only=status_only)
1232 msg = self.session.send(self._query_socket, "result_request", content=content)
1238 msg = self.session.send(self._query_socket, "result_request", content=content)
1233 zmq.select([self._query_socket], [], [])
1239 zmq.select([self._query_socket], [], [])
1234 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1240 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1235 if self.debug:
1241 if self.debug:
1236 pprint(msg)
1242 pprint(msg)
1237 content = msg['content']
1243 content = msg['content']
1238 if content['status'] != 'ok':
1244 if content['status'] != 'ok':
1239 raise self._unwrap_exception(content)
1245 raise self._unwrap_exception(content)
1240 buffers = msg['buffers']
1246 buffers = msg['buffers']
1241 else:
1247 else:
1242 content = dict(completed=[],pending=[])
1248 content = dict(completed=[],pending=[])
1243
1249
1244 content['completed'].extend(completed)
1250 content['completed'].extend(completed)
1245
1251
1246 if status_only:
1252 if status_only:
1247 return content
1253 return content
1248
1254
1249 failures = []
1255 failures = []
1250 # load cached results into result:
1256 # load cached results into result:
1251 content.update(local_results)
1257 content.update(local_results)
1252
1258
1253 # update cache with results:
1259 # update cache with results:
1254 for msg_id in sorted(theids):
1260 for msg_id in sorted(theids):
1255 if msg_id in content['completed']:
1261 if msg_id in content['completed']:
1256 rec = content[msg_id]
1262 rec = content[msg_id]
1257 parent = rec['header']
1263 parent = rec['header']
1258 header = rec['result_header']
1264 header = rec['result_header']
1259 rcontent = rec['result_content']
1265 rcontent = rec['result_content']
1260 iodict = rec['io']
1266 iodict = rec['io']
1261 if isinstance(rcontent, str):
1267 if isinstance(rcontent, str):
1262 rcontent = self.session.unpack(rcontent)
1268 rcontent = self.session.unpack(rcontent)
1263
1269
1264 md = self.metadata[msg_id]
1270 md = self.metadata[msg_id]
1265 md.update(self._extract_metadata(header, parent, rcontent))
1271 md.update(self._extract_metadata(header, parent, rcontent))
1266 md.update(iodict)
1272 md.update(iodict)
1267
1273
1268 if rcontent['status'] == 'ok':
1274 if rcontent['status'] == 'ok':
1269 res,buffers = util.unserialize_object(buffers)
1275 res,buffers = util.unserialize_object(buffers)
1270 else:
1276 else:
1271 print rcontent
1277 print rcontent
1272 res = self._unwrap_exception(rcontent)
1278 res = self._unwrap_exception(rcontent)
1273 failures.append(res)
1279 failures.append(res)
1274
1280
1275 self.results[msg_id] = res
1281 self.results[msg_id] = res
1276 content[msg_id] = res
1282 content[msg_id] = res
1277
1283
1278 if len(theids) == 1 and failures:
1284 if len(theids) == 1 and failures:
1279 raise failures[0]
1285 raise failures[0]
1280
1286
1281 error.collect_exceptions(failures, "result_status")
1287 error.collect_exceptions(failures, "result_status")
1282 return content
1288 return content
1283
1289
1284 @spin_first
1290 @spin_first
1285 def queue_status(self, targets='all', verbose=False):
1291 def queue_status(self, targets='all', verbose=False):
1286 """Fetch the status of engine queues.
1292 """Fetch the status of engine queues.
1287
1293
1288 Parameters
1294 Parameters
1289 ----------
1295 ----------
1290
1296
1291 targets : int/str/list of ints/strs
1297 targets : int/str/list of ints/strs
1292 the engines whose states are to be queried.
1298 the engines whose states are to be queried.
1293 default : all
1299 default : all
1294 verbose : bool
1300 verbose : bool
1295 Whether to return lengths only, or lists of ids for each element
1301 Whether to return lengths only, or lists of ids for each element
1296 """
1302 """
1297 engine_ids = self._build_targets(targets)[1]
1303 engine_ids = self._build_targets(targets)[1]
1298 content = dict(targets=engine_ids, verbose=verbose)
1304 content = dict(targets=engine_ids, verbose=verbose)
1299 self.session.send(self._query_socket, "queue_request", content=content)
1305 self.session.send(self._query_socket, "queue_request", content=content)
1300 idents,msg = self.session.recv(self._query_socket, 0)
1306 idents,msg = self.session.recv(self._query_socket, 0)
1301 if self.debug:
1307 if self.debug:
1302 pprint(msg)
1308 pprint(msg)
1303 content = msg['content']
1309 content = msg['content']
1304 status = content.pop('status')
1310 status = content.pop('status')
1305 if status != 'ok':
1311 if status != 'ok':
1306 raise self._unwrap_exception(content)
1312 raise self._unwrap_exception(content)
1307 content = rekey(content)
1313 content = rekey(content)
1308 if isinstance(targets, int):
1314 if isinstance(targets, int):
1309 return content[targets]
1315 return content[targets]
1310 else:
1316 else:
1311 return content
1317 return content
1312
1318
1313 @spin_first
1319 @spin_first
1314 def purge_results(self, jobs=[], targets=[]):
1320 def purge_results(self, jobs=[], targets=[]):
1315 """Tell the Hub to forget results.
1321 """Tell the Hub to forget results.
1316
1322
1317 Individual results can be purged by msg_id, or the entire
1323 Individual results can be purged by msg_id, or the entire
1318 history of specific targets can be purged.
1324 history of specific targets can be purged.
1319
1325
1320 Use `purge_results('all')` to scrub everything from the Hub's db.
1326 Use `purge_results('all')` to scrub everything from the Hub's db.
1321
1327
1322 Parameters
1328 Parameters
1323 ----------
1329 ----------
1324
1330
1325 jobs : str or list of str or AsyncResult objects
1331 jobs : str or list of str or AsyncResult objects
1326 the msg_ids whose results should be forgotten.
1332 the msg_ids whose results should be forgotten.
1327 targets : int/str/list of ints/strs
1333 targets : int/str/list of ints/strs
1328 The targets, by int_id, whose entire history is to be purged.
1334 The targets, by int_id, whose entire history is to be purged.
1329
1335
1330 default : None
1336 default : None
1331 """
1337 """
1332 if not targets and not jobs:
1338 if not targets and not jobs:
1333 raise ValueError("Must specify at least one of `targets` and `jobs`")
1339 raise ValueError("Must specify at least one of `targets` and `jobs`")
1334 if targets:
1340 if targets:
1335 targets = self._build_targets(targets)[1]
1341 targets = self._build_targets(targets)[1]
1336
1342
1337 # construct msg_ids from jobs
1343 # construct msg_ids from jobs
1338 if jobs == 'all':
1344 if jobs == 'all':
1339 msg_ids = jobs
1345 msg_ids = jobs
1340 else:
1346 else:
1341 msg_ids = []
1347 msg_ids = []
1342 if isinstance(jobs, (basestring,AsyncResult)):
1348 if isinstance(jobs, (basestring,AsyncResult)):
1343 jobs = [jobs]
1349 jobs = [jobs]
1344 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1350 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1345 if bad_ids:
1351 if bad_ids:
1346 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1352 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1347 for j in jobs:
1353 for j in jobs:
1348 if isinstance(j, AsyncResult):
1354 if isinstance(j, AsyncResult):
1349 msg_ids.extend(j.msg_ids)
1355 msg_ids.extend(j.msg_ids)
1350 else:
1356 else:
1351 msg_ids.append(j)
1357 msg_ids.append(j)
1352
1358
1353 content = dict(engine_ids=targets, msg_ids=msg_ids)
1359 content = dict(engine_ids=targets, msg_ids=msg_ids)
1354 self.session.send(self._query_socket, "purge_request", content=content)
1360 self.session.send(self._query_socket, "purge_request", content=content)
1355 idents, msg = self.session.recv(self._query_socket, 0)
1361 idents, msg = self.session.recv(self._query_socket, 0)
1356 if self.debug:
1362 if self.debug:
1357 pprint(msg)
1363 pprint(msg)
1358 content = msg['content']
1364 content = msg['content']
1359 if content['status'] != 'ok':
1365 if content['status'] != 'ok':
1360 raise self._unwrap_exception(content)
1366 raise self._unwrap_exception(content)
1361
1367
1362 @spin_first
1368 @spin_first
1363 def hub_history(self):
1369 def hub_history(self):
1364 """Get the Hub's history
1370 """Get the Hub's history
1365
1371
1366 Just like the Client, the Hub has a history, which is a list of msg_ids.
1372 Just like the Client, the Hub has a history, which is a list of msg_ids.
1367 This will contain the history of all clients, and, depending on configuration,
1373 This will contain the history of all clients, and, depending on configuration,
1368 may contain history across multiple cluster sessions.
1374 may contain history across multiple cluster sessions.
1369
1375
1370 Any msg_id returned here is a valid argument to `get_result`.
1376 Any msg_id returned here is a valid argument to `get_result`.
1371
1377
1372 Returns
1378 Returns
1373 -------
1379 -------
1374
1380
1375 msg_ids : list of strs
1381 msg_ids : list of strs
1376 list of all msg_ids, ordered by task submission time.
1382 list of all msg_ids, ordered by task submission time.
1377 """
1383 """
1378
1384
1379 self.session.send(self._query_socket, "history_request", content={})
1385 self.session.send(self._query_socket, "history_request", content={})
1380 idents, msg = self.session.recv(self._query_socket, 0)
1386 idents, msg = self.session.recv(self._query_socket, 0)
1381
1387
1382 if self.debug:
1388 if self.debug:
1383 pprint(msg)
1389 pprint(msg)
1384 content = msg['content']
1390 content = msg['content']
1385 if content['status'] != 'ok':
1391 if content['status'] != 'ok':
1386 raise self._unwrap_exception(content)
1392 raise self._unwrap_exception(content)
1387 else:
1393 else:
1388 return content['history']
1394 return content['history']
1389
1395
1390 @spin_first
1396 @spin_first
1391 def db_query(self, query, keys=None):
1397 def db_query(self, query, keys=None):
1392 """Query the Hub's TaskRecord database
1398 """Query the Hub's TaskRecord database
1393
1399
1394 This will return a list of task record dicts that match `query`
1400 This will return a list of task record dicts that match `query`
1395
1401
1396 Parameters
1402 Parameters
1397 ----------
1403 ----------
1398
1404
1399 query : mongodb query dict
1405 query : mongodb query dict
1400 The search dict. See mongodb query docs for details.
1406 The search dict. See mongodb query docs for details.
1401 keys : list of strs [optional]
1407 keys : list of strs [optional]
1402 The subset of keys to be returned. The default is to fetch everything but buffers.
1408 The subset of keys to be returned. The default is to fetch everything but buffers.
1403 'msg_id' will *always* be included.
1409 'msg_id' will *always* be included.
1404 """
1410 """
1405 if isinstance(keys, basestring):
1411 if isinstance(keys, basestring):
1406 keys = [keys]
1412 keys = [keys]
1407 content = dict(query=query, keys=keys)
1413 content = dict(query=query, keys=keys)
1408 self.session.send(self._query_socket, "db_request", content=content)
1414 self.session.send(self._query_socket, "db_request", content=content)
1409 idents, msg = self.session.recv(self._query_socket, 0)
1415 idents, msg = self.session.recv(self._query_socket, 0)
1410 if self.debug:
1416 if self.debug:
1411 pprint(msg)
1417 pprint(msg)
1412 content = msg['content']
1418 content = msg['content']
1413 if content['status'] != 'ok':
1419 if content['status'] != 'ok':
1414 raise self._unwrap_exception(content)
1420 raise self._unwrap_exception(content)
1415
1421
1416 records = content['records']
1422 records = content['records']
1417
1423
1418 buffer_lens = content['buffer_lens']
1424 buffer_lens = content['buffer_lens']
1419 result_buffer_lens = content['result_buffer_lens']
1425 result_buffer_lens = content['result_buffer_lens']
1420 buffers = msg['buffers']
1426 buffers = msg['buffers']
1421 has_bufs = buffer_lens is not None
1427 has_bufs = buffer_lens is not None
1422 has_rbufs = result_buffer_lens is not None
1428 has_rbufs = result_buffer_lens is not None
1423 for i,rec in enumerate(records):
1429 for i,rec in enumerate(records):
1424 # relink buffers
1430 # relink buffers
1425 if has_bufs:
1431 if has_bufs:
1426 blen = buffer_lens[i]
1432 blen = buffer_lens[i]
1427 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1433 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1428 if has_rbufs:
1434 if has_rbufs:
1429 blen = result_buffer_lens[i]
1435 blen = result_buffer_lens[i]
1430 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1436 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1431
1437
1432 return records
1438 return records
1433
1439
1434 __all__ = [ 'Client' ]
1440 __all__ = [ 'Client' ]
@@ -1,219 +1,222 b''
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 * Min RK
6 * Min RK
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import sys
21 import sys
22 import warnings
22 import warnings
23
23
24 from IPython.testing.skipdoctest import skip_doctest
24 from IPython.testing.skipdoctest import skip_doctest
25
25
26 from . import map as Map
26 from . import map as Map
27 from .asyncresult import AsyncMapResult
27 from .asyncresult import AsyncMapResult
28
28
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30 # Decorators
30 # Decorators
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 @skip_doctest
33 @skip_doctest
34 def remote(view, block=None, **flags):
34 def remote(view, block=None, **flags):
35 """Turn a function into a remote function.
35 """Turn a function into a remote function.
36
36
37 This method can be used for map:
37 This method can be used for map:
38
38
39 In [1]: @remote(view,block=True)
39 In [1]: @remote(view,block=True)
40 ...: def func(a):
40 ...: def func(a):
41 ...: pass
41 ...: pass
42 """
42 """
43
43
44 def remote_function(f):
44 def remote_function(f):
45 return RemoteFunction(view, f, block=block, **flags)
45 return RemoteFunction(view, f, block=block, **flags)
46 return remote_function
46 return remote_function
47
47
48 @skip_doctest
48 @skip_doctest
49 def parallel(view, dist='b', block=None, ordered=True, **flags):
49 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 """Turn a function into a parallel remote function.
50 """Turn a function into a parallel remote function.
51
51
52 This method can be used for map:
52 This method can be used for map:
53
53
54 In [1]: @parallel(view, block=True)
54 In [1]: @parallel(view, block=True)
55 ...: def func(a):
55 ...: def func(a):
56 ...: pass
56 ...: pass
57 """
57 """
58
58
59 def parallel_function(f):
59 def parallel_function(f):
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return parallel_function
61 return parallel_function
62
62
63 #--------------------------------------------------------------------------
63 #--------------------------------------------------------------------------
64 # Classes
64 # Classes
65 #--------------------------------------------------------------------------
65 #--------------------------------------------------------------------------
66
66
67 class RemoteFunction(object):
67 class RemoteFunction(object):
68 """Turn an existing function into a remote function.
68 """Turn an existing function into a remote function.
69
69
70 Parameters
70 Parameters
71 ----------
71 ----------
72
72
73 view : View instance
73 view : View instance
74 The view to be used for execution
74 The view to be used for execution
75 f : callable
75 f : callable
76 The function to be wrapped into a remote function
76 The function to be wrapped into a remote function
77 block : bool [default: None]
77 block : bool [default: None]
78 Whether to wait for results or not. The default behavior is
78 Whether to wait for results or not. The default behavior is
79 to use the current `block` attribute of `view`
79 to use the current `block` attribute of `view`
80
80
81 **flags : remaining kwargs are passed to View.temp_flags
81 **flags : remaining kwargs are passed to View.temp_flags
82 """
82 """
83
83
84 view = None # the remote connection
84 view = None # the remote connection
85 func = None # the wrapped function
85 func = None # the wrapped function
86 block = None # whether to block
86 block = None # whether to block
87 flags = None # dict of extra kwargs for temp_flags
87 flags = None # dict of extra kwargs for temp_flags
88
88
89 def __init__(self, view, f, block=None, **flags):
89 def __init__(self, view, f, block=None, **flags):
90 self.view = view
90 self.view = view
91 self.func = f
91 self.func = f
92 self.block=block
92 self.block=block
93 self.flags=flags
93 self.flags=flags
94
94
95 def __call__(self, *args, **kwargs):
95 def __call__(self, *args, **kwargs):
96 block = self.view.block if self.block is None else self.block
96 block = self.view.block if self.block is None else self.block
97 with self.view.temp_flags(block=block, **self.flags):
97 with self.view.temp_flags(block=block, **self.flags):
98 return self.view.apply(self.func, *args, **kwargs)
98 return self.view.apply(self.func, *args, **kwargs)
99
99
100
100
101 class ParallelFunction(RemoteFunction):
101 class ParallelFunction(RemoteFunction):
102 """Class for mapping a function to sequences.
102 """Class for mapping a function to sequences.
103
103
104 This will distribute the sequences according the a mapper, and call
104 This will distribute the sequences according the a mapper, and call
105 the function on each sub-sequence. If called via map, then the function
105 the function on each sub-sequence. If called via map, then the function
106 will be called once on each element, rather that each sub-sequence.
106 will be called once on each element, rather that each sub-sequence.
107
107
108 Parameters
108 Parameters
109 ----------
109 ----------
110
110
111 view : View instance
111 view : View instance
112 The view to be used for execution
112 The view to be used for execution
113 f : callable
113 f : callable
114 The function to be wrapped into a remote function
114 The function to be wrapped into a remote function
115 dist : str [default: 'b']
115 dist : str [default: 'b']
116 The key for which mapObject to use to distribute sequences
116 The key for which mapObject to use to distribute sequences
117 options are:
117 options are:
118 * 'b' : use contiguous chunks in order
118 * 'b' : use contiguous chunks in order
119 * 'r' : use round-robin striping
119 * 'r' : use round-robin striping
120 block : bool [default: None]
120 block : bool [default: None]
121 Whether to wait for results or not. The default behavior is
121 Whether to wait for results or not. The default behavior is
122 to use the current `block` attribute of `view`
122 to use the current `block` attribute of `view`
123 chunksize : int or None
123 chunksize : int or None
124 The size of chunk to use when breaking up sequences in a load-balanced manner
124 The size of chunk to use when breaking up sequences in a load-balanced manner
125 ordered : bool [default: True]
125 ordered : bool [default: True]
126 Whether
126 Whether
127 **flags : remaining kwargs are passed to View.temp_flags
127 **flags : remaining kwargs are passed to View.temp_flags
128 """
128 """
129
129
130 chunksize=None
130 chunksize=None
131 ordered=None
131 ordered=None
132 mapObject=None
132 mapObject=None
133
133
134 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
134 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
135 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
135 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
136 self.chunksize = chunksize
136 self.chunksize = chunksize
137 self.ordered = ordered
137 self.ordered = ordered
138
138
139 mapClass = Map.dists[dist]
139 mapClass = Map.dists[dist]
140 self.mapObject = mapClass()
140 self.mapObject = mapClass()
141
141
142 def __call__(self, *sequences):
142 def __call__(self, *sequences):
143 client = self.view.client
144
143 # check that the length of sequences match
145 # check that the length of sequences match
144 len_0 = len(sequences[0])
146 len_0 = len(sequences[0])
145 for s in sequences:
147 for s in sequences:
146 if len(s)!=len_0:
148 if len(s)!=len_0:
147 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
149 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
148 raise ValueError(msg)
150 raise ValueError(msg)
149 balanced = 'Balanced' in self.view.__class__.__name__
151 balanced = 'Balanced' in self.view.__class__.__name__
150 if balanced:
152 if balanced:
151 if self.chunksize:
153 if self.chunksize:
152 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
154 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
153 else:
155 else:
154 nparts = len_0
156 nparts = len_0
155 targets = [None]*nparts
157 targets = [None]*nparts
156 else:
158 else:
157 if self.chunksize:
159 if self.chunksize:
158 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
160 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
159 # multiplexed:
161 # multiplexed:
160 targets = self.view.targets
162 targets = self.view.targets
163 # 'all' is lazily evaluated at execution time, which is now:
164 if targets == 'all':
165 targets = client._build_targets(targets)[1]
161 nparts = len(targets)
166 nparts = len(targets)
162
167
163 msg_ids = []
168 msg_ids = []
164 # my_f = lambda *a: map(self.func, *a)
165 client = self.view.client
166 for index, t in enumerate(targets):
169 for index, t in enumerate(targets):
167 args = []
170 args = []
168 for seq in sequences:
171 for seq in sequences:
169 part = self.mapObject.getPartition(seq, index, nparts)
172 part = self.mapObject.getPartition(seq, index, nparts)
170 if len(part) == 0:
173 if len(part) == 0:
171 continue
174 continue
172 else:
175 else:
173 args.append(part)
176 args.append(part)
174 if not args:
177 if not args:
175 continue
178 continue
176
179
177 # print (args)
180 # print (args)
178 if hasattr(self, '_map'):
181 if hasattr(self, '_map'):
179 if sys.version_info[0] >= 3:
182 if sys.version_info[0] >= 3:
180 f = lambda f, *sequences: list(map(f, *sequences))
183 f = lambda f, *sequences: list(map(f, *sequences))
181 else:
184 else:
182 f = map
185 f = map
183 args = [self.func]+args
186 args = [self.func]+args
184 else:
187 else:
185 f=self.func
188 f=self.func
186
189
187 view = self.view if balanced else client[t]
190 view = self.view if balanced else client[t]
188 with view.temp_flags(block=False, **self.flags):
191 with view.temp_flags(block=False, **self.flags):
189 ar = view.apply(f, *args)
192 ar = view.apply(f, *args)
190
193
191 msg_ids.append(ar.msg_ids[0])
194 msg_ids.append(ar.msg_ids[0])
192
195
193 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
196 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
194 fname=self.func.__name__,
197 fname=self.func.__name__,
195 ordered=self.ordered
198 ordered=self.ordered
196 )
199 )
197
200
198 if self.block:
201 if self.block:
199 try:
202 try:
200 return r.get()
203 return r.get()
201 except KeyboardInterrupt:
204 except KeyboardInterrupt:
202 return r
205 return r
203 else:
206 else:
204 return r
207 return r
205
208
206 def map(self, *sequences):
209 def map(self, *sequences):
207 """call a function on each element of a sequence remotely.
210 """call a function on each element of a sequence remotely.
208 This should behave very much like the builtin map, but return an AsyncMapResult
211 This should behave very much like the builtin map, but return an AsyncMapResult
209 if self.block is False.
212 if self.block is False.
210 """
213 """
211 # set _map as a flag for use inside self.__call__
214 # set _map as a flag for use inside self.__call__
212 self._map = True
215 self._map = True
213 try:
216 try:
214 ret = self.__call__(*sequences)
217 ret = self.__call__(*sequences)
215 finally:
218 finally:
216 del self._map
219 del self._map
217 return ret
220 return ret
218
221
219 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
222 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,281 +1,316 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
31
31
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33
33
34 def setup():
34 def setup():
35 add_engines(4)
35 add_engines(4)
36
36
37 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
38
38
39 def test_ids(self):
39 def test_ids(self):
40 n = len(self.client.ids)
40 n = len(self.client.ids)
41 self.add_engines(3)
41 self.add_engines(3)
42 self.assertEquals(len(self.client.ids), n+3)
42 self.assertEquals(len(self.client.ids), n+3)
43
43
44 def test_view_indexing(self):
44 def test_view_indexing(self):
45 """test index access for views"""
45 """test index access for views"""
46 self.add_engines(2)
46 self.add_engines(2)
47 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
48 v = self.client[:]
48 v = self.client[:]
49 self.assertEquals(v.targets, targets)
49 self.assertEquals(v.targets, targets)
50 t = self.client.ids[2]
50 t = self.client.ids[2]
51 v = self.client[t]
51 v = self.client[t]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, t)
53 self.assertEquals(v.targets, t)
54 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
55 v = self.client[t]
55 v = self.client[t]
56 self.assert_(isinstance(v, DirectView))
56 self.assert_(isinstance(v, DirectView))
57 self.assertEquals(v.targets, t)
57 self.assertEquals(v.targets, t)
58 v = self.client[::2]
58 v = self.client[::2]
59 self.assert_(isinstance(v, DirectView))
59 self.assert_(isinstance(v, DirectView))
60 self.assertEquals(v.targets, targets[::2])
60 self.assertEquals(v.targets, targets[::2])
61 v = self.client[1::3]
61 v = self.client[1::3]
62 self.assert_(isinstance(v, DirectView))
62 self.assert_(isinstance(v, DirectView))
63 self.assertEquals(v.targets, targets[1::3])
63 self.assertEquals(v.targets, targets[1::3])
64 v = self.client[:-3]
64 v = self.client[:-3]
65 self.assert_(isinstance(v, DirectView))
65 self.assert_(isinstance(v, DirectView))
66 self.assertEquals(v.targets, targets[:-3])
66 self.assertEquals(v.targets, targets[:-3])
67 v = self.client[-1]
67 v = self.client[-1]
68 self.assert_(isinstance(v, DirectView))
68 self.assert_(isinstance(v, DirectView))
69 self.assertEquals(v.targets, targets[-1])
69 self.assertEquals(v.targets, targets[-1])
70 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
71
71
72 def test_lbview_targets(self):
72 def test_lbview_targets(self):
73 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
74 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
75 self.assertEquals(v.targets, None)
75 self.assertEquals(v.targets, None)
76 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
77 self.assertEquals(v.targets, [self.client.ids[-1]])
77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
79 self.assertEquals(v.targets, None)
79 self.assertEquals(v.targets, None)
80
80
81 def test_dview_targets(self):
81 def test_dview_targets(self):
82 """test load_balanced_view targets"""
82 """test direct_view targets"""
83 v = self.client.direct_view()
83 v = self.client.direct_view()
84 self.assertEquals(v.targets, 'all')
84 self.assertEquals(v.targets, 'all')
85 v = self.client.direct_view('all')
85 v = self.client.direct_view('all')
86 self.assertEquals(v.targets, 'all')
86 self.assertEquals(v.targets, 'all')
87 v = self.client.direct_view(-1)
87 v = self.client.direct_view(-1)
88 self.assertEquals(v.targets, self.client.ids[-1])
88 self.assertEquals(v.targets, self.client.ids[-1])
89
89
90 def test_lazy_all_targets(self):
91 """test lazy evaluation of rc.direct_view('all')"""
92 v = self.client.direct_view()
93 self.assertEquals(v.targets, 'all')
94
95 def double(x):
96 return x*2
97 seq = range(100)
98 ref = [ double(x) for x in seq ]
99
100 # add some engines, which should be used
101 self.add_engines(2)
102 n1 = len(self.client.ids)
103
104 # simple apply
105 r = v.apply_sync(lambda : 1)
106 self.assertEquals(r, [1] * n1)
107
108 # map goes through remotefunction
109 r = v.map_sync(double, seq)
110 self.assertEquals(r, ref)
111
112 # add a couple more engines, and try again
113 self.add_engines(2)
114 n2 = len(self.client.ids)
115 self.assertNotEquals(n2, n1)
116
117 # apply
118 r = v.apply_sync(lambda : 1)
119 self.assertEquals(r, [1] * n2)
120
121 # map
122 r = v.map_sync(double, seq)
123 self.assertEquals(r, ref)
124
90 def test_targets(self):
125 def test_targets(self):
91 """test various valid targets arguments"""
126 """test various valid targets arguments"""
92 build = self.client._build_targets
127 build = self.client._build_targets
93 ids = self.client.ids
128 ids = self.client.ids
94 idents,targets = build(None)
129 idents,targets = build(None)
95 self.assertEquals(ids, targets)
130 self.assertEquals(ids, targets)
96
131
97 def test_clear(self):
132 def test_clear(self):
98 """test clear behavior"""
133 """test clear behavior"""
99 # self.add_engines(2)
134 # self.add_engines(2)
100 v = self.client[:]
135 v = self.client[:]
101 v.block=True
136 v.block=True
102 v.push(dict(a=5))
137 v.push(dict(a=5))
103 v.pull('a')
138 v.pull('a')
104 id0 = self.client.ids[-1]
139 id0 = self.client.ids[-1]
105 self.client.clear(targets=id0, block=True)
140 self.client.clear(targets=id0, block=True)
106 a = self.client[:-1].get('a')
141 a = self.client[:-1].get('a')
107 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
108 self.client.clear(block=True)
143 self.client.clear(block=True)
109 for i in self.client.ids:
144 for i in self.client.ids:
110 # print i
145 # print i
111 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
112
147
113 def test_get_result(self):
148 def test_get_result(self):
114 """test getting results from the Hub."""
149 """test getting results from the Hub."""
115 c = clientmod.Client(profile='iptest')
150 c = clientmod.Client(profile='iptest')
116 # self.add_engines(1)
151 # self.add_engines(1)
117 t = c.ids[-1]
152 t = c.ids[-1]
118 ar = c[t].apply_async(wait, 1)
153 ar = c[t].apply_async(wait, 1)
119 # give the monitor time to notice the message
154 # give the monitor time to notice the message
120 time.sleep(.25)
155 time.sleep(.25)
121 ahr = self.client.get_result(ar.msg_ids)
156 ahr = self.client.get_result(ar.msg_ids)
122 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertTrue(isinstance(ahr, AsyncHubResult))
123 self.assertEquals(ahr.get(), ar.get())
158 self.assertEquals(ahr.get(), ar.get())
124 ar2 = self.client.get_result(ar.msg_ids)
159 ar2 = self.client.get_result(ar.msg_ids)
125 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 self.assertFalse(isinstance(ar2, AsyncHubResult))
126 c.close()
161 c.close()
127
162
128 def test_ids_list(self):
163 def test_ids_list(self):
129 """test client.ids"""
164 """test client.ids"""
130 # self.add_engines(2)
165 # self.add_engines(2)
131 ids = self.client.ids
166 ids = self.client.ids
132 self.assertEquals(ids, self.client._ids)
167 self.assertEquals(ids, self.client._ids)
133 self.assertFalse(ids is self.client._ids)
168 self.assertFalse(ids is self.client._ids)
134 ids.remove(ids[-1])
169 ids.remove(ids[-1])
135 self.assertNotEquals(ids, self.client._ids)
170 self.assertNotEquals(ids, self.client._ids)
136
171
137 def test_queue_status(self):
172 def test_queue_status(self):
138 # self.addEngine(4)
173 # self.addEngine(4)
139 ids = self.client.ids
174 ids = self.client.ids
140 id0 = ids[0]
175 id0 = ids[0]
141 qs = self.client.queue_status(targets=id0)
176 qs = self.client.queue_status(targets=id0)
142 self.assertTrue(isinstance(qs, dict))
177 self.assertTrue(isinstance(qs, dict))
143 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
178 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
144 allqs = self.client.queue_status()
179 allqs = self.client.queue_status()
145 self.assertTrue(isinstance(allqs, dict))
180 self.assertTrue(isinstance(allqs, dict))
146 intkeys = list(allqs.keys())
181 intkeys = list(allqs.keys())
147 intkeys.remove('unassigned')
182 intkeys.remove('unassigned')
148 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
183 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
149 unassigned = allqs.pop('unassigned')
184 unassigned = allqs.pop('unassigned')
150 for eid,qs in allqs.items():
185 for eid,qs in allqs.items():
151 self.assertTrue(isinstance(qs, dict))
186 self.assertTrue(isinstance(qs, dict))
152 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
187 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
153
188
154 def test_shutdown(self):
189 def test_shutdown(self):
155 # self.addEngine(4)
190 # self.addEngine(4)
156 ids = self.client.ids
191 ids = self.client.ids
157 id0 = ids[0]
192 id0 = ids[0]
158 self.client.shutdown(id0, block=True)
193 self.client.shutdown(id0, block=True)
159 while id0 in self.client.ids:
194 while id0 in self.client.ids:
160 time.sleep(0.1)
195 time.sleep(0.1)
161 self.client.spin()
196 self.client.spin()
162
197
163 self.assertRaises(IndexError, lambda : self.client[id0])
198 self.assertRaises(IndexError, lambda : self.client[id0])
164
199
165 def test_result_status(self):
200 def test_result_status(self):
166 pass
201 pass
167 # to be written
202 # to be written
168
203
169 def test_db_query_dt(self):
204 def test_db_query_dt(self):
170 """test db query by date"""
205 """test db query by date"""
171 hist = self.client.hub_history()
206 hist = self.client.hub_history()
172 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
207 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
173 tic = middle['submitted']
208 tic = middle['submitted']
174 before = self.client.db_query({'submitted' : {'$lt' : tic}})
209 before = self.client.db_query({'submitted' : {'$lt' : tic}})
175 after = self.client.db_query({'submitted' : {'$gte' : tic}})
210 after = self.client.db_query({'submitted' : {'$gte' : tic}})
176 self.assertEquals(len(before)+len(after),len(hist))
211 self.assertEquals(len(before)+len(after),len(hist))
177 for b in before:
212 for b in before:
178 self.assertTrue(b['submitted'] < tic)
213 self.assertTrue(b['submitted'] < tic)
179 for a in after:
214 for a in after:
180 self.assertTrue(a['submitted'] >= tic)
215 self.assertTrue(a['submitted'] >= tic)
181 same = self.client.db_query({'submitted' : tic})
216 same = self.client.db_query({'submitted' : tic})
182 for s in same:
217 for s in same:
183 self.assertTrue(s['submitted'] == tic)
218 self.assertTrue(s['submitted'] == tic)
184
219
185 def test_db_query_keys(self):
220 def test_db_query_keys(self):
186 """test extracting subset of record keys"""
221 """test extracting subset of record keys"""
187 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
222 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
188 for rec in found:
223 for rec in found:
189 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
224 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
190
225
191 def test_db_query_msg_id(self):
226 def test_db_query_msg_id(self):
192 """ensure msg_id is always in db queries"""
227 """ensure msg_id is always in db queries"""
193 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
228 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
194 for rec in found:
229 for rec in found:
195 self.assertTrue('msg_id' in rec.keys())
230 self.assertTrue('msg_id' in rec.keys())
196 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
197 for rec in found:
232 for rec in found:
198 self.assertTrue('msg_id' in rec.keys())
233 self.assertTrue('msg_id' in rec.keys())
199 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
200 for rec in found:
235 for rec in found:
201 self.assertTrue('msg_id' in rec.keys())
236 self.assertTrue('msg_id' in rec.keys())
202
237
203 def test_db_query_in(self):
238 def test_db_query_in(self):
204 """test db query with '$in','$nin' operators"""
239 """test db query with '$in','$nin' operators"""
205 hist = self.client.hub_history()
240 hist = self.client.hub_history()
206 even = hist[::2]
241 even = hist[::2]
207 odd = hist[1::2]
242 odd = hist[1::2]
208 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
243 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
209 found = [ r['msg_id'] for r in recs ]
244 found = [ r['msg_id'] for r in recs ]
210 self.assertEquals(set(even), set(found))
245 self.assertEquals(set(even), set(found))
211 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
246 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
212 found = [ r['msg_id'] for r in recs ]
247 found = [ r['msg_id'] for r in recs ]
213 self.assertEquals(set(odd), set(found))
248 self.assertEquals(set(odd), set(found))
214
249
215 def test_hub_history(self):
250 def test_hub_history(self):
216 hist = self.client.hub_history()
251 hist = self.client.hub_history()
217 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
252 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
218 recdict = {}
253 recdict = {}
219 for rec in recs:
254 for rec in recs:
220 recdict[rec['msg_id']] = rec
255 recdict[rec['msg_id']] = rec
221
256
222 latest = datetime(1984,1,1)
257 latest = datetime(1984,1,1)
223 for msg_id in hist:
258 for msg_id in hist:
224 rec = recdict[msg_id]
259 rec = recdict[msg_id]
225 newt = rec['submitted']
260 newt = rec['submitted']
226 self.assertTrue(newt >= latest)
261 self.assertTrue(newt >= latest)
227 latest = newt
262 latest = newt
228 ar = self.client[-1].apply_async(lambda : 1)
263 ar = self.client[-1].apply_async(lambda : 1)
229 ar.get()
264 ar.get()
230 time.sleep(0.25)
265 time.sleep(0.25)
231 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
266 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
232
267
233 def test_resubmit(self):
268 def test_resubmit(self):
234 def f():
269 def f():
235 import random
270 import random
236 return random.random()
271 return random.random()
237 v = self.client.load_balanced_view()
272 v = self.client.load_balanced_view()
238 ar = v.apply_async(f)
273 ar = v.apply_async(f)
239 r1 = ar.get(1)
274 r1 = ar.get(1)
240 # give the Hub a chance to notice:
275 # give the Hub a chance to notice:
241 time.sleep(0.5)
276 time.sleep(0.5)
242 ahr = self.client.resubmit(ar.msg_ids)
277 ahr = self.client.resubmit(ar.msg_ids)
243 r2 = ahr.get(1)
278 r2 = ahr.get(1)
244 self.assertFalse(r1 == r2)
279 self.assertFalse(r1 == r2)
245
280
246 def test_resubmit_inflight(self):
281 def test_resubmit_inflight(self):
247 """ensure ValueError on resubmit of inflight task"""
282 """ensure ValueError on resubmit of inflight task"""
248 v = self.client.load_balanced_view()
283 v = self.client.load_balanced_view()
249 ar = v.apply_async(time.sleep,1)
284 ar = v.apply_async(time.sleep,1)
250 # give the message a chance to arrive
285 # give the message a chance to arrive
251 time.sleep(0.2)
286 time.sleep(0.2)
252 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
287 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
253 ar.get(2)
288 ar.get(2)
254
289
255 def test_resubmit_badkey(self):
290 def test_resubmit_badkey(self):
256 """ensure KeyError on resubmit of nonexistant task"""
291 """ensure KeyError on resubmit of nonexistant task"""
257 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
292 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
258
293
259 def test_purge_results(self):
294 def test_purge_results(self):
260 # ensure there are some tasks
295 # ensure there are some tasks
261 for i in range(5):
296 for i in range(5):
262 self.client[:].apply_sync(lambda : 1)
297 self.client[:].apply_sync(lambda : 1)
263 # Wait for the Hub to realise the result is done:
298 # Wait for the Hub to realise the result is done:
264 # This prevents a race condition, where we
299 # This prevents a race condition, where we
265 # might purge a result the Hub still thinks is pending.
300 # might purge a result the Hub still thinks is pending.
266 time.sleep(0.1)
301 time.sleep(0.1)
267 rc2 = clientmod.Client(profile='iptest')
302 rc2 = clientmod.Client(profile='iptest')
268 hist = self.client.hub_history()
303 hist = self.client.hub_history()
269 ahr = rc2.get_result([hist[-1]])
304 ahr = rc2.get_result([hist[-1]])
270 ahr.wait(10)
305 ahr.wait(10)
271 self.client.purge_results(hist[-1])
306 self.client.purge_results(hist[-1])
272 newhist = self.client.hub_history()
307 newhist = self.client.hub_history()
273 self.assertEquals(len(newhist)+1,len(hist))
308 self.assertEquals(len(newhist)+1,len(hist))
274 rc2.spin()
309 rc2.spin()
275 rc2.close()
310 rc2.close()
276
311
277 def test_purge_all_results(self):
312 def test_purge_all_results(self):
278 self.client.purge_results('all')
313 self.client.purge_results('all')
279 hist = self.client.hub_history()
314 hist = self.client.hub_history()
280 self.assertEquals(len(hist), 0)
315 self.assertEquals(len(hist), 0)
281
316
General Comments 0
You need to be logged in to leave comments. Login now