##// END OF EJS Templates
allow rc.direct_view('all') to be lazily-evaluated...
MinRK -
Show More
@@ -1,1431 +1,1435 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 import time
21 import time
22 import warnings
22 import warnings
23 from datetime import datetime
23 from datetime import datetime
24 from getpass import getpass
24 from getpass import getpass
25 from pprint import pprint
25 from pprint import pprint
26
26
27 pjoin = os.path.join
27 pjoin = os.path.join
28
28
29 import zmq
29 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
30 # from zmq.eventloop import ioloop, zmqstream
31
31
32 from IPython.config.configurable import MultipleInstanceError
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
33 from IPython.core.application import BaseIPythonApplication
34
34
35 from IPython.utils.jsonutil import rekey
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 Dict, List, Bool, Set)
39 Dict, List, Bool, Set)
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import error
43 from IPython.parallel import error
44 from IPython.parallel import util
44 from IPython.parallel import util
45
45
46 from IPython.zmq.session import Session, Message
46 from IPython.zmq.session import Session, Message
47
47
48 from .asyncresult import AsyncResult, AsyncHubResult
48 from .asyncresult import AsyncResult, AsyncHubResult
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from .view import DirectView, LoadBalancedView
50 from .view import DirectView, LoadBalancedView
51
51
52 if sys.version_info[0] >= 3:
52 if sys.version_info[0] >= 3:
53 # xrange is used in a couple 'isinstance' tests in py2
53 # xrange is used in a couple 'isinstance' tests in py2
54 # should be just 'range' in 3k
54 # should be just 'range' in 3k
55 xrange = range
55 xrange = range
56
56
57 #--------------------------------------------------------------------------
57 #--------------------------------------------------------------------------
58 # Decorators for Client methods
58 # Decorators for Client methods
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60
60
61 @decorator
61 @decorator
62 def spin_first(f, self, *args, **kwargs):
62 def spin_first(f, self, *args, **kwargs):
63 """Call spin() to sync state prior to calling the method."""
63 """Call spin() to sync state prior to calling the method."""
64 self.spin()
64 self.spin()
65 return f(self, *args, **kwargs)
65 return f(self, *args, **kwargs)
66
66
67
67
68 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71
71
72 class Metadata(dict):
72 class Metadata(dict):
73 """Subclass of dict for initializing metadata values.
73 """Subclass of dict for initializing metadata values.
74
74
75 Attribute access works on keys.
75 Attribute access works on keys.
76
76
77 These objects have a strict set of keys - errors will raise if you try
77 These objects have a strict set of keys - errors will raise if you try
78 to add new keys.
78 to add new keys.
79 """
79 """
80 def __init__(self, *args, **kwargs):
80 def __init__(self, *args, **kwargs):
81 dict.__init__(self)
81 dict.__init__(self)
82 md = {'msg_id' : None,
82 md = {'msg_id' : None,
83 'submitted' : None,
83 'submitted' : None,
84 'started' : None,
84 'started' : None,
85 'completed' : None,
85 'completed' : None,
86 'received' : None,
86 'received' : None,
87 'engine_uuid' : None,
87 'engine_uuid' : None,
88 'engine_id' : None,
88 'engine_id' : None,
89 'follow' : None,
89 'follow' : None,
90 'after' : None,
90 'after' : None,
91 'status' : None,
91 'status' : None,
92
92
93 'pyin' : None,
93 'pyin' : None,
94 'pyout' : None,
94 'pyout' : None,
95 'pyerr' : None,
95 'pyerr' : None,
96 'stdout' : '',
96 'stdout' : '',
97 'stderr' : '',
97 'stderr' : '',
98 }
98 }
99 self.update(md)
99 self.update(md)
100 self.update(dict(*args, **kwargs))
100 self.update(dict(*args, **kwargs))
101
101
102 def __getattr__(self, key):
102 def __getattr__(self, key):
103 """getattr aliased to getitem"""
103 """getattr aliased to getitem"""
104 if key in self.iterkeys():
104 if key in self.iterkeys():
105 return self[key]
105 return self[key]
106 else:
106 else:
107 raise AttributeError(key)
107 raise AttributeError(key)
108
108
109 def __setattr__(self, key, value):
109 def __setattr__(self, key, value):
110 """setattr aliased to setitem, with strict"""
110 """setattr aliased to setitem, with strict"""
111 if key in self.iterkeys():
111 if key in self.iterkeys():
112 self[key] = value
112 self[key] = value
113 else:
113 else:
114 raise AttributeError(key)
114 raise AttributeError(key)
115
115
116 def __setitem__(self, key, value):
116 def __setitem__(self, key, value):
117 """strict static key enforcement"""
117 """strict static key enforcement"""
118 if key in self.iterkeys():
118 if key in self.iterkeys():
119 dict.__setitem__(self, key, value)
119 dict.__setitem__(self, key, value)
120 else:
120 else:
121 raise KeyError(key)
121 raise KeyError(key)
122
122
123
123
124 class Client(HasTraits):
124 class Client(HasTraits):
125 """A semi-synchronous client to the IPython ZMQ cluster
125 """A semi-synchronous client to the IPython ZMQ cluster
126
126
127 Parameters
127 Parameters
128 ----------
128 ----------
129
129
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 Connection information for the Hub's registration. If a json connector
131 Connection information for the Hub's registration. If a json connector
132 file is given, then likely no further configuration is necessary.
132 file is given, then likely no further configuration is necessary.
133 [Default: use profile]
133 [Default: use profile]
134 profile : bytes
134 profile : bytes
135 The name of the Cluster profile to be used to find connector information.
135 The name of the Cluster profile to be used to find connector information.
136 If run from an IPython application, the default profile will be the same
136 If run from an IPython application, the default profile will be the same
137 as the running application, otherwise it will be 'default'.
137 as the running application, otherwise it will be 'default'.
138 context : zmq.Context
138 context : zmq.Context
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 debug : bool
140 debug : bool
141 flag for lots of message printing for debug purposes
141 flag for lots of message printing for debug purposes
142 timeout : int/float
142 timeout : int/float
143 time (in seconds) to wait for connection replies from the Hub
143 time (in seconds) to wait for connection replies from the Hub
144 [Default: 10]
144 [Default: 10]
145
145
146 #-------------- session related args ----------------
146 #-------------- session related args ----------------
147
147
148 config : Config object
148 config : Config object
149 If specified, this will be relayed to the Session for configuration
149 If specified, this will be relayed to the Session for configuration
150 username : str
150 username : str
151 set username for the session object
151 set username for the session object
152 packer : str (import_string) or callable
152 packer : str (import_string) or callable
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 function to serialize messages. Must support same input as
154 function to serialize messages. Must support same input as
155 JSON, and output must be bytes.
155 JSON, and output must be bytes.
156 You can pass a callable directly as `pack`
156 You can pass a callable directly as `pack`
157 unpacker : str (import_string) or callable
157 unpacker : str (import_string) or callable
158 The inverse of packer. Only necessary if packer is specified as *not* one
158 The inverse of packer. Only necessary if packer is specified as *not* one
159 of 'json' or 'pickle'.
159 of 'json' or 'pickle'.
160
160
161 #-------------- ssh related args ----------------
161 #-------------- ssh related args ----------------
162 # These are args for configuring the ssh tunnel to be used
162 # These are args for configuring the ssh tunnel to be used
163 # credentials are used to forward connections over ssh to the Controller
163 # credentials are used to forward connections over ssh to the Controller
164 # Note that the ip given in `addr` needs to be relative to sshserver
164 # Note that the ip given in `addr` needs to be relative to sshserver
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # and set sshserver as the same machine the Controller is on. However,
166 # and set sshserver as the same machine the Controller is on. However,
167 # the only requirement is that sshserver is able to see the Controller
167 # the only requirement is that sshserver is able to see the Controller
168 # (i.e. is within the same trusted network).
168 # (i.e. is within the same trusted network).
169
169
170 sshserver : str
170 sshserver : str
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 If keyfile or password is specified, and this is not, it will default to
172 If keyfile or password is specified, and this is not, it will default to
173 the ip given in addr.
173 the ip given in addr.
174 sshkey : str; path to public ssh key file
174 sshkey : str; path to public ssh key file
175 This specifies a key to be used in ssh login, default None.
175 This specifies a key to be used in ssh login, default None.
176 Regular default ssh keys will be used without specifying this argument.
176 Regular default ssh keys will be used without specifying this argument.
177 password : str
177 password : str
178 Your ssh password to sshserver. Note that if this is left None,
178 Your ssh password to sshserver. Note that if this is left None,
179 you will be prompted for it if passwordless key based login is unavailable.
179 you will be prompted for it if passwordless key based login is unavailable.
180 paramiko : bool
180 paramiko : bool
181 flag for whether to use paramiko instead of shell ssh for tunneling.
181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 [default: True on win32, False else]
182 [default: True on win32, False else]
183
183
184 ------- exec authentication args -------
184 ------- exec authentication args -------
185 If even localhost is untrusted, you can have some protection against
185 If even localhost is untrusted, you can have some protection against
186 unauthorized execution by signing messages with HMAC digests.
186 unauthorized execution by signing messages with HMAC digests.
187 Messages are still sent as cleartext, so if someone can snoop your
187 Messages are still sent as cleartext, so if someone can snoop your
188 loopback traffic this will not protect your privacy, but will prevent
188 loopback traffic this will not protect your privacy, but will prevent
189 unauthorized execution.
189 unauthorized execution.
190
190
191 exec_key : str
191 exec_key : str
192 an authentication key or file containing a key
192 an authentication key or file containing a key
193 default: None
193 default: None
194
194
195
195
196 Attributes
196 Attributes
197 ----------
197 ----------
198
198
199 ids : list of int engine IDs
199 ids : list of int engine IDs
200 requesting the ids attribute always synchronizes
200 requesting the ids attribute always synchronizes
201 the registration state. To request ids without synchronization,
201 the registration state. To request ids without synchronization,
202 use semi-private _ids attributes.
202 use semi-private _ids attributes.
203
203
204 history : list of msg_ids
204 history : list of msg_ids
205 a list of msg_ids, keeping track of all the execution
205 a list of msg_ids, keeping track of all the execution
206 messages you have submitted in order.
206 messages you have submitted in order.
207
207
208 outstanding : set of msg_ids
208 outstanding : set of msg_ids
209 a set of msg_ids that have been submitted, but whose
209 a set of msg_ids that have been submitted, but whose
210 results have not yet been received.
210 results have not yet been received.
211
211
212 results : dict
212 results : dict
213 a dict of all our results, keyed by msg_id
213 a dict of all our results, keyed by msg_id
214
214
215 block : bool
215 block : bool
216 determines default behavior when block not specified
216 determines default behavior when block not specified
217 in execution methods
217 in execution methods
218
218
219 Methods
219 Methods
220 -------
220 -------
221
221
222 spin
222 spin
223 flushes incoming results and registration state changes
223 flushes incoming results and registration state changes
224 control methods spin, and requesting `ids` also ensures up to date
224 control methods spin, and requesting `ids` also ensures up to date
225
225
226 wait
226 wait
227 wait on one or more msg_ids
227 wait on one or more msg_ids
228
228
229 execution methods
229 execution methods
230 apply
230 apply
231 legacy: execute, run
231 legacy: execute, run
232
232
233 data movement
233 data movement
234 push, pull, scatter, gather
234 push, pull, scatter, gather
235
235
236 query methods
236 query methods
237 queue_status, get_result, purge, result_status
237 queue_status, get_result, purge, result_status
238
238
239 control methods
239 control methods
240 abort, shutdown
240 abort, shutdown
241
241
242 """
242 """
243
243
244
244
245 block = Bool(False)
245 block = Bool(False)
246 outstanding = Set()
246 outstanding = Set()
247 results = Instance('collections.defaultdict', (dict,))
247 results = Instance('collections.defaultdict', (dict,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
249 history = List()
249 history = List()
250 debug = Bool(False)
250 debug = Bool(False)
251
251
252 profile=Unicode()
252 profile=Unicode()
253 def _profile_default(self):
253 def _profile_default(self):
254 if BaseIPythonApplication.initialized():
254 if BaseIPythonApplication.initialized():
255 # an IPython app *might* be running, try to get its profile
255 # an IPython app *might* be running, try to get its profile
256 try:
256 try:
257 return BaseIPythonApplication.instance().profile
257 return BaseIPythonApplication.instance().profile
258 except (AttributeError, MultipleInstanceError):
258 except (AttributeError, MultipleInstanceError):
259 # could be a *different* subclass of config.Application,
259 # could be a *different* subclass of config.Application,
260 # which would raise one of these two errors.
260 # which would raise one of these two errors.
261 return u'default'
261 return u'default'
262 else:
262 else:
263 return u'default'
263 return u'default'
264
264
265
265
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _ids = List()
267 _ids = List()
268 _connected=Bool(False)
268 _connected=Bool(False)
269 _ssh=Bool(False)
269 _ssh=Bool(False)
270 _context = Instance('zmq.Context')
270 _context = Instance('zmq.Context')
271 _config = Dict()
271 _config = Dict()
272 _engines=Instance(util.ReverseDict, (), {})
272 _engines=Instance(util.ReverseDict, (), {})
273 # _hub_socket=Instance('zmq.Socket')
273 # _hub_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
280 _task_scheme=Unicode()
280 _task_scheme=Unicode()
281 _closed = False
281 _closed = False
282 _ignored_control_replies=Int(0)
282 _ignored_control_replies=Int(0)
283 _ignored_hub_replies=Int(0)
283 _ignored_hub_replies=Int(0)
284
284
285 def __new__(self, *args, **kw):
285 def __new__(self, *args, **kw):
286 # don't raise on positional args
286 # don't raise on positional args
287 return HasTraits.__new__(self, **kw)
287 return HasTraits.__new__(self, **kw)
288
288
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 context=None, debug=False, exec_key=None,
290 context=None, debug=False, exec_key=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 timeout=10, **extra_args
292 timeout=10, **extra_args
293 ):
293 ):
294 if profile:
294 if profile:
295 super(Client, self).__init__(debug=debug, profile=profile)
295 super(Client, self).__init__(debug=debug, profile=profile)
296 else:
296 else:
297 super(Client, self).__init__(debug=debug)
297 super(Client, self).__init__(debug=debug)
298 if context is None:
298 if context is None:
299 context = zmq.Context.instance()
299 context = zmq.Context.instance()
300 self._context = context
300 self._context = context
301
301
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 if self._cd is not None:
303 if self._cd is not None:
304 if url_or_file is None:
304 if url_or_file is None:
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 " Please specify at least one of url_or_file or profile."
307 " Please specify at least one of url_or_file or profile."
308
308
309 try:
309 try:
310 util.validate_url(url_or_file)
310 util.validate_url(url_or_file)
311 except AssertionError:
311 except AssertionError:
312 if not os.path.exists(url_or_file):
312 if not os.path.exists(url_or_file):
313 if self._cd:
313 if self._cd:
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 with open(url_or_file) as f:
316 with open(url_or_file) as f:
317 cfg = json.loads(f.read())
317 cfg = json.loads(f.read())
318 else:
318 else:
319 cfg = {'url':url_or_file}
319 cfg = {'url':url_or_file}
320
320
321 # sync defaults from args, json:
321 # sync defaults from args, json:
322 if sshserver:
322 if sshserver:
323 cfg['ssh'] = sshserver
323 cfg['ssh'] = sshserver
324 if exec_key:
324 if exec_key:
325 cfg['exec_key'] = exec_key
325 cfg['exec_key'] = exec_key
326 exec_key = cfg['exec_key']
326 exec_key = cfg['exec_key']
327 location = cfg.setdefault('location', None)
327 location = cfg.setdefault('location', None)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 url = cfg['url']
329 url = cfg['url']
330 proto,addr,port = util.split_url(url)
330 proto,addr,port = util.split_url(url)
331 if location is not None and addr == '127.0.0.1':
331 if location is not None and addr == '127.0.0.1':
332 # location specified, and connection is expected to be local
332 # location specified, and connection is expected to be local
333 if location not in LOCAL_IPS and not sshserver:
333 if location not in LOCAL_IPS and not sshserver:
334 # load ssh from JSON *only* if the controller is not on
334 # load ssh from JSON *only* if the controller is not on
335 # this machine
335 # this machine
336 sshserver=cfg['ssh']
336 sshserver=cfg['ssh']
337 if location not in LOCAL_IPS and not sshserver:
337 if location not in LOCAL_IPS and not sshserver:
338 # warn if no ssh specified, but SSH is probably needed
338 # warn if no ssh specified, but SSH is probably needed
339 # This is only a warning, because the most likely cause
339 # This is only a warning, because the most likely cause
340 # is a local Controller on a laptop whose IP is dynamic
340 # is a local Controller on a laptop whose IP is dynamic
341 warnings.warn("""
341 warnings.warn("""
342 Controller appears to be listening on localhost, but not on this machine.
342 Controller appears to be listening on localhost, but not on this machine.
343 If this is true, you should specify Client(...,sshserver='you@%s')
343 If this is true, you should specify Client(...,sshserver='you@%s')
344 or instruct your controller to listen on an external IP."""%location,
344 or instruct your controller to listen on an external IP."""%location,
345 RuntimeWarning)
345 RuntimeWarning)
346 elif not sshserver:
346 elif not sshserver:
347 # otherwise sync with cfg
347 # otherwise sync with cfg
348 sshserver = cfg['ssh']
348 sshserver = cfg['ssh']
349
349
350 self._config = cfg
350 self._config = cfg
351
351
352 self._ssh = bool(sshserver or sshkey or password)
352 self._ssh = bool(sshserver or sshkey or password)
353 if self._ssh and sshserver is None:
353 if self._ssh and sshserver is None:
354 # default to ssh via localhost
354 # default to ssh via localhost
355 sshserver = url.split('://')[1].split(':')[0]
355 sshserver = url.split('://')[1].split(':')[0]
356 if self._ssh and password is None:
356 if self._ssh and password is None:
357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 password=False
358 password=False
359 else:
359 else:
360 password = getpass("SSH Password for %s: "%sshserver)
360 password = getpass("SSH Password for %s: "%sshserver)
361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362
362
363 # configure and construct the session
363 # configure and construct the session
364 if exec_key is not None:
364 if exec_key is not None:
365 if os.path.isfile(exec_key):
365 if os.path.isfile(exec_key):
366 extra_args['keyfile'] = exec_key
366 extra_args['keyfile'] = exec_key
367 else:
367 else:
368 exec_key = util.asbytes(exec_key)
368 exec_key = util.asbytes(exec_key)
369 extra_args['key'] = exec_key
369 extra_args['key'] = exec_key
370 self.session = Session(**extra_args)
370 self.session = Session(**extra_args)
371
371
372 self._query_socket = self._context.socket(zmq.XREQ)
372 self._query_socket = self._context.socket(zmq.XREQ)
373 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
373 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
374 if self._ssh:
374 if self._ssh:
375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 else:
376 else:
377 self._query_socket.connect(url)
377 self._query_socket.connect(url)
378
378
379 self.session.debug = self.debug
379 self.session.debug = self.debug
380
380
381 self._notification_handlers = {'registration_notification' : self._register_engine,
381 self._notification_handlers = {'registration_notification' : self._register_engine,
382 'unregistration_notification' : self._unregister_engine,
382 'unregistration_notification' : self._unregister_engine,
383 'shutdown_notification' : lambda msg: self.close(),
383 'shutdown_notification' : lambda msg: self.close(),
384 }
384 }
385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 'apply_reply' : self._handle_apply_reply}
386 'apply_reply' : self._handle_apply_reply}
387 self._connect(sshserver, ssh_kwargs, timeout)
387 self._connect(sshserver, ssh_kwargs, timeout)
388
388
389 def __del__(self):
389 def __del__(self):
390 """cleanup sockets, but _not_ context."""
390 """cleanup sockets, but _not_ context."""
391 self.close()
391 self.close()
392
392
393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 if ipython_dir is None:
394 if ipython_dir is None:
395 ipython_dir = get_ipython_dir()
395 ipython_dir = get_ipython_dir()
396 if profile_dir is not None:
396 if profile_dir is not None:
397 try:
397 try:
398 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 return
399 return
400 except ProfileDirError:
400 except ProfileDirError:
401 pass
401 pass
402 elif profile is not None:
402 elif profile is not None:
403 try:
403 try:
404 self._cd = ProfileDir.find_profile_dir_by_name(
404 self._cd = ProfileDir.find_profile_dir_by_name(
405 ipython_dir, profile)
405 ipython_dir, profile)
406 return
406 return
407 except ProfileDirError:
407 except ProfileDirError:
408 pass
408 pass
409 self._cd = None
409 self._cd = None
410
410
411 def _update_engines(self, engines):
411 def _update_engines(self, engines):
412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 for k,v in engines.iteritems():
413 for k,v in engines.iteritems():
414 eid = int(k)
414 eid = int(k)
415 self._engines[eid] = v
415 self._engines[eid] = v
416 self._ids.append(eid)
416 self._ids.append(eid)
417 self._ids = sorted(self._ids)
417 self._ids = sorted(self._ids)
418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 self._task_scheme == 'pure' and self._task_socket:
419 self._task_scheme == 'pure' and self._task_socket:
420 self._stop_scheduling_tasks()
420 self._stop_scheduling_tasks()
421
421
422 def _stop_scheduling_tasks(self):
422 def _stop_scheduling_tasks(self):
423 """Stop scheduling tasks because an engine has been unregistered
423 """Stop scheduling tasks because an engine has been unregistered
424 from a pure ZMQ scheduler.
424 from a pure ZMQ scheduler.
425 """
425 """
426 self._task_socket.close()
426 self._task_socket.close()
427 self._task_socket = None
427 self._task_socket = None
428 msg = "An engine has been unregistered, and we are using pure " +\
428 msg = "An engine has been unregistered, and we are using pure " +\
429 "ZMQ task scheduling. Task farming will be disabled."
429 "ZMQ task scheduling. Task farming will be disabled."
430 if self.outstanding:
430 if self.outstanding:
431 msg += " If you were running tasks when this happened, " +\
431 msg += " If you were running tasks when this happened, " +\
432 "some `outstanding` msg_ids may never resolve."
432 "some `outstanding` msg_ids may never resolve."
433 warnings.warn(msg, RuntimeWarning)
433 warnings.warn(msg, RuntimeWarning)
434
434
435 def _build_targets(self, targets):
435 def _build_targets(self, targets):
436 """Turn valid target IDs or 'all' into two lists:
436 """Turn valid target IDs or 'all' into two lists:
437 (int_ids, uuids).
437 (int_ids, uuids).
438 """
438 """
439 if not self._ids:
439 if not self._ids:
440 # flush notification socket if no engines yet, just in case
440 # flush notification socket if no engines yet, just in case
441 if not self.ids:
441 if not self.ids:
442 raise error.NoEnginesRegistered("Can't build targets without any engines")
442 raise error.NoEnginesRegistered("Can't build targets without any engines")
443
443
444 if targets is None:
444 if targets is None:
445 targets = self._ids
445 targets = self._ids
446 elif isinstance(targets, basestring):
446 elif isinstance(targets, basestring):
447 if targets.lower() == 'all':
447 if targets.lower() == 'all':
448 targets = self._ids
448 targets = self._ids
449 else:
449 else:
450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 elif isinstance(targets, int):
451 elif isinstance(targets, int):
452 if targets < 0:
452 if targets < 0:
453 targets = self.ids[targets]
453 targets = self.ids[targets]
454 if targets not in self._ids:
454 if targets not in self._ids:
455 raise IndexError("No such engine: %i"%targets)
455 raise IndexError("No such engine: %i"%targets)
456 targets = [targets]
456 targets = [targets]
457
457
458 if isinstance(targets, slice):
458 if isinstance(targets, slice):
459 indices = range(len(self._ids))[targets]
459 indices = range(len(self._ids))[targets]
460 ids = self.ids
460 ids = self.ids
461 targets = [ ids[i] for i in indices ]
461 targets = [ ids[i] for i in indices ]
462
462
463 if not isinstance(targets, (tuple, list, xrange)):
463 if not isinstance(targets, (tuple, list, xrange)):
464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465
465
466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467
467
468 def _connect(self, sshserver, ssh_kwargs, timeout):
468 def _connect(self, sshserver, ssh_kwargs, timeout):
469 """setup all our socket connections to the cluster. This is called from
469 """setup all our socket connections to the cluster. This is called from
470 __init__."""
470 __init__."""
471
471
472 # Maybe allow reconnecting?
472 # Maybe allow reconnecting?
473 if self._connected:
473 if self._connected:
474 return
474 return
475 self._connected=True
475 self._connected=True
476
476
477 def connect_socket(s, url):
477 def connect_socket(s, url):
478 url = util.disambiguate_url(url, self._config['location'])
478 url = util.disambiguate_url(url, self._config['location'])
479 if self._ssh:
479 if self._ssh:
480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 else:
481 else:
482 return s.connect(url)
482 return s.connect(url)
483
483
484 self.session.send(self._query_socket, 'connection_request')
484 self.session.send(self._query_socket, 'connection_request')
485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 poller = zmq.Poller()
486 poller = zmq.Poller()
487 poller.register(self._query_socket, zmq.POLLIN)
487 poller.register(self._query_socket, zmq.POLLIN)
488 # poll expects milliseconds, timeout is seconds
488 # poll expects milliseconds, timeout is seconds
489 evts = poller.poll(timeout*1000)
489 evts = poller.poll(timeout*1000)
490 if not evts:
490 if not evts:
491 raise error.TimeoutError("Hub connection request timed out")
491 raise error.TimeoutError("Hub connection request timed out")
492 idents,msg = self.session.recv(self._query_socket,mode=0)
492 idents,msg = self.session.recv(self._query_socket,mode=0)
493 if self.debug:
493 if self.debug:
494 pprint(msg)
494 pprint(msg)
495 msg = Message(msg)
495 msg = Message(msg)
496 content = msg.content
496 content = msg.content
497 self._config['registration'] = dict(content)
497 self._config['registration'] = dict(content)
498 if content.status == 'ok':
498 if content.status == 'ok':
499 ident = util.asbytes(self.session.session)
499 ident = util.asbytes(self.session.session)
500 if content.mux:
500 if content.mux:
501 self._mux_socket = self._context.socket(zmq.XREQ)
501 self._mux_socket = self._context.socket(zmq.XREQ)
502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 connect_socket(self._mux_socket, content.mux)
503 connect_socket(self._mux_socket, content.mux)
504 if content.task:
504 if content.task:
505 self._task_scheme, task_addr = content.task
505 self._task_scheme, task_addr = content.task
506 self._task_socket = self._context.socket(zmq.XREQ)
506 self._task_socket = self._context.socket(zmq.XREQ)
507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 connect_socket(self._task_socket, task_addr)
508 connect_socket(self._task_socket, task_addr)
509 if content.notification:
509 if content.notification:
510 self._notification_socket = self._context.socket(zmq.SUB)
510 self._notification_socket = self._context.socket(zmq.SUB)
511 connect_socket(self._notification_socket, content.notification)
511 connect_socket(self._notification_socket, content.notification)
512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 # if content.query:
513 # if content.query:
514 # self._query_socket = self._context.socket(zmq.XREQ)
514 # self._query_socket = self._context.socket(zmq.XREQ)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
516 # connect_socket(self._query_socket, content.query)
516 # connect_socket(self._query_socket, content.query)
517 if content.control:
517 if content.control:
518 self._control_socket = self._context.socket(zmq.XREQ)
518 self._control_socket = self._context.socket(zmq.XREQ)
519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 connect_socket(self._control_socket, content.control)
520 connect_socket(self._control_socket, content.control)
521 if content.iopub:
521 if content.iopub:
522 self._iopub_socket = self._context.socket(zmq.SUB)
522 self._iopub_socket = self._context.socket(zmq.SUB)
523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 connect_socket(self._iopub_socket, content.iopub)
525 connect_socket(self._iopub_socket, content.iopub)
526 self._update_engines(dict(content.engines))
526 self._update_engines(dict(content.engines))
527 else:
527 else:
528 self._connected = False
528 self._connected = False
529 raise Exception("Failed to connect!")
529 raise Exception("Failed to connect!")
530
530
531 #--------------------------------------------------------------------------
531 #--------------------------------------------------------------------------
532 # handlers and callbacks for incoming messages
532 # handlers and callbacks for incoming messages
533 #--------------------------------------------------------------------------
533 #--------------------------------------------------------------------------
534
534
535 def _unwrap_exception(self, content):
535 def _unwrap_exception(self, content):
536 """unwrap exception, and remap engine_id to int."""
536 """unwrap exception, and remap engine_id to int."""
537 e = error.unwrap_exception(content)
537 e = error.unwrap_exception(content)
538 # print e.traceback
538 # print e.traceback
539 if e.engine_info:
539 if e.engine_info:
540 e_uuid = e.engine_info['engine_uuid']
540 e_uuid = e.engine_info['engine_uuid']
541 eid = self._engines[e_uuid]
541 eid = self._engines[e_uuid]
542 e.engine_info['engine_id'] = eid
542 e.engine_info['engine_id'] = eid
543 return e
543 return e
544
544
545 def _extract_metadata(self, header, parent, content):
545 def _extract_metadata(self, header, parent, content):
546 md = {'msg_id' : parent['msg_id'],
546 md = {'msg_id' : parent['msg_id'],
547 'received' : datetime.now(),
547 'received' : datetime.now(),
548 'engine_uuid' : header.get('engine', None),
548 'engine_uuid' : header.get('engine', None),
549 'follow' : parent.get('follow', []),
549 'follow' : parent.get('follow', []),
550 'after' : parent.get('after', []),
550 'after' : parent.get('after', []),
551 'status' : content['status'],
551 'status' : content['status'],
552 }
552 }
553
553
554 if md['engine_uuid'] is not None:
554 if md['engine_uuid'] is not None:
555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556
556
557 if 'date' in parent:
557 if 'date' in parent:
558 md['submitted'] = parent['date']
558 md['submitted'] = parent['date']
559 if 'started' in header:
559 if 'started' in header:
560 md['started'] = header['started']
560 md['started'] = header['started']
561 if 'date' in header:
561 if 'date' in header:
562 md['completed'] = header['date']
562 md['completed'] = header['date']
563 return md
563 return md
564
564
565 def _register_engine(self, msg):
565 def _register_engine(self, msg):
566 """Register a new engine, and update our connection info."""
566 """Register a new engine, and update our connection info."""
567 content = msg['content']
567 content = msg['content']
568 eid = content['id']
568 eid = content['id']
569 d = {eid : content['queue']}
569 d = {eid : content['queue']}
570 self._update_engines(d)
570 self._update_engines(d)
571
571
572 def _unregister_engine(self, msg):
572 def _unregister_engine(self, msg):
573 """Unregister an engine that has died."""
573 """Unregister an engine that has died."""
574 content = msg['content']
574 content = msg['content']
575 eid = int(content['id'])
575 eid = int(content['id'])
576 if eid in self._ids:
576 if eid in self._ids:
577 self._ids.remove(eid)
577 self._ids.remove(eid)
578 uuid = self._engines.pop(eid)
578 uuid = self._engines.pop(eid)
579
579
580 self._handle_stranded_msgs(eid, uuid)
580 self._handle_stranded_msgs(eid, uuid)
581
581
582 if self._task_socket and self._task_scheme == 'pure':
582 if self._task_socket and self._task_scheme == 'pure':
583 self._stop_scheduling_tasks()
583 self._stop_scheduling_tasks()
584
584
585 def _handle_stranded_msgs(self, eid, uuid):
585 def _handle_stranded_msgs(self, eid, uuid):
586 """Handle messages known to be on an engine when the engine unregisters.
586 """Handle messages known to be on an engine when the engine unregisters.
587
587
588 It is possible that this will fire prematurely - that is, an engine will
588 It is possible that this will fire prematurely - that is, an engine will
589 go down after completing a result, and the client will be notified
589 go down after completing a result, and the client will be notified
590 of the unregistration and later receive the successful result.
590 of the unregistration and later receive the successful result.
591 """
591 """
592
592
593 outstanding = self._outstanding_dict[uuid]
593 outstanding = self._outstanding_dict[uuid]
594
594
595 for msg_id in list(outstanding):
595 for msg_id in list(outstanding):
596 if msg_id in self.results:
596 if msg_id in self.results:
597 # we already
597 # we already
598 continue
598 continue
599 try:
599 try:
600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 except:
601 except:
602 content = error.wrap_exception()
602 content = error.wrap_exception()
603 # build a fake message:
603 # build a fake message:
604 parent = {}
604 parent = {}
605 header = {}
605 header = {}
606 parent['msg_id'] = msg_id
606 parent['msg_id'] = msg_id
607 header['engine'] = uuid
607 header['engine'] = uuid
608 header['date'] = datetime.now()
608 header['date'] = datetime.now()
609 msg = dict(parent_header=parent, header=header, content=content)
609 msg = dict(parent_header=parent, header=header, content=content)
610 self._handle_apply_reply(msg)
610 self._handle_apply_reply(msg)
611
611
612 def _handle_execute_reply(self, msg):
612 def _handle_execute_reply(self, msg):
613 """Save the reply to an execute_request into our results.
613 """Save the reply to an execute_request into our results.
614
614
615 execute messages are never actually used. apply is used instead.
615 execute messages are never actually used. apply is used instead.
616 """
616 """
617
617
618 parent = msg['parent_header']
618 parent = msg['parent_header']
619 msg_id = parent['msg_id']
619 msg_id = parent['msg_id']
620 if msg_id not in self.outstanding:
620 if msg_id not in self.outstanding:
621 if msg_id in self.history:
621 if msg_id in self.history:
622 print ("got stale result: %s"%msg_id)
622 print ("got stale result: %s"%msg_id)
623 else:
623 else:
624 print ("got unknown result: %s"%msg_id)
624 print ("got unknown result: %s"%msg_id)
625 else:
625 else:
626 self.outstanding.remove(msg_id)
626 self.outstanding.remove(msg_id)
627 self.results[msg_id] = self._unwrap_exception(msg['content'])
627 self.results[msg_id] = self._unwrap_exception(msg['content'])
628
628
629 def _handle_apply_reply(self, msg):
629 def _handle_apply_reply(self, msg):
630 """Save the reply to an apply_request into our results."""
630 """Save the reply to an apply_request into our results."""
631 parent = msg['parent_header']
631 parent = msg['parent_header']
632 msg_id = parent['msg_id']
632 msg_id = parent['msg_id']
633 if msg_id not in self.outstanding:
633 if msg_id not in self.outstanding:
634 if msg_id in self.history:
634 if msg_id in self.history:
635 print ("got stale result: %s"%msg_id)
635 print ("got stale result: %s"%msg_id)
636 print self.results[msg_id]
636 print self.results[msg_id]
637 print msg
637 print msg
638 else:
638 else:
639 print ("got unknown result: %s"%msg_id)
639 print ("got unknown result: %s"%msg_id)
640 else:
640 else:
641 self.outstanding.remove(msg_id)
641 self.outstanding.remove(msg_id)
642 content = msg['content']
642 content = msg['content']
643 header = msg['header']
643 header = msg['header']
644
644
645 # construct metadata:
645 # construct metadata:
646 md = self.metadata[msg_id]
646 md = self.metadata[msg_id]
647 md.update(self._extract_metadata(header, parent, content))
647 md.update(self._extract_metadata(header, parent, content))
648 # is this redundant?
648 # is this redundant?
649 self.metadata[msg_id] = md
649 self.metadata[msg_id] = md
650
650
651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 if msg_id in e_outstanding:
652 if msg_id in e_outstanding:
653 e_outstanding.remove(msg_id)
653 e_outstanding.remove(msg_id)
654
654
655 # construct result:
655 # construct result:
656 if content['status'] == 'ok':
656 if content['status'] == 'ok':
657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 elif content['status'] == 'aborted':
658 elif content['status'] == 'aborted':
659 self.results[msg_id] = error.TaskAborted(msg_id)
659 self.results[msg_id] = error.TaskAborted(msg_id)
660 elif content['status'] == 'resubmitted':
660 elif content['status'] == 'resubmitted':
661 # TODO: handle resubmission
661 # TODO: handle resubmission
662 pass
662 pass
663 else:
663 else:
664 self.results[msg_id] = self._unwrap_exception(content)
664 self.results[msg_id] = self._unwrap_exception(content)
665
665
666 def _flush_notifications(self):
666 def _flush_notifications(self):
667 """Flush notifications of engine registrations waiting
667 """Flush notifications of engine registrations waiting
668 in ZMQ queue."""
668 in ZMQ queue."""
669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 while msg is not None:
670 while msg is not None:
671 if self.debug:
671 if self.debug:
672 pprint(msg)
672 pprint(msg)
673 msg_type = msg['msg_type']
673 msg_type = msg['msg_type']
674 handler = self._notification_handlers.get(msg_type, None)
674 handler = self._notification_handlers.get(msg_type, None)
675 if handler is None:
675 if handler is None:
676 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 else:
677 else:
678 handler(msg)
678 handler(msg)
679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680
680
681 def _flush_results(self, sock):
681 def _flush_results(self, sock):
682 """Flush task or queue results waiting in ZMQ queue."""
682 """Flush task or queue results waiting in ZMQ queue."""
683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 while msg is not None:
684 while msg is not None:
685 if self.debug:
685 if self.debug:
686 pprint(msg)
686 pprint(msg)
687 msg_type = msg['msg_type']
687 msg_type = msg['msg_type']
688 handler = self._queue_handlers.get(msg_type, None)
688 handler = self._queue_handlers.get(msg_type, None)
689 if handler is None:
689 if handler is None:
690 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 else:
691 else:
692 handler(msg)
692 handler(msg)
693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694
694
695 def _flush_control(self, sock):
695 def _flush_control(self, sock):
696 """Flush replies from the control channel waiting
696 """Flush replies from the control channel waiting
697 in the ZMQ queue.
697 in the ZMQ queue.
698
698
699 Currently: ignore them."""
699 Currently: ignore them."""
700 if self._ignored_control_replies <= 0:
700 if self._ignored_control_replies <= 0:
701 return
701 return
702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 while msg is not None:
703 while msg is not None:
704 self._ignored_control_replies -= 1
704 self._ignored_control_replies -= 1
705 if self.debug:
705 if self.debug:
706 pprint(msg)
706 pprint(msg)
707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708
708
709 def _flush_ignored_control(self):
709 def _flush_ignored_control(self):
710 """flush ignored control replies"""
710 """flush ignored control replies"""
711 while self._ignored_control_replies > 0:
711 while self._ignored_control_replies > 0:
712 self.session.recv(self._control_socket)
712 self.session.recv(self._control_socket)
713 self._ignored_control_replies -= 1
713 self._ignored_control_replies -= 1
714
714
715 def _flush_ignored_hub_replies(self):
715 def _flush_ignored_hub_replies(self):
716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 while msg is not None:
717 while msg is not None:
718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719
719
720 def _flush_iopub(self, sock):
720 def _flush_iopub(self, sock):
721 """Flush replies from the iopub channel waiting
721 """Flush replies from the iopub channel waiting
722 in the ZMQ queue.
722 in the ZMQ queue.
723 """
723 """
724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 while msg is not None:
725 while msg is not None:
726 if self.debug:
726 if self.debug:
727 pprint(msg)
727 pprint(msg)
728 parent = msg['parent_header']
728 parent = msg['parent_header']
729 msg_id = parent['msg_id']
729 msg_id = parent['msg_id']
730 content = msg['content']
730 content = msg['content']
731 header = msg['header']
731 header = msg['header']
732 msg_type = msg['msg_type']
732 msg_type = msg['msg_type']
733
733
734 # init metadata:
734 # init metadata:
735 md = self.metadata[msg_id]
735 md = self.metadata[msg_id]
736
736
737 if msg_type == 'stream':
737 if msg_type == 'stream':
738 name = content['name']
738 name = content['name']
739 s = md[name] or ''
739 s = md[name] or ''
740 md[name] = s + content['data']
740 md[name] = s + content['data']
741 elif msg_type == 'pyerr':
741 elif msg_type == 'pyerr':
742 md.update({'pyerr' : self._unwrap_exception(content)})
742 md.update({'pyerr' : self._unwrap_exception(content)})
743 elif msg_type == 'pyin':
743 elif msg_type == 'pyin':
744 md.update({'pyin' : content['code']})
744 md.update({'pyin' : content['code']})
745 else:
745 else:
746 md.update({msg_type : content.get('data', '')})
746 md.update({msg_type : content.get('data', '')})
747
747
748 # reduntant?
748 # reduntant?
749 self.metadata[msg_id] = md
749 self.metadata[msg_id] = md
750
750
751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
752
752
753 #--------------------------------------------------------------------------
753 #--------------------------------------------------------------------------
754 # len, getitem
754 # len, getitem
755 #--------------------------------------------------------------------------
755 #--------------------------------------------------------------------------
756
756
757 def __len__(self):
757 def __len__(self):
758 """len(client) returns # of engines."""
758 """len(client) returns # of engines."""
759 return len(self.ids)
759 return len(self.ids)
760
760
761 def __getitem__(self, key):
761 def __getitem__(self, key):
762 """index access returns DirectView multiplexer objects
762 """index access returns DirectView multiplexer objects
763
763
764 Must be int, slice, or list/tuple/xrange of ints"""
764 Must be int, slice, or list/tuple/xrange of ints"""
765 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 if not isinstance(key, (int, slice, tuple, list, xrange)):
766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
767 else:
767 else:
768 return self.direct_view(key)
768 return self.direct_view(key)
769
769
770 #--------------------------------------------------------------------------
770 #--------------------------------------------------------------------------
771 # Begin public methods
771 # Begin public methods
772 #--------------------------------------------------------------------------
772 #--------------------------------------------------------------------------
773
773
774 @property
774 @property
775 def ids(self):
775 def ids(self):
776 """Always up-to-date ids property."""
776 """Always up-to-date ids property."""
777 self._flush_notifications()
777 self._flush_notifications()
778 # always copy:
778 # always copy:
779 return list(self._ids)
779 return list(self._ids)
780
780
781 def close(self):
781 def close(self):
782 if self._closed:
782 if self._closed:
783 return
783 return
784 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 snames = filter(lambda n: n.endswith('socket'), dir(self))
785 for socket in map(lambda name: getattr(self, name), snames):
785 for socket in map(lambda name: getattr(self, name), snames):
786 if isinstance(socket, zmq.Socket) and not socket.closed:
786 if isinstance(socket, zmq.Socket) and not socket.closed:
787 socket.close()
787 socket.close()
788 self._closed = True
788 self._closed = True
789
789
790 def spin(self):
790 def spin(self):
791 """Flush any registration notifications and execution results
791 """Flush any registration notifications and execution results
792 waiting in the ZMQ queue.
792 waiting in the ZMQ queue.
793 """
793 """
794 if self._notification_socket:
794 if self._notification_socket:
795 self._flush_notifications()
795 self._flush_notifications()
796 if self._mux_socket:
796 if self._mux_socket:
797 self._flush_results(self._mux_socket)
797 self._flush_results(self._mux_socket)
798 if self._task_socket:
798 if self._task_socket:
799 self._flush_results(self._task_socket)
799 self._flush_results(self._task_socket)
800 if self._control_socket:
800 if self._control_socket:
801 self._flush_control(self._control_socket)
801 self._flush_control(self._control_socket)
802 if self._iopub_socket:
802 if self._iopub_socket:
803 self._flush_iopub(self._iopub_socket)
803 self._flush_iopub(self._iopub_socket)
804 if self._query_socket:
804 if self._query_socket:
805 self._flush_ignored_hub_replies()
805 self._flush_ignored_hub_replies()
806
806
807 def wait(self, jobs=None, timeout=-1):
807 def wait(self, jobs=None, timeout=-1):
808 """waits on one or more `jobs`, for up to `timeout` seconds.
808 """waits on one or more `jobs`, for up to `timeout` seconds.
809
809
810 Parameters
810 Parameters
811 ----------
811 ----------
812
812
813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
814 ints are indices to self.history
814 ints are indices to self.history
815 strs are msg_ids
815 strs are msg_ids
816 default: wait on all outstanding messages
816 default: wait on all outstanding messages
817 timeout : float
817 timeout : float
818 a time in seconds, after which to give up.
818 a time in seconds, after which to give up.
819 default is -1, which means no timeout
819 default is -1, which means no timeout
820
820
821 Returns
821 Returns
822 -------
822 -------
823
823
824 True : when all msg_ids are done
824 True : when all msg_ids are done
825 False : timeout reached, some msg_ids still outstanding
825 False : timeout reached, some msg_ids still outstanding
826 """
826 """
827 tic = time.time()
827 tic = time.time()
828 if jobs is None:
828 if jobs is None:
829 theids = self.outstanding
829 theids = self.outstanding
830 else:
830 else:
831 if isinstance(jobs, (int, basestring, AsyncResult)):
831 if isinstance(jobs, (int, basestring, AsyncResult)):
832 jobs = [jobs]
832 jobs = [jobs]
833 theids = set()
833 theids = set()
834 for job in jobs:
834 for job in jobs:
835 if isinstance(job, int):
835 if isinstance(job, int):
836 # index access
836 # index access
837 job = self.history[job]
837 job = self.history[job]
838 elif isinstance(job, AsyncResult):
838 elif isinstance(job, AsyncResult):
839 map(theids.add, job.msg_ids)
839 map(theids.add, job.msg_ids)
840 continue
840 continue
841 theids.add(job)
841 theids.add(job)
842 if not theids.intersection(self.outstanding):
842 if not theids.intersection(self.outstanding):
843 return True
843 return True
844 self.spin()
844 self.spin()
845 while theids.intersection(self.outstanding):
845 while theids.intersection(self.outstanding):
846 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 if timeout >= 0 and ( time.time()-tic ) > timeout:
847 break
847 break
848 time.sleep(1e-3)
848 time.sleep(1e-3)
849 self.spin()
849 self.spin()
850 return len(theids.intersection(self.outstanding)) == 0
850 return len(theids.intersection(self.outstanding)) == 0
851
851
852 #--------------------------------------------------------------------------
852 #--------------------------------------------------------------------------
853 # Control methods
853 # Control methods
854 #--------------------------------------------------------------------------
854 #--------------------------------------------------------------------------
855
855
856 @spin_first
856 @spin_first
857 def clear(self, targets=None, block=None):
857 def clear(self, targets=None, block=None):
858 """Clear the namespace in target(s)."""
858 """Clear the namespace in target(s)."""
859 block = self.block if block is None else block
859 block = self.block if block is None else block
860 targets = self._build_targets(targets)[0]
860 targets = self._build_targets(targets)[0]
861 for t in targets:
861 for t in targets:
862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
863 error = False
863 error = False
864 if block:
864 if block:
865 self._flush_ignored_control()
865 self._flush_ignored_control()
866 for i in range(len(targets)):
866 for i in range(len(targets)):
867 idents,msg = self.session.recv(self._control_socket,0)
867 idents,msg = self.session.recv(self._control_socket,0)
868 if self.debug:
868 if self.debug:
869 pprint(msg)
869 pprint(msg)
870 if msg['content']['status'] != 'ok':
870 if msg['content']['status'] != 'ok':
871 error = self._unwrap_exception(msg['content'])
871 error = self._unwrap_exception(msg['content'])
872 else:
872 else:
873 self._ignored_control_replies += len(targets)
873 self._ignored_control_replies += len(targets)
874 if error:
874 if error:
875 raise error
875 raise error
876
876
877
877
878 @spin_first
878 @spin_first
879 def abort(self, jobs=None, targets=None, block=None):
879 def abort(self, jobs=None, targets=None, block=None):
880 """Abort specific jobs from the execution queues of target(s).
880 """Abort specific jobs from the execution queues of target(s).
881
881
882 This is a mechanism to prevent jobs that have already been submitted
882 This is a mechanism to prevent jobs that have already been submitted
883 from executing.
883 from executing.
884
884
885 Parameters
885 Parameters
886 ----------
886 ----------
887
887
888 jobs : msg_id, list of msg_ids, or AsyncResult
888 jobs : msg_id, list of msg_ids, or AsyncResult
889 The jobs to be aborted
889 The jobs to be aborted
890
890
891
891
892 """
892 """
893 block = self.block if block is None else block
893 block = self.block if block is None else block
894 targets = self._build_targets(targets)[0]
894 targets = self._build_targets(targets)[0]
895 msg_ids = []
895 msg_ids = []
896 if isinstance(jobs, (basestring,AsyncResult)):
896 if isinstance(jobs, (basestring,AsyncResult)):
897 jobs = [jobs]
897 jobs = [jobs]
898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
899 if bad_ids:
899 if bad_ids:
900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
901 for j in jobs:
901 for j in jobs:
902 if isinstance(j, AsyncResult):
902 if isinstance(j, AsyncResult):
903 msg_ids.extend(j.msg_ids)
903 msg_ids.extend(j.msg_ids)
904 else:
904 else:
905 msg_ids.append(j)
905 msg_ids.append(j)
906 content = dict(msg_ids=msg_ids)
906 content = dict(msg_ids=msg_ids)
907 for t in targets:
907 for t in targets:
908 self.session.send(self._control_socket, 'abort_request',
908 self.session.send(self._control_socket, 'abort_request',
909 content=content, ident=t)
909 content=content, ident=t)
910 error = False
910 error = False
911 if block:
911 if block:
912 self._flush_ignored_control()
912 self._flush_ignored_control()
913 for i in range(len(targets)):
913 for i in range(len(targets)):
914 idents,msg = self.session.recv(self._control_socket,0)
914 idents,msg = self.session.recv(self._control_socket,0)
915 if self.debug:
915 if self.debug:
916 pprint(msg)
916 pprint(msg)
917 if msg['content']['status'] != 'ok':
917 if msg['content']['status'] != 'ok':
918 error = self._unwrap_exception(msg['content'])
918 error = self._unwrap_exception(msg['content'])
919 else:
919 else:
920 self._ignored_control_replies += len(targets)
920 self._ignored_control_replies += len(targets)
921 if error:
921 if error:
922 raise error
922 raise error
923
923
924 @spin_first
924 @spin_first
925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
926 """Terminates one or more engine processes, optionally including the hub."""
926 """Terminates one or more engine processes, optionally including the hub."""
927 block = self.block if block is None else block
927 block = self.block if block is None else block
928 if hub:
928 if hub:
929 targets = 'all'
929 targets = 'all'
930 targets = self._build_targets(targets)[0]
930 targets = self._build_targets(targets)[0]
931 for t in targets:
931 for t in targets:
932 self.session.send(self._control_socket, 'shutdown_request',
932 self.session.send(self._control_socket, 'shutdown_request',
933 content={'restart':restart},ident=t)
933 content={'restart':restart},ident=t)
934 error = False
934 error = False
935 if block or hub:
935 if block or hub:
936 self._flush_ignored_control()
936 self._flush_ignored_control()
937 for i in range(len(targets)):
937 for i in range(len(targets)):
938 idents,msg = self.session.recv(self._control_socket, 0)
938 idents,msg = self.session.recv(self._control_socket, 0)
939 if self.debug:
939 if self.debug:
940 pprint(msg)
940 pprint(msg)
941 if msg['content']['status'] != 'ok':
941 if msg['content']['status'] != 'ok':
942 error = self._unwrap_exception(msg['content'])
942 error = self._unwrap_exception(msg['content'])
943 else:
943 else:
944 self._ignored_control_replies += len(targets)
944 self._ignored_control_replies += len(targets)
945
945
946 if hub:
946 if hub:
947 time.sleep(0.25)
947 time.sleep(0.25)
948 self.session.send(self._query_socket, 'shutdown_request')
948 self.session.send(self._query_socket, 'shutdown_request')
949 idents,msg = self.session.recv(self._query_socket, 0)
949 idents,msg = self.session.recv(self._query_socket, 0)
950 if self.debug:
950 if self.debug:
951 pprint(msg)
951 pprint(msg)
952 if msg['content']['status'] != 'ok':
952 if msg['content']['status'] != 'ok':
953 error = self._unwrap_exception(msg['content'])
953 error = self._unwrap_exception(msg['content'])
954
954
955 if error:
955 if error:
956 raise error
956 raise error
957
957
958 #--------------------------------------------------------------------------
958 #--------------------------------------------------------------------------
959 # Execution related methods
959 # Execution related methods
960 #--------------------------------------------------------------------------
960 #--------------------------------------------------------------------------
961
961
962 def _maybe_raise(self, result):
962 def _maybe_raise(self, result):
963 """wrapper for maybe raising an exception if apply failed."""
963 """wrapper for maybe raising an exception if apply failed."""
964 if isinstance(result, error.RemoteError):
964 if isinstance(result, error.RemoteError):
965 raise result
965 raise result
966
966
967 return result
967 return result
968
968
969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
970 ident=None):
970 ident=None):
971 """construct and send an apply message via a socket.
971 """construct and send an apply message via a socket.
972
972
973 This is the principal method with which all engine execution is performed by views.
973 This is the principal method with which all engine execution is performed by views.
974 """
974 """
975
975
976 assert not self._closed, "cannot use me anymore, I'm closed!"
976 assert not self._closed, "cannot use me anymore, I'm closed!"
977 # defaults:
977 # defaults:
978 args = args if args is not None else []
978 args = args if args is not None else []
979 kwargs = kwargs if kwargs is not None else {}
979 kwargs = kwargs if kwargs is not None else {}
980 subheader = subheader if subheader is not None else {}
980 subheader = subheader if subheader is not None else {}
981
981
982 # validate arguments
982 # validate arguments
983 if not callable(f):
983 if not callable(f):
984 raise TypeError("f must be callable, not %s"%type(f))
984 raise TypeError("f must be callable, not %s"%type(f))
985 if not isinstance(args, (tuple, list)):
985 if not isinstance(args, (tuple, list)):
986 raise TypeError("args must be tuple or list, not %s"%type(args))
986 raise TypeError("args must be tuple or list, not %s"%type(args))
987 if not isinstance(kwargs, dict):
987 if not isinstance(kwargs, dict):
988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
989 if not isinstance(subheader, dict):
989 if not isinstance(subheader, dict):
990 raise TypeError("subheader must be dict, not %s"%type(subheader))
990 raise TypeError("subheader must be dict, not %s"%type(subheader))
991
991
992 bufs = util.pack_apply_message(f,args,kwargs)
992 bufs = util.pack_apply_message(f,args,kwargs)
993
993
994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 subheader=subheader, track=track)
995 subheader=subheader, track=track)
996
996
997 msg_id = msg['msg_id']
997 msg_id = msg['msg_id']
998 self.outstanding.add(msg_id)
998 self.outstanding.add(msg_id)
999 if ident:
999 if ident:
1000 # possibly routed to a specific engine
1000 # possibly routed to a specific engine
1001 if isinstance(ident, list):
1001 if isinstance(ident, list):
1002 ident = ident[-1]
1002 ident = ident[-1]
1003 if ident in self._engines.values():
1003 if ident in self._engines.values():
1004 # save for later, in case of engine death
1004 # save for later, in case of engine death
1005 self._outstanding_dict[ident].add(msg_id)
1005 self._outstanding_dict[ident].add(msg_id)
1006 self.history.append(msg_id)
1006 self.history.append(msg_id)
1007 self.metadata[msg_id]['submitted'] = datetime.now()
1007 self.metadata[msg_id]['submitted'] = datetime.now()
1008
1008
1009 return msg
1009 return msg
1010
1010
1011 #--------------------------------------------------------------------------
1011 #--------------------------------------------------------------------------
1012 # construct a View object
1012 # construct a View object
1013 #--------------------------------------------------------------------------
1013 #--------------------------------------------------------------------------
1014
1014
1015 def load_balanced_view(self, targets=None):
1015 def load_balanced_view(self, targets=None):
1016 """construct a DirectView object.
1016 """construct a DirectView object.
1017
1017
1018 If no arguments are specified, create a LoadBalancedView
1018 If no arguments are specified, create a LoadBalancedView
1019 using all engines.
1019 using all engines.
1020
1020
1021 Parameters
1021 Parameters
1022 ----------
1022 ----------
1023
1023
1024 targets: list,slice,int,etc. [default: use all engines]
1024 targets: list,slice,int,etc. [default: use all engines]
1025 The subset of engines across which to load-balance
1025 The subset of engines across which to load-balance
1026 """
1026 """
1027 if targets == 'all':
1028 targets = None
1027 if targets is not None:
1029 if targets is not None:
1028 targets = self._build_targets(targets)[1]
1030 targets = self._build_targets(targets)[1]
1029 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1031 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1030
1032
1031 def direct_view(self, targets='all'):
1033 def direct_view(self, targets='all'):
1032 """construct a DirectView object.
1034 """construct a DirectView object.
1033
1035
1034 If no targets are specified, create a DirectView
1036 If no targets are specified, create a DirectView
1035 using all engines.
1037 using all engines.
1036
1038
1037 Parameters
1039 Parameters
1038 ----------
1040 ----------
1039
1041
1040 targets: list,slice,int,etc. [default: use all engines]
1042 targets: list,slice,int,etc. [default: use all engines]
1041 The engines to use for the View
1043 The engines to use for the View
1042 """
1044 """
1043 single = isinstance(targets, int)
1045 single = isinstance(targets, int)
1044 targets = self._build_targets(targets)[1]
1046 # allow 'all' to be lazily evaluated at each execution
1047 if targets != 'all':
1048 targets = self._build_targets(targets)[1]
1045 if single:
1049 if single:
1046 targets = targets[0]
1050 targets = targets[0]
1047 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1051 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1048
1052
1049 #--------------------------------------------------------------------------
1053 #--------------------------------------------------------------------------
1050 # Query methods
1054 # Query methods
1051 #--------------------------------------------------------------------------
1055 #--------------------------------------------------------------------------
1052
1056
1053 @spin_first
1057 @spin_first
1054 def get_result(self, indices_or_msg_ids=None, block=None):
1058 def get_result(self, indices_or_msg_ids=None, block=None):
1055 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1059 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1056
1060
1057 If the client already has the results, no request to the Hub will be made.
1061 If the client already has the results, no request to the Hub will be made.
1058
1062
1059 This is a convenient way to construct AsyncResult objects, which are wrappers
1063 This is a convenient way to construct AsyncResult objects, which are wrappers
1060 that include metadata about execution, and allow for awaiting results that
1064 that include metadata about execution, and allow for awaiting results that
1061 were not submitted by this Client.
1065 were not submitted by this Client.
1062
1066
1063 It can also be a convenient way to retrieve the metadata associated with
1067 It can also be a convenient way to retrieve the metadata associated with
1064 blocking execution, since it always retrieves
1068 blocking execution, since it always retrieves
1065
1069
1066 Examples
1070 Examples
1067 --------
1071 --------
1068 ::
1072 ::
1069
1073
1070 In [10]: r = client.apply()
1074 In [10]: r = client.apply()
1071
1075
1072 Parameters
1076 Parameters
1073 ----------
1077 ----------
1074
1078
1075 indices_or_msg_ids : integer history index, str msg_id, or list of either
1079 indices_or_msg_ids : integer history index, str msg_id, or list of either
1076 The indices or msg_ids of indices to be retrieved
1080 The indices or msg_ids of indices to be retrieved
1077
1081
1078 block : bool
1082 block : bool
1079 Whether to wait for the result to be done
1083 Whether to wait for the result to be done
1080
1084
1081 Returns
1085 Returns
1082 -------
1086 -------
1083
1087
1084 AsyncResult
1088 AsyncResult
1085 A single AsyncResult object will always be returned.
1089 A single AsyncResult object will always be returned.
1086
1090
1087 AsyncHubResult
1091 AsyncHubResult
1088 A subclass of AsyncResult that retrieves results from the Hub
1092 A subclass of AsyncResult that retrieves results from the Hub
1089
1093
1090 """
1094 """
1091 block = self.block if block is None else block
1095 block = self.block if block is None else block
1092 if indices_or_msg_ids is None:
1096 if indices_or_msg_ids is None:
1093 indices_or_msg_ids = -1
1097 indices_or_msg_ids = -1
1094
1098
1095 if not isinstance(indices_or_msg_ids, (list,tuple)):
1099 if not isinstance(indices_or_msg_ids, (list,tuple)):
1096 indices_or_msg_ids = [indices_or_msg_ids]
1100 indices_or_msg_ids = [indices_or_msg_ids]
1097
1101
1098 theids = []
1102 theids = []
1099 for id in indices_or_msg_ids:
1103 for id in indices_or_msg_ids:
1100 if isinstance(id, int):
1104 if isinstance(id, int):
1101 id = self.history[id]
1105 id = self.history[id]
1102 if not isinstance(id, basestring):
1106 if not isinstance(id, basestring):
1103 raise TypeError("indices must be str or int, not %r"%id)
1107 raise TypeError("indices must be str or int, not %r"%id)
1104 theids.append(id)
1108 theids.append(id)
1105
1109
1106 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1110 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1107 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1111 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1108
1112
1109 if remote_ids:
1113 if remote_ids:
1110 ar = AsyncHubResult(self, msg_ids=theids)
1114 ar = AsyncHubResult(self, msg_ids=theids)
1111 else:
1115 else:
1112 ar = AsyncResult(self, msg_ids=theids)
1116 ar = AsyncResult(self, msg_ids=theids)
1113
1117
1114 if block:
1118 if block:
1115 ar.wait()
1119 ar.wait()
1116
1120
1117 return ar
1121 return ar
1118
1122
1119 @spin_first
1123 @spin_first
1120 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1124 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1121 """Resubmit one or more tasks.
1125 """Resubmit one or more tasks.
1122
1126
1123 in-flight tasks may not be resubmitted.
1127 in-flight tasks may not be resubmitted.
1124
1128
1125 Parameters
1129 Parameters
1126 ----------
1130 ----------
1127
1131
1128 indices_or_msg_ids : integer history index, str msg_id, or list of either
1132 indices_or_msg_ids : integer history index, str msg_id, or list of either
1129 The indices or msg_ids of indices to be retrieved
1133 The indices or msg_ids of indices to be retrieved
1130
1134
1131 block : bool
1135 block : bool
1132 Whether to wait for the result to be done
1136 Whether to wait for the result to be done
1133
1137
1134 Returns
1138 Returns
1135 -------
1139 -------
1136
1140
1137 AsyncHubResult
1141 AsyncHubResult
1138 A subclass of AsyncResult that retrieves results from the Hub
1142 A subclass of AsyncResult that retrieves results from the Hub
1139
1143
1140 """
1144 """
1141 block = self.block if block is None else block
1145 block = self.block if block is None else block
1142 if indices_or_msg_ids is None:
1146 if indices_or_msg_ids is None:
1143 indices_or_msg_ids = -1
1147 indices_or_msg_ids = -1
1144
1148
1145 if not isinstance(indices_or_msg_ids, (list,tuple)):
1149 if not isinstance(indices_or_msg_ids, (list,tuple)):
1146 indices_or_msg_ids = [indices_or_msg_ids]
1150 indices_or_msg_ids = [indices_or_msg_ids]
1147
1151
1148 theids = []
1152 theids = []
1149 for id in indices_or_msg_ids:
1153 for id in indices_or_msg_ids:
1150 if isinstance(id, int):
1154 if isinstance(id, int):
1151 id = self.history[id]
1155 id = self.history[id]
1152 if not isinstance(id, basestring):
1156 if not isinstance(id, basestring):
1153 raise TypeError("indices must be str or int, not %r"%id)
1157 raise TypeError("indices must be str or int, not %r"%id)
1154 theids.append(id)
1158 theids.append(id)
1155
1159
1156 for msg_id in theids:
1160 for msg_id in theids:
1157 self.outstanding.discard(msg_id)
1161 self.outstanding.discard(msg_id)
1158 if msg_id in self.history:
1162 if msg_id in self.history:
1159 self.history.remove(msg_id)
1163 self.history.remove(msg_id)
1160 self.results.pop(msg_id, None)
1164 self.results.pop(msg_id, None)
1161 self.metadata.pop(msg_id, None)
1165 self.metadata.pop(msg_id, None)
1162 content = dict(msg_ids = theids)
1166 content = dict(msg_ids = theids)
1163
1167
1164 self.session.send(self._query_socket, 'resubmit_request', content)
1168 self.session.send(self._query_socket, 'resubmit_request', content)
1165
1169
1166 zmq.select([self._query_socket], [], [])
1170 zmq.select([self._query_socket], [], [])
1167 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1171 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1168 if self.debug:
1172 if self.debug:
1169 pprint(msg)
1173 pprint(msg)
1170 content = msg['content']
1174 content = msg['content']
1171 if content['status'] != 'ok':
1175 if content['status'] != 'ok':
1172 raise self._unwrap_exception(content)
1176 raise self._unwrap_exception(content)
1173
1177
1174 ar = AsyncHubResult(self, msg_ids=theids)
1178 ar = AsyncHubResult(self, msg_ids=theids)
1175
1179
1176 if block:
1180 if block:
1177 ar.wait()
1181 ar.wait()
1178
1182
1179 return ar
1183 return ar
1180
1184
1181 @spin_first
1185 @spin_first
1182 def result_status(self, msg_ids, status_only=True):
1186 def result_status(self, msg_ids, status_only=True):
1183 """Check on the status of the result(s) of the apply request with `msg_ids`.
1187 """Check on the status of the result(s) of the apply request with `msg_ids`.
1184
1188
1185 If status_only is False, then the actual results will be retrieved, else
1189 If status_only is False, then the actual results will be retrieved, else
1186 only the status of the results will be checked.
1190 only the status of the results will be checked.
1187
1191
1188 Parameters
1192 Parameters
1189 ----------
1193 ----------
1190
1194
1191 msg_ids : list of msg_ids
1195 msg_ids : list of msg_ids
1192 if int:
1196 if int:
1193 Passed as index to self.history for convenience.
1197 Passed as index to self.history for convenience.
1194 status_only : bool (default: True)
1198 status_only : bool (default: True)
1195 if False:
1199 if False:
1196 Retrieve the actual results of completed tasks.
1200 Retrieve the actual results of completed tasks.
1197
1201
1198 Returns
1202 Returns
1199 -------
1203 -------
1200
1204
1201 results : dict
1205 results : dict
1202 There will always be the keys 'pending' and 'completed', which will
1206 There will always be the keys 'pending' and 'completed', which will
1203 be lists of msg_ids that are incomplete or complete. If `status_only`
1207 be lists of msg_ids that are incomplete or complete. If `status_only`
1204 is False, then completed results will be keyed by their `msg_id`.
1208 is False, then completed results will be keyed by their `msg_id`.
1205 """
1209 """
1206 if not isinstance(msg_ids, (list,tuple)):
1210 if not isinstance(msg_ids, (list,tuple)):
1207 msg_ids = [msg_ids]
1211 msg_ids = [msg_ids]
1208
1212
1209 theids = []
1213 theids = []
1210 for msg_id in msg_ids:
1214 for msg_id in msg_ids:
1211 if isinstance(msg_id, int):
1215 if isinstance(msg_id, int):
1212 msg_id = self.history[msg_id]
1216 msg_id = self.history[msg_id]
1213 if not isinstance(msg_id, basestring):
1217 if not isinstance(msg_id, basestring):
1214 raise TypeError("msg_ids must be str, not %r"%msg_id)
1218 raise TypeError("msg_ids must be str, not %r"%msg_id)
1215 theids.append(msg_id)
1219 theids.append(msg_id)
1216
1220
1217 completed = []
1221 completed = []
1218 local_results = {}
1222 local_results = {}
1219
1223
1220 # comment this block out to temporarily disable local shortcut:
1224 # comment this block out to temporarily disable local shortcut:
1221 for msg_id in theids:
1225 for msg_id in theids:
1222 if msg_id in self.results:
1226 if msg_id in self.results:
1223 completed.append(msg_id)
1227 completed.append(msg_id)
1224 local_results[msg_id] = self.results[msg_id]
1228 local_results[msg_id] = self.results[msg_id]
1225 theids.remove(msg_id)
1229 theids.remove(msg_id)
1226
1230
1227 if theids: # some not locally cached
1231 if theids: # some not locally cached
1228 content = dict(msg_ids=theids, status_only=status_only)
1232 content = dict(msg_ids=theids, status_only=status_only)
1229 msg = self.session.send(self._query_socket, "result_request", content=content)
1233 msg = self.session.send(self._query_socket, "result_request", content=content)
1230 zmq.select([self._query_socket], [], [])
1234 zmq.select([self._query_socket], [], [])
1231 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1232 if self.debug:
1236 if self.debug:
1233 pprint(msg)
1237 pprint(msg)
1234 content = msg['content']
1238 content = msg['content']
1235 if content['status'] != 'ok':
1239 if content['status'] != 'ok':
1236 raise self._unwrap_exception(content)
1240 raise self._unwrap_exception(content)
1237 buffers = msg['buffers']
1241 buffers = msg['buffers']
1238 else:
1242 else:
1239 content = dict(completed=[],pending=[])
1243 content = dict(completed=[],pending=[])
1240
1244
1241 content['completed'].extend(completed)
1245 content['completed'].extend(completed)
1242
1246
1243 if status_only:
1247 if status_only:
1244 return content
1248 return content
1245
1249
1246 failures = []
1250 failures = []
1247 # load cached results into result:
1251 # load cached results into result:
1248 content.update(local_results)
1252 content.update(local_results)
1249
1253
1250 # update cache with results:
1254 # update cache with results:
1251 for msg_id in sorted(theids):
1255 for msg_id in sorted(theids):
1252 if msg_id in content['completed']:
1256 if msg_id in content['completed']:
1253 rec = content[msg_id]
1257 rec = content[msg_id]
1254 parent = rec['header']
1258 parent = rec['header']
1255 header = rec['result_header']
1259 header = rec['result_header']
1256 rcontent = rec['result_content']
1260 rcontent = rec['result_content']
1257 iodict = rec['io']
1261 iodict = rec['io']
1258 if isinstance(rcontent, str):
1262 if isinstance(rcontent, str):
1259 rcontent = self.session.unpack(rcontent)
1263 rcontent = self.session.unpack(rcontent)
1260
1264
1261 md = self.metadata[msg_id]
1265 md = self.metadata[msg_id]
1262 md.update(self._extract_metadata(header, parent, rcontent))
1266 md.update(self._extract_metadata(header, parent, rcontent))
1263 md.update(iodict)
1267 md.update(iodict)
1264
1268
1265 if rcontent['status'] == 'ok':
1269 if rcontent['status'] == 'ok':
1266 res,buffers = util.unserialize_object(buffers)
1270 res,buffers = util.unserialize_object(buffers)
1267 else:
1271 else:
1268 print rcontent
1272 print rcontent
1269 res = self._unwrap_exception(rcontent)
1273 res = self._unwrap_exception(rcontent)
1270 failures.append(res)
1274 failures.append(res)
1271
1275
1272 self.results[msg_id] = res
1276 self.results[msg_id] = res
1273 content[msg_id] = res
1277 content[msg_id] = res
1274
1278
1275 if len(theids) == 1 and failures:
1279 if len(theids) == 1 and failures:
1276 raise failures[0]
1280 raise failures[0]
1277
1281
1278 error.collect_exceptions(failures, "result_status")
1282 error.collect_exceptions(failures, "result_status")
1279 return content
1283 return content
1280
1284
1281 @spin_first
1285 @spin_first
1282 def queue_status(self, targets='all', verbose=False):
1286 def queue_status(self, targets='all', verbose=False):
1283 """Fetch the status of engine queues.
1287 """Fetch the status of engine queues.
1284
1288
1285 Parameters
1289 Parameters
1286 ----------
1290 ----------
1287
1291
1288 targets : int/str/list of ints/strs
1292 targets : int/str/list of ints/strs
1289 the engines whose states are to be queried.
1293 the engines whose states are to be queried.
1290 default : all
1294 default : all
1291 verbose : bool
1295 verbose : bool
1292 Whether to return lengths only, or lists of ids for each element
1296 Whether to return lengths only, or lists of ids for each element
1293 """
1297 """
1294 engine_ids = self._build_targets(targets)[1]
1298 engine_ids = self._build_targets(targets)[1]
1295 content = dict(targets=engine_ids, verbose=verbose)
1299 content = dict(targets=engine_ids, verbose=verbose)
1296 self.session.send(self._query_socket, "queue_request", content=content)
1300 self.session.send(self._query_socket, "queue_request", content=content)
1297 idents,msg = self.session.recv(self._query_socket, 0)
1301 idents,msg = self.session.recv(self._query_socket, 0)
1298 if self.debug:
1302 if self.debug:
1299 pprint(msg)
1303 pprint(msg)
1300 content = msg['content']
1304 content = msg['content']
1301 status = content.pop('status')
1305 status = content.pop('status')
1302 if status != 'ok':
1306 if status != 'ok':
1303 raise self._unwrap_exception(content)
1307 raise self._unwrap_exception(content)
1304 content = rekey(content)
1308 content = rekey(content)
1305 if isinstance(targets, int):
1309 if isinstance(targets, int):
1306 return content[targets]
1310 return content[targets]
1307 else:
1311 else:
1308 return content
1312 return content
1309
1313
1310 @spin_first
1314 @spin_first
1311 def purge_results(self, jobs=[], targets=[]):
1315 def purge_results(self, jobs=[], targets=[]):
1312 """Tell the Hub to forget results.
1316 """Tell the Hub to forget results.
1313
1317
1314 Individual results can be purged by msg_id, or the entire
1318 Individual results can be purged by msg_id, or the entire
1315 history of specific targets can be purged.
1319 history of specific targets can be purged.
1316
1320
1317 Use `purge_results('all')` to scrub everything from the Hub's db.
1321 Use `purge_results('all')` to scrub everything from the Hub's db.
1318
1322
1319 Parameters
1323 Parameters
1320 ----------
1324 ----------
1321
1325
1322 jobs : str or list of str or AsyncResult objects
1326 jobs : str or list of str or AsyncResult objects
1323 the msg_ids whose results should be forgotten.
1327 the msg_ids whose results should be forgotten.
1324 targets : int/str/list of ints/strs
1328 targets : int/str/list of ints/strs
1325 The targets, by int_id, whose entire history is to be purged.
1329 The targets, by int_id, whose entire history is to be purged.
1326
1330
1327 default : None
1331 default : None
1328 """
1332 """
1329 if not targets and not jobs:
1333 if not targets and not jobs:
1330 raise ValueError("Must specify at least one of `targets` and `jobs`")
1334 raise ValueError("Must specify at least one of `targets` and `jobs`")
1331 if targets:
1335 if targets:
1332 targets = self._build_targets(targets)[1]
1336 targets = self._build_targets(targets)[1]
1333
1337
1334 # construct msg_ids from jobs
1338 # construct msg_ids from jobs
1335 if jobs == 'all':
1339 if jobs == 'all':
1336 msg_ids = jobs
1340 msg_ids = jobs
1337 else:
1341 else:
1338 msg_ids = []
1342 msg_ids = []
1339 if isinstance(jobs, (basestring,AsyncResult)):
1343 if isinstance(jobs, (basestring,AsyncResult)):
1340 jobs = [jobs]
1344 jobs = [jobs]
1341 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1345 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1342 if bad_ids:
1346 if bad_ids:
1343 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1347 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1344 for j in jobs:
1348 for j in jobs:
1345 if isinstance(j, AsyncResult):
1349 if isinstance(j, AsyncResult):
1346 msg_ids.extend(j.msg_ids)
1350 msg_ids.extend(j.msg_ids)
1347 else:
1351 else:
1348 msg_ids.append(j)
1352 msg_ids.append(j)
1349
1353
1350 content = dict(engine_ids=targets, msg_ids=msg_ids)
1354 content = dict(engine_ids=targets, msg_ids=msg_ids)
1351 self.session.send(self._query_socket, "purge_request", content=content)
1355 self.session.send(self._query_socket, "purge_request", content=content)
1352 idents, msg = self.session.recv(self._query_socket, 0)
1356 idents, msg = self.session.recv(self._query_socket, 0)
1353 if self.debug:
1357 if self.debug:
1354 pprint(msg)
1358 pprint(msg)
1355 content = msg['content']
1359 content = msg['content']
1356 if content['status'] != 'ok':
1360 if content['status'] != 'ok':
1357 raise self._unwrap_exception(content)
1361 raise self._unwrap_exception(content)
1358
1362
1359 @spin_first
1363 @spin_first
1360 def hub_history(self):
1364 def hub_history(self):
1361 """Get the Hub's history
1365 """Get the Hub's history
1362
1366
1363 Just like the Client, the Hub has a history, which is a list of msg_ids.
1367 Just like the Client, the Hub has a history, which is a list of msg_ids.
1364 This will contain the history of all clients, and, depending on configuration,
1368 This will contain the history of all clients, and, depending on configuration,
1365 may contain history across multiple cluster sessions.
1369 may contain history across multiple cluster sessions.
1366
1370
1367 Any msg_id returned here is a valid argument to `get_result`.
1371 Any msg_id returned here is a valid argument to `get_result`.
1368
1372
1369 Returns
1373 Returns
1370 -------
1374 -------
1371
1375
1372 msg_ids : list of strs
1376 msg_ids : list of strs
1373 list of all msg_ids, ordered by task submission time.
1377 list of all msg_ids, ordered by task submission time.
1374 """
1378 """
1375
1379
1376 self.session.send(self._query_socket, "history_request", content={})
1380 self.session.send(self._query_socket, "history_request", content={})
1377 idents, msg = self.session.recv(self._query_socket, 0)
1381 idents, msg = self.session.recv(self._query_socket, 0)
1378
1382
1379 if self.debug:
1383 if self.debug:
1380 pprint(msg)
1384 pprint(msg)
1381 content = msg['content']
1385 content = msg['content']
1382 if content['status'] != 'ok':
1386 if content['status'] != 'ok':
1383 raise self._unwrap_exception(content)
1387 raise self._unwrap_exception(content)
1384 else:
1388 else:
1385 return content['history']
1389 return content['history']
1386
1390
1387 @spin_first
1391 @spin_first
1388 def db_query(self, query, keys=None):
1392 def db_query(self, query, keys=None):
1389 """Query the Hub's TaskRecord database
1393 """Query the Hub's TaskRecord database
1390
1394
1391 This will return a list of task record dicts that match `query`
1395 This will return a list of task record dicts that match `query`
1392
1396
1393 Parameters
1397 Parameters
1394 ----------
1398 ----------
1395
1399
1396 query : mongodb query dict
1400 query : mongodb query dict
1397 The search dict. See mongodb query docs for details.
1401 The search dict. See mongodb query docs for details.
1398 keys : list of strs [optional]
1402 keys : list of strs [optional]
1399 The subset of keys to be returned. The default is to fetch everything but buffers.
1403 The subset of keys to be returned. The default is to fetch everything but buffers.
1400 'msg_id' will *always* be included.
1404 'msg_id' will *always* be included.
1401 """
1405 """
1402 if isinstance(keys, basestring):
1406 if isinstance(keys, basestring):
1403 keys = [keys]
1407 keys = [keys]
1404 content = dict(query=query, keys=keys)
1408 content = dict(query=query, keys=keys)
1405 self.session.send(self._query_socket, "db_request", content=content)
1409 self.session.send(self._query_socket, "db_request", content=content)
1406 idents, msg = self.session.recv(self._query_socket, 0)
1410 idents, msg = self.session.recv(self._query_socket, 0)
1407 if self.debug:
1411 if self.debug:
1408 pprint(msg)
1412 pprint(msg)
1409 content = msg['content']
1413 content = msg['content']
1410 if content['status'] != 'ok':
1414 if content['status'] != 'ok':
1411 raise self._unwrap_exception(content)
1415 raise self._unwrap_exception(content)
1412
1416
1413 records = content['records']
1417 records = content['records']
1414
1418
1415 buffer_lens = content['buffer_lens']
1419 buffer_lens = content['buffer_lens']
1416 result_buffer_lens = content['result_buffer_lens']
1420 result_buffer_lens = content['result_buffer_lens']
1417 buffers = msg['buffers']
1421 buffers = msg['buffers']
1418 has_bufs = buffer_lens is not None
1422 has_bufs = buffer_lens is not None
1419 has_rbufs = result_buffer_lens is not None
1423 has_rbufs = result_buffer_lens is not None
1420 for i,rec in enumerate(records):
1424 for i,rec in enumerate(records):
1421 # relink buffers
1425 # relink buffers
1422 if has_bufs:
1426 if has_bufs:
1423 blen = buffer_lens[i]
1427 blen = buffer_lens[i]
1424 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1428 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1425 if has_rbufs:
1429 if has_rbufs:
1426 blen = result_buffer_lens[i]
1430 blen = result_buffer_lens[i]
1427 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1431 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1428
1432
1429 return records
1433 return records
1430
1434
1431 __all__ = [ 'Client' ]
1435 __all__ = [ 'Client' ]
@@ -1,270 +1,279 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
31
31
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33
33
34 def setup():
34 def setup():
35 add_engines(4)
35 add_engines(4)
36
36
37 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
38
38
39 def test_ids(self):
39 def test_ids(self):
40 n = len(self.client.ids)
40 n = len(self.client.ids)
41 self.add_engines(3)
41 self.add_engines(3)
42 self.assertEquals(len(self.client.ids), n+3)
42 self.assertEquals(len(self.client.ids), n+3)
43
43
44 def test_view_indexing(self):
44 def test_view_indexing(self):
45 """test index access for views"""
45 """test index access for views"""
46 self.add_engines(2)
46 self.add_engines(2)
47 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
48 v = self.client[:]
48 v = self.client[:]
49 self.assertEquals(v.targets, targets)
49 self.assertEquals(v.targets, targets)
50 t = self.client.ids[2]
50 t = self.client.ids[2]
51 v = self.client[t]
51 v = self.client[t]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, t)
53 self.assertEquals(v.targets, t)
54 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
55 v = self.client[t]
55 v = self.client[t]
56 self.assert_(isinstance(v, DirectView))
56 self.assert_(isinstance(v, DirectView))
57 self.assertEquals(v.targets, t)
57 self.assertEquals(v.targets, t)
58 v = self.client[::2]
58 v = self.client[::2]
59 self.assert_(isinstance(v, DirectView))
59 self.assert_(isinstance(v, DirectView))
60 self.assertEquals(v.targets, targets[::2])
60 self.assertEquals(v.targets, targets[::2])
61 v = self.client[1::3]
61 v = self.client[1::3]
62 self.assert_(isinstance(v, DirectView))
62 self.assert_(isinstance(v, DirectView))
63 self.assertEquals(v.targets, targets[1::3])
63 self.assertEquals(v.targets, targets[1::3])
64 v = self.client[:-3]
64 v = self.client[:-3]
65 self.assert_(isinstance(v, DirectView))
65 self.assert_(isinstance(v, DirectView))
66 self.assertEquals(v.targets, targets[:-3])
66 self.assertEquals(v.targets, targets[:-3])
67 v = self.client[-1]
67 v = self.client[-1]
68 self.assert_(isinstance(v, DirectView))
68 self.assert_(isinstance(v, DirectView))
69 self.assertEquals(v.targets, targets[-1])
69 self.assertEquals(v.targets, targets[-1])
70 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
71
71
72 def test_lbview_targets(self):
72 def test_lbview_targets(self):
73 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
74 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
75 self.assertEquals(v.targets, None)
75 self.assertEquals(v.targets, None)
76 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
77 self.assertEquals(v.targets, [self.client.ids[-1]])
77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
79 self.assertEquals(v.targets, self.client.ids)
79 self.assertEquals(v.targets, None)
80
81 def test_dview_targets(self):
82 """test load_balanced_view targets"""
83 v = self.client.direct_view()
84 self.assertEquals(v.targets, 'all')
85 v = self.client.direct_view('all')
86 self.assertEquals(v.targets, 'all')
87 v = self.client.direct_view(-1)
88 self.assertEquals(v.targets, self.client.ids[-1])
80
89
81 def test_targets(self):
90 def test_targets(self):
82 """test various valid targets arguments"""
91 """test various valid targets arguments"""
83 build = self.client._build_targets
92 build = self.client._build_targets
84 ids = self.client.ids
93 ids = self.client.ids
85 idents,targets = build(None)
94 idents,targets = build(None)
86 self.assertEquals(ids, targets)
95 self.assertEquals(ids, targets)
87
96
88 def test_clear(self):
97 def test_clear(self):
89 """test clear behavior"""
98 """test clear behavior"""
90 # self.add_engines(2)
99 # self.add_engines(2)
91 v = self.client[:]
100 v = self.client[:]
92 v.block=True
101 v.block=True
93 v.push(dict(a=5))
102 v.push(dict(a=5))
94 v.pull('a')
103 v.pull('a')
95 id0 = self.client.ids[-1]
104 id0 = self.client.ids[-1]
96 self.client.clear(targets=id0, block=True)
105 self.client.clear(targets=id0, block=True)
97 a = self.client[:-1].get('a')
106 a = self.client[:-1].get('a')
98 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
107 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
99 self.client.clear(block=True)
108 self.client.clear(block=True)
100 for i in self.client.ids:
109 for i in self.client.ids:
101 # print i
110 # print i
102 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
111 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
103
112
104 def test_get_result(self):
113 def test_get_result(self):
105 """test getting results from the Hub."""
114 """test getting results from the Hub."""
106 c = clientmod.Client(profile='iptest')
115 c = clientmod.Client(profile='iptest')
107 # self.add_engines(1)
116 # self.add_engines(1)
108 t = c.ids[-1]
117 t = c.ids[-1]
109 ar = c[t].apply_async(wait, 1)
118 ar = c[t].apply_async(wait, 1)
110 # give the monitor time to notice the message
119 # give the monitor time to notice the message
111 time.sleep(.25)
120 time.sleep(.25)
112 ahr = self.client.get_result(ar.msg_ids)
121 ahr = self.client.get_result(ar.msg_ids)
113 self.assertTrue(isinstance(ahr, AsyncHubResult))
122 self.assertTrue(isinstance(ahr, AsyncHubResult))
114 self.assertEquals(ahr.get(), ar.get())
123 self.assertEquals(ahr.get(), ar.get())
115 ar2 = self.client.get_result(ar.msg_ids)
124 ar2 = self.client.get_result(ar.msg_ids)
116 self.assertFalse(isinstance(ar2, AsyncHubResult))
125 self.assertFalse(isinstance(ar2, AsyncHubResult))
117 c.close()
126 c.close()
118
127
119 def test_ids_list(self):
128 def test_ids_list(self):
120 """test client.ids"""
129 """test client.ids"""
121 # self.add_engines(2)
130 # self.add_engines(2)
122 ids = self.client.ids
131 ids = self.client.ids
123 self.assertEquals(ids, self.client._ids)
132 self.assertEquals(ids, self.client._ids)
124 self.assertFalse(ids is self.client._ids)
133 self.assertFalse(ids is self.client._ids)
125 ids.remove(ids[-1])
134 ids.remove(ids[-1])
126 self.assertNotEquals(ids, self.client._ids)
135 self.assertNotEquals(ids, self.client._ids)
127
136
128 def test_queue_status(self):
137 def test_queue_status(self):
129 # self.addEngine(4)
138 # self.addEngine(4)
130 ids = self.client.ids
139 ids = self.client.ids
131 id0 = ids[0]
140 id0 = ids[0]
132 qs = self.client.queue_status(targets=id0)
141 qs = self.client.queue_status(targets=id0)
133 self.assertTrue(isinstance(qs, dict))
142 self.assertTrue(isinstance(qs, dict))
134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
143 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135 allqs = self.client.queue_status()
144 allqs = self.client.queue_status()
136 self.assertTrue(isinstance(allqs, dict))
145 self.assertTrue(isinstance(allqs, dict))
137 intkeys = list(allqs.keys())
146 intkeys = list(allqs.keys())
138 intkeys.remove('unassigned')
147 intkeys.remove('unassigned')
139 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
148 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
140 unassigned = allqs.pop('unassigned')
149 unassigned = allqs.pop('unassigned')
141 for eid,qs in allqs.items():
150 for eid,qs in allqs.items():
142 self.assertTrue(isinstance(qs, dict))
151 self.assertTrue(isinstance(qs, dict))
143 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
152 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
144
153
145 def test_shutdown(self):
154 def test_shutdown(self):
146 # self.addEngine(4)
155 # self.addEngine(4)
147 ids = self.client.ids
156 ids = self.client.ids
148 id0 = ids[0]
157 id0 = ids[0]
149 self.client.shutdown(id0, block=True)
158 self.client.shutdown(id0, block=True)
150 while id0 in self.client.ids:
159 while id0 in self.client.ids:
151 time.sleep(0.1)
160 time.sleep(0.1)
152 self.client.spin()
161 self.client.spin()
153
162
154 self.assertRaises(IndexError, lambda : self.client[id0])
163 self.assertRaises(IndexError, lambda : self.client[id0])
155
164
156 def test_result_status(self):
165 def test_result_status(self):
157 pass
166 pass
158 # to be written
167 # to be written
159
168
160 def test_db_query_dt(self):
169 def test_db_query_dt(self):
161 """test db query by date"""
170 """test db query by date"""
162 hist = self.client.hub_history()
171 hist = self.client.hub_history()
163 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
172 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
164 tic = middle['submitted']
173 tic = middle['submitted']
165 before = self.client.db_query({'submitted' : {'$lt' : tic}})
174 before = self.client.db_query({'submitted' : {'$lt' : tic}})
166 after = self.client.db_query({'submitted' : {'$gte' : tic}})
175 after = self.client.db_query({'submitted' : {'$gte' : tic}})
167 self.assertEquals(len(before)+len(after),len(hist))
176 self.assertEquals(len(before)+len(after),len(hist))
168 for b in before:
177 for b in before:
169 self.assertTrue(b['submitted'] < tic)
178 self.assertTrue(b['submitted'] < tic)
170 for a in after:
179 for a in after:
171 self.assertTrue(a['submitted'] >= tic)
180 self.assertTrue(a['submitted'] >= tic)
172 same = self.client.db_query({'submitted' : tic})
181 same = self.client.db_query({'submitted' : tic})
173 for s in same:
182 for s in same:
174 self.assertTrue(s['submitted'] == tic)
183 self.assertTrue(s['submitted'] == tic)
175
184
176 def test_db_query_keys(self):
185 def test_db_query_keys(self):
177 """test extracting subset of record keys"""
186 """test extracting subset of record keys"""
178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
187 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
179 for rec in found:
188 for rec in found:
180 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
189 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
181
190
182 def test_db_query_msg_id(self):
191 def test_db_query_msg_id(self):
183 """ensure msg_id is always in db queries"""
192 """ensure msg_id is always in db queries"""
184 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
193 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
185 for rec in found:
194 for rec in found:
186 self.assertTrue('msg_id' in rec.keys())
195 self.assertTrue('msg_id' in rec.keys())
187 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
196 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
188 for rec in found:
197 for rec in found:
189 self.assertTrue('msg_id' in rec.keys())
198 self.assertTrue('msg_id' in rec.keys())
190 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
199 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
191 for rec in found:
200 for rec in found:
192 self.assertTrue('msg_id' in rec.keys())
201 self.assertTrue('msg_id' in rec.keys())
193
202
194 def test_db_query_in(self):
203 def test_db_query_in(self):
195 """test db query with '$in','$nin' operators"""
204 """test db query with '$in','$nin' operators"""
196 hist = self.client.hub_history()
205 hist = self.client.hub_history()
197 even = hist[::2]
206 even = hist[::2]
198 odd = hist[1::2]
207 odd = hist[1::2]
199 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
208 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
200 found = [ r['msg_id'] for r in recs ]
209 found = [ r['msg_id'] for r in recs ]
201 self.assertEquals(set(even), set(found))
210 self.assertEquals(set(even), set(found))
202 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
211 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
203 found = [ r['msg_id'] for r in recs ]
212 found = [ r['msg_id'] for r in recs ]
204 self.assertEquals(set(odd), set(found))
213 self.assertEquals(set(odd), set(found))
205
214
206 def test_hub_history(self):
215 def test_hub_history(self):
207 hist = self.client.hub_history()
216 hist = self.client.hub_history()
208 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
217 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
209 recdict = {}
218 recdict = {}
210 for rec in recs:
219 for rec in recs:
211 recdict[rec['msg_id']] = rec
220 recdict[rec['msg_id']] = rec
212
221
213 latest = datetime(1984,1,1)
222 latest = datetime(1984,1,1)
214 for msg_id in hist:
223 for msg_id in hist:
215 rec = recdict[msg_id]
224 rec = recdict[msg_id]
216 newt = rec['submitted']
225 newt = rec['submitted']
217 self.assertTrue(newt >= latest)
226 self.assertTrue(newt >= latest)
218 latest = newt
227 latest = newt
219 ar = self.client[-1].apply_async(lambda : 1)
228 ar = self.client[-1].apply_async(lambda : 1)
220 ar.get()
229 ar.get()
221 time.sleep(0.25)
230 time.sleep(0.25)
222 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
231 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
223
232
224 def test_resubmit(self):
233 def test_resubmit(self):
225 def f():
234 def f():
226 import random
235 import random
227 return random.random()
236 return random.random()
228 v = self.client.load_balanced_view()
237 v = self.client.load_balanced_view()
229 ar = v.apply_async(f)
238 ar = v.apply_async(f)
230 r1 = ar.get(1)
239 r1 = ar.get(1)
231 ahr = self.client.resubmit(ar.msg_ids)
240 ahr = self.client.resubmit(ar.msg_ids)
232 r2 = ahr.get(1)
241 r2 = ahr.get(1)
233 self.assertFalse(r1 == r2)
242 self.assertFalse(r1 == r2)
234
243
235 def test_resubmit_inflight(self):
244 def test_resubmit_inflight(self):
236 """ensure ValueError on resubmit of inflight task"""
245 """ensure ValueError on resubmit of inflight task"""
237 v = self.client.load_balanced_view()
246 v = self.client.load_balanced_view()
238 ar = v.apply_async(time.sleep,1)
247 ar = v.apply_async(time.sleep,1)
239 # give the message a chance to arrive
248 # give the message a chance to arrive
240 time.sleep(0.2)
249 time.sleep(0.2)
241 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
250 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
242 ar.get(2)
251 ar.get(2)
243
252
244 def test_resubmit_badkey(self):
253 def test_resubmit_badkey(self):
245 """ensure KeyError on resubmit of nonexistant task"""
254 """ensure KeyError on resubmit of nonexistant task"""
246 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
255 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
247
256
248 def test_purge_results(self):
257 def test_purge_results(self):
249 # ensure there are some tasks
258 # ensure there are some tasks
250 for i in range(5):
259 for i in range(5):
251 self.client[:].apply_sync(lambda : 1)
260 self.client[:].apply_sync(lambda : 1)
252 # Wait for the Hub to realise the result is done:
261 # Wait for the Hub to realise the result is done:
253 # This prevents a race condition, where we
262 # This prevents a race condition, where we
254 # might purge a result the Hub still thinks is pending.
263 # might purge a result the Hub still thinks is pending.
255 time.sleep(0.1)
264 time.sleep(0.1)
256 rc2 = clientmod.Client(profile='iptest')
265 rc2 = clientmod.Client(profile='iptest')
257 hist = self.client.hub_history()
266 hist = self.client.hub_history()
258 ahr = rc2.get_result([hist[-1]])
267 ahr = rc2.get_result([hist[-1]])
259 ahr.wait(10)
268 ahr.wait(10)
260 self.client.purge_results(hist[-1])
269 self.client.purge_results(hist[-1])
261 newhist = self.client.hub_history()
270 newhist = self.client.hub_history()
262 self.assertEquals(len(newhist)+1,len(hist))
271 self.assertEquals(len(newhist)+1,len(hist))
263 rc2.spin()
272 rc2.spin()
264 rc2.close()
273 rc2.close()
265
274
266 def test_purge_all_results(self):
275 def test_purge_all_results(self):
267 self.client.purge_results('all')
276 self.client.purge_results('all')
268 hist = self.client.hub_history()
277 hist = self.client.hub_history()
269 self.assertEquals(len(hist), 0)
278 self.assertEquals(len(hist), 0)
270
279
General Comments 0
You need to be logged in to leave comments. Login now