##// END OF EJS Templates
add Client.spin_thread()...
MinRK -
Show More
@@ -1,1454 +1,1505 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 from threading import Thread, Event
21 import time
22 import time
22 import warnings
23 import warnings
23 from datetime import datetime
24 from datetime import datetime
24 from getpass import getpass
25 from getpass import getpass
25 from pprint import pprint
26 from pprint import pprint
26
27
27 pjoin = os.path.join
28 pjoin = os.path.join
28
29
29 import zmq
30 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
31 # from zmq.eventloop import ioloop, zmqstream
31
32
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
34 from IPython.core.application import BaseIPythonApplication
34
35
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 Dict, List, Bool, Set)
40 Dict, List, Bool, Set, Any)
40 from IPython.external.decorator import decorator
41 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
42 from IPython.external.ssh import tunnel
42
43
43 from IPython.parallel import Reference
44 from IPython.parallel import Reference
44 from IPython.parallel import error
45 from IPython.parallel import error
45 from IPython.parallel import util
46 from IPython.parallel import util
46
47
47 from IPython.zmq.session import Session, Message
48 from IPython.zmq.session import Session, Message
48
49
49 from .asyncresult import AsyncResult, AsyncHubResult
50 from .asyncresult import AsyncResult, AsyncHubResult
50 from IPython.core.profiledir import ProfileDir, ProfileDirError
51 from IPython.core.profiledir import ProfileDir, ProfileDirError
51 from .view import DirectView, LoadBalancedView
52 from .view import DirectView, LoadBalancedView
52
53
53 if sys.version_info[0] >= 3:
54 if sys.version_info[0] >= 3:
54 # xrange is used in a couple 'isinstance' tests in py2
55 # xrange is used in a couple 'isinstance' tests in py2
55 # should be just 'range' in 3k
56 # should be just 'range' in 3k
56 xrange = range
57 xrange = range
57
58
58 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
59 # Decorators for Client methods
60 # Decorators for Client methods
60 #--------------------------------------------------------------------------
61 #--------------------------------------------------------------------------
61
62
62 @decorator
63 @decorator
63 def spin_first(f, self, *args, **kwargs):
64 def spin_first(f, self, *args, **kwargs):
64 """Call spin() to sync state prior to calling the method."""
65 """Call spin() to sync state prior to calling the method."""
65 self.spin()
66 self.spin()
66 return f(self, *args, **kwargs)
67 return f(self, *args, **kwargs)
67
68
68
69
69 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
70 # Classes
71 # Classes
71 #--------------------------------------------------------------------------
72 #--------------------------------------------------------------------------
72
73
73 class Metadata(dict):
74 class Metadata(dict):
74 """Subclass of dict for initializing metadata values.
75 """Subclass of dict for initializing metadata values.
75
76
76 Attribute access works on keys.
77 Attribute access works on keys.
77
78
78 These objects have a strict set of keys - errors will raise if you try
79 These objects have a strict set of keys - errors will raise if you try
79 to add new keys.
80 to add new keys.
80 """
81 """
81 def __init__(self, *args, **kwargs):
82 def __init__(self, *args, **kwargs):
82 dict.__init__(self)
83 dict.__init__(self)
83 md = {'msg_id' : None,
84 md = {'msg_id' : None,
84 'submitted' : None,
85 'submitted' : None,
85 'started' : None,
86 'started' : None,
86 'completed' : None,
87 'completed' : None,
87 'received' : None,
88 'received' : None,
88 'engine_uuid' : None,
89 'engine_uuid' : None,
89 'engine_id' : None,
90 'engine_id' : None,
90 'follow' : None,
91 'follow' : None,
91 'after' : None,
92 'after' : None,
92 'status' : None,
93 'status' : None,
93
94
94 'pyin' : None,
95 'pyin' : None,
95 'pyout' : None,
96 'pyout' : None,
96 'pyerr' : None,
97 'pyerr' : None,
97 'stdout' : '',
98 'stdout' : '',
98 'stderr' : '',
99 'stderr' : '',
99 }
100 }
100 self.update(md)
101 self.update(md)
101 self.update(dict(*args, **kwargs))
102 self.update(dict(*args, **kwargs))
102
103
103 def __getattr__(self, key):
104 def __getattr__(self, key):
104 """getattr aliased to getitem"""
105 """getattr aliased to getitem"""
105 if key in self.iterkeys():
106 if key in self.iterkeys():
106 return self[key]
107 return self[key]
107 else:
108 else:
108 raise AttributeError(key)
109 raise AttributeError(key)
109
110
110 def __setattr__(self, key, value):
111 def __setattr__(self, key, value):
111 """setattr aliased to setitem, with strict"""
112 """setattr aliased to setitem, with strict"""
112 if key in self.iterkeys():
113 if key in self.iterkeys():
113 self[key] = value
114 self[key] = value
114 else:
115 else:
115 raise AttributeError(key)
116 raise AttributeError(key)
116
117
117 def __setitem__(self, key, value):
118 def __setitem__(self, key, value):
118 """strict static key enforcement"""
119 """strict static key enforcement"""
119 if key in self.iterkeys():
120 if key in self.iterkeys():
120 dict.__setitem__(self, key, value)
121 dict.__setitem__(self, key, value)
121 else:
122 else:
122 raise KeyError(key)
123 raise KeyError(key)
123
124
124
125
125 class Client(HasTraits):
126 class Client(HasTraits):
126 """A semi-synchronous client to the IPython ZMQ cluster
127 """A semi-synchronous client to the IPython ZMQ cluster
127
128
128 Parameters
129 Parameters
129 ----------
130 ----------
130
131
131 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
132 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
132 Connection information for the Hub's registration. If a json connector
133 Connection information for the Hub's registration. If a json connector
133 file is given, then likely no further configuration is necessary.
134 file is given, then likely no further configuration is necessary.
134 [Default: use profile]
135 [Default: use profile]
135 profile : bytes
136 profile : bytes
136 The name of the Cluster profile to be used to find connector information.
137 The name of the Cluster profile to be used to find connector information.
137 If run from an IPython application, the default profile will be the same
138 If run from an IPython application, the default profile will be the same
138 as the running application, otherwise it will be 'default'.
139 as the running application, otherwise it will be 'default'.
139 context : zmq.Context
140 context : zmq.Context
140 Pass an existing zmq.Context instance, otherwise the client will create its own.
141 Pass an existing zmq.Context instance, otherwise the client will create its own.
141 debug : bool
142 debug : bool
142 flag for lots of message printing for debug purposes
143 flag for lots of message printing for debug purposes
143 timeout : int/float
144 timeout : int/float
144 time (in seconds) to wait for connection replies from the Hub
145 time (in seconds) to wait for connection replies from the Hub
145 [Default: 10]
146 [Default: 10]
146
147
147 #-------------- session related args ----------------
148 #-------------- session related args ----------------
148
149
149 config : Config object
150 config : Config object
150 If specified, this will be relayed to the Session for configuration
151 If specified, this will be relayed to the Session for configuration
151 username : str
152 username : str
152 set username for the session object
153 set username for the session object
153 packer : str (import_string) or callable
154 packer : str (import_string) or callable
154 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
155 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
155 function to serialize messages. Must support same input as
156 function to serialize messages. Must support same input as
156 JSON, and output must be bytes.
157 JSON, and output must be bytes.
157 You can pass a callable directly as `pack`
158 You can pass a callable directly as `pack`
158 unpacker : str (import_string) or callable
159 unpacker : str (import_string) or callable
159 The inverse of packer. Only necessary if packer is specified as *not* one
160 The inverse of packer. Only necessary if packer is specified as *not* one
160 of 'json' or 'pickle'.
161 of 'json' or 'pickle'.
161
162
162 #-------------- ssh related args ----------------
163 #-------------- ssh related args ----------------
163 # These are args for configuring the ssh tunnel to be used
164 # These are args for configuring the ssh tunnel to be used
164 # credentials are used to forward connections over ssh to the Controller
165 # credentials are used to forward connections over ssh to the Controller
165 # Note that the ip given in `addr` needs to be relative to sshserver
166 # Note that the ip given in `addr` needs to be relative to sshserver
166 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
167 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
167 # and set sshserver as the same machine the Controller is on. However,
168 # and set sshserver as the same machine the Controller is on. However,
168 # the only requirement is that sshserver is able to see the Controller
169 # the only requirement is that sshserver is able to see the Controller
169 # (i.e. is within the same trusted network).
170 # (i.e. is within the same trusted network).
170
171
171 sshserver : str
172 sshserver : str
172 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
173 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
173 If keyfile or password is specified, and this is not, it will default to
174 If keyfile or password is specified, and this is not, it will default to
174 the ip given in addr.
175 the ip given in addr.
175 sshkey : str; path to ssh private key file
176 sshkey : str; path to ssh private key file
176 This specifies a key to be used in ssh login, default None.
177 This specifies a key to be used in ssh login, default None.
177 Regular default ssh keys will be used without specifying this argument.
178 Regular default ssh keys will be used without specifying this argument.
178 password : str
179 password : str
179 Your ssh password to sshserver. Note that if this is left None,
180 Your ssh password to sshserver. Note that if this is left None,
180 you will be prompted for it if passwordless key based login is unavailable.
181 you will be prompted for it if passwordless key based login is unavailable.
181 paramiko : bool
182 paramiko : bool
182 flag for whether to use paramiko instead of shell ssh for tunneling.
183 flag for whether to use paramiko instead of shell ssh for tunneling.
183 [default: True on win32, False else]
184 [default: True on win32, False else]
184
185
185 ------- exec authentication args -------
186 ------- exec authentication args -------
186 If even localhost is untrusted, you can have some protection against
187 If even localhost is untrusted, you can have some protection against
187 unauthorized execution by signing messages with HMAC digests.
188 unauthorized execution by signing messages with HMAC digests.
188 Messages are still sent as cleartext, so if someone can snoop your
189 Messages are still sent as cleartext, so if someone can snoop your
189 loopback traffic this will not protect your privacy, but will prevent
190 loopback traffic this will not protect your privacy, but will prevent
190 unauthorized execution.
191 unauthorized execution.
191
192
192 exec_key : str
193 exec_key : str
193 an authentication key or file containing a key
194 an authentication key or file containing a key
194 default: None
195 default: None
195
196
196
197
197 Attributes
198 Attributes
198 ----------
199 ----------
199
200
200 ids : list of int engine IDs
201 ids : list of int engine IDs
201 requesting the ids attribute always synchronizes
202 requesting the ids attribute always synchronizes
202 the registration state. To request ids without synchronization,
203 the registration state. To request ids without synchronization,
203 use semi-private _ids attributes.
204 use semi-private _ids attributes.
204
205
205 history : list of msg_ids
206 history : list of msg_ids
206 a list of msg_ids, keeping track of all the execution
207 a list of msg_ids, keeping track of all the execution
207 messages you have submitted in order.
208 messages you have submitted in order.
208
209
209 outstanding : set of msg_ids
210 outstanding : set of msg_ids
210 a set of msg_ids that have been submitted, but whose
211 a set of msg_ids that have been submitted, but whose
211 results have not yet been received.
212 results have not yet been received.
212
213
213 results : dict
214 results : dict
214 a dict of all our results, keyed by msg_id
215 a dict of all our results, keyed by msg_id
215
216
216 block : bool
217 block : bool
217 determines default behavior when block not specified
218 determines default behavior when block not specified
218 in execution methods
219 in execution methods
219
220
220 Methods
221 Methods
221 -------
222 -------
222
223
223 spin
224 spin
224 flushes incoming results and registration state changes
225 flushes incoming results and registration state changes
225 control methods spin, and requesting `ids` also ensures up to date
226 control methods spin, and requesting `ids` also ensures up to date
226
227
227 wait
228 wait
228 wait on one or more msg_ids
229 wait on one or more msg_ids
229
230
230 execution methods
231 execution methods
231 apply
232 apply
232 legacy: execute, run
233 legacy: execute, run
233
234
234 data movement
235 data movement
235 push, pull, scatter, gather
236 push, pull, scatter, gather
236
237
237 query methods
238 query methods
238 queue_status, get_result, purge, result_status
239 queue_status, get_result, purge, result_status
239
240
240 control methods
241 control methods
241 abort, shutdown
242 abort, shutdown
242
243
243 """
244 """
244
245
245
246
246 block = Bool(False)
247 block = Bool(False)
247 outstanding = Set()
248 outstanding = Set()
248 results = Instance('collections.defaultdict', (dict,))
249 results = Instance('collections.defaultdict', (dict,))
249 metadata = Instance('collections.defaultdict', (Metadata,))
250 metadata = Instance('collections.defaultdict', (Metadata,))
250 history = List()
251 history = List()
251 debug = Bool(False)
252 debug = Bool(False)
253 _spin_thread = Any()
254 _stop_spinning = Any()
252
255
253 profile=Unicode()
256 profile=Unicode()
254 def _profile_default(self):
257 def _profile_default(self):
255 if BaseIPythonApplication.initialized():
258 if BaseIPythonApplication.initialized():
256 # an IPython app *might* be running, try to get its profile
259 # an IPython app *might* be running, try to get its profile
257 try:
260 try:
258 return BaseIPythonApplication.instance().profile
261 return BaseIPythonApplication.instance().profile
259 except (AttributeError, MultipleInstanceError):
262 except (AttributeError, MultipleInstanceError):
260 # could be a *different* subclass of config.Application,
263 # could be a *different* subclass of config.Application,
261 # which would raise one of these two errors.
264 # which would raise one of these two errors.
262 return u'default'
265 return u'default'
263 else:
266 else:
264 return u'default'
267 return u'default'
265
268
266
269
267 _outstanding_dict = Instance('collections.defaultdict', (set,))
270 _outstanding_dict = Instance('collections.defaultdict', (set,))
268 _ids = List()
271 _ids = List()
269 _connected=Bool(False)
272 _connected=Bool(False)
270 _ssh=Bool(False)
273 _ssh=Bool(False)
271 _context = Instance('zmq.Context')
274 _context = Instance('zmq.Context')
272 _config = Dict()
275 _config = Dict()
273 _engines=Instance(util.ReverseDict, (), {})
276 _engines=Instance(util.ReverseDict, (), {})
274 # _hub_socket=Instance('zmq.Socket')
277 # _hub_socket=Instance('zmq.Socket')
275 _query_socket=Instance('zmq.Socket')
278 _query_socket=Instance('zmq.Socket')
276 _control_socket=Instance('zmq.Socket')
279 _control_socket=Instance('zmq.Socket')
277 _iopub_socket=Instance('zmq.Socket')
280 _iopub_socket=Instance('zmq.Socket')
278 _notification_socket=Instance('zmq.Socket')
281 _notification_socket=Instance('zmq.Socket')
279 _mux_socket=Instance('zmq.Socket')
282 _mux_socket=Instance('zmq.Socket')
280 _task_socket=Instance('zmq.Socket')
283 _task_socket=Instance('zmq.Socket')
281 _task_scheme=Unicode()
284 _task_scheme=Unicode()
282 _closed = False
285 _closed = False
283 _ignored_control_replies=Integer(0)
286 _ignored_control_replies=Integer(0)
284 _ignored_hub_replies=Integer(0)
287 _ignored_hub_replies=Integer(0)
285
288
286 def __new__(self, *args, **kw):
289 def __new__(self, *args, **kw):
287 # don't raise on positional args
290 # don't raise on positional args
288 return HasTraits.__new__(self, **kw)
291 return HasTraits.__new__(self, **kw)
289
292
290 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
293 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
291 context=None, debug=False, exec_key=None,
294 context=None, debug=False, exec_key=None,
292 sshserver=None, sshkey=None, password=None, paramiko=None,
295 sshserver=None, sshkey=None, password=None, paramiko=None,
293 timeout=10, **extra_args
296 timeout=10, **extra_args
294 ):
297 ):
295 if profile:
298 if profile:
296 super(Client, self).__init__(debug=debug, profile=profile)
299 super(Client, self).__init__(debug=debug, profile=profile)
297 else:
300 else:
298 super(Client, self).__init__(debug=debug)
301 super(Client, self).__init__(debug=debug)
299 if context is None:
302 if context is None:
300 context = zmq.Context.instance()
303 context = zmq.Context.instance()
301 self._context = context
304 self._context = context
305 self._stop_spinning = Event()
302
306
303 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
307 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
304 if self._cd is not None:
308 if self._cd is not None:
305 if url_or_file is None:
309 if url_or_file is None:
306 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
310 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
307 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
311 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
308 " Please specify at least one of url_or_file or profile."
312 " Please specify at least one of url_or_file or profile."
309
313
310 if not util.is_url(url_or_file):
314 if not util.is_url(url_or_file):
311 # it's not a url, try for a file
315 # it's not a url, try for a file
312 if not os.path.exists(url_or_file):
316 if not os.path.exists(url_or_file):
313 if self._cd:
317 if self._cd:
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
318 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
319 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 with open(url_or_file) as f:
320 with open(url_or_file) as f:
317 cfg = json.loads(f.read())
321 cfg = json.loads(f.read())
318 else:
322 else:
319 cfg = {'url':url_or_file}
323 cfg = {'url':url_or_file}
320
324
321 # sync defaults from args, json:
325 # sync defaults from args, json:
322 if sshserver:
326 if sshserver:
323 cfg['ssh'] = sshserver
327 cfg['ssh'] = sshserver
324 if exec_key:
328 if exec_key:
325 cfg['exec_key'] = exec_key
329 cfg['exec_key'] = exec_key
326 exec_key = cfg['exec_key']
330 exec_key = cfg['exec_key']
327 location = cfg.setdefault('location', None)
331 location = cfg.setdefault('location', None)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
332 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 url = cfg['url']
333 url = cfg['url']
330 proto,addr,port = util.split_url(url)
334 proto,addr,port = util.split_url(url)
331 if location is not None and addr == '127.0.0.1':
335 if location is not None and addr == '127.0.0.1':
332 # location specified, and connection is expected to be local
336 # location specified, and connection is expected to be local
333 if location not in LOCAL_IPS and not sshserver:
337 if location not in LOCAL_IPS and not sshserver:
334 # load ssh from JSON *only* if the controller is not on
338 # load ssh from JSON *only* if the controller is not on
335 # this machine
339 # this machine
336 sshserver=cfg['ssh']
340 sshserver=cfg['ssh']
337 if location not in LOCAL_IPS and not sshserver:
341 if location not in LOCAL_IPS and not sshserver:
338 # warn if no ssh specified, but SSH is probably needed
342 # warn if no ssh specified, but SSH is probably needed
339 # This is only a warning, because the most likely cause
343 # This is only a warning, because the most likely cause
340 # is a local Controller on a laptop whose IP is dynamic
344 # is a local Controller on a laptop whose IP is dynamic
341 warnings.warn("""
345 warnings.warn("""
342 Controller appears to be listening on localhost, but not on this machine.
346 Controller appears to be listening on localhost, but not on this machine.
343 If this is true, you should specify Client(...,sshserver='you@%s')
347 If this is true, you should specify Client(...,sshserver='you@%s')
344 or instruct your controller to listen on an external IP."""%location,
348 or instruct your controller to listen on an external IP."""%location,
345 RuntimeWarning)
349 RuntimeWarning)
346 elif not sshserver:
350 elif not sshserver:
347 # otherwise sync with cfg
351 # otherwise sync with cfg
348 sshserver = cfg['ssh']
352 sshserver = cfg['ssh']
349
353
350 self._config = cfg
354 self._config = cfg
351
355
352 self._ssh = bool(sshserver or sshkey or password)
356 self._ssh = bool(sshserver or sshkey or password)
353 if self._ssh and sshserver is None:
357 if self._ssh and sshserver is None:
354 # default to ssh via localhost
358 # default to ssh via localhost
355 sshserver = url.split('://')[1].split(':')[0]
359 sshserver = url.split('://')[1].split(':')[0]
356 if self._ssh and password is None:
360 if self._ssh and password is None:
357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
361 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 password=False
362 password=False
359 else:
363 else:
360 password = getpass("SSH Password for %s: "%sshserver)
364 password = getpass("SSH Password for %s: "%sshserver)
361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
365 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362
366
363 # configure and construct the session
367 # configure and construct the session
364 if exec_key is not None:
368 if exec_key is not None:
365 if os.path.isfile(exec_key):
369 if os.path.isfile(exec_key):
366 extra_args['keyfile'] = exec_key
370 extra_args['keyfile'] = exec_key
367 else:
371 else:
368 exec_key = util.asbytes(exec_key)
372 exec_key = util.asbytes(exec_key)
369 extra_args['key'] = exec_key
373 extra_args['key'] = exec_key
370 self.session = Session(**extra_args)
374 self.session = Session(**extra_args)
371
375
372 self._query_socket = self._context.socket(zmq.DEALER)
376 self._query_socket = self._context.socket(zmq.DEALER)
373 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
377 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
374 if self._ssh:
378 if self._ssh:
375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
379 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 else:
380 else:
377 self._query_socket.connect(url)
381 self._query_socket.connect(url)
378
382
379 self.session.debug = self.debug
383 self.session.debug = self.debug
380
384
381 self._notification_handlers = {'registration_notification' : self._register_engine,
385 self._notification_handlers = {'registration_notification' : self._register_engine,
382 'unregistration_notification' : self._unregister_engine,
386 'unregistration_notification' : self._unregister_engine,
383 'shutdown_notification' : lambda msg: self.close(),
387 'shutdown_notification' : lambda msg: self.close(),
384 }
388 }
385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
389 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 'apply_reply' : self._handle_apply_reply}
390 'apply_reply' : self._handle_apply_reply}
387 self._connect(sshserver, ssh_kwargs, timeout)
391 self._connect(sshserver, ssh_kwargs, timeout)
388
392
389 def __del__(self):
393 def __del__(self):
390 """cleanup sockets, but _not_ context."""
394 """cleanup sockets, but _not_ context."""
391 self.close()
395 self.close()
392
396
393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
397 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 if ipython_dir is None:
398 if ipython_dir is None:
395 ipython_dir = get_ipython_dir()
399 ipython_dir = get_ipython_dir()
396 if profile_dir is not None:
400 if profile_dir is not None:
397 try:
401 try:
398 self._cd = ProfileDir.find_profile_dir(profile_dir)
402 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 return
403 return
400 except ProfileDirError:
404 except ProfileDirError:
401 pass
405 pass
402 elif profile is not None:
406 elif profile is not None:
403 try:
407 try:
404 self._cd = ProfileDir.find_profile_dir_by_name(
408 self._cd = ProfileDir.find_profile_dir_by_name(
405 ipython_dir, profile)
409 ipython_dir, profile)
406 return
410 return
407 except ProfileDirError:
411 except ProfileDirError:
408 pass
412 pass
409 self._cd = None
413 self._cd = None
410
414
411 def _update_engines(self, engines):
415 def _update_engines(self, engines):
412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
416 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 for k,v in engines.iteritems():
417 for k,v in engines.iteritems():
414 eid = int(k)
418 eid = int(k)
415 self._engines[eid] = v
419 self._engines[eid] = v
416 self._ids.append(eid)
420 self._ids.append(eid)
417 self._ids = sorted(self._ids)
421 self._ids = sorted(self._ids)
418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
422 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 self._task_scheme == 'pure' and self._task_socket:
423 self._task_scheme == 'pure' and self._task_socket:
420 self._stop_scheduling_tasks()
424 self._stop_scheduling_tasks()
421
425
422 def _stop_scheduling_tasks(self):
426 def _stop_scheduling_tasks(self):
423 """Stop scheduling tasks because an engine has been unregistered
427 """Stop scheduling tasks because an engine has been unregistered
424 from a pure ZMQ scheduler.
428 from a pure ZMQ scheduler.
425 """
429 """
426 self._task_socket.close()
430 self._task_socket.close()
427 self._task_socket = None
431 self._task_socket = None
428 msg = "An engine has been unregistered, and we are using pure " +\
432 msg = "An engine has been unregistered, and we are using pure " +\
429 "ZMQ task scheduling. Task farming will be disabled."
433 "ZMQ task scheduling. Task farming will be disabled."
430 if self.outstanding:
434 if self.outstanding:
431 msg += " If you were running tasks when this happened, " +\
435 msg += " If you were running tasks when this happened, " +\
432 "some `outstanding` msg_ids may never resolve."
436 "some `outstanding` msg_ids may never resolve."
433 warnings.warn(msg, RuntimeWarning)
437 warnings.warn(msg, RuntimeWarning)
434
438
435 def _build_targets(self, targets):
439 def _build_targets(self, targets):
436 """Turn valid target IDs or 'all' into two lists:
440 """Turn valid target IDs or 'all' into two lists:
437 (int_ids, uuids).
441 (int_ids, uuids).
438 """
442 """
439 if not self._ids:
443 if not self._ids:
440 # flush notification socket if no engines yet, just in case
444 # flush notification socket if no engines yet, just in case
441 if not self.ids:
445 if not self.ids:
442 raise error.NoEnginesRegistered("Can't build targets without any engines")
446 raise error.NoEnginesRegistered("Can't build targets without any engines")
443
447
444 if targets is None:
448 if targets is None:
445 targets = self._ids
449 targets = self._ids
446 elif isinstance(targets, basestring):
450 elif isinstance(targets, basestring):
447 if targets.lower() == 'all':
451 if targets.lower() == 'all':
448 targets = self._ids
452 targets = self._ids
449 else:
453 else:
450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
454 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 elif isinstance(targets, int):
455 elif isinstance(targets, int):
452 if targets < 0:
456 if targets < 0:
453 targets = self.ids[targets]
457 targets = self.ids[targets]
454 if targets not in self._ids:
458 if targets not in self._ids:
455 raise IndexError("No such engine: %i"%targets)
459 raise IndexError("No such engine: %i"%targets)
456 targets = [targets]
460 targets = [targets]
457
461
458 if isinstance(targets, slice):
462 if isinstance(targets, slice):
459 indices = range(len(self._ids))[targets]
463 indices = range(len(self._ids))[targets]
460 ids = self.ids
464 ids = self.ids
461 targets = [ ids[i] for i in indices ]
465 targets = [ ids[i] for i in indices ]
462
466
463 if not isinstance(targets, (tuple, list, xrange)):
467 if not isinstance(targets, (tuple, list, xrange)):
464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
468 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465
469
466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
470 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467
471
468 def _connect(self, sshserver, ssh_kwargs, timeout):
472 def _connect(self, sshserver, ssh_kwargs, timeout):
469 """setup all our socket connections to the cluster. This is called from
473 """setup all our socket connections to the cluster. This is called from
470 __init__."""
474 __init__."""
471
475
472 # Maybe allow reconnecting?
476 # Maybe allow reconnecting?
473 if self._connected:
477 if self._connected:
474 return
478 return
475 self._connected=True
479 self._connected=True
476
480
477 def connect_socket(s, url):
481 def connect_socket(s, url):
478 url = util.disambiguate_url(url, self._config['location'])
482 url = util.disambiguate_url(url, self._config['location'])
479 if self._ssh:
483 if self._ssh:
480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
484 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 else:
485 else:
482 return s.connect(url)
486 return s.connect(url)
483
487
484 self.session.send(self._query_socket, 'connection_request')
488 self.session.send(self._query_socket, 'connection_request')
485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
489 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 poller = zmq.Poller()
490 poller = zmq.Poller()
487 poller.register(self._query_socket, zmq.POLLIN)
491 poller.register(self._query_socket, zmq.POLLIN)
488 # poll expects milliseconds, timeout is seconds
492 # poll expects milliseconds, timeout is seconds
489 evts = poller.poll(timeout*1000)
493 evts = poller.poll(timeout*1000)
490 if not evts:
494 if not evts:
491 raise error.TimeoutError("Hub connection request timed out")
495 raise error.TimeoutError("Hub connection request timed out")
492 idents,msg = self.session.recv(self._query_socket,mode=0)
496 idents,msg = self.session.recv(self._query_socket,mode=0)
493 if self.debug:
497 if self.debug:
494 pprint(msg)
498 pprint(msg)
495 msg = Message(msg)
499 msg = Message(msg)
496 content = msg.content
500 content = msg.content
497 self._config['registration'] = dict(content)
501 self._config['registration'] = dict(content)
498 if content.status == 'ok':
502 if content.status == 'ok':
499 ident = self.session.bsession
503 ident = self.session.bsession
500 if content.mux:
504 if content.mux:
501 self._mux_socket = self._context.socket(zmq.DEALER)
505 self._mux_socket = self._context.socket(zmq.DEALER)
502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
506 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 connect_socket(self._mux_socket, content.mux)
507 connect_socket(self._mux_socket, content.mux)
504 if content.task:
508 if content.task:
505 self._task_scheme, task_addr = content.task
509 self._task_scheme, task_addr = content.task
506 self._task_socket = self._context.socket(zmq.DEALER)
510 self._task_socket = self._context.socket(zmq.DEALER)
507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
511 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 connect_socket(self._task_socket, task_addr)
512 connect_socket(self._task_socket, task_addr)
509 if content.notification:
513 if content.notification:
510 self._notification_socket = self._context.socket(zmq.SUB)
514 self._notification_socket = self._context.socket(zmq.SUB)
511 connect_socket(self._notification_socket, content.notification)
515 connect_socket(self._notification_socket, content.notification)
512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
516 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 # if content.query:
517 # if content.query:
514 # self._query_socket = self._context.socket(zmq.DEALER)
518 # self._query_socket = self._context.socket(zmq.DEALER)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
519 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
516 # connect_socket(self._query_socket, content.query)
520 # connect_socket(self._query_socket, content.query)
517 if content.control:
521 if content.control:
518 self._control_socket = self._context.socket(zmq.DEALER)
522 self._control_socket = self._context.socket(zmq.DEALER)
519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
523 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 connect_socket(self._control_socket, content.control)
524 connect_socket(self._control_socket, content.control)
521 if content.iopub:
525 if content.iopub:
522 self._iopub_socket = self._context.socket(zmq.SUB)
526 self._iopub_socket = self._context.socket(zmq.SUB)
523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
527 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
528 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 connect_socket(self._iopub_socket, content.iopub)
529 connect_socket(self._iopub_socket, content.iopub)
526 self._update_engines(dict(content.engines))
530 self._update_engines(dict(content.engines))
527 else:
531 else:
528 self._connected = False
532 self._connected = False
529 raise Exception("Failed to connect!")
533 raise Exception("Failed to connect!")
530
534
531 #--------------------------------------------------------------------------
535 #--------------------------------------------------------------------------
532 # handlers and callbacks for incoming messages
536 # handlers and callbacks for incoming messages
533 #--------------------------------------------------------------------------
537 #--------------------------------------------------------------------------
534
538
535 def _unwrap_exception(self, content):
539 def _unwrap_exception(self, content):
536 """unwrap exception, and remap engine_id to int."""
540 """unwrap exception, and remap engine_id to int."""
537 e = error.unwrap_exception(content)
541 e = error.unwrap_exception(content)
538 # print e.traceback
542 # print e.traceback
539 if e.engine_info:
543 if e.engine_info:
540 e_uuid = e.engine_info['engine_uuid']
544 e_uuid = e.engine_info['engine_uuid']
541 eid = self._engines[e_uuid]
545 eid = self._engines[e_uuid]
542 e.engine_info['engine_id'] = eid
546 e.engine_info['engine_id'] = eid
543 return e
547 return e
544
548
545 def _extract_metadata(self, header, parent, content):
549 def _extract_metadata(self, header, parent, content):
546 md = {'msg_id' : parent['msg_id'],
550 md = {'msg_id' : parent['msg_id'],
547 'received' : datetime.now(),
551 'received' : datetime.now(),
548 'engine_uuid' : header.get('engine', None),
552 'engine_uuid' : header.get('engine', None),
549 'follow' : parent.get('follow', []),
553 'follow' : parent.get('follow', []),
550 'after' : parent.get('after', []),
554 'after' : parent.get('after', []),
551 'status' : content['status'],
555 'status' : content['status'],
552 }
556 }
553
557
554 if md['engine_uuid'] is not None:
558 if md['engine_uuid'] is not None:
555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
559 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556
560
557 if 'date' in parent:
561 if 'date' in parent:
558 md['submitted'] = parent['date']
562 md['submitted'] = parent['date']
559 if 'started' in header:
563 if 'started' in header:
560 md['started'] = header['started']
564 md['started'] = header['started']
561 if 'date' in header:
565 if 'date' in header:
562 md['completed'] = header['date']
566 md['completed'] = header['date']
563 return md
567 return md
564
568
565 def _register_engine(self, msg):
569 def _register_engine(self, msg):
566 """Register a new engine, and update our connection info."""
570 """Register a new engine, and update our connection info."""
567 content = msg['content']
571 content = msg['content']
568 eid = content['id']
572 eid = content['id']
569 d = {eid : content['queue']}
573 d = {eid : content['queue']}
570 self._update_engines(d)
574 self._update_engines(d)
571
575
572 def _unregister_engine(self, msg):
576 def _unregister_engine(self, msg):
573 """Unregister an engine that has died."""
577 """Unregister an engine that has died."""
574 content = msg['content']
578 content = msg['content']
575 eid = int(content['id'])
579 eid = int(content['id'])
576 if eid in self._ids:
580 if eid in self._ids:
577 self._ids.remove(eid)
581 self._ids.remove(eid)
578 uuid = self._engines.pop(eid)
582 uuid = self._engines.pop(eid)
579
583
580 self._handle_stranded_msgs(eid, uuid)
584 self._handle_stranded_msgs(eid, uuid)
581
585
582 if self._task_socket and self._task_scheme == 'pure':
586 if self._task_socket and self._task_scheme == 'pure':
583 self._stop_scheduling_tasks()
587 self._stop_scheduling_tasks()
584
588
585 def _handle_stranded_msgs(self, eid, uuid):
589 def _handle_stranded_msgs(self, eid, uuid):
586 """Handle messages known to be on an engine when the engine unregisters.
590 """Handle messages known to be on an engine when the engine unregisters.
587
591
588 It is possible that this will fire prematurely - that is, an engine will
592 It is possible that this will fire prematurely - that is, an engine will
589 go down after completing a result, and the client will be notified
593 go down after completing a result, and the client will be notified
590 of the unregistration and later receive the successful result.
594 of the unregistration and later receive the successful result.
591 """
595 """
592
596
593 outstanding = self._outstanding_dict[uuid]
597 outstanding = self._outstanding_dict[uuid]
594
598
595 for msg_id in list(outstanding):
599 for msg_id in list(outstanding):
596 if msg_id in self.results:
600 if msg_id in self.results:
597 # we already
601 # we already
598 continue
602 continue
599 try:
603 try:
600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
604 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 except:
605 except:
602 content = error.wrap_exception()
606 content = error.wrap_exception()
603 # build a fake message:
607 # build a fake message:
604 parent = {}
608 parent = {}
605 header = {}
609 header = {}
606 parent['msg_id'] = msg_id
610 parent['msg_id'] = msg_id
607 header['engine'] = uuid
611 header['engine'] = uuid
608 header['date'] = datetime.now()
612 header['date'] = datetime.now()
609 msg = dict(parent_header=parent, header=header, content=content)
613 msg = dict(parent_header=parent, header=header, content=content)
610 self._handle_apply_reply(msg)
614 self._handle_apply_reply(msg)
611
615
612 def _handle_execute_reply(self, msg):
616 def _handle_execute_reply(self, msg):
613 """Save the reply to an execute_request into our results.
617 """Save the reply to an execute_request into our results.
614
618
615 execute messages are never actually used. apply is used instead.
619 execute messages are never actually used. apply is used instead.
616 """
620 """
617
621
618 parent = msg['parent_header']
622 parent = msg['parent_header']
619 msg_id = parent['msg_id']
623 msg_id = parent['msg_id']
620 if msg_id not in self.outstanding:
624 if msg_id not in self.outstanding:
621 if msg_id in self.history:
625 if msg_id in self.history:
622 print ("got stale result: %s"%msg_id)
626 print ("got stale result: %s"%msg_id)
623 else:
627 else:
624 print ("got unknown result: %s"%msg_id)
628 print ("got unknown result: %s"%msg_id)
625 else:
629 else:
626 self.outstanding.remove(msg_id)
630 self.outstanding.remove(msg_id)
627 self.results[msg_id] = self._unwrap_exception(msg['content'])
631 self.results[msg_id] = self._unwrap_exception(msg['content'])
628
632
629 def _handle_apply_reply(self, msg):
633 def _handle_apply_reply(self, msg):
630 """Save the reply to an apply_request into our results."""
634 """Save the reply to an apply_request into our results."""
631 parent = msg['parent_header']
635 parent = msg['parent_header']
632 msg_id = parent['msg_id']
636 msg_id = parent['msg_id']
633 if msg_id not in self.outstanding:
637 if msg_id not in self.outstanding:
634 if msg_id in self.history:
638 if msg_id in self.history:
635 print ("got stale result: %s"%msg_id)
639 print ("got stale result: %s"%msg_id)
636 print self.results[msg_id]
640 print self.results[msg_id]
637 print msg
641 print msg
638 else:
642 else:
639 print ("got unknown result: %s"%msg_id)
643 print ("got unknown result: %s"%msg_id)
640 else:
644 else:
641 self.outstanding.remove(msg_id)
645 self.outstanding.remove(msg_id)
642 content = msg['content']
646 content = msg['content']
643 header = msg['header']
647 header = msg['header']
644
648
645 # construct metadata:
649 # construct metadata:
646 md = self.metadata[msg_id]
650 md = self.metadata[msg_id]
647 md.update(self._extract_metadata(header, parent, content))
651 md.update(self._extract_metadata(header, parent, content))
648 # is this redundant?
652 # is this redundant?
649 self.metadata[msg_id] = md
653 self.metadata[msg_id] = md
650
654
651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
655 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 if msg_id in e_outstanding:
656 if msg_id in e_outstanding:
653 e_outstanding.remove(msg_id)
657 e_outstanding.remove(msg_id)
654
658
655 # construct result:
659 # construct result:
656 if content['status'] == 'ok':
660 if content['status'] == 'ok':
657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
661 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 elif content['status'] == 'aborted':
662 elif content['status'] == 'aborted':
659 self.results[msg_id] = error.TaskAborted(msg_id)
663 self.results[msg_id] = error.TaskAborted(msg_id)
660 elif content['status'] == 'resubmitted':
664 elif content['status'] == 'resubmitted':
661 # TODO: handle resubmission
665 # TODO: handle resubmission
662 pass
666 pass
663 else:
667 else:
664 self.results[msg_id] = self._unwrap_exception(content)
668 self.results[msg_id] = self._unwrap_exception(content)
665
669
666 def _flush_notifications(self):
670 def _flush_notifications(self):
667 """Flush notifications of engine registrations waiting
671 """Flush notifications of engine registrations waiting
668 in ZMQ queue."""
672 in ZMQ queue."""
669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
673 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 while msg is not None:
674 while msg is not None:
671 if self.debug:
675 if self.debug:
672 pprint(msg)
676 pprint(msg)
673 msg_type = msg['header']['msg_type']
677 msg_type = msg['header']['msg_type']
674 handler = self._notification_handlers.get(msg_type, None)
678 handler = self._notification_handlers.get(msg_type, None)
675 if handler is None:
679 if handler is None:
676 raise Exception("Unhandled message type: %s"%msg.msg_type)
680 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 else:
681 else:
678 handler(msg)
682 handler(msg)
679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
683 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680
684
681 def _flush_results(self, sock):
685 def _flush_results(self, sock):
682 """Flush task or queue results waiting in ZMQ queue."""
686 """Flush task or queue results waiting in ZMQ queue."""
683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
687 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 while msg is not None:
688 while msg is not None:
685 if self.debug:
689 if self.debug:
686 pprint(msg)
690 pprint(msg)
687 msg_type = msg['header']['msg_type']
691 msg_type = msg['header']['msg_type']
688 handler = self._queue_handlers.get(msg_type, None)
692 handler = self._queue_handlers.get(msg_type, None)
689 if handler is None:
693 if handler is None:
690 raise Exception("Unhandled message type: %s"%msg.msg_type)
694 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 else:
695 else:
692 handler(msg)
696 handler(msg)
693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
697 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694
698
695 def _flush_control(self, sock):
699 def _flush_control(self, sock):
696 """Flush replies from the control channel waiting
700 """Flush replies from the control channel waiting
697 in the ZMQ queue.
701 in the ZMQ queue.
698
702
699 Currently: ignore them."""
703 Currently: ignore them."""
700 if self._ignored_control_replies <= 0:
704 if self._ignored_control_replies <= 0:
701 return
705 return
702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 while msg is not None:
707 while msg is not None:
704 self._ignored_control_replies -= 1
708 self._ignored_control_replies -= 1
705 if self.debug:
709 if self.debug:
706 pprint(msg)
710 pprint(msg)
707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
711 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708
712
709 def _flush_ignored_control(self):
713 def _flush_ignored_control(self):
710 """flush ignored control replies"""
714 """flush ignored control replies"""
711 while self._ignored_control_replies > 0:
715 while self._ignored_control_replies > 0:
712 self.session.recv(self._control_socket)
716 self.session.recv(self._control_socket)
713 self._ignored_control_replies -= 1
717 self._ignored_control_replies -= 1
714
718
715 def _flush_ignored_hub_replies(self):
719 def _flush_ignored_hub_replies(self):
716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
720 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 while msg is not None:
721 while msg is not None:
718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
722 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719
723
720 def _flush_iopub(self, sock):
724 def _flush_iopub(self, sock):
721 """Flush replies from the iopub channel waiting
725 """Flush replies from the iopub channel waiting
722 in the ZMQ queue.
726 in the ZMQ queue.
723 """
727 """
724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
728 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 while msg is not None:
729 while msg is not None:
726 if self.debug:
730 if self.debug:
727 pprint(msg)
731 pprint(msg)
728 parent = msg['parent_header']
732 parent = msg['parent_header']
729 # ignore IOPub messages with no parent.
733 # ignore IOPub messages with no parent.
730 # Caused by print statements or warnings from before the first execution.
734 # Caused by print statements or warnings from before the first execution.
731 if not parent:
735 if not parent:
732 continue
736 continue
733 msg_id = parent['msg_id']
737 msg_id = parent['msg_id']
734 content = msg['content']
738 content = msg['content']
735 header = msg['header']
739 header = msg['header']
736 msg_type = msg['header']['msg_type']
740 msg_type = msg['header']['msg_type']
737
741
738 # init metadata:
742 # init metadata:
739 md = self.metadata[msg_id]
743 md = self.metadata[msg_id]
740
744
741 if msg_type == 'stream':
745 if msg_type == 'stream':
742 name = content['name']
746 name = content['name']
743 s = md[name] or ''
747 s = md[name] or ''
744 md[name] = s + content['data']
748 md[name] = s + content['data']
745 elif msg_type == 'pyerr':
749 elif msg_type == 'pyerr':
746 md.update({'pyerr' : self._unwrap_exception(content)})
750 md.update({'pyerr' : self._unwrap_exception(content)})
747 elif msg_type == 'pyin':
751 elif msg_type == 'pyin':
748 md.update({'pyin' : content['code']})
752 md.update({'pyin' : content['code']})
749 else:
753 else:
750 md.update({msg_type : content.get('data', '')})
754 md.update({msg_type : content.get('data', '')})
751
755
752 # reduntant?
756 # reduntant?
753 self.metadata[msg_id] = md
757 self.metadata[msg_id] = md
754
758
755 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
759 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
756
760
757 #--------------------------------------------------------------------------
761 #--------------------------------------------------------------------------
758 # len, getitem
762 # len, getitem
759 #--------------------------------------------------------------------------
763 #--------------------------------------------------------------------------
760
764
761 def __len__(self):
765 def __len__(self):
762 """len(client) returns # of engines."""
766 """len(client) returns # of engines."""
763 return len(self.ids)
767 return len(self.ids)
764
768
765 def __getitem__(self, key):
769 def __getitem__(self, key):
766 """index access returns DirectView multiplexer objects
770 """index access returns DirectView multiplexer objects
767
771
768 Must be int, slice, or list/tuple/xrange of ints"""
772 Must be int, slice, or list/tuple/xrange of ints"""
769 if not isinstance(key, (int, slice, tuple, list, xrange)):
773 if not isinstance(key, (int, slice, tuple, list, xrange)):
770 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
774 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
771 else:
775 else:
772 return self.direct_view(key)
776 return self.direct_view(key)
773
777
774 #--------------------------------------------------------------------------
778 #--------------------------------------------------------------------------
775 # Begin public methods
779 # Begin public methods
776 #--------------------------------------------------------------------------
780 #--------------------------------------------------------------------------
777
781
778 @property
782 @property
779 def ids(self):
783 def ids(self):
780 """Always up-to-date ids property."""
784 """Always up-to-date ids property."""
781 self._flush_notifications()
785 self._flush_notifications()
782 # always copy:
786 # always copy:
783 return list(self._ids)
787 return list(self._ids)
784
788
785 def close(self):
789 def close(self):
786 if self._closed:
790 if self._closed:
787 return
791 return
792 self.stop_spin_thread()
788 snames = filter(lambda n: n.endswith('socket'), dir(self))
793 snames = filter(lambda n: n.endswith('socket'), dir(self))
789 for socket in map(lambda name: getattr(self, name), snames):
794 for socket in map(lambda name: getattr(self, name), snames):
790 if isinstance(socket, zmq.Socket) and not socket.closed:
795 if isinstance(socket, zmq.Socket) and not socket.closed:
791 socket.close()
796 socket.close()
792 self._closed = True
797 self._closed = True
793
798
799 def _spin_every(self, interval=1):
800 """target func for use in spin_thread"""
801 while True:
802 if self._stop_spinning.is_set():
803 return
804 time.sleep(interval)
805 self.spin()
806
807 def spin_thread(self, interval=1):
808 """call Client.spin() in a background thread on some regular interval
809
810 This helps ensure that messages don't pile up too much in the zmq queue
811 while you are working on other things, or just leaving an idle terminal.
812
813 It also helps limit potential padding of the `received` timestamp
814 on AsyncResult objects, used for timings.
815
816 Parameters
817 ----------
818
819 interval : float, optional
820 The interval on which to spin the client in the background thread
821 (simply passed to time.sleep).
822
823 Notes
824 -----
825
826 For precision timing, you may want to use this method to put a bound
827 on the jitter (in seconds) in `received` timestamps used
828 in AsyncResult.wall_time.
829
830 """
831 if self._spin_thread is not None:
832 self.stop_spin_thread()
833 self._stop_spinning.clear()
834 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
835 self._spin_thread.daemon = True
836 self._spin_thread.start()
837
838 def stop_spin_thread(self):
839 """stop background spin_thread, if any"""
840 if self._spin_thread is not None:
841 self._stop_spinning.set()
842 self._spin_thread.join()
843 self._spin_thread = None
844
794 def spin(self):
845 def spin(self):
795 """Flush any registration notifications and execution results
846 """Flush any registration notifications and execution results
796 waiting in the ZMQ queue.
847 waiting in the ZMQ queue.
797 """
848 """
798 if self._notification_socket:
849 if self._notification_socket:
799 self._flush_notifications()
850 self._flush_notifications()
800 if self._mux_socket:
851 if self._mux_socket:
801 self._flush_results(self._mux_socket)
852 self._flush_results(self._mux_socket)
802 if self._task_socket:
853 if self._task_socket:
803 self._flush_results(self._task_socket)
854 self._flush_results(self._task_socket)
804 if self._control_socket:
855 if self._control_socket:
805 self._flush_control(self._control_socket)
856 self._flush_control(self._control_socket)
806 if self._iopub_socket:
857 if self._iopub_socket:
807 self._flush_iopub(self._iopub_socket)
858 self._flush_iopub(self._iopub_socket)
808 if self._query_socket:
859 if self._query_socket:
809 self._flush_ignored_hub_replies()
860 self._flush_ignored_hub_replies()
810
861
811 def wait(self, jobs=None, timeout=-1):
862 def wait(self, jobs=None, timeout=-1):
812 """waits on one or more `jobs`, for up to `timeout` seconds.
863 """waits on one or more `jobs`, for up to `timeout` seconds.
813
864
814 Parameters
865 Parameters
815 ----------
866 ----------
816
867
817 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
868 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
818 ints are indices to self.history
869 ints are indices to self.history
819 strs are msg_ids
870 strs are msg_ids
820 default: wait on all outstanding messages
871 default: wait on all outstanding messages
821 timeout : float
872 timeout : float
822 a time in seconds, after which to give up.
873 a time in seconds, after which to give up.
823 default is -1, which means no timeout
874 default is -1, which means no timeout
824
875
825 Returns
876 Returns
826 -------
877 -------
827
878
828 True : when all msg_ids are done
879 True : when all msg_ids are done
829 False : timeout reached, some msg_ids still outstanding
880 False : timeout reached, some msg_ids still outstanding
830 """
881 """
831 tic = time.time()
882 tic = time.time()
832 if jobs is None:
883 if jobs is None:
833 theids = self.outstanding
884 theids = self.outstanding
834 else:
885 else:
835 if isinstance(jobs, (int, basestring, AsyncResult)):
886 if isinstance(jobs, (int, basestring, AsyncResult)):
836 jobs = [jobs]
887 jobs = [jobs]
837 theids = set()
888 theids = set()
838 for job in jobs:
889 for job in jobs:
839 if isinstance(job, int):
890 if isinstance(job, int):
840 # index access
891 # index access
841 job = self.history[job]
892 job = self.history[job]
842 elif isinstance(job, AsyncResult):
893 elif isinstance(job, AsyncResult):
843 map(theids.add, job.msg_ids)
894 map(theids.add, job.msg_ids)
844 continue
895 continue
845 theids.add(job)
896 theids.add(job)
846 if not theids.intersection(self.outstanding):
897 if not theids.intersection(self.outstanding):
847 return True
898 return True
848 self.spin()
899 self.spin()
849 while theids.intersection(self.outstanding):
900 while theids.intersection(self.outstanding):
850 if timeout >= 0 and ( time.time()-tic ) > timeout:
901 if timeout >= 0 and ( time.time()-tic ) > timeout:
851 break
902 break
852 time.sleep(1e-3)
903 time.sleep(1e-3)
853 self.spin()
904 self.spin()
854 return len(theids.intersection(self.outstanding)) == 0
905 return len(theids.intersection(self.outstanding)) == 0
855
906
856 #--------------------------------------------------------------------------
907 #--------------------------------------------------------------------------
857 # Control methods
908 # Control methods
858 #--------------------------------------------------------------------------
909 #--------------------------------------------------------------------------
859
910
860 @spin_first
911 @spin_first
861 def clear(self, targets=None, block=None):
912 def clear(self, targets=None, block=None):
862 """Clear the namespace in target(s)."""
913 """Clear the namespace in target(s)."""
863 block = self.block if block is None else block
914 block = self.block if block is None else block
864 targets = self._build_targets(targets)[0]
915 targets = self._build_targets(targets)[0]
865 for t in targets:
916 for t in targets:
866 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
917 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
867 error = False
918 error = False
868 if block:
919 if block:
869 self._flush_ignored_control()
920 self._flush_ignored_control()
870 for i in range(len(targets)):
921 for i in range(len(targets)):
871 idents,msg = self.session.recv(self._control_socket,0)
922 idents,msg = self.session.recv(self._control_socket,0)
872 if self.debug:
923 if self.debug:
873 pprint(msg)
924 pprint(msg)
874 if msg['content']['status'] != 'ok':
925 if msg['content']['status'] != 'ok':
875 error = self._unwrap_exception(msg['content'])
926 error = self._unwrap_exception(msg['content'])
876 else:
927 else:
877 self._ignored_control_replies += len(targets)
928 self._ignored_control_replies += len(targets)
878 if error:
929 if error:
879 raise error
930 raise error
880
931
881
932
882 @spin_first
933 @spin_first
883 def abort(self, jobs=None, targets=None, block=None):
934 def abort(self, jobs=None, targets=None, block=None):
884 """Abort specific jobs from the execution queues of target(s).
935 """Abort specific jobs from the execution queues of target(s).
885
936
886 This is a mechanism to prevent jobs that have already been submitted
937 This is a mechanism to prevent jobs that have already been submitted
887 from executing.
938 from executing.
888
939
889 Parameters
940 Parameters
890 ----------
941 ----------
891
942
892 jobs : msg_id, list of msg_ids, or AsyncResult
943 jobs : msg_id, list of msg_ids, or AsyncResult
893 The jobs to be aborted
944 The jobs to be aborted
894
945
895 If unspecified/None: abort all outstanding jobs.
946 If unspecified/None: abort all outstanding jobs.
896
947
897 """
948 """
898 block = self.block if block is None else block
949 block = self.block if block is None else block
899 jobs = jobs if jobs is not None else list(self.outstanding)
950 jobs = jobs if jobs is not None else list(self.outstanding)
900 targets = self._build_targets(targets)[0]
951 targets = self._build_targets(targets)[0]
901
952
902 msg_ids = []
953 msg_ids = []
903 if isinstance(jobs, (basestring,AsyncResult)):
954 if isinstance(jobs, (basestring,AsyncResult)):
904 jobs = [jobs]
955 jobs = [jobs]
905 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
956 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
906 if bad_ids:
957 if bad_ids:
907 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
958 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
908 for j in jobs:
959 for j in jobs:
909 if isinstance(j, AsyncResult):
960 if isinstance(j, AsyncResult):
910 msg_ids.extend(j.msg_ids)
961 msg_ids.extend(j.msg_ids)
911 else:
962 else:
912 msg_ids.append(j)
963 msg_ids.append(j)
913 content = dict(msg_ids=msg_ids)
964 content = dict(msg_ids=msg_ids)
914 for t in targets:
965 for t in targets:
915 self.session.send(self._control_socket, 'abort_request',
966 self.session.send(self._control_socket, 'abort_request',
916 content=content, ident=t)
967 content=content, ident=t)
917 error = False
968 error = False
918 if block:
969 if block:
919 self._flush_ignored_control()
970 self._flush_ignored_control()
920 for i in range(len(targets)):
971 for i in range(len(targets)):
921 idents,msg = self.session.recv(self._control_socket,0)
972 idents,msg = self.session.recv(self._control_socket,0)
922 if self.debug:
973 if self.debug:
923 pprint(msg)
974 pprint(msg)
924 if msg['content']['status'] != 'ok':
975 if msg['content']['status'] != 'ok':
925 error = self._unwrap_exception(msg['content'])
976 error = self._unwrap_exception(msg['content'])
926 else:
977 else:
927 self._ignored_control_replies += len(targets)
978 self._ignored_control_replies += len(targets)
928 if error:
979 if error:
929 raise error
980 raise error
930
981
931 @spin_first
982 @spin_first
932 def shutdown(self, targets=None, restart=False, hub=False, block=None):
983 def shutdown(self, targets=None, restart=False, hub=False, block=None):
933 """Terminates one or more engine processes, optionally including the hub."""
984 """Terminates one or more engine processes, optionally including the hub."""
934 block = self.block if block is None else block
985 block = self.block if block is None else block
935 if hub:
986 if hub:
936 targets = 'all'
987 targets = 'all'
937 targets = self._build_targets(targets)[0]
988 targets = self._build_targets(targets)[0]
938 for t in targets:
989 for t in targets:
939 self.session.send(self._control_socket, 'shutdown_request',
990 self.session.send(self._control_socket, 'shutdown_request',
940 content={'restart':restart},ident=t)
991 content={'restart':restart},ident=t)
941 error = False
992 error = False
942 if block or hub:
993 if block or hub:
943 self._flush_ignored_control()
994 self._flush_ignored_control()
944 for i in range(len(targets)):
995 for i in range(len(targets)):
945 idents,msg = self.session.recv(self._control_socket, 0)
996 idents,msg = self.session.recv(self._control_socket, 0)
946 if self.debug:
997 if self.debug:
947 pprint(msg)
998 pprint(msg)
948 if msg['content']['status'] != 'ok':
999 if msg['content']['status'] != 'ok':
949 error = self._unwrap_exception(msg['content'])
1000 error = self._unwrap_exception(msg['content'])
950 else:
1001 else:
951 self._ignored_control_replies += len(targets)
1002 self._ignored_control_replies += len(targets)
952
1003
953 if hub:
1004 if hub:
954 time.sleep(0.25)
1005 time.sleep(0.25)
955 self.session.send(self._query_socket, 'shutdown_request')
1006 self.session.send(self._query_socket, 'shutdown_request')
956 idents,msg = self.session.recv(self._query_socket, 0)
1007 idents,msg = self.session.recv(self._query_socket, 0)
957 if self.debug:
1008 if self.debug:
958 pprint(msg)
1009 pprint(msg)
959 if msg['content']['status'] != 'ok':
1010 if msg['content']['status'] != 'ok':
960 error = self._unwrap_exception(msg['content'])
1011 error = self._unwrap_exception(msg['content'])
961
1012
962 if error:
1013 if error:
963 raise error
1014 raise error
964
1015
965 #--------------------------------------------------------------------------
1016 #--------------------------------------------------------------------------
966 # Execution related methods
1017 # Execution related methods
967 #--------------------------------------------------------------------------
1018 #--------------------------------------------------------------------------
968
1019
969 def _maybe_raise(self, result):
1020 def _maybe_raise(self, result):
970 """wrapper for maybe raising an exception if apply failed."""
1021 """wrapper for maybe raising an exception if apply failed."""
971 if isinstance(result, error.RemoteError):
1022 if isinstance(result, error.RemoteError):
972 raise result
1023 raise result
973
1024
974 return result
1025 return result
975
1026
976 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1027 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
977 ident=None):
1028 ident=None):
978 """construct and send an apply message via a socket.
1029 """construct and send an apply message via a socket.
979
1030
980 This is the principal method with which all engine execution is performed by views.
1031 This is the principal method with which all engine execution is performed by views.
981 """
1032 """
982
1033
983 assert not self._closed, "cannot use me anymore, I'm closed!"
1034 assert not self._closed, "cannot use me anymore, I'm closed!"
984 # defaults:
1035 # defaults:
985 args = args if args is not None else []
1036 args = args if args is not None else []
986 kwargs = kwargs if kwargs is not None else {}
1037 kwargs = kwargs if kwargs is not None else {}
987 subheader = subheader if subheader is not None else {}
1038 subheader = subheader if subheader is not None else {}
988
1039
989 # validate arguments
1040 # validate arguments
990 if not callable(f) and not isinstance(f, Reference):
1041 if not callable(f) and not isinstance(f, Reference):
991 raise TypeError("f must be callable, not %s"%type(f))
1042 raise TypeError("f must be callable, not %s"%type(f))
992 if not isinstance(args, (tuple, list)):
1043 if not isinstance(args, (tuple, list)):
993 raise TypeError("args must be tuple or list, not %s"%type(args))
1044 raise TypeError("args must be tuple or list, not %s"%type(args))
994 if not isinstance(kwargs, dict):
1045 if not isinstance(kwargs, dict):
995 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
996 if not isinstance(subheader, dict):
1047 if not isinstance(subheader, dict):
997 raise TypeError("subheader must be dict, not %s"%type(subheader))
1048 raise TypeError("subheader must be dict, not %s"%type(subheader))
998
1049
999 bufs = util.pack_apply_message(f,args,kwargs)
1050 bufs = util.pack_apply_message(f,args,kwargs)
1000
1051
1001 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1052 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1002 subheader=subheader, track=track)
1053 subheader=subheader, track=track)
1003
1054
1004 msg_id = msg['header']['msg_id']
1055 msg_id = msg['header']['msg_id']
1005 self.outstanding.add(msg_id)
1056 self.outstanding.add(msg_id)
1006 if ident:
1057 if ident:
1007 # possibly routed to a specific engine
1058 # possibly routed to a specific engine
1008 if isinstance(ident, list):
1059 if isinstance(ident, list):
1009 ident = ident[-1]
1060 ident = ident[-1]
1010 if ident in self._engines.values():
1061 if ident in self._engines.values():
1011 # save for later, in case of engine death
1062 # save for later, in case of engine death
1012 self._outstanding_dict[ident].add(msg_id)
1063 self._outstanding_dict[ident].add(msg_id)
1013 self.history.append(msg_id)
1064 self.history.append(msg_id)
1014 self.metadata[msg_id]['submitted'] = datetime.now()
1065 self.metadata[msg_id]['submitted'] = datetime.now()
1015
1066
1016 return msg
1067 return msg
1017
1068
1018 #--------------------------------------------------------------------------
1069 #--------------------------------------------------------------------------
1019 # construct a View object
1070 # construct a View object
1020 #--------------------------------------------------------------------------
1071 #--------------------------------------------------------------------------
1021
1072
1022 def load_balanced_view(self, targets=None):
1073 def load_balanced_view(self, targets=None):
1023 """construct a DirectView object.
1074 """construct a DirectView object.
1024
1075
1025 If no arguments are specified, create a LoadBalancedView
1076 If no arguments are specified, create a LoadBalancedView
1026 using all engines.
1077 using all engines.
1027
1078
1028 Parameters
1079 Parameters
1029 ----------
1080 ----------
1030
1081
1031 targets: list,slice,int,etc. [default: use all engines]
1082 targets: list,slice,int,etc. [default: use all engines]
1032 The subset of engines across which to load-balance
1083 The subset of engines across which to load-balance
1033 """
1084 """
1034 if targets == 'all':
1085 if targets == 'all':
1035 targets = None
1086 targets = None
1036 if targets is not None:
1087 if targets is not None:
1037 targets = self._build_targets(targets)[1]
1088 targets = self._build_targets(targets)[1]
1038 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1089 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1039
1090
1040 def direct_view(self, targets='all'):
1091 def direct_view(self, targets='all'):
1041 """construct a DirectView object.
1092 """construct a DirectView object.
1042
1093
1043 If no targets are specified, create a DirectView using all engines.
1094 If no targets are specified, create a DirectView using all engines.
1044
1095
1045 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1096 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1046 evaluate the target engines at each execution, whereas rc[:] will connect to
1097 evaluate the target engines at each execution, whereas rc[:] will connect to
1047 all *current* engines, and that list will not change.
1098 all *current* engines, and that list will not change.
1048
1099
1049 That is, 'all' will always use all engines, whereas rc[:] will not use
1100 That is, 'all' will always use all engines, whereas rc[:] will not use
1050 engines added after the DirectView is constructed.
1101 engines added after the DirectView is constructed.
1051
1102
1052 Parameters
1103 Parameters
1053 ----------
1104 ----------
1054
1105
1055 targets: list,slice,int,etc. [default: use all engines]
1106 targets: list,slice,int,etc. [default: use all engines]
1056 The engines to use for the View
1107 The engines to use for the View
1057 """
1108 """
1058 single = isinstance(targets, int)
1109 single = isinstance(targets, int)
1059 # allow 'all' to be lazily evaluated at each execution
1110 # allow 'all' to be lazily evaluated at each execution
1060 if targets != 'all':
1111 if targets != 'all':
1061 targets = self._build_targets(targets)[1]
1112 targets = self._build_targets(targets)[1]
1062 if single:
1113 if single:
1063 targets = targets[0]
1114 targets = targets[0]
1064 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1115 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1065
1116
1066 #--------------------------------------------------------------------------
1117 #--------------------------------------------------------------------------
1067 # Query methods
1118 # Query methods
1068 #--------------------------------------------------------------------------
1119 #--------------------------------------------------------------------------
1069
1120
1070 @spin_first
1121 @spin_first
1071 def get_result(self, indices_or_msg_ids=None, block=None):
1122 def get_result(self, indices_or_msg_ids=None, block=None):
1072 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1123 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1073
1124
1074 If the client already has the results, no request to the Hub will be made.
1125 If the client already has the results, no request to the Hub will be made.
1075
1126
1076 This is a convenient way to construct AsyncResult objects, which are wrappers
1127 This is a convenient way to construct AsyncResult objects, which are wrappers
1077 that include metadata about execution, and allow for awaiting results that
1128 that include metadata about execution, and allow for awaiting results that
1078 were not submitted by this Client.
1129 were not submitted by this Client.
1079
1130
1080 It can also be a convenient way to retrieve the metadata associated with
1131 It can also be a convenient way to retrieve the metadata associated with
1081 blocking execution, since it always retrieves
1132 blocking execution, since it always retrieves
1082
1133
1083 Examples
1134 Examples
1084 --------
1135 --------
1085 ::
1136 ::
1086
1137
1087 In [10]: r = client.apply()
1138 In [10]: r = client.apply()
1088
1139
1089 Parameters
1140 Parameters
1090 ----------
1141 ----------
1091
1142
1092 indices_or_msg_ids : integer history index, str msg_id, or list of either
1143 indices_or_msg_ids : integer history index, str msg_id, or list of either
1093 The indices or msg_ids of indices to be retrieved
1144 The indices or msg_ids of indices to be retrieved
1094
1145
1095 block : bool
1146 block : bool
1096 Whether to wait for the result to be done
1147 Whether to wait for the result to be done
1097
1148
1098 Returns
1149 Returns
1099 -------
1150 -------
1100
1151
1101 AsyncResult
1152 AsyncResult
1102 A single AsyncResult object will always be returned.
1153 A single AsyncResult object will always be returned.
1103
1154
1104 AsyncHubResult
1155 AsyncHubResult
1105 A subclass of AsyncResult that retrieves results from the Hub
1156 A subclass of AsyncResult that retrieves results from the Hub
1106
1157
1107 """
1158 """
1108 block = self.block if block is None else block
1159 block = self.block if block is None else block
1109 if indices_or_msg_ids is None:
1160 if indices_or_msg_ids is None:
1110 indices_or_msg_ids = -1
1161 indices_or_msg_ids = -1
1111
1162
1112 if not isinstance(indices_or_msg_ids, (list,tuple)):
1163 if not isinstance(indices_or_msg_ids, (list,tuple)):
1113 indices_or_msg_ids = [indices_or_msg_ids]
1164 indices_or_msg_ids = [indices_or_msg_ids]
1114
1165
1115 theids = []
1166 theids = []
1116 for id in indices_or_msg_ids:
1167 for id in indices_or_msg_ids:
1117 if isinstance(id, int):
1168 if isinstance(id, int):
1118 id = self.history[id]
1169 id = self.history[id]
1119 if not isinstance(id, basestring):
1170 if not isinstance(id, basestring):
1120 raise TypeError("indices must be str or int, not %r"%id)
1171 raise TypeError("indices must be str or int, not %r"%id)
1121 theids.append(id)
1172 theids.append(id)
1122
1173
1123 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1174 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1124 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1175 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1125
1176
1126 if remote_ids:
1177 if remote_ids:
1127 ar = AsyncHubResult(self, msg_ids=theids)
1178 ar = AsyncHubResult(self, msg_ids=theids)
1128 else:
1179 else:
1129 ar = AsyncResult(self, msg_ids=theids)
1180 ar = AsyncResult(self, msg_ids=theids)
1130
1181
1131 if block:
1182 if block:
1132 ar.wait()
1183 ar.wait()
1133
1184
1134 return ar
1185 return ar
1135
1186
1136 @spin_first
1187 @spin_first
1137 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1188 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1138 """Resubmit one or more tasks.
1189 """Resubmit one or more tasks.
1139
1190
1140 in-flight tasks may not be resubmitted.
1191 in-flight tasks may not be resubmitted.
1141
1192
1142 Parameters
1193 Parameters
1143 ----------
1194 ----------
1144
1195
1145 indices_or_msg_ids : integer history index, str msg_id, or list of either
1196 indices_or_msg_ids : integer history index, str msg_id, or list of either
1146 The indices or msg_ids of indices to be retrieved
1197 The indices or msg_ids of indices to be retrieved
1147
1198
1148 block : bool
1199 block : bool
1149 Whether to wait for the result to be done
1200 Whether to wait for the result to be done
1150
1201
1151 Returns
1202 Returns
1152 -------
1203 -------
1153
1204
1154 AsyncHubResult
1205 AsyncHubResult
1155 A subclass of AsyncResult that retrieves results from the Hub
1206 A subclass of AsyncResult that retrieves results from the Hub
1156
1207
1157 """
1208 """
1158 block = self.block if block is None else block
1209 block = self.block if block is None else block
1159 if indices_or_msg_ids is None:
1210 if indices_or_msg_ids is None:
1160 indices_or_msg_ids = -1
1211 indices_or_msg_ids = -1
1161
1212
1162 if not isinstance(indices_or_msg_ids, (list,tuple)):
1213 if not isinstance(indices_or_msg_ids, (list,tuple)):
1163 indices_or_msg_ids = [indices_or_msg_ids]
1214 indices_or_msg_ids = [indices_or_msg_ids]
1164
1215
1165 theids = []
1216 theids = []
1166 for id in indices_or_msg_ids:
1217 for id in indices_or_msg_ids:
1167 if isinstance(id, int):
1218 if isinstance(id, int):
1168 id = self.history[id]
1219 id = self.history[id]
1169 if not isinstance(id, basestring):
1220 if not isinstance(id, basestring):
1170 raise TypeError("indices must be str or int, not %r"%id)
1221 raise TypeError("indices must be str or int, not %r"%id)
1171 theids.append(id)
1222 theids.append(id)
1172
1223
1173 for msg_id in theids:
1224 for msg_id in theids:
1174 self.outstanding.discard(msg_id)
1225 self.outstanding.discard(msg_id)
1175 if msg_id in self.history:
1226 if msg_id in self.history:
1176 self.history.remove(msg_id)
1227 self.history.remove(msg_id)
1177 self.results.pop(msg_id, None)
1228 self.results.pop(msg_id, None)
1178 self.metadata.pop(msg_id, None)
1229 self.metadata.pop(msg_id, None)
1179 content = dict(msg_ids = theids)
1230 content = dict(msg_ids = theids)
1180
1231
1181 self.session.send(self._query_socket, 'resubmit_request', content)
1232 self.session.send(self._query_socket, 'resubmit_request', content)
1182
1233
1183 zmq.select([self._query_socket], [], [])
1234 zmq.select([self._query_socket], [], [])
1184 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1185 if self.debug:
1236 if self.debug:
1186 pprint(msg)
1237 pprint(msg)
1187 content = msg['content']
1238 content = msg['content']
1188 if content['status'] != 'ok':
1239 if content['status'] != 'ok':
1189 raise self._unwrap_exception(content)
1240 raise self._unwrap_exception(content)
1190
1241
1191 ar = AsyncHubResult(self, msg_ids=theids)
1242 ar = AsyncHubResult(self, msg_ids=theids)
1192
1243
1193 if block:
1244 if block:
1194 ar.wait()
1245 ar.wait()
1195
1246
1196 return ar
1247 return ar
1197
1248
1198 @spin_first
1249 @spin_first
1199 def result_status(self, msg_ids, status_only=True):
1250 def result_status(self, msg_ids, status_only=True):
1200 """Check on the status of the result(s) of the apply request with `msg_ids`.
1251 """Check on the status of the result(s) of the apply request with `msg_ids`.
1201
1252
1202 If status_only is False, then the actual results will be retrieved, else
1253 If status_only is False, then the actual results will be retrieved, else
1203 only the status of the results will be checked.
1254 only the status of the results will be checked.
1204
1255
1205 Parameters
1256 Parameters
1206 ----------
1257 ----------
1207
1258
1208 msg_ids : list of msg_ids
1259 msg_ids : list of msg_ids
1209 if int:
1260 if int:
1210 Passed as index to self.history for convenience.
1261 Passed as index to self.history for convenience.
1211 status_only : bool (default: True)
1262 status_only : bool (default: True)
1212 if False:
1263 if False:
1213 Retrieve the actual results of completed tasks.
1264 Retrieve the actual results of completed tasks.
1214
1265
1215 Returns
1266 Returns
1216 -------
1267 -------
1217
1268
1218 results : dict
1269 results : dict
1219 There will always be the keys 'pending' and 'completed', which will
1270 There will always be the keys 'pending' and 'completed', which will
1220 be lists of msg_ids that are incomplete or complete. If `status_only`
1271 be lists of msg_ids that are incomplete or complete. If `status_only`
1221 is False, then completed results will be keyed by their `msg_id`.
1272 is False, then completed results will be keyed by their `msg_id`.
1222 """
1273 """
1223 if not isinstance(msg_ids, (list,tuple)):
1274 if not isinstance(msg_ids, (list,tuple)):
1224 msg_ids = [msg_ids]
1275 msg_ids = [msg_ids]
1225
1276
1226 theids = []
1277 theids = []
1227 for msg_id in msg_ids:
1278 for msg_id in msg_ids:
1228 if isinstance(msg_id, int):
1279 if isinstance(msg_id, int):
1229 msg_id = self.history[msg_id]
1280 msg_id = self.history[msg_id]
1230 if not isinstance(msg_id, basestring):
1281 if not isinstance(msg_id, basestring):
1231 raise TypeError("msg_ids must be str, not %r"%msg_id)
1282 raise TypeError("msg_ids must be str, not %r"%msg_id)
1232 theids.append(msg_id)
1283 theids.append(msg_id)
1233
1284
1234 completed = []
1285 completed = []
1235 local_results = {}
1286 local_results = {}
1236
1287
1237 # comment this block out to temporarily disable local shortcut:
1288 # comment this block out to temporarily disable local shortcut:
1238 for msg_id in theids:
1289 for msg_id in theids:
1239 if msg_id in self.results:
1290 if msg_id in self.results:
1240 completed.append(msg_id)
1291 completed.append(msg_id)
1241 local_results[msg_id] = self.results[msg_id]
1292 local_results[msg_id] = self.results[msg_id]
1242 theids.remove(msg_id)
1293 theids.remove(msg_id)
1243
1294
1244 if theids: # some not locally cached
1295 if theids: # some not locally cached
1245 content = dict(msg_ids=theids, status_only=status_only)
1296 content = dict(msg_ids=theids, status_only=status_only)
1246 msg = self.session.send(self._query_socket, "result_request", content=content)
1297 msg = self.session.send(self._query_socket, "result_request", content=content)
1247 zmq.select([self._query_socket], [], [])
1298 zmq.select([self._query_socket], [], [])
1248 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1299 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1249 if self.debug:
1300 if self.debug:
1250 pprint(msg)
1301 pprint(msg)
1251 content = msg['content']
1302 content = msg['content']
1252 if content['status'] != 'ok':
1303 if content['status'] != 'ok':
1253 raise self._unwrap_exception(content)
1304 raise self._unwrap_exception(content)
1254 buffers = msg['buffers']
1305 buffers = msg['buffers']
1255 else:
1306 else:
1256 content = dict(completed=[],pending=[])
1307 content = dict(completed=[],pending=[])
1257
1308
1258 content['completed'].extend(completed)
1309 content['completed'].extend(completed)
1259
1310
1260 if status_only:
1311 if status_only:
1261 return content
1312 return content
1262
1313
1263 failures = []
1314 failures = []
1264 # load cached results into result:
1315 # load cached results into result:
1265 content.update(local_results)
1316 content.update(local_results)
1266
1317
1267 # update cache with results:
1318 # update cache with results:
1268 for msg_id in sorted(theids):
1319 for msg_id in sorted(theids):
1269 if msg_id in content['completed']:
1320 if msg_id in content['completed']:
1270 rec = content[msg_id]
1321 rec = content[msg_id]
1271 parent = rec['header']
1322 parent = rec['header']
1272 header = rec['result_header']
1323 header = rec['result_header']
1273 rcontent = rec['result_content']
1324 rcontent = rec['result_content']
1274 iodict = rec['io']
1325 iodict = rec['io']
1275 if isinstance(rcontent, str):
1326 if isinstance(rcontent, str):
1276 rcontent = self.session.unpack(rcontent)
1327 rcontent = self.session.unpack(rcontent)
1277
1328
1278 md = self.metadata[msg_id]
1329 md = self.metadata[msg_id]
1279 md.update(self._extract_metadata(header, parent, rcontent))
1330 md.update(self._extract_metadata(header, parent, rcontent))
1280 if rec.get('received'):
1331 if rec.get('received'):
1281 md['received'] = rec['received']
1332 md['received'] = rec['received']
1282 md.update(iodict)
1333 md.update(iodict)
1283
1334
1284 if rcontent['status'] == 'ok':
1335 if rcontent['status'] == 'ok':
1285 res,buffers = util.unserialize_object(buffers)
1336 res,buffers = util.unserialize_object(buffers)
1286 else:
1337 else:
1287 print rcontent
1338 print rcontent
1288 res = self._unwrap_exception(rcontent)
1339 res = self._unwrap_exception(rcontent)
1289 failures.append(res)
1340 failures.append(res)
1290
1341
1291 self.results[msg_id] = res
1342 self.results[msg_id] = res
1292 content[msg_id] = res
1343 content[msg_id] = res
1293
1344
1294 if len(theids) == 1 and failures:
1345 if len(theids) == 1 and failures:
1295 raise failures[0]
1346 raise failures[0]
1296
1347
1297 error.collect_exceptions(failures, "result_status")
1348 error.collect_exceptions(failures, "result_status")
1298 return content
1349 return content
1299
1350
1300 @spin_first
1351 @spin_first
1301 def queue_status(self, targets='all', verbose=False):
1352 def queue_status(self, targets='all', verbose=False):
1302 """Fetch the status of engine queues.
1353 """Fetch the status of engine queues.
1303
1354
1304 Parameters
1355 Parameters
1305 ----------
1356 ----------
1306
1357
1307 targets : int/str/list of ints/strs
1358 targets : int/str/list of ints/strs
1308 the engines whose states are to be queried.
1359 the engines whose states are to be queried.
1309 default : all
1360 default : all
1310 verbose : bool
1361 verbose : bool
1311 Whether to return lengths only, or lists of ids for each element
1362 Whether to return lengths only, or lists of ids for each element
1312 """
1363 """
1313 if targets == 'all':
1364 if targets == 'all':
1314 # allow 'all' to be evaluated on the engine
1365 # allow 'all' to be evaluated on the engine
1315 engine_ids = None
1366 engine_ids = None
1316 else:
1367 else:
1317 engine_ids = self._build_targets(targets)[1]
1368 engine_ids = self._build_targets(targets)[1]
1318 content = dict(targets=engine_ids, verbose=verbose)
1369 content = dict(targets=engine_ids, verbose=verbose)
1319 self.session.send(self._query_socket, "queue_request", content=content)
1370 self.session.send(self._query_socket, "queue_request", content=content)
1320 idents,msg = self.session.recv(self._query_socket, 0)
1371 idents,msg = self.session.recv(self._query_socket, 0)
1321 if self.debug:
1372 if self.debug:
1322 pprint(msg)
1373 pprint(msg)
1323 content = msg['content']
1374 content = msg['content']
1324 status = content.pop('status')
1375 status = content.pop('status')
1325 if status != 'ok':
1376 if status != 'ok':
1326 raise self._unwrap_exception(content)
1377 raise self._unwrap_exception(content)
1327 content = rekey(content)
1378 content = rekey(content)
1328 if isinstance(targets, int):
1379 if isinstance(targets, int):
1329 return content[targets]
1380 return content[targets]
1330 else:
1381 else:
1331 return content
1382 return content
1332
1383
1333 @spin_first
1384 @spin_first
1334 def purge_results(self, jobs=[], targets=[]):
1385 def purge_results(self, jobs=[], targets=[]):
1335 """Tell the Hub to forget results.
1386 """Tell the Hub to forget results.
1336
1387
1337 Individual results can be purged by msg_id, or the entire
1388 Individual results can be purged by msg_id, or the entire
1338 history of specific targets can be purged.
1389 history of specific targets can be purged.
1339
1390
1340 Use `purge_results('all')` to scrub everything from the Hub's db.
1391 Use `purge_results('all')` to scrub everything from the Hub's db.
1341
1392
1342 Parameters
1393 Parameters
1343 ----------
1394 ----------
1344
1395
1345 jobs : str or list of str or AsyncResult objects
1396 jobs : str or list of str or AsyncResult objects
1346 the msg_ids whose results should be forgotten.
1397 the msg_ids whose results should be forgotten.
1347 targets : int/str/list of ints/strs
1398 targets : int/str/list of ints/strs
1348 The targets, by int_id, whose entire history is to be purged.
1399 The targets, by int_id, whose entire history is to be purged.
1349
1400
1350 default : None
1401 default : None
1351 """
1402 """
1352 if not targets and not jobs:
1403 if not targets and not jobs:
1353 raise ValueError("Must specify at least one of `targets` and `jobs`")
1404 raise ValueError("Must specify at least one of `targets` and `jobs`")
1354 if targets:
1405 if targets:
1355 targets = self._build_targets(targets)[1]
1406 targets = self._build_targets(targets)[1]
1356
1407
1357 # construct msg_ids from jobs
1408 # construct msg_ids from jobs
1358 if jobs == 'all':
1409 if jobs == 'all':
1359 msg_ids = jobs
1410 msg_ids = jobs
1360 else:
1411 else:
1361 msg_ids = []
1412 msg_ids = []
1362 if isinstance(jobs, (basestring,AsyncResult)):
1413 if isinstance(jobs, (basestring,AsyncResult)):
1363 jobs = [jobs]
1414 jobs = [jobs]
1364 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1415 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1365 if bad_ids:
1416 if bad_ids:
1366 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1417 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1367 for j in jobs:
1418 for j in jobs:
1368 if isinstance(j, AsyncResult):
1419 if isinstance(j, AsyncResult):
1369 msg_ids.extend(j.msg_ids)
1420 msg_ids.extend(j.msg_ids)
1370 else:
1421 else:
1371 msg_ids.append(j)
1422 msg_ids.append(j)
1372
1423
1373 content = dict(engine_ids=targets, msg_ids=msg_ids)
1424 content = dict(engine_ids=targets, msg_ids=msg_ids)
1374 self.session.send(self._query_socket, "purge_request", content=content)
1425 self.session.send(self._query_socket, "purge_request", content=content)
1375 idents, msg = self.session.recv(self._query_socket, 0)
1426 idents, msg = self.session.recv(self._query_socket, 0)
1376 if self.debug:
1427 if self.debug:
1377 pprint(msg)
1428 pprint(msg)
1378 content = msg['content']
1429 content = msg['content']
1379 if content['status'] != 'ok':
1430 if content['status'] != 'ok':
1380 raise self._unwrap_exception(content)
1431 raise self._unwrap_exception(content)
1381
1432
1382 @spin_first
1433 @spin_first
1383 def hub_history(self):
1434 def hub_history(self):
1384 """Get the Hub's history
1435 """Get the Hub's history
1385
1436
1386 Just like the Client, the Hub has a history, which is a list of msg_ids.
1437 Just like the Client, the Hub has a history, which is a list of msg_ids.
1387 This will contain the history of all clients, and, depending on configuration,
1438 This will contain the history of all clients, and, depending on configuration,
1388 may contain history across multiple cluster sessions.
1439 may contain history across multiple cluster sessions.
1389
1440
1390 Any msg_id returned here is a valid argument to `get_result`.
1441 Any msg_id returned here is a valid argument to `get_result`.
1391
1442
1392 Returns
1443 Returns
1393 -------
1444 -------
1394
1445
1395 msg_ids : list of strs
1446 msg_ids : list of strs
1396 list of all msg_ids, ordered by task submission time.
1447 list of all msg_ids, ordered by task submission time.
1397 """
1448 """
1398
1449
1399 self.session.send(self._query_socket, "history_request", content={})
1450 self.session.send(self._query_socket, "history_request", content={})
1400 idents, msg = self.session.recv(self._query_socket, 0)
1451 idents, msg = self.session.recv(self._query_socket, 0)
1401
1452
1402 if self.debug:
1453 if self.debug:
1403 pprint(msg)
1454 pprint(msg)
1404 content = msg['content']
1455 content = msg['content']
1405 if content['status'] != 'ok':
1456 if content['status'] != 'ok':
1406 raise self._unwrap_exception(content)
1457 raise self._unwrap_exception(content)
1407 else:
1458 else:
1408 return content['history']
1459 return content['history']
1409
1460
1410 @spin_first
1461 @spin_first
1411 def db_query(self, query, keys=None):
1462 def db_query(self, query, keys=None):
1412 """Query the Hub's TaskRecord database
1463 """Query the Hub's TaskRecord database
1413
1464
1414 This will return a list of task record dicts that match `query`
1465 This will return a list of task record dicts that match `query`
1415
1466
1416 Parameters
1467 Parameters
1417 ----------
1468 ----------
1418
1469
1419 query : mongodb query dict
1470 query : mongodb query dict
1420 The search dict. See mongodb query docs for details.
1471 The search dict. See mongodb query docs for details.
1421 keys : list of strs [optional]
1472 keys : list of strs [optional]
1422 The subset of keys to be returned. The default is to fetch everything but buffers.
1473 The subset of keys to be returned. The default is to fetch everything but buffers.
1423 'msg_id' will *always* be included.
1474 'msg_id' will *always* be included.
1424 """
1475 """
1425 if isinstance(keys, basestring):
1476 if isinstance(keys, basestring):
1426 keys = [keys]
1477 keys = [keys]
1427 content = dict(query=query, keys=keys)
1478 content = dict(query=query, keys=keys)
1428 self.session.send(self._query_socket, "db_request", content=content)
1479 self.session.send(self._query_socket, "db_request", content=content)
1429 idents, msg = self.session.recv(self._query_socket, 0)
1480 idents, msg = self.session.recv(self._query_socket, 0)
1430 if self.debug:
1481 if self.debug:
1431 pprint(msg)
1482 pprint(msg)
1432 content = msg['content']
1483 content = msg['content']
1433 if content['status'] != 'ok':
1484 if content['status'] != 'ok':
1434 raise self._unwrap_exception(content)
1485 raise self._unwrap_exception(content)
1435
1486
1436 records = content['records']
1487 records = content['records']
1437
1488
1438 buffer_lens = content['buffer_lens']
1489 buffer_lens = content['buffer_lens']
1439 result_buffer_lens = content['result_buffer_lens']
1490 result_buffer_lens = content['result_buffer_lens']
1440 buffers = msg['buffers']
1491 buffers = msg['buffers']
1441 has_bufs = buffer_lens is not None
1492 has_bufs = buffer_lens is not None
1442 has_rbufs = result_buffer_lens is not None
1493 has_rbufs = result_buffer_lens is not None
1443 for i,rec in enumerate(records):
1494 for i,rec in enumerate(records):
1444 # relink buffers
1495 # relink buffers
1445 if has_bufs:
1496 if has_bufs:
1446 blen = buffer_lens[i]
1497 blen = buffer_lens[i]
1447 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1498 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1448 if has_rbufs:
1499 if has_rbufs:
1449 blen = result_buffer_lens[i]
1500 blen = result_buffer_lens[i]
1450 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1501 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1451
1502
1452 return records
1503 return records
1453
1504
1454 __all__ = [ 'Client' ]
1505 __all__ = [ 'Client' ]
@@ -1,319 +1,337 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
31
31
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33
33
34 def setup():
34 def setup():
35 add_engines(4, total=True)
35 add_engines(4, total=True)
36
36
37 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
38
38
39 def test_ids(self):
39 def test_ids(self):
40 n = len(self.client.ids)
40 n = len(self.client.ids)
41 self.add_engines(2)
41 self.add_engines(2)
42 self.assertEquals(len(self.client.ids), n+2)
42 self.assertEquals(len(self.client.ids), n+2)
43
43
44 def test_view_indexing(self):
44 def test_view_indexing(self):
45 """test index access for views"""
45 """test index access for views"""
46 self.minimum_engines(4)
46 self.minimum_engines(4)
47 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
48 v = self.client[:]
48 v = self.client[:]
49 self.assertEquals(v.targets, targets)
49 self.assertEquals(v.targets, targets)
50 t = self.client.ids[2]
50 t = self.client.ids[2]
51 v = self.client[t]
51 v = self.client[t]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, t)
53 self.assertEquals(v.targets, t)
54 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
55 v = self.client[t]
55 v = self.client[t]
56 self.assert_(isinstance(v, DirectView))
56 self.assert_(isinstance(v, DirectView))
57 self.assertEquals(v.targets, t)
57 self.assertEquals(v.targets, t)
58 v = self.client[::2]
58 v = self.client[::2]
59 self.assert_(isinstance(v, DirectView))
59 self.assert_(isinstance(v, DirectView))
60 self.assertEquals(v.targets, targets[::2])
60 self.assertEquals(v.targets, targets[::2])
61 v = self.client[1::3]
61 v = self.client[1::3]
62 self.assert_(isinstance(v, DirectView))
62 self.assert_(isinstance(v, DirectView))
63 self.assertEquals(v.targets, targets[1::3])
63 self.assertEquals(v.targets, targets[1::3])
64 v = self.client[:-3]
64 v = self.client[:-3]
65 self.assert_(isinstance(v, DirectView))
65 self.assert_(isinstance(v, DirectView))
66 self.assertEquals(v.targets, targets[:-3])
66 self.assertEquals(v.targets, targets[:-3])
67 v = self.client[-1]
67 v = self.client[-1]
68 self.assert_(isinstance(v, DirectView))
68 self.assert_(isinstance(v, DirectView))
69 self.assertEquals(v.targets, targets[-1])
69 self.assertEquals(v.targets, targets[-1])
70 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
71
71
72 def test_lbview_targets(self):
72 def test_lbview_targets(self):
73 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
74 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
75 self.assertEquals(v.targets, None)
75 self.assertEquals(v.targets, None)
76 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
77 self.assertEquals(v.targets, [self.client.ids[-1]])
77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
79 self.assertEquals(v.targets, None)
79 self.assertEquals(v.targets, None)
80
80
81 def test_dview_targets(self):
81 def test_dview_targets(self):
82 """test direct_view targets"""
82 """test direct_view targets"""
83 v = self.client.direct_view()
83 v = self.client.direct_view()
84 self.assertEquals(v.targets, 'all')
84 self.assertEquals(v.targets, 'all')
85 v = self.client.direct_view('all')
85 v = self.client.direct_view('all')
86 self.assertEquals(v.targets, 'all')
86 self.assertEquals(v.targets, 'all')
87 v = self.client.direct_view(-1)
87 v = self.client.direct_view(-1)
88 self.assertEquals(v.targets, self.client.ids[-1])
88 self.assertEquals(v.targets, self.client.ids[-1])
89
89
90 def test_lazy_all_targets(self):
90 def test_lazy_all_targets(self):
91 """test lazy evaluation of rc.direct_view('all')"""
91 """test lazy evaluation of rc.direct_view('all')"""
92 v = self.client.direct_view()
92 v = self.client.direct_view()
93 self.assertEquals(v.targets, 'all')
93 self.assertEquals(v.targets, 'all')
94
94
95 def double(x):
95 def double(x):
96 return x*2
96 return x*2
97 seq = range(100)
97 seq = range(100)
98 ref = [ double(x) for x in seq ]
98 ref = [ double(x) for x in seq ]
99
99
100 # add some engines, which should be used
100 # add some engines, which should be used
101 self.add_engines(1)
101 self.add_engines(1)
102 n1 = len(self.client.ids)
102 n1 = len(self.client.ids)
103
103
104 # simple apply
104 # simple apply
105 r = v.apply_sync(lambda : 1)
105 r = v.apply_sync(lambda : 1)
106 self.assertEquals(r, [1] * n1)
106 self.assertEquals(r, [1] * n1)
107
107
108 # map goes through remotefunction
108 # map goes through remotefunction
109 r = v.map_sync(double, seq)
109 r = v.map_sync(double, seq)
110 self.assertEquals(r, ref)
110 self.assertEquals(r, ref)
111
111
112 # add a couple more engines, and try again
112 # add a couple more engines, and try again
113 self.add_engines(2)
113 self.add_engines(2)
114 n2 = len(self.client.ids)
114 n2 = len(self.client.ids)
115 self.assertNotEquals(n2, n1)
115 self.assertNotEquals(n2, n1)
116
116
117 # apply
117 # apply
118 r = v.apply_sync(lambda : 1)
118 r = v.apply_sync(lambda : 1)
119 self.assertEquals(r, [1] * n2)
119 self.assertEquals(r, [1] * n2)
120
120
121 # map
121 # map
122 r = v.map_sync(double, seq)
122 r = v.map_sync(double, seq)
123 self.assertEquals(r, ref)
123 self.assertEquals(r, ref)
124
124
125 def test_targets(self):
125 def test_targets(self):
126 """test various valid targets arguments"""
126 """test various valid targets arguments"""
127 build = self.client._build_targets
127 build = self.client._build_targets
128 ids = self.client.ids
128 ids = self.client.ids
129 idents,targets = build(None)
129 idents,targets = build(None)
130 self.assertEquals(ids, targets)
130 self.assertEquals(ids, targets)
131
131
132 def test_clear(self):
132 def test_clear(self):
133 """test clear behavior"""
133 """test clear behavior"""
134 self.minimum_engines(2)
134 self.minimum_engines(2)
135 v = self.client[:]
135 v = self.client[:]
136 v.block=True
136 v.block=True
137 v.push(dict(a=5))
137 v.push(dict(a=5))
138 v.pull('a')
138 v.pull('a')
139 id0 = self.client.ids[-1]
139 id0 = self.client.ids[-1]
140 self.client.clear(targets=id0, block=True)
140 self.client.clear(targets=id0, block=True)
141 a = self.client[:-1].get('a')
141 a = self.client[:-1].get('a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.client.clear(block=True)
143 self.client.clear(block=True)
144 for i in self.client.ids:
144 for i in self.client.ids:
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146
146
147 def test_get_result(self):
147 def test_get_result(self):
148 """test getting results from the Hub."""
148 """test getting results from the Hub."""
149 c = clientmod.Client(profile='iptest')
149 c = clientmod.Client(profile='iptest')
150 t = c.ids[-1]
150 t = c.ids[-1]
151 ar = c[t].apply_async(wait, 1)
151 ar = c[t].apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = self.client.get_result(ar.msg_ids)
154 ahr = self.client.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEquals(ahr.get(), ar.get())
156 self.assertEquals(ahr.get(), ar.get())
157 ar2 = self.client.get_result(ar.msg_ids)
157 ar2 = self.client.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.close()
159 c.close()
160
160
161 def test_ids_list(self):
161 def test_ids_list(self):
162 """test client.ids"""
162 """test client.ids"""
163 ids = self.client.ids
163 ids = self.client.ids
164 self.assertEquals(ids, self.client._ids)
164 self.assertEquals(ids, self.client._ids)
165 self.assertFalse(ids is self.client._ids)
165 self.assertFalse(ids is self.client._ids)
166 ids.remove(ids[-1])
166 ids.remove(ids[-1])
167 self.assertNotEquals(ids, self.client._ids)
167 self.assertNotEquals(ids, self.client._ids)
168
168
169 def test_queue_status(self):
169 def test_queue_status(self):
170 ids = self.client.ids
170 ids = self.client.ids
171 id0 = ids[0]
171 id0 = ids[0]
172 qs = self.client.queue_status(targets=id0)
172 qs = self.client.queue_status(targets=id0)
173 self.assertTrue(isinstance(qs, dict))
173 self.assertTrue(isinstance(qs, dict))
174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 allqs = self.client.queue_status()
175 allqs = self.client.queue_status()
176 self.assertTrue(isinstance(allqs, dict))
176 self.assertTrue(isinstance(allqs, dict))
177 intkeys = list(allqs.keys())
177 intkeys = list(allqs.keys())
178 intkeys.remove('unassigned')
178 intkeys.remove('unassigned')
179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 unassigned = allqs.pop('unassigned')
180 unassigned = allqs.pop('unassigned')
181 for eid,qs in allqs.items():
181 for eid,qs in allqs.items():
182 self.assertTrue(isinstance(qs, dict))
182 self.assertTrue(isinstance(qs, dict))
183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184
184
185 def test_shutdown(self):
185 def test_shutdown(self):
186 ids = self.client.ids
186 ids = self.client.ids
187 id0 = ids[0]
187 id0 = ids[0]
188 self.client.shutdown(id0, block=True)
188 self.client.shutdown(id0, block=True)
189 while id0 in self.client.ids:
189 while id0 in self.client.ids:
190 time.sleep(0.1)
190 time.sleep(0.1)
191 self.client.spin()
191 self.client.spin()
192
192
193 self.assertRaises(IndexError, lambda : self.client[id0])
193 self.assertRaises(IndexError, lambda : self.client[id0])
194
194
195 def test_result_status(self):
195 def test_result_status(self):
196 pass
196 pass
197 # to be written
197 # to be written
198
198
199 def test_db_query_dt(self):
199 def test_db_query_dt(self):
200 """test db query by date"""
200 """test db query by date"""
201 hist = self.client.hub_history()
201 hist = self.client.hub_history()
202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 tic = middle['submitted']
203 tic = middle['submitted']
204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 self.assertEquals(len(before)+len(after),len(hist))
206 self.assertEquals(len(before)+len(after),len(hist))
207 for b in before:
207 for b in before:
208 self.assertTrue(b['submitted'] < tic)
208 self.assertTrue(b['submitted'] < tic)
209 for a in after:
209 for a in after:
210 self.assertTrue(a['submitted'] >= tic)
210 self.assertTrue(a['submitted'] >= tic)
211 same = self.client.db_query({'submitted' : tic})
211 same = self.client.db_query({'submitted' : tic})
212 for s in same:
212 for s in same:
213 self.assertTrue(s['submitted'] == tic)
213 self.assertTrue(s['submitted'] == tic)
214
214
215 def test_db_query_keys(self):
215 def test_db_query_keys(self):
216 """test extracting subset of record keys"""
216 """test extracting subset of record keys"""
217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 for rec in found:
218 for rec in found:
219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220
220
221 def test_db_query_default_keys(self):
221 def test_db_query_default_keys(self):
222 """default db_query excludes buffers"""
222 """default db_query excludes buffers"""
223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 for rec in found:
224 for rec in found:
225 keys = set(rec.keys())
225 keys = set(rec.keys())
226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228
228
229 def test_db_query_msg_id(self):
229 def test_db_query_msg_id(self):
230 """ensure msg_id is always in db queries"""
230 """ensure msg_id is always in db queries"""
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 for rec in found:
232 for rec in found:
233 self.assertTrue('msg_id' in rec.keys())
233 self.assertTrue('msg_id' in rec.keys())
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 for rec in found:
235 for rec in found:
236 self.assertTrue('msg_id' in rec.keys())
236 self.assertTrue('msg_id' in rec.keys())
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 for rec in found:
238 for rec in found:
239 self.assertTrue('msg_id' in rec.keys())
239 self.assertTrue('msg_id' in rec.keys())
240
240
241 def test_db_query_in(self):
241 def test_db_query_in(self):
242 """test db query with '$in','$nin' operators"""
242 """test db query with '$in','$nin' operators"""
243 hist = self.client.hub_history()
243 hist = self.client.hub_history()
244 even = hist[::2]
244 even = hist[::2]
245 odd = hist[1::2]
245 odd = hist[1::2]
246 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
246 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
247 found = [ r['msg_id'] for r in recs ]
247 found = [ r['msg_id'] for r in recs ]
248 self.assertEquals(set(even), set(found))
248 self.assertEquals(set(even), set(found))
249 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
249 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
250 found = [ r['msg_id'] for r in recs ]
250 found = [ r['msg_id'] for r in recs ]
251 self.assertEquals(set(odd), set(found))
251 self.assertEquals(set(odd), set(found))
252
252
253 def test_hub_history(self):
253 def test_hub_history(self):
254 hist = self.client.hub_history()
254 hist = self.client.hub_history()
255 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
255 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
256 recdict = {}
256 recdict = {}
257 for rec in recs:
257 for rec in recs:
258 recdict[rec['msg_id']] = rec
258 recdict[rec['msg_id']] = rec
259
259
260 latest = datetime(1984,1,1)
260 latest = datetime(1984,1,1)
261 for msg_id in hist:
261 for msg_id in hist:
262 rec = recdict[msg_id]
262 rec = recdict[msg_id]
263 newt = rec['submitted']
263 newt = rec['submitted']
264 self.assertTrue(newt >= latest)
264 self.assertTrue(newt >= latest)
265 latest = newt
265 latest = newt
266 ar = self.client[-1].apply_async(lambda : 1)
266 ar = self.client[-1].apply_async(lambda : 1)
267 ar.get()
267 ar.get()
268 time.sleep(0.25)
268 time.sleep(0.25)
269 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
269 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
270
270
271 def test_resubmit(self):
271 def test_resubmit(self):
272 def f():
272 def f():
273 import random
273 import random
274 return random.random()
274 return random.random()
275 v = self.client.load_balanced_view()
275 v = self.client.load_balanced_view()
276 ar = v.apply_async(f)
276 ar = v.apply_async(f)
277 r1 = ar.get(1)
277 r1 = ar.get(1)
278 # give the Hub a chance to notice:
278 # give the Hub a chance to notice:
279 time.sleep(0.5)
279 time.sleep(0.5)
280 ahr = self.client.resubmit(ar.msg_ids)
280 ahr = self.client.resubmit(ar.msg_ids)
281 r2 = ahr.get(1)
281 r2 = ahr.get(1)
282 self.assertFalse(r1 == r2)
282 self.assertFalse(r1 == r2)
283
283
284 def test_resubmit_inflight(self):
284 def test_resubmit_inflight(self):
285 """ensure ValueError on resubmit of inflight task"""
285 """ensure ValueError on resubmit of inflight task"""
286 v = self.client.load_balanced_view()
286 v = self.client.load_balanced_view()
287 ar = v.apply_async(time.sleep,1)
287 ar = v.apply_async(time.sleep,1)
288 # give the message a chance to arrive
288 # give the message a chance to arrive
289 time.sleep(0.2)
289 time.sleep(0.2)
290 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
290 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
291 ar.get(2)
291 ar.get(2)
292
292
293 def test_resubmit_badkey(self):
293 def test_resubmit_badkey(self):
294 """ensure KeyError on resubmit of nonexistant task"""
294 """ensure KeyError on resubmit of nonexistant task"""
295 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
295 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
296
296
297 def test_purge_results(self):
297 def test_purge_results(self):
298 # ensure there are some tasks
298 # ensure there are some tasks
299 for i in range(5):
299 for i in range(5):
300 self.client[:].apply_sync(lambda : 1)
300 self.client[:].apply_sync(lambda : 1)
301 # Wait for the Hub to realise the result is done:
301 # Wait for the Hub to realise the result is done:
302 # This prevents a race condition, where we
302 # This prevents a race condition, where we
303 # might purge a result the Hub still thinks is pending.
303 # might purge a result the Hub still thinks is pending.
304 time.sleep(0.1)
304 time.sleep(0.1)
305 rc2 = clientmod.Client(profile='iptest')
305 rc2 = clientmod.Client(profile='iptest')
306 hist = self.client.hub_history()
306 hist = self.client.hub_history()
307 ahr = rc2.get_result([hist[-1]])
307 ahr = rc2.get_result([hist[-1]])
308 ahr.wait(10)
308 ahr.wait(10)
309 self.client.purge_results(hist[-1])
309 self.client.purge_results(hist[-1])
310 newhist = self.client.hub_history()
310 newhist = self.client.hub_history()
311 self.assertEquals(len(newhist)+1,len(hist))
311 self.assertEquals(len(newhist)+1,len(hist))
312 rc2.spin()
312 rc2.spin()
313 rc2.close()
313 rc2.close()
314
314
315 def test_purge_all_results(self):
315 def test_purge_all_results(self):
316 self.client.purge_results('all')
316 self.client.purge_results('all')
317 hist = self.client.hub_history()
317 hist = self.client.hub_history()
318 self.assertEquals(len(hist), 0)
318 self.assertEquals(len(hist), 0)
319
320 def test_spin_thread(self):
321 self.client.spin_thread(0.01)
322 ar = self.client[-1].apply_async(lambda : 1)
323 time.sleep(0.1)
324 self.assertTrue(ar.wall_time < 0.1,
325 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
326 )
327
328 def test_stop_spin_thread(self):
329 self.client.spin_thread(0.01)
330 self.client.stop_spin_thread()
331 ar = self.client[-1].apply_async(lambda : 1)
332 time.sleep(0.15)
333 self.assertTrue(ar.wall_time > 0.1,
334 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
335 )
336
319
337
General Comments 0
You need to be logged in to leave comments. Login now