##// END OF EJS Templates
allow Reference as callable in map/apply...
MinRK -
Show More
@@ -1,1443 +1,1444
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 import time
21 import time
22 import warnings
22 import warnings
23 from datetime import datetime
23 from datetime import datetime
24 from getpass import getpass
24 from getpass import getpass
25 from pprint import pprint
25 from pprint import pprint
26
26
27 pjoin = os.path.join
27 pjoin = os.path.join
28
28
29 import zmq
29 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
30 # from zmq.eventloop import ioloop, zmqstream
31
31
32 from IPython.config.configurable import MultipleInstanceError
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
33 from IPython.core.application import BaseIPythonApplication
34
34
35 from IPython.utils.jsonutil import rekey
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 Dict, List, Bool, Set)
39 Dict, List, Bool, Set)
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import Reference
43 from IPython.parallel import error
44 from IPython.parallel import error
44 from IPython.parallel import util
45 from IPython.parallel import util
45
46
46 from IPython.zmq.session import Session, Message
47 from IPython.zmq.session import Session, Message
47
48
48 from .asyncresult import AsyncResult, AsyncHubResult
49 from .asyncresult import AsyncResult, AsyncHubResult
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from .view import DirectView, LoadBalancedView
51 from .view import DirectView, LoadBalancedView
51
52
52 if sys.version_info[0] >= 3:
53 if sys.version_info[0] >= 3:
53 # xrange is used in a couple 'isinstance' tests in py2
54 # xrange is used in a couple 'isinstance' tests in py2
54 # should be just 'range' in 3k
55 # should be just 'range' in 3k
55 xrange = range
56 xrange = range
56
57
57 #--------------------------------------------------------------------------
58 #--------------------------------------------------------------------------
58 # Decorators for Client methods
59 # Decorators for Client methods
59 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
60
61
61 @decorator
62 @decorator
62 def spin_first(f, self, *args, **kwargs):
63 def spin_first(f, self, *args, **kwargs):
63 """Call spin() to sync state prior to calling the method."""
64 """Call spin() to sync state prior to calling the method."""
64 self.spin()
65 self.spin()
65 return f(self, *args, **kwargs)
66 return f(self, *args, **kwargs)
66
67
67
68
68 #--------------------------------------------------------------------------
69 #--------------------------------------------------------------------------
69 # Classes
70 # Classes
70 #--------------------------------------------------------------------------
71 #--------------------------------------------------------------------------
71
72
72 class Metadata(dict):
73 class Metadata(dict):
73 """Subclass of dict for initializing metadata values.
74 """Subclass of dict for initializing metadata values.
74
75
75 Attribute access works on keys.
76 Attribute access works on keys.
76
77
77 These objects have a strict set of keys - errors will raise if you try
78 These objects have a strict set of keys - errors will raise if you try
78 to add new keys.
79 to add new keys.
79 """
80 """
80 def __init__(self, *args, **kwargs):
81 def __init__(self, *args, **kwargs):
81 dict.__init__(self)
82 dict.__init__(self)
82 md = {'msg_id' : None,
83 md = {'msg_id' : None,
83 'submitted' : None,
84 'submitted' : None,
84 'started' : None,
85 'started' : None,
85 'completed' : None,
86 'completed' : None,
86 'received' : None,
87 'received' : None,
87 'engine_uuid' : None,
88 'engine_uuid' : None,
88 'engine_id' : None,
89 'engine_id' : None,
89 'follow' : None,
90 'follow' : None,
90 'after' : None,
91 'after' : None,
91 'status' : None,
92 'status' : None,
92
93
93 'pyin' : None,
94 'pyin' : None,
94 'pyout' : None,
95 'pyout' : None,
95 'pyerr' : None,
96 'pyerr' : None,
96 'stdout' : '',
97 'stdout' : '',
97 'stderr' : '',
98 'stderr' : '',
98 }
99 }
99 self.update(md)
100 self.update(md)
100 self.update(dict(*args, **kwargs))
101 self.update(dict(*args, **kwargs))
101
102
102 def __getattr__(self, key):
103 def __getattr__(self, key):
103 """getattr aliased to getitem"""
104 """getattr aliased to getitem"""
104 if key in self.iterkeys():
105 if key in self.iterkeys():
105 return self[key]
106 return self[key]
106 else:
107 else:
107 raise AttributeError(key)
108 raise AttributeError(key)
108
109
109 def __setattr__(self, key, value):
110 def __setattr__(self, key, value):
110 """setattr aliased to setitem, with strict"""
111 """setattr aliased to setitem, with strict"""
111 if key in self.iterkeys():
112 if key in self.iterkeys():
112 self[key] = value
113 self[key] = value
113 else:
114 else:
114 raise AttributeError(key)
115 raise AttributeError(key)
115
116
116 def __setitem__(self, key, value):
117 def __setitem__(self, key, value):
117 """strict static key enforcement"""
118 """strict static key enforcement"""
118 if key in self.iterkeys():
119 if key in self.iterkeys():
119 dict.__setitem__(self, key, value)
120 dict.__setitem__(self, key, value)
120 else:
121 else:
121 raise KeyError(key)
122 raise KeyError(key)
122
123
123
124
124 class Client(HasTraits):
125 class Client(HasTraits):
125 """A semi-synchronous client to the IPython ZMQ cluster
126 """A semi-synchronous client to the IPython ZMQ cluster
126
127
127 Parameters
128 Parameters
128 ----------
129 ----------
129
130
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 Connection information for the Hub's registration. If a json connector
132 Connection information for the Hub's registration. If a json connector
132 file is given, then likely no further configuration is necessary.
133 file is given, then likely no further configuration is necessary.
133 [Default: use profile]
134 [Default: use profile]
134 profile : bytes
135 profile : bytes
135 The name of the Cluster profile to be used to find connector information.
136 The name of the Cluster profile to be used to find connector information.
136 If run from an IPython application, the default profile will be the same
137 If run from an IPython application, the default profile will be the same
137 as the running application, otherwise it will be 'default'.
138 as the running application, otherwise it will be 'default'.
138 context : zmq.Context
139 context : zmq.Context
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 debug : bool
141 debug : bool
141 flag for lots of message printing for debug purposes
142 flag for lots of message printing for debug purposes
142 timeout : int/float
143 timeout : int/float
143 time (in seconds) to wait for connection replies from the Hub
144 time (in seconds) to wait for connection replies from the Hub
144 [Default: 10]
145 [Default: 10]
145
146
146 #-------------- session related args ----------------
147 #-------------- session related args ----------------
147
148
148 config : Config object
149 config : Config object
149 If specified, this will be relayed to the Session for configuration
150 If specified, this will be relayed to the Session for configuration
150 username : str
151 username : str
151 set username for the session object
152 set username for the session object
152 packer : str (import_string) or callable
153 packer : str (import_string) or callable
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 function to serialize messages. Must support same input as
155 function to serialize messages. Must support same input as
155 JSON, and output must be bytes.
156 JSON, and output must be bytes.
156 You can pass a callable directly as `pack`
157 You can pass a callable directly as `pack`
157 unpacker : str (import_string) or callable
158 unpacker : str (import_string) or callable
158 The inverse of packer. Only necessary if packer is specified as *not* one
159 The inverse of packer. Only necessary if packer is specified as *not* one
159 of 'json' or 'pickle'.
160 of 'json' or 'pickle'.
160
161
161 #-------------- ssh related args ----------------
162 #-------------- ssh related args ----------------
162 # These are args for configuring the ssh tunnel to be used
163 # These are args for configuring the ssh tunnel to be used
163 # credentials are used to forward connections over ssh to the Controller
164 # credentials are used to forward connections over ssh to the Controller
164 # Note that the ip given in `addr` needs to be relative to sshserver
165 # Note that the ip given in `addr` needs to be relative to sshserver
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # and set sshserver as the same machine the Controller is on. However,
167 # and set sshserver as the same machine the Controller is on. However,
167 # the only requirement is that sshserver is able to see the Controller
168 # the only requirement is that sshserver is able to see the Controller
168 # (i.e. is within the same trusted network).
169 # (i.e. is within the same trusted network).
169
170
170 sshserver : str
171 sshserver : str
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 If keyfile or password is specified, and this is not, it will default to
173 If keyfile or password is specified, and this is not, it will default to
173 the ip given in addr.
174 the ip given in addr.
174 sshkey : str; path to ssh private key file
175 sshkey : str; path to ssh private key file
175 This specifies a key to be used in ssh login, default None.
176 This specifies a key to be used in ssh login, default None.
176 Regular default ssh keys will be used without specifying this argument.
177 Regular default ssh keys will be used without specifying this argument.
177 password : str
178 password : str
178 Your ssh password to sshserver. Note that if this is left None,
179 Your ssh password to sshserver. Note that if this is left None,
179 you will be prompted for it if passwordless key based login is unavailable.
180 you will be prompted for it if passwordless key based login is unavailable.
180 paramiko : bool
181 paramiko : bool
181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 flag for whether to use paramiko instead of shell ssh for tunneling.
182 [default: True on win32, False else]
183 [default: True on win32, False else]
183
184
184 ------- exec authentication args -------
185 ------- exec authentication args -------
185 If even localhost is untrusted, you can have some protection against
186 If even localhost is untrusted, you can have some protection against
186 unauthorized execution by signing messages with HMAC digests.
187 unauthorized execution by signing messages with HMAC digests.
187 Messages are still sent as cleartext, so if someone can snoop your
188 Messages are still sent as cleartext, so if someone can snoop your
188 loopback traffic this will not protect your privacy, but will prevent
189 loopback traffic this will not protect your privacy, but will prevent
189 unauthorized execution.
190 unauthorized execution.
190
191
191 exec_key : str
192 exec_key : str
192 an authentication key or file containing a key
193 an authentication key or file containing a key
193 default: None
194 default: None
194
195
195
196
196 Attributes
197 Attributes
197 ----------
198 ----------
198
199
199 ids : list of int engine IDs
200 ids : list of int engine IDs
200 requesting the ids attribute always synchronizes
201 requesting the ids attribute always synchronizes
201 the registration state. To request ids without synchronization,
202 the registration state. To request ids without synchronization,
202 use semi-private _ids attributes.
203 use semi-private _ids attributes.
203
204
204 history : list of msg_ids
205 history : list of msg_ids
205 a list of msg_ids, keeping track of all the execution
206 a list of msg_ids, keeping track of all the execution
206 messages you have submitted in order.
207 messages you have submitted in order.
207
208
208 outstanding : set of msg_ids
209 outstanding : set of msg_ids
209 a set of msg_ids that have been submitted, but whose
210 a set of msg_ids that have been submitted, but whose
210 results have not yet been received.
211 results have not yet been received.
211
212
212 results : dict
213 results : dict
213 a dict of all our results, keyed by msg_id
214 a dict of all our results, keyed by msg_id
214
215
215 block : bool
216 block : bool
216 determines default behavior when block not specified
217 determines default behavior when block not specified
217 in execution methods
218 in execution methods
218
219
219 Methods
220 Methods
220 -------
221 -------
221
222
222 spin
223 spin
223 flushes incoming results and registration state changes
224 flushes incoming results and registration state changes
224 control methods spin, and requesting `ids` also ensures up to date
225 control methods spin, and requesting `ids` also ensures up to date
225
226
226 wait
227 wait
227 wait on one or more msg_ids
228 wait on one or more msg_ids
228
229
229 execution methods
230 execution methods
230 apply
231 apply
231 legacy: execute, run
232 legacy: execute, run
232
233
233 data movement
234 data movement
234 push, pull, scatter, gather
235 push, pull, scatter, gather
235
236
236 query methods
237 query methods
237 queue_status, get_result, purge, result_status
238 queue_status, get_result, purge, result_status
238
239
239 control methods
240 control methods
240 abort, shutdown
241 abort, shutdown
241
242
242 """
243 """
243
244
244
245
245 block = Bool(False)
246 block = Bool(False)
246 outstanding = Set()
247 outstanding = Set()
247 results = Instance('collections.defaultdict', (dict,))
248 results = Instance('collections.defaultdict', (dict,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
249 metadata = Instance('collections.defaultdict', (Metadata,))
249 history = List()
250 history = List()
250 debug = Bool(False)
251 debug = Bool(False)
251
252
252 profile=Unicode()
253 profile=Unicode()
253 def _profile_default(self):
254 def _profile_default(self):
254 if BaseIPythonApplication.initialized():
255 if BaseIPythonApplication.initialized():
255 # an IPython app *might* be running, try to get its profile
256 # an IPython app *might* be running, try to get its profile
256 try:
257 try:
257 return BaseIPythonApplication.instance().profile
258 return BaseIPythonApplication.instance().profile
258 except (AttributeError, MultipleInstanceError):
259 except (AttributeError, MultipleInstanceError):
259 # could be a *different* subclass of config.Application,
260 # could be a *different* subclass of config.Application,
260 # which would raise one of these two errors.
261 # which would raise one of these two errors.
261 return u'default'
262 return u'default'
262 else:
263 else:
263 return u'default'
264 return u'default'
264
265
265
266
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _ids = List()
268 _ids = List()
268 _connected=Bool(False)
269 _connected=Bool(False)
269 _ssh=Bool(False)
270 _ssh=Bool(False)
270 _context = Instance('zmq.Context')
271 _context = Instance('zmq.Context')
271 _config = Dict()
272 _config = Dict()
272 _engines=Instance(util.ReverseDict, (), {})
273 _engines=Instance(util.ReverseDict, (), {})
273 # _hub_socket=Instance('zmq.Socket')
274 # _hub_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
275 _query_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
276 _control_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
277 _iopub_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
278 _notification_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
279 _mux_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
280 _task_socket=Instance('zmq.Socket')
280 _task_scheme=Unicode()
281 _task_scheme=Unicode()
281 _closed = False
282 _closed = False
282 _ignored_control_replies=Integer(0)
283 _ignored_control_replies=Integer(0)
283 _ignored_hub_replies=Integer(0)
284 _ignored_hub_replies=Integer(0)
284
285
285 def __new__(self, *args, **kw):
286 def __new__(self, *args, **kw):
286 # don't raise on positional args
287 # don't raise on positional args
287 return HasTraits.__new__(self, **kw)
288 return HasTraits.__new__(self, **kw)
288
289
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 context=None, debug=False, exec_key=None,
291 context=None, debug=False, exec_key=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 sshserver=None, sshkey=None, password=None, paramiko=None,
292 timeout=10, **extra_args
293 timeout=10, **extra_args
293 ):
294 ):
294 if profile:
295 if profile:
295 super(Client, self).__init__(debug=debug, profile=profile)
296 super(Client, self).__init__(debug=debug, profile=profile)
296 else:
297 else:
297 super(Client, self).__init__(debug=debug)
298 super(Client, self).__init__(debug=debug)
298 if context is None:
299 if context is None:
299 context = zmq.Context.instance()
300 context = zmq.Context.instance()
300 self._context = context
301 self._context = context
301
302
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 if self._cd is not None:
304 if self._cd is not None:
304 if url_or_file is None:
305 if url_or_file is None:
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 " Please specify at least one of url_or_file or profile."
308 " Please specify at least one of url_or_file or profile."
308
309
309 if not util.is_url(url_or_file):
310 if not util.is_url(url_or_file):
310 # it's not a url, try for a file
311 # it's not a url, try for a file
311 if not os.path.exists(url_or_file):
312 if not os.path.exists(url_or_file):
312 if self._cd:
313 if self._cd:
313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 with open(url_or_file) as f:
316 with open(url_or_file) as f:
316 cfg = json.loads(f.read())
317 cfg = json.loads(f.read())
317 else:
318 else:
318 cfg = {'url':url_or_file}
319 cfg = {'url':url_or_file}
319
320
320 # sync defaults from args, json:
321 # sync defaults from args, json:
321 if sshserver:
322 if sshserver:
322 cfg['ssh'] = sshserver
323 cfg['ssh'] = sshserver
323 if exec_key:
324 if exec_key:
324 cfg['exec_key'] = exec_key
325 cfg['exec_key'] = exec_key
325 exec_key = cfg['exec_key']
326 exec_key = cfg['exec_key']
326 location = cfg.setdefault('location', None)
327 location = cfg.setdefault('location', None)
327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 url = cfg['url']
329 url = cfg['url']
329 proto,addr,port = util.split_url(url)
330 proto,addr,port = util.split_url(url)
330 if location is not None and addr == '127.0.0.1':
331 if location is not None and addr == '127.0.0.1':
331 # location specified, and connection is expected to be local
332 # location specified, and connection is expected to be local
332 if location not in LOCAL_IPS and not sshserver:
333 if location not in LOCAL_IPS and not sshserver:
333 # load ssh from JSON *only* if the controller is not on
334 # load ssh from JSON *only* if the controller is not on
334 # this machine
335 # this machine
335 sshserver=cfg['ssh']
336 sshserver=cfg['ssh']
336 if location not in LOCAL_IPS and not sshserver:
337 if location not in LOCAL_IPS and not sshserver:
337 # warn if no ssh specified, but SSH is probably needed
338 # warn if no ssh specified, but SSH is probably needed
338 # This is only a warning, because the most likely cause
339 # This is only a warning, because the most likely cause
339 # is a local Controller on a laptop whose IP is dynamic
340 # is a local Controller on a laptop whose IP is dynamic
340 warnings.warn("""
341 warnings.warn("""
341 Controller appears to be listening on localhost, but not on this machine.
342 Controller appears to be listening on localhost, but not on this machine.
342 If this is true, you should specify Client(...,sshserver='you@%s')
343 If this is true, you should specify Client(...,sshserver='you@%s')
343 or instruct your controller to listen on an external IP."""%location,
344 or instruct your controller to listen on an external IP."""%location,
344 RuntimeWarning)
345 RuntimeWarning)
345 elif not sshserver:
346 elif not sshserver:
346 # otherwise sync with cfg
347 # otherwise sync with cfg
347 sshserver = cfg['ssh']
348 sshserver = cfg['ssh']
348
349
349 self._config = cfg
350 self._config = cfg
350
351
351 self._ssh = bool(sshserver or sshkey or password)
352 self._ssh = bool(sshserver or sshkey or password)
352 if self._ssh and sshserver is None:
353 if self._ssh and sshserver is None:
353 # default to ssh via localhost
354 # default to ssh via localhost
354 sshserver = url.split('://')[1].split(':')[0]
355 sshserver = url.split('://')[1].split(':')[0]
355 if self._ssh and password is None:
356 if self._ssh and password is None:
356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 password=False
358 password=False
358 else:
359 else:
359 password = getpass("SSH Password for %s: "%sshserver)
360 password = getpass("SSH Password for %s: "%sshserver)
360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361
362
362 # configure and construct the session
363 # configure and construct the session
363 if exec_key is not None:
364 if exec_key is not None:
364 if os.path.isfile(exec_key):
365 if os.path.isfile(exec_key):
365 extra_args['keyfile'] = exec_key
366 extra_args['keyfile'] = exec_key
366 else:
367 else:
367 exec_key = util.asbytes(exec_key)
368 exec_key = util.asbytes(exec_key)
368 extra_args['key'] = exec_key
369 extra_args['key'] = exec_key
369 self.session = Session(**extra_args)
370 self.session = Session(**extra_args)
370
371
371 self._query_socket = self._context.socket(zmq.DEALER)
372 self._query_socket = self._context.socket(zmq.DEALER)
372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 if self._ssh:
374 if self._ssh:
374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 else:
376 else:
376 self._query_socket.connect(url)
377 self._query_socket.connect(url)
377
378
378 self.session.debug = self.debug
379 self.session.debug = self.debug
379
380
380 self._notification_handlers = {'registration_notification' : self._register_engine,
381 self._notification_handlers = {'registration_notification' : self._register_engine,
381 'unregistration_notification' : self._unregister_engine,
382 'unregistration_notification' : self._unregister_engine,
382 'shutdown_notification' : lambda msg: self.close(),
383 'shutdown_notification' : lambda msg: self.close(),
383 }
384 }
384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 'apply_reply' : self._handle_apply_reply}
386 'apply_reply' : self._handle_apply_reply}
386 self._connect(sshserver, ssh_kwargs, timeout)
387 self._connect(sshserver, ssh_kwargs, timeout)
387
388
388 def __del__(self):
389 def __del__(self):
389 """cleanup sockets, but _not_ context."""
390 """cleanup sockets, but _not_ context."""
390 self.close()
391 self.close()
391
392
392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 if ipython_dir is None:
394 if ipython_dir is None:
394 ipython_dir = get_ipython_dir()
395 ipython_dir = get_ipython_dir()
395 if profile_dir is not None:
396 if profile_dir is not None:
396 try:
397 try:
397 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 return
399 return
399 except ProfileDirError:
400 except ProfileDirError:
400 pass
401 pass
401 elif profile is not None:
402 elif profile is not None:
402 try:
403 try:
403 self._cd = ProfileDir.find_profile_dir_by_name(
404 self._cd = ProfileDir.find_profile_dir_by_name(
404 ipython_dir, profile)
405 ipython_dir, profile)
405 return
406 return
406 except ProfileDirError:
407 except ProfileDirError:
407 pass
408 pass
408 self._cd = None
409 self._cd = None
409
410
410 def _update_engines(self, engines):
411 def _update_engines(self, engines):
411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 for k,v in engines.iteritems():
413 for k,v in engines.iteritems():
413 eid = int(k)
414 eid = int(k)
414 self._engines[eid] = v
415 self._engines[eid] = v
415 self._ids.append(eid)
416 self._ids.append(eid)
416 self._ids = sorted(self._ids)
417 self._ids = sorted(self._ids)
417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 self._task_scheme == 'pure' and self._task_socket:
419 self._task_scheme == 'pure' and self._task_socket:
419 self._stop_scheduling_tasks()
420 self._stop_scheduling_tasks()
420
421
421 def _stop_scheduling_tasks(self):
422 def _stop_scheduling_tasks(self):
422 """Stop scheduling tasks because an engine has been unregistered
423 """Stop scheduling tasks because an engine has been unregistered
423 from a pure ZMQ scheduler.
424 from a pure ZMQ scheduler.
424 """
425 """
425 self._task_socket.close()
426 self._task_socket.close()
426 self._task_socket = None
427 self._task_socket = None
427 msg = "An engine has been unregistered, and we are using pure " +\
428 msg = "An engine has been unregistered, and we are using pure " +\
428 "ZMQ task scheduling. Task farming will be disabled."
429 "ZMQ task scheduling. Task farming will be disabled."
429 if self.outstanding:
430 if self.outstanding:
430 msg += " If you were running tasks when this happened, " +\
431 msg += " If you were running tasks when this happened, " +\
431 "some `outstanding` msg_ids may never resolve."
432 "some `outstanding` msg_ids may never resolve."
432 warnings.warn(msg, RuntimeWarning)
433 warnings.warn(msg, RuntimeWarning)
433
434
434 def _build_targets(self, targets):
435 def _build_targets(self, targets):
435 """Turn valid target IDs or 'all' into two lists:
436 """Turn valid target IDs or 'all' into two lists:
436 (int_ids, uuids).
437 (int_ids, uuids).
437 """
438 """
438 if not self._ids:
439 if not self._ids:
439 # flush notification socket if no engines yet, just in case
440 # flush notification socket if no engines yet, just in case
440 if not self.ids:
441 if not self.ids:
441 raise error.NoEnginesRegistered("Can't build targets without any engines")
442 raise error.NoEnginesRegistered("Can't build targets without any engines")
442
443
443 if targets is None:
444 if targets is None:
444 targets = self._ids
445 targets = self._ids
445 elif isinstance(targets, basestring):
446 elif isinstance(targets, basestring):
446 if targets.lower() == 'all':
447 if targets.lower() == 'all':
447 targets = self._ids
448 targets = self._ids
448 else:
449 else:
449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 elif isinstance(targets, int):
451 elif isinstance(targets, int):
451 if targets < 0:
452 if targets < 0:
452 targets = self.ids[targets]
453 targets = self.ids[targets]
453 if targets not in self._ids:
454 if targets not in self._ids:
454 raise IndexError("No such engine: %i"%targets)
455 raise IndexError("No such engine: %i"%targets)
455 targets = [targets]
456 targets = [targets]
456
457
457 if isinstance(targets, slice):
458 if isinstance(targets, slice):
458 indices = range(len(self._ids))[targets]
459 indices = range(len(self._ids))[targets]
459 ids = self.ids
460 ids = self.ids
460 targets = [ ids[i] for i in indices ]
461 targets = [ ids[i] for i in indices ]
461
462
462 if not isinstance(targets, (tuple, list, xrange)):
463 if not isinstance(targets, (tuple, list, xrange)):
463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464
465
465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466
467
467 def _connect(self, sshserver, ssh_kwargs, timeout):
468 def _connect(self, sshserver, ssh_kwargs, timeout):
468 """setup all our socket connections to the cluster. This is called from
469 """setup all our socket connections to the cluster. This is called from
469 __init__."""
470 __init__."""
470
471
471 # Maybe allow reconnecting?
472 # Maybe allow reconnecting?
472 if self._connected:
473 if self._connected:
473 return
474 return
474 self._connected=True
475 self._connected=True
475
476
476 def connect_socket(s, url):
477 def connect_socket(s, url):
477 url = util.disambiguate_url(url, self._config['location'])
478 url = util.disambiguate_url(url, self._config['location'])
478 if self._ssh:
479 if self._ssh:
479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 else:
481 else:
481 return s.connect(url)
482 return s.connect(url)
482
483
483 self.session.send(self._query_socket, 'connection_request')
484 self.session.send(self._query_socket, 'connection_request')
484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 poller = zmq.Poller()
486 poller = zmq.Poller()
486 poller.register(self._query_socket, zmq.POLLIN)
487 poller.register(self._query_socket, zmq.POLLIN)
487 # poll expects milliseconds, timeout is seconds
488 # poll expects milliseconds, timeout is seconds
488 evts = poller.poll(timeout*1000)
489 evts = poller.poll(timeout*1000)
489 if not evts:
490 if not evts:
490 raise error.TimeoutError("Hub connection request timed out")
491 raise error.TimeoutError("Hub connection request timed out")
491 idents,msg = self.session.recv(self._query_socket,mode=0)
492 idents,msg = self.session.recv(self._query_socket,mode=0)
492 if self.debug:
493 if self.debug:
493 pprint(msg)
494 pprint(msg)
494 msg = Message(msg)
495 msg = Message(msg)
495 content = msg.content
496 content = msg.content
496 self._config['registration'] = dict(content)
497 self._config['registration'] = dict(content)
497 if content.status == 'ok':
498 if content.status == 'ok':
498 ident = self.session.bsession
499 ident = self.session.bsession
499 if content.mux:
500 if content.mux:
500 self._mux_socket = self._context.socket(zmq.DEALER)
501 self._mux_socket = self._context.socket(zmq.DEALER)
501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 connect_socket(self._mux_socket, content.mux)
503 connect_socket(self._mux_socket, content.mux)
503 if content.task:
504 if content.task:
504 self._task_scheme, task_addr = content.task
505 self._task_scheme, task_addr = content.task
505 self._task_socket = self._context.socket(zmq.DEALER)
506 self._task_socket = self._context.socket(zmq.DEALER)
506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 connect_socket(self._task_socket, task_addr)
508 connect_socket(self._task_socket, task_addr)
508 if content.notification:
509 if content.notification:
509 self._notification_socket = self._context.socket(zmq.SUB)
510 self._notification_socket = self._context.socket(zmq.SUB)
510 connect_socket(self._notification_socket, content.notification)
511 connect_socket(self._notification_socket, content.notification)
511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 # if content.query:
513 # if content.query:
513 # self._query_socket = self._context.socket(zmq.DEALER)
514 # self._query_socket = self._context.socket(zmq.DEALER)
514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 # connect_socket(self._query_socket, content.query)
516 # connect_socket(self._query_socket, content.query)
516 if content.control:
517 if content.control:
517 self._control_socket = self._context.socket(zmq.DEALER)
518 self._control_socket = self._context.socket(zmq.DEALER)
518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 connect_socket(self._control_socket, content.control)
520 connect_socket(self._control_socket, content.control)
520 if content.iopub:
521 if content.iopub:
521 self._iopub_socket = self._context.socket(zmq.SUB)
522 self._iopub_socket = self._context.socket(zmq.SUB)
522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 connect_socket(self._iopub_socket, content.iopub)
525 connect_socket(self._iopub_socket, content.iopub)
525 self._update_engines(dict(content.engines))
526 self._update_engines(dict(content.engines))
526 else:
527 else:
527 self._connected = False
528 self._connected = False
528 raise Exception("Failed to connect!")
529 raise Exception("Failed to connect!")
529
530
530 #--------------------------------------------------------------------------
531 #--------------------------------------------------------------------------
531 # handlers and callbacks for incoming messages
532 # handlers and callbacks for incoming messages
532 #--------------------------------------------------------------------------
533 #--------------------------------------------------------------------------
533
534
534 def _unwrap_exception(self, content):
535 def _unwrap_exception(self, content):
535 """unwrap exception, and remap engine_id to int."""
536 """unwrap exception, and remap engine_id to int."""
536 e = error.unwrap_exception(content)
537 e = error.unwrap_exception(content)
537 # print e.traceback
538 # print e.traceback
538 if e.engine_info:
539 if e.engine_info:
539 e_uuid = e.engine_info['engine_uuid']
540 e_uuid = e.engine_info['engine_uuid']
540 eid = self._engines[e_uuid]
541 eid = self._engines[e_uuid]
541 e.engine_info['engine_id'] = eid
542 e.engine_info['engine_id'] = eid
542 return e
543 return e
543
544
544 def _extract_metadata(self, header, parent, content):
545 def _extract_metadata(self, header, parent, content):
545 md = {'msg_id' : parent['msg_id'],
546 md = {'msg_id' : parent['msg_id'],
546 'received' : datetime.now(),
547 'received' : datetime.now(),
547 'engine_uuid' : header.get('engine', None),
548 'engine_uuid' : header.get('engine', None),
548 'follow' : parent.get('follow', []),
549 'follow' : parent.get('follow', []),
549 'after' : parent.get('after', []),
550 'after' : parent.get('after', []),
550 'status' : content['status'],
551 'status' : content['status'],
551 }
552 }
552
553
553 if md['engine_uuid'] is not None:
554 if md['engine_uuid'] is not None:
554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555
556
556 if 'date' in parent:
557 if 'date' in parent:
557 md['submitted'] = parent['date']
558 md['submitted'] = parent['date']
558 if 'started' in header:
559 if 'started' in header:
559 md['started'] = header['started']
560 md['started'] = header['started']
560 if 'date' in header:
561 if 'date' in header:
561 md['completed'] = header['date']
562 md['completed'] = header['date']
562 return md
563 return md
563
564
564 def _register_engine(self, msg):
565 def _register_engine(self, msg):
565 """Register a new engine, and update our connection info."""
566 """Register a new engine, and update our connection info."""
566 content = msg['content']
567 content = msg['content']
567 eid = content['id']
568 eid = content['id']
568 d = {eid : content['queue']}
569 d = {eid : content['queue']}
569 self._update_engines(d)
570 self._update_engines(d)
570
571
571 def _unregister_engine(self, msg):
572 def _unregister_engine(self, msg):
572 """Unregister an engine that has died."""
573 """Unregister an engine that has died."""
573 content = msg['content']
574 content = msg['content']
574 eid = int(content['id'])
575 eid = int(content['id'])
575 if eid in self._ids:
576 if eid in self._ids:
576 self._ids.remove(eid)
577 self._ids.remove(eid)
577 uuid = self._engines.pop(eid)
578 uuid = self._engines.pop(eid)
578
579
579 self._handle_stranded_msgs(eid, uuid)
580 self._handle_stranded_msgs(eid, uuid)
580
581
581 if self._task_socket and self._task_scheme == 'pure':
582 if self._task_socket and self._task_scheme == 'pure':
582 self._stop_scheduling_tasks()
583 self._stop_scheduling_tasks()
583
584
584 def _handle_stranded_msgs(self, eid, uuid):
585 def _handle_stranded_msgs(self, eid, uuid):
585 """Handle messages known to be on an engine when the engine unregisters.
586 """Handle messages known to be on an engine when the engine unregisters.
586
587
587 It is possible that this will fire prematurely - that is, an engine will
588 It is possible that this will fire prematurely - that is, an engine will
588 go down after completing a result, and the client will be notified
589 go down after completing a result, and the client will be notified
589 of the unregistration and later receive the successful result.
590 of the unregistration and later receive the successful result.
590 """
591 """
591
592
592 outstanding = self._outstanding_dict[uuid]
593 outstanding = self._outstanding_dict[uuid]
593
594
594 for msg_id in list(outstanding):
595 for msg_id in list(outstanding):
595 if msg_id in self.results:
596 if msg_id in self.results:
596 # we already
597 # we already
597 continue
598 continue
598 try:
599 try:
599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 except:
601 except:
601 content = error.wrap_exception()
602 content = error.wrap_exception()
602 # build a fake message:
603 # build a fake message:
603 parent = {}
604 parent = {}
604 header = {}
605 header = {}
605 parent['msg_id'] = msg_id
606 parent['msg_id'] = msg_id
606 header['engine'] = uuid
607 header['engine'] = uuid
607 header['date'] = datetime.now()
608 header['date'] = datetime.now()
608 msg = dict(parent_header=parent, header=header, content=content)
609 msg = dict(parent_header=parent, header=header, content=content)
609 self._handle_apply_reply(msg)
610 self._handle_apply_reply(msg)
610
611
611 def _handle_execute_reply(self, msg):
612 def _handle_execute_reply(self, msg):
612 """Save the reply to an execute_request into our results.
613 """Save the reply to an execute_request into our results.
613
614
614 execute messages are never actually used. apply is used instead.
615 execute messages are never actually used. apply is used instead.
615 """
616 """
616
617
617 parent = msg['parent_header']
618 parent = msg['parent_header']
618 msg_id = parent['msg_id']
619 msg_id = parent['msg_id']
619 if msg_id not in self.outstanding:
620 if msg_id not in self.outstanding:
620 if msg_id in self.history:
621 if msg_id in self.history:
621 print ("got stale result: %s"%msg_id)
622 print ("got stale result: %s"%msg_id)
622 else:
623 else:
623 print ("got unknown result: %s"%msg_id)
624 print ("got unknown result: %s"%msg_id)
624 else:
625 else:
625 self.outstanding.remove(msg_id)
626 self.outstanding.remove(msg_id)
626 self.results[msg_id] = self._unwrap_exception(msg['content'])
627 self.results[msg_id] = self._unwrap_exception(msg['content'])
627
628
628 def _handle_apply_reply(self, msg):
629 def _handle_apply_reply(self, msg):
629 """Save the reply to an apply_request into our results."""
630 """Save the reply to an apply_request into our results."""
630 parent = msg['parent_header']
631 parent = msg['parent_header']
631 msg_id = parent['msg_id']
632 msg_id = parent['msg_id']
632 if msg_id not in self.outstanding:
633 if msg_id not in self.outstanding:
633 if msg_id in self.history:
634 if msg_id in self.history:
634 print ("got stale result: %s"%msg_id)
635 print ("got stale result: %s"%msg_id)
635 print self.results[msg_id]
636 print self.results[msg_id]
636 print msg
637 print msg
637 else:
638 else:
638 print ("got unknown result: %s"%msg_id)
639 print ("got unknown result: %s"%msg_id)
639 else:
640 else:
640 self.outstanding.remove(msg_id)
641 self.outstanding.remove(msg_id)
641 content = msg['content']
642 content = msg['content']
642 header = msg['header']
643 header = msg['header']
643
644
644 # construct metadata:
645 # construct metadata:
645 md = self.metadata[msg_id]
646 md = self.metadata[msg_id]
646 md.update(self._extract_metadata(header, parent, content))
647 md.update(self._extract_metadata(header, parent, content))
647 # is this redundant?
648 # is this redundant?
648 self.metadata[msg_id] = md
649 self.metadata[msg_id] = md
649
650
650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 if msg_id in e_outstanding:
652 if msg_id in e_outstanding:
652 e_outstanding.remove(msg_id)
653 e_outstanding.remove(msg_id)
653
654
654 # construct result:
655 # construct result:
655 if content['status'] == 'ok':
656 if content['status'] == 'ok':
656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 elif content['status'] == 'aborted':
658 elif content['status'] == 'aborted':
658 self.results[msg_id] = error.TaskAborted(msg_id)
659 self.results[msg_id] = error.TaskAborted(msg_id)
659 elif content['status'] == 'resubmitted':
660 elif content['status'] == 'resubmitted':
660 # TODO: handle resubmission
661 # TODO: handle resubmission
661 pass
662 pass
662 else:
663 else:
663 self.results[msg_id] = self._unwrap_exception(content)
664 self.results[msg_id] = self._unwrap_exception(content)
664
665
665 def _flush_notifications(self):
666 def _flush_notifications(self):
666 """Flush notifications of engine registrations waiting
667 """Flush notifications of engine registrations waiting
667 in ZMQ queue."""
668 in ZMQ queue."""
668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 while msg is not None:
670 while msg is not None:
670 if self.debug:
671 if self.debug:
671 pprint(msg)
672 pprint(msg)
672 msg_type = msg['header']['msg_type']
673 msg_type = msg['header']['msg_type']
673 handler = self._notification_handlers.get(msg_type, None)
674 handler = self._notification_handlers.get(msg_type, None)
674 if handler is None:
675 if handler is None:
675 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 else:
677 else:
677 handler(msg)
678 handler(msg)
678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679
680
680 def _flush_results(self, sock):
681 def _flush_results(self, sock):
681 """Flush task or queue results waiting in ZMQ queue."""
682 """Flush task or queue results waiting in ZMQ queue."""
682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 while msg is not None:
684 while msg is not None:
684 if self.debug:
685 if self.debug:
685 pprint(msg)
686 pprint(msg)
686 msg_type = msg['header']['msg_type']
687 msg_type = msg['header']['msg_type']
687 handler = self._queue_handlers.get(msg_type, None)
688 handler = self._queue_handlers.get(msg_type, None)
688 if handler is None:
689 if handler is None:
689 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 else:
691 else:
691 handler(msg)
692 handler(msg)
692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693
694
694 def _flush_control(self, sock):
695 def _flush_control(self, sock):
695 """Flush replies from the control channel waiting
696 """Flush replies from the control channel waiting
696 in the ZMQ queue.
697 in the ZMQ queue.
697
698
698 Currently: ignore them."""
699 Currently: ignore them."""
699 if self._ignored_control_replies <= 0:
700 if self._ignored_control_replies <= 0:
700 return
701 return
701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 while msg is not None:
703 while msg is not None:
703 self._ignored_control_replies -= 1
704 self._ignored_control_replies -= 1
704 if self.debug:
705 if self.debug:
705 pprint(msg)
706 pprint(msg)
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707
708
708 def _flush_ignored_control(self):
709 def _flush_ignored_control(self):
709 """flush ignored control replies"""
710 """flush ignored control replies"""
710 while self._ignored_control_replies > 0:
711 while self._ignored_control_replies > 0:
711 self.session.recv(self._control_socket)
712 self.session.recv(self._control_socket)
712 self._ignored_control_replies -= 1
713 self._ignored_control_replies -= 1
713
714
714 def _flush_ignored_hub_replies(self):
715 def _flush_ignored_hub_replies(self):
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 while msg is not None:
717 while msg is not None:
717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718
719
719 def _flush_iopub(self, sock):
720 def _flush_iopub(self, sock):
720 """Flush replies from the iopub channel waiting
721 """Flush replies from the iopub channel waiting
721 in the ZMQ queue.
722 in the ZMQ queue.
722 """
723 """
723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 while msg is not None:
725 while msg is not None:
725 if self.debug:
726 if self.debug:
726 pprint(msg)
727 pprint(msg)
727 parent = msg['parent_header']
728 parent = msg['parent_header']
728 msg_id = parent['msg_id']
729 msg_id = parent['msg_id']
729 content = msg['content']
730 content = msg['content']
730 header = msg['header']
731 header = msg['header']
731 msg_type = msg['header']['msg_type']
732 msg_type = msg['header']['msg_type']
732
733
733 # init metadata:
734 # init metadata:
734 md = self.metadata[msg_id]
735 md = self.metadata[msg_id]
735
736
736 if msg_type == 'stream':
737 if msg_type == 'stream':
737 name = content['name']
738 name = content['name']
738 s = md[name] or ''
739 s = md[name] or ''
739 md[name] = s + content['data']
740 md[name] = s + content['data']
740 elif msg_type == 'pyerr':
741 elif msg_type == 'pyerr':
741 md.update({'pyerr' : self._unwrap_exception(content)})
742 md.update({'pyerr' : self._unwrap_exception(content)})
742 elif msg_type == 'pyin':
743 elif msg_type == 'pyin':
743 md.update({'pyin' : content['code']})
744 md.update({'pyin' : content['code']})
744 else:
745 else:
745 md.update({msg_type : content.get('data', '')})
746 md.update({msg_type : content.get('data', '')})
746
747
747 # reduntant?
748 # reduntant?
748 self.metadata[msg_id] = md
749 self.metadata[msg_id] = md
749
750
750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751
752
752 #--------------------------------------------------------------------------
753 #--------------------------------------------------------------------------
753 # len, getitem
754 # len, getitem
754 #--------------------------------------------------------------------------
755 #--------------------------------------------------------------------------
755
756
756 def __len__(self):
757 def __len__(self):
757 """len(client) returns # of engines."""
758 """len(client) returns # of engines."""
758 return len(self.ids)
759 return len(self.ids)
759
760
760 def __getitem__(self, key):
761 def __getitem__(self, key):
761 """index access returns DirectView multiplexer objects
762 """index access returns DirectView multiplexer objects
762
763
763 Must be int, slice, or list/tuple/xrange of ints"""
764 Must be int, slice, or list/tuple/xrange of ints"""
764 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 else:
767 else:
767 return self.direct_view(key)
768 return self.direct_view(key)
768
769
769 #--------------------------------------------------------------------------
770 #--------------------------------------------------------------------------
770 # Begin public methods
771 # Begin public methods
771 #--------------------------------------------------------------------------
772 #--------------------------------------------------------------------------
772
773
773 @property
774 @property
774 def ids(self):
775 def ids(self):
775 """Always up-to-date ids property."""
776 """Always up-to-date ids property."""
776 self._flush_notifications()
777 self._flush_notifications()
777 # always copy:
778 # always copy:
778 return list(self._ids)
779 return list(self._ids)
779
780
780 def close(self):
781 def close(self):
781 if self._closed:
782 if self._closed:
782 return
783 return
783 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 for socket in map(lambda name: getattr(self, name), snames):
785 for socket in map(lambda name: getattr(self, name), snames):
785 if isinstance(socket, zmq.Socket) and not socket.closed:
786 if isinstance(socket, zmq.Socket) and not socket.closed:
786 socket.close()
787 socket.close()
787 self._closed = True
788 self._closed = True
788
789
789 def spin(self):
790 def spin(self):
790 """Flush any registration notifications and execution results
791 """Flush any registration notifications and execution results
791 waiting in the ZMQ queue.
792 waiting in the ZMQ queue.
792 """
793 """
793 if self._notification_socket:
794 if self._notification_socket:
794 self._flush_notifications()
795 self._flush_notifications()
795 if self._mux_socket:
796 if self._mux_socket:
796 self._flush_results(self._mux_socket)
797 self._flush_results(self._mux_socket)
797 if self._task_socket:
798 if self._task_socket:
798 self._flush_results(self._task_socket)
799 self._flush_results(self._task_socket)
799 if self._control_socket:
800 if self._control_socket:
800 self._flush_control(self._control_socket)
801 self._flush_control(self._control_socket)
801 if self._iopub_socket:
802 if self._iopub_socket:
802 self._flush_iopub(self._iopub_socket)
803 self._flush_iopub(self._iopub_socket)
803 if self._query_socket:
804 if self._query_socket:
804 self._flush_ignored_hub_replies()
805 self._flush_ignored_hub_replies()
805
806
806 def wait(self, jobs=None, timeout=-1):
807 def wait(self, jobs=None, timeout=-1):
807 """waits on one or more `jobs`, for up to `timeout` seconds.
808 """waits on one or more `jobs`, for up to `timeout` seconds.
808
809
809 Parameters
810 Parameters
810 ----------
811 ----------
811
812
812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 ints are indices to self.history
814 ints are indices to self.history
814 strs are msg_ids
815 strs are msg_ids
815 default: wait on all outstanding messages
816 default: wait on all outstanding messages
816 timeout : float
817 timeout : float
817 a time in seconds, after which to give up.
818 a time in seconds, after which to give up.
818 default is -1, which means no timeout
819 default is -1, which means no timeout
819
820
820 Returns
821 Returns
821 -------
822 -------
822
823
823 True : when all msg_ids are done
824 True : when all msg_ids are done
824 False : timeout reached, some msg_ids still outstanding
825 False : timeout reached, some msg_ids still outstanding
825 """
826 """
826 tic = time.time()
827 tic = time.time()
827 if jobs is None:
828 if jobs is None:
828 theids = self.outstanding
829 theids = self.outstanding
829 else:
830 else:
830 if isinstance(jobs, (int, basestring, AsyncResult)):
831 if isinstance(jobs, (int, basestring, AsyncResult)):
831 jobs = [jobs]
832 jobs = [jobs]
832 theids = set()
833 theids = set()
833 for job in jobs:
834 for job in jobs:
834 if isinstance(job, int):
835 if isinstance(job, int):
835 # index access
836 # index access
836 job = self.history[job]
837 job = self.history[job]
837 elif isinstance(job, AsyncResult):
838 elif isinstance(job, AsyncResult):
838 map(theids.add, job.msg_ids)
839 map(theids.add, job.msg_ids)
839 continue
840 continue
840 theids.add(job)
841 theids.add(job)
841 if not theids.intersection(self.outstanding):
842 if not theids.intersection(self.outstanding):
842 return True
843 return True
843 self.spin()
844 self.spin()
844 while theids.intersection(self.outstanding):
845 while theids.intersection(self.outstanding):
845 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 break
847 break
847 time.sleep(1e-3)
848 time.sleep(1e-3)
848 self.spin()
849 self.spin()
849 return len(theids.intersection(self.outstanding)) == 0
850 return len(theids.intersection(self.outstanding)) == 0
850
851
851 #--------------------------------------------------------------------------
852 #--------------------------------------------------------------------------
852 # Control methods
853 # Control methods
853 #--------------------------------------------------------------------------
854 #--------------------------------------------------------------------------
854
855
855 @spin_first
856 @spin_first
856 def clear(self, targets=None, block=None):
857 def clear(self, targets=None, block=None):
857 """Clear the namespace in target(s)."""
858 """Clear the namespace in target(s)."""
858 block = self.block if block is None else block
859 block = self.block if block is None else block
859 targets = self._build_targets(targets)[0]
860 targets = self._build_targets(targets)[0]
860 for t in targets:
861 for t in targets:
861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 error = False
863 error = False
863 if block:
864 if block:
864 self._flush_ignored_control()
865 self._flush_ignored_control()
865 for i in range(len(targets)):
866 for i in range(len(targets)):
866 idents,msg = self.session.recv(self._control_socket,0)
867 idents,msg = self.session.recv(self._control_socket,0)
867 if self.debug:
868 if self.debug:
868 pprint(msg)
869 pprint(msg)
869 if msg['content']['status'] != 'ok':
870 if msg['content']['status'] != 'ok':
870 error = self._unwrap_exception(msg['content'])
871 error = self._unwrap_exception(msg['content'])
871 else:
872 else:
872 self._ignored_control_replies += len(targets)
873 self._ignored_control_replies += len(targets)
873 if error:
874 if error:
874 raise error
875 raise error
875
876
876
877
877 @spin_first
878 @spin_first
878 def abort(self, jobs=None, targets=None, block=None):
879 def abort(self, jobs=None, targets=None, block=None):
879 """Abort specific jobs from the execution queues of target(s).
880 """Abort specific jobs from the execution queues of target(s).
880
881
881 This is a mechanism to prevent jobs that have already been submitted
882 This is a mechanism to prevent jobs that have already been submitted
882 from executing.
883 from executing.
883
884
884 Parameters
885 Parameters
885 ----------
886 ----------
886
887
887 jobs : msg_id, list of msg_ids, or AsyncResult
888 jobs : msg_id, list of msg_ids, or AsyncResult
888 The jobs to be aborted
889 The jobs to be aborted
889
890
890 If unspecified/None: abort all outstanding jobs.
891 If unspecified/None: abort all outstanding jobs.
891
892
892 """
893 """
893 block = self.block if block is None else block
894 block = self.block if block is None else block
894 jobs = jobs if jobs is not None else list(self.outstanding)
895 jobs = jobs if jobs is not None else list(self.outstanding)
895 targets = self._build_targets(targets)[0]
896 targets = self._build_targets(targets)[0]
896
897
897 msg_ids = []
898 msg_ids = []
898 if isinstance(jobs, (basestring,AsyncResult)):
899 if isinstance(jobs, (basestring,AsyncResult)):
899 jobs = [jobs]
900 jobs = [jobs]
900 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
901 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
901 if bad_ids:
902 if bad_ids:
902 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
903 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
903 for j in jobs:
904 for j in jobs:
904 if isinstance(j, AsyncResult):
905 if isinstance(j, AsyncResult):
905 msg_ids.extend(j.msg_ids)
906 msg_ids.extend(j.msg_ids)
906 else:
907 else:
907 msg_ids.append(j)
908 msg_ids.append(j)
908 content = dict(msg_ids=msg_ids)
909 content = dict(msg_ids=msg_ids)
909 for t in targets:
910 for t in targets:
910 self.session.send(self._control_socket, 'abort_request',
911 self.session.send(self._control_socket, 'abort_request',
911 content=content, ident=t)
912 content=content, ident=t)
912 error = False
913 error = False
913 if block:
914 if block:
914 self._flush_ignored_control()
915 self._flush_ignored_control()
915 for i in range(len(targets)):
916 for i in range(len(targets)):
916 idents,msg = self.session.recv(self._control_socket,0)
917 idents,msg = self.session.recv(self._control_socket,0)
917 if self.debug:
918 if self.debug:
918 pprint(msg)
919 pprint(msg)
919 if msg['content']['status'] != 'ok':
920 if msg['content']['status'] != 'ok':
920 error = self._unwrap_exception(msg['content'])
921 error = self._unwrap_exception(msg['content'])
921 else:
922 else:
922 self._ignored_control_replies += len(targets)
923 self._ignored_control_replies += len(targets)
923 if error:
924 if error:
924 raise error
925 raise error
925
926
926 @spin_first
927 @spin_first
927 def shutdown(self, targets=None, restart=False, hub=False, block=None):
928 def shutdown(self, targets=None, restart=False, hub=False, block=None):
928 """Terminates one or more engine processes, optionally including the hub."""
929 """Terminates one or more engine processes, optionally including the hub."""
929 block = self.block if block is None else block
930 block = self.block if block is None else block
930 if hub:
931 if hub:
931 targets = 'all'
932 targets = 'all'
932 targets = self._build_targets(targets)[0]
933 targets = self._build_targets(targets)[0]
933 for t in targets:
934 for t in targets:
934 self.session.send(self._control_socket, 'shutdown_request',
935 self.session.send(self._control_socket, 'shutdown_request',
935 content={'restart':restart},ident=t)
936 content={'restart':restart},ident=t)
936 error = False
937 error = False
937 if block or hub:
938 if block or hub:
938 self._flush_ignored_control()
939 self._flush_ignored_control()
939 for i in range(len(targets)):
940 for i in range(len(targets)):
940 idents,msg = self.session.recv(self._control_socket, 0)
941 idents,msg = self.session.recv(self._control_socket, 0)
941 if self.debug:
942 if self.debug:
942 pprint(msg)
943 pprint(msg)
943 if msg['content']['status'] != 'ok':
944 if msg['content']['status'] != 'ok':
944 error = self._unwrap_exception(msg['content'])
945 error = self._unwrap_exception(msg['content'])
945 else:
946 else:
946 self._ignored_control_replies += len(targets)
947 self._ignored_control_replies += len(targets)
947
948
948 if hub:
949 if hub:
949 time.sleep(0.25)
950 time.sleep(0.25)
950 self.session.send(self._query_socket, 'shutdown_request')
951 self.session.send(self._query_socket, 'shutdown_request')
951 idents,msg = self.session.recv(self._query_socket, 0)
952 idents,msg = self.session.recv(self._query_socket, 0)
952 if self.debug:
953 if self.debug:
953 pprint(msg)
954 pprint(msg)
954 if msg['content']['status'] != 'ok':
955 if msg['content']['status'] != 'ok':
955 error = self._unwrap_exception(msg['content'])
956 error = self._unwrap_exception(msg['content'])
956
957
957 if error:
958 if error:
958 raise error
959 raise error
959
960
960 #--------------------------------------------------------------------------
961 #--------------------------------------------------------------------------
961 # Execution related methods
962 # Execution related methods
962 #--------------------------------------------------------------------------
963 #--------------------------------------------------------------------------
963
964
964 def _maybe_raise(self, result):
965 def _maybe_raise(self, result):
965 """wrapper for maybe raising an exception if apply failed."""
966 """wrapper for maybe raising an exception if apply failed."""
966 if isinstance(result, error.RemoteError):
967 if isinstance(result, error.RemoteError):
967 raise result
968 raise result
968
969
969 return result
970 return result
970
971
971 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
972 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
972 ident=None):
973 ident=None):
973 """construct and send an apply message via a socket.
974 """construct and send an apply message via a socket.
974
975
975 This is the principal method with which all engine execution is performed by views.
976 This is the principal method with which all engine execution is performed by views.
976 """
977 """
977
978
978 assert not self._closed, "cannot use me anymore, I'm closed!"
979 assert not self._closed, "cannot use me anymore, I'm closed!"
979 # defaults:
980 # defaults:
980 args = args if args is not None else []
981 args = args if args is not None else []
981 kwargs = kwargs if kwargs is not None else {}
982 kwargs = kwargs if kwargs is not None else {}
982 subheader = subheader if subheader is not None else {}
983 subheader = subheader if subheader is not None else {}
983
984
984 # validate arguments
985 # validate arguments
985 if not callable(f):
986 if not callable(f) and not isinstance(f, Reference):
986 raise TypeError("f must be callable, not %s"%type(f))
987 raise TypeError("f must be callable, not %s"%type(f))
987 if not isinstance(args, (tuple, list)):
988 if not isinstance(args, (tuple, list)):
988 raise TypeError("args must be tuple or list, not %s"%type(args))
989 raise TypeError("args must be tuple or list, not %s"%type(args))
989 if not isinstance(kwargs, dict):
990 if not isinstance(kwargs, dict):
990 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
991 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
991 if not isinstance(subheader, dict):
992 if not isinstance(subheader, dict):
992 raise TypeError("subheader must be dict, not %s"%type(subheader))
993 raise TypeError("subheader must be dict, not %s"%type(subheader))
993
994
994 bufs = util.pack_apply_message(f,args,kwargs)
995 bufs = util.pack_apply_message(f,args,kwargs)
995
996
996 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
997 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
997 subheader=subheader, track=track)
998 subheader=subheader, track=track)
998
999
999 msg_id = msg['header']['msg_id']
1000 msg_id = msg['header']['msg_id']
1000 self.outstanding.add(msg_id)
1001 self.outstanding.add(msg_id)
1001 if ident:
1002 if ident:
1002 # possibly routed to a specific engine
1003 # possibly routed to a specific engine
1003 if isinstance(ident, list):
1004 if isinstance(ident, list):
1004 ident = ident[-1]
1005 ident = ident[-1]
1005 if ident in self._engines.values():
1006 if ident in self._engines.values():
1006 # save for later, in case of engine death
1007 # save for later, in case of engine death
1007 self._outstanding_dict[ident].add(msg_id)
1008 self._outstanding_dict[ident].add(msg_id)
1008 self.history.append(msg_id)
1009 self.history.append(msg_id)
1009 self.metadata[msg_id]['submitted'] = datetime.now()
1010 self.metadata[msg_id]['submitted'] = datetime.now()
1010
1011
1011 return msg
1012 return msg
1012
1013
1013 #--------------------------------------------------------------------------
1014 #--------------------------------------------------------------------------
1014 # construct a View object
1015 # construct a View object
1015 #--------------------------------------------------------------------------
1016 #--------------------------------------------------------------------------
1016
1017
1017 def load_balanced_view(self, targets=None):
1018 def load_balanced_view(self, targets=None):
1018 """construct a DirectView object.
1019 """construct a DirectView object.
1019
1020
1020 If no arguments are specified, create a LoadBalancedView
1021 If no arguments are specified, create a LoadBalancedView
1021 using all engines.
1022 using all engines.
1022
1023
1023 Parameters
1024 Parameters
1024 ----------
1025 ----------
1025
1026
1026 targets: list,slice,int,etc. [default: use all engines]
1027 targets: list,slice,int,etc. [default: use all engines]
1027 The subset of engines across which to load-balance
1028 The subset of engines across which to load-balance
1028 """
1029 """
1029 if targets == 'all':
1030 if targets == 'all':
1030 targets = None
1031 targets = None
1031 if targets is not None:
1032 if targets is not None:
1032 targets = self._build_targets(targets)[1]
1033 targets = self._build_targets(targets)[1]
1033 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1034 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1034
1035
1035 def direct_view(self, targets='all'):
1036 def direct_view(self, targets='all'):
1036 """construct a DirectView object.
1037 """construct a DirectView object.
1037
1038
1038 If no targets are specified, create a DirectView using all engines.
1039 If no targets are specified, create a DirectView using all engines.
1039
1040
1040 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1041 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1041 evaluate the target engines at each execution, whereas rc[:] will connect to
1042 evaluate the target engines at each execution, whereas rc[:] will connect to
1042 all *current* engines, and that list will not change.
1043 all *current* engines, and that list will not change.
1043
1044
1044 That is, 'all' will always use all engines, whereas rc[:] will not use
1045 That is, 'all' will always use all engines, whereas rc[:] will not use
1045 engines added after the DirectView is constructed.
1046 engines added after the DirectView is constructed.
1046
1047
1047 Parameters
1048 Parameters
1048 ----------
1049 ----------
1049
1050
1050 targets: list,slice,int,etc. [default: use all engines]
1051 targets: list,slice,int,etc. [default: use all engines]
1051 The engines to use for the View
1052 The engines to use for the View
1052 """
1053 """
1053 single = isinstance(targets, int)
1054 single = isinstance(targets, int)
1054 # allow 'all' to be lazily evaluated at each execution
1055 # allow 'all' to be lazily evaluated at each execution
1055 if targets != 'all':
1056 if targets != 'all':
1056 targets = self._build_targets(targets)[1]
1057 targets = self._build_targets(targets)[1]
1057 if single:
1058 if single:
1058 targets = targets[0]
1059 targets = targets[0]
1059 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1060 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1060
1061
1061 #--------------------------------------------------------------------------
1062 #--------------------------------------------------------------------------
1062 # Query methods
1063 # Query methods
1063 #--------------------------------------------------------------------------
1064 #--------------------------------------------------------------------------
1064
1065
1065 @spin_first
1066 @spin_first
1066 def get_result(self, indices_or_msg_ids=None, block=None):
1067 def get_result(self, indices_or_msg_ids=None, block=None):
1067 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1068 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1068
1069
1069 If the client already has the results, no request to the Hub will be made.
1070 If the client already has the results, no request to the Hub will be made.
1070
1071
1071 This is a convenient way to construct AsyncResult objects, which are wrappers
1072 This is a convenient way to construct AsyncResult objects, which are wrappers
1072 that include metadata about execution, and allow for awaiting results that
1073 that include metadata about execution, and allow for awaiting results that
1073 were not submitted by this Client.
1074 were not submitted by this Client.
1074
1075
1075 It can also be a convenient way to retrieve the metadata associated with
1076 It can also be a convenient way to retrieve the metadata associated with
1076 blocking execution, since it always retrieves
1077 blocking execution, since it always retrieves
1077
1078
1078 Examples
1079 Examples
1079 --------
1080 --------
1080 ::
1081 ::
1081
1082
1082 In [10]: r = client.apply()
1083 In [10]: r = client.apply()
1083
1084
1084 Parameters
1085 Parameters
1085 ----------
1086 ----------
1086
1087
1087 indices_or_msg_ids : integer history index, str msg_id, or list of either
1088 indices_or_msg_ids : integer history index, str msg_id, or list of either
1088 The indices or msg_ids of indices to be retrieved
1089 The indices or msg_ids of indices to be retrieved
1089
1090
1090 block : bool
1091 block : bool
1091 Whether to wait for the result to be done
1092 Whether to wait for the result to be done
1092
1093
1093 Returns
1094 Returns
1094 -------
1095 -------
1095
1096
1096 AsyncResult
1097 AsyncResult
1097 A single AsyncResult object will always be returned.
1098 A single AsyncResult object will always be returned.
1098
1099
1099 AsyncHubResult
1100 AsyncHubResult
1100 A subclass of AsyncResult that retrieves results from the Hub
1101 A subclass of AsyncResult that retrieves results from the Hub
1101
1102
1102 """
1103 """
1103 block = self.block if block is None else block
1104 block = self.block if block is None else block
1104 if indices_or_msg_ids is None:
1105 if indices_or_msg_ids is None:
1105 indices_or_msg_ids = -1
1106 indices_or_msg_ids = -1
1106
1107
1107 if not isinstance(indices_or_msg_ids, (list,tuple)):
1108 if not isinstance(indices_or_msg_ids, (list,tuple)):
1108 indices_or_msg_ids = [indices_or_msg_ids]
1109 indices_or_msg_ids = [indices_or_msg_ids]
1109
1110
1110 theids = []
1111 theids = []
1111 for id in indices_or_msg_ids:
1112 for id in indices_or_msg_ids:
1112 if isinstance(id, int):
1113 if isinstance(id, int):
1113 id = self.history[id]
1114 id = self.history[id]
1114 if not isinstance(id, basestring):
1115 if not isinstance(id, basestring):
1115 raise TypeError("indices must be str or int, not %r"%id)
1116 raise TypeError("indices must be str or int, not %r"%id)
1116 theids.append(id)
1117 theids.append(id)
1117
1118
1118 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1119 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1119 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1120 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1120
1121
1121 if remote_ids:
1122 if remote_ids:
1122 ar = AsyncHubResult(self, msg_ids=theids)
1123 ar = AsyncHubResult(self, msg_ids=theids)
1123 else:
1124 else:
1124 ar = AsyncResult(self, msg_ids=theids)
1125 ar = AsyncResult(self, msg_ids=theids)
1125
1126
1126 if block:
1127 if block:
1127 ar.wait()
1128 ar.wait()
1128
1129
1129 return ar
1130 return ar
1130
1131
1131 @spin_first
1132 @spin_first
1132 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1133 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1133 """Resubmit one or more tasks.
1134 """Resubmit one or more tasks.
1134
1135
1135 in-flight tasks may not be resubmitted.
1136 in-flight tasks may not be resubmitted.
1136
1137
1137 Parameters
1138 Parameters
1138 ----------
1139 ----------
1139
1140
1140 indices_or_msg_ids : integer history index, str msg_id, or list of either
1141 indices_or_msg_ids : integer history index, str msg_id, or list of either
1141 The indices or msg_ids of indices to be retrieved
1142 The indices or msg_ids of indices to be retrieved
1142
1143
1143 block : bool
1144 block : bool
1144 Whether to wait for the result to be done
1145 Whether to wait for the result to be done
1145
1146
1146 Returns
1147 Returns
1147 -------
1148 -------
1148
1149
1149 AsyncHubResult
1150 AsyncHubResult
1150 A subclass of AsyncResult that retrieves results from the Hub
1151 A subclass of AsyncResult that retrieves results from the Hub
1151
1152
1152 """
1153 """
1153 block = self.block if block is None else block
1154 block = self.block if block is None else block
1154 if indices_or_msg_ids is None:
1155 if indices_or_msg_ids is None:
1155 indices_or_msg_ids = -1
1156 indices_or_msg_ids = -1
1156
1157
1157 if not isinstance(indices_or_msg_ids, (list,tuple)):
1158 if not isinstance(indices_or_msg_ids, (list,tuple)):
1158 indices_or_msg_ids = [indices_or_msg_ids]
1159 indices_or_msg_ids = [indices_or_msg_ids]
1159
1160
1160 theids = []
1161 theids = []
1161 for id in indices_or_msg_ids:
1162 for id in indices_or_msg_ids:
1162 if isinstance(id, int):
1163 if isinstance(id, int):
1163 id = self.history[id]
1164 id = self.history[id]
1164 if not isinstance(id, basestring):
1165 if not isinstance(id, basestring):
1165 raise TypeError("indices must be str or int, not %r"%id)
1166 raise TypeError("indices must be str or int, not %r"%id)
1166 theids.append(id)
1167 theids.append(id)
1167
1168
1168 for msg_id in theids:
1169 for msg_id in theids:
1169 self.outstanding.discard(msg_id)
1170 self.outstanding.discard(msg_id)
1170 if msg_id in self.history:
1171 if msg_id in self.history:
1171 self.history.remove(msg_id)
1172 self.history.remove(msg_id)
1172 self.results.pop(msg_id, None)
1173 self.results.pop(msg_id, None)
1173 self.metadata.pop(msg_id, None)
1174 self.metadata.pop(msg_id, None)
1174 content = dict(msg_ids = theids)
1175 content = dict(msg_ids = theids)
1175
1176
1176 self.session.send(self._query_socket, 'resubmit_request', content)
1177 self.session.send(self._query_socket, 'resubmit_request', content)
1177
1178
1178 zmq.select([self._query_socket], [], [])
1179 zmq.select([self._query_socket], [], [])
1179 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1180 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1180 if self.debug:
1181 if self.debug:
1181 pprint(msg)
1182 pprint(msg)
1182 content = msg['content']
1183 content = msg['content']
1183 if content['status'] != 'ok':
1184 if content['status'] != 'ok':
1184 raise self._unwrap_exception(content)
1185 raise self._unwrap_exception(content)
1185
1186
1186 ar = AsyncHubResult(self, msg_ids=theids)
1187 ar = AsyncHubResult(self, msg_ids=theids)
1187
1188
1188 if block:
1189 if block:
1189 ar.wait()
1190 ar.wait()
1190
1191
1191 return ar
1192 return ar
1192
1193
1193 @spin_first
1194 @spin_first
1194 def result_status(self, msg_ids, status_only=True):
1195 def result_status(self, msg_ids, status_only=True):
1195 """Check on the status of the result(s) of the apply request with `msg_ids`.
1196 """Check on the status of the result(s) of the apply request with `msg_ids`.
1196
1197
1197 If status_only is False, then the actual results will be retrieved, else
1198 If status_only is False, then the actual results will be retrieved, else
1198 only the status of the results will be checked.
1199 only the status of the results will be checked.
1199
1200
1200 Parameters
1201 Parameters
1201 ----------
1202 ----------
1202
1203
1203 msg_ids : list of msg_ids
1204 msg_ids : list of msg_ids
1204 if int:
1205 if int:
1205 Passed as index to self.history for convenience.
1206 Passed as index to self.history for convenience.
1206 status_only : bool (default: True)
1207 status_only : bool (default: True)
1207 if False:
1208 if False:
1208 Retrieve the actual results of completed tasks.
1209 Retrieve the actual results of completed tasks.
1209
1210
1210 Returns
1211 Returns
1211 -------
1212 -------
1212
1213
1213 results : dict
1214 results : dict
1214 There will always be the keys 'pending' and 'completed', which will
1215 There will always be the keys 'pending' and 'completed', which will
1215 be lists of msg_ids that are incomplete or complete. If `status_only`
1216 be lists of msg_ids that are incomplete or complete. If `status_only`
1216 is False, then completed results will be keyed by their `msg_id`.
1217 is False, then completed results will be keyed by their `msg_id`.
1217 """
1218 """
1218 if not isinstance(msg_ids, (list,tuple)):
1219 if not isinstance(msg_ids, (list,tuple)):
1219 msg_ids = [msg_ids]
1220 msg_ids = [msg_ids]
1220
1221
1221 theids = []
1222 theids = []
1222 for msg_id in msg_ids:
1223 for msg_id in msg_ids:
1223 if isinstance(msg_id, int):
1224 if isinstance(msg_id, int):
1224 msg_id = self.history[msg_id]
1225 msg_id = self.history[msg_id]
1225 if not isinstance(msg_id, basestring):
1226 if not isinstance(msg_id, basestring):
1226 raise TypeError("msg_ids must be str, not %r"%msg_id)
1227 raise TypeError("msg_ids must be str, not %r"%msg_id)
1227 theids.append(msg_id)
1228 theids.append(msg_id)
1228
1229
1229 completed = []
1230 completed = []
1230 local_results = {}
1231 local_results = {}
1231
1232
1232 # comment this block out to temporarily disable local shortcut:
1233 # comment this block out to temporarily disable local shortcut:
1233 for msg_id in theids:
1234 for msg_id in theids:
1234 if msg_id in self.results:
1235 if msg_id in self.results:
1235 completed.append(msg_id)
1236 completed.append(msg_id)
1236 local_results[msg_id] = self.results[msg_id]
1237 local_results[msg_id] = self.results[msg_id]
1237 theids.remove(msg_id)
1238 theids.remove(msg_id)
1238
1239
1239 if theids: # some not locally cached
1240 if theids: # some not locally cached
1240 content = dict(msg_ids=theids, status_only=status_only)
1241 content = dict(msg_ids=theids, status_only=status_only)
1241 msg = self.session.send(self._query_socket, "result_request", content=content)
1242 msg = self.session.send(self._query_socket, "result_request", content=content)
1242 zmq.select([self._query_socket], [], [])
1243 zmq.select([self._query_socket], [], [])
1243 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1244 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1244 if self.debug:
1245 if self.debug:
1245 pprint(msg)
1246 pprint(msg)
1246 content = msg['content']
1247 content = msg['content']
1247 if content['status'] != 'ok':
1248 if content['status'] != 'ok':
1248 raise self._unwrap_exception(content)
1249 raise self._unwrap_exception(content)
1249 buffers = msg['buffers']
1250 buffers = msg['buffers']
1250 else:
1251 else:
1251 content = dict(completed=[],pending=[])
1252 content = dict(completed=[],pending=[])
1252
1253
1253 content['completed'].extend(completed)
1254 content['completed'].extend(completed)
1254
1255
1255 if status_only:
1256 if status_only:
1256 return content
1257 return content
1257
1258
1258 failures = []
1259 failures = []
1259 # load cached results into result:
1260 # load cached results into result:
1260 content.update(local_results)
1261 content.update(local_results)
1261
1262
1262 # update cache with results:
1263 # update cache with results:
1263 for msg_id in sorted(theids):
1264 for msg_id in sorted(theids):
1264 if msg_id in content['completed']:
1265 if msg_id in content['completed']:
1265 rec = content[msg_id]
1266 rec = content[msg_id]
1266 parent = rec['header']
1267 parent = rec['header']
1267 header = rec['result_header']
1268 header = rec['result_header']
1268 rcontent = rec['result_content']
1269 rcontent = rec['result_content']
1269 iodict = rec['io']
1270 iodict = rec['io']
1270 if isinstance(rcontent, str):
1271 if isinstance(rcontent, str):
1271 rcontent = self.session.unpack(rcontent)
1272 rcontent = self.session.unpack(rcontent)
1272
1273
1273 md = self.metadata[msg_id]
1274 md = self.metadata[msg_id]
1274 md.update(self._extract_metadata(header, parent, rcontent))
1275 md.update(self._extract_metadata(header, parent, rcontent))
1275 md.update(iodict)
1276 md.update(iodict)
1276
1277
1277 if rcontent['status'] == 'ok':
1278 if rcontent['status'] == 'ok':
1278 res,buffers = util.unserialize_object(buffers)
1279 res,buffers = util.unserialize_object(buffers)
1279 else:
1280 else:
1280 print rcontent
1281 print rcontent
1281 res = self._unwrap_exception(rcontent)
1282 res = self._unwrap_exception(rcontent)
1282 failures.append(res)
1283 failures.append(res)
1283
1284
1284 self.results[msg_id] = res
1285 self.results[msg_id] = res
1285 content[msg_id] = res
1286 content[msg_id] = res
1286
1287
1287 if len(theids) == 1 and failures:
1288 if len(theids) == 1 and failures:
1288 raise failures[0]
1289 raise failures[0]
1289
1290
1290 error.collect_exceptions(failures, "result_status")
1291 error.collect_exceptions(failures, "result_status")
1291 return content
1292 return content
1292
1293
1293 @spin_first
1294 @spin_first
1294 def queue_status(self, targets='all', verbose=False):
1295 def queue_status(self, targets='all', verbose=False):
1295 """Fetch the status of engine queues.
1296 """Fetch the status of engine queues.
1296
1297
1297 Parameters
1298 Parameters
1298 ----------
1299 ----------
1299
1300
1300 targets : int/str/list of ints/strs
1301 targets : int/str/list of ints/strs
1301 the engines whose states are to be queried.
1302 the engines whose states are to be queried.
1302 default : all
1303 default : all
1303 verbose : bool
1304 verbose : bool
1304 Whether to return lengths only, or lists of ids for each element
1305 Whether to return lengths only, or lists of ids for each element
1305 """
1306 """
1306 engine_ids = self._build_targets(targets)[1]
1307 engine_ids = self._build_targets(targets)[1]
1307 content = dict(targets=engine_ids, verbose=verbose)
1308 content = dict(targets=engine_ids, verbose=verbose)
1308 self.session.send(self._query_socket, "queue_request", content=content)
1309 self.session.send(self._query_socket, "queue_request", content=content)
1309 idents,msg = self.session.recv(self._query_socket, 0)
1310 idents,msg = self.session.recv(self._query_socket, 0)
1310 if self.debug:
1311 if self.debug:
1311 pprint(msg)
1312 pprint(msg)
1312 content = msg['content']
1313 content = msg['content']
1313 status = content.pop('status')
1314 status = content.pop('status')
1314 if status != 'ok':
1315 if status != 'ok':
1315 raise self._unwrap_exception(content)
1316 raise self._unwrap_exception(content)
1316 content = rekey(content)
1317 content = rekey(content)
1317 if isinstance(targets, int):
1318 if isinstance(targets, int):
1318 return content[targets]
1319 return content[targets]
1319 else:
1320 else:
1320 return content
1321 return content
1321
1322
1322 @spin_first
1323 @spin_first
1323 def purge_results(self, jobs=[], targets=[]):
1324 def purge_results(self, jobs=[], targets=[]):
1324 """Tell the Hub to forget results.
1325 """Tell the Hub to forget results.
1325
1326
1326 Individual results can be purged by msg_id, or the entire
1327 Individual results can be purged by msg_id, or the entire
1327 history of specific targets can be purged.
1328 history of specific targets can be purged.
1328
1329
1329 Use `purge_results('all')` to scrub everything from the Hub's db.
1330 Use `purge_results('all')` to scrub everything from the Hub's db.
1330
1331
1331 Parameters
1332 Parameters
1332 ----------
1333 ----------
1333
1334
1334 jobs : str or list of str or AsyncResult objects
1335 jobs : str or list of str or AsyncResult objects
1335 the msg_ids whose results should be forgotten.
1336 the msg_ids whose results should be forgotten.
1336 targets : int/str/list of ints/strs
1337 targets : int/str/list of ints/strs
1337 The targets, by int_id, whose entire history is to be purged.
1338 The targets, by int_id, whose entire history is to be purged.
1338
1339
1339 default : None
1340 default : None
1340 """
1341 """
1341 if not targets and not jobs:
1342 if not targets and not jobs:
1342 raise ValueError("Must specify at least one of `targets` and `jobs`")
1343 raise ValueError("Must specify at least one of `targets` and `jobs`")
1343 if targets:
1344 if targets:
1344 targets = self._build_targets(targets)[1]
1345 targets = self._build_targets(targets)[1]
1345
1346
1346 # construct msg_ids from jobs
1347 # construct msg_ids from jobs
1347 if jobs == 'all':
1348 if jobs == 'all':
1348 msg_ids = jobs
1349 msg_ids = jobs
1349 else:
1350 else:
1350 msg_ids = []
1351 msg_ids = []
1351 if isinstance(jobs, (basestring,AsyncResult)):
1352 if isinstance(jobs, (basestring,AsyncResult)):
1352 jobs = [jobs]
1353 jobs = [jobs]
1353 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1354 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1354 if bad_ids:
1355 if bad_ids:
1355 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1356 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1356 for j in jobs:
1357 for j in jobs:
1357 if isinstance(j, AsyncResult):
1358 if isinstance(j, AsyncResult):
1358 msg_ids.extend(j.msg_ids)
1359 msg_ids.extend(j.msg_ids)
1359 else:
1360 else:
1360 msg_ids.append(j)
1361 msg_ids.append(j)
1361
1362
1362 content = dict(engine_ids=targets, msg_ids=msg_ids)
1363 content = dict(engine_ids=targets, msg_ids=msg_ids)
1363 self.session.send(self._query_socket, "purge_request", content=content)
1364 self.session.send(self._query_socket, "purge_request", content=content)
1364 idents, msg = self.session.recv(self._query_socket, 0)
1365 idents, msg = self.session.recv(self._query_socket, 0)
1365 if self.debug:
1366 if self.debug:
1366 pprint(msg)
1367 pprint(msg)
1367 content = msg['content']
1368 content = msg['content']
1368 if content['status'] != 'ok':
1369 if content['status'] != 'ok':
1369 raise self._unwrap_exception(content)
1370 raise self._unwrap_exception(content)
1370
1371
1371 @spin_first
1372 @spin_first
1372 def hub_history(self):
1373 def hub_history(self):
1373 """Get the Hub's history
1374 """Get the Hub's history
1374
1375
1375 Just like the Client, the Hub has a history, which is a list of msg_ids.
1376 Just like the Client, the Hub has a history, which is a list of msg_ids.
1376 This will contain the history of all clients, and, depending on configuration,
1377 This will contain the history of all clients, and, depending on configuration,
1377 may contain history across multiple cluster sessions.
1378 may contain history across multiple cluster sessions.
1378
1379
1379 Any msg_id returned here is a valid argument to `get_result`.
1380 Any msg_id returned here is a valid argument to `get_result`.
1380
1381
1381 Returns
1382 Returns
1382 -------
1383 -------
1383
1384
1384 msg_ids : list of strs
1385 msg_ids : list of strs
1385 list of all msg_ids, ordered by task submission time.
1386 list of all msg_ids, ordered by task submission time.
1386 """
1387 """
1387
1388
1388 self.session.send(self._query_socket, "history_request", content={})
1389 self.session.send(self._query_socket, "history_request", content={})
1389 idents, msg = self.session.recv(self._query_socket, 0)
1390 idents, msg = self.session.recv(self._query_socket, 0)
1390
1391
1391 if self.debug:
1392 if self.debug:
1392 pprint(msg)
1393 pprint(msg)
1393 content = msg['content']
1394 content = msg['content']
1394 if content['status'] != 'ok':
1395 if content['status'] != 'ok':
1395 raise self._unwrap_exception(content)
1396 raise self._unwrap_exception(content)
1396 else:
1397 else:
1397 return content['history']
1398 return content['history']
1398
1399
1399 @spin_first
1400 @spin_first
1400 def db_query(self, query, keys=None):
1401 def db_query(self, query, keys=None):
1401 """Query the Hub's TaskRecord database
1402 """Query the Hub's TaskRecord database
1402
1403
1403 This will return a list of task record dicts that match `query`
1404 This will return a list of task record dicts that match `query`
1404
1405
1405 Parameters
1406 Parameters
1406 ----------
1407 ----------
1407
1408
1408 query : mongodb query dict
1409 query : mongodb query dict
1409 The search dict. See mongodb query docs for details.
1410 The search dict. See mongodb query docs for details.
1410 keys : list of strs [optional]
1411 keys : list of strs [optional]
1411 The subset of keys to be returned. The default is to fetch everything but buffers.
1412 The subset of keys to be returned. The default is to fetch everything but buffers.
1412 'msg_id' will *always* be included.
1413 'msg_id' will *always* be included.
1413 """
1414 """
1414 if isinstance(keys, basestring):
1415 if isinstance(keys, basestring):
1415 keys = [keys]
1416 keys = [keys]
1416 content = dict(query=query, keys=keys)
1417 content = dict(query=query, keys=keys)
1417 self.session.send(self._query_socket, "db_request", content=content)
1418 self.session.send(self._query_socket, "db_request", content=content)
1418 idents, msg = self.session.recv(self._query_socket, 0)
1419 idents, msg = self.session.recv(self._query_socket, 0)
1419 if self.debug:
1420 if self.debug:
1420 pprint(msg)
1421 pprint(msg)
1421 content = msg['content']
1422 content = msg['content']
1422 if content['status'] != 'ok':
1423 if content['status'] != 'ok':
1423 raise self._unwrap_exception(content)
1424 raise self._unwrap_exception(content)
1424
1425
1425 records = content['records']
1426 records = content['records']
1426
1427
1427 buffer_lens = content['buffer_lens']
1428 buffer_lens = content['buffer_lens']
1428 result_buffer_lens = content['result_buffer_lens']
1429 result_buffer_lens = content['result_buffer_lens']
1429 buffers = msg['buffers']
1430 buffers = msg['buffers']
1430 has_bufs = buffer_lens is not None
1431 has_bufs = buffer_lens is not None
1431 has_rbufs = result_buffer_lens is not None
1432 has_rbufs = result_buffer_lens is not None
1432 for i,rec in enumerate(records):
1433 for i,rec in enumerate(records):
1433 # relink buffers
1434 # relink buffers
1434 if has_bufs:
1435 if has_bufs:
1435 blen = buffer_lens[i]
1436 blen = buffer_lens[i]
1436 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1437 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1437 if has_rbufs:
1438 if has_rbufs:
1438 blen = result_buffer_lens[i]
1439 blen = result_buffer_lens[i]
1439 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1440 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1440
1441
1441 return records
1442 return records
1442
1443
1443 __all__ = [ 'Client' ]
1444 __all__ = [ 'Client' ]
@@ -1,222 +1,241
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 * Min RK
6 * Min RK
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import sys
21 import sys
22 import warnings
22 import warnings
23
23
24 from IPython.testing.skipdoctest import skip_doctest
24 from IPython.testing.skipdoctest import skip_doctest
25
25
26 from . import map as Map
26 from . import map as Map
27 from .asyncresult import AsyncMapResult
27 from .asyncresult import AsyncMapResult
28
28
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30 # Decorators
30 # Functions and Decorators
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 @skip_doctest
33 @skip_doctest
34 def remote(view, block=None, **flags):
34 def remote(view, block=None, **flags):
35 """Turn a function into a remote function.
35 """Turn a function into a remote function.
36
36
37 This method can be used for map:
37 This method can be used for map:
38
38
39 In [1]: @remote(view,block=True)
39 In [1]: @remote(view,block=True)
40 ...: def func(a):
40 ...: def func(a):
41 ...: pass
41 ...: pass
42 """
42 """
43
43
44 def remote_function(f):
44 def remote_function(f):
45 return RemoteFunction(view, f, block=block, **flags)
45 return RemoteFunction(view, f, block=block, **flags)
46 return remote_function
46 return remote_function
47
47
48 @skip_doctest
48 @skip_doctest
49 def parallel(view, dist='b', block=None, ordered=True, **flags):
49 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 """Turn a function into a parallel remote function.
50 """Turn a function into a parallel remote function.
51
51
52 This method can be used for map:
52 This method can be used for map:
53
53
54 In [1]: @parallel(view, block=True)
54 In [1]: @parallel(view, block=True)
55 ...: def func(a):
55 ...: def func(a):
56 ...: pass
56 ...: pass
57 """
57 """
58
58
59 def parallel_function(f):
59 def parallel_function(f):
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return parallel_function
61 return parallel_function
62
62
63 def getname(f):
64 """Get the name of an object.
65
66 For use in case of callables that are not functions, and
67 thus may not have __name__ defined.
68
69 Order: f.__name__ > f.name > str(f)
70 """
71 try:
72 return f.__name__
73 except:
74 pass
75 try:
76 return f.name
77 except:
78 pass
79
80 return str(f)
81
63 #--------------------------------------------------------------------------
82 #--------------------------------------------------------------------------
64 # Classes
83 # Classes
65 #--------------------------------------------------------------------------
84 #--------------------------------------------------------------------------
66
85
67 class RemoteFunction(object):
86 class RemoteFunction(object):
68 """Turn an existing function into a remote function.
87 """Turn an existing function into a remote function.
69
88
70 Parameters
89 Parameters
71 ----------
90 ----------
72
91
73 view : View instance
92 view : View instance
74 The view to be used for execution
93 The view to be used for execution
75 f : callable
94 f : callable
76 The function to be wrapped into a remote function
95 The function to be wrapped into a remote function
77 block : bool [default: None]
96 block : bool [default: None]
78 Whether to wait for results or not. The default behavior is
97 Whether to wait for results or not. The default behavior is
79 to use the current `block` attribute of `view`
98 to use the current `block` attribute of `view`
80
99
81 **flags : remaining kwargs are passed to View.temp_flags
100 **flags : remaining kwargs are passed to View.temp_flags
82 """
101 """
83
102
84 view = None # the remote connection
103 view = None # the remote connection
85 func = None # the wrapped function
104 func = None # the wrapped function
86 block = None # whether to block
105 block = None # whether to block
87 flags = None # dict of extra kwargs for temp_flags
106 flags = None # dict of extra kwargs for temp_flags
88
107
89 def __init__(self, view, f, block=None, **flags):
108 def __init__(self, view, f, block=None, **flags):
90 self.view = view
109 self.view = view
91 self.func = f
110 self.func = f
92 self.block=block
111 self.block=block
93 self.flags=flags
112 self.flags=flags
94
113
95 def __call__(self, *args, **kwargs):
114 def __call__(self, *args, **kwargs):
96 block = self.view.block if self.block is None else self.block
115 block = self.view.block if self.block is None else self.block
97 with self.view.temp_flags(block=block, **self.flags):
116 with self.view.temp_flags(block=block, **self.flags):
98 return self.view.apply(self.func, *args, **kwargs)
117 return self.view.apply(self.func, *args, **kwargs)
99
118
100
119
101 class ParallelFunction(RemoteFunction):
120 class ParallelFunction(RemoteFunction):
102 """Class for mapping a function to sequences.
121 """Class for mapping a function to sequences.
103
122
104 This will distribute the sequences according the a mapper, and call
123 This will distribute the sequences according the a mapper, and call
105 the function on each sub-sequence. If called via map, then the function
124 the function on each sub-sequence. If called via map, then the function
106 will be called once on each element, rather that each sub-sequence.
125 will be called once on each element, rather that each sub-sequence.
107
126
108 Parameters
127 Parameters
109 ----------
128 ----------
110
129
111 view : View instance
130 view : View instance
112 The view to be used for execution
131 The view to be used for execution
113 f : callable
132 f : callable
114 The function to be wrapped into a remote function
133 The function to be wrapped into a remote function
115 dist : str [default: 'b']
134 dist : str [default: 'b']
116 The key for which mapObject to use to distribute sequences
135 The key for which mapObject to use to distribute sequences
117 options are:
136 options are:
118 * 'b' : use contiguous chunks in order
137 * 'b' : use contiguous chunks in order
119 * 'r' : use round-robin striping
138 * 'r' : use round-robin striping
120 block : bool [default: None]
139 block : bool [default: None]
121 Whether to wait for results or not. The default behavior is
140 Whether to wait for results or not. The default behavior is
122 to use the current `block` attribute of `view`
141 to use the current `block` attribute of `view`
123 chunksize : int or None
142 chunksize : int or None
124 The size of chunk to use when breaking up sequences in a load-balanced manner
143 The size of chunk to use when breaking up sequences in a load-balanced manner
125 ordered : bool [default: True]
144 ordered : bool [default: True]
126 Whether
145 Whether
127 **flags : remaining kwargs are passed to View.temp_flags
146 **flags : remaining kwargs are passed to View.temp_flags
128 """
147 """
129
148
130 chunksize=None
149 chunksize=None
131 ordered=None
150 ordered=None
132 mapObject=None
151 mapObject=None
133
152
134 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
153 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
135 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
154 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
136 self.chunksize = chunksize
155 self.chunksize = chunksize
137 self.ordered = ordered
156 self.ordered = ordered
138
157
139 mapClass = Map.dists[dist]
158 mapClass = Map.dists[dist]
140 self.mapObject = mapClass()
159 self.mapObject = mapClass()
141
160
142 def __call__(self, *sequences):
161 def __call__(self, *sequences):
143 client = self.view.client
162 client = self.view.client
144
163
145 # check that the length of sequences match
164 # check that the length of sequences match
146 len_0 = len(sequences[0])
165 len_0 = len(sequences[0])
147 for s in sequences:
166 for s in sequences:
148 if len(s)!=len_0:
167 if len(s)!=len_0:
149 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
168 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
150 raise ValueError(msg)
169 raise ValueError(msg)
151 balanced = 'Balanced' in self.view.__class__.__name__
170 balanced = 'Balanced' in self.view.__class__.__name__
152 if balanced:
171 if balanced:
153 if self.chunksize:
172 if self.chunksize:
154 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
173 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
155 else:
174 else:
156 nparts = len_0
175 nparts = len_0
157 targets = [None]*nparts
176 targets = [None]*nparts
158 else:
177 else:
159 if self.chunksize:
178 if self.chunksize:
160 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
179 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
161 # multiplexed:
180 # multiplexed:
162 targets = self.view.targets
181 targets = self.view.targets
163 # 'all' is lazily evaluated at execution time, which is now:
182 # 'all' is lazily evaluated at execution time, which is now:
164 if targets == 'all':
183 if targets == 'all':
165 targets = client._build_targets(targets)[1]
184 targets = client._build_targets(targets)[1]
166 nparts = len(targets)
185 nparts = len(targets)
167
186
168 msg_ids = []
187 msg_ids = []
169 for index, t in enumerate(targets):
188 for index, t in enumerate(targets):
170 args = []
189 args = []
171 for seq in sequences:
190 for seq in sequences:
172 part = self.mapObject.getPartition(seq, index, nparts)
191 part = self.mapObject.getPartition(seq, index, nparts)
173 if len(part) == 0:
192 if len(part) == 0:
174 continue
193 continue
175 else:
194 else:
176 args.append(part)
195 args.append(part)
177 if not args:
196 if not args:
178 continue
197 continue
179
198
180 # print (args)
199 # print (args)
181 if hasattr(self, '_map'):
200 if hasattr(self, '_map'):
182 if sys.version_info[0] >= 3:
201 if sys.version_info[0] >= 3:
183 f = lambda f, *sequences: list(map(f, *sequences))
202 f = lambda f, *sequences: list(map(f, *sequences))
184 else:
203 else:
185 f = map
204 f = map
186 args = [self.func]+args
205 args = [self.func]+args
187 else:
206 else:
188 f=self.func
207 f=self.func
189
208
190 view = self.view if balanced else client[t]
209 view = self.view if balanced else client[t]
191 with view.temp_flags(block=False, **self.flags):
210 with view.temp_flags(block=False, **self.flags):
192 ar = view.apply(f, *args)
211 ar = view.apply(f, *args)
193
212
194 msg_ids.append(ar.msg_ids[0])
213 msg_ids.append(ar.msg_ids[0])
195
214
196 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
215 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
197 fname=self.func.__name__,
216 fname=getname(self.func),
198 ordered=self.ordered
217 ordered=self.ordered
199 )
218 )
200
219
201 if self.block:
220 if self.block:
202 try:
221 try:
203 return r.get()
222 return r.get()
204 except KeyboardInterrupt:
223 except KeyboardInterrupt:
205 return r
224 return r
206 else:
225 else:
207 return r
226 return r
208
227
209 def map(self, *sequences):
228 def map(self, *sequences):
210 """call a function on each element of a sequence remotely.
229 """call a function on each element of a sequence remotely.
211 This should behave very much like the builtin map, but return an AsyncMapResult
230 This should behave very much like the builtin map, but return an AsyncMapResult
212 if self.block is False.
231 if self.block is False.
213 """
232 """
214 # set _map as a flag for use inside self.__call__
233 # set _map as a flag for use inside self.__call__
215 self._map = True
234 self._map = True
216 try:
235 try:
217 ret = self.__call__(*sequences)
236 ret = self.__call__(*sequences)
218 finally:
237 finally:
219 del self._map
238 del self._map
220 return ret
239 return ret
221
240
222 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
241 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1065 +1,1065
1 """Views of remote engines.
1 """Views of remote engines.
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import imp
18 import imp
19 import sys
19 import sys
20 import warnings
20 import warnings
21 from contextlib import contextmanager
21 from contextlib import contextmanager
22 from types import ModuleType
22 from types import ModuleType
23
23
24 import zmq
24 import zmq
25
25
26 from IPython.testing.skipdoctest import skip_doctest
26 from IPython.testing.skipdoctest import skip_doctest
27 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 )
29 )
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31
31
32 from IPython.parallel import util
32 from IPython.parallel import util
33 from IPython.parallel.controller.dependency import Dependency, dependent
33 from IPython.parallel.controller.dependency import Dependency, dependent
34
34
35 from . import map as Map
35 from . import map as Map
36 from .asyncresult import AsyncResult, AsyncMapResult
36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote
37 from .remotefunction import ParallelFunction, parallel, remote, getname
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Decorators
40 # Decorators
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 @decorator
43 @decorator
44 def save_ids(f, self, *args, **kwargs):
44 def save_ids(f, self, *args, **kwargs):
45 """Keep our history and outstanding attributes up to date after a method call."""
45 """Keep our history and outstanding attributes up to date after a method call."""
46 n_previous = len(self.client.history)
46 n_previous = len(self.client.history)
47 try:
47 try:
48 ret = f(self, *args, **kwargs)
48 ret = f(self, *args, **kwargs)
49 finally:
49 finally:
50 nmsgs = len(self.client.history) - n_previous
50 nmsgs = len(self.client.history) - n_previous
51 msg_ids = self.client.history[-nmsgs:]
51 msg_ids = self.client.history[-nmsgs:]
52 self.history.extend(msg_ids)
52 self.history.extend(msg_ids)
53 map(self.outstanding.add, msg_ids)
53 map(self.outstanding.add, msg_ids)
54 return ret
54 return ret
55
55
56 @decorator
56 @decorator
57 def sync_results(f, self, *args, **kwargs):
57 def sync_results(f, self, *args, **kwargs):
58 """sync relevant results from self.client to our results attribute."""
58 """sync relevant results from self.client to our results attribute."""
59 ret = f(self, *args, **kwargs)
59 ret = f(self, *args, **kwargs)
60 delta = self.outstanding.difference(self.client.outstanding)
60 delta = self.outstanding.difference(self.client.outstanding)
61 completed = self.outstanding.intersection(delta)
61 completed = self.outstanding.intersection(delta)
62 self.outstanding = self.outstanding.difference(completed)
62 self.outstanding = self.outstanding.difference(completed)
63 for msg_id in completed:
63 for msg_id in completed:
64 self.results[msg_id] = self.client.results[msg_id]
64 self.results[msg_id] = self.client.results[msg_id]
65 return ret
65 return ret
66
66
67 @decorator
67 @decorator
68 def spin_after(f, self, *args, **kwargs):
68 def spin_after(f, self, *args, **kwargs):
69 """call spin after the method."""
69 """call spin after the method."""
70 ret = f(self, *args, **kwargs)
70 ret = f(self, *args, **kwargs)
71 self.spin()
71 self.spin()
72 return ret
72 return ret
73
73
74 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
75 # Classes
75 # Classes
76 #-----------------------------------------------------------------------------
76 #-----------------------------------------------------------------------------
77
77
78 @skip_doctest
78 @skip_doctest
79 class View(HasTraits):
79 class View(HasTraits):
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81
81
82 Don't use this class, use subclasses.
82 Don't use this class, use subclasses.
83
83
84 Methods
84 Methods
85 -------
85 -------
86
86
87 spin
87 spin
88 flushes incoming results and registration state changes
88 flushes incoming results and registration state changes
89 control methods spin, and requesting `ids` also ensures up to date
89 control methods spin, and requesting `ids` also ensures up to date
90
90
91 wait
91 wait
92 wait on one or more msg_ids
92 wait on one or more msg_ids
93
93
94 execution methods
94 execution methods
95 apply
95 apply
96 legacy: execute, run
96 legacy: execute, run
97
97
98 data movement
98 data movement
99 push, pull, scatter, gather
99 push, pull, scatter, gather
100
100
101 query methods
101 query methods
102 get_result, queue_status, purge_results, result_status
102 get_result, queue_status, purge_results, result_status
103
103
104 control methods
104 control methods
105 abort, shutdown
105 abort, shutdown
106
106
107 """
107 """
108 # flags
108 # flags
109 block=Bool(False)
109 block=Bool(False)
110 track=Bool(True)
110 track=Bool(True)
111 targets = Any()
111 targets = Any()
112
112
113 history=List()
113 history=List()
114 outstanding = Set()
114 outstanding = Set()
115 results = Dict()
115 results = Dict()
116 client = Instance('IPython.parallel.Client')
116 client = Instance('IPython.parallel.Client')
117
117
118 _socket = Instance('zmq.Socket')
118 _socket = Instance('zmq.Socket')
119 _flag_names = List(['targets', 'block', 'track'])
119 _flag_names = List(['targets', 'block', 'track'])
120 _targets = Any()
120 _targets = Any()
121 _idents = Any()
121 _idents = Any()
122
122
123 def __init__(self, client=None, socket=None, **flags):
123 def __init__(self, client=None, socket=None, **flags):
124 super(View, self).__init__(client=client, _socket=socket)
124 super(View, self).__init__(client=client, _socket=socket)
125 self.block = client.block
125 self.block = client.block
126
126
127 self.set_flags(**flags)
127 self.set_flags(**flags)
128
128
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130
130
131
131
132 def __repr__(self):
132 def __repr__(self):
133 strtargets = str(self.targets)
133 strtargets = str(self.targets)
134 if len(strtargets) > 16:
134 if len(strtargets) > 16:
135 strtargets = strtargets[:12]+'...]'
135 strtargets = strtargets[:12]+'...]'
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137
137
138 def set_flags(self, **kwargs):
138 def set_flags(self, **kwargs):
139 """set my attribute flags by keyword.
139 """set my attribute flags by keyword.
140
140
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 These attributes can be set all at once by name with this method.
142 These attributes can be set all at once by name with this method.
143
143
144 Parameters
144 Parameters
145 ----------
145 ----------
146
146
147 block : bool
147 block : bool
148 whether to wait for results
148 whether to wait for results
149 track : bool
149 track : bool
150 whether to create a MessageTracker to allow the user to
150 whether to create a MessageTracker to allow the user to
151 safely edit after arrays and buffers during non-copying
151 safely edit after arrays and buffers during non-copying
152 sends.
152 sends.
153 """
153 """
154 for name, value in kwargs.iteritems():
154 for name, value in kwargs.iteritems():
155 if name not in self._flag_names:
155 if name not in self._flag_names:
156 raise KeyError("Invalid name: %r"%name)
156 raise KeyError("Invalid name: %r"%name)
157 else:
157 else:
158 setattr(self, name, value)
158 setattr(self, name, value)
159
159
160 @contextmanager
160 @contextmanager
161 def temp_flags(self, **kwargs):
161 def temp_flags(self, **kwargs):
162 """temporarily set flags, for use in `with` statements.
162 """temporarily set flags, for use in `with` statements.
163
163
164 See set_flags for permanent setting of flags
164 See set_flags for permanent setting of flags
165
165
166 Examples
166 Examples
167 --------
167 --------
168
168
169 >>> view.track=False
169 >>> view.track=False
170 ...
170 ...
171 >>> with view.temp_flags(track=True):
171 >>> with view.temp_flags(track=True):
172 ... ar = view.apply(dostuff, my_big_array)
172 ... ar = view.apply(dostuff, my_big_array)
173 ... ar.tracker.wait() # wait for send to finish
173 ... ar.tracker.wait() # wait for send to finish
174 >>> view.track
174 >>> view.track
175 False
175 False
176
176
177 """
177 """
178 # preflight: save flags, and set temporaries
178 # preflight: save flags, and set temporaries
179 saved_flags = {}
179 saved_flags = {}
180 for f in self._flag_names:
180 for f in self._flag_names:
181 saved_flags[f] = getattr(self, f)
181 saved_flags[f] = getattr(self, f)
182 self.set_flags(**kwargs)
182 self.set_flags(**kwargs)
183 # yield to the with-statement block
183 # yield to the with-statement block
184 try:
184 try:
185 yield
185 yield
186 finally:
186 finally:
187 # postflight: restore saved flags
187 # postflight: restore saved flags
188 self.set_flags(**saved_flags)
188 self.set_flags(**saved_flags)
189
189
190
190
191 #----------------------------------------------------------------
191 #----------------------------------------------------------------
192 # apply
192 # apply
193 #----------------------------------------------------------------
193 #----------------------------------------------------------------
194
194
195 @sync_results
195 @sync_results
196 @save_ids
196 @save_ids
197 def _really_apply(self, f, args, kwargs, block=None, **options):
197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 """wrapper for client.send_apply_message"""
198 """wrapper for client.send_apply_message"""
199 raise NotImplementedError("Implement in subclasses")
199 raise NotImplementedError("Implement in subclasses")
200
200
201 def apply(self, f, *args, **kwargs):
201 def apply(self, f, *args, **kwargs):
202 """calls f(*args, **kwargs) on remote engines, returning the result.
202 """calls f(*args, **kwargs) on remote engines, returning the result.
203
203
204 This method sets all apply flags via this View's attributes.
204 This method sets all apply flags via this View's attributes.
205
205
206 if self.block is False:
206 if self.block is False:
207 returns AsyncResult
207 returns AsyncResult
208 else:
208 else:
209 returns actual result of f(*args, **kwargs)
209 returns actual result of f(*args, **kwargs)
210 """
210 """
211 return self._really_apply(f, args, kwargs)
211 return self._really_apply(f, args, kwargs)
212
212
213 def apply_async(self, f, *args, **kwargs):
213 def apply_async(self, f, *args, **kwargs):
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215
215
216 returns AsyncResult
216 returns AsyncResult
217 """
217 """
218 return self._really_apply(f, args, kwargs, block=False)
218 return self._really_apply(f, args, kwargs, block=False)
219
219
220 @spin_after
220 @spin_after
221 def apply_sync(self, f, *args, **kwargs):
221 def apply_sync(self, f, *args, **kwargs):
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 returning the result.
223 returning the result.
224
224
225 returns: actual result of f(*args, **kwargs)
225 returns: actual result of f(*args, **kwargs)
226 """
226 """
227 return self._really_apply(f, args, kwargs, block=True)
227 return self._really_apply(f, args, kwargs, block=True)
228
228
229 #----------------------------------------------------------------
229 #----------------------------------------------------------------
230 # wrappers for client and control methods
230 # wrappers for client and control methods
231 #----------------------------------------------------------------
231 #----------------------------------------------------------------
232 @sync_results
232 @sync_results
233 def spin(self):
233 def spin(self):
234 """spin the client, and sync"""
234 """spin the client, and sync"""
235 self.client.spin()
235 self.client.spin()
236
236
237 @sync_results
237 @sync_results
238 def wait(self, jobs=None, timeout=-1):
238 def wait(self, jobs=None, timeout=-1):
239 """waits on one or more `jobs`, for up to `timeout` seconds.
239 """waits on one or more `jobs`, for up to `timeout` seconds.
240
240
241 Parameters
241 Parameters
242 ----------
242 ----------
243
243
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 ints are indices to self.history
245 ints are indices to self.history
246 strs are msg_ids
246 strs are msg_ids
247 default: wait on all outstanding messages
247 default: wait on all outstanding messages
248 timeout : float
248 timeout : float
249 a time in seconds, after which to give up.
249 a time in seconds, after which to give up.
250 default is -1, which means no timeout
250 default is -1, which means no timeout
251
251
252 Returns
252 Returns
253 -------
253 -------
254
254
255 True : when all msg_ids are done
255 True : when all msg_ids are done
256 False : timeout reached, some msg_ids still outstanding
256 False : timeout reached, some msg_ids still outstanding
257 """
257 """
258 if jobs is None:
258 if jobs is None:
259 jobs = self.history
259 jobs = self.history
260 return self.client.wait(jobs, timeout)
260 return self.client.wait(jobs, timeout)
261
261
262 def abort(self, jobs=None, targets=None, block=None):
262 def abort(self, jobs=None, targets=None, block=None):
263 """Abort jobs on my engines.
263 """Abort jobs on my engines.
264
264
265 Parameters
265 Parameters
266 ----------
266 ----------
267
267
268 jobs : None, str, list of strs, optional
268 jobs : None, str, list of strs, optional
269 if None: abort all jobs.
269 if None: abort all jobs.
270 else: abort specific msg_id(s).
270 else: abort specific msg_id(s).
271 """
271 """
272 block = block if block is not None else self.block
272 block = block if block is not None else self.block
273 targets = targets if targets is not None else self.targets
273 targets = targets if targets is not None else self.targets
274 jobs = jobs if jobs is not None else list(self.outstanding)
274 jobs = jobs if jobs is not None else list(self.outstanding)
275
275
276 return self.client.abort(jobs=jobs, targets=targets, block=block)
276 return self.client.abort(jobs=jobs, targets=targets, block=block)
277
277
278 def queue_status(self, targets=None, verbose=False):
278 def queue_status(self, targets=None, verbose=False):
279 """Fetch the Queue status of my engines"""
279 """Fetch the Queue status of my engines"""
280 targets = targets if targets is not None else self.targets
280 targets = targets if targets is not None else self.targets
281 return self.client.queue_status(targets=targets, verbose=verbose)
281 return self.client.queue_status(targets=targets, verbose=verbose)
282
282
283 def purge_results(self, jobs=[], targets=[]):
283 def purge_results(self, jobs=[], targets=[]):
284 """Instruct the controller to forget specific results."""
284 """Instruct the controller to forget specific results."""
285 if targets is None or targets == 'all':
285 if targets is None or targets == 'all':
286 targets = self.targets
286 targets = self.targets
287 return self.client.purge_results(jobs=jobs, targets=targets)
287 return self.client.purge_results(jobs=jobs, targets=targets)
288
288
289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 """Terminates one or more engine processes, optionally including the hub.
290 """Terminates one or more engine processes, optionally including the hub.
291 """
291 """
292 block = self.block if block is None else block
292 block = self.block if block is None else block
293 if targets is None or targets == 'all':
293 if targets is None or targets == 'all':
294 targets = self.targets
294 targets = self.targets
295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296
296
297 @spin_after
297 @spin_after
298 def get_result(self, indices_or_msg_ids=None):
298 def get_result(self, indices_or_msg_ids=None):
299 """return one or more results, specified by history index or msg_id.
299 """return one or more results, specified by history index or msg_id.
300
300
301 See client.get_result for details.
301 See client.get_result for details.
302
302
303 """
303 """
304
304
305 if indices_or_msg_ids is None:
305 if indices_or_msg_ids is None:
306 indices_or_msg_ids = -1
306 indices_or_msg_ids = -1
307 if isinstance(indices_or_msg_ids, int):
307 if isinstance(indices_or_msg_ids, int):
308 indices_or_msg_ids = self.history[indices_or_msg_ids]
308 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 indices_or_msg_ids = list(indices_or_msg_ids)
310 indices_or_msg_ids = list(indices_or_msg_ids)
311 for i,index in enumerate(indices_or_msg_ids):
311 for i,index in enumerate(indices_or_msg_ids):
312 if isinstance(index, int):
312 if isinstance(index, int):
313 indices_or_msg_ids[i] = self.history[index]
313 indices_or_msg_ids[i] = self.history[index]
314 return self.client.get_result(indices_or_msg_ids)
314 return self.client.get_result(indices_or_msg_ids)
315
315
316 #-------------------------------------------------------------------
316 #-------------------------------------------------------------------
317 # Map
317 # Map
318 #-------------------------------------------------------------------
318 #-------------------------------------------------------------------
319
319
320 def map(self, f, *sequences, **kwargs):
320 def map(self, f, *sequences, **kwargs):
321 """override in subclasses"""
321 """override in subclasses"""
322 raise NotImplementedError
322 raise NotImplementedError
323
323
324 def map_async(self, f, *sequences, **kwargs):
324 def map_async(self, f, *sequences, **kwargs):
325 """Parallel version of builtin `map`, using this view's engines.
325 """Parallel version of builtin `map`, using this view's engines.
326
326
327 This is equivalent to map(...block=False)
327 This is equivalent to map(...block=False)
328
328
329 See `self.map` for details.
329 See `self.map` for details.
330 """
330 """
331 if 'block' in kwargs:
331 if 'block' in kwargs:
332 raise TypeError("map_async doesn't take a `block` keyword argument.")
332 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 kwargs['block'] = False
333 kwargs['block'] = False
334 return self.map(f,*sequences,**kwargs)
334 return self.map(f,*sequences,**kwargs)
335
335
336 def map_sync(self, f, *sequences, **kwargs):
336 def map_sync(self, f, *sequences, **kwargs):
337 """Parallel version of builtin `map`, using this view's engines.
337 """Parallel version of builtin `map`, using this view's engines.
338
338
339 This is equivalent to map(...block=True)
339 This is equivalent to map(...block=True)
340
340
341 See `self.map` for details.
341 See `self.map` for details.
342 """
342 """
343 if 'block' in kwargs:
343 if 'block' in kwargs:
344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 kwargs['block'] = True
345 kwargs['block'] = True
346 return self.map(f,*sequences,**kwargs)
346 return self.map(f,*sequences,**kwargs)
347
347
348 def imap(self, f, *sequences, **kwargs):
348 def imap(self, f, *sequences, **kwargs):
349 """Parallel version of `itertools.imap`.
349 """Parallel version of `itertools.imap`.
350
350
351 See `self.map` for details.
351 See `self.map` for details.
352
352
353 """
353 """
354
354
355 return iter(self.map_async(f,*sequences, **kwargs))
355 return iter(self.map_async(f,*sequences, **kwargs))
356
356
357 #-------------------------------------------------------------------
357 #-------------------------------------------------------------------
358 # Decorators
358 # Decorators
359 #-------------------------------------------------------------------
359 #-------------------------------------------------------------------
360
360
361 def remote(self, block=True, **flags):
361 def remote(self, block=True, **flags):
362 """Decorator for making a RemoteFunction"""
362 """Decorator for making a RemoteFunction"""
363 block = self.block if block is None else block
363 block = self.block if block is None else block
364 return remote(self, block=block, **flags)
364 return remote(self, block=block, **flags)
365
365
366 def parallel(self, dist='b', block=None, **flags):
366 def parallel(self, dist='b', block=None, **flags):
367 """Decorator for making a ParallelFunction"""
367 """Decorator for making a ParallelFunction"""
368 block = self.block if block is None else block
368 block = self.block if block is None else block
369 return parallel(self, dist=dist, block=block, **flags)
369 return parallel(self, dist=dist, block=block, **flags)
370
370
371 @skip_doctest
371 @skip_doctest
372 class DirectView(View):
372 class DirectView(View):
373 """Direct Multiplexer View of one or more engines.
373 """Direct Multiplexer View of one or more engines.
374
374
375 These are created via indexed access to a client:
375 These are created via indexed access to a client:
376
376
377 >>> dv_1 = client[1]
377 >>> dv_1 = client[1]
378 >>> dv_all = client[:]
378 >>> dv_all = client[:]
379 >>> dv_even = client[::2]
379 >>> dv_even = client[::2]
380 >>> dv_some = client[1:3]
380 >>> dv_some = client[1:3]
381
381
382 This object provides dictionary access to engine namespaces:
382 This object provides dictionary access to engine namespaces:
383
383
384 # push a=5:
384 # push a=5:
385 >>> dv['a'] = 5
385 >>> dv['a'] = 5
386 # pull 'foo':
386 # pull 'foo':
387 >>> db['foo']
387 >>> db['foo']
388
388
389 """
389 """
390
390
391 def __init__(self, client=None, socket=None, targets=None):
391 def __init__(self, client=None, socket=None, targets=None):
392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393
393
394 @property
394 @property
395 def importer(self):
395 def importer(self):
396 """sync_imports(local=True) as a property.
396 """sync_imports(local=True) as a property.
397
397
398 See sync_imports for details.
398 See sync_imports for details.
399
399
400 """
400 """
401 return self.sync_imports(True)
401 return self.sync_imports(True)
402
402
403 @contextmanager
403 @contextmanager
404 def sync_imports(self, local=True):
404 def sync_imports(self, local=True):
405 """Context Manager for performing simultaneous local and remote imports.
405 """Context Manager for performing simultaneous local and remote imports.
406
406
407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408
408
409 If `local=True`, then the package will also be imported locally.
409 If `local=True`, then the package will also be imported locally.
410
410
411 Note that remote-only (`local=False`) imports have not been implemented.
411 Note that remote-only (`local=False`) imports have not been implemented.
412
412
413 >>> with view.sync_imports():
413 >>> with view.sync_imports():
414 ... from numpy import recarray
414 ... from numpy import recarray
415 importing recarray from numpy on engine(s)
415 importing recarray from numpy on engine(s)
416
416
417 """
417 """
418 import __builtin__
418 import __builtin__
419 local_import = __builtin__.__import__
419 local_import = __builtin__.__import__
420 modules = set()
420 modules = set()
421 results = []
421 results = []
422 @util.interactive
422 @util.interactive
423 def remote_import(name, fromlist, level):
423 def remote_import(name, fromlist, level):
424 """the function to be passed to apply, that actually performs the import
424 """the function to be passed to apply, that actually performs the import
425 on the engine, and loads up the user namespace.
425 on the engine, and loads up the user namespace.
426 """
426 """
427 import sys
427 import sys
428 user_ns = globals()
428 user_ns = globals()
429 mod = __import__(name, fromlist=fromlist, level=level)
429 mod = __import__(name, fromlist=fromlist, level=level)
430 if fromlist:
430 if fromlist:
431 for key in fromlist:
431 for key in fromlist:
432 user_ns[key] = getattr(mod, key)
432 user_ns[key] = getattr(mod, key)
433 else:
433 else:
434 user_ns[name] = sys.modules[name]
434 user_ns[name] = sys.modules[name]
435
435
436 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
436 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
437 """the drop-in replacement for __import__, that optionally imports
437 """the drop-in replacement for __import__, that optionally imports
438 locally as well.
438 locally as well.
439 """
439 """
440 # don't override nested imports
440 # don't override nested imports
441 save_import = __builtin__.__import__
441 save_import = __builtin__.__import__
442 __builtin__.__import__ = local_import
442 __builtin__.__import__ = local_import
443
443
444 if imp.lock_held():
444 if imp.lock_held():
445 # this is a side-effect import, don't do it remotely, or even
445 # this is a side-effect import, don't do it remotely, or even
446 # ignore the local effects
446 # ignore the local effects
447 return local_import(name, globals, locals, fromlist, level)
447 return local_import(name, globals, locals, fromlist, level)
448
448
449 imp.acquire_lock()
449 imp.acquire_lock()
450 if local:
450 if local:
451 mod = local_import(name, globals, locals, fromlist, level)
451 mod = local_import(name, globals, locals, fromlist, level)
452 else:
452 else:
453 raise NotImplementedError("remote-only imports not yet implemented")
453 raise NotImplementedError("remote-only imports not yet implemented")
454 imp.release_lock()
454 imp.release_lock()
455
455
456 key = name+':'+','.join(fromlist or [])
456 key = name+':'+','.join(fromlist or [])
457 if level == -1 and key not in modules:
457 if level == -1 and key not in modules:
458 modules.add(key)
458 modules.add(key)
459 if fromlist:
459 if fromlist:
460 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
460 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
461 else:
461 else:
462 print "importing %s on engine(s)"%name
462 print "importing %s on engine(s)"%name
463 results.append(self.apply_async(remote_import, name, fromlist, level))
463 results.append(self.apply_async(remote_import, name, fromlist, level))
464 # restore override
464 # restore override
465 __builtin__.__import__ = save_import
465 __builtin__.__import__ = save_import
466
466
467 return mod
467 return mod
468
468
469 # override __import__
469 # override __import__
470 __builtin__.__import__ = view_import
470 __builtin__.__import__ = view_import
471 try:
471 try:
472 # enter the block
472 # enter the block
473 yield
473 yield
474 except ImportError:
474 except ImportError:
475 if local:
475 if local:
476 raise
476 raise
477 else:
477 else:
478 # ignore import errors if not doing local imports
478 # ignore import errors if not doing local imports
479 pass
479 pass
480 finally:
480 finally:
481 # always restore __import__
481 # always restore __import__
482 __builtin__.__import__ = local_import
482 __builtin__.__import__ = local_import
483
483
484 for r in results:
484 for r in results:
485 # raise possible remote ImportErrors here
485 # raise possible remote ImportErrors here
486 r.get()
486 r.get()
487
487
488
488
489 @sync_results
489 @sync_results
490 @save_ids
490 @save_ids
491 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
491 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
492 """calls f(*args, **kwargs) on remote engines, returning the result.
492 """calls f(*args, **kwargs) on remote engines, returning the result.
493
493
494 This method sets all of `apply`'s flags via this View's attributes.
494 This method sets all of `apply`'s flags via this View's attributes.
495
495
496 Parameters
496 Parameters
497 ----------
497 ----------
498
498
499 f : callable
499 f : callable
500
500
501 args : list [default: empty]
501 args : list [default: empty]
502
502
503 kwargs : dict [default: empty]
503 kwargs : dict [default: empty]
504
504
505 targets : target list [default: self.targets]
505 targets : target list [default: self.targets]
506 where to run
506 where to run
507 block : bool [default: self.block]
507 block : bool [default: self.block]
508 whether to block
508 whether to block
509 track : bool [default: self.track]
509 track : bool [default: self.track]
510 whether to ask zmq to track the message, for safe non-copying sends
510 whether to ask zmq to track the message, for safe non-copying sends
511
511
512 Returns
512 Returns
513 -------
513 -------
514
514
515 if self.block is False:
515 if self.block is False:
516 returns AsyncResult
516 returns AsyncResult
517 else:
517 else:
518 returns actual result of f(*args, **kwargs) on the engine(s)
518 returns actual result of f(*args, **kwargs) on the engine(s)
519 This will be a list of self.targets is also a list (even length 1), or
519 This will be a list of self.targets is also a list (even length 1), or
520 the single result if self.targets is an integer engine id
520 the single result if self.targets is an integer engine id
521 """
521 """
522 args = [] if args is None else args
522 args = [] if args is None else args
523 kwargs = {} if kwargs is None else kwargs
523 kwargs = {} if kwargs is None else kwargs
524 block = self.block if block is None else block
524 block = self.block if block is None else block
525 track = self.track if track is None else track
525 track = self.track if track is None else track
526 targets = self.targets if targets is None else targets
526 targets = self.targets if targets is None else targets
527
527
528 _idents = self.client._build_targets(targets)[0]
528 _idents = self.client._build_targets(targets)[0]
529 msg_ids = []
529 msg_ids = []
530 trackers = []
530 trackers = []
531 for ident in _idents:
531 for ident in _idents:
532 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
532 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
533 ident=ident)
533 ident=ident)
534 if track:
534 if track:
535 trackers.append(msg['tracker'])
535 trackers.append(msg['tracker'])
536 msg_ids.append(msg['header']['msg_id'])
536 msg_ids.append(msg['header']['msg_id'])
537 tracker = None if track is False else zmq.MessageTracker(*trackers)
537 tracker = None if track is False else zmq.MessageTracker(*trackers)
538 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
538 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
539 if block:
539 if block:
540 try:
540 try:
541 return ar.get()
541 return ar.get()
542 except KeyboardInterrupt:
542 except KeyboardInterrupt:
543 pass
543 pass
544 return ar
544 return ar
545
545
546 @spin_after
546 @spin_after
547 def map(self, f, *sequences, **kwargs):
547 def map(self, f, *sequences, **kwargs):
548 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
548 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
549
549
550 Parallel version of builtin `map`, using this View's `targets`.
550 Parallel version of builtin `map`, using this View's `targets`.
551
551
552 There will be one task per target, so work will be chunked
552 There will be one task per target, so work will be chunked
553 if the sequences are longer than `targets`.
553 if the sequences are longer than `targets`.
554
554
555 Results can be iterated as they are ready, but will become available in chunks.
555 Results can be iterated as they are ready, but will become available in chunks.
556
556
557 Parameters
557 Parameters
558 ----------
558 ----------
559
559
560 f : callable
560 f : callable
561 function to be mapped
561 function to be mapped
562 *sequences: one or more sequences of matching length
562 *sequences: one or more sequences of matching length
563 the sequences to be distributed and passed to `f`
563 the sequences to be distributed and passed to `f`
564 block : bool
564 block : bool
565 whether to wait for the result or not [default self.block]
565 whether to wait for the result or not [default self.block]
566
566
567 Returns
567 Returns
568 -------
568 -------
569
569
570 if block=False:
570 if block=False:
571 AsyncMapResult
571 AsyncMapResult
572 An object like AsyncResult, but which reassembles the sequence of results
572 An object like AsyncResult, but which reassembles the sequence of results
573 into a single list. AsyncMapResults can be iterated through before all
573 into a single list. AsyncMapResults can be iterated through before all
574 results are complete.
574 results are complete.
575 else:
575 else:
576 list
576 list
577 the result of map(f,*sequences)
577 the result of map(f,*sequences)
578 """
578 """
579
579
580 block = kwargs.pop('block', self.block)
580 block = kwargs.pop('block', self.block)
581 for k in kwargs.keys():
581 for k in kwargs.keys():
582 if k not in ['block', 'track']:
582 if k not in ['block', 'track']:
583 raise TypeError("invalid keyword arg, %r"%k)
583 raise TypeError("invalid keyword arg, %r"%k)
584
584
585 assert len(sequences) > 0, "must have some sequences to map onto!"
585 assert len(sequences) > 0, "must have some sequences to map onto!"
586 pf = ParallelFunction(self, f, block=block, **kwargs)
586 pf = ParallelFunction(self, f, block=block, **kwargs)
587 return pf.map(*sequences)
587 return pf.map(*sequences)
588
588
589 def execute(self, code, targets=None, block=None):
589 def execute(self, code, targets=None, block=None):
590 """Executes `code` on `targets` in blocking or nonblocking manner.
590 """Executes `code` on `targets` in blocking or nonblocking manner.
591
591
592 ``execute`` is always `bound` (affects engine namespace)
592 ``execute`` is always `bound` (affects engine namespace)
593
593
594 Parameters
594 Parameters
595 ----------
595 ----------
596
596
597 code : str
597 code : str
598 the code string to be executed
598 the code string to be executed
599 block : bool
599 block : bool
600 whether or not to wait until done to return
600 whether or not to wait until done to return
601 default: self.block
601 default: self.block
602 """
602 """
603 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
603 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
604
604
605 def run(self, filename, targets=None, block=None):
605 def run(self, filename, targets=None, block=None):
606 """Execute contents of `filename` on my engine(s).
606 """Execute contents of `filename` on my engine(s).
607
607
608 This simply reads the contents of the file and calls `execute`.
608 This simply reads the contents of the file and calls `execute`.
609
609
610 Parameters
610 Parameters
611 ----------
611 ----------
612
612
613 filename : str
613 filename : str
614 The path to the file
614 The path to the file
615 targets : int/str/list of ints/strs
615 targets : int/str/list of ints/strs
616 the engines on which to execute
616 the engines on which to execute
617 default : all
617 default : all
618 block : bool
618 block : bool
619 whether or not to wait until done
619 whether or not to wait until done
620 default: self.block
620 default: self.block
621
621
622 """
622 """
623 with open(filename, 'r') as f:
623 with open(filename, 'r') as f:
624 # add newline in case of trailing indented whitespace
624 # add newline in case of trailing indented whitespace
625 # which will cause SyntaxError
625 # which will cause SyntaxError
626 code = f.read()+'\n'
626 code = f.read()+'\n'
627 return self.execute(code, block=block, targets=targets)
627 return self.execute(code, block=block, targets=targets)
628
628
629 def update(self, ns):
629 def update(self, ns):
630 """update remote namespace with dict `ns`
630 """update remote namespace with dict `ns`
631
631
632 See `push` for details.
632 See `push` for details.
633 """
633 """
634 return self.push(ns, block=self.block, track=self.track)
634 return self.push(ns, block=self.block, track=self.track)
635
635
636 def push(self, ns, targets=None, block=None, track=None):
636 def push(self, ns, targets=None, block=None, track=None):
637 """update remote namespace with dict `ns`
637 """update remote namespace with dict `ns`
638
638
639 Parameters
639 Parameters
640 ----------
640 ----------
641
641
642 ns : dict
642 ns : dict
643 dict of keys with which to update engine namespace(s)
643 dict of keys with which to update engine namespace(s)
644 block : bool [default : self.block]
644 block : bool [default : self.block]
645 whether to wait to be notified of engine receipt
645 whether to wait to be notified of engine receipt
646
646
647 """
647 """
648
648
649 block = block if block is not None else self.block
649 block = block if block is not None else self.block
650 track = track if track is not None else self.track
650 track = track if track is not None else self.track
651 targets = targets if targets is not None else self.targets
651 targets = targets if targets is not None else self.targets
652 # applier = self.apply_sync if block else self.apply_async
652 # applier = self.apply_sync if block else self.apply_async
653 if not isinstance(ns, dict):
653 if not isinstance(ns, dict):
654 raise TypeError("Must be a dict, not %s"%type(ns))
654 raise TypeError("Must be a dict, not %s"%type(ns))
655 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
655 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
656
656
657 def get(self, key_s):
657 def get(self, key_s):
658 """get object(s) by `key_s` from remote namespace
658 """get object(s) by `key_s` from remote namespace
659
659
660 see `pull` for details.
660 see `pull` for details.
661 """
661 """
662 # block = block if block is not None else self.block
662 # block = block if block is not None else self.block
663 return self.pull(key_s, block=True)
663 return self.pull(key_s, block=True)
664
664
665 def pull(self, names, targets=None, block=None):
665 def pull(self, names, targets=None, block=None):
666 """get object(s) by `name` from remote namespace
666 """get object(s) by `name` from remote namespace
667
667
668 will return one object if it is a key.
668 will return one object if it is a key.
669 can also take a list of keys, in which case it will return a list of objects.
669 can also take a list of keys, in which case it will return a list of objects.
670 """
670 """
671 block = block if block is not None else self.block
671 block = block if block is not None else self.block
672 targets = targets if targets is not None else self.targets
672 targets = targets if targets is not None else self.targets
673 applier = self.apply_sync if block else self.apply_async
673 applier = self.apply_sync if block else self.apply_async
674 if isinstance(names, basestring):
674 if isinstance(names, basestring):
675 pass
675 pass
676 elif isinstance(names, (list,tuple,set)):
676 elif isinstance(names, (list,tuple,set)):
677 for key in names:
677 for key in names:
678 if not isinstance(key, basestring):
678 if not isinstance(key, basestring):
679 raise TypeError("keys must be str, not type %r"%type(key))
679 raise TypeError("keys must be str, not type %r"%type(key))
680 else:
680 else:
681 raise TypeError("names must be strs, not %r"%names)
681 raise TypeError("names must be strs, not %r"%names)
682 return self._really_apply(util._pull, (names,), block=block, targets=targets)
682 return self._really_apply(util._pull, (names,), block=block, targets=targets)
683
683
684 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
684 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
685 """
685 """
686 Partition a Python sequence and send the partitions to a set of engines.
686 Partition a Python sequence and send the partitions to a set of engines.
687 """
687 """
688 block = block if block is not None else self.block
688 block = block if block is not None else self.block
689 track = track if track is not None else self.track
689 track = track if track is not None else self.track
690 targets = targets if targets is not None else self.targets
690 targets = targets if targets is not None else self.targets
691
691
692 mapObject = Map.dists[dist]()
692 mapObject = Map.dists[dist]()
693 nparts = len(targets)
693 nparts = len(targets)
694 msg_ids = []
694 msg_ids = []
695 trackers = []
695 trackers = []
696 for index, engineid in enumerate(targets):
696 for index, engineid in enumerate(targets):
697 partition = mapObject.getPartition(seq, index, nparts)
697 partition = mapObject.getPartition(seq, index, nparts)
698 if flatten and len(partition) == 1:
698 if flatten and len(partition) == 1:
699 ns = {key: partition[0]}
699 ns = {key: partition[0]}
700 else:
700 else:
701 ns = {key: partition}
701 ns = {key: partition}
702 r = self.push(ns, block=False, track=track, targets=engineid)
702 r = self.push(ns, block=False, track=track, targets=engineid)
703 msg_ids.extend(r.msg_ids)
703 msg_ids.extend(r.msg_ids)
704 if track:
704 if track:
705 trackers.append(r._tracker)
705 trackers.append(r._tracker)
706
706
707 if track:
707 if track:
708 tracker = zmq.MessageTracker(*trackers)
708 tracker = zmq.MessageTracker(*trackers)
709 else:
709 else:
710 tracker = None
710 tracker = None
711
711
712 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
712 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
713 if block:
713 if block:
714 r.wait()
714 r.wait()
715 else:
715 else:
716 return r
716 return r
717
717
718 @sync_results
718 @sync_results
719 @save_ids
719 @save_ids
720 def gather(self, key, dist='b', targets=None, block=None):
720 def gather(self, key, dist='b', targets=None, block=None):
721 """
721 """
722 Gather a partitioned sequence on a set of engines as a single local seq.
722 Gather a partitioned sequence on a set of engines as a single local seq.
723 """
723 """
724 block = block if block is not None else self.block
724 block = block if block is not None else self.block
725 targets = targets if targets is not None else self.targets
725 targets = targets if targets is not None else self.targets
726 mapObject = Map.dists[dist]()
726 mapObject = Map.dists[dist]()
727 msg_ids = []
727 msg_ids = []
728
728
729 for index, engineid in enumerate(targets):
729 for index, engineid in enumerate(targets):
730 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
730 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
731
731
732 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
732 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
733
733
734 if block:
734 if block:
735 try:
735 try:
736 return r.get()
736 return r.get()
737 except KeyboardInterrupt:
737 except KeyboardInterrupt:
738 pass
738 pass
739 return r
739 return r
740
740
741 def __getitem__(self, key):
741 def __getitem__(self, key):
742 return self.get(key)
742 return self.get(key)
743
743
744 def __setitem__(self,key, value):
744 def __setitem__(self,key, value):
745 self.update({key:value})
745 self.update({key:value})
746
746
747 def clear(self, targets=None, block=False):
747 def clear(self, targets=None, block=False):
748 """Clear the remote namespaces on my engines."""
748 """Clear the remote namespaces on my engines."""
749 block = block if block is not None else self.block
749 block = block if block is not None else self.block
750 targets = targets if targets is not None else self.targets
750 targets = targets if targets is not None else self.targets
751 return self.client.clear(targets=targets, block=block)
751 return self.client.clear(targets=targets, block=block)
752
752
753 def kill(self, targets=None, block=True):
753 def kill(self, targets=None, block=True):
754 """Kill my engines."""
754 """Kill my engines."""
755 block = block if block is not None else self.block
755 block = block if block is not None else self.block
756 targets = targets if targets is not None else self.targets
756 targets = targets if targets is not None else self.targets
757 return self.client.kill(targets=targets, block=block)
757 return self.client.kill(targets=targets, block=block)
758
758
759 #----------------------------------------
759 #----------------------------------------
760 # activate for %px,%autopx magics
760 # activate for %px,%autopx magics
761 #----------------------------------------
761 #----------------------------------------
762 def activate(self):
762 def activate(self):
763 """Make this `View` active for parallel magic commands.
763 """Make this `View` active for parallel magic commands.
764
764
765 IPython has a magic command syntax to work with `MultiEngineClient` objects.
765 IPython has a magic command syntax to work with `MultiEngineClient` objects.
766 In a given IPython session there is a single active one. While
766 In a given IPython session there is a single active one. While
767 there can be many `Views` created and used by the user,
767 there can be many `Views` created and used by the user,
768 there is only one active one. The active `View` is used whenever
768 there is only one active one. The active `View` is used whenever
769 the magic commands %px and %autopx are used.
769 the magic commands %px and %autopx are used.
770
770
771 The activate() method is called on a given `View` to make it
771 The activate() method is called on a given `View` to make it
772 active. Once this has been done, the magic commands can be used.
772 active. Once this has been done, the magic commands can be used.
773 """
773 """
774
774
775 try:
775 try:
776 # This is injected into __builtins__.
776 # This is injected into __builtins__.
777 ip = get_ipython()
777 ip = get_ipython()
778 except NameError:
778 except NameError:
779 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
779 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
780 else:
780 else:
781 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
781 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
782 if pmagic is None:
782 if pmagic is None:
783 ip.magic_load_ext('parallelmagic')
783 ip.magic_load_ext('parallelmagic')
784 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
784 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
785
785
786 pmagic.active_view = self
786 pmagic.active_view = self
787
787
788
788
789 @skip_doctest
789 @skip_doctest
790 class LoadBalancedView(View):
790 class LoadBalancedView(View):
791 """An load-balancing View that only executes via the Task scheduler.
791 """An load-balancing View that only executes via the Task scheduler.
792
792
793 Load-balanced views can be created with the client's `view` method:
793 Load-balanced views can be created with the client's `view` method:
794
794
795 >>> v = client.load_balanced_view()
795 >>> v = client.load_balanced_view()
796
796
797 or targets can be specified, to restrict the potential destinations:
797 or targets can be specified, to restrict the potential destinations:
798
798
799 >>> v = client.client.load_balanced_view([1,3])
799 >>> v = client.client.load_balanced_view([1,3])
800
800
801 which would restrict loadbalancing to between engines 1 and 3.
801 which would restrict loadbalancing to between engines 1 and 3.
802
802
803 """
803 """
804
804
805 follow=Any()
805 follow=Any()
806 after=Any()
806 after=Any()
807 timeout=CFloat()
807 timeout=CFloat()
808 retries = Integer(0)
808 retries = Integer(0)
809
809
810 _task_scheme = Any()
810 _task_scheme = Any()
811 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
811 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
812
812
813 def __init__(self, client=None, socket=None, **flags):
813 def __init__(self, client=None, socket=None, **flags):
814 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
814 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
815 self._task_scheme=client._task_scheme
815 self._task_scheme=client._task_scheme
816
816
817 def _validate_dependency(self, dep):
817 def _validate_dependency(self, dep):
818 """validate a dependency.
818 """validate a dependency.
819
819
820 For use in `set_flags`.
820 For use in `set_flags`.
821 """
821 """
822 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
822 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
823 return True
823 return True
824 elif isinstance(dep, (list,set, tuple)):
824 elif isinstance(dep, (list,set, tuple)):
825 for d in dep:
825 for d in dep:
826 if not isinstance(d, (basestring, AsyncResult)):
826 if not isinstance(d, (basestring, AsyncResult)):
827 return False
827 return False
828 elif isinstance(dep, dict):
828 elif isinstance(dep, dict):
829 if set(dep.keys()) != set(Dependency().as_dict().keys()):
829 if set(dep.keys()) != set(Dependency().as_dict().keys()):
830 return False
830 return False
831 if not isinstance(dep['msg_ids'], list):
831 if not isinstance(dep['msg_ids'], list):
832 return False
832 return False
833 for d in dep['msg_ids']:
833 for d in dep['msg_ids']:
834 if not isinstance(d, basestring):
834 if not isinstance(d, basestring):
835 return False
835 return False
836 else:
836 else:
837 return False
837 return False
838
838
839 return True
839 return True
840
840
841 def _render_dependency(self, dep):
841 def _render_dependency(self, dep):
842 """helper for building jsonable dependencies from various input forms."""
842 """helper for building jsonable dependencies from various input forms."""
843 if isinstance(dep, Dependency):
843 if isinstance(dep, Dependency):
844 return dep.as_dict()
844 return dep.as_dict()
845 elif isinstance(dep, AsyncResult):
845 elif isinstance(dep, AsyncResult):
846 return dep.msg_ids
846 return dep.msg_ids
847 elif dep is None:
847 elif dep is None:
848 return []
848 return []
849 else:
849 else:
850 # pass to Dependency constructor
850 # pass to Dependency constructor
851 return list(Dependency(dep))
851 return list(Dependency(dep))
852
852
853 def set_flags(self, **kwargs):
853 def set_flags(self, **kwargs):
854 """set my attribute flags by keyword.
854 """set my attribute flags by keyword.
855
855
856 A View is a wrapper for the Client's apply method, but with attributes
856 A View is a wrapper for the Client's apply method, but with attributes
857 that specify keyword arguments, those attributes can be set by keyword
857 that specify keyword arguments, those attributes can be set by keyword
858 argument with this method.
858 argument with this method.
859
859
860 Parameters
860 Parameters
861 ----------
861 ----------
862
862
863 block : bool
863 block : bool
864 whether to wait for results
864 whether to wait for results
865 track : bool
865 track : bool
866 whether to create a MessageTracker to allow the user to
866 whether to create a MessageTracker to allow the user to
867 safely edit after arrays and buffers during non-copying
867 safely edit after arrays and buffers during non-copying
868 sends.
868 sends.
869
869
870 after : Dependency or collection of msg_ids
870 after : Dependency or collection of msg_ids
871 Only for load-balanced execution (targets=None)
871 Only for load-balanced execution (targets=None)
872 Specify a list of msg_ids as a time-based dependency.
872 Specify a list of msg_ids as a time-based dependency.
873 This job will only be run *after* the dependencies
873 This job will only be run *after* the dependencies
874 have been met.
874 have been met.
875
875
876 follow : Dependency or collection of msg_ids
876 follow : Dependency or collection of msg_ids
877 Only for load-balanced execution (targets=None)
877 Only for load-balanced execution (targets=None)
878 Specify a list of msg_ids as a location-based dependency.
878 Specify a list of msg_ids as a location-based dependency.
879 This job will only be run on an engine where this dependency
879 This job will only be run on an engine where this dependency
880 is met.
880 is met.
881
881
882 timeout : float/int or None
882 timeout : float/int or None
883 Only for load-balanced execution (targets=None)
883 Only for load-balanced execution (targets=None)
884 Specify an amount of time (in seconds) for the scheduler to
884 Specify an amount of time (in seconds) for the scheduler to
885 wait for dependencies to be met before failing with a
885 wait for dependencies to be met before failing with a
886 DependencyTimeout.
886 DependencyTimeout.
887
887
888 retries : int
888 retries : int
889 Number of times a task will be retried on failure.
889 Number of times a task will be retried on failure.
890 """
890 """
891
891
892 super(LoadBalancedView, self).set_flags(**kwargs)
892 super(LoadBalancedView, self).set_flags(**kwargs)
893 for name in ('follow', 'after'):
893 for name in ('follow', 'after'):
894 if name in kwargs:
894 if name in kwargs:
895 value = kwargs[name]
895 value = kwargs[name]
896 if self._validate_dependency(value):
896 if self._validate_dependency(value):
897 setattr(self, name, value)
897 setattr(self, name, value)
898 else:
898 else:
899 raise ValueError("Invalid dependency: %r"%value)
899 raise ValueError("Invalid dependency: %r"%value)
900 if 'timeout' in kwargs:
900 if 'timeout' in kwargs:
901 t = kwargs['timeout']
901 t = kwargs['timeout']
902 if not isinstance(t, (int, long, float, type(None))):
902 if not isinstance(t, (int, long, float, type(None))):
903 raise TypeError("Invalid type for timeout: %r"%type(t))
903 raise TypeError("Invalid type for timeout: %r"%type(t))
904 if t is not None:
904 if t is not None:
905 if t < 0:
905 if t < 0:
906 raise ValueError("Invalid timeout: %s"%t)
906 raise ValueError("Invalid timeout: %s"%t)
907 self.timeout = t
907 self.timeout = t
908
908
909 @sync_results
909 @sync_results
910 @save_ids
910 @save_ids
911 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
911 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
912 after=None, follow=None, timeout=None,
912 after=None, follow=None, timeout=None,
913 targets=None, retries=None):
913 targets=None, retries=None):
914 """calls f(*args, **kwargs) on a remote engine, returning the result.
914 """calls f(*args, **kwargs) on a remote engine, returning the result.
915
915
916 This method temporarily sets all of `apply`'s flags for a single call.
916 This method temporarily sets all of `apply`'s flags for a single call.
917
917
918 Parameters
918 Parameters
919 ----------
919 ----------
920
920
921 f : callable
921 f : callable
922
922
923 args : list [default: empty]
923 args : list [default: empty]
924
924
925 kwargs : dict [default: empty]
925 kwargs : dict [default: empty]
926
926
927 block : bool [default: self.block]
927 block : bool [default: self.block]
928 whether to block
928 whether to block
929 track : bool [default: self.track]
929 track : bool [default: self.track]
930 whether to ask zmq to track the message, for safe non-copying sends
930 whether to ask zmq to track the message, for safe non-copying sends
931
931
932 !!!!!! TODO: THE REST HERE !!!!
932 !!!!!! TODO: THE REST HERE !!!!
933
933
934 Returns
934 Returns
935 -------
935 -------
936
936
937 if self.block is False:
937 if self.block is False:
938 returns AsyncResult
938 returns AsyncResult
939 else:
939 else:
940 returns actual result of f(*args, **kwargs) on the engine(s)
940 returns actual result of f(*args, **kwargs) on the engine(s)
941 This will be a list of self.targets is also a list (even length 1), or
941 This will be a list of self.targets is also a list (even length 1), or
942 the single result if self.targets is an integer engine id
942 the single result if self.targets is an integer engine id
943 """
943 """
944
944
945 # validate whether we can run
945 # validate whether we can run
946 if self._socket.closed:
946 if self._socket.closed:
947 msg = "Task farming is disabled"
947 msg = "Task farming is disabled"
948 if self._task_scheme == 'pure':
948 if self._task_scheme == 'pure':
949 msg += " because the pure ZMQ scheduler cannot handle"
949 msg += " because the pure ZMQ scheduler cannot handle"
950 msg += " disappearing engines."
950 msg += " disappearing engines."
951 raise RuntimeError(msg)
951 raise RuntimeError(msg)
952
952
953 if self._task_scheme == 'pure':
953 if self._task_scheme == 'pure':
954 # pure zmq scheme doesn't support extra features
954 # pure zmq scheme doesn't support extra features
955 msg = "Pure ZMQ scheduler doesn't support the following flags:"
955 msg = "Pure ZMQ scheduler doesn't support the following flags:"
956 "follow, after, retries, targets, timeout"
956 "follow, after, retries, targets, timeout"
957 if (follow or after or retries or targets or timeout):
957 if (follow or after or retries or targets or timeout):
958 # hard fail on Scheduler flags
958 # hard fail on Scheduler flags
959 raise RuntimeError(msg)
959 raise RuntimeError(msg)
960 if isinstance(f, dependent):
960 if isinstance(f, dependent):
961 # soft warn on functional dependencies
961 # soft warn on functional dependencies
962 warnings.warn(msg, RuntimeWarning)
962 warnings.warn(msg, RuntimeWarning)
963
963
964 # build args
964 # build args
965 args = [] if args is None else args
965 args = [] if args is None else args
966 kwargs = {} if kwargs is None else kwargs
966 kwargs = {} if kwargs is None else kwargs
967 block = self.block if block is None else block
967 block = self.block if block is None else block
968 track = self.track if track is None else track
968 track = self.track if track is None else track
969 after = self.after if after is None else after
969 after = self.after if after is None else after
970 retries = self.retries if retries is None else retries
970 retries = self.retries if retries is None else retries
971 follow = self.follow if follow is None else follow
971 follow = self.follow if follow is None else follow
972 timeout = self.timeout if timeout is None else timeout
972 timeout = self.timeout if timeout is None else timeout
973 targets = self.targets if targets is None else targets
973 targets = self.targets if targets is None else targets
974
974
975 if not isinstance(retries, int):
975 if not isinstance(retries, int):
976 raise TypeError('retries must be int, not %r'%type(retries))
976 raise TypeError('retries must be int, not %r'%type(retries))
977
977
978 if targets is None:
978 if targets is None:
979 idents = []
979 idents = []
980 else:
980 else:
981 idents = self.client._build_targets(targets)[0]
981 idents = self.client._build_targets(targets)[0]
982 # ensure *not* bytes
982 # ensure *not* bytes
983 idents = [ ident.decode() for ident in idents ]
983 idents = [ ident.decode() for ident in idents ]
984
984
985 after = self._render_dependency(after)
985 after = self._render_dependency(after)
986 follow = self._render_dependency(follow)
986 follow = self._render_dependency(follow)
987 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
987 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
988
988
989 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
989 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
990 subheader=subheader)
990 subheader=subheader)
991 tracker = None if track is False else msg['tracker']
991 tracker = None if track is False else msg['tracker']
992
992
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
994
994
995 if block:
995 if block:
996 try:
996 try:
997 return ar.get()
997 return ar.get()
998 except KeyboardInterrupt:
998 except KeyboardInterrupt:
999 pass
999 pass
1000 return ar
1000 return ar
1001
1001
1002 @spin_after
1002 @spin_after
1003 @save_ids
1003 @save_ids
1004 def map(self, f, *sequences, **kwargs):
1004 def map(self, f, *sequences, **kwargs):
1005 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1005 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1006
1006
1007 Parallel version of builtin `map`, load-balanced by this View.
1007 Parallel version of builtin `map`, load-balanced by this View.
1008
1008
1009 `block`, and `chunksize` can be specified by keyword only.
1009 `block`, and `chunksize` can be specified by keyword only.
1010
1010
1011 Each `chunksize` elements will be a separate task, and will be
1011 Each `chunksize` elements will be a separate task, and will be
1012 load-balanced. This lets individual elements be available for iteration
1012 load-balanced. This lets individual elements be available for iteration
1013 as soon as they arrive.
1013 as soon as they arrive.
1014
1014
1015 Parameters
1015 Parameters
1016 ----------
1016 ----------
1017
1017
1018 f : callable
1018 f : callable
1019 function to be mapped
1019 function to be mapped
1020 *sequences: one or more sequences of matching length
1020 *sequences: one or more sequences of matching length
1021 the sequences to be distributed and passed to `f`
1021 the sequences to be distributed and passed to `f`
1022 block : bool [default self.block]
1022 block : bool [default self.block]
1023 whether to wait for the result or not
1023 whether to wait for the result or not
1024 track : bool
1024 track : bool
1025 whether to create a MessageTracker to allow the user to
1025 whether to create a MessageTracker to allow the user to
1026 safely edit after arrays and buffers during non-copying
1026 safely edit after arrays and buffers during non-copying
1027 sends.
1027 sends.
1028 chunksize : int [default 1]
1028 chunksize : int [default 1]
1029 how many elements should be in each task.
1029 how many elements should be in each task.
1030 ordered : bool [default True]
1030 ordered : bool [default True]
1031 Whether the results should be gathered as they arrive, or enforce
1031 Whether the results should be gathered as they arrive, or enforce
1032 the order of submission.
1032 the order of submission.
1033
1033
1034 Only applies when iterating through AsyncMapResult as results arrive.
1034 Only applies when iterating through AsyncMapResult as results arrive.
1035 Has no effect when block=True.
1035 Has no effect when block=True.
1036
1036
1037 Returns
1037 Returns
1038 -------
1038 -------
1039
1039
1040 if block=False:
1040 if block=False:
1041 AsyncMapResult
1041 AsyncMapResult
1042 An object like AsyncResult, but which reassembles the sequence of results
1042 An object like AsyncResult, but which reassembles the sequence of results
1043 into a single list. AsyncMapResults can be iterated through before all
1043 into a single list. AsyncMapResults can be iterated through before all
1044 results are complete.
1044 results are complete.
1045 else:
1045 else:
1046 the result of map(f,*sequences)
1046 the result of map(f,*sequences)
1047
1047
1048 """
1048 """
1049
1049
1050 # default
1050 # default
1051 block = kwargs.get('block', self.block)
1051 block = kwargs.get('block', self.block)
1052 chunksize = kwargs.get('chunksize', 1)
1052 chunksize = kwargs.get('chunksize', 1)
1053 ordered = kwargs.get('ordered', True)
1053 ordered = kwargs.get('ordered', True)
1054
1054
1055 keyset = set(kwargs.keys())
1055 keyset = set(kwargs.keys())
1056 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1056 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1057 if extra_keys:
1057 if extra_keys:
1058 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1058 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1059
1059
1060 assert len(sequences) > 0, "must have some sequences to map onto!"
1060 assert len(sequences) > 0, "must have some sequences to map onto!"
1061
1061
1062 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1062 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1063 return pf.map(*sequences)
1063 return pf.map(*sequences)
1064
1064
1065 __all__ = ['LoadBalancedView', 'DirectView']
1065 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,472 +1,493
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28
28
29 from IPython import parallel as pmod
29 from IPython import parallel as pmod
30 from IPython.parallel import error
30 from IPython.parallel import error
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import DirectView
32 from IPython.parallel import DirectView
33 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
34
34
35 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
36
36
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38
38
39 def setup():
39 def setup():
40 add_engines(3)
40 add_engines(3)
41
41
42 class TestView(ClusterTestCase):
42 class TestView(ClusterTestCase):
43
43
44 def test_z_crash_mux(self):
44 def test_z_crash_mux(self):
45 """test graceful handling of engine death (direct)"""
45 """test graceful handling of engine death (direct)"""
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEquals(d, data)
69 self.assertEquals(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEquals(d, nengines*[data])
72 self.assertEquals(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEquals(r, nengines*[data])
79 self.assertEquals(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEquals(r, nengines*[[10,20]])
82 self.assertEquals(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEquals(r, testf(10))
100 self.assertEquals(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEquals(v['b'], 5)
123 self.assertEquals(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEquals(v['b'], 10)
133 self.assertEquals(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids)
145 ahr = v2.get_result(ar.msg_ids)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertEquals(ahr.get(), ar.get())
147 self.assertEquals(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids)
148 ar2 = v2.get_result(ar.msg_ids)
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 c.spin()
150 c.spin()
151 c.close()
151 c.close()
152
152
153 def test_run_newline(self):
153 def test_run_newline(self):
154 """test that run appends newline to files"""
154 """test that run appends newline to files"""
155 tmpfile = mktemp()
155 tmpfile = mktemp()
156 with open(tmpfile, 'w') as f:
156 with open(tmpfile, 'w') as f:
157 f.write("""def g():
157 f.write("""def g():
158 return 5
158 return 5
159 """)
159 """)
160 v = self.client[-1]
160 v = self.client[-1]
161 v.run(tmpfile, block=True)
161 v.run(tmpfile, block=True)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
163
164 def test_apply_tracked(self):
164 def test_apply_tracked(self):
165 """test tracking for apply"""
165 """test tracking for apply"""
166 # self.add_engines(1)
166 # self.add_engines(1)
167 t = self.client.ids[-1]
167 t = self.client.ids[-1]
168 v = self.client[t]
168 v = self.client[t]
169 v.block=False
169 v.block=False
170 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
173 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176 ar = echo(track=True)
176 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEquals(ar.sent, ar._tracker.done)
178 self.assertEquals(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
179 ar._tracker.wait()
180 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
181
181
182 def test_push_tracked(self):
182 def test_push_tracked(self):
183 t = self.client.ids[-1]
183 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
185 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
189
189
190 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
192 ar._tracker.wait()
193 self.assertEquals(ar.sent, ar._tracker.done)
193 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195 ar.get()
195 ar.get()
196
196
197 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
198 t = self.client.ids
198 t = self.client.ids
199 x='x'*1024*1024
199 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEquals(ar.sent, ar._tracker.done)
206 self.assertEquals(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
207 ar._tracker.wait()
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_remote_reference(self):
211 def test_remote_reference(self):
212 v = self.client[-1]
212 v = self.client[-1]
213 v['a'] = 123
213 v['a'] = 123
214 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
216 self.assertEquals(b, 123)
216 self.assertEquals(b, 123)
217
217
218
218
219 def test_scatter_gather(self):
219 def test_scatter_gather(self):
220 view = self.client[:]
220 view = self.client[:]
221 seq1 = range(16)
221 seq1 = range(16)
222 view.scatter('a', seq1)
222 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
224 self.assertEquals(seq2, seq1)
224 self.assertEquals(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
226
227 @skip_without('numpy')
227 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
229 import numpy
229 import numpy
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 view = self.client[:]
231 view = self.client[:]
232 a = numpy.arange(64)
232 a = numpy.arange(64)
233 view.scatter('a', a)
233 view.scatter('a', a)
234 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
235 assert_array_equal(b, a)
236
236
237 def test_map(self):
237 def test_map(self):
238 view = self.client[:]
238 view = self.client[:]
239 def f(x):
239 def f(x):
240 return x**2
240 return x**2
241 data = range(16)
241 data = range(16)
242 r = view.map_sync(f, data)
242 r = view.map_sync(f, data)
243 self.assertEquals(r, map(f, data))
243 self.assertEquals(r, map(f, data))
244
244
245 def test_map_iterable(self):
245 def test_map_iterable(self):
246 """test map on iterables (direct)"""
246 """test map on iterables (direct)"""
247 view = self.client[:]
247 view = self.client[:]
248 # 101 is prime, so it won't be evenly distributed
248 # 101 is prime, so it won't be evenly distributed
249 arr = range(101)
249 arr = range(101)
250 # ensure it will be an iterator, even in Python 3
250 # ensure it will be an iterator, even in Python 3
251 it = iter(arr)
251 it = iter(arr)
252 r = view.map_sync(lambda x:x, arr)
252 r = view.map_sync(lambda x:x, arr)
253 self.assertEquals(r, list(arr))
253 self.assertEquals(r, list(arr))
254
254
255 def test_scatterGatherNonblocking(self):
255 def test_scatterGatherNonblocking(self):
256 data = range(16)
256 data = range(16)
257 view = self.client[:]
257 view = self.client[:]
258 view.scatter('a', data, block=False)
258 view.scatter('a', data, block=False)
259 ar = view.gather('a', block=False)
259 ar = view.gather('a', block=False)
260 self.assertEquals(ar.get(), data)
260 self.assertEquals(ar.get(), data)
261
261
262 @skip_without('numpy')
262 @skip_without('numpy')
263 def test_scatter_gather_numpy_nonblocking(self):
263 def test_scatter_gather_numpy_nonblocking(self):
264 import numpy
264 import numpy
265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
266 a = numpy.arange(64)
266 a = numpy.arange(64)
267 view = self.client[:]
267 view = self.client[:]
268 ar = view.scatter('a', a, block=False)
268 ar = view.scatter('a', a, block=False)
269 self.assertTrue(isinstance(ar, AsyncResult))
269 self.assertTrue(isinstance(ar, AsyncResult))
270 amr = view.gather('a', block=False)
270 amr = view.gather('a', block=False)
271 self.assertTrue(isinstance(amr, AsyncMapResult))
271 self.assertTrue(isinstance(amr, AsyncMapResult))
272 assert_array_equal(amr.get(), a)
272 assert_array_equal(amr.get(), a)
273
273
274 def test_execute(self):
274 def test_execute(self):
275 view = self.client[:]
275 view = self.client[:]
276 # self.client.debug=True
276 # self.client.debug=True
277 execute = view.execute
277 execute = view.execute
278 ar = execute('c=30', block=False)
278 ar = execute('c=30', block=False)
279 self.assertTrue(isinstance(ar, AsyncResult))
279 self.assertTrue(isinstance(ar, AsyncResult))
280 ar = execute('d=[0,1,2]', block=False)
280 ar = execute('d=[0,1,2]', block=False)
281 self.client.wait(ar, 1)
281 self.client.wait(ar, 1)
282 self.assertEquals(len(ar.get()), len(self.client))
282 self.assertEquals(len(ar.get()), len(self.client))
283 for c in view['c']:
283 for c in view['c']:
284 self.assertEquals(c, 30)
284 self.assertEquals(c, 30)
285
285
286 def test_abort(self):
286 def test_abort(self):
287 view = self.client[-1]
287 view = self.client[-1]
288 ar = view.execute('import time; time.sleep(1)', block=False)
288 ar = view.execute('import time; time.sleep(1)', block=False)
289 ar2 = view.apply_async(lambda : 2)
289 ar2 = view.apply_async(lambda : 2)
290 ar3 = view.apply_async(lambda : 3)
290 ar3 = view.apply_async(lambda : 3)
291 view.abort(ar2)
291 view.abort(ar2)
292 view.abort(ar3.msg_ids)
292 view.abort(ar3.msg_ids)
293 self.assertRaises(error.TaskAborted, ar2.get)
293 self.assertRaises(error.TaskAborted, ar2.get)
294 self.assertRaises(error.TaskAborted, ar3.get)
294 self.assertRaises(error.TaskAborted, ar3.get)
295
295
296 def test_abort_all(self):
296 def test_abort_all(self):
297 """view.abort() aborts all outstanding tasks"""
297 """view.abort() aborts all outstanding tasks"""
298 view = self.client[-1]
298 view = self.client[-1]
299 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
299 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
300 view.abort()
300 view.abort()
301 view.wait(timeout=5)
301 view.wait(timeout=5)
302 for ar in ars[5:]:
302 for ar in ars[5:]:
303 self.assertRaises(error.TaskAborted, ar.get)
303 self.assertRaises(error.TaskAborted, ar.get)
304
304
305 def test_temp_flags(self):
305 def test_temp_flags(self):
306 view = self.client[-1]
306 view = self.client[-1]
307 view.block=True
307 view.block=True
308 with view.temp_flags(block=False):
308 with view.temp_flags(block=False):
309 self.assertFalse(view.block)
309 self.assertFalse(view.block)
310 self.assertTrue(view.block)
310 self.assertTrue(view.block)
311
311
312 @dec.known_failure_py3
312 @dec.known_failure_py3
313 def test_importer(self):
313 def test_importer(self):
314 view = self.client[-1]
314 view = self.client[-1]
315 view.clear(block=True)
315 view.clear(block=True)
316 with view.importer:
316 with view.importer:
317 import re
317 import re
318
318
319 @interactive
319 @interactive
320 def findall(pat, s):
320 def findall(pat, s):
321 # this globals() step isn't necessary in real code
321 # this globals() step isn't necessary in real code
322 # only to prevent a closure in the test
322 # only to prevent a closure in the test
323 re = globals()['re']
323 re = globals()['re']
324 return re.findall(pat, s)
324 return re.findall(pat, s)
325
325
326 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
326 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
327
327
328 # parallel magic tests
328 # parallel magic tests
329
329
330 def test_magic_px_blocking(self):
330 def test_magic_px_blocking(self):
331 ip = get_ipython()
331 ip = get_ipython()
332 v = self.client[-1]
332 v = self.client[-1]
333 v.activate()
333 v.activate()
334 v.block=True
334 v.block=True
335
335
336 ip.magic_px('a=5')
336 ip.magic_px('a=5')
337 self.assertEquals(v['a'], 5)
337 self.assertEquals(v['a'], 5)
338 ip.magic_px('a=10')
338 ip.magic_px('a=10')
339 self.assertEquals(v['a'], 10)
339 self.assertEquals(v['a'], 10)
340 sio = StringIO()
340 sio = StringIO()
341 savestdout = sys.stdout
341 savestdout = sys.stdout
342 sys.stdout = sio
342 sys.stdout = sio
343 # just 'print a' worst ~99% of the time, but this ensures that
343 # just 'print a' worst ~99% of the time, but this ensures that
344 # the stdout message has arrived when the result is finished:
344 # the stdout message has arrived when the result is finished:
345 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
345 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
346 sys.stdout = savestdout
346 sys.stdout = savestdout
347 buf = sio.getvalue()
347 buf = sio.getvalue()
348 self.assertTrue('[stdout:' in buf, buf)
348 self.assertTrue('[stdout:' in buf, buf)
349 self.assertTrue(buf.rstrip().endswith('10'))
349 self.assertTrue(buf.rstrip().endswith('10'))
350 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
350 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
351
351
352 def test_magic_px_nonblocking(self):
352 def test_magic_px_nonblocking(self):
353 ip = get_ipython()
353 ip = get_ipython()
354 v = self.client[-1]
354 v = self.client[-1]
355 v.activate()
355 v.activate()
356 v.block=False
356 v.block=False
357
357
358 ip.magic_px('a=5')
358 ip.magic_px('a=5')
359 self.assertEquals(v['a'], 5)
359 self.assertEquals(v['a'], 5)
360 ip.magic_px('a=10')
360 ip.magic_px('a=10')
361 self.assertEquals(v['a'], 10)
361 self.assertEquals(v['a'], 10)
362 sio = StringIO()
362 sio = StringIO()
363 savestdout = sys.stdout
363 savestdout = sys.stdout
364 sys.stdout = sio
364 sys.stdout = sio
365 ip.magic_px('print a')
365 ip.magic_px('print a')
366 sys.stdout = savestdout
366 sys.stdout = savestdout
367 buf = sio.getvalue()
367 buf = sio.getvalue()
368 self.assertFalse('[stdout:%i]'%v.targets in buf)
368 self.assertFalse('[stdout:%i]'%v.targets in buf)
369 ip.magic_px('1/0')
369 ip.magic_px('1/0')
370 ar = v.get_result(-1)
370 ar = v.get_result(-1)
371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
372
372
373 def test_magic_autopx_blocking(self):
373 def test_magic_autopx_blocking(self):
374 ip = get_ipython()
374 ip = get_ipython()
375 v = self.client[-1]
375 v = self.client[-1]
376 v.activate()
376 v.activate()
377 v.block=True
377 v.block=True
378
378
379 sio = StringIO()
379 sio = StringIO()
380 savestdout = sys.stdout
380 savestdout = sys.stdout
381 sys.stdout = sio
381 sys.stdout = sio
382 ip.magic_autopx()
382 ip.magic_autopx()
383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
384 ip.run_cell('print b')
384 ip.run_cell('print b')
385 ip.run_cell("b/c")
385 ip.run_cell("b/c")
386 ip.run_code(compile('b*=2', '', 'single'))
386 ip.run_code(compile('b*=2', '', 'single'))
387 ip.magic_autopx()
387 ip.magic_autopx()
388 sys.stdout = savestdout
388 sys.stdout = savestdout
389 output = sio.getvalue().strip()
389 output = sio.getvalue().strip()
390 self.assertTrue(output.startswith('%autopx enabled'))
390 self.assertTrue(output.startswith('%autopx enabled'))
391 self.assertTrue(output.endswith('%autopx disabled'))
391 self.assertTrue(output.endswith('%autopx disabled'))
392 self.assertTrue('RemoteError: ZeroDivisionError' in output)
392 self.assertTrue('RemoteError: ZeroDivisionError' in output)
393 ar = v.get_result(-2)
393 ar = v.get_result(-2)
394 self.assertEquals(v['a'], 5)
394 self.assertEquals(v['a'], 5)
395 self.assertEquals(v['b'], 20)
395 self.assertEquals(v['b'], 20)
396 self.assertRaisesRemote(ZeroDivisionError, ar.get)
396 self.assertRaisesRemote(ZeroDivisionError, ar.get)
397
397
398 def test_magic_autopx_nonblocking(self):
398 def test_magic_autopx_nonblocking(self):
399 ip = get_ipython()
399 ip = get_ipython()
400 v = self.client[-1]
400 v = self.client[-1]
401 v.activate()
401 v.activate()
402 v.block=False
402 v.block=False
403
403
404 sio = StringIO()
404 sio = StringIO()
405 savestdout = sys.stdout
405 savestdout = sys.stdout
406 sys.stdout = sio
406 sys.stdout = sio
407 ip.magic_autopx()
407 ip.magic_autopx()
408 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
408 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
409 ip.run_cell('print b')
409 ip.run_cell('print b')
410 ip.run_cell("b/c")
410 ip.run_cell("b/c")
411 ip.run_code(compile('b*=2', '', 'single'))
411 ip.run_code(compile('b*=2', '', 'single'))
412 ip.magic_autopx()
412 ip.magic_autopx()
413 sys.stdout = savestdout
413 sys.stdout = savestdout
414 output = sio.getvalue().strip()
414 output = sio.getvalue().strip()
415 self.assertTrue(output.startswith('%autopx enabled'))
415 self.assertTrue(output.startswith('%autopx enabled'))
416 self.assertTrue(output.endswith('%autopx disabled'))
416 self.assertTrue(output.endswith('%autopx disabled'))
417 self.assertFalse('ZeroDivisionError' in output)
417 self.assertFalse('ZeroDivisionError' in output)
418 ar = v.get_result(-2)
418 ar = v.get_result(-2)
419 self.assertEquals(v['a'], 5)
419 self.assertEquals(v['a'], 5)
420 self.assertEquals(v['b'], 20)
420 self.assertEquals(v['b'], 20)
421 self.assertRaisesRemote(ZeroDivisionError, ar.get)
421 self.assertRaisesRemote(ZeroDivisionError, ar.get)
422
422
423 def test_magic_result(self):
423 def test_magic_result(self):
424 ip = get_ipython()
424 ip = get_ipython()
425 v = self.client[-1]
425 v = self.client[-1]
426 v.activate()
426 v.activate()
427 v['a'] = 111
427 v['a'] = 111
428 ra = v['a']
428 ra = v['a']
429
429
430 ar = ip.magic_result()
430 ar = ip.magic_result()
431 self.assertEquals(ar.msg_ids, [v.history[-1]])
431 self.assertEquals(ar.msg_ids, [v.history[-1]])
432 self.assertEquals(ar.get(), 111)
432 self.assertEquals(ar.get(), 111)
433 ar = ip.magic_result('-2')
433 ar = ip.magic_result('-2')
434 self.assertEquals(ar.msg_ids, [v.history[-2]])
434 self.assertEquals(ar.msg_ids, [v.history[-2]])
435
435
436 def test_unicode_execute(self):
436 def test_unicode_execute(self):
437 """test executing unicode strings"""
437 """test executing unicode strings"""
438 v = self.client[-1]
438 v = self.client[-1]
439 v.block=True
439 v.block=True
440 if sys.version_info[0] >= 3:
440 if sys.version_info[0] >= 3:
441 code="a='é'"
441 code="a='é'"
442 else:
442 else:
443 code=u"a=u'é'"
443 code=u"a=u'é'"
444 v.execute(code)
444 v.execute(code)
445 self.assertEquals(v['a'], u'é')
445 self.assertEquals(v['a'], u'é')
446
446
447 def test_unicode_apply_result(self):
447 def test_unicode_apply_result(self):
448 """test unicode apply results"""
448 """test unicode apply results"""
449 v = self.client[-1]
449 v = self.client[-1]
450 r = v.apply_sync(lambda : u'é')
450 r = v.apply_sync(lambda : u'é')
451 self.assertEquals(r, u'é')
451 self.assertEquals(r, u'é')
452
452
453 def test_unicode_apply_arg(self):
453 def test_unicode_apply_arg(self):
454 """test passing unicode arguments to apply"""
454 """test passing unicode arguments to apply"""
455 v = self.client[-1]
455 v = self.client[-1]
456
456
457 @interactive
457 @interactive
458 def check_unicode(a, check):
458 def check_unicode(a, check):
459 assert isinstance(a, unicode), "%r is not unicode"%a
459 assert isinstance(a, unicode), "%r is not unicode"%a
460 assert isinstance(check, bytes), "%r is not bytes"%check
460 assert isinstance(check, bytes), "%r is not bytes"%check
461 assert a.encode('utf8') == check, "%s != %s"%(a,check)
461 assert a.encode('utf8') == check, "%s != %s"%(a,check)
462
462
463 for s in [ u'é', u'ßø®∫',u'asdf' ]:
463 for s in [ u'é', u'ßø®∫',u'asdf' ]:
464 try:
464 try:
465 v.apply_sync(check_unicode, s, s.encode('utf8'))
465 v.apply_sync(check_unicode, s, s.encode('utf8'))
466 except error.RemoteError as e:
466 except error.RemoteError as e:
467 if e.ename == 'AssertionError':
467 if e.ename == 'AssertionError':
468 self.fail(e.evalue)
468 self.fail(e.evalue)
469 else:
469 else:
470 raise e
470 raise e
471
471
472 def test_map_reference(self):
473 """view.map(<Reference>, *seqs) should work"""
474 v = self.client[:]
475 v.scatter('n', self.client.ids, flatten=True)
476 v.execute("f = lambda x,y: x*y")
477 rf = pmod.Reference('f')
478 nlist = list(range(10))
479 mlist = nlist[::-1]
480 expected = [ m*n for m,n in zip(mlist, nlist) ]
481 result = v.map_sync(rf, mlist, nlist)
482 self.assertEquals(result, expected)
483
484 def test_apply_reference(self):
485 """view.apply(<Reference>, *args) should work"""
486 v = self.client[:]
487 v.scatter('n', self.client.ids, flatten=True)
488 v.execute("f = lambda x: n*x")
489 rf = pmod.Reference('f')
490 result = v.apply_sync(rf, 5)
491 expected = [ 5*id for id in self.client.ids ]
492 self.assertEquals(result, expected)
472
493
General Comments 0
You need to be logged in to leave comments. Login now