##// END OF EJS Templates
view.abort() aborts all outstanding tasks...
MinRK -
Show More
@@ -1,1440 +1,1443 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 import time
21 import time
22 import warnings
22 import warnings
23 from datetime import datetime
23 from datetime import datetime
24 from getpass import getpass
24 from getpass import getpass
25 from pprint import pprint
25 from pprint import pprint
26
26
27 pjoin = os.path.join
27 pjoin = os.path.join
28
28
29 import zmq
29 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
30 # from zmq.eventloop import ioloop, zmqstream
31
31
32 from IPython.config.configurable import MultipleInstanceError
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
33 from IPython.core.application import BaseIPythonApplication
34
34
35 from IPython.utils.jsonutil import rekey
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 Dict, List, Bool, Set)
39 Dict, List, Bool, Set)
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import error
43 from IPython.parallel import error
44 from IPython.parallel import util
44 from IPython.parallel import util
45
45
46 from IPython.zmq.session import Session, Message
46 from IPython.zmq.session import Session, Message
47
47
48 from .asyncresult import AsyncResult, AsyncHubResult
48 from .asyncresult import AsyncResult, AsyncHubResult
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from .view import DirectView, LoadBalancedView
50 from .view import DirectView, LoadBalancedView
51
51
52 if sys.version_info[0] >= 3:
52 if sys.version_info[0] >= 3:
53 # xrange is used in a couple 'isinstance' tests in py2
53 # xrange is used in a couple 'isinstance' tests in py2
54 # should be just 'range' in 3k
54 # should be just 'range' in 3k
55 xrange = range
55 xrange = range
56
56
57 #--------------------------------------------------------------------------
57 #--------------------------------------------------------------------------
58 # Decorators for Client methods
58 # Decorators for Client methods
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60
60
61 @decorator
61 @decorator
62 def spin_first(f, self, *args, **kwargs):
62 def spin_first(f, self, *args, **kwargs):
63 """Call spin() to sync state prior to calling the method."""
63 """Call spin() to sync state prior to calling the method."""
64 self.spin()
64 self.spin()
65 return f(self, *args, **kwargs)
65 return f(self, *args, **kwargs)
66
66
67
67
68 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71
71
72 class Metadata(dict):
72 class Metadata(dict):
73 """Subclass of dict for initializing metadata values.
73 """Subclass of dict for initializing metadata values.
74
74
75 Attribute access works on keys.
75 Attribute access works on keys.
76
76
77 These objects have a strict set of keys - errors will raise if you try
77 These objects have a strict set of keys - errors will raise if you try
78 to add new keys.
78 to add new keys.
79 """
79 """
80 def __init__(self, *args, **kwargs):
80 def __init__(self, *args, **kwargs):
81 dict.__init__(self)
81 dict.__init__(self)
82 md = {'msg_id' : None,
82 md = {'msg_id' : None,
83 'submitted' : None,
83 'submitted' : None,
84 'started' : None,
84 'started' : None,
85 'completed' : None,
85 'completed' : None,
86 'received' : None,
86 'received' : None,
87 'engine_uuid' : None,
87 'engine_uuid' : None,
88 'engine_id' : None,
88 'engine_id' : None,
89 'follow' : None,
89 'follow' : None,
90 'after' : None,
90 'after' : None,
91 'status' : None,
91 'status' : None,
92
92
93 'pyin' : None,
93 'pyin' : None,
94 'pyout' : None,
94 'pyout' : None,
95 'pyerr' : None,
95 'pyerr' : None,
96 'stdout' : '',
96 'stdout' : '',
97 'stderr' : '',
97 'stderr' : '',
98 }
98 }
99 self.update(md)
99 self.update(md)
100 self.update(dict(*args, **kwargs))
100 self.update(dict(*args, **kwargs))
101
101
102 def __getattr__(self, key):
102 def __getattr__(self, key):
103 """getattr aliased to getitem"""
103 """getattr aliased to getitem"""
104 if key in self.iterkeys():
104 if key in self.iterkeys():
105 return self[key]
105 return self[key]
106 else:
106 else:
107 raise AttributeError(key)
107 raise AttributeError(key)
108
108
109 def __setattr__(self, key, value):
109 def __setattr__(self, key, value):
110 """setattr aliased to setitem, with strict"""
110 """setattr aliased to setitem, with strict"""
111 if key in self.iterkeys():
111 if key in self.iterkeys():
112 self[key] = value
112 self[key] = value
113 else:
113 else:
114 raise AttributeError(key)
114 raise AttributeError(key)
115
115
116 def __setitem__(self, key, value):
116 def __setitem__(self, key, value):
117 """strict static key enforcement"""
117 """strict static key enforcement"""
118 if key in self.iterkeys():
118 if key in self.iterkeys():
119 dict.__setitem__(self, key, value)
119 dict.__setitem__(self, key, value)
120 else:
120 else:
121 raise KeyError(key)
121 raise KeyError(key)
122
122
123
123
124 class Client(HasTraits):
124 class Client(HasTraits):
125 """A semi-synchronous client to the IPython ZMQ cluster
125 """A semi-synchronous client to the IPython ZMQ cluster
126
126
127 Parameters
127 Parameters
128 ----------
128 ----------
129
129
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 Connection information for the Hub's registration. If a json connector
131 Connection information for the Hub's registration. If a json connector
132 file is given, then likely no further configuration is necessary.
132 file is given, then likely no further configuration is necessary.
133 [Default: use profile]
133 [Default: use profile]
134 profile : bytes
134 profile : bytes
135 The name of the Cluster profile to be used to find connector information.
135 The name of the Cluster profile to be used to find connector information.
136 If run from an IPython application, the default profile will be the same
136 If run from an IPython application, the default profile will be the same
137 as the running application, otherwise it will be 'default'.
137 as the running application, otherwise it will be 'default'.
138 context : zmq.Context
138 context : zmq.Context
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 debug : bool
140 debug : bool
141 flag for lots of message printing for debug purposes
141 flag for lots of message printing for debug purposes
142 timeout : int/float
142 timeout : int/float
143 time (in seconds) to wait for connection replies from the Hub
143 time (in seconds) to wait for connection replies from the Hub
144 [Default: 10]
144 [Default: 10]
145
145
146 #-------------- session related args ----------------
146 #-------------- session related args ----------------
147
147
148 config : Config object
148 config : Config object
149 If specified, this will be relayed to the Session for configuration
149 If specified, this will be relayed to the Session for configuration
150 username : str
150 username : str
151 set username for the session object
151 set username for the session object
152 packer : str (import_string) or callable
152 packer : str (import_string) or callable
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 function to serialize messages. Must support same input as
154 function to serialize messages. Must support same input as
155 JSON, and output must be bytes.
155 JSON, and output must be bytes.
156 You can pass a callable directly as `pack`
156 You can pass a callable directly as `pack`
157 unpacker : str (import_string) or callable
157 unpacker : str (import_string) or callable
158 The inverse of packer. Only necessary if packer is specified as *not* one
158 The inverse of packer. Only necessary if packer is specified as *not* one
159 of 'json' or 'pickle'.
159 of 'json' or 'pickle'.
160
160
161 #-------------- ssh related args ----------------
161 #-------------- ssh related args ----------------
162 # These are args for configuring the ssh tunnel to be used
162 # These are args for configuring the ssh tunnel to be used
163 # credentials are used to forward connections over ssh to the Controller
163 # credentials are used to forward connections over ssh to the Controller
164 # Note that the ip given in `addr` needs to be relative to sshserver
164 # Note that the ip given in `addr` needs to be relative to sshserver
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # and set sshserver as the same machine the Controller is on. However,
166 # and set sshserver as the same machine the Controller is on. However,
167 # the only requirement is that sshserver is able to see the Controller
167 # the only requirement is that sshserver is able to see the Controller
168 # (i.e. is within the same trusted network).
168 # (i.e. is within the same trusted network).
169
169
170 sshserver : str
170 sshserver : str
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 If keyfile or password is specified, and this is not, it will default to
172 If keyfile or password is specified, and this is not, it will default to
173 the ip given in addr.
173 the ip given in addr.
174 sshkey : str; path to ssh private key file
174 sshkey : str; path to ssh private key file
175 This specifies a key to be used in ssh login, default None.
175 This specifies a key to be used in ssh login, default None.
176 Regular default ssh keys will be used without specifying this argument.
176 Regular default ssh keys will be used without specifying this argument.
177 password : str
177 password : str
178 Your ssh password to sshserver. Note that if this is left None,
178 Your ssh password to sshserver. Note that if this is left None,
179 you will be prompted for it if passwordless key based login is unavailable.
179 you will be prompted for it if passwordless key based login is unavailable.
180 paramiko : bool
180 paramiko : bool
181 flag for whether to use paramiko instead of shell ssh for tunneling.
181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 [default: True on win32, False else]
182 [default: True on win32, False else]
183
183
184 ------- exec authentication args -------
184 ------- exec authentication args -------
185 If even localhost is untrusted, you can have some protection against
185 If even localhost is untrusted, you can have some protection against
186 unauthorized execution by signing messages with HMAC digests.
186 unauthorized execution by signing messages with HMAC digests.
187 Messages are still sent as cleartext, so if someone can snoop your
187 Messages are still sent as cleartext, so if someone can snoop your
188 loopback traffic this will not protect your privacy, but will prevent
188 loopback traffic this will not protect your privacy, but will prevent
189 unauthorized execution.
189 unauthorized execution.
190
190
191 exec_key : str
191 exec_key : str
192 an authentication key or file containing a key
192 an authentication key or file containing a key
193 default: None
193 default: None
194
194
195
195
196 Attributes
196 Attributes
197 ----------
197 ----------
198
198
199 ids : list of int engine IDs
199 ids : list of int engine IDs
200 requesting the ids attribute always synchronizes
200 requesting the ids attribute always synchronizes
201 the registration state. To request ids without synchronization,
201 the registration state. To request ids without synchronization,
202 use semi-private _ids attributes.
202 use semi-private _ids attributes.
203
203
204 history : list of msg_ids
204 history : list of msg_ids
205 a list of msg_ids, keeping track of all the execution
205 a list of msg_ids, keeping track of all the execution
206 messages you have submitted in order.
206 messages you have submitted in order.
207
207
208 outstanding : set of msg_ids
208 outstanding : set of msg_ids
209 a set of msg_ids that have been submitted, but whose
209 a set of msg_ids that have been submitted, but whose
210 results have not yet been received.
210 results have not yet been received.
211
211
212 results : dict
212 results : dict
213 a dict of all our results, keyed by msg_id
213 a dict of all our results, keyed by msg_id
214
214
215 block : bool
215 block : bool
216 determines default behavior when block not specified
216 determines default behavior when block not specified
217 in execution methods
217 in execution methods
218
218
219 Methods
219 Methods
220 -------
220 -------
221
221
222 spin
222 spin
223 flushes incoming results and registration state changes
223 flushes incoming results and registration state changes
224 control methods spin, and requesting `ids` also ensures up to date
224 control methods spin, and requesting `ids` also ensures up to date
225
225
226 wait
226 wait
227 wait on one or more msg_ids
227 wait on one or more msg_ids
228
228
229 execution methods
229 execution methods
230 apply
230 apply
231 legacy: execute, run
231 legacy: execute, run
232
232
233 data movement
233 data movement
234 push, pull, scatter, gather
234 push, pull, scatter, gather
235
235
236 query methods
236 query methods
237 queue_status, get_result, purge, result_status
237 queue_status, get_result, purge, result_status
238
238
239 control methods
239 control methods
240 abort, shutdown
240 abort, shutdown
241
241
242 """
242 """
243
243
244
244
245 block = Bool(False)
245 block = Bool(False)
246 outstanding = Set()
246 outstanding = Set()
247 results = Instance('collections.defaultdict', (dict,))
247 results = Instance('collections.defaultdict', (dict,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
249 history = List()
249 history = List()
250 debug = Bool(False)
250 debug = Bool(False)
251
251
252 profile=Unicode()
252 profile=Unicode()
253 def _profile_default(self):
253 def _profile_default(self):
254 if BaseIPythonApplication.initialized():
254 if BaseIPythonApplication.initialized():
255 # an IPython app *might* be running, try to get its profile
255 # an IPython app *might* be running, try to get its profile
256 try:
256 try:
257 return BaseIPythonApplication.instance().profile
257 return BaseIPythonApplication.instance().profile
258 except (AttributeError, MultipleInstanceError):
258 except (AttributeError, MultipleInstanceError):
259 # could be a *different* subclass of config.Application,
259 # could be a *different* subclass of config.Application,
260 # which would raise one of these two errors.
260 # which would raise one of these two errors.
261 return u'default'
261 return u'default'
262 else:
262 else:
263 return u'default'
263 return u'default'
264
264
265
265
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _ids = List()
267 _ids = List()
268 _connected=Bool(False)
268 _connected=Bool(False)
269 _ssh=Bool(False)
269 _ssh=Bool(False)
270 _context = Instance('zmq.Context')
270 _context = Instance('zmq.Context')
271 _config = Dict()
271 _config = Dict()
272 _engines=Instance(util.ReverseDict, (), {})
272 _engines=Instance(util.ReverseDict, (), {})
273 # _hub_socket=Instance('zmq.Socket')
273 # _hub_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
280 _task_scheme=Unicode()
280 _task_scheme=Unicode()
281 _closed = False
281 _closed = False
282 _ignored_control_replies=Integer(0)
282 _ignored_control_replies=Integer(0)
283 _ignored_hub_replies=Integer(0)
283 _ignored_hub_replies=Integer(0)
284
284
285 def __new__(self, *args, **kw):
285 def __new__(self, *args, **kw):
286 # don't raise on positional args
286 # don't raise on positional args
287 return HasTraits.__new__(self, **kw)
287 return HasTraits.__new__(self, **kw)
288
288
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 context=None, debug=False, exec_key=None,
290 context=None, debug=False, exec_key=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 timeout=10, **extra_args
292 timeout=10, **extra_args
293 ):
293 ):
294 if profile:
294 if profile:
295 super(Client, self).__init__(debug=debug, profile=profile)
295 super(Client, self).__init__(debug=debug, profile=profile)
296 else:
296 else:
297 super(Client, self).__init__(debug=debug)
297 super(Client, self).__init__(debug=debug)
298 if context is None:
298 if context is None:
299 context = zmq.Context.instance()
299 context = zmq.Context.instance()
300 self._context = context
300 self._context = context
301
301
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 if self._cd is not None:
303 if self._cd is not None:
304 if url_or_file is None:
304 if url_or_file is None:
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 " Please specify at least one of url_or_file or profile."
307 " Please specify at least one of url_or_file or profile."
308
308
309 if not util.is_url(url_or_file):
309 if not util.is_url(url_or_file):
310 # it's not a url, try for a file
310 # it's not a url, try for a file
311 if not os.path.exists(url_or_file):
311 if not os.path.exists(url_or_file):
312 if self._cd:
312 if self._cd:
313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 with open(url_or_file) as f:
315 with open(url_or_file) as f:
316 cfg = json.loads(f.read())
316 cfg = json.loads(f.read())
317 else:
317 else:
318 cfg = {'url':url_or_file}
318 cfg = {'url':url_or_file}
319
319
320 # sync defaults from args, json:
320 # sync defaults from args, json:
321 if sshserver:
321 if sshserver:
322 cfg['ssh'] = sshserver
322 cfg['ssh'] = sshserver
323 if exec_key:
323 if exec_key:
324 cfg['exec_key'] = exec_key
324 cfg['exec_key'] = exec_key
325 exec_key = cfg['exec_key']
325 exec_key = cfg['exec_key']
326 location = cfg.setdefault('location', None)
326 location = cfg.setdefault('location', None)
327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 url = cfg['url']
328 url = cfg['url']
329 proto,addr,port = util.split_url(url)
329 proto,addr,port = util.split_url(url)
330 if location is not None and addr == '127.0.0.1':
330 if location is not None and addr == '127.0.0.1':
331 # location specified, and connection is expected to be local
331 # location specified, and connection is expected to be local
332 if location not in LOCAL_IPS and not sshserver:
332 if location not in LOCAL_IPS and not sshserver:
333 # load ssh from JSON *only* if the controller is not on
333 # load ssh from JSON *only* if the controller is not on
334 # this machine
334 # this machine
335 sshserver=cfg['ssh']
335 sshserver=cfg['ssh']
336 if location not in LOCAL_IPS and not sshserver:
336 if location not in LOCAL_IPS and not sshserver:
337 # warn if no ssh specified, but SSH is probably needed
337 # warn if no ssh specified, but SSH is probably needed
338 # This is only a warning, because the most likely cause
338 # This is only a warning, because the most likely cause
339 # is a local Controller on a laptop whose IP is dynamic
339 # is a local Controller on a laptop whose IP is dynamic
340 warnings.warn("""
340 warnings.warn("""
341 Controller appears to be listening on localhost, but not on this machine.
341 Controller appears to be listening on localhost, but not on this machine.
342 If this is true, you should specify Client(...,sshserver='you@%s')
342 If this is true, you should specify Client(...,sshserver='you@%s')
343 or instruct your controller to listen on an external IP."""%location,
343 or instruct your controller to listen on an external IP."""%location,
344 RuntimeWarning)
344 RuntimeWarning)
345 elif not sshserver:
345 elif not sshserver:
346 # otherwise sync with cfg
346 # otherwise sync with cfg
347 sshserver = cfg['ssh']
347 sshserver = cfg['ssh']
348
348
349 self._config = cfg
349 self._config = cfg
350
350
351 self._ssh = bool(sshserver or sshkey or password)
351 self._ssh = bool(sshserver or sshkey or password)
352 if self._ssh and sshserver is None:
352 if self._ssh and sshserver is None:
353 # default to ssh via localhost
353 # default to ssh via localhost
354 sshserver = url.split('://')[1].split(':')[0]
354 sshserver = url.split('://')[1].split(':')[0]
355 if self._ssh and password is None:
355 if self._ssh and password is None:
356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 password=False
357 password=False
358 else:
358 else:
359 password = getpass("SSH Password for %s: "%sshserver)
359 password = getpass("SSH Password for %s: "%sshserver)
360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361
361
362 # configure and construct the session
362 # configure and construct the session
363 if exec_key is not None:
363 if exec_key is not None:
364 if os.path.isfile(exec_key):
364 if os.path.isfile(exec_key):
365 extra_args['keyfile'] = exec_key
365 extra_args['keyfile'] = exec_key
366 else:
366 else:
367 exec_key = util.asbytes(exec_key)
367 exec_key = util.asbytes(exec_key)
368 extra_args['key'] = exec_key
368 extra_args['key'] = exec_key
369 self.session = Session(**extra_args)
369 self.session = Session(**extra_args)
370
370
371 self._query_socket = self._context.socket(zmq.DEALER)
371 self._query_socket = self._context.socket(zmq.DEALER)
372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 if self._ssh:
373 if self._ssh:
374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 else:
375 else:
376 self._query_socket.connect(url)
376 self._query_socket.connect(url)
377
377
378 self.session.debug = self.debug
378 self.session.debug = self.debug
379
379
380 self._notification_handlers = {'registration_notification' : self._register_engine,
380 self._notification_handlers = {'registration_notification' : self._register_engine,
381 'unregistration_notification' : self._unregister_engine,
381 'unregistration_notification' : self._unregister_engine,
382 'shutdown_notification' : lambda msg: self.close(),
382 'shutdown_notification' : lambda msg: self.close(),
383 }
383 }
384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 'apply_reply' : self._handle_apply_reply}
385 'apply_reply' : self._handle_apply_reply}
386 self._connect(sshserver, ssh_kwargs, timeout)
386 self._connect(sshserver, ssh_kwargs, timeout)
387
387
388 def __del__(self):
388 def __del__(self):
389 """cleanup sockets, but _not_ context."""
389 """cleanup sockets, but _not_ context."""
390 self.close()
390 self.close()
391
391
392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 if ipython_dir is None:
393 if ipython_dir is None:
394 ipython_dir = get_ipython_dir()
394 ipython_dir = get_ipython_dir()
395 if profile_dir is not None:
395 if profile_dir is not None:
396 try:
396 try:
397 self._cd = ProfileDir.find_profile_dir(profile_dir)
397 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 return
398 return
399 except ProfileDirError:
399 except ProfileDirError:
400 pass
400 pass
401 elif profile is not None:
401 elif profile is not None:
402 try:
402 try:
403 self._cd = ProfileDir.find_profile_dir_by_name(
403 self._cd = ProfileDir.find_profile_dir_by_name(
404 ipython_dir, profile)
404 ipython_dir, profile)
405 return
405 return
406 except ProfileDirError:
406 except ProfileDirError:
407 pass
407 pass
408 self._cd = None
408 self._cd = None
409
409
410 def _update_engines(self, engines):
410 def _update_engines(self, engines):
411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 for k,v in engines.iteritems():
412 for k,v in engines.iteritems():
413 eid = int(k)
413 eid = int(k)
414 self._engines[eid] = v
414 self._engines[eid] = v
415 self._ids.append(eid)
415 self._ids.append(eid)
416 self._ids = sorted(self._ids)
416 self._ids = sorted(self._ids)
417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 self._task_scheme == 'pure' and self._task_socket:
418 self._task_scheme == 'pure' and self._task_socket:
419 self._stop_scheduling_tasks()
419 self._stop_scheduling_tasks()
420
420
421 def _stop_scheduling_tasks(self):
421 def _stop_scheduling_tasks(self):
422 """Stop scheduling tasks because an engine has been unregistered
422 """Stop scheduling tasks because an engine has been unregistered
423 from a pure ZMQ scheduler.
423 from a pure ZMQ scheduler.
424 """
424 """
425 self._task_socket.close()
425 self._task_socket.close()
426 self._task_socket = None
426 self._task_socket = None
427 msg = "An engine has been unregistered, and we are using pure " +\
427 msg = "An engine has been unregistered, and we are using pure " +\
428 "ZMQ task scheduling. Task farming will be disabled."
428 "ZMQ task scheduling. Task farming will be disabled."
429 if self.outstanding:
429 if self.outstanding:
430 msg += " If you were running tasks when this happened, " +\
430 msg += " If you were running tasks when this happened, " +\
431 "some `outstanding` msg_ids may never resolve."
431 "some `outstanding` msg_ids may never resolve."
432 warnings.warn(msg, RuntimeWarning)
432 warnings.warn(msg, RuntimeWarning)
433
433
434 def _build_targets(self, targets):
434 def _build_targets(self, targets):
435 """Turn valid target IDs or 'all' into two lists:
435 """Turn valid target IDs or 'all' into two lists:
436 (int_ids, uuids).
436 (int_ids, uuids).
437 """
437 """
438 if not self._ids:
438 if not self._ids:
439 # flush notification socket if no engines yet, just in case
439 # flush notification socket if no engines yet, just in case
440 if not self.ids:
440 if not self.ids:
441 raise error.NoEnginesRegistered("Can't build targets without any engines")
441 raise error.NoEnginesRegistered("Can't build targets without any engines")
442
442
443 if targets is None:
443 if targets is None:
444 targets = self._ids
444 targets = self._ids
445 elif isinstance(targets, basestring):
445 elif isinstance(targets, basestring):
446 if targets.lower() == 'all':
446 if targets.lower() == 'all':
447 targets = self._ids
447 targets = self._ids
448 else:
448 else:
449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 elif isinstance(targets, int):
450 elif isinstance(targets, int):
451 if targets < 0:
451 if targets < 0:
452 targets = self.ids[targets]
452 targets = self.ids[targets]
453 if targets not in self._ids:
453 if targets not in self._ids:
454 raise IndexError("No such engine: %i"%targets)
454 raise IndexError("No such engine: %i"%targets)
455 targets = [targets]
455 targets = [targets]
456
456
457 if isinstance(targets, slice):
457 if isinstance(targets, slice):
458 indices = range(len(self._ids))[targets]
458 indices = range(len(self._ids))[targets]
459 ids = self.ids
459 ids = self.ids
460 targets = [ ids[i] for i in indices ]
460 targets = [ ids[i] for i in indices ]
461
461
462 if not isinstance(targets, (tuple, list, xrange)):
462 if not isinstance(targets, (tuple, list, xrange)):
463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464
464
465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466
466
467 def _connect(self, sshserver, ssh_kwargs, timeout):
467 def _connect(self, sshserver, ssh_kwargs, timeout):
468 """setup all our socket connections to the cluster. This is called from
468 """setup all our socket connections to the cluster. This is called from
469 __init__."""
469 __init__."""
470
470
471 # Maybe allow reconnecting?
471 # Maybe allow reconnecting?
472 if self._connected:
472 if self._connected:
473 return
473 return
474 self._connected=True
474 self._connected=True
475
475
476 def connect_socket(s, url):
476 def connect_socket(s, url):
477 url = util.disambiguate_url(url, self._config['location'])
477 url = util.disambiguate_url(url, self._config['location'])
478 if self._ssh:
478 if self._ssh:
479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 else:
480 else:
481 return s.connect(url)
481 return s.connect(url)
482
482
483 self.session.send(self._query_socket, 'connection_request')
483 self.session.send(self._query_socket, 'connection_request')
484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 poller = zmq.Poller()
485 poller = zmq.Poller()
486 poller.register(self._query_socket, zmq.POLLIN)
486 poller.register(self._query_socket, zmq.POLLIN)
487 # poll expects milliseconds, timeout is seconds
487 # poll expects milliseconds, timeout is seconds
488 evts = poller.poll(timeout*1000)
488 evts = poller.poll(timeout*1000)
489 if not evts:
489 if not evts:
490 raise error.TimeoutError("Hub connection request timed out")
490 raise error.TimeoutError("Hub connection request timed out")
491 idents,msg = self.session.recv(self._query_socket,mode=0)
491 idents,msg = self.session.recv(self._query_socket,mode=0)
492 if self.debug:
492 if self.debug:
493 pprint(msg)
493 pprint(msg)
494 msg = Message(msg)
494 msg = Message(msg)
495 content = msg.content
495 content = msg.content
496 self._config['registration'] = dict(content)
496 self._config['registration'] = dict(content)
497 if content.status == 'ok':
497 if content.status == 'ok':
498 ident = self.session.bsession
498 ident = self.session.bsession
499 if content.mux:
499 if content.mux:
500 self._mux_socket = self._context.socket(zmq.DEALER)
500 self._mux_socket = self._context.socket(zmq.DEALER)
501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 connect_socket(self._mux_socket, content.mux)
502 connect_socket(self._mux_socket, content.mux)
503 if content.task:
503 if content.task:
504 self._task_scheme, task_addr = content.task
504 self._task_scheme, task_addr = content.task
505 self._task_socket = self._context.socket(zmq.DEALER)
505 self._task_socket = self._context.socket(zmq.DEALER)
506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 connect_socket(self._task_socket, task_addr)
507 connect_socket(self._task_socket, task_addr)
508 if content.notification:
508 if content.notification:
509 self._notification_socket = self._context.socket(zmq.SUB)
509 self._notification_socket = self._context.socket(zmq.SUB)
510 connect_socket(self._notification_socket, content.notification)
510 connect_socket(self._notification_socket, content.notification)
511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 # if content.query:
512 # if content.query:
513 # self._query_socket = self._context.socket(zmq.DEALER)
513 # self._query_socket = self._context.socket(zmq.DEALER)
514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 # connect_socket(self._query_socket, content.query)
515 # connect_socket(self._query_socket, content.query)
516 if content.control:
516 if content.control:
517 self._control_socket = self._context.socket(zmq.DEALER)
517 self._control_socket = self._context.socket(zmq.DEALER)
518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 connect_socket(self._control_socket, content.control)
519 connect_socket(self._control_socket, content.control)
520 if content.iopub:
520 if content.iopub:
521 self._iopub_socket = self._context.socket(zmq.SUB)
521 self._iopub_socket = self._context.socket(zmq.SUB)
522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 connect_socket(self._iopub_socket, content.iopub)
524 connect_socket(self._iopub_socket, content.iopub)
525 self._update_engines(dict(content.engines))
525 self._update_engines(dict(content.engines))
526 else:
526 else:
527 self._connected = False
527 self._connected = False
528 raise Exception("Failed to connect!")
528 raise Exception("Failed to connect!")
529
529
530 #--------------------------------------------------------------------------
530 #--------------------------------------------------------------------------
531 # handlers and callbacks for incoming messages
531 # handlers and callbacks for incoming messages
532 #--------------------------------------------------------------------------
532 #--------------------------------------------------------------------------
533
533
534 def _unwrap_exception(self, content):
534 def _unwrap_exception(self, content):
535 """unwrap exception, and remap engine_id to int."""
535 """unwrap exception, and remap engine_id to int."""
536 e = error.unwrap_exception(content)
536 e = error.unwrap_exception(content)
537 # print e.traceback
537 # print e.traceback
538 if e.engine_info:
538 if e.engine_info:
539 e_uuid = e.engine_info['engine_uuid']
539 e_uuid = e.engine_info['engine_uuid']
540 eid = self._engines[e_uuid]
540 eid = self._engines[e_uuid]
541 e.engine_info['engine_id'] = eid
541 e.engine_info['engine_id'] = eid
542 return e
542 return e
543
543
544 def _extract_metadata(self, header, parent, content):
544 def _extract_metadata(self, header, parent, content):
545 md = {'msg_id' : parent['msg_id'],
545 md = {'msg_id' : parent['msg_id'],
546 'received' : datetime.now(),
546 'received' : datetime.now(),
547 'engine_uuid' : header.get('engine', None),
547 'engine_uuid' : header.get('engine', None),
548 'follow' : parent.get('follow', []),
548 'follow' : parent.get('follow', []),
549 'after' : parent.get('after', []),
549 'after' : parent.get('after', []),
550 'status' : content['status'],
550 'status' : content['status'],
551 }
551 }
552
552
553 if md['engine_uuid'] is not None:
553 if md['engine_uuid'] is not None:
554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555
555
556 if 'date' in parent:
556 if 'date' in parent:
557 md['submitted'] = parent['date']
557 md['submitted'] = parent['date']
558 if 'started' in header:
558 if 'started' in header:
559 md['started'] = header['started']
559 md['started'] = header['started']
560 if 'date' in header:
560 if 'date' in header:
561 md['completed'] = header['date']
561 md['completed'] = header['date']
562 return md
562 return md
563
563
564 def _register_engine(self, msg):
564 def _register_engine(self, msg):
565 """Register a new engine, and update our connection info."""
565 """Register a new engine, and update our connection info."""
566 content = msg['content']
566 content = msg['content']
567 eid = content['id']
567 eid = content['id']
568 d = {eid : content['queue']}
568 d = {eid : content['queue']}
569 self._update_engines(d)
569 self._update_engines(d)
570
570
571 def _unregister_engine(self, msg):
571 def _unregister_engine(self, msg):
572 """Unregister an engine that has died."""
572 """Unregister an engine that has died."""
573 content = msg['content']
573 content = msg['content']
574 eid = int(content['id'])
574 eid = int(content['id'])
575 if eid in self._ids:
575 if eid in self._ids:
576 self._ids.remove(eid)
576 self._ids.remove(eid)
577 uuid = self._engines.pop(eid)
577 uuid = self._engines.pop(eid)
578
578
579 self._handle_stranded_msgs(eid, uuid)
579 self._handle_stranded_msgs(eid, uuid)
580
580
581 if self._task_socket and self._task_scheme == 'pure':
581 if self._task_socket and self._task_scheme == 'pure':
582 self._stop_scheduling_tasks()
582 self._stop_scheduling_tasks()
583
583
584 def _handle_stranded_msgs(self, eid, uuid):
584 def _handle_stranded_msgs(self, eid, uuid):
585 """Handle messages known to be on an engine when the engine unregisters.
585 """Handle messages known to be on an engine when the engine unregisters.
586
586
587 It is possible that this will fire prematurely - that is, an engine will
587 It is possible that this will fire prematurely - that is, an engine will
588 go down after completing a result, and the client will be notified
588 go down after completing a result, and the client will be notified
589 of the unregistration and later receive the successful result.
589 of the unregistration and later receive the successful result.
590 """
590 """
591
591
592 outstanding = self._outstanding_dict[uuid]
592 outstanding = self._outstanding_dict[uuid]
593
593
594 for msg_id in list(outstanding):
594 for msg_id in list(outstanding):
595 if msg_id in self.results:
595 if msg_id in self.results:
596 # we already
596 # we already
597 continue
597 continue
598 try:
598 try:
599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 except:
600 except:
601 content = error.wrap_exception()
601 content = error.wrap_exception()
602 # build a fake message:
602 # build a fake message:
603 parent = {}
603 parent = {}
604 header = {}
604 header = {}
605 parent['msg_id'] = msg_id
605 parent['msg_id'] = msg_id
606 header['engine'] = uuid
606 header['engine'] = uuid
607 header['date'] = datetime.now()
607 header['date'] = datetime.now()
608 msg = dict(parent_header=parent, header=header, content=content)
608 msg = dict(parent_header=parent, header=header, content=content)
609 self._handle_apply_reply(msg)
609 self._handle_apply_reply(msg)
610
610
611 def _handle_execute_reply(self, msg):
611 def _handle_execute_reply(self, msg):
612 """Save the reply to an execute_request into our results.
612 """Save the reply to an execute_request into our results.
613
613
614 execute messages are never actually used. apply is used instead.
614 execute messages are never actually used. apply is used instead.
615 """
615 """
616
616
617 parent = msg['parent_header']
617 parent = msg['parent_header']
618 msg_id = parent['msg_id']
618 msg_id = parent['msg_id']
619 if msg_id not in self.outstanding:
619 if msg_id not in self.outstanding:
620 if msg_id in self.history:
620 if msg_id in self.history:
621 print ("got stale result: %s"%msg_id)
621 print ("got stale result: %s"%msg_id)
622 else:
622 else:
623 print ("got unknown result: %s"%msg_id)
623 print ("got unknown result: %s"%msg_id)
624 else:
624 else:
625 self.outstanding.remove(msg_id)
625 self.outstanding.remove(msg_id)
626 self.results[msg_id] = self._unwrap_exception(msg['content'])
626 self.results[msg_id] = self._unwrap_exception(msg['content'])
627
627
628 def _handle_apply_reply(self, msg):
628 def _handle_apply_reply(self, msg):
629 """Save the reply to an apply_request into our results."""
629 """Save the reply to an apply_request into our results."""
630 parent = msg['parent_header']
630 parent = msg['parent_header']
631 msg_id = parent['msg_id']
631 msg_id = parent['msg_id']
632 if msg_id not in self.outstanding:
632 if msg_id not in self.outstanding:
633 if msg_id in self.history:
633 if msg_id in self.history:
634 print ("got stale result: %s"%msg_id)
634 print ("got stale result: %s"%msg_id)
635 print self.results[msg_id]
635 print self.results[msg_id]
636 print msg
636 print msg
637 else:
637 else:
638 print ("got unknown result: %s"%msg_id)
638 print ("got unknown result: %s"%msg_id)
639 else:
639 else:
640 self.outstanding.remove(msg_id)
640 self.outstanding.remove(msg_id)
641 content = msg['content']
641 content = msg['content']
642 header = msg['header']
642 header = msg['header']
643
643
644 # construct metadata:
644 # construct metadata:
645 md = self.metadata[msg_id]
645 md = self.metadata[msg_id]
646 md.update(self._extract_metadata(header, parent, content))
646 md.update(self._extract_metadata(header, parent, content))
647 # is this redundant?
647 # is this redundant?
648 self.metadata[msg_id] = md
648 self.metadata[msg_id] = md
649
649
650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 if msg_id in e_outstanding:
651 if msg_id in e_outstanding:
652 e_outstanding.remove(msg_id)
652 e_outstanding.remove(msg_id)
653
653
654 # construct result:
654 # construct result:
655 if content['status'] == 'ok':
655 if content['status'] == 'ok':
656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 elif content['status'] == 'aborted':
657 elif content['status'] == 'aborted':
658 self.results[msg_id] = error.TaskAborted(msg_id)
658 self.results[msg_id] = error.TaskAborted(msg_id)
659 elif content['status'] == 'resubmitted':
659 elif content['status'] == 'resubmitted':
660 # TODO: handle resubmission
660 # TODO: handle resubmission
661 pass
661 pass
662 else:
662 else:
663 self.results[msg_id] = self._unwrap_exception(content)
663 self.results[msg_id] = self._unwrap_exception(content)
664
664
665 def _flush_notifications(self):
665 def _flush_notifications(self):
666 """Flush notifications of engine registrations waiting
666 """Flush notifications of engine registrations waiting
667 in ZMQ queue."""
667 in ZMQ queue."""
668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 while msg is not None:
669 while msg is not None:
670 if self.debug:
670 if self.debug:
671 pprint(msg)
671 pprint(msg)
672 msg_type = msg['header']['msg_type']
672 msg_type = msg['header']['msg_type']
673 handler = self._notification_handlers.get(msg_type, None)
673 handler = self._notification_handlers.get(msg_type, None)
674 if handler is None:
674 if handler is None:
675 raise Exception("Unhandled message type: %s"%msg.msg_type)
675 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 else:
676 else:
677 handler(msg)
677 handler(msg)
678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679
679
680 def _flush_results(self, sock):
680 def _flush_results(self, sock):
681 """Flush task or queue results waiting in ZMQ queue."""
681 """Flush task or queue results waiting in ZMQ queue."""
682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 while msg is not None:
683 while msg is not None:
684 if self.debug:
684 if self.debug:
685 pprint(msg)
685 pprint(msg)
686 msg_type = msg['header']['msg_type']
686 msg_type = msg['header']['msg_type']
687 handler = self._queue_handlers.get(msg_type, None)
687 handler = self._queue_handlers.get(msg_type, None)
688 if handler is None:
688 if handler is None:
689 raise Exception("Unhandled message type: %s"%msg.msg_type)
689 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 else:
690 else:
691 handler(msg)
691 handler(msg)
692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693
693
694 def _flush_control(self, sock):
694 def _flush_control(self, sock):
695 """Flush replies from the control channel waiting
695 """Flush replies from the control channel waiting
696 in the ZMQ queue.
696 in the ZMQ queue.
697
697
698 Currently: ignore them."""
698 Currently: ignore them."""
699 if self._ignored_control_replies <= 0:
699 if self._ignored_control_replies <= 0:
700 return
700 return
701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 while msg is not None:
702 while msg is not None:
703 self._ignored_control_replies -= 1
703 self._ignored_control_replies -= 1
704 if self.debug:
704 if self.debug:
705 pprint(msg)
705 pprint(msg)
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707
707
708 def _flush_ignored_control(self):
708 def _flush_ignored_control(self):
709 """flush ignored control replies"""
709 """flush ignored control replies"""
710 while self._ignored_control_replies > 0:
710 while self._ignored_control_replies > 0:
711 self.session.recv(self._control_socket)
711 self.session.recv(self._control_socket)
712 self._ignored_control_replies -= 1
712 self._ignored_control_replies -= 1
713
713
714 def _flush_ignored_hub_replies(self):
714 def _flush_ignored_hub_replies(self):
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 while msg is not None:
716 while msg is not None:
717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718
718
719 def _flush_iopub(self, sock):
719 def _flush_iopub(self, sock):
720 """Flush replies from the iopub channel waiting
720 """Flush replies from the iopub channel waiting
721 in the ZMQ queue.
721 in the ZMQ queue.
722 """
722 """
723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 while msg is not None:
724 while msg is not None:
725 if self.debug:
725 if self.debug:
726 pprint(msg)
726 pprint(msg)
727 parent = msg['parent_header']
727 parent = msg['parent_header']
728 msg_id = parent['msg_id']
728 msg_id = parent['msg_id']
729 content = msg['content']
729 content = msg['content']
730 header = msg['header']
730 header = msg['header']
731 msg_type = msg['header']['msg_type']
731 msg_type = msg['header']['msg_type']
732
732
733 # init metadata:
733 # init metadata:
734 md = self.metadata[msg_id]
734 md = self.metadata[msg_id]
735
735
736 if msg_type == 'stream':
736 if msg_type == 'stream':
737 name = content['name']
737 name = content['name']
738 s = md[name] or ''
738 s = md[name] or ''
739 md[name] = s + content['data']
739 md[name] = s + content['data']
740 elif msg_type == 'pyerr':
740 elif msg_type == 'pyerr':
741 md.update({'pyerr' : self._unwrap_exception(content)})
741 md.update({'pyerr' : self._unwrap_exception(content)})
742 elif msg_type == 'pyin':
742 elif msg_type == 'pyin':
743 md.update({'pyin' : content['code']})
743 md.update({'pyin' : content['code']})
744 else:
744 else:
745 md.update({msg_type : content.get('data', '')})
745 md.update({msg_type : content.get('data', '')})
746
746
747 # reduntant?
747 # reduntant?
748 self.metadata[msg_id] = md
748 self.metadata[msg_id] = md
749
749
750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751
751
752 #--------------------------------------------------------------------------
752 #--------------------------------------------------------------------------
753 # len, getitem
753 # len, getitem
754 #--------------------------------------------------------------------------
754 #--------------------------------------------------------------------------
755
755
756 def __len__(self):
756 def __len__(self):
757 """len(client) returns # of engines."""
757 """len(client) returns # of engines."""
758 return len(self.ids)
758 return len(self.ids)
759
759
760 def __getitem__(self, key):
760 def __getitem__(self, key):
761 """index access returns DirectView multiplexer objects
761 """index access returns DirectView multiplexer objects
762
762
763 Must be int, slice, or list/tuple/xrange of ints"""
763 Must be int, slice, or list/tuple/xrange of ints"""
764 if not isinstance(key, (int, slice, tuple, list, xrange)):
764 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 else:
766 else:
767 return self.direct_view(key)
767 return self.direct_view(key)
768
768
769 #--------------------------------------------------------------------------
769 #--------------------------------------------------------------------------
770 # Begin public methods
770 # Begin public methods
771 #--------------------------------------------------------------------------
771 #--------------------------------------------------------------------------
772
772
773 @property
773 @property
774 def ids(self):
774 def ids(self):
775 """Always up-to-date ids property."""
775 """Always up-to-date ids property."""
776 self._flush_notifications()
776 self._flush_notifications()
777 # always copy:
777 # always copy:
778 return list(self._ids)
778 return list(self._ids)
779
779
780 def close(self):
780 def close(self):
781 if self._closed:
781 if self._closed:
782 return
782 return
783 snames = filter(lambda n: n.endswith('socket'), dir(self))
783 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 for socket in map(lambda name: getattr(self, name), snames):
784 for socket in map(lambda name: getattr(self, name), snames):
785 if isinstance(socket, zmq.Socket) and not socket.closed:
785 if isinstance(socket, zmq.Socket) and not socket.closed:
786 socket.close()
786 socket.close()
787 self._closed = True
787 self._closed = True
788
788
789 def spin(self):
789 def spin(self):
790 """Flush any registration notifications and execution results
790 """Flush any registration notifications and execution results
791 waiting in the ZMQ queue.
791 waiting in the ZMQ queue.
792 """
792 """
793 if self._notification_socket:
793 if self._notification_socket:
794 self._flush_notifications()
794 self._flush_notifications()
795 if self._mux_socket:
795 if self._mux_socket:
796 self._flush_results(self._mux_socket)
796 self._flush_results(self._mux_socket)
797 if self._task_socket:
797 if self._task_socket:
798 self._flush_results(self._task_socket)
798 self._flush_results(self._task_socket)
799 if self._control_socket:
799 if self._control_socket:
800 self._flush_control(self._control_socket)
800 self._flush_control(self._control_socket)
801 if self._iopub_socket:
801 if self._iopub_socket:
802 self._flush_iopub(self._iopub_socket)
802 self._flush_iopub(self._iopub_socket)
803 if self._query_socket:
803 if self._query_socket:
804 self._flush_ignored_hub_replies()
804 self._flush_ignored_hub_replies()
805
805
806 def wait(self, jobs=None, timeout=-1):
806 def wait(self, jobs=None, timeout=-1):
807 """waits on one or more `jobs`, for up to `timeout` seconds.
807 """waits on one or more `jobs`, for up to `timeout` seconds.
808
808
809 Parameters
809 Parameters
810 ----------
810 ----------
811
811
812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 ints are indices to self.history
813 ints are indices to self.history
814 strs are msg_ids
814 strs are msg_ids
815 default: wait on all outstanding messages
815 default: wait on all outstanding messages
816 timeout : float
816 timeout : float
817 a time in seconds, after which to give up.
817 a time in seconds, after which to give up.
818 default is -1, which means no timeout
818 default is -1, which means no timeout
819
819
820 Returns
820 Returns
821 -------
821 -------
822
822
823 True : when all msg_ids are done
823 True : when all msg_ids are done
824 False : timeout reached, some msg_ids still outstanding
824 False : timeout reached, some msg_ids still outstanding
825 """
825 """
826 tic = time.time()
826 tic = time.time()
827 if jobs is None:
827 if jobs is None:
828 theids = self.outstanding
828 theids = self.outstanding
829 else:
829 else:
830 if isinstance(jobs, (int, basestring, AsyncResult)):
830 if isinstance(jobs, (int, basestring, AsyncResult)):
831 jobs = [jobs]
831 jobs = [jobs]
832 theids = set()
832 theids = set()
833 for job in jobs:
833 for job in jobs:
834 if isinstance(job, int):
834 if isinstance(job, int):
835 # index access
835 # index access
836 job = self.history[job]
836 job = self.history[job]
837 elif isinstance(job, AsyncResult):
837 elif isinstance(job, AsyncResult):
838 map(theids.add, job.msg_ids)
838 map(theids.add, job.msg_ids)
839 continue
839 continue
840 theids.add(job)
840 theids.add(job)
841 if not theids.intersection(self.outstanding):
841 if not theids.intersection(self.outstanding):
842 return True
842 return True
843 self.spin()
843 self.spin()
844 while theids.intersection(self.outstanding):
844 while theids.intersection(self.outstanding):
845 if timeout >= 0 and ( time.time()-tic ) > timeout:
845 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 break
846 break
847 time.sleep(1e-3)
847 time.sleep(1e-3)
848 self.spin()
848 self.spin()
849 return len(theids.intersection(self.outstanding)) == 0
849 return len(theids.intersection(self.outstanding)) == 0
850
850
851 #--------------------------------------------------------------------------
851 #--------------------------------------------------------------------------
852 # Control methods
852 # Control methods
853 #--------------------------------------------------------------------------
853 #--------------------------------------------------------------------------
854
854
855 @spin_first
855 @spin_first
856 def clear(self, targets=None, block=None):
856 def clear(self, targets=None, block=None):
857 """Clear the namespace in target(s)."""
857 """Clear the namespace in target(s)."""
858 block = self.block if block is None else block
858 block = self.block if block is None else block
859 targets = self._build_targets(targets)[0]
859 targets = self._build_targets(targets)[0]
860 for t in targets:
860 for t in targets:
861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 error = False
862 error = False
863 if block:
863 if block:
864 self._flush_ignored_control()
864 self._flush_ignored_control()
865 for i in range(len(targets)):
865 for i in range(len(targets)):
866 idents,msg = self.session.recv(self._control_socket,0)
866 idents,msg = self.session.recv(self._control_socket,0)
867 if self.debug:
867 if self.debug:
868 pprint(msg)
868 pprint(msg)
869 if msg['content']['status'] != 'ok':
869 if msg['content']['status'] != 'ok':
870 error = self._unwrap_exception(msg['content'])
870 error = self._unwrap_exception(msg['content'])
871 else:
871 else:
872 self._ignored_control_replies += len(targets)
872 self._ignored_control_replies += len(targets)
873 if error:
873 if error:
874 raise error
874 raise error
875
875
876
876
877 @spin_first
877 @spin_first
878 def abort(self, jobs=None, targets=None, block=None):
878 def abort(self, jobs=None, targets=None, block=None):
879 """Abort specific jobs from the execution queues of target(s).
879 """Abort specific jobs from the execution queues of target(s).
880
880
881 This is a mechanism to prevent jobs that have already been submitted
881 This is a mechanism to prevent jobs that have already been submitted
882 from executing.
882 from executing.
883
883
884 Parameters
884 Parameters
885 ----------
885 ----------
886
886
887 jobs : msg_id, list of msg_ids, or AsyncResult
887 jobs : msg_id, list of msg_ids, or AsyncResult
888 The jobs to be aborted
888 The jobs to be aborted
889
889
890 If unspecified/None: abort all outstanding jobs.
890
891
891 """
892 """
892 block = self.block if block is None else block
893 block = self.block if block is None else block
894 jobs = jobs if jobs is not None else list(self.outstanding)
893 targets = self._build_targets(targets)[0]
895 targets = self._build_targets(targets)[0]
896
894 msg_ids = []
897 msg_ids = []
895 if isinstance(jobs, (basestring,AsyncResult)):
898 if isinstance(jobs, (basestring,AsyncResult)):
896 jobs = [jobs]
899 jobs = [jobs]
897 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
900 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
898 if bad_ids:
901 if bad_ids:
899 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
902 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
900 for j in jobs:
903 for j in jobs:
901 if isinstance(j, AsyncResult):
904 if isinstance(j, AsyncResult):
902 msg_ids.extend(j.msg_ids)
905 msg_ids.extend(j.msg_ids)
903 else:
906 else:
904 msg_ids.append(j)
907 msg_ids.append(j)
905 content = dict(msg_ids=msg_ids)
908 content = dict(msg_ids=msg_ids)
906 for t in targets:
909 for t in targets:
907 self.session.send(self._control_socket, 'abort_request',
910 self.session.send(self._control_socket, 'abort_request',
908 content=content, ident=t)
911 content=content, ident=t)
909 error = False
912 error = False
910 if block:
913 if block:
911 self._flush_ignored_control()
914 self._flush_ignored_control()
912 for i in range(len(targets)):
915 for i in range(len(targets)):
913 idents,msg = self.session.recv(self._control_socket,0)
916 idents,msg = self.session.recv(self._control_socket,0)
914 if self.debug:
917 if self.debug:
915 pprint(msg)
918 pprint(msg)
916 if msg['content']['status'] != 'ok':
919 if msg['content']['status'] != 'ok':
917 error = self._unwrap_exception(msg['content'])
920 error = self._unwrap_exception(msg['content'])
918 else:
921 else:
919 self._ignored_control_replies += len(targets)
922 self._ignored_control_replies += len(targets)
920 if error:
923 if error:
921 raise error
924 raise error
922
925
923 @spin_first
926 @spin_first
924 def shutdown(self, targets=None, restart=False, hub=False, block=None):
927 def shutdown(self, targets=None, restart=False, hub=False, block=None):
925 """Terminates one or more engine processes, optionally including the hub."""
928 """Terminates one or more engine processes, optionally including the hub."""
926 block = self.block if block is None else block
929 block = self.block if block is None else block
927 if hub:
930 if hub:
928 targets = 'all'
931 targets = 'all'
929 targets = self._build_targets(targets)[0]
932 targets = self._build_targets(targets)[0]
930 for t in targets:
933 for t in targets:
931 self.session.send(self._control_socket, 'shutdown_request',
934 self.session.send(self._control_socket, 'shutdown_request',
932 content={'restart':restart},ident=t)
935 content={'restart':restart},ident=t)
933 error = False
936 error = False
934 if block or hub:
937 if block or hub:
935 self._flush_ignored_control()
938 self._flush_ignored_control()
936 for i in range(len(targets)):
939 for i in range(len(targets)):
937 idents,msg = self.session.recv(self._control_socket, 0)
940 idents,msg = self.session.recv(self._control_socket, 0)
938 if self.debug:
941 if self.debug:
939 pprint(msg)
942 pprint(msg)
940 if msg['content']['status'] != 'ok':
943 if msg['content']['status'] != 'ok':
941 error = self._unwrap_exception(msg['content'])
944 error = self._unwrap_exception(msg['content'])
942 else:
945 else:
943 self._ignored_control_replies += len(targets)
946 self._ignored_control_replies += len(targets)
944
947
945 if hub:
948 if hub:
946 time.sleep(0.25)
949 time.sleep(0.25)
947 self.session.send(self._query_socket, 'shutdown_request')
950 self.session.send(self._query_socket, 'shutdown_request')
948 idents,msg = self.session.recv(self._query_socket, 0)
951 idents,msg = self.session.recv(self._query_socket, 0)
949 if self.debug:
952 if self.debug:
950 pprint(msg)
953 pprint(msg)
951 if msg['content']['status'] != 'ok':
954 if msg['content']['status'] != 'ok':
952 error = self._unwrap_exception(msg['content'])
955 error = self._unwrap_exception(msg['content'])
953
956
954 if error:
957 if error:
955 raise error
958 raise error
956
959
957 #--------------------------------------------------------------------------
960 #--------------------------------------------------------------------------
958 # Execution related methods
961 # Execution related methods
959 #--------------------------------------------------------------------------
962 #--------------------------------------------------------------------------
960
963
961 def _maybe_raise(self, result):
964 def _maybe_raise(self, result):
962 """wrapper for maybe raising an exception if apply failed."""
965 """wrapper for maybe raising an exception if apply failed."""
963 if isinstance(result, error.RemoteError):
966 if isinstance(result, error.RemoteError):
964 raise result
967 raise result
965
968
966 return result
969 return result
967
970
968 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
971 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
969 ident=None):
972 ident=None):
970 """construct and send an apply message via a socket.
973 """construct and send an apply message via a socket.
971
974
972 This is the principal method with which all engine execution is performed by views.
975 This is the principal method with which all engine execution is performed by views.
973 """
976 """
974
977
975 assert not self._closed, "cannot use me anymore, I'm closed!"
978 assert not self._closed, "cannot use me anymore, I'm closed!"
976 # defaults:
979 # defaults:
977 args = args if args is not None else []
980 args = args if args is not None else []
978 kwargs = kwargs if kwargs is not None else {}
981 kwargs = kwargs if kwargs is not None else {}
979 subheader = subheader if subheader is not None else {}
982 subheader = subheader if subheader is not None else {}
980
983
981 # validate arguments
984 # validate arguments
982 if not callable(f):
985 if not callable(f):
983 raise TypeError("f must be callable, not %s"%type(f))
986 raise TypeError("f must be callable, not %s"%type(f))
984 if not isinstance(args, (tuple, list)):
987 if not isinstance(args, (tuple, list)):
985 raise TypeError("args must be tuple or list, not %s"%type(args))
988 raise TypeError("args must be tuple or list, not %s"%type(args))
986 if not isinstance(kwargs, dict):
989 if not isinstance(kwargs, dict):
987 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
990 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
988 if not isinstance(subheader, dict):
991 if not isinstance(subheader, dict):
989 raise TypeError("subheader must be dict, not %s"%type(subheader))
992 raise TypeError("subheader must be dict, not %s"%type(subheader))
990
993
991 bufs = util.pack_apply_message(f,args,kwargs)
994 bufs = util.pack_apply_message(f,args,kwargs)
992
995
993 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
996 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
994 subheader=subheader, track=track)
997 subheader=subheader, track=track)
995
998
996 msg_id = msg['header']['msg_id']
999 msg_id = msg['header']['msg_id']
997 self.outstanding.add(msg_id)
1000 self.outstanding.add(msg_id)
998 if ident:
1001 if ident:
999 # possibly routed to a specific engine
1002 # possibly routed to a specific engine
1000 if isinstance(ident, list):
1003 if isinstance(ident, list):
1001 ident = ident[-1]
1004 ident = ident[-1]
1002 if ident in self._engines.values():
1005 if ident in self._engines.values():
1003 # save for later, in case of engine death
1006 # save for later, in case of engine death
1004 self._outstanding_dict[ident].add(msg_id)
1007 self._outstanding_dict[ident].add(msg_id)
1005 self.history.append(msg_id)
1008 self.history.append(msg_id)
1006 self.metadata[msg_id]['submitted'] = datetime.now()
1009 self.metadata[msg_id]['submitted'] = datetime.now()
1007
1010
1008 return msg
1011 return msg
1009
1012
1010 #--------------------------------------------------------------------------
1013 #--------------------------------------------------------------------------
1011 # construct a View object
1014 # construct a View object
1012 #--------------------------------------------------------------------------
1015 #--------------------------------------------------------------------------
1013
1016
1014 def load_balanced_view(self, targets=None):
1017 def load_balanced_view(self, targets=None):
1015 """construct a DirectView object.
1018 """construct a DirectView object.
1016
1019
1017 If no arguments are specified, create a LoadBalancedView
1020 If no arguments are specified, create a LoadBalancedView
1018 using all engines.
1021 using all engines.
1019
1022
1020 Parameters
1023 Parameters
1021 ----------
1024 ----------
1022
1025
1023 targets: list,slice,int,etc. [default: use all engines]
1026 targets: list,slice,int,etc. [default: use all engines]
1024 The subset of engines across which to load-balance
1027 The subset of engines across which to load-balance
1025 """
1028 """
1026 if targets == 'all':
1029 if targets == 'all':
1027 targets = None
1030 targets = None
1028 if targets is not None:
1031 if targets is not None:
1029 targets = self._build_targets(targets)[1]
1032 targets = self._build_targets(targets)[1]
1030 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1033 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1031
1034
1032 def direct_view(self, targets='all'):
1035 def direct_view(self, targets='all'):
1033 """construct a DirectView object.
1036 """construct a DirectView object.
1034
1037
1035 If no targets are specified, create a DirectView using all engines.
1038 If no targets are specified, create a DirectView using all engines.
1036
1039
1037 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1040 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1038 evaluate the target engines at each execution, whereas rc[:] will connect to
1041 evaluate the target engines at each execution, whereas rc[:] will connect to
1039 all *current* engines, and that list will not change.
1042 all *current* engines, and that list will not change.
1040
1043
1041 That is, 'all' will always use all engines, whereas rc[:] will not use
1044 That is, 'all' will always use all engines, whereas rc[:] will not use
1042 engines added after the DirectView is constructed.
1045 engines added after the DirectView is constructed.
1043
1046
1044 Parameters
1047 Parameters
1045 ----------
1048 ----------
1046
1049
1047 targets: list,slice,int,etc. [default: use all engines]
1050 targets: list,slice,int,etc. [default: use all engines]
1048 The engines to use for the View
1051 The engines to use for the View
1049 """
1052 """
1050 single = isinstance(targets, int)
1053 single = isinstance(targets, int)
1051 # allow 'all' to be lazily evaluated at each execution
1054 # allow 'all' to be lazily evaluated at each execution
1052 if targets != 'all':
1055 if targets != 'all':
1053 targets = self._build_targets(targets)[1]
1056 targets = self._build_targets(targets)[1]
1054 if single:
1057 if single:
1055 targets = targets[0]
1058 targets = targets[0]
1056 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1059 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1057
1060
1058 #--------------------------------------------------------------------------
1061 #--------------------------------------------------------------------------
1059 # Query methods
1062 # Query methods
1060 #--------------------------------------------------------------------------
1063 #--------------------------------------------------------------------------
1061
1064
1062 @spin_first
1065 @spin_first
1063 def get_result(self, indices_or_msg_ids=None, block=None):
1066 def get_result(self, indices_or_msg_ids=None, block=None):
1064 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1067 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1065
1068
1066 If the client already has the results, no request to the Hub will be made.
1069 If the client already has the results, no request to the Hub will be made.
1067
1070
1068 This is a convenient way to construct AsyncResult objects, which are wrappers
1071 This is a convenient way to construct AsyncResult objects, which are wrappers
1069 that include metadata about execution, and allow for awaiting results that
1072 that include metadata about execution, and allow for awaiting results that
1070 were not submitted by this Client.
1073 were not submitted by this Client.
1071
1074
1072 It can also be a convenient way to retrieve the metadata associated with
1075 It can also be a convenient way to retrieve the metadata associated with
1073 blocking execution, since it always retrieves
1076 blocking execution, since it always retrieves
1074
1077
1075 Examples
1078 Examples
1076 --------
1079 --------
1077 ::
1080 ::
1078
1081
1079 In [10]: r = client.apply()
1082 In [10]: r = client.apply()
1080
1083
1081 Parameters
1084 Parameters
1082 ----------
1085 ----------
1083
1086
1084 indices_or_msg_ids : integer history index, str msg_id, or list of either
1087 indices_or_msg_ids : integer history index, str msg_id, or list of either
1085 The indices or msg_ids of indices to be retrieved
1088 The indices or msg_ids of indices to be retrieved
1086
1089
1087 block : bool
1090 block : bool
1088 Whether to wait for the result to be done
1091 Whether to wait for the result to be done
1089
1092
1090 Returns
1093 Returns
1091 -------
1094 -------
1092
1095
1093 AsyncResult
1096 AsyncResult
1094 A single AsyncResult object will always be returned.
1097 A single AsyncResult object will always be returned.
1095
1098
1096 AsyncHubResult
1099 AsyncHubResult
1097 A subclass of AsyncResult that retrieves results from the Hub
1100 A subclass of AsyncResult that retrieves results from the Hub
1098
1101
1099 """
1102 """
1100 block = self.block if block is None else block
1103 block = self.block if block is None else block
1101 if indices_or_msg_ids is None:
1104 if indices_or_msg_ids is None:
1102 indices_or_msg_ids = -1
1105 indices_or_msg_ids = -1
1103
1106
1104 if not isinstance(indices_or_msg_ids, (list,tuple)):
1107 if not isinstance(indices_or_msg_ids, (list,tuple)):
1105 indices_or_msg_ids = [indices_or_msg_ids]
1108 indices_or_msg_ids = [indices_or_msg_ids]
1106
1109
1107 theids = []
1110 theids = []
1108 for id in indices_or_msg_ids:
1111 for id in indices_or_msg_ids:
1109 if isinstance(id, int):
1112 if isinstance(id, int):
1110 id = self.history[id]
1113 id = self.history[id]
1111 if not isinstance(id, basestring):
1114 if not isinstance(id, basestring):
1112 raise TypeError("indices must be str or int, not %r"%id)
1115 raise TypeError("indices must be str or int, not %r"%id)
1113 theids.append(id)
1116 theids.append(id)
1114
1117
1115 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1118 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1116 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1119 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1117
1120
1118 if remote_ids:
1121 if remote_ids:
1119 ar = AsyncHubResult(self, msg_ids=theids)
1122 ar = AsyncHubResult(self, msg_ids=theids)
1120 else:
1123 else:
1121 ar = AsyncResult(self, msg_ids=theids)
1124 ar = AsyncResult(self, msg_ids=theids)
1122
1125
1123 if block:
1126 if block:
1124 ar.wait()
1127 ar.wait()
1125
1128
1126 return ar
1129 return ar
1127
1130
1128 @spin_first
1131 @spin_first
1129 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1132 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1130 """Resubmit one or more tasks.
1133 """Resubmit one or more tasks.
1131
1134
1132 in-flight tasks may not be resubmitted.
1135 in-flight tasks may not be resubmitted.
1133
1136
1134 Parameters
1137 Parameters
1135 ----------
1138 ----------
1136
1139
1137 indices_or_msg_ids : integer history index, str msg_id, or list of either
1140 indices_or_msg_ids : integer history index, str msg_id, or list of either
1138 The indices or msg_ids of indices to be retrieved
1141 The indices or msg_ids of indices to be retrieved
1139
1142
1140 block : bool
1143 block : bool
1141 Whether to wait for the result to be done
1144 Whether to wait for the result to be done
1142
1145
1143 Returns
1146 Returns
1144 -------
1147 -------
1145
1148
1146 AsyncHubResult
1149 AsyncHubResult
1147 A subclass of AsyncResult that retrieves results from the Hub
1150 A subclass of AsyncResult that retrieves results from the Hub
1148
1151
1149 """
1152 """
1150 block = self.block if block is None else block
1153 block = self.block if block is None else block
1151 if indices_or_msg_ids is None:
1154 if indices_or_msg_ids is None:
1152 indices_or_msg_ids = -1
1155 indices_or_msg_ids = -1
1153
1156
1154 if not isinstance(indices_or_msg_ids, (list,tuple)):
1157 if not isinstance(indices_or_msg_ids, (list,tuple)):
1155 indices_or_msg_ids = [indices_or_msg_ids]
1158 indices_or_msg_ids = [indices_or_msg_ids]
1156
1159
1157 theids = []
1160 theids = []
1158 for id in indices_or_msg_ids:
1161 for id in indices_or_msg_ids:
1159 if isinstance(id, int):
1162 if isinstance(id, int):
1160 id = self.history[id]
1163 id = self.history[id]
1161 if not isinstance(id, basestring):
1164 if not isinstance(id, basestring):
1162 raise TypeError("indices must be str or int, not %r"%id)
1165 raise TypeError("indices must be str or int, not %r"%id)
1163 theids.append(id)
1166 theids.append(id)
1164
1167
1165 for msg_id in theids:
1168 for msg_id in theids:
1166 self.outstanding.discard(msg_id)
1169 self.outstanding.discard(msg_id)
1167 if msg_id in self.history:
1170 if msg_id in self.history:
1168 self.history.remove(msg_id)
1171 self.history.remove(msg_id)
1169 self.results.pop(msg_id, None)
1172 self.results.pop(msg_id, None)
1170 self.metadata.pop(msg_id, None)
1173 self.metadata.pop(msg_id, None)
1171 content = dict(msg_ids = theids)
1174 content = dict(msg_ids = theids)
1172
1175
1173 self.session.send(self._query_socket, 'resubmit_request', content)
1176 self.session.send(self._query_socket, 'resubmit_request', content)
1174
1177
1175 zmq.select([self._query_socket], [], [])
1178 zmq.select([self._query_socket], [], [])
1176 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1179 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1177 if self.debug:
1180 if self.debug:
1178 pprint(msg)
1181 pprint(msg)
1179 content = msg['content']
1182 content = msg['content']
1180 if content['status'] != 'ok':
1183 if content['status'] != 'ok':
1181 raise self._unwrap_exception(content)
1184 raise self._unwrap_exception(content)
1182
1185
1183 ar = AsyncHubResult(self, msg_ids=theids)
1186 ar = AsyncHubResult(self, msg_ids=theids)
1184
1187
1185 if block:
1188 if block:
1186 ar.wait()
1189 ar.wait()
1187
1190
1188 return ar
1191 return ar
1189
1192
1190 @spin_first
1193 @spin_first
1191 def result_status(self, msg_ids, status_only=True):
1194 def result_status(self, msg_ids, status_only=True):
1192 """Check on the status of the result(s) of the apply request with `msg_ids`.
1195 """Check on the status of the result(s) of the apply request with `msg_ids`.
1193
1196
1194 If status_only is False, then the actual results will be retrieved, else
1197 If status_only is False, then the actual results will be retrieved, else
1195 only the status of the results will be checked.
1198 only the status of the results will be checked.
1196
1199
1197 Parameters
1200 Parameters
1198 ----------
1201 ----------
1199
1202
1200 msg_ids : list of msg_ids
1203 msg_ids : list of msg_ids
1201 if int:
1204 if int:
1202 Passed as index to self.history for convenience.
1205 Passed as index to self.history for convenience.
1203 status_only : bool (default: True)
1206 status_only : bool (default: True)
1204 if False:
1207 if False:
1205 Retrieve the actual results of completed tasks.
1208 Retrieve the actual results of completed tasks.
1206
1209
1207 Returns
1210 Returns
1208 -------
1211 -------
1209
1212
1210 results : dict
1213 results : dict
1211 There will always be the keys 'pending' and 'completed', which will
1214 There will always be the keys 'pending' and 'completed', which will
1212 be lists of msg_ids that are incomplete or complete. If `status_only`
1215 be lists of msg_ids that are incomplete or complete. If `status_only`
1213 is False, then completed results will be keyed by their `msg_id`.
1216 is False, then completed results will be keyed by their `msg_id`.
1214 """
1217 """
1215 if not isinstance(msg_ids, (list,tuple)):
1218 if not isinstance(msg_ids, (list,tuple)):
1216 msg_ids = [msg_ids]
1219 msg_ids = [msg_ids]
1217
1220
1218 theids = []
1221 theids = []
1219 for msg_id in msg_ids:
1222 for msg_id in msg_ids:
1220 if isinstance(msg_id, int):
1223 if isinstance(msg_id, int):
1221 msg_id = self.history[msg_id]
1224 msg_id = self.history[msg_id]
1222 if not isinstance(msg_id, basestring):
1225 if not isinstance(msg_id, basestring):
1223 raise TypeError("msg_ids must be str, not %r"%msg_id)
1226 raise TypeError("msg_ids must be str, not %r"%msg_id)
1224 theids.append(msg_id)
1227 theids.append(msg_id)
1225
1228
1226 completed = []
1229 completed = []
1227 local_results = {}
1230 local_results = {}
1228
1231
1229 # comment this block out to temporarily disable local shortcut:
1232 # comment this block out to temporarily disable local shortcut:
1230 for msg_id in theids:
1233 for msg_id in theids:
1231 if msg_id in self.results:
1234 if msg_id in self.results:
1232 completed.append(msg_id)
1235 completed.append(msg_id)
1233 local_results[msg_id] = self.results[msg_id]
1236 local_results[msg_id] = self.results[msg_id]
1234 theids.remove(msg_id)
1237 theids.remove(msg_id)
1235
1238
1236 if theids: # some not locally cached
1239 if theids: # some not locally cached
1237 content = dict(msg_ids=theids, status_only=status_only)
1240 content = dict(msg_ids=theids, status_only=status_only)
1238 msg = self.session.send(self._query_socket, "result_request", content=content)
1241 msg = self.session.send(self._query_socket, "result_request", content=content)
1239 zmq.select([self._query_socket], [], [])
1242 zmq.select([self._query_socket], [], [])
1240 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1243 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1241 if self.debug:
1244 if self.debug:
1242 pprint(msg)
1245 pprint(msg)
1243 content = msg['content']
1246 content = msg['content']
1244 if content['status'] != 'ok':
1247 if content['status'] != 'ok':
1245 raise self._unwrap_exception(content)
1248 raise self._unwrap_exception(content)
1246 buffers = msg['buffers']
1249 buffers = msg['buffers']
1247 else:
1250 else:
1248 content = dict(completed=[],pending=[])
1251 content = dict(completed=[],pending=[])
1249
1252
1250 content['completed'].extend(completed)
1253 content['completed'].extend(completed)
1251
1254
1252 if status_only:
1255 if status_only:
1253 return content
1256 return content
1254
1257
1255 failures = []
1258 failures = []
1256 # load cached results into result:
1259 # load cached results into result:
1257 content.update(local_results)
1260 content.update(local_results)
1258
1261
1259 # update cache with results:
1262 # update cache with results:
1260 for msg_id in sorted(theids):
1263 for msg_id in sorted(theids):
1261 if msg_id in content['completed']:
1264 if msg_id in content['completed']:
1262 rec = content[msg_id]
1265 rec = content[msg_id]
1263 parent = rec['header']
1266 parent = rec['header']
1264 header = rec['result_header']
1267 header = rec['result_header']
1265 rcontent = rec['result_content']
1268 rcontent = rec['result_content']
1266 iodict = rec['io']
1269 iodict = rec['io']
1267 if isinstance(rcontent, str):
1270 if isinstance(rcontent, str):
1268 rcontent = self.session.unpack(rcontent)
1271 rcontent = self.session.unpack(rcontent)
1269
1272
1270 md = self.metadata[msg_id]
1273 md = self.metadata[msg_id]
1271 md.update(self._extract_metadata(header, parent, rcontent))
1274 md.update(self._extract_metadata(header, parent, rcontent))
1272 md.update(iodict)
1275 md.update(iodict)
1273
1276
1274 if rcontent['status'] == 'ok':
1277 if rcontent['status'] == 'ok':
1275 res,buffers = util.unserialize_object(buffers)
1278 res,buffers = util.unserialize_object(buffers)
1276 else:
1279 else:
1277 print rcontent
1280 print rcontent
1278 res = self._unwrap_exception(rcontent)
1281 res = self._unwrap_exception(rcontent)
1279 failures.append(res)
1282 failures.append(res)
1280
1283
1281 self.results[msg_id] = res
1284 self.results[msg_id] = res
1282 content[msg_id] = res
1285 content[msg_id] = res
1283
1286
1284 if len(theids) == 1 and failures:
1287 if len(theids) == 1 and failures:
1285 raise failures[0]
1288 raise failures[0]
1286
1289
1287 error.collect_exceptions(failures, "result_status")
1290 error.collect_exceptions(failures, "result_status")
1288 return content
1291 return content
1289
1292
1290 @spin_first
1293 @spin_first
1291 def queue_status(self, targets='all', verbose=False):
1294 def queue_status(self, targets='all', verbose=False):
1292 """Fetch the status of engine queues.
1295 """Fetch the status of engine queues.
1293
1296
1294 Parameters
1297 Parameters
1295 ----------
1298 ----------
1296
1299
1297 targets : int/str/list of ints/strs
1300 targets : int/str/list of ints/strs
1298 the engines whose states are to be queried.
1301 the engines whose states are to be queried.
1299 default : all
1302 default : all
1300 verbose : bool
1303 verbose : bool
1301 Whether to return lengths only, or lists of ids for each element
1304 Whether to return lengths only, or lists of ids for each element
1302 """
1305 """
1303 engine_ids = self._build_targets(targets)[1]
1306 engine_ids = self._build_targets(targets)[1]
1304 content = dict(targets=engine_ids, verbose=verbose)
1307 content = dict(targets=engine_ids, verbose=verbose)
1305 self.session.send(self._query_socket, "queue_request", content=content)
1308 self.session.send(self._query_socket, "queue_request", content=content)
1306 idents,msg = self.session.recv(self._query_socket, 0)
1309 idents,msg = self.session.recv(self._query_socket, 0)
1307 if self.debug:
1310 if self.debug:
1308 pprint(msg)
1311 pprint(msg)
1309 content = msg['content']
1312 content = msg['content']
1310 status = content.pop('status')
1313 status = content.pop('status')
1311 if status != 'ok':
1314 if status != 'ok':
1312 raise self._unwrap_exception(content)
1315 raise self._unwrap_exception(content)
1313 content = rekey(content)
1316 content = rekey(content)
1314 if isinstance(targets, int):
1317 if isinstance(targets, int):
1315 return content[targets]
1318 return content[targets]
1316 else:
1319 else:
1317 return content
1320 return content
1318
1321
1319 @spin_first
1322 @spin_first
1320 def purge_results(self, jobs=[], targets=[]):
1323 def purge_results(self, jobs=[], targets=[]):
1321 """Tell the Hub to forget results.
1324 """Tell the Hub to forget results.
1322
1325
1323 Individual results can be purged by msg_id, or the entire
1326 Individual results can be purged by msg_id, or the entire
1324 history of specific targets can be purged.
1327 history of specific targets can be purged.
1325
1328
1326 Use `purge_results('all')` to scrub everything from the Hub's db.
1329 Use `purge_results('all')` to scrub everything from the Hub's db.
1327
1330
1328 Parameters
1331 Parameters
1329 ----------
1332 ----------
1330
1333
1331 jobs : str or list of str or AsyncResult objects
1334 jobs : str or list of str or AsyncResult objects
1332 the msg_ids whose results should be forgotten.
1335 the msg_ids whose results should be forgotten.
1333 targets : int/str/list of ints/strs
1336 targets : int/str/list of ints/strs
1334 The targets, by int_id, whose entire history is to be purged.
1337 The targets, by int_id, whose entire history is to be purged.
1335
1338
1336 default : None
1339 default : None
1337 """
1340 """
1338 if not targets and not jobs:
1341 if not targets and not jobs:
1339 raise ValueError("Must specify at least one of `targets` and `jobs`")
1342 raise ValueError("Must specify at least one of `targets` and `jobs`")
1340 if targets:
1343 if targets:
1341 targets = self._build_targets(targets)[1]
1344 targets = self._build_targets(targets)[1]
1342
1345
1343 # construct msg_ids from jobs
1346 # construct msg_ids from jobs
1344 if jobs == 'all':
1347 if jobs == 'all':
1345 msg_ids = jobs
1348 msg_ids = jobs
1346 else:
1349 else:
1347 msg_ids = []
1350 msg_ids = []
1348 if isinstance(jobs, (basestring,AsyncResult)):
1351 if isinstance(jobs, (basestring,AsyncResult)):
1349 jobs = [jobs]
1352 jobs = [jobs]
1350 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1353 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1351 if bad_ids:
1354 if bad_ids:
1352 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1355 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1353 for j in jobs:
1356 for j in jobs:
1354 if isinstance(j, AsyncResult):
1357 if isinstance(j, AsyncResult):
1355 msg_ids.extend(j.msg_ids)
1358 msg_ids.extend(j.msg_ids)
1356 else:
1359 else:
1357 msg_ids.append(j)
1360 msg_ids.append(j)
1358
1361
1359 content = dict(engine_ids=targets, msg_ids=msg_ids)
1362 content = dict(engine_ids=targets, msg_ids=msg_ids)
1360 self.session.send(self._query_socket, "purge_request", content=content)
1363 self.session.send(self._query_socket, "purge_request", content=content)
1361 idents, msg = self.session.recv(self._query_socket, 0)
1364 idents, msg = self.session.recv(self._query_socket, 0)
1362 if self.debug:
1365 if self.debug:
1363 pprint(msg)
1366 pprint(msg)
1364 content = msg['content']
1367 content = msg['content']
1365 if content['status'] != 'ok':
1368 if content['status'] != 'ok':
1366 raise self._unwrap_exception(content)
1369 raise self._unwrap_exception(content)
1367
1370
1368 @spin_first
1371 @spin_first
1369 def hub_history(self):
1372 def hub_history(self):
1370 """Get the Hub's history
1373 """Get the Hub's history
1371
1374
1372 Just like the Client, the Hub has a history, which is a list of msg_ids.
1375 Just like the Client, the Hub has a history, which is a list of msg_ids.
1373 This will contain the history of all clients, and, depending on configuration,
1376 This will contain the history of all clients, and, depending on configuration,
1374 may contain history across multiple cluster sessions.
1377 may contain history across multiple cluster sessions.
1375
1378
1376 Any msg_id returned here is a valid argument to `get_result`.
1379 Any msg_id returned here is a valid argument to `get_result`.
1377
1380
1378 Returns
1381 Returns
1379 -------
1382 -------
1380
1383
1381 msg_ids : list of strs
1384 msg_ids : list of strs
1382 list of all msg_ids, ordered by task submission time.
1385 list of all msg_ids, ordered by task submission time.
1383 """
1386 """
1384
1387
1385 self.session.send(self._query_socket, "history_request", content={})
1388 self.session.send(self._query_socket, "history_request", content={})
1386 idents, msg = self.session.recv(self._query_socket, 0)
1389 idents, msg = self.session.recv(self._query_socket, 0)
1387
1390
1388 if self.debug:
1391 if self.debug:
1389 pprint(msg)
1392 pprint(msg)
1390 content = msg['content']
1393 content = msg['content']
1391 if content['status'] != 'ok':
1394 if content['status'] != 'ok':
1392 raise self._unwrap_exception(content)
1395 raise self._unwrap_exception(content)
1393 else:
1396 else:
1394 return content['history']
1397 return content['history']
1395
1398
1396 @spin_first
1399 @spin_first
1397 def db_query(self, query, keys=None):
1400 def db_query(self, query, keys=None):
1398 """Query the Hub's TaskRecord database
1401 """Query the Hub's TaskRecord database
1399
1402
1400 This will return a list of task record dicts that match `query`
1403 This will return a list of task record dicts that match `query`
1401
1404
1402 Parameters
1405 Parameters
1403 ----------
1406 ----------
1404
1407
1405 query : mongodb query dict
1408 query : mongodb query dict
1406 The search dict. See mongodb query docs for details.
1409 The search dict. See mongodb query docs for details.
1407 keys : list of strs [optional]
1410 keys : list of strs [optional]
1408 The subset of keys to be returned. The default is to fetch everything but buffers.
1411 The subset of keys to be returned. The default is to fetch everything but buffers.
1409 'msg_id' will *always* be included.
1412 'msg_id' will *always* be included.
1410 """
1413 """
1411 if isinstance(keys, basestring):
1414 if isinstance(keys, basestring):
1412 keys = [keys]
1415 keys = [keys]
1413 content = dict(query=query, keys=keys)
1416 content = dict(query=query, keys=keys)
1414 self.session.send(self._query_socket, "db_request", content=content)
1417 self.session.send(self._query_socket, "db_request", content=content)
1415 idents, msg = self.session.recv(self._query_socket, 0)
1418 idents, msg = self.session.recv(self._query_socket, 0)
1416 if self.debug:
1419 if self.debug:
1417 pprint(msg)
1420 pprint(msg)
1418 content = msg['content']
1421 content = msg['content']
1419 if content['status'] != 'ok':
1422 if content['status'] != 'ok':
1420 raise self._unwrap_exception(content)
1423 raise self._unwrap_exception(content)
1421
1424
1422 records = content['records']
1425 records = content['records']
1423
1426
1424 buffer_lens = content['buffer_lens']
1427 buffer_lens = content['buffer_lens']
1425 result_buffer_lens = content['result_buffer_lens']
1428 result_buffer_lens = content['result_buffer_lens']
1426 buffers = msg['buffers']
1429 buffers = msg['buffers']
1427 has_bufs = buffer_lens is not None
1430 has_bufs = buffer_lens is not None
1428 has_rbufs = result_buffer_lens is not None
1431 has_rbufs = result_buffer_lens is not None
1429 for i,rec in enumerate(records):
1432 for i,rec in enumerate(records):
1430 # relink buffers
1433 # relink buffers
1431 if has_bufs:
1434 if has_bufs:
1432 blen = buffer_lens[i]
1435 blen = buffer_lens[i]
1433 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1436 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1434 if has_rbufs:
1437 if has_rbufs:
1435 blen = result_buffer_lens[i]
1438 blen = result_buffer_lens[i]
1436 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1439 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1437
1440
1438 return records
1441 return records
1439
1442
1440 __all__ = [ 'Client' ]
1443 __all__ = [ 'Client' ]
@@ -1,1057 +1,1059 b''
1 """Views of remote engines.
1 """Views of remote engines.
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import imp
18 import imp
19 import sys
19 import sys
20 import warnings
20 import warnings
21 from contextlib import contextmanager
21 from contextlib import contextmanager
22 from types import ModuleType
22 from types import ModuleType
23
23
24 import zmq
24 import zmq
25
25
26 from IPython.testing.skipdoctest import skip_doctest
26 from IPython.testing.skipdoctest import skip_doctest
27 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 )
29 )
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31
31
32 from IPython.parallel import util
32 from IPython.parallel import util
33 from IPython.parallel.controller.dependency import Dependency, dependent
33 from IPython.parallel.controller.dependency import Dependency, dependent
34
34
35 from . import map as Map
35 from . import map as Map
36 from .asyncresult import AsyncResult, AsyncMapResult
36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote
37 from .remotefunction import ParallelFunction, parallel, remote
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Decorators
40 # Decorators
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 @decorator
43 @decorator
44 def save_ids(f, self, *args, **kwargs):
44 def save_ids(f, self, *args, **kwargs):
45 """Keep our history and outstanding attributes up to date after a method call."""
45 """Keep our history and outstanding attributes up to date after a method call."""
46 n_previous = len(self.client.history)
46 n_previous = len(self.client.history)
47 try:
47 try:
48 ret = f(self, *args, **kwargs)
48 ret = f(self, *args, **kwargs)
49 finally:
49 finally:
50 nmsgs = len(self.client.history) - n_previous
50 nmsgs = len(self.client.history) - n_previous
51 msg_ids = self.client.history[-nmsgs:]
51 msg_ids = self.client.history[-nmsgs:]
52 self.history.extend(msg_ids)
52 self.history.extend(msg_ids)
53 map(self.outstanding.add, msg_ids)
53 map(self.outstanding.add, msg_ids)
54 return ret
54 return ret
55
55
56 @decorator
56 @decorator
57 def sync_results(f, self, *args, **kwargs):
57 def sync_results(f, self, *args, **kwargs):
58 """sync relevant results from self.client to our results attribute."""
58 """sync relevant results from self.client to our results attribute."""
59 ret = f(self, *args, **kwargs)
59 ret = f(self, *args, **kwargs)
60 delta = self.outstanding.difference(self.client.outstanding)
60 delta = self.outstanding.difference(self.client.outstanding)
61 completed = self.outstanding.intersection(delta)
61 completed = self.outstanding.intersection(delta)
62 self.outstanding = self.outstanding.difference(completed)
62 self.outstanding = self.outstanding.difference(completed)
63 for msg_id in completed:
63 for msg_id in completed:
64 self.results[msg_id] = self.client.results[msg_id]
64 self.results[msg_id] = self.client.results[msg_id]
65 return ret
65 return ret
66
66
67 @decorator
67 @decorator
68 def spin_after(f, self, *args, **kwargs):
68 def spin_after(f, self, *args, **kwargs):
69 """call spin after the method."""
69 """call spin after the method."""
70 ret = f(self, *args, **kwargs)
70 ret = f(self, *args, **kwargs)
71 self.spin()
71 self.spin()
72 return ret
72 return ret
73
73
74 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
75 # Classes
75 # Classes
76 #-----------------------------------------------------------------------------
76 #-----------------------------------------------------------------------------
77
77
78 @skip_doctest
78 @skip_doctest
79 class View(HasTraits):
79 class View(HasTraits):
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81
81
82 Don't use this class, use subclasses.
82 Don't use this class, use subclasses.
83
83
84 Methods
84 Methods
85 -------
85 -------
86
86
87 spin
87 spin
88 flushes incoming results and registration state changes
88 flushes incoming results and registration state changes
89 control methods spin, and requesting `ids` also ensures up to date
89 control methods spin, and requesting `ids` also ensures up to date
90
90
91 wait
91 wait
92 wait on one or more msg_ids
92 wait on one or more msg_ids
93
93
94 execution methods
94 execution methods
95 apply
95 apply
96 legacy: execute, run
96 legacy: execute, run
97
97
98 data movement
98 data movement
99 push, pull, scatter, gather
99 push, pull, scatter, gather
100
100
101 query methods
101 query methods
102 get_result, queue_status, purge_results, result_status
102 get_result, queue_status, purge_results, result_status
103
103
104 control methods
104 control methods
105 abort, shutdown
105 abort, shutdown
106
106
107 """
107 """
108 # flags
108 # flags
109 block=Bool(False)
109 block=Bool(False)
110 track=Bool(True)
110 track=Bool(True)
111 targets = Any()
111 targets = Any()
112
112
113 history=List()
113 history=List()
114 outstanding = Set()
114 outstanding = Set()
115 results = Dict()
115 results = Dict()
116 client = Instance('IPython.parallel.Client')
116 client = Instance('IPython.parallel.Client')
117
117
118 _socket = Instance('zmq.Socket')
118 _socket = Instance('zmq.Socket')
119 _flag_names = List(['targets', 'block', 'track'])
119 _flag_names = List(['targets', 'block', 'track'])
120 _targets = Any()
120 _targets = Any()
121 _idents = Any()
121 _idents = Any()
122
122
123 def __init__(self, client=None, socket=None, **flags):
123 def __init__(self, client=None, socket=None, **flags):
124 super(View, self).__init__(client=client, _socket=socket)
124 super(View, self).__init__(client=client, _socket=socket)
125 self.block = client.block
125 self.block = client.block
126
126
127 self.set_flags(**flags)
127 self.set_flags(**flags)
128
128
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130
130
131
131
132 def __repr__(self):
132 def __repr__(self):
133 strtargets = str(self.targets)
133 strtargets = str(self.targets)
134 if len(strtargets) > 16:
134 if len(strtargets) > 16:
135 strtargets = strtargets[:12]+'...]'
135 strtargets = strtargets[:12]+'...]'
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137
137
138 def set_flags(self, **kwargs):
138 def set_flags(self, **kwargs):
139 """set my attribute flags by keyword.
139 """set my attribute flags by keyword.
140
140
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 These attributes can be set all at once by name with this method.
142 These attributes can be set all at once by name with this method.
143
143
144 Parameters
144 Parameters
145 ----------
145 ----------
146
146
147 block : bool
147 block : bool
148 whether to wait for results
148 whether to wait for results
149 track : bool
149 track : bool
150 whether to create a MessageTracker to allow the user to
150 whether to create a MessageTracker to allow the user to
151 safely edit after arrays and buffers during non-copying
151 safely edit after arrays and buffers during non-copying
152 sends.
152 sends.
153 """
153 """
154 for name, value in kwargs.iteritems():
154 for name, value in kwargs.iteritems():
155 if name not in self._flag_names:
155 if name not in self._flag_names:
156 raise KeyError("Invalid name: %r"%name)
156 raise KeyError("Invalid name: %r"%name)
157 else:
157 else:
158 setattr(self, name, value)
158 setattr(self, name, value)
159
159
160 @contextmanager
160 @contextmanager
161 def temp_flags(self, **kwargs):
161 def temp_flags(self, **kwargs):
162 """temporarily set flags, for use in `with` statements.
162 """temporarily set flags, for use in `with` statements.
163
163
164 See set_flags for permanent setting of flags
164 See set_flags for permanent setting of flags
165
165
166 Examples
166 Examples
167 --------
167 --------
168
168
169 >>> view.track=False
169 >>> view.track=False
170 ...
170 ...
171 >>> with view.temp_flags(track=True):
171 >>> with view.temp_flags(track=True):
172 ... ar = view.apply(dostuff, my_big_array)
172 ... ar = view.apply(dostuff, my_big_array)
173 ... ar.tracker.wait() # wait for send to finish
173 ... ar.tracker.wait() # wait for send to finish
174 >>> view.track
174 >>> view.track
175 False
175 False
176
176
177 """
177 """
178 # preflight: save flags, and set temporaries
178 # preflight: save flags, and set temporaries
179 saved_flags = {}
179 saved_flags = {}
180 for f in self._flag_names:
180 for f in self._flag_names:
181 saved_flags[f] = getattr(self, f)
181 saved_flags[f] = getattr(self, f)
182 self.set_flags(**kwargs)
182 self.set_flags(**kwargs)
183 # yield to the with-statement block
183 # yield to the with-statement block
184 try:
184 try:
185 yield
185 yield
186 finally:
186 finally:
187 # postflight: restore saved flags
187 # postflight: restore saved flags
188 self.set_flags(**saved_flags)
188 self.set_flags(**saved_flags)
189
189
190
190
191 #----------------------------------------------------------------
191 #----------------------------------------------------------------
192 # apply
192 # apply
193 #----------------------------------------------------------------
193 #----------------------------------------------------------------
194
194
195 @sync_results
195 @sync_results
196 @save_ids
196 @save_ids
197 def _really_apply(self, f, args, kwargs, block=None, **options):
197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 """wrapper for client.send_apply_message"""
198 """wrapper for client.send_apply_message"""
199 raise NotImplementedError("Implement in subclasses")
199 raise NotImplementedError("Implement in subclasses")
200
200
201 def apply(self, f, *args, **kwargs):
201 def apply(self, f, *args, **kwargs):
202 """calls f(*args, **kwargs) on remote engines, returning the result.
202 """calls f(*args, **kwargs) on remote engines, returning the result.
203
203
204 This method sets all apply flags via this View's attributes.
204 This method sets all apply flags via this View's attributes.
205
205
206 if self.block is False:
206 if self.block is False:
207 returns AsyncResult
207 returns AsyncResult
208 else:
208 else:
209 returns actual result of f(*args, **kwargs)
209 returns actual result of f(*args, **kwargs)
210 """
210 """
211 return self._really_apply(f, args, kwargs)
211 return self._really_apply(f, args, kwargs)
212
212
213 def apply_async(self, f, *args, **kwargs):
213 def apply_async(self, f, *args, **kwargs):
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215
215
216 returns AsyncResult
216 returns AsyncResult
217 """
217 """
218 return self._really_apply(f, args, kwargs, block=False)
218 return self._really_apply(f, args, kwargs, block=False)
219
219
220 @spin_after
220 @spin_after
221 def apply_sync(self, f, *args, **kwargs):
221 def apply_sync(self, f, *args, **kwargs):
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 returning the result.
223 returning the result.
224
224
225 returns: actual result of f(*args, **kwargs)
225 returns: actual result of f(*args, **kwargs)
226 """
226 """
227 return self._really_apply(f, args, kwargs, block=True)
227 return self._really_apply(f, args, kwargs, block=True)
228
228
229 #----------------------------------------------------------------
229 #----------------------------------------------------------------
230 # wrappers for client and control methods
230 # wrappers for client and control methods
231 #----------------------------------------------------------------
231 #----------------------------------------------------------------
232 @sync_results
232 @sync_results
233 def spin(self):
233 def spin(self):
234 """spin the client, and sync"""
234 """spin the client, and sync"""
235 self.client.spin()
235 self.client.spin()
236
236
237 @sync_results
237 @sync_results
238 def wait(self, jobs=None, timeout=-1):
238 def wait(self, jobs=None, timeout=-1):
239 """waits on one or more `jobs`, for up to `timeout` seconds.
239 """waits on one or more `jobs`, for up to `timeout` seconds.
240
240
241 Parameters
241 Parameters
242 ----------
242 ----------
243
243
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 ints are indices to self.history
245 ints are indices to self.history
246 strs are msg_ids
246 strs are msg_ids
247 default: wait on all outstanding messages
247 default: wait on all outstanding messages
248 timeout : float
248 timeout : float
249 a time in seconds, after which to give up.
249 a time in seconds, after which to give up.
250 default is -1, which means no timeout
250 default is -1, which means no timeout
251
251
252 Returns
252 Returns
253 -------
253 -------
254
254
255 True : when all msg_ids are done
255 True : when all msg_ids are done
256 False : timeout reached, some msg_ids still outstanding
256 False : timeout reached, some msg_ids still outstanding
257 """
257 """
258 if jobs is None:
258 if jobs is None:
259 jobs = self.history
259 jobs = self.history
260 return self.client.wait(jobs, timeout)
260 return self.client.wait(jobs, timeout)
261
261
262 def abort(self, jobs=None, targets=None, block=None):
262 def abort(self, jobs=None, targets=None, block=None):
263 """Abort jobs on my engines.
263 """Abort jobs on my engines.
264
264
265 Parameters
265 Parameters
266 ----------
266 ----------
267
267
268 jobs : None, str, list of strs, optional
268 jobs : None, str, list of strs, optional
269 if None: abort all jobs.
269 if None: abort all jobs.
270 else: abort specific msg_id(s).
270 else: abort specific msg_id(s).
271 """
271 """
272 block = block if block is not None else self.block
272 block = block if block is not None else self.block
273 targets = targets if targets is not None else self.targets
273 targets = targets if targets is not None else self.targets
274 jobs = jobs if jobs is not None else list(self.outstanding)
275
274 return self.client.abort(jobs=jobs, targets=targets, block=block)
276 return self.client.abort(jobs=jobs, targets=targets, block=block)
275
277
276 def queue_status(self, targets=None, verbose=False):
278 def queue_status(self, targets=None, verbose=False):
277 """Fetch the Queue status of my engines"""
279 """Fetch the Queue status of my engines"""
278 targets = targets if targets is not None else self.targets
280 targets = targets if targets is not None else self.targets
279 return self.client.queue_status(targets=targets, verbose=verbose)
281 return self.client.queue_status(targets=targets, verbose=verbose)
280
282
281 def purge_results(self, jobs=[], targets=[]):
283 def purge_results(self, jobs=[], targets=[]):
282 """Instruct the controller to forget specific results."""
284 """Instruct the controller to forget specific results."""
283 if targets is None or targets == 'all':
285 if targets is None or targets == 'all':
284 targets = self.targets
286 targets = self.targets
285 return self.client.purge_results(jobs=jobs, targets=targets)
287 return self.client.purge_results(jobs=jobs, targets=targets)
286
288
287 def shutdown(self, targets=None, restart=False, hub=False, block=None):
289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
288 """Terminates one or more engine processes, optionally including the hub.
290 """Terminates one or more engine processes, optionally including the hub.
289 """
291 """
290 block = self.block if block is None else block
292 block = self.block if block is None else block
291 if targets is None or targets == 'all':
293 if targets is None or targets == 'all':
292 targets = self.targets
294 targets = self.targets
293 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
294
296
295 @spin_after
297 @spin_after
296 def get_result(self, indices_or_msg_ids=None):
298 def get_result(self, indices_or_msg_ids=None):
297 """return one or more results, specified by history index or msg_id.
299 """return one or more results, specified by history index or msg_id.
298
300
299 See client.get_result for details.
301 See client.get_result for details.
300
302
301 """
303 """
302
304
303 if indices_or_msg_ids is None:
305 if indices_or_msg_ids is None:
304 indices_or_msg_ids = -1
306 indices_or_msg_ids = -1
305 if isinstance(indices_or_msg_ids, int):
307 if isinstance(indices_or_msg_ids, int):
306 indices_or_msg_ids = self.history[indices_or_msg_ids]
308 indices_or_msg_ids = self.history[indices_or_msg_ids]
307 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
308 indices_or_msg_ids = list(indices_or_msg_ids)
310 indices_or_msg_ids = list(indices_or_msg_ids)
309 for i,index in enumerate(indices_or_msg_ids):
311 for i,index in enumerate(indices_or_msg_ids):
310 if isinstance(index, int):
312 if isinstance(index, int):
311 indices_or_msg_ids[i] = self.history[index]
313 indices_or_msg_ids[i] = self.history[index]
312 return self.client.get_result(indices_or_msg_ids)
314 return self.client.get_result(indices_or_msg_ids)
313
315
314 #-------------------------------------------------------------------
316 #-------------------------------------------------------------------
315 # Map
317 # Map
316 #-------------------------------------------------------------------
318 #-------------------------------------------------------------------
317
319
318 def map(self, f, *sequences, **kwargs):
320 def map(self, f, *sequences, **kwargs):
319 """override in subclasses"""
321 """override in subclasses"""
320 raise NotImplementedError
322 raise NotImplementedError
321
323
322 def map_async(self, f, *sequences, **kwargs):
324 def map_async(self, f, *sequences, **kwargs):
323 """Parallel version of builtin `map`, using this view's engines.
325 """Parallel version of builtin `map`, using this view's engines.
324
326
325 This is equivalent to map(...block=False)
327 This is equivalent to map(...block=False)
326
328
327 See `self.map` for details.
329 See `self.map` for details.
328 """
330 """
329 if 'block' in kwargs:
331 if 'block' in kwargs:
330 raise TypeError("map_async doesn't take a `block` keyword argument.")
332 raise TypeError("map_async doesn't take a `block` keyword argument.")
331 kwargs['block'] = False
333 kwargs['block'] = False
332 return self.map(f,*sequences,**kwargs)
334 return self.map(f,*sequences,**kwargs)
333
335
334 def map_sync(self, f, *sequences, **kwargs):
336 def map_sync(self, f, *sequences, **kwargs):
335 """Parallel version of builtin `map`, using this view's engines.
337 """Parallel version of builtin `map`, using this view's engines.
336
338
337 This is equivalent to map(...block=True)
339 This is equivalent to map(...block=True)
338
340
339 See `self.map` for details.
341 See `self.map` for details.
340 """
342 """
341 if 'block' in kwargs:
343 if 'block' in kwargs:
342 raise TypeError("map_sync doesn't take a `block` keyword argument.")
344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
343 kwargs['block'] = True
345 kwargs['block'] = True
344 return self.map(f,*sequences,**kwargs)
346 return self.map(f,*sequences,**kwargs)
345
347
346 def imap(self, f, *sequences, **kwargs):
348 def imap(self, f, *sequences, **kwargs):
347 """Parallel version of `itertools.imap`.
349 """Parallel version of `itertools.imap`.
348
350
349 See `self.map` for details.
351 See `self.map` for details.
350
352
351 """
353 """
352
354
353 return iter(self.map_async(f,*sequences, **kwargs))
355 return iter(self.map_async(f,*sequences, **kwargs))
354
356
355 #-------------------------------------------------------------------
357 #-------------------------------------------------------------------
356 # Decorators
358 # Decorators
357 #-------------------------------------------------------------------
359 #-------------------------------------------------------------------
358
360
359 def remote(self, block=True, **flags):
361 def remote(self, block=True, **flags):
360 """Decorator for making a RemoteFunction"""
362 """Decorator for making a RemoteFunction"""
361 block = self.block if block is None else block
363 block = self.block if block is None else block
362 return remote(self, block=block, **flags)
364 return remote(self, block=block, **flags)
363
365
364 def parallel(self, dist='b', block=None, **flags):
366 def parallel(self, dist='b', block=None, **flags):
365 """Decorator for making a ParallelFunction"""
367 """Decorator for making a ParallelFunction"""
366 block = self.block if block is None else block
368 block = self.block if block is None else block
367 return parallel(self, dist=dist, block=block, **flags)
369 return parallel(self, dist=dist, block=block, **flags)
368
370
369 @skip_doctest
371 @skip_doctest
370 class DirectView(View):
372 class DirectView(View):
371 """Direct Multiplexer View of one or more engines.
373 """Direct Multiplexer View of one or more engines.
372
374
373 These are created via indexed access to a client:
375 These are created via indexed access to a client:
374
376
375 >>> dv_1 = client[1]
377 >>> dv_1 = client[1]
376 >>> dv_all = client[:]
378 >>> dv_all = client[:]
377 >>> dv_even = client[::2]
379 >>> dv_even = client[::2]
378 >>> dv_some = client[1:3]
380 >>> dv_some = client[1:3]
379
381
380 This object provides dictionary access to engine namespaces:
382 This object provides dictionary access to engine namespaces:
381
383
382 # push a=5:
384 # push a=5:
383 >>> dv['a'] = 5
385 >>> dv['a'] = 5
384 # pull 'foo':
386 # pull 'foo':
385 >>> db['foo']
387 >>> db['foo']
386
388
387 """
389 """
388
390
389 def __init__(self, client=None, socket=None, targets=None):
391 def __init__(self, client=None, socket=None, targets=None):
390 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
391
393
392 @property
394 @property
393 def importer(self):
395 def importer(self):
394 """sync_imports(local=True) as a property.
396 """sync_imports(local=True) as a property.
395
397
396 See sync_imports for details.
398 See sync_imports for details.
397
399
398 """
400 """
399 return self.sync_imports(True)
401 return self.sync_imports(True)
400
402
401 @contextmanager
403 @contextmanager
402 def sync_imports(self, local=True):
404 def sync_imports(self, local=True):
403 """Context Manager for performing simultaneous local and remote imports.
405 """Context Manager for performing simultaneous local and remote imports.
404
406
405 'import x as y' will *not* work. The 'as y' part will simply be ignored.
407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
406
408
407 >>> with view.sync_imports():
409 >>> with view.sync_imports():
408 ... from numpy import recarray
410 ... from numpy import recarray
409 importing recarray from numpy on engine(s)
411 importing recarray from numpy on engine(s)
410
412
411 """
413 """
412 import __builtin__
414 import __builtin__
413 local_import = __builtin__.__import__
415 local_import = __builtin__.__import__
414 modules = set()
416 modules = set()
415 results = []
417 results = []
416 @util.interactive
418 @util.interactive
417 def remote_import(name, fromlist, level):
419 def remote_import(name, fromlist, level):
418 """the function to be passed to apply, that actually performs the import
420 """the function to be passed to apply, that actually performs the import
419 on the engine, and loads up the user namespace.
421 on the engine, and loads up the user namespace.
420 """
422 """
421 import sys
423 import sys
422 user_ns = globals()
424 user_ns = globals()
423 mod = __import__(name, fromlist=fromlist, level=level)
425 mod = __import__(name, fromlist=fromlist, level=level)
424 if fromlist:
426 if fromlist:
425 for key in fromlist:
427 for key in fromlist:
426 user_ns[key] = getattr(mod, key)
428 user_ns[key] = getattr(mod, key)
427 else:
429 else:
428 user_ns[name] = sys.modules[name]
430 user_ns[name] = sys.modules[name]
429
431
430 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
432 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
431 """the drop-in replacement for __import__, that optionally imports
433 """the drop-in replacement for __import__, that optionally imports
432 locally as well.
434 locally as well.
433 """
435 """
434 # don't override nested imports
436 # don't override nested imports
435 save_import = __builtin__.__import__
437 save_import = __builtin__.__import__
436 __builtin__.__import__ = local_import
438 __builtin__.__import__ = local_import
437
439
438 if imp.lock_held():
440 if imp.lock_held():
439 # this is a side-effect import, don't do it remotely, or even
441 # this is a side-effect import, don't do it remotely, or even
440 # ignore the local effects
442 # ignore the local effects
441 return local_import(name, globals, locals, fromlist, level)
443 return local_import(name, globals, locals, fromlist, level)
442
444
443 imp.acquire_lock()
445 imp.acquire_lock()
444 if local:
446 if local:
445 mod = local_import(name, globals, locals, fromlist, level)
447 mod = local_import(name, globals, locals, fromlist, level)
446 else:
448 else:
447 raise NotImplementedError("remote-only imports not yet implemented")
449 raise NotImplementedError("remote-only imports not yet implemented")
448 imp.release_lock()
450 imp.release_lock()
449
451
450 key = name+':'+','.join(fromlist or [])
452 key = name+':'+','.join(fromlist or [])
451 if level == -1 and key not in modules:
453 if level == -1 and key not in modules:
452 modules.add(key)
454 modules.add(key)
453 if fromlist:
455 if fromlist:
454 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
456 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
455 else:
457 else:
456 print "importing %s on engine(s)"%name
458 print "importing %s on engine(s)"%name
457 results.append(self.apply_async(remote_import, name, fromlist, level))
459 results.append(self.apply_async(remote_import, name, fromlist, level))
458 # restore override
460 # restore override
459 __builtin__.__import__ = save_import
461 __builtin__.__import__ = save_import
460
462
461 return mod
463 return mod
462
464
463 # override __import__
465 # override __import__
464 __builtin__.__import__ = view_import
466 __builtin__.__import__ = view_import
465 try:
467 try:
466 # enter the block
468 # enter the block
467 yield
469 yield
468 except ImportError:
470 except ImportError:
469 if not local:
471 if not local:
470 # ignore import errors if not doing local imports
472 # ignore import errors if not doing local imports
471 pass
473 pass
472 finally:
474 finally:
473 # always restore __import__
475 # always restore __import__
474 __builtin__.__import__ = local_import
476 __builtin__.__import__ = local_import
475
477
476 for r in results:
478 for r in results:
477 # raise possible remote ImportErrors here
479 # raise possible remote ImportErrors here
478 r.get()
480 r.get()
479
481
480
482
481 @sync_results
483 @sync_results
482 @save_ids
484 @save_ids
483 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
485 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
484 """calls f(*args, **kwargs) on remote engines, returning the result.
486 """calls f(*args, **kwargs) on remote engines, returning the result.
485
487
486 This method sets all of `apply`'s flags via this View's attributes.
488 This method sets all of `apply`'s flags via this View's attributes.
487
489
488 Parameters
490 Parameters
489 ----------
491 ----------
490
492
491 f : callable
493 f : callable
492
494
493 args : list [default: empty]
495 args : list [default: empty]
494
496
495 kwargs : dict [default: empty]
497 kwargs : dict [default: empty]
496
498
497 targets : target list [default: self.targets]
499 targets : target list [default: self.targets]
498 where to run
500 where to run
499 block : bool [default: self.block]
501 block : bool [default: self.block]
500 whether to block
502 whether to block
501 track : bool [default: self.track]
503 track : bool [default: self.track]
502 whether to ask zmq to track the message, for safe non-copying sends
504 whether to ask zmq to track the message, for safe non-copying sends
503
505
504 Returns
506 Returns
505 -------
507 -------
506
508
507 if self.block is False:
509 if self.block is False:
508 returns AsyncResult
510 returns AsyncResult
509 else:
511 else:
510 returns actual result of f(*args, **kwargs) on the engine(s)
512 returns actual result of f(*args, **kwargs) on the engine(s)
511 This will be a list of self.targets is also a list (even length 1), or
513 This will be a list of self.targets is also a list (even length 1), or
512 the single result if self.targets is an integer engine id
514 the single result if self.targets is an integer engine id
513 """
515 """
514 args = [] if args is None else args
516 args = [] if args is None else args
515 kwargs = {} if kwargs is None else kwargs
517 kwargs = {} if kwargs is None else kwargs
516 block = self.block if block is None else block
518 block = self.block if block is None else block
517 track = self.track if track is None else track
519 track = self.track if track is None else track
518 targets = self.targets if targets is None else targets
520 targets = self.targets if targets is None else targets
519
521
520 _idents = self.client._build_targets(targets)[0]
522 _idents = self.client._build_targets(targets)[0]
521 msg_ids = []
523 msg_ids = []
522 trackers = []
524 trackers = []
523 for ident in _idents:
525 for ident in _idents:
524 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
526 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
525 ident=ident)
527 ident=ident)
526 if track:
528 if track:
527 trackers.append(msg['tracker'])
529 trackers.append(msg['tracker'])
528 msg_ids.append(msg['header']['msg_id'])
530 msg_ids.append(msg['header']['msg_id'])
529 tracker = None if track is False else zmq.MessageTracker(*trackers)
531 tracker = None if track is False else zmq.MessageTracker(*trackers)
530 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
532 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
531 if block:
533 if block:
532 try:
534 try:
533 return ar.get()
535 return ar.get()
534 except KeyboardInterrupt:
536 except KeyboardInterrupt:
535 pass
537 pass
536 return ar
538 return ar
537
539
538 @spin_after
540 @spin_after
539 def map(self, f, *sequences, **kwargs):
541 def map(self, f, *sequences, **kwargs):
540 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
542 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
541
543
542 Parallel version of builtin `map`, using this View's `targets`.
544 Parallel version of builtin `map`, using this View's `targets`.
543
545
544 There will be one task per target, so work will be chunked
546 There will be one task per target, so work will be chunked
545 if the sequences are longer than `targets`.
547 if the sequences are longer than `targets`.
546
548
547 Results can be iterated as they are ready, but will become available in chunks.
549 Results can be iterated as they are ready, but will become available in chunks.
548
550
549 Parameters
551 Parameters
550 ----------
552 ----------
551
553
552 f : callable
554 f : callable
553 function to be mapped
555 function to be mapped
554 *sequences: one or more sequences of matching length
556 *sequences: one or more sequences of matching length
555 the sequences to be distributed and passed to `f`
557 the sequences to be distributed and passed to `f`
556 block : bool
558 block : bool
557 whether to wait for the result or not [default self.block]
559 whether to wait for the result or not [default self.block]
558
560
559 Returns
561 Returns
560 -------
562 -------
561
563
562 if block=False:
564 if block=False:
563 AsyncMapResult
565 AsyncMapResult
564 An object like AsyncResult, but which reassembles the sequence of results
566 An object like AsyncResult, but which reassembles the sequence of results
565 into a single list. AsyncMapResults can be iterated through before all
567 into a single list. AsyncMapResults can be iterated through before all
566 results are complete.
568 results are complete.
567 else:
569 else:
568 list
570 list
569 the result of map(f,*sequences)
571 the result of map(f,*sequences)
570 """
572 """
571
573
572 block = kwargs.pop('block', self.block)
574 block = kwargs.pop('block', self.block)
573 for k in kwargs.keys():
575 for k in kwargs.keys():
574 if k not in ['block', 'track']:
576 if k not in ['block', 'track']:
575 raise TypeError("invalid keyword arg, %r"%k)
577 raise TypeError("invalid keyword arg, %r"%k)
576
578
577 assert len(sequences) > 0, "must have some sequences to map onto!"
579 assert len(sequences) > 0, "must have some sequences to map onto!"
578 pf = ParallelFunction(self, f, block=block, **kwargs)
580 pf = ParallelFunction(self, f, block=block, **kwargs)
579 return pf.map(*sequences)
581 return pf.map(*sequences)
580
582
581 def execute(self, code, targets=None, block=None):
583 def execute(self, code, targets=None, block=None):
582 """Executes `code` on `targets` in blocking or nonblocking manner.
584 """Executes `code` on `targets` in blocking or nonblocking manner.
583
585
584 ``execute`` is always `bound` (affects engine namespace)
586 ``execute`` is always `bound` (affects engine namespace)
585
587
586 Parameters
588 Parameters
587 ----------
589 ----------
588
590
589 code : str
591 code : str
590 the code string to be executed
592 the code string to be executed
591 block : bool
593 block : bool
592 whether or not to wait until done to return
594 whether or not to wait until done to return
593 default: self.block
595 default: self.block
594 """
596 """
595 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
597 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
596
598
597 def run(self, filename, targets=None, block=None):
599 def run(self, filename, targets=None, block=None):
598 """Execute contents of `filename` on my engine(s).
600 """Execute contents of `filename` on my engine(s).
599
601
600 This simply reads the contents of the file and calls `execute`.
602 This simply reads the contents of the file and calls `execute`.
601
603
602 Parameters
604 Parameters
603 ----------
605 ----------
604
606
605 filename : str
607 filename : str
606 The path to the file
608 The path to the file
607 targets : int/str/list of ints/strs
609 targets : int/str/list of ints/strs
608 the engines on which to execute
610 the engines on which to execute
609 default : all
611 default : all
610 block : bool
612 block : bool
611 whether or not to wait until done
613 whether or not to wait until done
612 default: self.block
614 default: self.block
613
615
614 """
616 """
615 with open(filename, 'r') as f:
617 with open(filename, 'r') as f:
616 # add newline in case of trailing indented whitespace
618 # add newline in case of trailing indented whitespace
617 # which will cause SyntaxError
619 # which will cause SyntaxError
618 code = f.read()+'\n'
620 code = f.read()+'\n'
619 return self.execute(code, block=block, targets=targets)
621 return self.execute(code, block=block, targets=targets)
620
622
621 def update(self, ns):
623 def update(self, ns):
622 """update remote namespace with dict `ns`
624 """update remote namespace with dict `ns`
623
625
624 See `push` for details.
626 See `push` for details.
625 """
627 """
626 return self.push(ns, block=self.block, track=self.track)
628 return self.push(ns, block=self.block, track=self.track)
627
629
628 def push(self, ns, targets=None, block=None, track=None):
630 def push(self, ns, targets=None, block=None, track=None):
629 """update remote namespace with dict `ns`
631 """update remote namespace with dict `ns`
630
632
631 Parameters
633 Parameters
632 ----------
634 ----------
633
635
634 ns : dict
636 ns : dict
635 dict of keys with which to update engine namespace(s)
637 dict of keys with which to update engine namespace(s)
636 block : bool [default : self.block]
638 block : bool [default : self.block]
637 whether to wait to be notified of engine receipt
639 whether to wait to be notified of engine receipt
638
640
639 """
641 """
640
642
641 block = block if block is not None else self.block
643 block = block if block is not None else self.block
642 track = track if track is not None else self.track
644 track = track if track is not None else self.track
643 targets = targets if targets is not None else self.targets
645 targets = targets if targets is not None else self.targets
644 # applier = self.apply_sync if block else self.apply_async
646 # applier = self.apply_sync if block else self.apply_async
645 if not isinstance(ns, dict):
647 if not isinstance(ns, dict):
646 raise TypeError("Must be a dict, not %s"%type(ns))
648 raise TypeError("Must be a dict, not %s"%type(ns))
647 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
649 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
648
650
649 def get(self, key_s):
651 def get(self, key_s):
650 """get object(s) by `key_s` from remote namespace
652 """get object(s) by `key_s` from remote namespace
651
653
652 see `pull` for details.
654 see `pull` for details.
653 """
655 """
654 # block = block if block is not None else self.block
656 # block = block if block is not None else self.block
655 return self.pull(key_s, block=True)
657 return self.pull(key_s, block=True)
656
658
657 def pull(self, names, targets=None, block=None):
659 def pull(self, names, targets=None, block=None):
658 """get object(s) by `name` from remote namespace
660 """get object(s) by `name` from remote namespace
659
661
660 will return one object if it is a key.
662 will return one object if it is a key.
661 can also take a list of keys, in which case it will return a list of objects.
663 can also take a list of keys, in which case it will return a list of objects.
662 """
664 """
663 block = block if block is not None else self.block
665 block = block if block is not None else self.block
664 targets = targets if targets is not None else self.targets
666 targets = targets if targets is not None else self.targets
665 applier = self.apply_sync if block else self.apply_async
667 applier = self.apply_sync if block else self.apply_async
666 if isinstance(names, basestring):
668 if isinstance(names, basestring):
667 pass
669 pass
668 elif isinstance(names, (list,tuple,set)):
670 elif isinstance(names, (list,tuple,set)):
669 for key in names:
671 for key in names:
670 if not isinstance(key, basestring):
672 if not isinstance(key, basestring):
671 raise TypeError("keys must be str, not type %r"%type(key))
673 raise TypeError("keys must be str, not type %r"%type(key))
672 else:
674 else:
673 raise TypeError("names must be strs, not %r"%names)
675 raise TypeError("names must be strs, not %r"%names)
674 return self._really_apply(util._pull, (names,), block=block, targets=targets)
676 return self._really_apply(util._pull, (names,), block=block, targets=targets)
675
677
676 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
678 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
677 """
679 """
678 Partition a Python sequence and send the partitions to a set of engines.
680 Partition a Python sequence and send the partitions to a set of engines.
679 """
681 """
680 block = block if block is not None else self.block
682 block = block if block is not None else self.block
681 track = track if track is not None else self.track
683 track = track if track is not None else self.track
682 targets = targets if targets is not None else self.targets
684 targets = targets if targets is not None else self.targets
683
685
684 mapObject = Map.dists[dist]()
686 mapObject = Map.dists[dist]()
685 nparts = len(targets)
687 nparts = len(targets)
686 msg_ids = []
688 msg_ids = []
687 trackers = []
689 trackers = []
688 for index, engineid in enumerate(targets):
690 for index, engineid in enumerate(targets):
689 partition = mapObject.getPartition(seq, index, nparts)
691 partition = mapObject.getPartition(seq, index, nparts)
690 if flatten and len(partition) == 1:
692 if flatten and len(partition) == 1:
691 ns = {key: partition[0]}
693 ns = {key: partition[0]}
692 else:
694 else:
693 ns = {key: partition}
695 ns = {key: partition}
694 r = self.push(ns, block=False, track=track, targets=engineid)
696 r = self.push(ns, block=False, track=track, targets=engineid)
695 msg_ids.extend(r.msg_ids)
697 msg_ids.extend(r.msg_ids)
696 if track:
698 if track:
697 trackers.append(r._tracker)
699 trackers.append(r._tracker)
698
700
699 if track:
701 if track:
700 tracker = zmq.MessageTracker(*trackers)
702 tracker = zmq.MessageTracker(*trackers)
701 else:
703 else:
702 tracker = None
704 tracker = None
703
705
704 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
706 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
705 if block:
707 if block:
706 r.wait()
708 r.wait()
707 else:
709 else:
708 return r
710 return r
709
711
710 @sync_results
712 @sync_results
711 @save_ids
713 @save_ids
712 def gather(self, key, dist='b', targets=None, block=None):
714 def gather(self, key, dist='b', targets=None, block=None):
713 """
715 """
714 Gather a partitioned sequence on a set of engines as a single local seq.
716 Gather a partitioned sequence on a set of engines as a single local seq.
715 """
717 """
716 block = block if block is not None else self.block
718 block = block if block is not None else self.block
717 targets = targets if targets is not None else self.targets
719 targets = targets if targets is not None else self.targets
718 mapObject = Map.dists[dist]()
720 mapObject = Map.dists[dist]()
719 msg_ids = []
721 msg_ids = []
720
722
721 for index, engineid in enumerate(targets):
723 for index, engineid in enumerate(targets):
722 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
724 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
723
725
724 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
726 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
725
727
726 if block:
728 if block:
727 try:
729 try:
728 return r.get()
730 return r.get()
729 except KeyboardInterrupt:
731 except KeyboardInterrupt:
730 pass
732 pass
731 return r
733 return r
732
734
733 def __getitem__(self, key):
735 def __getitem__(self, key):
734 return self.get(key)
736 return self.get(key)
735
737
736 def __setitem__(self,key, value):
738 def __setitem__(self,key, value):
737 self.update({key:value})
739 self.update({key:value})
738
740
739 def clear(self, targets=None, block=False):
741 def clear(self, targets=None, block=False):
740 """Clear the remote namespaces on my engines."""
742 """Clear the remote namespaces on my engines."""
741 block = block if block is not None else self.block
743 block = block if block is not None else self.block
742 targets = targets if targets is not None else self.targets
744 targets = targets if targets is not None else self.targets
743 return self.client.clear(targets=targets, block=block)
745 return self.client.clear(targets=targets, block=block)
744
746
745 def kill(self, targets=None, block=True):
747 def kill(self, targets=None, block=True):
746 """Kill my engines."""
748 """Kill my engines."""
747 block = block if block is not None else self.block
749 block = block if block is not None else self.block
748 targets = targets if targets is not None else self.targets
750 targets = targets if targets is not None else self.targets
749 return self.client.kill(targets=targets, block=block)
751 return self.client.kill(targets=targets, block=block)
750
752
751 #----------------------------------------
753 #----------------------------------------
752 # activate for %px,%autopx magics
754 # activate for %px,%autopx magics
753 #----------------------------------------
755 #----------------------------------------
754 def activate(self):
756 def activate(self):
755 """Make this `View` active for parallel magic commands.
757 """Make this `View` active for parallel magic commands.
756
758
757 IPython has a magic command syntax to work with `MultiEngineClient` objects.
759 IPython has a magic command syntax to work with `MultiEngineClient` objects.
758 In a given IPython session there is a single active one. While
760 In a given IPython session there is a single active one. While
759 there can be many `Views` created and used by the user,
761 there can be many `Views` created and used by the user,
760 there is only one active one. The active `View` is used whenever
762 there is only one active one. The active `View` is used whenever
761 the magic commands %px and %autopx are used.
763 the magic commands %px and %autopx are used.
762
764
763 The activate() method is called on a given `View` to make it
765 The activate() method is called on a given `View` to make it
764 active. Once this has been done, the magic commands can be used.
766 active. Once this has been done, the magic commands can be used.
765 """
767 """
766
768
767 try:
769 try:
768 # This is injected into __builtins__.
770 # This is injected into __builtins__.
769 ip = get_ipython()
771 ip = get_ipython()
770 except NameError:
772 except NameError:
771 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
773 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
772 else:
774 else:
773 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
775 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
774 if pmagic is None:
776 if pmagic is None:
775 ip.magic_load_ext('parallelmagic')
777 ip.magic_load_ext('parallelmagic')
776 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
778 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
777
779
778 pmagic.active_view = self
780 pmagic.active_view = self
779
781
780
782
781 @skip_doctest
783 @skip_doctest
782 class LoadBalancedView(View):
784 class LoadBalancedView(View):
783 """An load-balancing View that only executes via the Task scheduler.
785 """An load-balancing View that only executes via the Task scheduler.
784
786
785 Load-balanced views can be created with the client's `view` method:
787 Load-balanced views can be created with the client's `view` method:
786
788
787 >>> v = client.load_balanced_view()
789 >>> v = client.load_balanced_view()
788
790
789 or targets can be specified, to restrict the potential destinations:
791 or targets can be specified, to restrict the potential destinations:
790
792
791 >>> v = client.client.load_balanced_view([1,3])
793 >>> v = client.client.load_balanced_view([1,3])
792
794
793 which would restrict loadbalancing to between engines 1 and 3.
795 which would restrict loadbalancing to between engines 1 and 3.
794
796
795 """
797 """
796
798
797 follow=Any()
799 follow=Any()
798 after=Any()
800 after=Any()
799 timeout=CFloat()
801 timeout=CFloat()
800 retries = Integer(0)
802 retries = Integer(0)
801
803
802 _task_scheme = Any()
804 _task_scheme = Any()
803 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
805 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
804
806
805 def __init__(self, client=None, socket=None, **flags):
807 def __init__(self, client=None, socket=None, **flags):
806 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
808 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
807 self._task_scheme=client._task_scheme
809 self._task_scheme=client._task_scheme
808
810
809 def _validate_dependency(self, dep):
811 def _validate_dependency(self, dep):
810 """validate a dependency.
812 """validate a dependency.
811
813
812 For use in `set_flags`.
814 For use in `set_flags`.
813 """
815 """
814 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
816 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
815 return True
817 return True
816 elif isinstance(dep, (list,set, tuple)):
818 elif isinstance(dep, (list,set, tuple)):
817 for d in dep:
819 for d in dep:
818 if not isinstance(d, (basestring, AsyncResult)):
820 if not isinstance(d, (basestring, AsyncResult)):
819 return False
821 return False
820 elif isinstance(dep, dict):
822 elif isinstance(dep, dict):
821 if set(dep.keys()) != set(Dependency().as_dict().keys()):
823 if set(dep.keys()) != set(Dependency().as_dict().keys()):
822 return False
824 return False
823 if not isinstance(dep['msg_ids'], list):
825 if not isinstance(dep['msg_ids'], list):
824 return False
826 return False
825 for d in dep['msg_ids']:
827 for d in dep['msg_ids']:
826 if not isinstance(d, basestring):
828 if not isinstance(d, basestring):
827 return False
829 return False
828 else:
830 else:
829 return False
831 return False
830
832
831 return True
833 return True
832
834
833 def _render_dependency(self, dep):
835 def _render_dependency(self, dep):
834 """helper for building jsonable dependencies from various input forms."""
836 """helper for building jsonable dependencies from various input forms."""
835 if isinstance(dep, Dependency):
837 if isinstance(dep, Dependency):
836 return dep.as_dict()
838 return dep.as_dict()
837 elif isinstance(dep, AsyncResult):
839 elif isinstance(dep, AsyncResult):
838 return dep.msg_ids
840 return dep.msg_ids
839 elif dep is None:
841 elif dep is None:
840 return []
842 return []
841 else:
843 else:
842 # pass to Dependency constructor
844 # pass to Dependency constructor
843 return list(Dependency(dep))
845 return list(Dependency(dep))
844
846
845 def set_flags(self, **kwargs):
847 def set_flags(self, **kwargs):
846 """set my attribute flags by keyword.
848 """set my attribute flags by keyword.
847
849
848 A View is a wrapper for the Client's apply method, but with attributes
850 A View is a wrapper for the Client's apply method, but with attributes
849 that specify keyword arguments, those attributes can be set by keyword
851 that specify keyword arguments, those attributes can be set by keyword
850 argument with this method.
852 argument with this method.
851
853
852 Parameters
854 Parameters
853 ----------
855 ----------
854
856
855 block : bool
857 block : bool
856 whether to wait for results
858 whether to wait for results
857 track : bool
859 track : bool
858 whether to create a MessageTracker to allow the user to
860 whether to create a MessageTracker to allow the user to
859 safely edit after arrays and buffers during non-copying
861 safely edit after arrays and buffers during non-copying
860 sends.
862 sends.
861
863
862 after : Dependency or collection of msg_ids
864 after : Dependency or collection of msg_ids
863 Only for load-balanced execution (targets=None)
865 Only for load-balanced execution (targets=None)
864 Specify a list of msg_ids as a time-based dependency.
866 Specify a list of msg_ids as a time-based dependency.
865 This job will only be run *after* the dependencies
867 This job will only be run *after* the dependencies
866 have been met.
868 have been met.
867
869
868 follow : Dependency or collection of msg_ids
870 follow : Dependency or collection of msg_ids
869 Only for load-balanced execution (targets=None)
871 Only for load-balanced execution (targets=None)
870 Specify a list of msg_ids as a location-based dependency.
872 Specify a list of msg_ids as a location-based dependency.
871 This job will only be run on an engine where this dependency
873 This job will only be run on an engine where this dependency
872 is met.
874 is met.
873
875
874 timeout : float/int or None
876 timeout : float/int or None
875 Only for load-balanced execution (targets=None)
877 Only for load-balanced execution (targets=None)
876 Specify an amount of time (in seconds) for the scheduler to
878 Specify an amount of time (in seconds) for the scheduler to
877 wait for dependencies to be met before failing with a
879 wait for dependencies to be met before failing with a
878 DependencyTimeout.
880 DependencyTimeout.
879
881
880 retries : int
882 retries : int
881 Number of times a task will be retried on failure.
883 Number of times a task will be retried on failure.
882 """
884 """
883
885
884 super(LoadBalancedView, self).set_flags(**kwargs)
886 super(LoadBalancedView, self).set_flags(**kwargs)
885 for name in ('follow', 'after'):
887 for name in ('follow', 'after'):
886 if name in kwargs:
888 if name in kwargs:
887 value = kwargs[name]
889 value = kwargs[name]
888 if self._validate_dependency(value):
890 if self._validate_dependency(value):
889 setattr(self, name, value)
891 setattr(self, name, value)
890 else:
892 else:
891 raise ValueError("Invalid dependency: %r"%value)
893 raise ValueError("Invalid dependency: %r"%value)
892 if 'timeout' in kwargs:
894 if 'timeout' in kwargs:
893 t = kwargs['timeout']
895 t = kwargs['timeout']
894 if not isinstance(t, (int, long, float, type(None))):
896 if not isinstance(t, (int, long, float, type(None))):
895 raise TypeError("Invalid type for timeout: %r"%type(t))
897 raise TypeError("Invalid type for timeout: %r"%type(t))
896 if t is not None:
898 if t is not None:
897 if t < 0:
899 if t < 0:
898 raise ValueError("Invalid timeout: %s"%t)
900 raise ValueError("Invalid timeout: %s"%t)
899 self.timeout = t
901 self.timeout = t
900
902
901 @sync_results
903 @sync_results
902 @save_ids
904 @save_ids
903 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
905 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
904 after=None, follow=None, timeout=None,
906 after=None, follow=None, timeout=None,
905 targets=None, retries=None):
907 targets=None, retries=None):
906 """calls f(*args, **kwargs) on a remote engine, returning the result.
908 """calls f(*args, **kwargs) on a remote engine, returning the result.
907
909
908 This method temporarily sets all of `apply`'s flags for a single call.
910 This method temporarily sets all of `apply`'s flags for a single call.
909
911
910 Parameters
912 Parameters
911 ----------
913 ----------
912
914
913 f : callable
915 f : callable
914
916
915 args : list [default: empty]
917 args : list [default: empty]
916
918
917 kwargs : dict [default: empty]
919 kwargs : dict [default: empty]
918
920
919 block : bool [default: self.block]
921 block : bool [default: self.block]
920 whether to block
922 whether to block
921 track : bool [default: self.track]
923 track : bool [default: self.track]
922 whether to ask zmq to track the message, for safe non-copying sends
924 whether to ask zmq to track the message, for safe non-copying sends
923
925
924 !!!!!! TODO: THE REST HERE !!!!
926 !!!!!! TODO: THE REST HERE !!!!
925
927
926 Returns
928 Returns
927 -------
929 -------
928
930
929 if self.block is False:
931 if self.block is False:
930 returns AsyncResult
932 returns AsyncResult
931 else:
933 else:
932 returns actual result of f(*args, **kwargs) on the engine(s)
934 returns actual result of f(*args, **kwargs) on the engine(s)
933 This will be a list of self.targets is also a list (even length 1), or
935 This will be a list of self.targets is also a list (even length 1), or
934 the single result if self.targets is an integer engine id
936 the single result if self.targets is an integer engine id
935 """
937 """
936
938
937 # validate whether we can run
939 # validate whether we can run
938 if self._socket.closed:
940 if self._socket.closed:
939 msg = "Task farming is disabled"
941 msg = "Task farming is disabled"
940 if self._task_scheme == 'pure':
942 if self._task_scheme == 'pure':
941 msg += " because the pure ZMQ scheduler cannot handle"
943 msg += " because the pure ZMQ scheduler cannot handle"
942 msg += " disappearing engines."
944 msg += " disappearing engines."
943 raise RuntimeError(msg)
945 raise RuntimeError(msg)
944
946
945 if self._task_scheme == 'pure':
947 if self._task_scheme == 'pure':
946 # pure zmq scheme doesn't support extra features
948 # pure zmq scheme doesn't support extra features
947 msg = "Pure ZMQ scheduler doesn't support the following flags:"
949 msg = "Pure ZMQ scheduler doesn't support the following flags:"
948 "follow, after, retries, targets, timeout"
950 "follow, after, retries, targets, timeout"
949 if (follow or after or retries or targets or timeout):
951 if (follow or after or retries or targets or timeout):
950 # hard fail on Scheduler flags
952 # hard fail on Scheduler flags
951 raise RuntimeError(msg)
953 raise RuntimeError(msg)
952 if isinstance(f, dependent):
954 if isinstance(f, dependent):
953 # soft warn on functional dependencies
955 # soft warn on functional dependencies
954 warnings.warn(msg, RuntimeWarning)
956 warnings.warn(msg, RuntimeWarning)
955
957
956 # build args
958 # build args
957 args = [] if args is None else args
959 args = [] if args is None else args
958 kwargs = {} if kwargs is None else kwargs
960 kwargs = {} if kwargs is None else kwargs
959 block = self.block if block is None else block
961 block = self.block if block is None else block
960 track = self.track if track is None else track
962 track = self.track if track is None else track
961 after = self.after if after is None else after
963 after = self.after if after is None else after
962 retries = self.retries if retries is None else retries
964 retries = self.retries if retries is None else retries
963 follow = self.follow if follow is None else follow
965 follow = self.follow if follow is None else follow
964 timeout = self.timeout if timeout is None else timeout
966 timeout = self.timeout if timeout is None else timeout
965 targets = self.targets if targets is None else targets
967 targets = self.targets if targets is None else targets
966
968
967 if not isinstance(retries, int):
969 if not isinstance(retries, int):
968 raise TypeError('retries must be int, not %r'%type(retries))
970 raise TypeError('retries must be int, not %r'%type(retries))
969
971
970 if targets is None:
972 if targets is None:
971 idents = []
973 idents = []
972 else:
974 else:
973 idents = self.client._build_targets(targets)[0]
975 idents = self.client._build_targets(targets)[0]
974 # ensure *not* bytes
976 # ensure *not* bytes
975 idents = [ ident.decode() for ident in idents ]
977 idents = [ ident.decode() for ident in idents ]
976
978
977 after = self._render_dependency(after)
979 after = self._render_dependency(after)
978 follow = self._render_dependency(follow)
980 follow = self._render_dependency(follow)
979 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
981 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
980
982
981 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
983 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
982 subheader=subheader)
984 subheader=subheader)
983 tracker = None if track is False else msg['tracker']
985 tracker = None if track is False else msg['tracker']
984
986
985 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
987 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
986
988
987 if block:
989 if block:
988 try:
990 try:
989 return ar.get()
991 return ar.get()
990 except KeyboardInterrupt:
992 except KeyboardInterrupt:
991 pass
993 pass
992 return ar
994 return ar
993
995
994 @spin_after
996 @spin_after
995 @save_ids
997 @save_ids
996 def map(self, f, *sequences, **kwargs):
998 def map(self, f, *sequences, **kwargs):
997 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
999 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
998
1000
999 Parallel version of builtin `map`, load-balanced by this View.
1001 Parallel version of builtin `map`, load-balanced by this View.
1000
1002
1001 `block`, and `chunksize` can be specified by keyword only.
1003 `block`, and `chunksize` can be specified by keyword only.
1002
1004
1003 Each `chunksize` elements will be a separate task, and will be
1005 Each `chunksize` elements will be a separate task, and will be
1004 load-balanced. This lets individual elements be available for iteration
1006 load-balanced. This lets individual elements be available for iteration
1005 as soon as they arrive.
1007 as soon as they arrive.
1006
1008
1007 Parameters
1009 Parameters
1008 ----------
1010 ----------
1009
1011
1010 f : callable
1012 f : callable
1011 function to be mapped
1013 function to be mapped
1012 *sequences: one or more sequences of matching length
1014 *sequences: one or more sequences of matching length
1013 the sequences to be distributed and passed to `f`
1015 the sequences to be distributed and passed to `f`
1014 block : bool [default self.block]
1016 block : bool [default self.block]
1015 whether to wait for the result or not
1017 whether to wait for the result or not
1016 track : bool
1018 track : bool
1017 whether to create a MessageTracker to allow the user to
1019 whether to create a MessageTracker to allow the user to
1018 safely edit after arrays and buffers during non-copying
1020 safely edit after arrays and buffers during non-copying
1019 sends.
1021 sends.
1020 chunksize : int [default 1]
1022 chunksize : int [default 1]
1021 how many elements should be in each task.
1023 how many elements should be in each task.
1022 ordered : bool [default True]
1024 ordered : bool [default True]
1023 Whether the results should be gathered as they arrive, or enforce
1025 Whether the results should be gathered as they arrive, or enforce
1024 the order of submission.
1026 the order of submission.
1025
1027
1026 Only applies when iterating through AsyncMapResult as results arrive.
1028 Only applies when iterating through AsyncMapResult as results arrive.
1027 Has no effect when block=True.
1029 Has no effect when block=True.
1028
1030
1029 Returns
1031 Returns
1030 -------
1032 -------
1031
1033
1032 if block=False:
1034 if block=False:
1033 AsyncMapResult
1035 AsyncMapResult
1034 An object like AsyncResult, but which reassembles the sequence of results
1036 An object like AsyncResult, but which reassembles the sequence of results
1035 into a single list. AsyncMapResults can be iterated through before all
1037 into a single list. AsyncMapResults can be iterated through before all
1036 results are complete.
1038 results are complete.
1037 else:
1039 else:
1038 the result of map(f,*sequences)
1040 the result of map(f,*sequences)
1039
1041
1040 """
1042 """
1041
1043
1042 # default
1044 # default
1043 block = kwargs.get('block', self.block)
1045 block = kwargs.get('block', self.block)
1044 chunksize = kwargs.get('chunksize', 1)
1046 chunksize = kwargs.get('chunksize', 1)
1045 ordered = kwargs.get('ordered', True)
1047 ordered = kwargs.get('ordered', True)
1046
1048
1047 keyset = set(kwargs.keys())
1049 keyset = set(kwargs.keys())
1048 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1050 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1049 if extra_keys:
1051 if extra_keys:
1050 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1052 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1051
1053
1052 assert len(sequences) > 0, "must have some sequences to map onto!"
1054 assert len(sequences) > 0, "must have some sequences to map onto!"
1053
1055
1054 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1056 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1055 return pf.map(*sequences)
1057 return pf.map(*sequences)
1056
1058
1057 __all__ = ['LoadBalancedView', 'DirectView']
1059 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,463 +1,472 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28
28
29 from IPython import parallel as pmod
29 from IPython import parallel as pmod
30 from IPython.parallel import error
30 from IPython.parallel import error
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import DirectView
32 from IPython.parallel import DirectView
33 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
34
34
35 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
36
36
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38
38
39 def setup():
39 def setup():
40 add_engines(3)
40 add_engines(3)
41
41
42 class TestView(ClusterTestCase):
42 class TestView(ClusterTestCase):
43
43
44 def test_z_crash_mux(self):
44 def test_z_crash_mux(self):
45 """test graceful handling of engine death (direct)"""
45 """test graceful handling of engine death (direct)"""
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEquals(d, data)
69 self.assertEquals(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEquals(d, nengines*[data])
72 self.assertEquals(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEquals(r, nengines*[data])
79 self.assertEquals(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEquals(r, nengines*[[10,20]])
82 self.assertEquals(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEquals(r, testf(10))
100 self.assertEquals(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEquals(v['b'], 5)
123 self.assertEquals(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEquals(v['b'], 10)
133 self.assertEquals(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids)
145 ahr = v2.get_result(ar.msg_ids)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertEquals(ahr.get(), ar.get())
147 self.assertEquals(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids)
148 ar2 = v2.get_result(ar.msg_ids)
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 c.spin()
150 c.spin()
151 c.close()
151 c.close()
152
152
153 def test_run_newline(self):
153 def test_run_newline(self):
154 """test that run appends newline to files"""
154 """test that run appends newline to files"""
155 tmpfile = mktemp()
155 tmpfile = mktemp()
156 with open(tmpfile, 'w') as f:
156 with open(tmpfile, 'w') as f:
157 f.write("""def g():
157 f.write("""def g():
158 return 5
158 return 5
159 """)
159 """)
160 v = self.client[-1]
160 v = self.client[-1]
161 v.run(tmpfile, block=True)
161 v.run(tmpfile, block=True)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
163
164 def test_apply_tracked(self):
164 def test_apply_tracked(self):
165 """test tracking for apply"""
165 """test tracking for apply"""
166 # self.add_engines(1)
166 # self.add_engines(1)
167 t = self.client.ids[-1]
167 t = self.client.ids[-1]
168 v = self.client[t]
168 v = self.client[t]
169 v.block=False
169 v.block=False
170 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
173 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176 ar = echo(track=True)
176 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEquals(ar.sent, ar._tracker.done)
178 self.assertEquals(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
179 ar._tracker.wait()
180 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
181
181
182 def test_push_tracked(self):
182 def test_push_tracked(self):
183 t = self.client.ids[-1]
183 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
185 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
189
189
190 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
192 ar._tracker.wait()
193 self.assertEquals(ar.sent, ar._tracker.done)
193 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195 ar.get()
195 ar.get()
196
196
197 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
198 t = self.client.ids
198 t = self.client.ids
199 x='x'*1024*1024
199 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEquals(ar.sent, ar._tracker.done)
206 self.assertEquals(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
207 ar._tracker.wait()
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_remote_reference(self):
211 def test_remote_reference(self):
212 v = self.client[-1]
212 v = self.client[-1]
213 v['a'] = 123
213 v['a'] = 123
214 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
216 self.assertEquals(b, 123)
216 self.assertEquals(b, 123)
217
217
218
218
219 def test_scatter_gather(self):
219 def test_scatter_gather(self):
220 view = self.client[:]
220 view = self.client[:]
221 seq1 = range(16)
221 seq1 = range(16)
222 view.scatter('a', seq1)
222 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
224 self.assertEquals(seq2, seq1)
224 self.assertEquals(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
226
227 @skip_without('numpy')
227 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
229 import numpy
229 import numpy
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 view = self.client[:]
231 view = self.client[:]
232 a = numpy.arange(64)
232 a = numpy.arange(64)
233 view.scatter('a', a)
233 view.scatter('a', a)
234 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
235 assert_array_equal(b, a)
236
236
237 def test_map(self):
237 def test_map(self):
238 view = self.client[:]
238 view = self.client[:]
239 def f(x):
239 def f(x):
240 return x**2
240 return x**2
241 data = range(16)
241 data = range(16)
242 r = view.map_sync(f, data)
242 r = view.map_sync(f, data)
243 self.assertEquals(r, map(f, data))
243 self.assertEquals(r, map(f, data))
244
244
245 def test_map_iterable(self):
245 def test_map_iterable(self):
246 """test map on iterables (direct)"""
246 """test map on iterables (direct)"""
247 view = self.client[:]
247 view = self.client[:]
248 # 101 is prime, so it won't be evenly distributed
248 # 101 is prime, so it won't be evenly distributed
249 arr = range(101)
249 arr = range(101)
250 # ensure it will be an iterator, even in Python 3
250 # ensure it will be an iterator, even in Python 3
251 it = iter(arr)
251 it = iter(arr)
252 r = view.map_sync(lambda x:x, arr)
252 r = view.map_sync(lambda x:x, arr)
253 self.assertEquals(r, list(arr))
253 self.assertEquals(r, list(arr))
254
254
255 def test_scatterGatherNonblocking(self):
255 def test_scatterGatherNonblocking(self):
256 data = range(16)
256 data = range(16)
257 view = self.client[:]
257 view = self.client[:]
258 view.scatter('a', data, block=False)
258 view.scatter('a', data, block=False)
259 ar = view.gather('a', block=False)
259 ar = view.gather('a', block=False)
260 self.assertEquals(ar.get(), data)
260 self.assertEquals(ar.get(), data)
261
261
262 @skip_without('numpy')
262 @skip_without('numpy')
263 def test_scatter_gather_numpy_nonblocking(self):
263 def test_scatter_gather_numpy_nonblocking(self):
264 import numpy
264 import numpy
265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
266 a = numpy.arange(64)
266 a = numpy.arange(64)
267 view = self.client[:]
267 view = self.client[:]
268 ar = view.scatter('a', a, block=False)
268 ar = view.scatter('a', a, block=False)
269 self.assertTrue(isinstance(ar, AsyncResult))
269 self.assertTrue(isinstance(ar, AsyncResult))
270 amr = view.gather('a', block=False)
270 amr = view.gather('a', block=False)
271 self.assertTrue(isinstance(amr, AsyncMapResult))
271 self.assertTrue(isinstance(amr, AsyncMapResult))
272 assert_array_equal(amr.get(), a)
272 assert_array_equal(amr.get(), a)
273
273
274 def test_execute(self):
274 def test_execute(self):
275 view = self.client[:]
275 view = self.client[:]
276 # self.client.debug=True
276 # self.client.debug=True
277 execute = view.execute
277 execute = view.execute
278 ar = execute('c=30', block=False)
278 ar = execute('c=30', block=False)
279 self.assertTrue(isinstance(ar, AsyncResult))
279 self.assertTrue(isinstance(ar, AsyncResult))
280 ar = execute('d=[0,1,2]', block=False)
280 ar = execute('d=[0,1,2]', block=False)
281 self.client.wait(ar, 1)
281 self.client.wait(ar, 1)
282 self.assertEquals(len(ar.get()), len(self.client))
282 self.assertEquals(len(ar.get()), len(self.client))
283 for c in view['c']:
283 for c in view['c']:
284 self.assertEquals(c, 30)
284 self.assertEquals(c, 30)
285
285
286 def test_abort(self):
286 def test_abort(self):
287 view = self.client[-1]
287 view = self.client[-1]
288 ar = view.execute('import time; time.sleep(1)', block=False)
288 ar = view.execute('import time; time.sleep(1)', block=False)
289 ar2 = view.apply_async(lambda : 2)
289 ar2 = view.apply_async(lambda : 2)
290 ar3 = view.apply_async(lambda : 3)
290 ar3 = view.apply_async(lambda : 3)
291 view.abort(ar2)
291 view.abort(ar2)
292 view.abort(ar3.msg_ids)
292 view.abort(ar3.msg_ids)
293 self.assertRaises(error.TaskAborted, ar2.get)
293 self.assertRaises(error.TaskAborted, ar2.get)
294 self.assertRaises(error.TaskAborted, ar3.get)
294 self.assertRaises(error.TaskAborted, ar3.get)
295
295
296 def test_abort_all(self):
297 """view.abort() aborts all outstanding tasks"""
298 view = self.client[-1]
299 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
300 view.abort()
301 view.wait(timeout=5)
302 for ar in ars[5:]:
303 self.assertRaises(error.TaskAborted, ar.get)
304
296 def test_temp_flags(self):
305 def test_temp_flags(self):
297 view = self.client[-1]
306 view = self.client[-1]
298 view.block=True
307 view.block=True
299 with view.temp_flags(block=False):
308 with view.temp_flags(block=False):
300 self.assertFalse(view.block)
309 self.assertFalse(view.block)
301 self.assertTrue(view.block)
310 self.assertTrue(view.block)
302
311
303 @dec.known_failure_py3
312 @dec.known_failure_py3
304 def test_importer(self):
313 def test_importer(self):
305 view = self.client[-1]
314 view = self.client[-1]
306 view.clear(block=True)
315 view.clear(block=True)
307 with view.importer:
316 with view.importer:
308 import re
317 import re
309
318
310 @interactive
319 @interactive
311 def findall(pat, s):
320 def findall(pat, s):
312 # this globals() step isn't necessary in real code
321 # this globals() step isn't necessary in real code
313 # only to prevent a closure in the test
322 # only to prevent a closure in the test
314 re = globals()['re']
323 re = globals()['re']
315 return re.findall(pat, s)
324 return re.findall(pat, s)
316
325
317 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
326 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
318
327
319 # parallel magic tests
328 # parallel magic tests
320
329
321 def test_magic_px_blocking(self):
330 def test_magic_px_blocking(self):
322 ip = get_ipython()
331 ip = get_ipython()
323 v = self.client[-1]
332 v = self.client[-1]
324 v.activate()
333 v.activate()
325 v.block=True
334 v.block=True
326
335
327 ip.magic_px('a=5')
336 ip.magic_px('a=5')
328 self.assertEquals(v['a'], 5)
337 self.assertEquals(v['a'], 5)
329 ip.magic_px('a=10')
338 ip.magic_px('a=10')
330 self.assertEquals(v['a'], 10)
339 self.assertEquals(v['a'], 10)
331 sio = StringIO()
340 sio = StringIO()
332 savestdout = sys.stdout
341 savestdout = sys.stdout
333 sys.stdout = sio
342 sys.stdout = sio
334 # just 'print a' worst ~99% of the time, but this ensures that
343 # just 'print a' worst ~99% of the time, but this ensures that
335 # the stdout message has arrived when the result is finished:
344 # the stdout message has arrived when the result is finished:
336 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
345 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
337 sys.stdout = savestdout
346 sys.stdout = savestdout
338 buf = sio.getvalue()
347 buf = sio.getvalue()
339 self.assertTrue('[stdout:' in buf, buf)
348 self.assertTrue('[stdout:' in buf, buf)
340 self.assertTrue(buf.rstrip().endswith('10'))
349 self.assertTrue(buf.rstrip().endswith('10'))
341 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
350 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
342
351
343 def test_magic_px_nonblocking(self):
352 def test_magic_px_nonblocking(self):
344 ip = get_ipython()
353 ip = get_ipython()
345 v = self.client[-1]
354 v = self.client[-1]
346 v.activate()
355 v.activate()
347 v.block=False
356 v.block=False
348
357
349 ip.magic_px('a=5')
358 ip.magic_px('a=5')
350 self.assertEquals(v['a'], 5)
359 self.assertEquals(v['a'], 5)
351 ip.magic_px('a=10')
360 ip.magic_px('a=10')
352 self.assertEquals(v['a'], 10)
361 self.assertEquals(v['a'], 10)
353 sio = StringIO()
362 sio = StringIO()
354 savestdout = sys.stdout
363 savestdout = sys.stdout
355 sys.stdout = sio
364 sys.stdout = sio
356 ip.magic_px('print a')
365 ip.magic_px('print a')
357 sys.stdout = savestdout
366 sys.stdout = savestdout
358 buf = sio.getvalue()
367 buf = sio.getvalue()
359 self.assertFalse('[stdout:%i]'%v.targets in buf)
368 self.assertFalse('[stdout:%i]'%v.targets in buf)
360 ip.magic_px('1/0')
369 ip.magic_px('1/0')
361 ar = v.get_result(-1)
370 ar = v.get_result(-1)
362 self.assertRaisesRemote(ZeroDivisionError, ar.get)
371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
363
372
364 def test_magic_autopx_blocking(self):
373 def test_magic_autopx_blocking(self):
365 ip = get_ipython()
374 ip = get_ipython()
366 v = self.client[-1]
375 v = self.client[-1]
367 v.activate()
376 v.activate()
368 v.block=True
377 v.block=True
369
378
370 sio = StringIO()
379 sio = StringIO()
371 savestdout = sys.stdout
380 savestdout = sys.stdout
372 sys.stdout = sio
381 sys.stdout = sio
373 ip.magic_autopx()
382 ip.magic_autopx()
374 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
375 ip.run_cell('print b')
384 ip.run_cell('print b')
376 ip.run_cell("b/c")
385 ip.run_cell("b/c")
377 ip.run_code(compile('b*=2', '', 'single'))
386 ip.run_code(compile('b*=2', '', 'single'))
378 ip.magic_autopx()
387 ip.magic_autopx()
379 sys.stdout = savestdout
388 sys.stdout = savestdout
380 output = sio.getvalue().strip()
389 output = sio.getvalue().strip()
381 self.assertTrue(output.startswith('%autopx enabled'))
390 self.assertTrue(output.startswith('%autopx enabled'))
382 self.assertTrue(output.endswith('%autopx disabled'))
391 self.assertTrue(output.endswith('%autopx disabled'))
383 self.assertTrue('RemoteError: ZeroDivisionError' in output)
392 self.assertTrue('RemoteError: ZeroDivisionError' in output)
384 ar = v.get_result(-2)
393 ar = v.get_result(-2)
385 self.assertEquals(v['a'], 5)
394 self.assertEquals(v['a'], 5)
386 self.assertEquals(v['b'], 20)
395 self.assertEquals(v['b'], 20)
387 self.assertRaisesRemote(ZeroDivisionError, ar.get)
396 self.assertRaisesRemote(ZeroDivisionError, ar.get)
388
397
389 def test_magic_autopx_nonblocking(self):
398 def test_magic_autopx_nonblocking(self):
390 ip = get_ipython()
399 ip = get_ipython()
391 v = self.client[-1]
400 v = self.client[-1]
392 v.activate()
401 v.activate()
393 v.block=False
402 v.block=False
394
403
395 sio = StringIO()
404 sio = StringIO()
396 savestdout = sys.stdout
405 savestdout = sys.stdout
397 sys.stdout = sio
406 sys.stdout = sio
398 ip.magic_autopx()
407 ip.magic_autopx()
399 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
408 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
400 ip.run_cell('print b')
409 ip.run_cell('print b')
401 ip.run_cell("b/c")
410 ip.run_cell("b/c")
402 ip.run_code(compile('b*=2', '', 'single'))
411 ip.run_code(compile('b*=2', '', 'single'))
403 ip.magic_autopx()
412 ip.magic_autopx()
404 sys.stdout = savestdout
413 sys.stdout = savestdout
405 output = sio.getvalue().strip()
414 output = sio.getvalue().strip()
406 self.assertTrue(output.startswith('%autopx enabled'))
415 self.assertTrue(output.startswith('%autopx enabled'))
407 self.assertTrue(output.endswith('%autopx disabled'))
416 self.assertTrue(output.endswith('%autopx disabled'))
408 self.assertFalse('ZeroDivisionError' in output)
417 self.assertFalse('ZeroDivisionError' in output)
409 ar = v.get_result(-2)
418 ar = v.get_result(-2)
410 self.assertEquals(v['a'], 5)
419 self.assertEquals(v['a'], 5)
411 self.assertEquals(v['b'], 20)
420 self.assertEquals(v['b'], 20)
412 self.assertRaisesRemote(ZeroDivisionError, ar.get)
421 self.assertRaisesRemote(ZeroDivisionError, ar.get)
413
422
414 def test_magic_result(self):
423 def test_magic_result(self):
415 ip = get_ipython()
424 ip = get_ipython()
416 v = self.client[-1]
425 v = self.client[-1]
417 v.activate()
426 v.activate()
418 v['a'] = 111
427 v['a'] = 111
419 ra = v['a']
428 ra = v['a']
420
429
421 ar = ip.magic_result()
430 ar = ip.magic_result()
422 self.assertEquals(ar.msg_ids, [v.history[-1]])
431 self.assertEquals(ar.msg_ids, [v.history[-1]])
423 self.assertEquals(ar.get(), 111)
432 self.assertEquals(ar.get(), 111)
424 ar = ip.magic_result('-2')
433 ar = ip.magic_result('-2')
425 self.assertEquals(ar.msg_ids, [v.history[-2]])
434 self.assertEquals(ar.msg_ids, [v.history[-2]])
426
435
427 def test_unicode_execute(self):
436 def test_unicode_execute(self):
428 """test executing unicode strings"""
437 """test executing unicode strings"""
429 v = self.client[-1]
438 v = self.client[-1]
430 v.block=True
439 v.block=True
431 if sys.version_info[0] >= 3:
440 if sys.version_info[0] >= 3:
432 code="a='é'"
441 code="a='é'"
433 else:
442 else:
434 code=u"a=u'é'"
443 code=u"a=u'é'"
435 v.execute(code)
444 v.execute(code)
436 self.assertEquals(v['a'], u'é')
445 self.assertEquals(v['a'], u'é')
437
446
438 def test_unicode_apply_result(self):
447 def test_unicode_apply_result(self):
439 """test unicode apply results"""
448 """test unicode apply results"""
440 v = self.client[-1]
449 v = self.client[-1]
441 r = v.apply_sync(lambda : u'é')
450 r = v.apply_sync(lambda : u'é')
442 self.assertEquals(r, u'é')
451 self.assertEquals(r, u'é')
443
452
444 def test_unicode_apply_arg(self):
453 def test_unicode_apply_arg(self):
445 """test passing unicode arguments to apply"""
454 """test passing unicode arguments to apply"""
446 v = self.client[-1]
455 v = self.client[-1]
447
456
448 @interactive
457 @interactive
449 def check_unicode(a, check):
458 def check_unicode(a, check):
450 assert isinstance(a, unicode), "%r is not unicode"%a
459 assert isinstance(a, unicode), "%r is not unicode"%a
451 assert isinstance(check, bytes), "%r is not bytes"%check
460 assert isinstance(check, bytes), "%r is not bytes"%check
452 assert a.encode('utf8') == check, "%s != %s"%(a,check)
461 assert a.encode('utf8') == check, "%s != %s"%(a,check)
453
462
454 for s in [ u'é', u'ßø®∫',u'asdf' ]:
463 for s in [ u'é', u'ßø®∫',u'asdf' ]:
455 try:
464 try:
456 v.apply_sync(check_unicode, s, s.encode('utf8'))
465 v.apply_sync(check_unicode, s, s.encode('utf8'))
457 except error.RemoteError as e:
466 except error.RemoteError as e:
458 if e.ename == 'AssertionError':
467 if e.ename == 'AssertionError':
459 self.fail(e.evalue)
468 self.fail(e.evalue)
460 else:
469 else:
461 raise e
470 raise e
462
471
463
472
General Comments 0
You need to be logged in to leave comments. Login now