##// END OF EJS Templates
add DirectView.importer contextmanager, demote targets to mutable flag...
MinRK -
Show More
@@ -1,23 +1,24 b''
1 """The IPython ZMQ-based parallel computing interface."""
1 """The IPython ZMQ-based parallel computing interface."""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
3 # Copyright (C) 2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 # from .asyncresult import *
14 # from .client import Client
15 # from .dependency import *
16 # from .remotefunction import *
17 # from .view import *
18
19 import zmq
13 import zmq
20
14
21 if zmq.__version__ < '2.1.3':
15 if zmq.__version__ < '2.1.3':
22 raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
16 raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
23
17
18 from .asyncresult import *
19 from .client import Client
20 from .dependency import *
21 from .remotefunction import *
22 from .view import *
23
24
@@ -1,1343 +1,1293 b''
1 """A semi-synchronous Client for the ZMQ cluster"""
1 """A semi-synchronous Client for the ZMQ cluster"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.pickleutil import Reference
27 from IPython.utils.pickleutil import Reference
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 Dict, List, Bool, Str, Set)
29 Dict, List, Bool, Str, Set)
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.external.ssh import tunnel
31 from IPython.external.ssh import tunnel
32
32
33 from . import error
33 from . import error
34 from . import util
34 from . import util
35 from . import streamsession as ss
35 from . import streamsession as ss
36 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
36 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
37 from .clusterdir import ClusterDir, ClusterDirError
37 from .clusterdir import ClusterDir, ClusterDirError
38 from .dependency import Dependency, depend, require, dependent
38 from .dependency import Dependency, depend, require, dependent
39 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
39 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
40 from .view import DirectView, LoadBalancedView
40 from .view import DirectView, LoadBalancedView
41
41
42 #--------------------------------------------------------------------------
42 #--------------------------------------------------------------------------
43 # Decorators for Client methods
43 # Decorators for Client methods
44 #--------------------------------------------------------------------------
44 #--------------------------------------------------------------------------
45
45
46 @decorator
46 @decorator
47 def spin_first(f, self, *args, **kwargs):
47 def spin_first(f, self, *args, **kwargs):
48 """Call spin() to sync state prior to calling the method."""
48 """Call spin() to sync state prior to calling the method."""
49 self.spin()
49 self.spin()
50 return f(self, *args, **kwargs)
50 return f(self, *args, **kwargs)
51
51
52 @decorator
52 @decorator
53 def default_block(f, self, *args, **kwargs):
53 def default_block(f, self, *args, **kwargs):
54 """Default to self.block; preserve self.block."""
54 """Default to self.block; preserve self.block."""
55 block = kwargs.get('block',None)
55 block = kwargs.get('block',None)
56 block = self.block if block is None else block
56 block = self.block if block is None else block
57 saveblock = self.block
57 saveblock = self.block
58 self.block = block
58 self.block = block
59 try:
59 try:
60 ret = f(self, *args, **kwargs)
60 ret = f(self, *args, **kwargs)
61 finally:
61 finally:
62 self.block = saveblock
62 self.block = saveblock
63 return ret
63 return ret
64
64
65
65
66 #--------------------------------------------------------------------------
66 #--------------------------------------------------------------------------
67 # Classes
67 # Classes
68 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
69
69
70 class Metadata(dict):
70 class Metadata(dict):
71 """Subclass of dict for initializing metadata values.
71 """Subclass of dict for initializing metadata values.
72
72
73 Attribute access works on keys.
73 Attribute access works on keys.
74
74
75 These objects have a strict set of keys - errors will raise if you try
75 These objects have a strict set of keys - errors will raise if you try
76 to add new keys.
76 to add new keys.
77 """
77 """
78 def __init__(self, *args, **kwargs):
78 def __init__(self, *args, **kwargs):
79 dict.__init__(self)
79 dict.__init__(self)
80 md = {'msg_id' : None,
80 md = {'msg_id' : None,
81 'submitted' : None,
81 'submitted' : None,
82 'started' : None,
82 'started' : None,
83 'completed' : None,
83 'completed' : None,
84 'received' : None,
84 'received' : None,
85 'engine_uuid' : None,
85 'engine_uuid' : None,
86 'engine_id' : None,
86 'engine_id' : None,
87 'follow' : None,
87 'follow' : None,
88 'after' : None,
88 'after' : None,
89 'status' : None,
89 'status' : None,
90
90
91 'pyin' : None,
91 'pyin' : None,
92 'pyout' : None,
92 'pyout' : None,
93 'pyerr' : None,
93 'pyerr' : None,
94 'stdout' : '',
94 'stdout' : '',
95 'stderr' : '',
95 'stderr' : '',
96 }
96 }
97 self.update(md)
97 self.update(md)
98 self.update(dict(*args, **kwargs))
98 self.update(dict(*args, **kwargs))
99
99
100 def __getattr__(self, key):
100 def __getattr__(self, key):
101 """getattr aliased to getitem"""
101 """getattr aliased to getitem"""
102 if key in self.iterkeys():
102 if key in self.iterkeys():
103 return self[key]
103 return self[key]
104 else:
104 else:
105 raise AttributeError(key)
105 raise AttributeError(key)
106
106
107 def __setattr__(self, key, value):
107 def __setattr__(self, key, value):
108 """setattr aliased to setitem, with strict"""
108 """setattr aliased to setitem, with strict"""
109 if key in self.iterkeys():
109 if key in self.iterkeys():
110 self[key] = value
110 self[key] = value
111 else:
111 else:
112 raise AttributeError(key)
112 raise AttributeError(key)
113
113
114 def __setitem__(self, key, value):
114 def __setitem__(self, key, value):
115 """strict static key enforcement"""
115 """strict static key enforcement"""
116 if key in self.iterkeys():
116 if key in self.iterkeys():
117 dict.__setitem__(self, key, value)
117 dict.__setitem__(self, key, value)
118 else:
118 else:
119 raise KeyError(key)
119 raise KeyError(key)
120
120
121
121
122 class Client(HasTraits):
122 class Client(HasTraits):
123 """A semi-synchronous client to the IPython ZMQ cluster
123 """A semi-synchronous client to the IPython ZMQ cluster
124
124
125 Parameters
125 Parameters
126 ----------
126 ----------
127
127
128 url_or_file : bytes; zmq url or path to ipcontroller-client.json
128 url_or_file : bytes; zmq url or path to ipcontroller-client.json
129 Connection information for the Hub's registration. If a json connector
129 Connection information for the Hub's registration. If a json connector
130 file is given, then likely no further configuration is necessary.
130 file is given, then likely no further configuration is necessary.
131 [Default: use profile]
131 [Default: use profile]
132 profile : bytes
132 profile : bytes
133 The name of the Cluster profile to be used to find connector information.
133 The name of the Cluster profile to be used to find connector information.
134 [Default: 'default']
134 [Default: 'default']
135 context : zmq.Context
135 context : zmq.Context
136 Pass an existing zmq.Context instance, otherwise the client will create its own.
136 Pass an existing zmq.Context instance, otherwise the client will create its own.
137 username : bytes
137 username : bytes
138 set username to be passed to the Session object
138 set username to be passed to the Session object
139 debug : bool
139 debug : bool
140 flag for lots of message printing for debug purposes
140 flag for lots of message printing for debug purposes
141
141
142 #-------------- ssh related args ----------------
142 #-------------- ssh related args ----------------
143 # These are args for configuring the ssh tunnel to be used
143 # These are args for configuring the ssh tunnel to be used
144 # credentials are used to forward connections over ssh to the Controller
144 # credentials are used to forward connections over ssh to the Controller
145 # Note that the ip given in `addr` needs to be relative to sshserver
145 # Note that the ip given in `addr` needs to be relative to sshserver
146 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
146 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
147 # and set sshserver as the same machine the Controller is on. However,
147 # and set sshserver as the same machine the Controller is on. However,
148 # the only requirement is that sshserver is able to see the Controller
148 # the only requirement is that sshserver is able to see the Controller
149 # (i.e. is within the same trusted network).
149 # (i.e. is within the same trusted network).
150
150
151 sshserver : str
151 sshserver : str
152 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
152 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
153 If keyfile or password is specified, and this is not, it will default to
153 If keyfile or password is specified, and this is not, it will default to
154 the ip given in addr.
154 the ip given in addr.
155 sshkey : str; path to public ssh key file
155 sshkey : str; path to public ssh key file
156 This specifies a key to be used in ssh login, default None.
156 This specifies a key to be used in ssh login, default None.
157 Regular default ssh keys will be used without specifying this argument.
157 Regular default ssh keys will be used without specifying this argument.
158 password : str
158 password : str
159 Your ssh password to sshserver. Note that if this is left None,
159 Your ssh password to sshserver. Note that if this is left None,
160 you will be prompted for it if passwordless key based login is unavailable.
160 you will be prompted for it if passwordless key based login is unavailable.
161 paramiko : bool
161 paramiko : bool
162 flag for whether to use paramiko instead of shell ssh for tunneling.
162 flag for whether to use paramiko instead of shell ssh for tunneling.
163 [default: True on win32, False else]
163 [default: True on win32, False else]
164
164
165 ------- exec authentication args -------
165 ------- exec authentication args -------
166 If even localhost is untrusted, you can have some protection against
166 If even localhost is untrusted, you can have some protection against
167 unauthorized execution by using a key. Messages are still sent
167 unauthorized execution by using a key. Messages are still sent
168 as cleartext, so if someone can snoop your loopback traffic this will
168 as cleartext, so if someone can snoop your loopback traffic this will
169 not help against malicious attacks.
169 not help against malicious attacks.
170
170
171 exec_key : str
171 exec_key : str
172 an authentication key or file containing a key
172 an authentication key or file containing a key
173 default: None
173 default: None
174
174
175
175
176 Attributes
176 Attributes
177 ----------
177 ----------
178
178
179 ids : list of int engine IDs
179 ids : list of int engine IDs
180 requesting the ids attribute always synchronizes
180 requesting the ids attribute always synchronizes
181 the registration state. To request ids without synchronization,
181 the registration state. To request ids without synchronization,
182 use semi-private _ids attributes.
182 use semi-private _ids attributes.
183
183
184 history : list of msg_ids
184 history : list of msg_ids
185 a list of msg_ids, keeping track of all the execution
185 a list of msg_ids, keeping track of all the execution
186 messages you have submitted in order.
186 messages you have submitted in order.
187
187
188 outstanding : set of msg_ids
188 outstanding : set of msg_ids
189 a set of msg_ids that have been submitted, but whose
189 a set of msg_ids that have been submitted, but whose
190 results have not yet been received.
190 results have not yet been received.
191
191
192 results : dict
192 results : dict
193 a dict of all our results, keyed by msg_id
193 a dict of all our results, keyed by msg_id
194
194
195 block : bool
195 block : bool
196 determines default behavior when block not specified
196 determines default behavior when block not specified
197 in execution methods
197 in execution methods
198
198
199 Methods
199 Methods
200 -------
200 -------
201
201
202 spin
202 spin
203 flushes incoming results and registration state changes
203 flushes incoming results and registration state changes
204 control methods spin, and requesting `ids` also ensures up to date
204 control methods spin, and requesting `ids` also ensures up to date
205
205
206 wait
206 wait
207 wait on one or more msg_ids
207 wait on one or more msg_ids
208
208
209 execution methods
209 execution methods
210 apply
210 apply
211 legacy: execute, run
211 legacy: execute, run
212
212
213 data movement
213 data movement
214 push, pull, scatter, gather
214 push, pull, scatter, gather
215
215
216 query methods
216 query methods
217 queue_status, get_result, purge, result_status
217 queue_status, get_result, purge, result_status
218
218
219 control methods
219 control methods
220 abort, shutdown
220 abort, shutdown
221
221
222 """
222 """
223
223
224
224
225 block = Bool(False)
225 block = Bool(False)
226 outstanding = Set()
226 outstanding = Set()
227 results = Instance('collections.defaultdict', (dict,))
227 results = Instance('collections.defaultdict', (dict,))
228 metadata = Instance('collections.defaultdict', (Metadata,))
228 metadata = Instance('collections.defaultdict', (Metadata,))
229 history = List()
229 history = List()
230 debug = Bool(False)
230 debug = Bool(False)
231 profile=CUnicode('default')
231 profile=CUnicode('default')
232
232
233 _outstanding_dict = Instance('collections.defaultdict', (set,))
233 _outstanding_dict = Instance('collections.defaultdict', (set,))
234 _ids = List()
234 _ids = List()
235 _connected=Bool(False)
235 _connected=Bool(False)
236 _ssh=Bool(False)
236 _ssh=Bool(False)
237 _context = Instance('zmq.Context')
237 _context = Instance('zmq.Context')
238 _config = Dict()
238 _config = Dict()
239 _engines=Instance(util.ReverseDict, (), {})
239 _engines=Instance(util.ReverseDict, (), {})
240 # _hub_socket=Instance('zmq.Socket')
240 # _hub_socket=Instance('zmq.Socket')
241 _query_socket=Instance('zmq.Socket')
241 _query_socket=Instance('zmq.Socket')
242 _control_socket=Instance('zmq.Socket')
242 _control_socket=Instance('zmq.Socket')
243 _iopub_socket=Instance('zmq.Socket')
243 _iopub_socket=Instance('zmq.Socket')
244 _notification_socket=Instance('zmq.Socket')
244 _notification_socket=Instance('zmq.Socket')
245 _mux_socket=Instance('zmq.Socket')
245 _mux_socket=Instance('zmq.Socket')
246 _task_socket=Instance('zmq.Socket')
246 _task_socket=Instance('zmq.Socket')
247 _task_scheme=Str()
247 _task_scheme=Str()
248 _balanced_views=Dict()
249 _direct_views=Dict()
250 _closed = False
248 _closed = False
251 _ignored_control_replies=Int(0)
249 _ignored_control_replies=Int(0)
252 _ignored_hub_replies=Int(0)
250 _ignored_hub_replies=Int(0)
253
251
254 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
252 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
255 context=None, username=None, debug=False, exec_key=None,
253 context=None, username=None, debug=False, exec_key=None,
256 sshserver=None, sshkey=None, password=None, paramiko=None,
254 sshserver=None, sshkey=None, password=None, paramiko=None,
257 timeout=10
255 timeout=10
258 ):
256 ):
259 super(Client, self).__init__(debug=debug, profile=profile)
257 super(Client, self).__init__(debug=debug, profile=profile)
260 if context is None:
258 if context is None:
261 context = zmq.Context.instance()
259 context = zmq.Context.instance()
262 self._context = context
260 self._context = context
263
261
264
262
265 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
263 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
266 if self._cd is not None:
264 if self._cd is not None:
267 if url_or_file is None:
265 if url_or_file is None:
268 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
266 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
269 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
267 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
270 " Please specify at least one of url_or_file or profile."
268 " Please specify at least one of url_or_file or profile."
271
269
272 try:
270 try:
273 util.validate_url(url_or_file)
271 util.validate_url(url_or_file)
274 except AssertionError:
272 except AssertionError:
275 if not os.path.exists(url_or_file):
273 if not os.path.exists(url_or_file):
276 if self._cd:
274 if self._cd:
277 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
275 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
278 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
276 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
279 with open(url_or_file) as f:
277 with open(url_or_file) as f:
280 cfg = json.loads(f.read())
278 cfg = json.loads(f.read())
281 else:
279 else:
282 cfg = {'url':url_or_file}
280 cfg = {'url':url_or_file}
283
281
284 # sync defaults from args, json:
282 # sync defaults from args, json:
285 if sshserver:
283 if sshserver:
286 cfg['ssh'] = sshserver
284 cfg['ssh'] = sshserver
287 if exec_key:
285 if exec_key:
288 cfg['exec_key'] = exec_key
286 cfg['exec_key'] = exec_key
289 exec_key = cfg['exec_key']
287 exec_key = cfg['exec_key']
290 sshserver=cfg['ssh']
288 sshserver=cfg['ssh']
291 url = cfg['url']
289 url = cfg['url']
292 location = cfg.setdefault('location', None)
290 location = cfg.setdefault('location', None)
293 cfg['url'] = util.disambiguate_url(cfg['url'], location)
291 cfg['url'] = util.disambiguate_url(cfg['url'], location)
294 url = cfg['url']
292 url = cfg['url']
295
293
296 self._config = cfg
294 self._config = cfg
297
295
298 self._ssh = bool(sshserver or sshkey or password)
296 self._ssh = bool(sshserver or sshkey or password)
299 if self._ssh and sshserver is None:
297 if self._ssh and sshserver is None:
300 # default to ssh via localhost
298 # default to ssh via localhost
301 sshserver = url.split('://')[1].split(':')[0]
299 sshserver = url.split('://')[1].split(':')[0]
302 if self._ssh and password is None:
300 if self._ssh and password is None:
303 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
301 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
304 password=False
302 password=False
305 else:
303 else:
306 password = getpass("SSH Password for %s: "%sshserver)
304 password = getpass("SSH Password for %s: "%sshserver)
307 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
305 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
308 if exec_key is not None and os.path.isfile(exec_key):
306 if exec_key is not None and os.path.isfile(exec_key):
309 arg = 'keyfile'
307 arg = 'keyfile'
310 else:
308 else:
311 arg = 'key'
309 arg = 'key'
312 key_arg = {arg:exec_key}
310 key_arg = {arg:exec_key}
313 if username is None:
311 if username is None:
314 self.session = ss.StreamSession(**key_arg)
312 self.session = ss.StreamSession(**key_arg)
315 else:
313 else:
316 self.session = ss.StreamSession(username, **key_arg)
314 self.session = ss.StreamSession(username, **key_arg)
317 self._query_socket = self._context.socket(zmq.XREQ)
315 self._query_socket = self._context.socket(zmq.XREQ)
318 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
316 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
319 if self._ssh:
317 if self._ssh:
320 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
318 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
321 else:
319 else:
322 self._query_socket.connect(url)
320 self._query_socket.connect(url)
323
321
324 self.session.debug = self.debug
322 self.session.debug = self.debug
325
323
326 self._notification_handlers = {'registration_notification' : self._register_engine,
324 self._notification_handlers = {'registration_notification' : self._register_engine,
327 'unregistration_notification' : self._unregister_engine,
325 'unregistration_notification' : self._unregister_engine,
328 'shutdown_notification' : lambda msg: self.close(),
326 'shutdown_notification' : lambda msg: self.close(),
329 }
327 }
330 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
328 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
331 'apply_reply' : self._handle_apply_reply}
329 'apply_reply' : self._handle_apply_reply}
332 self._connect(sshserver, ssh_kwargs, timeout)
330 self._connect(sshserver, ssh_kwargs, timeout)
333
331
334 def __del__(self):
332 def __del__(self):
335 """cleanup sockets, but _not_ context."""
333 """cleanup sockets, but _not_ context."""
336 self.close()
334 self.close()
337
335
338 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
336 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
339 if ipython_dir is None:
337 if ipython_dir is None:
340 ipython_dir = get_ipython_dir()
338 ipython_dir = get_ipython_dir()
341 if cluster_dir is not None:
339 if cluster_dir is not None:
342 try:
340 try:
343 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
341 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
344 return
342 return
345 except ClusterDirError:
343 except ClusterDirError:
346 pass
344 pass
347 elif profile is not None:
345 elif profile is not None:
348 try:
346 try:
349 self._cd = ClusterDir.find_cluster_dir_by_profile(
347 self._cd = ClusterDir.find_cluster_dir_by_profile(
350 ipython_dir, profile)
348 ipython_dir, profile)
351 return
349 return
352 except ClusterDirError:
350 except ClusterDirError:
353 pass
351 pass
354 self._cd = None
352 self._cd = None
355
353
356 def _update_engines(self, engines):
354 def _update_engines(self, engines):
357 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
355 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
358 for k,v in engines.iteritems():
356 for k,v in engines.iteritems():
359 eid = int(k)
357 eid = int(k)
360 self._engines[eid] = bytes(v) # force not unicode
358 self._engines[eid] = bytes(v) # force not unicode
361 self._ids.append(eid)
359 self._ids.append(eid)
362 self._ids = sorted(self._ids)
360 self._ids = sorted(self._ids)
363 if sorted(self._engines.keys()) != range(len(self._engines)) and \
361 if sorted(self._engines.keys()) != range(len(self._engines)) and \
364 self._task_scheme == 'pure' and self._task_socket:
362 self._task_scheme == 'pure' and self._task_socket:
365 self._stop_scheduling_tasks()
363 self._stop_scheduling_tasks()
366
364
367 def _stop_scheduling_tasks(self):
365 def _stop_scheduling_tasks(self):
368 """Stop scheduling tasks because an engine has been unregistered
366 """Stop scheduling tasks because an engine has been unregistered
369 from a pure ZMQ scheduler.
367 from a pure ZMQ scheduler.
370 """
368 """
371 self._task_socket.close()
369 self._task_socket.close()
372 self._task_socket = None
370 self._task_socket = None
373 msg = "An engine has been unregistered, and we are using pure " +\
371 msg = "An engine has been unregistered, and we are using pure " +\
374 "ZMQ task scheduling. Task farming will be disabled."
372 "ZMQ task scheduling. Task farming will be disabled."
375 if self.outstanding:
373 if self.outstanding:
376 msg += " If you were running tasks when this happened, " +\
374 msg += " If you were running tasks when this happened, " +\
377 "some `outstanding` msg_ids may never resolve."
375 "some `outstanding` msg_ids may never resolve."
378 warnings.warn(msg, RuntimeWarning)
376 warnings.warn(msg, RuntimeWarning)
379
377
380 def _build_targets(self, targets):
378 def _build_targets(self, targets):
381 """Turn valid target IDs or 'all' into two lists:
379 """Turn valid target IDs or 'all' into two lists:
382 (int_ids, uuids).
380 (int_ids, uuids).
383 """
381 """
384 if targets is None:
382 if targets is None:
385 targets = self._ids
383 targets = self._ids
386 elif isinstance(targets, str):
384 elif isinstance(targets, str):
387 if targets.lower() == 'all':
385 if targets.lower() == 'all':
388 targets = self._ids
386 targets = self._ids
389 else:
387 else:
390 raise TypeError("%r not valid str target, must be 'all'"%(targets))
388 raise TypeError("%r not valid str target, must be 'all'"%(targets))
391 elif isinstance(targets, int):
389 elif isinstance(targets, int):
390 if targets < 0:
391 targets = self.ids[targets]
392 if targets not in self.ids:
393 raise IndexError("No such engine: %i"%targets)
392 targets = [targets]
394 targets = [targets]
395
396 if isinstance(targets, slice):
397 indices = range(len(self._ids))[targets]
398 ids = self.ids
399 targets = [ ids[i] for i in indices ]
400
401 if not isinstance(targets, (tuple, list, xrange)):
402 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
403
393 return [self._engines[t] for t in targets], list(targets)
404 return [self._engines[t] for t in targets], list(targets)
394
405
395 def _connect(self, sshserver, ssh_kwargs, timeout):
406 def _connect(self, sshserver, ssh_kwargs, timeout):
396 """setup all our socket connections to the cluster. This is called from
407 """setup all our socket connections to the cluster. This is called from
397 __init__."""
408 __init__."""
398
409
399 # Maybe allow reconnecting?
410 # Maybe allow reconnecting?
400 if self._connected:
411 if self._connected:
401 return
412 return
402 self._connected=True
413 self._connected=True
403
414
404 def connect_socket(s, url):
415 def connect_socket(s, url):
405 url = util.disambiguate_url(url, self._config['location'])
416 url = util.disambiguate_url(url, self._config['location'])
406 if self._ssh:
417 if self._ssh:
407 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
418 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
408 else:
419 else:
409 return s.connect(url)
420 return s.connect(url)
410
421
411 self.session.send(self._query_socket, 'connection_request')
422 self.session.send(self._query_socket, 'connection_request')
412 r,w,x = zmq.select([self._query_socket],[],[], timeout)
423 r,w,x = zmq.select([self._query_socket],[],[], timeout)
413 if not r:
424 if not r:
414 raise error.TimeoutError("Hub connection request timed out")
425 raise error.TimeoutError("Hub connection request timed out")
415 idents,msg = self.session.recv(self._query_socket,mode=0)
426 idents,msg = self.session.recv(self._query_socket,mode=0)
416 if self.debug:
427 if self.debug:
417 pprint(msg)
428 pprint(msg)
418 msg = ss.Message(msg)
429 msg = ss.Message(msg)
419 content = msg.content
430 content = msg.content
420 self._config['registration'] = dict(content)
431 self._config['registration'] = dict(content)
421 if content.status == 'ok':
432 if content.status == 'ok':
422 if content.mux:
433 if content.mux:
423 self._mux_socket = self._context.socket(zmq.XREQ)
434 self._mux_socket = self._context.socket(zmq.XREQ)
424 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
435 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
425 connect_socket(self._mux_socket, content.mux)
436 connect_socket(self._mux_socket, content.mux)
426 if content.task:
437 if content.task:
427 self._task_scheme, task_addr = content.task
438 self._task_scheme, task_addr = content.task
428 self._task_socket = self._context.socket(zmq.XREQ)
439 self._task_socket = self._context.socket(zmq.XREQ)
429 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
440 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
430 connect_socket(self._task_socket, task_addr)
441 connect_socket(self._task_socket, task_addr)
431 if content.notification:
442 if content.notification:
432 self._notification_socket = self._context.socket(zmq.SUB)
443 self._notification_socket = self._context.socket(zmq.SUB)
433 connect_socket(self._notification_socket, content.notification)
444 connect_socket(self._notification_socket, content.notification)
434 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
445 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
435 # if content.query:
446 # if content.query:
436 # self._query_socket = self._context.socket(zmq.XREQ)
447 # self._query_socket = self._context.socket(zmq.XREQ)
437 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
438 # connect_socket(self._query_socket, content.query)
449 # connect_socket(self._query_socket, content.query)
439 if content.control:
450 if content.control:
440 self._control_socket = self._context.socket(zmq.XREQ)
451 self._control_socket = self._context.socket(zmq.XREQ)
441 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
452 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
442 connect_socket(self._control_socket, content.control)
453 connect_socket(self._control_socket, content.control)
443 if content.iopub:
454 if content.iopub:
444 self._iopub_socket = self._context.socket(zmq.SUB)
455 self._iopub_socket = self._context.socket(zmq.SUB)
445 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
456 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
446 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
457 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
447 connect_socket(self._iopub_socket, content.iopub)
458 connect_socket(self._iopub_socket, content.iopub)
448 self._update_engines(dict(content.engines))
459 self._update_engines(dict(content.engines))
449 else:
460 else:
450 self._connected = False
461 self._connected = False
451 raise Exception("Failed to connect!")
462 raise Exception("Failed to connect!")
452
463
453 #--------------------------------------------------------------------------
464 #--------------------------------------------------------------------------
454 # handlers and callbacks for incoming messages
465 # handlers and callbacks for incoming messages
455 #--------------------------------------------------------------------------
466 #--------------------------------------------------------------------------
456
467
457 def _unwrap_exception(self, content):
468 def _unwrap_exception(self, content):
458 """unwrap exception, and remap engine_id to int."""
469 """unwrap exception, and remap engine_id to int."""
459 e = error.unwrap_exception(content)
470 e = error.unwrap_exception(content)
460 # print e.traceback
471 # print e.traceback
461 if e.engine_info:
472 if e.engine_info:
462 e_uuid = e.engine_info['engine_uuid']
473 e_uuid = e.engine_info['engine_uuid']
463 eid = self._engines[e_uuid]
474 eid = self._engines[e_uuid]
464 e.engine_info['engine_id'] = eid
475 e.engine_info['engine_id'] = eid
465 return e
476 return e
466
477
467 def _extract_metadata(self, header, parent, content):
478 def _extract_metadata(self, header, parent, content):
468 md = {'msg_id' : parent['msg_id'],
479 md = {'msg_id' : parent['msg_id'],
469 'received' : datetime.now(),
480 'received' : datetime.now(),
470 'engine_uuid' : header.get('engine', None),
481 'engine_uuid' : header.get('engine', None),
471 'follow' : parent.get('follow', []),
482 'follow' : parent.get('follow', []),
472 'after' : parent.get('after', []),
483 'after' : parent.get('after', []),
473 'status' : content['status'],
484 'status' : content['status'],
474 }
485 }
475
486
476 if md['engine_uuid'] is not None:
487 if md['engine_uuid'] is not None:
477 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
488 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
478
489
479 if 'date' in parent:
490 if 'date' in parent:
480 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
491 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
481 if 'started' in header:
492 if 'started' in header:
482 md['started'] = datetime.strptime(header['started'], util.ISO8601)
493 md['started'] = datetime.strptime(header['started'], util.ISO8601)
483 if 'date' in header:
494 if 'date' in header:
484 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
495 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
485 return md
496 return md
486
497
487 def _register_engine(self, msg):
498 def _register_engine(self, msg):
488 """Register a new engine, and update our connection info."""
499 """Register a new engine, and update our connection info."""
489 content = msg['content']
500 content = msg['content']
490 eid = content['id']
501 eid = content['id']
491 d = {eid : content['queue']}
502 d = {eid : content['queue']}
492 self._update_engines(d)
503 self._update_engines(d)
493
504
494 def _unregister_engine(self, msg):
505 def _unregister_engine(self, msg):
495 """Unregister an engine that has died."""
506 """Unregister an engine that has died."""
496 content = msg['content']
507 content = msg['content']
497 eid = int(content['id'])
508 eid = int(content['id'])
498 if eid in self._ids:
509 if eid in self._ids:
499 self._ids.remove(eid)
510 self._ids.remove(eid)
500 uuid = self._engines.pop(eid)
511 uuid = self._engines.pop(eid)
501
512
502 self._handle_stranded_msgs(eid, uuid)
513 self._handle_stranded_msgs(eid, uuid)
503
514
504 if self._task_socket and self._task_scheme == 'pure':
515 if self._task_socket and self._task_scheme == 'pure':
505 self._stop_scheduling_tasks()
516 self._stop_scheduling_tasks()
506
517
507 def _handle_stranded_msgs(self, eid, uuid):
518 def _handle_stranded_msgs(self, eid, uuid):
508 """Handle messages known to be on an engine when the engine unregisters.
519 """Handle messages known to be on an engine when the engine unregisters.
509
520
510 It is possible that this will fire prematurely - that is, an engine will
521 It is possible that this will fire prematurely - that is, an engine will
511 go down after completing a result, and the client will be notified
522 go down after completing a result, and the client will be notified
512 of the unregistration and later receive the successful result.
523 of the unregistration and later receive the successful result.
513 """
524 """
514
525
515 outstanding = self._outstanding_dict[uuid]
526 outstanding = self._outstanding_dict[uuid]
516
527
517 for msg_id in list(outstanding):
528 for msg_id in list(outstanding):
518 if msg_id in self.results:
529 if msg_id in self.results:
519 # we already
530 # we already
520 continue
531 continue
521 try:
532 try:
522 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
533 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
523 except:
534 except:
524 content = error.wrap_exception()
535 content = error.wrap_exception()
525 # build a fake message:
536 # build a fake message:
526 parent = {}
537 parent = {}
527 header = {}
538 header = {}
528 parent['msg_id'] = msg_id
539 parent['msg_id'] = msg_id
529 header['engine'] = uuid
540 header['engine'] = uuid
530 header['date'] = datetime.now().strftime(util.ISO8601)
541 header['date'] = datetime.now().strftime(util.ISO8601)
531 msg = dict(parent_header=parent, header=header, content=content)
542 msg = dict(parent_header=parent, header=header, content=content)
532 self._handle_apply_reply(msg)
543 self._handle_apply_reply(msg)
533
544
534 def _handle_execute_reply(self, msg):
545 def _handle_execute_reply(self, msg):
535 """Save the reply to an execute_request into our results.
546 """Save the reply to an execute_request into our results.
536
547
537 execute messages are never actually used. apply is used instead.
548 execute messages are never actually used. apply is used instead.
538 """
549 """
539
550
540 parent = msg['parent_header']
551 parent = msg['parent_header']
541 msg_id = parent['msg_id']
552 msg_id = parent['msg_id']
542 if msg_id not in self.outstanding:
553 if msg_id not in self.outstanding:
543 if msg_id in self.history:
554 if msg_id in self.history:
544 print ("got stale result: %s"%msg_id)
555 print ("got stale result: %s"%msg_id)
545 else:
556 else:
546 print ("got unknown result: %s"%msg_id)
557 print ("got unknown result: %s"%msg_id)
547 else:
558 else:
548 self.outstanding.remove(msg_id)
559 self.outstanding.remove(msg_id)
549 self.results[msg_id] = self._unwrap_exception(msg['content'])
560 self.results[msg_id] = self._unwrap_exception(msg['content'])
550
561
551 def _handle_apply_reply(self, msg):
562 def _handle_apply_reply(self, msg):
552 """Save the reply to an apply_request into our results."""
563 """Save the reply to an apply_request into our results."""
553 parent = msg['parent_header']
564 parent = msg['parent_header']
554 msg_id = parent['msg_id']
565 msg_id = parent['msg_id']
555 if msg_id not in self.outstanding:
566 if msg_id not in self.outstanding:
556 if msg_id in self.history:
567 if msg_id in self.history:
557 print ("got stale result: %s"%msg_id)
568 print ("got stale result: %s"%msg_id)
558 print self.results[msg_id]
569 print self.results[msg_id]
559 print msg
570 print msg
560 else:
571 else:
561 print ("got unknown result: %s"%msg_id)
572 print ("got unknown result: %s"%msg_id)
562 else:
573 else:
563 self.outstanding.remove(msg_id)
574 self.outstanding.remove(msg_id)
564 content = msg['content']
575 content = msg['content']
565 header = msg['header']
576 header = msg['header']
566
577
567 # construct metadata:
578 # construct metadata:
568 md = self.metadata[msg_id]
579 md = self.metadata[msg_id]
569 md.update(self._extract_metadata(header, parent, content))
580 md.update(self._extract_metadata(header, parent, content))
570 # is this redundant?
581 # is this redundant?
571 self.metadata[msg_id] = md
582 self.metadata[msg_id] = md
572
583
573 e_outstanding = self._outstanding_dict[md['engine_uuid']]
584 e_outstanding = self._outstanding_dict[md['engine_uuid']]
574 if msg_id in e_outstanding:
585 if msg_id in e_outstanding:
575 e_outstanding.remove(msg_id)
586 e_outstanding.remove(msg_id)
576
587
577 # construct result:
588 # construct result:
578 if content['status'] == 'ok':
589 if content['status'] == 'ok':
579 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
590 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
580 elif content['status'] == 'aborted':
591 elif content['status'] == 'aborted':
581 self.results[msg_id] = error.TaskAborted(msg_id)
592 self.results[msg_id] = error.TaskAborted(msg_id)
582 elif content['status'] == 'resubmitted':
593 elif content['status'] == 'resubmitted':
583 # TODO: handle resubmission
594 # TODO: handle resubmission
584 pass
595 pass
585 else:
596 else:
586 self.results[msg_id] = self._unwrap_exception(content)
597 self.results[msg_id] = self._unwrap_exception(content)
587
598
588 def _flush_notifications(self):
599 def _flush_notifications(self):
589 """Flush notifications of engine registrations waiting
600 """Flush notifications of engine registrations waiting
590 in ZMQ queue."""
601 in ZMQ queue."""
591 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
602 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
592 while msg is not None:
603 while msg is not None:
593 if self.debug:
604 if self.debug:
594 pprint(msg)
605 pprint(msg)
595 msg = msg[-1]
606 msg = msg[-1]
596 msg_type = msg['msg_type']
607 msg_type = msg['msg_type']
597 handler = self._notification_handlers.get(msg_type, None)
608 handler = self._notification_handlers.get(msg_type, None)
598 if handler is None:
609 if handler is None:
599 raise Exception("Unhandled message type: %s"%msg.msg_type)
610 raise Exception("Unhandled message type: %s"%msg.msg_type)
600 else:
611 else:
601 handler(msg)
612 handler(msg)
602 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
613 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
603
614
604 def _flush_results(self, sock):
615 def _flush_results(self, sock):
605 """Flush task or queue results waiting in ZMQ queue."""
616 """Flush task or queue results waiting in ZMQ queue."""
606 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
617 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
607 while msg is not None:
618 while msg is not None:
608 if self.debug:
619 if self.debug:
609 pprint(msg)
620 pprint(msg)
610 msg = msg[-1]
621 msg = msg[-1]
611 msg_type = msg['msg_type']
622 msg_type = msg['msg_type']
612 handler = self._queue_handlers.get(msg_type, None)
623 handler = self._queue_handlers.get(msg_type, None)
613 if handler is None:
624 if handler is None:
614 raise Exception("Unhandled message type: %s"%msg.msg_type)
625 raise Exception("Unhandled message type: %s"%msg.msg_type)
615 else:
626 else:
616 handler(msg)
627 handler(msg)
617 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
618
629
619 def _flush_control(self, sock):
630 def _flush_control(self, sock):
620 """Flush replies from the control channel waiting
631 """Flush replies from the control channel waiting
621 in the ZMQ queue.
632 in the ZMQ queue.
622
633
623 Currently: ignore them."""
634 Currently: ignore them."""
624 if self._ignored_control_replies <= 0:
635 if self._ignored_control_replies <= 0:
625 return
636 return
626 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
637 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 while msg is not None:
638 while msg is not None:
628 self._ignored_control_replies -= 1
639 self._ignored_control_replies -= 1
629 if self.debug:
640 if self.debug:
630 pprint(msg)
641 pprint(msg)
631 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
642 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632
643
633 def _flush_ignored_control(self):
644 def _flush_ignored_control(self):
634 """flush ignored control replies"""
645 """flush ignored control replies"""
635 while self._ignored_control_replies > 0:
646 while self._ignored_control_replies > 0:
636 self.session.recv(self._control_socket)
647 self.session.recv(self._control_socket)
637 self._ignored_control_replies -= 1
648 self._ignored_control_replies -= 1
638
649
639 def _flush_ignored_hub_replies(self):
650 def _flush_ignored_hub_replies(self):
640 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
651 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
641 while msg is not None:
652 while msg is not None:
642 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
653 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
643
654
644 def _flush_iopub(self, sock):
655 def _flush_iopub(self, sock):
645 """Flush replies from the iopub channel waiting
656 """Flush replies from the iopub channel waiting
646 in the ZMQ queue.
657 in the ZMQ queue.
647 """
658 """
648 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
659 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 while msg is not None:
660 while msg is not None:
650 if self.debug:
661 if self.debug:
651 pprint(msg)
662 pprint(msg)
652 msg = msg[-1]
663 msg = msg[-1]
653 parent = msg['parent_header']
664 parent = msg['parent_header']
654 msg_id = parent['msg_id']
665 msg_id = parent['msg_id']
655 content = msg['content']
666 content = msg['content']
656 header = msg['header']
667 header = msg['header']
657 msg_type = msg['msg_type']
668 msg_type = msg['msg_type']
658
669
659 # init metadata:
670 # init metadata:
660 md = self.metadata[msg_id]
671 md = self.metadata[msg_id]
661
672
662 if msg_type == 'stream':
673 if msg_type == 'stream':
663 name = content['name']
674 name = content['name']
664 s = md[name] or ''
675 s = md[name] or ''
665 md[name] = s + content['data']
676 md[name] = s + content['data']
666 elif msg_type == 'pyerr':
677 elif msg_type == 'pyerr':
667 md.update({'pyerr' : self._unwrap_exception(content)})
678 md.update({'pyerr' : self._unwrap_exception(content)})
668 else:
679 else:
669 md.update({msg_type : content['data']})
680 md.update({msg_type : content['data']})
670
681
671 # reduntant?
682 # reduntant?
672 self.metadata[msg_id] = md
683 self.metadata[msg_id] = md
673
684
674 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
685 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
675
686
676 #--------------------------------------------------------------------------
687 #--------------------------------------------------------------------------
677 # len, getitem
688 # len, getitem
678 #--------------------------------------------------------------------------
689 #--------------------------------------------------------------------------
679
690
680 def __len__(self):
691 def __len__(self):
681 """len(client) returns # of engines."""
692 """len(client) returns # of engines."""
682 return len(self.ids)
693 return len(self.ids)
683
694
684 def __getitem__(self, key):
695 def __getitem__(self, key):
685 """index access returns DirectView multiplexer objects
696 """index access returns DirectView multiplexer objects
686
697
687 Must be int, slice, or list/tuple/xrange of ints"""
698 Must be int, slice, or list/tuple/xrange of ints"""
688 if not isinstance(key, (int, slice, tuple, list, xrange)):
699 if not isinstance(key, (int, slice, tuple, list, xrange)):
689 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
700 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
690 else:
701 else:
691 return self._get_view(key, balanced=False)
702 return self.direct_view(key)
692
703
693 #--------------------------------------------------------------------------
704 #--------------------------------------------------------------------------
694 # Begin public methods
705 # Begin public methods
695 #--------------------------------------------------------------------------
706 #--------------------------------------------------------------------------
696
707
697 @property
708 @property
698 def ids(self):
709 def ids(self):
699 """Always up-to-date ids property."""
710 """Always up-to-date ids property."""
700 self._flush_notifications()
711 self._flush_notifications()
701 # always copy:
712 # always copy:
702 return list(self._ids)
713 return list(self._ids)
703
714
704 def close(self):
715 def close(self):
705 if self._closed:
716 if self._closed:
706 return
717 return
707 snames = filter(lambda n: n.endswith('socket'), dir(self))
718 snames = filter(lambda n: n.endswith('socket'), dir(self))
708 for socket in map(lambda name: getattr(self, name), snames):
719 for socket in map(lambda name: getattr(self, name), snames):
709 if isinstance(socket, zmq.Socket) and not socket.closed:
720 if isinstance(socket, zmq.Socket) and not socket.closed:
710 socket.close()
721 socket.close()
711 self._closed = True
722 self._closed = True
712
723
713 def spin(self):
724 def spin(self):
714 """Flush any registration notifications and execution results
725 """Flush any registration notifications and execution results
715 waiting in the ZMQ queue.
726 waiting in the ZMQ queue.
716 """
727 """
717 if self._notification_socket:
728 if self._notification_socket:
718 self._flush_notifications()
729 self._flush_notifications()
719 if self._mux_socket:
730 if self._mux_socket:
720 self._flush_results(self._mux_socket)
731 self._flush_results(self._mux_socket)
721 if self._task_socket:
732 if self._task_socket:
722 self._flush_results(self._task_socket)
733 self._flush_results(self._task_socket)
723 if self._control_socket:
734 if self._control_socket:
724 self._flush_control(self._control_socket)
735 self._flush_control(self._control_socket)
725 if self._iopub_socket:
736 if self._iopub_socket:
726 self._flush_iopub(self._iopub_socket)
737 self._flush_iopub(self._iopub_socket)
727 if self._query_socket:
738 if self._query_socket:
728 self._flush_ignored_hub_replies()
739 self._flush_ignored_hub_replies()
729
740
730 def wait(self, jobs=None, timeout=-1):
741 def wait(self, jobs=None, timeout=-1):
731 """waits on one or more `jobs`, for up to `timeout` seconds.
742 """waits on one or more `jobs`, for up to `timeout` seconds.
732
743
733 Parameters
744 Parameters
734 ----------
745 ----------
735
746
736 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
747 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
737 ints are indices to self.history
748 ints are indices to self.history
738 strs are msg_ids
749 strs are msg_ids
739 default: wait on all outstanding messages
750 default: wait on all outstanding messages
740 timeout : float
751 timeout : float
741 a time in seconds, after which to give up.
752 a time in seconds, after which to give up.
742 default is -1, which means no timeout
753 default is -1, which means no timeout
743
754
744 Returns
755 Returns
745 -------
756 -------
746
757
747 True : when all msg_ids are done
758 True : when all msg_ids are done
748 False : timeout reached, some msg_ids still outstanding
759 False : timeout reached, some msg_ids still outstanding
749 """
760 """
750 tic = time.time()
761 tic = time.time()
751 if jobs is None:
762 if jobs is None:
752 theids = self.outstanding
763 theids = self.outstanding
753 else:
764 else:
754 if isinstance(jobs, (int, str, AsyncResult)):
765 if isinstance(jobs, (int, str, AsyncResult)):
755 jobs = [jobs]
766 jobs = [jobs]
756 theids = set()
767 theids = set()
757 for job in jobs:
768 for job in jobs:
758 if isinstance(job, int):
769 if isinstance(job, int):
759 # index access
770 # index access
760 job = self.history[job]
771 job = self.history[job]
761 elif isinstance(job, AsyncResult):
772 elif isinstance(job, AsyncResult):
762 map(theids.add, job.msg_ids)
773 map(theids.add, job.msg_ids)
763 continue
774 continue
764 theids.add(job)
775 theids.add(job)
765 if not theids.intersection(self.outstanding):
776 if not theids.intersection(self.outstanding):
766 return True
777 return True
767 self.spin()
778 self.spin()
768 while theids.intersection(self.outstanding):
779 while theids.intersection(self.outstanding):
769 if timeout >= 0 and ( time.time()-tic ) > timeout:
780 if timeout >= 0 and ( time.time()-tic ) > timeout:
770 break
781 break
771 time.sleep(1e-3)
782 time.sleep(1e-3)
772 self.spin()
783 self.spin()
773 return len(theids.intersection(self.outstanding)) == 0
784 return len(theids.intersection(self.outstanding)) == 0
774
785
775 #--------------------------------------------------------------------------
786 #--------------------------------------------------------------------------
776 # Control methods
787 # Control methods
777 #--------------------------------------------------------------------------
788 #--------------------------------------------------------------------------
778
789
779 @spin_first
790 @spin_first
780 @default_block
791 @default_block
781 def clear(self, targets=None, block=None):
792 def clear(self, targets=None, block=None):
782 """Clear the namespace in target(s)."""
793 """Clear the namespace in target(s)."""
783 targets = self._build_targets(targets)[0]
794 targets = self._build_targets(targets)[0]
784 for t in targets:
795 for t in targets:
785 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
796 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
786 error = False
797 error = False
787 if self.block:
798 if self.block:
788 self._flush_ignored_control()
799 self._flush_ignored_control()
789 for i in range(len(targets)):
800 for i in range(len(targets)):
790 idents,msg = self.session.recv(self._control_socket,0)
801 idents,msg = self.session.recv(self._control_socket,0)
791 if self.debug:
802 if self.debug:
792 pprint(msg)
803 pprint(msg)
793 if msg['content']['status'] != 'ok':
804 if msg['content']['status'] != 'ok':
794 error = self._unwrap_exception(msg['content'])
805 error = self._unwrap_exception(msg['content'])
795 else:
806 else:
796 self._ignored_control_replies += len(targets)
807 self._ignored_control_replies += len(targets)
797 if error:
808 if error:
798 raise error
809 raise error
799
810
800
811
801 @spin_first
812 @spin_first
802 @default_block
813 @default_block
803 def abort(self, jobs=None, targets=None, block=None):
814 def abort(self, jobs=None, targets=None, block=None):
804 """Abort specific jobs from the execution queues of target(s).
815 """Abort specific jobs from the execution queues of target(s).
805
816
806 This is a mechanism to prevent jobs that have already been submitted
817 This is a mechanism to prevent jobs that have already been submitted
807 from executing.
818 from executing.
808
819
809 Parameters
820 Parameters
810 ----------
821 ----------
811
822
812 jobs : msg_id, list of msg_ids, or AsyncResult
823 jobs : msg_id, list of msg_ids, or AsyncResult
813 The jobs to be aborted
824 The jobs to be aborted
814
825
815
826
816 """
827 """
817 targets = self._build_targets(targets)[0]
828 targets = self._build_targets(targets)[0]
818 msg_ids = []
829 msg_ids = []
819 if isinstance(jobs, (basestring,AsyncResult)):
830 if isinstance(jobs, (basestring,AsyncResult)):
820 jobs = [jobs]
831 jobs = [jobs]
821 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
832 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
822 if bad_ids:
833 if bad_ids:
823 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
834 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
824 for j in jobs:
835 for j in jobs:
825 if isinstance(j, AsyncResult):
836 if isinstance(j, AsyncResult):
826 msg_ids.extend(j.msg_ids)
837 msg_ids.extend(j.msg_ids)
827 else:
838 else:
828 msg_ids.append(j)
839 msg_ids.append(j)
829 content = dict(msg_ids=msg_ids)
840 content = dict(msg_ids=msg_ids)
830 for t in targets:
841 for t in targets:
831 self.session.send(self._control_socket, 'abort_request',
842 self.session.send(self._control_socket, 'abort_request',
832 content=content, ident=t)
843 content=content, ident=t)
833 error = False
844 error = False
834 if self.block:
845 if self.block:
835 self._flush_ignored_control()
846 self._flush_ignored_control()
836 for i in range(len(targets)):
847 for i in range(len(targets)):
837 idents,msg = self.session.recv(self._control_socket,0)
848 idents,msg = self.session.recv(self._control_socket,0)
838 if self.debug:
849 if self.debug:
839 pprint(msg)
850 pprint(msg)
840 if msg['content']['status'] != 'ok':
851 if msg['content']['status'] != 'ok':
841 error = self._unwrap_exception(msg['content'])
852 error = self._unwrap_exception(msg['content'])
842 else:
853 else:
843 self._ignored_control_replies += len(targets)
854 self._ignored_control_replies += len(targets)
844 if error:
855 if error:
845 raise error
856 raise error
846
857
847 @spin_first
858 @spin_first
848 @default_block
859 @default_block
849 def shutdown(self, targets=None, restart=False, hub=False, block=None):
860 def shutdown(self, targets=None, restart=False, hub=False, block=None):
850 """Terminates one or more engine processes, optionally including the hub."""
861 """Terminates one or more engine processes, optionally including the hub."""
851 if hub:
862 if hub:
852 targets = 'all'
863 targets = 'all'
853 targets = self._build_targets(targets)[0]
864 targets = self._build_targets(targets)[0]
854 for t in targets:
865 for t in targets:
855 self.session.send(self._control_socket, 'shutdown_request',
866 self.session.send(self._control_socket, 'shutdown_request',
856 content={'restart':restart},ident=t)
867 content={'restart':restart},ident=t)
857 error = False
868 error = False
858 if block or hub:
869 if block or hub:
859 self._flush_ignored_control()
870 self._flush_ignored_control()
860 for i in range(len(targets)):
871 for i in range(len(targets)):
861 idents,msg = self.session.recv(self._control_socket, 0)
872 idents,msg = self.session.recv(self._control_socket, 0)
862 if self.debug:
873 if self.debug:
863 pprint(msg)
874 pprint(msg)
864 if msg['content']['status'] != 'ok':
875 if msg['content']['status'] != 'ok':
865 error = self._unwrap_exception(msg['content'])
876 error = self._unwrap_exception(msg['content'])
866 else:
877 else:
867 self._ignored_control_replies += len(targets)
878 self._ignored_control_replies += len(targets)
868
879
869 if hub:
880 if hub:
870 time.sleep(0.25)
881 time.sleep(0.25)
871 self.session.send(self._query_socket, 'shutdown_request')
882 self.session.send(self._query_socket, 'shutdown_request')
872 idents,msg = self.session.recv(self._query_socket, 0)
883 idents,msg = self.session.recv(self._query_socket, 0)
873 if self.debug:
884 if self.debug:
874 pprint(msg)
885 pprint(msg)
875 if msg['content']['status'] != 'ok':
886 if msg['content']['status'] != 'ok':
876 error = self._unwrap_exception(msg['content'])
887 error = self._unwrap_exception(msg['content'])
877
888
878 if error:
889 if error:
879 raise error
890 raise error
880
891
881 #--------------------------------------------------------------------------
892 #--------------------------------------------------------------------------
882 # Execution methods
893 # Execution methods
883 #--------------------------------------------------------------------------
894 #--------------------------------------------------------------------------
884
895
885 @default_block
896 @default_block
886 def _execute(self, code, targets='all', block=None):
897 def _execute(self, code, targets='all', block=None):
887 """Executes `code` on `targets` in blocking or nonblocking manner.
898 """Executes `code` on `targets` in blocking or nonblocking manner.
888
899
889 ``execute`` is always `bound` (affects engine namespace)
900 ``execute`` is always `bound` (affects engine namespace)
890
901
891 Parameters
902 Parameters
892 ----------
903 ----------
893
904
894 code : str
905 code : str
895 the code string to be executed
906 the code string to be executed
896 targets : int/str/list of ints/strs
907 targets : int/str/list of ints/strs
897 the engines on which to execute
908 the engines on which to execute
898 default : all
909 default : all
899 block : bool
910 block : bool
900 whether or not to wait until done to return
911 whether or not to wait until done to return
901 default: self.block
912 default: self.block
902 """
913 """
903 return self[targets].execute(code, block=block)
914 return self[targets].execute(code, block=block)
904
915
905 def _maybe_raise(self, result):
916 def _maybe_raise(self, result):
906 """wrapper for maybe raising an exception if apply failed."""
917 """wrapper for maybe raising an exception if apply failed."""
907 if isinstance(result, error.RemoteError):
918 if isinstance(result, error.RemoteError):
908 raise result
919 raise result
909
920
910 return result
921 return result
911
922
912 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
923 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
913 ident=None):
924 ident=None):
914 """construct and send an apply message via a socket.
925 """construct and send an apply message via a socket.
915
926
916 This is the principal method with which all engine execution is performed by views.
927 This is the principal method with which all engine execution is performed by views.
917 """
928 """
918
929
919 assert not self._closed, "cannot use me anymore, I'm closed!"
930 assert not self._closed, "cannot use me anymore, I'm closed!"
920 # defaults:
931 # defaults:
921 args = args if args is not None else []
932 args = args if args is not None else []
922 kwargs = kwargs if kwargs is not None else {}
933 kwargs = kwargs if kwargs is not None else {}
923 subheader = subheader if subheader is not None else {}
934 subheader = subheader if subheader is not None else {}
924
935
925 # validate arguments
936 # validate arguments
926 if not callable(f):
937 if not callable(f):
927 raise TypeError("f must be callable, not %s"%type(f))
938 raise TypeError("f must be callable, not %s"%type(f))
928 if not isinstance(args, (tuple, list)):
939 if not isinstance(args, (tuple, list)):
929 raise TypeError("args must be tuple or list, not %s"%type(args))
940 raise TypeError("args must be tuple or list, not %s"%type(args))
930 if not isinstance(kwargs, dict):
941 if not isinstance(kwargs, dict):
931 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
942 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
932 if not isinstance(subheader, dict):
943 if not isinstance(subheader, dict):
933 raise TypeError("subheader must be dict, not %s"%type(subheader))
944 raise TypeError("subheader must be dict, not %s"%type(subheader))
934
945
935 if not self._ids:
946 if not self._ids:
936 # flush notification socket if no engines yet
947 # flush notification socket if no engines yet
937 any_ids = self.ids
948 any_ids = self.ids
938 if not any_ids:
949 if not any_ids:
939 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
950 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
940 # enforce types of f,args,kwargs
951 # enforce types of f,args,kwargs
941
952
942 bufs = util.pack_apply_message(f,args,kwargs)
953 bufs = util.pack_apply_message(f,args,kwargs)
943
954
944 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
955 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
945 subheader=subheader, track=track)
956 subheader=subheader, track=track)
946
957
947 msg_id = msg['msg_id']
958 msg_id = msg['msg_id']
948 self.outstanding.add(msg_id)
959 self.outstanding.add(msg_id)
949 if ident:
960 if ident:
950 # possibly routed to a specific engine
961 # possibly routed to a specific engine
951 if isinstance(ident, list):
962 if isinstance(ident, list):
952 ident = ident[-1]
963 ident = ident[-1]
953 if ident in self._engines.values():
964 if ident in self._engines.values():
954 # save for later, in case of engine death
965 # save for later, in case of engine death
955 self._outstanding_dict[ident].add(msg_id)
966 self._outstanding_dict[ident].add(msg_id)
956 self.history.append(msg_id)
967 self.history.append(msg_id)
957 self.metadata[msg_id]['submitted'] = datetime.now()
968 self.metadata[msg_id]['submitted'] = datetime.now()
958
969
959 return msg
970 return msg
960
971
961 #--------------------------------------------------------------------------
972 #--------------------------------------------------------------------------
962 # construct a View object
973 # construct a View object
963 #--------------------------------------------------------------------------
974 #--------------------------------------------------------------------------
964
975
965 def _cache_view(self, targets, balanced):
966 """save views, so subsequent requests don't create new objects."""
967 if balanced:
968 # validate whether we can run
969 if not self._task_socket:
970 msg = "Task farming is disabled"
971 if self._task_scheme == 'pure':
972 msg += " because the pure ZMQ scheduler cannot handle"
973 msg += " disappearing engines."
974 raise RuntimeError(msg)
975 socket = self._task_socket
976 view_class = LoadBalancedView
977 view_cache = self._balanced_views
978 else:
979 socket = self._mux_socket
980 view_class = DirectView
981 view_cache = self._direct_views
982
983 # use str, since often targets will be a list
984 key = str(targets)
985 if key not in view_cache:
986 view_cache[key] = view_class(client=self, socket=socket, targets=targets)
987
988 return view_cache[key]
989
990 def load_balanced_view(self, targets=None):
976 def load_balanced_view(self, targets=None):
991 """construct a DirectView object.
977 """construct a DirectView object.
992
978
993 If no arguments are specified, create a LoadBalancedView
979 If no arguments are specified, create a LoadBalancedView
994 using all engines.
980 using all engines.
995
981
996 Parameters
982 Parameters
997 ----------
983 ----------
998
984
999 targets: list,slice,int,etc. [default: use all engines]
985 targets: list,slice,int,etc. [default: use all engines]
1000 The subset of engines across which to load-balance
986 The subset of engines across which to load-balance
1001 """
987 """
1002 return self._get_view(targets, balanced=True)
988 if targets is None:
989 targets = self._build_targets(targets)[1]
990 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1003
991
1004 def direct_view(self, targets='all'):
992 def direct_view(self, targets='all'):
1005 """construct a DirectView object.
993 """construct a DirectView object.
1006
994
1007 If no targets are specified, create a DirectView
995 If no targets are specified, create a DirectView
1008 using all engines.
996 using all engines.
1009
997
1010 Parameters
998 Parameters
1011 ----------
999 ----------
1012
1000
1013 targets: list,slice,int,etc. [default: use all engines]
1001 targets: list,slice,int,etc. [default: use all engines]
1014 The engines to use for the View
1002 The engines to use for the View
1015 """
1003 """
1016 return self._get_view(targets, balanced=False)
1004 single = isinstance(targets, int)
1017
1005 targets = self._build_targets(targets)[1]
1018 def _get_view(self, targets, balanced):
1006 if single:
1019 """Method for constructing View objects.
1007 targets = targets[0]
1020
1008 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1021 If no arguments are specified, create a LoadBalancedView
1022 using all engines. If only `targets` specified, it will
1023 be a DirectView. This method is the underlying implementation
1024 of ``client.__getitem__``.
1025
1026 Parameters
1027 ----------
1028
1029 targets: list,slice,int,etc. [default: use all engines]
1030 The engines to use for the View
1031 balanced : bool [default: False if targets specified, True else]
1032 whether to build a LoadBalancedView or a DirectView
1033
1034 """
1035
1036 if targets in (None,'all'):
1037 if balanced:
1038 return self._cache_view(None,True)
1039 else:
1040 targets = slice(None)
1041
1042 if isinstance(targets, int):
1043 if targets < 0:
1044 targets = self.ids[targets]
1045 if targets not in self.ids:
1046 raise IndexError("No such engine: %i"%targets)
1047 return self._cache_view(targets, balanced)
1048
1049 if isinstance(targets, slice):
1050 indices = range(len(self.ids))[targets]
1051 ids = sorted(self._ids)
1052 targets = [ ids[i] for i in indices ]
1053
1054 if isinstance(targets, (tuple, list, xrange)):
1055 _,targets = self._build_targets(list(targets))
1056 return self._cache_view(targets, balanced)
1057 else:
1058 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1059
1009
1060 #--------------------------------------------------------------------------
1010 #--------------------------------------------------------------------------
1061 # Data movement (TO BE REMOVED)
1011 # Data movement (TO BE REMOVED)
1062 #--------------------------------------------------------------------------
1012 #--------------------------------------------------------------------------
1063
1013
1064 @default_block
1014 @default_block
1065 def _push(self, ns, targets='all', block=None, track=False):
1015 def _push(self, ns, targets='all', block=None, track=False):
1066 """Push the contents of `ns` into the namespace on `target`"""
1016 """Push the contents of `ns` into the namespace on `target`"""
1067 if not isinstance(ns, dict):
1017 if not isinstance(ns, dict):
1068 raise TypeError("Must be a dict, not %s"%type(ns))
1018 raise TypeError("Must be a dict, not %s"%type(ns))
1069 result = self.apply(util._push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1019 result = self.apply(util._push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1070 if not block:
1020 if not block:
1071 return result
1021 return result
1072
1022
1073 @default_block
1023 @default_block
1074 def _pull(self, keys, targets='all', block=None):
1024 def _pull(self, keys, targets='all', block=None):
1075 """Pull objects from `target`'s namespace by `keys`"""
1025 """Pull objects from `target`'s namespace by `keys`"""
1076 if isinstance(keys, basestring):
1026 if isinstance(keys, basestring):
1077 pass
1027 pass
1078 elif isinstance(keys, (list,tuple,set)):
1028 elif isinstance(keys, (list,tuple,set)):
1079 for key in keys:
1029 for key in keys:
1080 if not isinstance(key, basestring):
1030 if not isinstance(key, basestring):
1081 raise TypeError("keys must be str, not type %r"%type(key))
1031 raise TypeError("keys must be str, not type %r"%type(key))
1082 else:
1032 else:
1083 raise TypeError("keys must be strs, not %r"%keys)
1033 raise TypeError("keys must be strs, not %r"%keys)
1084 result = self.apply(util._pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1034 result = self.apply(util._pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1085 return result
1035 return result
1086
1036
1087 #--------------------------------------------------------------------------
1037 #--------------------------------------------------------------------------
1088 # Query methods
1038 # Query methods
1089 #--------------------------------------------------------------------------
1039 #--------------------------------------------------------------------------
1090
1040
1091 @spin_first
1041 @spin_first
1092 @default_block
1042 @default_block
1093 def get_result(self, indices_or_msg_ids=None, block=None):
1043 def get_result(self, indices_or_msg_ids=None, block=None):
1094 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1044 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1095
1045
1096 If the client already has the results, no request to the Hub will be made.
1046 If the client already has the results, no request to the Hub will be made.
1097
1047
1098 This is a convenient way to construct AsyncResult objects, which are wrappers
1048 This is a convenient way to construct AsyncResult objects, which are wrappers
1099 that include metadata about execution, and allow for awaiting results that
1049 that include metadata about execution, and allow for awaiting results that
1100 were not submitted by this Client.
1050 were not submitted by this Client.
1101
1051
1102 It can also be a convenient way to retrieve the metadata associated with
1052 It can also be a convenient way to retrieve the metadata associated with
1103 blocking execution, since it always retrieves
1053 blocking execution, since it always retrieves
1104
1054
1105 Examples
1055 Examples
1106 --------
1056 --------
1107 ::
1057 ::
1108
1058
1109 In [10]: r = client.apply()
1059 In [10]: r = client.apply()
1110
1060
1111 Parameters
1061 Parameters
1112 ----------
1062 ----------
1113
1063
1114 indices_or_msg_ids : integer history index, str msg_id, or list of either
1064 indices_or_msg_ids : integer history index, str msg_id, or list of either
1115 The indices or msg_ids of indices to be retrieved
1065 The indices or msg_ids of indices to be retrieved
1116
1066
1117 block : bool
1067 block : bool
1118 Whether to wait for the result to be done
1068 Whether to wait for the result to be done
1119
1069
1120 Returns
1070 Returns
1121 -------
1071 -------
1122
1072
1123 AsyncResult
1073 AsyncResult
1124 A single AsyncResult object will always be returned.
1074 A single AsyncResult object will always be returned.
1125
1075
1126 AsyncHubResult
1076 AsyncHubResult
1127 A subclass of AsyncResult that retrieves results from the Hub
1077 A subclass of AsyncResult that retrieves results from the Hub
1128
1078
1129 """
1079 """
1130 if indices_or_msg_ids is None:
1080 if indices_or_msg_ids is None:
1131 indices_or_msg_ids = -1
1081 indices_or_msg_ids = -1
1132
1082
1133 if not isinstance(indices_or_msg_ids, (list,tuple)):
1083 if not isinstance(indices_or_msg_ids, (list,tuple)):
1134 indices_or_msg_ids = [indices_or_msg_ids]
1084 indices_or_msg_ids = [indices_or_msg_ids]
1135
1085
1136 theids = []
1086 theids = []
1137 for id in indices_or_msg_ids:
1087 for id in indices_or_msg_ids:
1138 if isinstance(id, int):
1088 if isinstance(id, int):
1139 id = self.history[id]
1089 id = self.history[id]
1140 if not isinstance(id, str):
1090 if not isinstance(id, str):
1141 raise TypeError("indices must be str or int, not %r"%id)
1091 raise TypeError("indices must be str or int, not %r"%id)
1142 theids.append(id)
1092 theids.append(id)
1143
1093
1144 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1094 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1145 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1095 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1146
1096
1147 if remote_ids:
1097 if remote_ids:
1148 ar = AsyncHubResult(self, msg_ids=theids)
1098 ar = AsyncHubResult(self, msg_ids=theids)
1149 else:
1099 else:
1150 ar = AsyncResult(self, msg_ids=theids)
1100 ar = AsyncResult(self, msg_ids=theids)
1151
1101
1152 if block:
1102 if block:
1153 ar.wait()
1103 ar.wait()
1154
1104
1155 return ar
1105 return ar
1156
1106
1157 @spin_first
1107 @spin_first
1158 def result_status(self, msg_ids, status_only=True):
1108 def result_status(self, msg_ids, status_only=True):
1159 """Check on the status of the result(s) of the apply request with `msg_ids`.
1109 """Check on the status of the result(s) of the apply request with `msg_ids`.
1160
1110
1161 If status_only is False, then the actual results will be retrieved, else
1111 If status_only is False, then the actual results will be retrieved, else
1162 only the status of the results will be checked.
1112 only the status of the results will be checked.
1163
1113
1164 Parameters
1114 Parameters
1165 ----------
1115 ----------
1166
1116
1167 msg_ids : list of msg_ids
1117 msg_ids : list of msg_ids
1168 if int:
1118 if int:
1169 Passed as index to self.history for convenience.
1119 Passed as index to self.history for convenience.
1170 status_only : bool (default: True)
1120 status_only : bool (default: True)
1171 if False:
1121 if False:
1172 Retrieve the actual results of completed tasks.
1122 Retrieve the actual results of completed tasks.
1173
1123
1174 Returns
1124 Returns
1175 -------
1125 -------
1176
1126
1177 results : dict
1127 results : dict
1178 There will always be the keys 'pending' and 'completed', which will
1128 There will always be the keys 'pending' and 'completed', which will
1179 be lists of msg_ids that are incomplete or complete. If `status_only`
1129 be lists of msg_ids that are incomplete or complete. If `status_only`
1180 is False, then completed results will be keyed by their `msg_id`.
1130 is False, then completed results will be keyed by their `msg_id`.
1181 """
1131 """
1182 if not isinstance(msg_ids, (list,tuple)):
1132 if not isinstance(msg_ids, (list,tuple)):
1183 msg_ids = [msg_ids]
1133 msg_ids = [msg_ids]
1184
1134
1185 theids = []
1135 theids = []
1186 for msg_id in msg_ids:
1136 for msg_id in msg_ids:
1187 if isinstance(msg_id, int):
1137 if isinstance(msg_id, int):
1188 msg_id = self.history[msg_id]
1138 msg_id = self.history[msg_id]
1189 if not isinstance(msg_id, basestring):
1139 if not isinstance(msg_id, basestring):
1190 raise TypeError("msg_ids must be str, not %r"%msg_id)
1140 raise TypeError("msg_ids must be str, not %r"%msg_id)
1191 theids.append(msg_id)
1141 theids.append(msg_id)
1192
1142
1193 completed = []
1143 completed = []
1194 local_results = {}
1144 local_results = {}
1195
1145
1196 # comment this block out to temporarily disable local shortcut:
1146 # comment this block out to temporarily disable local shortcut:
1197 for msg_id in theids:
1147 for msg_id in theids:
1198 if msg_id in self.results:
1148 if msg_id in self.results:
1199 completed.append(msg_id)
1149 completed.append(msg_id)
1200 local_results[msg_id] = self.results[msg_id]
1150 local_results[msg_id] = self.results[msg_id]
1201 theids.remove(msg_id)
1151 theids.remove(msg_id)
1202
1152
1203 if theids: # some not locally cached
1153 if theids: # some not locally cached
1204 content = dict(msg_ids=theids, status_only=status_only)
1154 content = dict(msg_ids=theids, status_only=status_only)
1205 msg = self.session.send(self._query_socket, "result_request", content=content)
1155 msg = self.session.send(self._query_socket, "result_request", content=content)
1206 zmq.select([self._query_socket], [], [])
1156 zmq.select([self._query_socket], [], [])
1207 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1157 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1208 if self.debug:
1158 if self.debug:
1209 pprint(msg)
1159 pprint(msg)
1210 content = msg['content']
1160 content = msg['content']
1211 if content['status'] != 'ok':
1161 if content['status'] != 'ok':
1212 raise self._unwrap_exception(content)
1162 raise self._unwrap_exception(content)
1213 buffers = msg['buffers']
1163 buffers = msg['buffers']
1214 else:
1164 else:
1215 content = dict(completed=[],pending=[])
1165 content = dict(completed=[],pending=[])
1216
1166
1217 content['completed'].extend(completed)
1167 content['completed'].extend(completed)
1218
1168
1219 if status_only:
1169 if status_only:
1220 return content
1170 return content
1221
1171
1222 failures = []
1172 failures = []
1223 # load cached results into result:
1173 # load cached results into result:
1224 content.update(local_results)
1174 content.update(local_results)
1225 # update cache with results:
1175 # update cache with results:
1226 for msg_id in sorted(theids):
1176 for msg_id in sorted(theids):
1227 if msg_id in content['completed']:
1177 if msg_id in content['completed']:
1228 rec = content[msg_id]
1178 rec = content[msg_id]
1229 parent = rec['header']
1179 parent = rec['header']
1230 header = rec['result_header']
1180 header = rec['result_header']
1231 rcontent = rec['result_content']
1181 rcontent = rec['result_content']
1232 iodict = rec['io']
1182 iodict = rec['io']
1233 if isinstance(rcontent, str):
1183 if isinstance(rcontent, str):
1234 rcontent = self.session.unpack(rcontent)
1184 rcontent = self.session.unpack(rcontent)
1235
1185
1236 md = self.metadata[msg_id]
1186 md = self.metadata[msg_id]
1237 md.update(self._extract_metadata(header, parent, rcontent))
1187 md.update(self._extract_metadata(header, parent, rcontent))
1238 md.update(iodict)
1188 md.update(iodict)
1239
1189
1240 if rcontent['status'] == 'ok':
1190 if rcontent['status'] == 'ok':
1241 res,buffers = util.unserialize_object(buffers)
1191 res,buffers = util.unserialize_object(buffers)
1242 else:
1192 else:
1243 print rcontent
1193 print rcontent
1244 res = self._unwrap_exception(rcontent)
1194 res = self._unwrap_exception(rcontent)
1245 failures.append(res)
1195 failures.append(res)
1246
1196
1247 self.results[msg_id] = res
1197 self.results[msg_id] = res
1248 content[msg_id] = res
1198 content[msg_id] = res
1249
1199
1250 if len(theids) == 1 and failures:
1200 if len(theids) == 1 and failures:
1251 raise failures[0]
1201 raise failures[0]
1252
1202
1253 error.collect_exceptions(failures, "result_status")
1203 error.collect_exceptions(failures, "result_status")
1254 return content
1204 return content
1255
1205
1256 @spin_first
1206 @spin_first
1257 def queue_status(self, targets='all', verbose=False):
1207 def queue_status(self, targets='all', verbose=False):
1258 """Fetch the status of engine queues.
1208 """Fetch the status of engine queues.
1259
1209
1260 Parameters
1210 Parameters
1261 ----------
1211 ----------
1262
1212
1263 targets : int/str/list of ints/strs
1213 targets : int/str/list of ints/strs
1264 the engines whose states are to be queried.
1214 the engines whose states are to be queried.
1265 default : all
1215 default : all
1266 verbose : bool
1216 verbose : bool
1267 Whether to return lengths only, or lists of ids for each element
1217 Whether to return lengths only, or lists of ids for each element
1268 """
1218 """
1269 engine_ids = self._build_targets(targets)[1]
1219 engine_ids = self._build_targets(targets)[1]
1270 content = dict(targets=engine_ids, verbose=verbose)
1220 content = dict(targets=engine_ids, verbose=verbose)
1271 self.session.send(self._query_socket, "queue_request", content=content)
1221 self.session.send(self._query_socket, "queue_request", content=content)
1272 idents,msg = self.session.recv(self._query_socket, 0)
1222 idents,msg = self.session.recv(self._query_socket, 0)
1273 if self.debug:
1223 if self.debug:
1274 pprint(msg)
1224 pprint(msg)
1275 content = msg['content']
1225 content = msg['content']
1276 status = content.pop('status')
1226 status = content.pop('status')
1277 if status != 'ok':
1227 if status != 'ok':
1278 raise self._unwrap_exception(content)
1228 raise self._unwrap_exception(content)
1279 content = util.rekey(content)
1229 content = util.rekey(content)
1280 if isinstance(targets, int):
1230 if isinstance(targets, int):
1281 return content[targets]
1231 return content[targets]
1282 else:
1232 else:
1283 return content
1233 return content
1284
1234
1285 @spin_first
1235 @spin_first
1286 def purge_results(self, jobs=[], targets=[]):
1236 def purge_results(self, jobs=[], targets=[]):
1287 """Tell the Hub to forget results.
1237 """Tell the Hub to forget results.
1288
1238
1289 Individual results can be purged by msg_id, or the entire
1239 Individual results can be purged by msg_id, or the entire
1290 history of specific targets can be purged.
1240 history of specific targets can be purged.
1291
1241
1292 Parameters
1242 Parameters
1293 ----------
1243 ----------
1294
1244
1295 jobs : str or list of str or AsyncResult objects
1245 jobs : str or list of str or AsyncResult objects
1296 the msg_ids whose results should be forgotten.
1246 the msg_ids whose results should be forgotten.
1297 targets : int/str/list of ints/strs
1247 targets : int/str/list of ints/strs
1298 The targets, by uuid or int_id, whose entire history is to be purged.
1248 The targets, by uuid or int_id, whose entire history is to be purged.
1299 Use `targets='all'` to scrub everything from the Hub's memory.
1249 Use `targets='all'` to scrub everything from the Hub's memory.
1300
1250
1301 default : None
1251 default : None
1302 """
1252 """
1303 if not targets and not jobs:
1253 if not targets and not jobs:
1304 raise ValueError("Must specify at least one of `targets` and `jobs`")
1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1305 if targets:
1255 if targets:
1306 targets = self._build_targets(targets)[1]
1256 targets = self._build_targets(targets)[1]
1307
1257
1308 # construct msg_ids from jobs
1258 # construct msg_ids from jobs
1309 msg_ids = []
1259 msg_ids = []
1310 if isinstance(jobs, (basestring,AsyncResult)):
1260 if isinstance(jobs, (basestring,AsyncResult)):
1311 jobs = [jobs]
1261 jobs = [jobs]
1312 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1313 if bad_ids:
1263 if bad_ids:
1314 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1315 for j in jobs:
1265 for j in jobs:
1316 if isinstance(j, AsyncResult):
1266 if isinstance(j, AsyncResult):
1317 msg_ids.extend(j.msg_ids)
1267 msg_ids.extend(j.msg_ids)
1318 else:
1268 else:
1319 msg_ids.append(j)
1269 msg_ids.append(j)
1320
1270
1321 content = dict(targets=targets, msg_ids=msg_ids)
1271 content = dict(targets=targets, msg_ids=msg_ids)
1322 self.session.send(self._query_socket, "purge_request", content=content)
1272 self.session.send(self._query_socket, "purge_request", content=content)
1323 idents, msg = self.session.recv(self._query_socket, 0)
1273 idents, msg = self.session.recv(self._query_socket, 0)
1324 if self.debug:
1274 if self.debug:
1325 pprint(msg)
1275 pprint(msg)
1326 content = msg['content']
1276 content = msg['content']
1327 if content['status'] != 'ok':
1277 if content['status'] != 'ok':
1328 raise self._unwrap_exception(content)
1278 raise self._unwrap_exception(content)
1329
1279
1330
1280
1331 __all__ = [ 'Client',
1281 __all__ = [ 'Client',
1332 'depend',
1282 'depend',
1333 'require',
1283 'require',
1334 'remote',
1284 'remote',
1335 'parallel',
1285 'parallel',
1336 'RemoteFunction',
1286 'RemoteFunction',
1337 'ParallelFunction',
1287 'ParallelFunction',
1338 'DirectView',
1288 'DirectView',
1339 'LoadBalancedView',
1289 'LoadBalancedView',
1340 'AsyncResult',
1290 'AsyncResult',
1341 'AsyncMapResult',
1291 'AsyncMapResult',
1342 'Reference'
1292 'Reference'
1343 ]
1293 ]
@@ -1,184 +1,196 b''
1 """Dependency utilities"""
1 """Dependency utilities"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
3 # Copyright (C) 2010-2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 from types import ModuleType
10
9 from .asyncresult import AsyncResult
11 from .asyncresult import AsyncResult
10 from .error import UnmetDependency
12 from .error import UnmetDependency
11 from .util import interactive
13 from .util import interactive
12
14
13 class depend(object):
15 class depend(object):
14 """Dependency decorator, for use with tasks.
16 """Dependency decorator, for use with tasks.
15
17
16 `@depend` lets you define a function for engine dependencies
18 `@depend` lets you define a function for engine dependencies
17 just like you use `apply` for tasks.
19 just like you use `apply` for tasks.
18
20
19
21
20 Examples
22 Examples
21 --------
23 --------
22 ::
24 ::
23
25
24 @depend(df, a,b, c=5)
26 @depend(df, a,b, c=5)
25 def f(m,n,p)
27 def f(m,n,p)
26
28
27 view.apply(f, 1,2,3)
29 view.apply(f, 1,2,3)
28
30
29 will call df(a,b,c=5) on the engine, and if it returns False or
31 will call df(a,b,c=5) on the engine, and if it returns False or
30 raises an UnmetDependency error, then the task will not be run
32 raises an UnmetDependency error, then the task will not be run
31 and another engine will be tried.
33 and another engine will be tried.
32 """
34 """
33 def __init__(self, f, *args, **kwargs):
35 def __init__(self, f, *args, **kwargs):
34 self.f = f
36 self.f = f
35 self.args = args
37 self.args = args
36 self.kwargs = kwargs
38 self.kwargs = kwargs
37
39
38 def __call__(self, f):
40 def __call__(self, f):
39 return dependent(f, self.f, *self.args, **self.kwargs)
41 return dependent(f, self.f, *self.args, **self.kwargs)
40
42
41 class dependent(object):
43 class dependent(object):
42 """A function that depends on another function.
44 """A function that depends on another function.
43 This is an object to prevent the closure used
45 This is an object to prevent the closure used
44 in traditional decorators, which are not picklable.
46 in traditional decorators, which are not picklable.
45 """
47 """
46
48
47 def __init__(self, f, df, *dargs, **dkwargs):
49 def __init__(self, f, df, *dargs, **dkwargs):
48 self.f = f
50 self.f = f
49 self.func_name = getattr(f, '__name__', 'f')
51 self.func_name = getattr(f, '__name__', 'f')
50 self.df = df
52 self.df = df
51 self.dargs = dargs
53 self.dargs = dargs
52 self.dkwargs = dkwargs
54 self.dkwargs = dkwargs
53
55
54 def __call__(self, *args, **kwargs):
56 def __call__(self, *args, **kwargs):
55 # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
57 # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
56 # self.df.func_globals = self.f.func_globals
58 # self.df.func_globals = self.f.func_globals
57 if self.df(*self.dargs, **self.dkwargs) is False:
59 if self.df(*self.dargs, **self.dkwargs) is False:
58 raise UnmetDependency()
60 raise UnmetDependency()
59 return self.f(*args, **kwargs)
61 return self.f(*args, **kwargs)
60
62
61 @property
63 @property
62 def __name__(self):
64 def __name__(self):
63 return self.func_name
65 return self.func_name
64
66
65 @interactive
67 @interactive
66 def _require(*names):
68 def _require(*names):
67 """Helper for @require decorator."""
69 """Helper for @require decorator."""
68 from IPython.zmq.parallel.error import UnmetDependency
70 from IPython.zmq.parallel.error import UnmetDependency
69 user_ns = globals()
71 user_ns = globals()
70 for name in names:
72 for name in names:
71 if name in user_ns:
73 if name in user_ns:
72 continue
74 continue
73 try:
75 try:
74 exec 'import %s'%name in user_ns
76 exec 'import %s'%name in user_ns
75 except ImportError:
77 except ImportError:
76 raise UnmetDependency(name)
78 raise UnmetDependency(name)
77 return True
79 return True
78
80
79 def require(*names):
81 def require(*mods):
80 """Simple decorator for requiring names to be importable.
82 """Simple decorator for requiring names to be importable.
81
83
82 Examples
84 Examples
83 --------
85 --------
84
86
85 In [1]: @require('numpy')
87 In [1]: @require('numpy')
86 ...: def norm(a):
88 ...: def norm(a):
87 ...: import numpy
89 ...: import numpy
88 ...: return numpy.linalg.norm(a,2)
90 ...: return numpy.linalg.norm(a,2)
89 """
91 """
92 names = []
93 for mod in mods:
94 if isinstance(mod, ModuleType):
95 mod = mod.__name__
96
97 if isinstance(mod, basestring):
98 names.append(mod)
99 else:
100 raise TypeError("names must be modules or module names, not %s"%type(mod))
101
90 return depend(_require, *names)
102 return depend(_require, *names)
91
103
92 class Dependency(set):
104 class Dependency(set):
93 """An object for representing a set of msg_id dependencies.
105 """An object for representing a set of msg_id dependencies.
94
106
95 Subclassed from set().
107 Subclassed from set().
96
108
97 Parameters
109 Parameters
98 ----------
110 ----------
99 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
111 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
100 The msg_ids to depend on
112 The msg_ids to depend on
101 all : bool [default True]
113 all : bool [default True]
102 Whether the dependency should be considered met when *all* depending tasks have completed
114 Whether the dependency should be considered met when *all* depending tasks have completed
103 or only when *any* have been completed.
115 or only when *any* have been completed.
104 success : bool [default True]
116 success : bool [default True]
105 Whether to consider successes as fulfilling dependencies.
117 Whether to consider successes as fulfilling dependencies.
106 failure : bool [default False]
118 failure : bool [default False]
107 Whether to consider failures as fulfilling dependencies.
119 Whether to consider failures as fulfilling dependencies.
108
120
109 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
121 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
110 as soon as the first depended-upon task fails.
122 as soon as the first depended-upon task fails.
111 """
123 """
112
124
113 all=True
125 all=True
114 success=True
126 success=True
115 failure=True
127 failure=True
116
128
117 def __init__(self, dependencies=[], all=True, success=True, failure=False):
129 def __init__(self, dependencies=[], all=True, success=True, failure=False):
118 if isinstance(dependencies, dict):
130 if isinstance(dependencies, dict):
119 # load from dict
131 # load from dict
120 all = dependencies.get('all', True)
132 all = dependencies.get('all', True)
121 success = dependencies.get('success', success)
133 success = dependencies.get('success', success)
122 failure = dependencies.get('failure', failure)
134 failure = dependencies.get('failure', failure)
123 dependencies = dependencies.get('dependencies', [])
135 dependencies = dependencies.get('dependencies', [])
124 ids = []
136 ids = []
125
137
126 # extract ids from various sources:
138 # extract ids from various sources:
127 if isinstance(dependencies, (basestring, AsyncResult)):
139 if isinstance(dependencies, (basestring, AsyncResult)):
128 dependencies = [dependencies]
140 dependencies = [dependencies]
129 for d in dependencies:
141 for d in dependencies:
130 if isinstance(d, basestring):
142 if isinstance(d, basestring):
131 ids.append(d)
143 ids.append(d)
132 elif isinstance(d, AsyncResult):
144 elif isinstance(d, AsyncResult):
133 ids.extend(d.msg_ids)
145 ids.extend(d.msg_ids)
134 else:
146 else:
135 raise TypeError("invalid dependency type: %r"%type(d))
147 raise TypeError("invalid dependency type: %r"%type(d))
136
148
137 set.__init__(self, ids)
149 set.__init__(self, ids)
138 self.all = all
150 self.all = all
139 if not (success or failure):
151 if not (success or failure):
140 raise ValueError("Must depend on at least one of successes or failures!")
152 raise ValueError("Must depend on at least one of successes or failures!")
141 self.success=success
153 self.success=success
142 self.failure = failure
154 self.failure = failure
143
155
144 def check(self, completed, failed=None):
156 def check(self, completed, failed=None):
145 """check whether our dependencies have been met."""
157 """check whether our dependencies have been met."""
146 if len(self) == 0:
158 if len(self) == 0:
147 return True
159 return True
148 against = set()
160 against = set()
149 if self.success:
161 if self.success:
150 against = completed
162 against = completed
151 if failed is not None and self.failure:
163 if failed is not None and self.failure:
152 against = against.union(failed)
164 against = against.union(failed)
153 if self.all:
165 if self.all:
154 return self.issubset(against)
166 return self.issubset(against)
155 else:
167 else:
156 return not self.isdisjoint(against)
168 return not self.isdisjoint(against)
157
169
158 def unreachable(self, completed, failed=None):
170 def unreachable(self, completed, failed=None):
159 """return whether this dependency has become impossible."""
171 """return whether this dependency has become impossible."""
160 if len(self) == 0:
172 if len(self) == 0:
161 return False
173 return False
162 against = set()
174 against = set()
163 if not self.success:
175 if not self.success:
164 against = completed
176 against = completed
165 if failed is not None and not self.failure:
177 if failed is not None and not self.failure:
166 against = against.union(failed)
178 against = against.union(failed)
167 if self.all:
179 if self.all:
168 return not self.isdisjoint(against)
180 return not self.isdisjoint(against)
169 else:
181 else:
170 return self.issubset(against)
182 return self.issubset(against)
171
183
172
184
173 def as_dict(self):
185 def as_dict(self):
174 """Represent this dependency as a dict. For json compatibility."""
186 """Represent this dependency as a dict. For json compatibility."""
175 return dict(
187 return dict(
176 dependencies=list(self),
188 dependencies=list(self),
177 all=self.all,
189 all=self.all,
178 success=self.success,
190 success=self.success,
179 failure=self.failure
191 failure=self.failure
180 )
192 )
181
193
182
194
183 __all__ = ['depend', 'require', 'dependent', 'Dependency']
195 __all__ = ['depend', 'require', 'dependent', 'Dependency']
184
196
@@ -1,971 +1,971 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 Facilities for launching IPython processes asynchronously.
4 Facilities for launching IPython processes asynchronously.
5 """
5 """
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
8 # Copyright (C) 2008-2009 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import copy
18 import copy
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23
23
24 from signal import SIGINT, SIGTERM
24 from signal import SIGINT, SIGTERM
25 try:
25 try:
26 from signal import SIGKILL
26 from signal import SIGKILL
27 except ImportError:
27 except ImportError:
28 SIGKILL=SIGTERM
28 SIGKILL=SIGTERM
29
29
30 from subprocess import Popen, PIPE, STDOUT
30 from subprocess import Popen, PIPE, STDOUT
31 try:
31 try:
32 from subprocess import check_output
32 from subprocess import check_output
33 except ImportError:
33 except ImportError:
34 # pre-2.7, define check_output with Popen
34 # pre-2.7, define check_output with Popen
35 def check_output(*args, **kwargs):
35 def check_output(*args, **kwargs):
36 kwargs.update(dict(stdout=PIPE))
36 kwargs.update(dict(stdout=PIPE))
37 p = Popen(*args, **kwargs)
37 p = Popen(*args, **kwargs)
38 out,err = p.communicate()
38 out,err = p.communicate()
39 return out
39 return out
40
40
41 from zmq.eventloop import ioloop
41 from zmq.eventloop import ioloop
42
42
43 from IPython.external import Itpl
43 from IPython.external import Itpl
44 # from IPython.config.configurable import Configurable
44 # from IPython.config.configurable import Configurable
45 from IPython.utils.traitlets import Any, Str, Int, List, Unicode, Dict, Instance, CUnicode
45 from IPython.utils.traitlets import Any, Str, Int, List, Unicode, Dict, Instance, CUnicode
46 from IPython.utils.path import get_ipython_module_path
46 from IPython.utils.path import get_ipython_module_path
47 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
47 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
48
48
49 from .factory import LoggingFactory
49 from .factory import LoggingFactory
50
50
51 # load winhpcjob from IPython.kernel
51 # load winhpcjob only on Windows
52 try:
52 try:
53 from IPython.kernel.winhpcjob import (
53 from .winhpcjob import (
54 IPControllerTask, IPEngineTask,
54 IPControllerTask, IPEngineTask,
55 IPControllerJob, IPEngineSetJob
55 IPControllerJob, IPEngineSetJob
56 )
56 )
57 except ImportError:
57 except ImportError:
58 pass
58 pass
59
59
60
60
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62 # Paths to the kernel apps
62 # Paths to the kernel apps
63 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
64
64
65
65
66 ipclusterz_cmd_argv = pycmd2argv(get_ipython_module_path(
66 ipclusterz_cmd_argv = pycmd2argv(get_ipython_module_path(
67 'IPython.zmq.parallel.ipclusterapp'
67 'IPython.zmq.parallel.ipclusterapp'
68 ))
68 ))
69
69
70 ipenginez_cmd_argv = pycmd2argv(get_ipython_module_path(
70 ipenginez_cmd_argv = pycmd2argv(get_ipython_module_path(
71 'IPython.zmq.parallel.ipengineapp'
71 'IPython.zmq.parallel.ipengineapp'
72 ))
72 ))
73
73
74 ipcontrollerz_cmd_argv = pycmd2argv(get_ipython_module_path(
74 ipcontrollerz_cmd_argv = pycmd2argv(get_ipython_module_path(
75 'IPython.zmq.parallel.ipcontrollerapp'
75 'IPython.zmq.parallel.ipcontrollerapp'
76 ))
76 ))
77
77
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79 # Base launchers and errors
79 # Base launchers and errors
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81
81
82
82
83 class LauncherError(Exception):
83 class LauncherError(Exception):
84 pass
84 pass
85
85
86
86
87 class ProcessStateError(LauncherError):
87 class ProcessStateError(LauncherError):
88 pass
88 pass
89
89
90
90
91 class UnknownStatus(LauncherError):
91 class UnknownStatus(LauncherError):
92 pass
92 pass
93
93
94
94
95 class BaseLauncher(LoggingFactory):
95 class BaseLauncher(LoggingFactory):
96 """An asbtraction for starting, stopping and signaling a process."""
96 """An asbtraction for starting, stopping and signaling a process."""
97
97
98 # In all of the launchers, the work_dir is where child processes will be
98 # In all of the launchers, the work_dir is where child processes will be
99 # run. This will usually be the cluster_dir, but may not be. any work_dir
99 # run. This will usually be the cluster_dir, but may not be. any work_dir
100 # passed into the __init__ method will override the config value.
100 # passed into the __init__ method will override the config value.
101 # This should not be used to set the work_dir for the actual engine
101 # This should not be used to set the work_dir for the actual engine
102 # and controller. Instead, use their own config files or the
102 # and controller. Instead, use their own config files or the
103 # controller_args, engine_args attributes of the launchers to add
103 # controller_args, engine_args attributes of the launchers to add
104 # the --work-dir option.
104 # the --work-dir option.
105 work_dir = Unicode(u'.')
105 work_dir = Unicode(u'.')
106 loop = Instance('zmq.eventloop.ioloop.IOLoop')
106 loop = Instance('zmq.eventloop.ioloop.IOLoop')
107
107
108 start_data = Any()
108 start_data = Any()
109 stop_data = Any()
109 stop_data = Any()
110
110
111 def _loop_default(self):
111 def _loop_default(self):
112 return ioloop.IOLoop.instance()
112 return ioloop.IOLoop.instance()
113
113
114 def __init__(self, work_dir=u'.', config=None, **kwargs):
114 def __init__(self, work_dir=u'.', config=None, **kwargs):
115 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
115 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
116 self.state = 'before' # can be before, running, after
116 self.state = 'before' # can be before, running, after
117 self.stop_callbacks = []
117 self.stop_callbacks = []
118 self.start_data = None
118 self.start_data = None
119 self.stop_data = None
119 self.stop_data = None
120
120
121 @property
121 @property
122 def args(self):
122 def args(self):
123 """A list of cmd and args that will be used to start the process.
123 """A list of cmd and args that will be used to start the process.
124
124
125 This is what is passed to :func:`spawnProcess` and the first element
125 This is what is passed to :func:`spawnProcess` and the first element
126 will be the process name.
126 will be the process name.
127 """
127 """
128 return self.find_args()
128 return self.find_args()
129
129
130 def find_args(self):
130 def find_args(self):
131 """The ``.args`` property calls this to find the args list.
131 """The ``.args`` property calls this to find the args list.
132
132
133 Subcommand should implement this to construct the cmd and args.
133 Subcommand should implement this to construct the cmd and args.
134 """
134 """
135 raise NotImplementedError('find_args must be implemented in a subclass')
135 raise NotImplementedError('find_args must be implemented in a subclass')
136
136
137 @property
137 @property
138 def arg_str(self):
138 def arg_str(self):
139 """The string form of the program arguments."""
139 """The string form of the program arguments."""
140 return ' '.join(self.args)
140 return ' '.join(self.args)
141
141
142 @property
142 @property
143 def running(self):
143 def running(self):
144 """Am I running."""
144 """Am I running."""
145 if self.state == 'running':
145 if self.state == 'running':
146 return True
146 return True
147 else:
147 else:
148 return False
148 return False
149
149
150 def start(self):
150 def start(self):
151 """Start the process.
151 """Start the process.
152
152
153 This must return a deferred that fires with information about the
153 This must return a deferred that fires with information about the
154 process starting (like a pid, job id, etc.).
154 process starting (like a pid, job id, etc.).
155 """
155 """
156 raise NotImplementedError('start must be implemented in a subclass')
156 raise NotImplementedError('start must be implemented in a subclass')
157
157
158 def stop(self):
158 def stop(self):
159 """Stop the process and notify observers of stopping.
159 """Stop the process and notify observers of stopping.
160
160
161 This must return a deferred that fires with information about the
161 This must return a deferred that fires with information about the
162 processing stopping, like errors that occur while the process is
162 processing stopping, like errors that occur while the process is
163 attempting to be shut down. This deferred won't fire when the process
163 attempting to be shut down. This deferred won't fire when the process
164 actually stops. To observe the actual process stopping, see
164 actually stops. To observe the actual process stopping, see
165 :func:`observe_stop`.
165 :func:`observe_stop`.
166 """
166 """
167 raise NotImplementedError('stop must be implemented in a subclass')
167 raise NotImplementedError('stop must be implemented in a subclass')
168
168
169 def on_stop(self, f):
169 def on_stop(self, f):
170 """Get a deferred that will fire when the process stops.
170 """Get a deferred that will fire when the process stops.
171
171
172 The deferred will fire with data that contains information about
172 The deferred will fire with data that contains information about
173 the exit status of the process.
173 the exit status of the process.
174 """
174 """
175 if self.state=='after':
175 if self.state=='after':
176 return f(self.stop_data)
176 return f(self.stop_data)
177 else:
177 else:
178 self.stop_callbacks.append(f)
178 self.stop_callbacks.append(f)
179
179
180 def notify_start(self, data):
180 def notify_start(self, data):
181 """Call this to trigger startup actions.
181 """Call this to trigger startup actions.
182
182
183 This logs the process startup and sets the state to 'running'. It is
183 This logs the process startup and sets the state to 'running'. It is
184 a pass-through so it can be used as a callback.
184 a pass-through so it can be used as a callback.
185 """
185 """
186
186
187 self.log.info('Process %r started: %r' % (self.args[0], data))
187 self.log.info('Process %r started: %r' % (self.args[0], data))
188 self.start_data = data
188 self.start_data = data
189 self.state = 'running'
189 self.state = 'running'
190 return data
190 return data
191
191
192 def notify_stop(self, data):
192 def notify_stop(self, data):
193 """Call this to trigger process stop actions.
193 """Call this to trigger process stop actions.
194
194
195 This logs the process stopping and sets the state to 'after'. Call
195 This logs the process stopping and sets the state to 'after'. Call
196 this to trigger all the deferreds from :func:`observe_stop`."""
196 this to trigger all the deferreds from :func:`observe_stop`."""
197
197
198 self.log.info('Process %r stopped: %r' % (self.args[0], data))
198 self.log.info('Process %r stopped: %r' % (self.args[0], data))
199 self.stop_data = data
199 self.stop_data = data
200 self.state = 'after'
200 self.state = 'after'
201 for i in range(len(self.stop_callbacks)):
201 for i in range(len(self.stop_callbacks)):
202 d = self.stop_callbacks.pop()
202 d = self.stop_callbacks.pop()
203 d(data)
203 d(data)
204 return data
204 return data
205
205
206 def signal(self, sig):
206 def signal(self, sig):
207 """Signal the process.
207 """Signal the process.
208
208
209 Return a semi-meaningless deferred after signaling the process.
209 Return a semi-meaningless deferred after signaling the process.
210
210
211 Parameters
211 Parameters
212 ----------
212 ----------
213 sig : str or int
213 sig : str or int
214 'KILL', 'INT', etc., or any signal number
214 'KILL', 'INT', etc., or any signal number
215 """
215 """
216 raise NotImplementedError('signal must be implemented in a subclass')
216 raise NotImplementedError('signal must be implemented in a subclass')
217
217
218
218
219 #-----------------------------------------------------------------------------
219 #-----------------------------------------------------------------------------
220 # Local process launchers
220 # Local process launchers
221 #-----------------------------------------------------------------------------
221 #-----------------------------------------------------------------------------
222
222
223
223
224 class LocalProcessLauncher(BaseLauncher):
224 class LocalProcessLauncher(BaseLauncher):
225 """Start and stop an external process in an asynchronous manner.
225 """Start and stop an external process in an asynchronous manner.
226
226
227 This will launch the external process with a working directory of
227 This will launch the external process with a working directory of
228 ``self.work_dir``.
228 ``self.work_dir``.
229 """
229 """
230
230
231 # This is used to to construct self.args, which is passed to
231 # This is used to to construct self.args, which is passed to
232 # spawnProcess.
232 # spawnProcess.
233 cmd_and_args = List([])
233 cmd_and_args = List([])
234 poll_frequency = Int(100) # in ms
234 poll_frequency = Int(100) # in ms
235
235
236 def __init__(self, work_dir=u'.', config=None, **kwargs):
236 def __init__(self, work_dir=u'.', config=None, **kwargs):
237 super(LocalProcessLauncher, self).__init__(
237 super(LocalProcessLauncher, self).__init__(
238 work_dir=work_dir, config=config, **kwargs
238 work_dir=work_dir, config=config, **kwargs
239 )
239 )
240 self.process = None
240 self.process = None
241 self.start_deferred = None
241 self.start_deferred = None
242 self.poller = None
242 self.poller = None
243
243
244 def find_args(self):
244 def find_args(self):
245 return self.cmd_and_args
245 return self.cmd_and_args
246
246
247 def start(self):
247 def start(self):
248 if self.state == 'before':
248 if self.state == 'before':
249 self.process = Popen(self.args,
249 self.process = Popen(self.args,
250 stdout=PIPE,stderr=PIPE,stdin=PIPE,
250 stdout=PIPE,stderr=PIPE,stdin=PIPE,
251 env=os.environ,
251 env=os.environ,
252 cwd=self.work_dir
252 cwd=self.work_dir
253 )
253 )
254
254
255 self.loop.add_handler(self.process.stdout.fileno(), self.handle_stdout, self.loop.READ)
255 self.loop.add_handler(self.process.stdout.fileno(), self.handle_stdout, self.loop.READ)
256 self.loop.add_handler(self.process.stderr.fileno(), self.handle_stderr, self.loop.READ)
256 self.loop.add_handler(self.process.stderr.fileno(), self.handle_stderr, self.loop.READ)
257 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
257 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
258 self.poller.start()
258 self.poller.start()
259 self.notify_start(self.process.pid)
259 self.notify_start(self.process.pid)
260 else:
260 else:
261 s = 'The process was already started and has state: %r' % self.state
261 s = 'The process was already started and has state: %r' % self.state
262 raise ProcessStateError(s)
262 raise ProcessStateError(s)
263
263
264 def stop(self):
264 def stop(self):
265 return self.interrupt_then_kill()
265 return self.interrupt_then_kill()
266
266
267 def signal(self, sig):
267 def signal(self, sig):
268 if self.state == 'running':
268 if self.state == 'running':
269 self.process.send_signal(sig)
269 self.process.send_signal(sig)
270
270
271 def interrupt_then_kill(self, delay=2.0):
271 def interrupt_then_kill(self, delay=2.0):
272 """Send INT, wait a delay and then send KILL."""
272 """Send INT, wait a delay and then send KILL."""
273 self.signal(SIGINT)
273 self.signal(SIGINT)
274 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
274 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
275 self.killer.start()
275 self.killer.start()
276
276
277 # callbacks, etc:
277 # callbacks, etc:
278
278
279 def handle_stdout(self, fd, events):
279 def handle_stdout(self, fd, events):
280 line = self.process.stdout.readline()
280 line = self.process.stdout.readline()
281 # a stopped process will be readable but return empty strings
281 # a stopped process will be readable but return empty strings
282 if line:
282 if line:
283 self.log.info(line[:-1])
283 self.log.info(line[:-1])
284 else:
284 else:
285 self.poll()
285 self.poll()
286
286
287 def handle_stderr(self, fd, events):
287 def handle_stderr(self, fd, events):
288 line = self.process.stderr.readline()
288 line = self.process.stderr.readline()
289 # a stopped process will be readable but return empty strings
289 # a stopped process will be readable but return empty strings
290 if line:
290 if line:
291 self.log.error(line[:-1])
291 self.log.error(line[:-1])
292 else:
292 else:
293 self.poll()
293 self.poll()
294
294
295 def poll(self):
295 def poll(self):
296 status = self.process.poll()
296 status = self.process.poll()
297 if status is not None:
297 if status is not None:
298 self.poller.stop()
298 self.poller.stop()
299 self.loop.remove_handler(self.process.stdout.fileno())
299 self.loop.remove_handler(self.process.stdout.fileno())
300 self.loop.remove_handler(self.process.stderr.fileno())
300 self.loop.remove_handler(self.process.stderr.fileno())
301 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
301 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
302 return status
302 return status
303
303
304 class LocalControllerLauncher(LocalProcessLauncher):
304 class LocalControllerLauncher(LocalProcessLauncher):
305 """Launch a controller as a regular external process."""
305 """Launch a controller as a regular external process."""
306
306
307 controller_cmd = List(ipcontrollerz_cmd_argv, config=True)
307 controller_cmd = List(ipcontrollerz_cmd_argv, config=True)
308 # Command line arguments to ipcontroller.
308 # Command line arguments to ipcontroller.
309 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
309 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
310
310
311 def find_args(self):
311 def find_args(self):
312 return self.controller_cmd + self.controller_args
312 return self.controller_cmd + self.controller_args
313
313
314 def start(self, cluster_dir):
314 def start(self, cluster_dir):
315 """Start the controller by cluster_dir."""
315 """Start the controller by cluster_dir."""
316 self.controller_args.extend(['--cluster-dir', cluster_dir])
316 self.controller_args.extend(['--cluster-dir', cluster_dir])
317 self.cluster_dir = unicode(cluster_dir)
317 self.cluster_dir = unicode(cluster_dir)
318 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
318 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
319 return super(LocalControllerLauncher, self).start()
319 return super(LocalControllerLauncher, self).start()
320
320
321
321
322 class LocalEngineLauncher(LocalProcessLauncher):
322 class LocalEngineLauncher(LocalProcessLauncher):
323 """Launch a single engine as a regular externall process."""
323 """Launch a single engine as a regular externall process."""
324
324
325 engine_cmd = List(ipenginez_cmd_argv, config=True)
325 engine_cmd = List(ipenginez_cmd_argv, config=True)
326 # Command line arguments for ipengine.
326 # Command line arguments for ipengine.
327 engine_args = List(
327 engine_args = List(
328 ['--log-to-file','--log-level', str(logging.INFO)], config=True
328 ['--log-to-file','--log-level', str(logging.INFO)], config=True
329 )
329 )
330
330
331 def find_args(self):
331 def find_args(self):
332 return self.engine_cmd + self.engine_args
332 return self.engine_cmd + self.engine_args
333
333
334 def start(self, cluster_dir):
334 def start(self, cluster_dir):
335 """Start the engine by cluster_dir."""
335 """Start the engine by cluster_dir."""
336 self.engine_args.extend(['--cluster-dir', cluster_dir])
336 self.engine_args.extend(['--cluster-dir', cluster_dir])
337 self.cluster_dir = unicode(cluster_dir)
337 self.cluster_dir = unicode(cluster_dir)
338 return super(LocalEngineLauncher, self).start()
338 return super(LocalEngineLauncher, self).start()
339
339
340
340
341 class LocalEngineSetLauncher(BaseLauncher):
341 class LocalEngineSetLauncher(BaseLauncher):
342 """Launch a set of engines as regular external processes."""
342 """Launch a set of engines as regular external processes."""
343
343
344 # Command line arguments for ipengine.
344 # Command line arguments for ipengine.
345 engine_args = List(
345 engine_args = List(
346 ['--log-to-file','--log-level', str(logging.INFO)], config=True
346 ['--log-to-file','--log-level', str(logging.INFO)], config=True
347 )
347 )
348 # launcher class
348 # launcher class
349 launcher_class = LocalEngineLauncher
349 launcher_class = LocalEngineLauncher
350
350
351 launchers = Dict()
351 launchers = Dict()
352 stop_data = Dict()
352 stop_data = Dict()
353
353
354 def __init__(self, work_dir=u'.', config=None, **kwargs):
354 def __init__(self, work_dir=u'.', config=None, **kwargs):
355 super(LocalEngineSetLauncher, self).__init__(
355 super(LocalEngineSetLauncher, self).__init__(
356 work_dir=work_dir, config=config, **kwargs
356 work_dir=work_dir, config=config, **kwargs
357 )
357 )
358 self.stop_data = {}
358 self.stop_data = {}
359
359
360 def start(self, n, cluster_dir):
360 def start(self, n, cluster_dir):
361 """Start n engines by profile or cluster_dir."""
361 """Start n engines by profile or cluster_dir."""
362 self.cluster_dir = unicode(cluster_dir)
362 self.cluster_dir = unicode(cluster_dir)
363 dlist = []
363 dlist = []
364 for i in range(n):
364 for i in range(n):
365 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
365 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
366 # Copy the engine args over to each engine launcher.
366 # Copy the engine args over to each engine launcher.
367 el.engine_args = copy.deepcopy(self.engine_args)
367 el.engine_args = copy.deepcopy(self.engine_args)
368 el.on_stop(self._notice_engine_stopped)
368 el.on_stop(self._notice_engine_stopped)
369 d = el.start(cluster_dir)
369 d = el.start(cluster_dir)
370 if i==0:
370 if i==0:
371 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
371 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
372 self.launchers[i] = el
372 self.launchers[i] = el
373 dlist.append(d)
373 dlist.append(d)
374 self.notify_start(dlist)
374 self.notify_start(dlist)
375 # The consumeErrors here could be dangerous
375 # The consumeErrors here could be dangerous
376 # dfinal = gatherBoth(dlist, consumeErrors=True)
376 # dfinal = gatherBoth(dlist, consumeErrors=True)
377 # dfinal.addCallback(self.notify_start)
377 # dfinal.addCallback(self.notify_start)
378 return dlist
378 return dlist
379
379
380 def find_args(self):
380 def find_args(self):
381 return ['engine set']
381 return ['engine set']
382
382
383 def signal(self, sig):
383 def signal(self, sig):
384 dlist = []
384 dlist = []
385 for el in self.launchers.itervalues():
385 for el in self.launchers.itervalues():
386 d = el.signal(sig)
386 d = el.signal(sig)
387 dlist.append(d)
387 dlist.append(d)
388 # dfinal = gatherBoth(dlist, consumeErrors=True)
388 # dfinal = gatherBoth(dlist, consumeErrors=True)
389 return dlist
389 return dlist
390
390
391 def interrupt_then_kill(self, delay=1.0):
391 def interrupt_then_kill(self, delay=1.0):
392 dlist = []
392 dlist = []
393 for el in self.launchers.itervalues():
393 for el in self.launchers.itervalues():
394 d = el.interrupt_then_kill(delay)
394 d = el.interrupt_then_kill(delay)
395 dlist.append(d)
395 dlist.append(d)
396 # dfinal = gatherBoth(dlist, consumeErrors=True)
396 # dfinal = gatherBoth(dlist, consumeErrors=True)
397 return dlist
397 return dlist
398
398
399 def stop(self):
399 def stop(self):
400 return self.interrupt_then_kill()
400 return self.interrupt_then_kill()
401
401
402 def _notice_engine_stopped(self, data):
402 def _notice_engine_stopped(self, data):
403 pid = data['pid']
403 pid = data['pid']
404 for idx,el in self.launchers.iteritems():
404 for idx,el in self.launchers.iteritems():
405 if el.process.pid == pid:
405 if el.process.pid == pid:
406 break
406 break
407 self.launchers.pop(idx)
407 self.launchers.pop(idx)
408 self.stop_data[idx] = data
408 self.stop_data[idx] = data
409 if not self.launchers:
409 if not self.launchers:
410 self.notify_stop(self.stop_data)
410 self.notify_stop(self.stop_data)
411
411
412
412
413 #-----------------------------------------------------------------------------
413 #-----------------------------------------------------------------------------
414 # MPIExec launchers
414 # MPIExec launchers
415 #-----------------------------------------------------------------------------
415 #-----------------------------------------------------------------------------
416
416
417
417
418 class MPIExecLauncher(LocalProcessLauncher):
418 class MPIExecLauncher(LocalProcessLauncher):
419 """Launch an external process using mpiexec."""
419 """Launch an external process using mpiexec."""
420
420
421 # The mpiexec command to use in starting the process.
421 # The mpiexec command to use in starting the process.
422 mpi_cmd = List(['mpiexec'], config=True)
422 mpi_cmd = List(['mpiexec'], config=True)
423 # The command line arguments to pass to mpiexec.
423 # The command line arguments to pass to mpiexec.
424 mpi_args = List([], config=True)
424 mpi_args = List([], config=True)
425 # The program to start using mpiexec.
425 # The program to start using mpiexec.
426 program = List(['date'], config=True)
426 program = List(['date'], config=True)
427 # The command line argument to the program.
427 # The command line argument to the program.
428 program_args = List([], config=True)
428 program_args = List([], config=True)
429 # The number of instances of the program to start.
429 # The number of instances of the program to start.
430 n = Int(1, config=True)
430 n = Int(1, config=True)
431
431
432 def find_args(self):
432 def find_args(self):
433 """Build self.args using all the fields."""
433 """Build self.args using all the fields."""
434 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
434 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
435 self.program + self.program_args
435 self.program + self.program_args
436
436
437 def start(self, n):
437 def start(self, n):
438 """Start n instances of the program using mpiexec."""
438 """Start n instances of the program using mpiexec."""
439 self.n = n
439 self.n = n
440 return super(MPIExecLauncher, self).start()
440 return super(MPIExecLauncher, self).start()
441
441
442
442
443 class MPIExecControllerLauncher(MPIExecLauncher):
443 class MPIExecControllerLauncher(MPIExecLauncher):
444 """Launch a controller using mpiexec."""
444 """Launch a controller using mpiexec."""
445
445
446 controller_cmd = List(ipcontrollerz_cmd_argv, config=True)
446 controller_cmd = List(ipcontrollerz_cmd_argv, config=True)
447 # Command line arguments to ipcontroller.
447 # Command line arguments to ipcontroller.
448 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
448 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
449 n = Int(1, config=False)
449 n = Int(1, config=False)
450
450
451 def start(self, cluster_dir):
451 def start(self, cluster_dir):
452 """Start the controller by cluster_dir."""
452 """Start the controller by cluster_dir."""
453 self.controller_args.extend(['--cluster-dir', cluster_dir])
453 self.controller_args.extend(['--cluster-dir', cluster_dir])
454 self.cluster_dir = unicode(cluster_dir)
454 self.cluster_dir = unicode(cluster_dir)
455 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
455 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
456 return super(MPIExecControllerLauncher, self).start(1)
456 return super(MPIExecControllerLauncher, self).start(1)
457
457
458 def find_args(self):
458 def find_args(self):
459 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
459 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
460 self.controller_cmd + self.controller_args
460 self.controller_cmd + self.controller_args
461
461
462
462
463 class MPIExecEngineSetLauncher(MPIExecLauncher):
463 class MPIExecEngineSetLauncher(MPIExecLauncher):
464
464
465 program = List(ipenginez_cmd_argv, config=True)
465 program = List(ipenginez_cmd_argv, config=True)
466 # Command line arguments for ipengine.
466 # Command line arguments for ipengine.
467 program_args = List(
467 program_args = List(
468 ['--log-to-file','--log-level', str(logging.INFO)], config=True
468 ['--log-to-file','--log-level', str(logging.INFO)], config=True
469 )
469 )
470 n = Int(1, config=True)
470 n = Int(1, config=True)
471
471
472 def start(self, n, cluster_dir):
472 def start(self, n, cluster_dir):
473 """Start n engines by profile or cluster_dir."""
473 """Start n engines by profile or cluster_dir."""
474 self.program_args.extend(['--cluster-dir', cluster_dir])
474 self.program_args.extend(['--cluster-dir', cluster_dir])
475 self.cluster_dir = unicode(cluster_dir)
475 self.cluster_dir = unicode(cluster_dir)
476 self.n = n
476 self.n = n
477 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
477 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
478 return super(MPIExecEngineSetLauncher, self).start(n)
478 return super(MPIExecEngineSetLauncher, self).start(n)
479
479
480 #-----------------------------------------------------------------------------
480 #-----------------------------------------------------------------------------
481 # SSH launchers
481 # SSH launchers
482 #-----------------------------------------------------------------------------
482 #-----------------------------------------------------------------------------
483
483
484 # TODO: Get SSH Launcher working again.
484 # TODO: Get SSH Launcher working again.
485
485
486 class SSHLauncher(LocalProcessLauncher):
486 class SSHLauncher(LocalProcessLauncher):
487 """A minimal launcher for ssh.
487 """A minimal launcher for ssh.
488
488
489 To be useful this will probably have to be extended to use the ``sshx``
489 To be useful this will probably have to be extended to use the ``sshx``
490 idea for environment variables. There could be other things this needs
490 idea for environment variables. There could be other things this needs
491 as well.
491 as well.
492 """
492 """
493
493
494 ssh_cmd = List(['ssh'], config=True)
494 ssh_cmd = List(['ssh'], config=True)
495 ssh_args = List(['-tt'], config=True)
495 ssh_args = List(['-tt'], config=True)
496 program = List(['date'], config=True)
496 program = List(['date'], config=True)
497 program_args = List([], config=True)
497 program_args = List([], config=True)
498 hostname = CUnicode('', config=True)
498 hostname = CUnicode('', config=True)
499 user = CUnicode('', config=True)
499 user = CUnicode('', config=True)
500 location = CUnicode('')
500 location = CUnicode('')
501
501
502 def _hostname_changed(self, name, old, new):
502 def _hostname_changed(self, name, old, new):
503 if self.user:
503 if self.user:
504 self.location = u'%s@%s' % (self.user, new)
504 self.location = u'%s@%s' % (self.user, new)
505 else:
505 else:
506 self.location = new
506 self.location = new
507
507
508 def _user_changed(self, name, old, new):
508 def _user_changed(self, name, old, new):
509 self.location = u'%s@%s' % (new, self.hostname)
509 self.location = u'%s@%s' % (new, self.hostname)
510
510
511 def find_args(self):
511 def find_args(self):
512 return self.ssh_cmd + self.ssh_args + [self.location] + \
512 return self.ssh_cmd + self.ssh_args + [self.location] + \
513 self.program + self.program_args
513 self.program + self.program_args
514
514
515 def start(self, cluster_dir, hostname=None, user=None):
515 def start(self, cluster_dir, hostname=None, user=None):
516 self.cluster_dir = unicode(cluster_dir)
516 self.cluster_dir = unicode(cluster_dir)
517 if hostname is not None:
517 if hostname is not None:
518 self.hostname = hostname
518 self.hostname = hostname
519 if user is not None:
519 if user is not None:
520 self.user = user
520 self.user = user
521
521
522 return super(SSHLauncher, self).start()
522 return super(SSHLauncher, self).start()
523
523
524 def signal(self, sig):
524 def signal(self, sig):
525 if self.state == 'running':
525 if self.state == 'running':
526 # send escaped ssh connection-closer
526 # send escaped ssh connection-closer
527 self.process.stdin.write('~.')
527 self.process.stdin.write('~.')
528 self.process.stdin.flush()
528 self.process.stdin.flush()
529
529
530
530
531
531
532 class SSHControllerLauncher(SSHLauncher):
532 class SSHControllerLauncher(SSHLauncher):
533
533
534 program = List(ipcontrollerz_cmd_argv, config=True)
534 program = List(ipcontrollerz_cmd_argv, config=True)
535 # Command line arguments to ipcontroller.
535 # Command line arguments to ipcontroller.
536 program_args = List(['-r', '--log-to-file','--log-level', str(logging.INFO)], config=True)
536 program_args = List(['-r', '--log-to-file','--log-level', str(logging.INFO)], config=True)
537
537
538
538
539 class SSHEngineLauncher(SSHLauncher):
539 class SSHEngineLauncher(SSHLauncher):
540 program = List(ipenginez_cmd_argv, config=True)
540 program = List(ipenginez_cmd_argv, config=True)
541 # Command line arguments for ipengine.
541 # Command line arguments for ipengine.
542 program_args = List(
542 program_args = List(
543 ['--log-to-file','--log-level', str(logging.INFO)], config=True
543 ['--log-to-file','--log-level', str(logging.INFO)], config=True
544 )
544 )
545
545
546 class SSHEngineSetLauncher(LocalEngineSetLauncher):
546 class SSHEngineSetLauncher(LocalEngineSetLauncher):
547 launcher_class = SSHEngineLauncher
547 launcher_class = SSHEngineLauncher
548 engines = Dict(config=True)
548 engines = Dict(config=True)
549
549
550 def start(self, n, cluster_dir):
550 def start(self, n, cluster_dir):
551 """Start engines by profile or cluster_dir.
551 """Start engines by profile or cluster_dir.
552 `n` is ignored, and the `engines` config property is used instead.
552 `n` is ignored, and the `engines` config property is used instead.
553 """
553 """
554
554
555 self.cluster_dir = unicode(cluster_dir)
555 self.cluster_dir = unicode(cluster_dir)
556 dlist = []
556 dlist = []
557 for host, n in self.engines.iteritems():
557 for host, n in self.engines.iteritems():
558 if isinstance(n, (tuple, list)):
558 if isinstance(n, (tuple, list)):
559 n, args = n
559 n, args = n
560 else:
560 else:
561 args = copy.deepcopy(self.engine_args)
561 args = copy.deepcopy(self.engine_args)
562
562
563 if '@' in host:
563 if '@' in host:
564 user,host = host.split('@',1)
564 user,host = host.split('@',1)
565 else:
565 else:
566 user=None
566 user=None
567 for i in range(n):
567 for i in range(n):
568 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
568 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
569
569
570 # Copy the engine args over to each engine launcher.
570 # Copy the engine args over to each engine launcher.
571 i
571 i
572 el.program_args = args
572 el.program_args = args
573 el.on_stop(self._notice_engine_stopped)
573 el.on_stop(self._notice_engine_stopped)
574 d = el.start(cluster_dir, user=user, hostname=host)
574 d = el.start(cluster_dir, user=user, hostname=host)
575 if i==0:
575 if i==0:
576 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
576 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
577 self.launchers[host+str(i)] = el
577 self.launchers[host+str(i)] = el
578 dlist.append(d)
578 dlist.append(d)
579 self.notify_start(dlist)
579 self.notify_start(dlist)
580 return dlist
580 return dlist
581
581
582
582
583
583
584 #-----------------------------------------------------------------------------
584 #-----------------------------------------------------------------------------
585 # Windows HPC Server 2008 scheduler launchers
585 # Windows HPC Server 2008 scheduler launchers
586 #-----------------------------------------------------------------------------
586 #-----------------------------------------------------------------------------
587
587
588
588
589 # This is only used on Windows.
589 # This is only used on Windows.
590 def find_job_cmd():
590 def find_job_cmd():
591 if os.name=='nt':
591 if os.name=='nt':
592 try:
592 try:
593 return find_cmd('job')
593 return find_cmd('job')
594 except FindCmdError:
594 except FindCmdError:
595 return 'job'
595 return 'job'
596 else:
596 else:
597 return 'job'
597 return 'job'
598
598
599
599
600 class WindowsHPCLauncher(BaseLauncher):
600 class WindowsHPCLauncher(BaseLauncher):
601
601
602 # A regular expression used to get the job id from the output of the
602 # A regular expression used to get the job id from the output of the
603 # submit_command.
603 # submit_command.
604 job_id_regexp = Str(r'\d+', config=True)
604 job_id_regexp = Str(r'\d+', config=True)
605 # The filename of the instantiated job script.
605 # The filename of the instantiated job script.
606 job_file_name = CUnicode(u'ipython_job.xml', config=True)
606 job_file_name = CUnicode(u'ipython_job.xml', config=True)
607 # The full path to the instantiated job script. This gets made dynamically
607 # The full path to the instantiated job script. This gets made dynamically
608 # by combining the work_dir with the job_file_name.
608 # by combining the work_dir with the job_file_name.
609 job_file = CUnicode(u'')
609 job_file = CUnicode(u'')
610 # The hostname of the scheduler to submit the job to
610 # The hostname of the scheduler to submit the job to
611 scheduler = CUnicode('', config=True)
611 scheduler = CUnicode('', config=True)
612 job_cmd = CUnicode(find_job_cmd(), config=True)
612 job_cmd = CUnicode(find_job_cmd(), config=True)
613
613
614 def __init__(self, work_dir=u'.', config=None, **kwargs):
614 def __init__(self, work_dir=u'.', config=None, **kwargs):
615 super(WindowsHPCLauncher, self).__init__(
615 super(WindowsHPCLauncher, self).__init__(
616 work_dir=work_dir, config=config, **kwargs
616 work_dir=work_dir, config=config, **kwargs
617 )
617 )
618
618
619 @property
619 @property
620 def job_file(self):
620 def job_file(self):
621 return os.path.join(self.work_dir, self.job_file_name)
621 return os.path.join(self.work_dir, self.job_file_name)
622
622
623 def write_job_file(self, n):
623 def write_job_file(self, n):
624 raise NotImplementedError("Implement write_job_file in a subclass.")
624 raise NotImplementedError("Implement write_job_file in a subclass.")
625
625
626 def find_args(self):
626 def find_args(self):
627 return [u'job.exe']
627 return [u'job.exe']
628
628
629 def parse_job_id(self, output):
629 def parse_job_id(self, output):
630 """Take the output of the submit command and return the job id."""
630 """Take the output of the submit command and return the job id."""
631 m = re.search(self.job_id_regexp, output)
631 m = re.search(self.job_id_regexp, output)
632 if m is not None:
632 if m is not None:
633 job_id = m.group()
633 job_id = m.group()
634 else:
634 else:
635 raise LauncherError("Job id couldn't be determined: %s" % output)
635 raise LauncherError("Job id couldn't be determined: %s" % output)
636 self.job_id = job_id
636 self.job_id = job_id
637 self.log.info('Job started with job id: %r' % job_id)
637 self.log.info('Job started with job id: %r' % job_id)
638 return job_id
638 return job_id
639
639
640 def start(self, n):
640 def start(self, n):
641 """Start n copies of the process using the Win HPC job scheduler."""
641 """Start n copies of the process using the Win HPC job scheduler."""
642 self.write_job_file(n)
642 self.write_job_file(n)
643 args = [
643 args = [
644 'submit',
644 'submit',
645 '/jobfile:%s' % self.job_file,
645 '/jobfile:%s' % self.job_file,
646 '/scheduler:%s' % self.scheduler
646 '/scheduler:%s' % self.scheduler
647 ]
647 ]
648 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
648 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
649 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
649 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
650 output = check_output([self.job_cmd]+args,
650 output = check_output([self.job_cmd]+args,
651 env=os.environ,
651 env=os.environ,
652 cwd=self.work_dir,
652 cwd=self.work_dir,
653 stderr=STDOUT
653 stderr=STDOUT
654 )
654 )
655 job_id = self.parse_job_id(output)
655 job_id = self.parse_job_id(output)
656 self.notify_start(job_id)
656 self.notify_start(job_id)
657 return job_id
657 return job_id
658
658
659 def stop(self):
659 def stop(self):
660 args = [
660 args = [
661 'cancel',
661 'cancel',
662 self.job_id,
662 self.job_id,
663 '/scheduler:%s' % self.scheduler
663 '/scheduler:%s' % self.scheduler
664 ]
664 ]
665 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
665 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
666 try:
666 try:
667 output = check_output([self.job_cmd]+args,
667 output = check_output([self.job_cmd]+args,
668 env=os.environ,
668 env=os.environ,
669 cwd=self.work_dir,
669 cwd=self.work_dir,
670 stderr=STDOUT
670 stderr=STDOUT
671 )
671 )
672 except:
672 except:
673 output = 'The job already appears to be stoppped: %r' % self.job_id
673 output = 'The job already appears to be stoppped: %r' % self.job_id
674 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
674 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
675 return output
675 return output
676
676
677
677
678 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
678 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
679
679
680 job_file_name = CUnicode(u'ipcontroller_job.xml', config=True)
680 job_file_name = CUnicode(u'ipcontroller_job.xml', config=True)
681 extra_args = List([], config=False)
681 extra_args = List([], config=False)
682
682
683 def write_job_file(self, n):
683 def write_job_file(self, n):
684 job = IPControllerJob(config=self.config)
684 job = IPControllerJob(config=self.config)
685
685
686 t = IPControllerTask(config=self.config)
686 t = IPControllerTask(config=self.config)
687 # The tasks work directory is *not* the actual work directory of
687 # The tasks work directory is *not* the actual work directory of
688 # the controller. It is used as the base path for the stdout/stderr
688 # the controller. It is used as the base path for the stdout/stderr
689 # files that the scheduler redirects to.
689 # files that the scheduler redirects to.
690 t.work_directory = self.cluster_dir
690 t.work_directory = self.cluster_dir
691 # Add the --cluster-dir and from self.start().
691 # Add the --cluster-dir and from self.start().
692 t.controller_args.extend(self.extra_args)
692 t.controller_args.extend(self.extra_args)
693 job.add_task(t)
693 job.add_task(t)
694
694
695 self.log.info("Writing job description file: %s" % self.job_file)
695 self.log.info("Writing job description file: %s" % self.job_file)
696 job.write(self.job_file)
696 job.write(self.job_file)
697
697
698 @property
698 @property
699 def job_file(self):
699 def job_file(self):
700 return os.path.join(self.cluster_dir, self.job_file_name)
700 return os.path.join(self.cluster_dir, self.job_file_name)
701
701
702 def start(self, cluster_dir):
702 def start(self, cluster_dir):
703 """Start the controller by cluster_dir."""
703 """Start the controller by cluster_dir."""
704 self.extra_args = ['--cluster-dir', cluster_dir]
704 self.extra_args = ['--cluster-dir', cluster_dir]
705 self.cluster_dir = unicode(cluster_dir)
705 self.cluster_dir = unicode(cluster_dir)
706 return super(WindowsHPCControllerLauncher, self).start(1)
706 return super(WindowsHPCControllerLauncher, self).start(1)
707
707
708
708
709 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
709 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
710
710
711 job_file_name = CUnicode(u'ipengineset_job.xml', config=True)
711 job_file_name = CUnicode(u'ipengineset_job.xml', config=True)
712 extra_args = List([], config=False)
712 extra_args = List([], config=False)
713
713
714 def write_job_file(self, n):
714 def write_job_file(self, n):
715 job = IPEngineSetJob(config=self.config)
715 job = IPEngineSetJob(config=self.config)
716
716
717 for i in range(n):
717 for i in range(n):
718 t = IPEngineTask(config=self.config)
718 t = IPEngineTask(config=self.config)
719 # The tasks work directory is *not* the actual work directory of
719 # The tasks work directory is *not* the actual work directory of
720 # the engine. It is used as the base path for the stdout/stderr
720 # the engine. It is used as the base path for the stdout/stderr
721 # files that the scheduler redirects to.
721 # files that the scheduler redirects to.
722 t.work_directory = self.cluster_dir
722 t.work_directory = self.cluster_dir
723 # Add the --cluster-dir and from self.start().
723 # Add the --cluster-dir and from self.start().
724 t.engine_args.extend(self.extra_args)
724 t.engine_args.extend(self.extra_args)
725 job.add_task(t)
725 job.add_task(t)
726
726
727 self.log.info("Writing job description file: %s" % self.job_file)
727 self.log.info("Writing job description file: %s" % self.job_file)
728 job.write(self.job_file)
728 job.write(self.job_file)
729
729
730 @property
730 @property
731 def job_file(self):
731 def job_file(self):
732 return os.path.join(self.cluster_dir, self.job_file_name)
732 return os.path.join(self.cluster_dir, self.job_file_name)
733
733
734 def start(self, n, cluster_dir):
734 def start(self, n, cluster_dir):
735 """Start the controller by cluster_dir."""
735 """Start the controller by cluster_dir."""
736 self.extra_args = ['--cluster-dir', cluster_dir]
736 self.extra_args = ['--cluster-dir', cluster_dir]
737 self.cluster_dir = unicode(cluster_dir)
737 self.cluster_dir = unicode(cluster_dir)
738 return super(WindowsHPCEngineSetLauncher, self).start(n)
738 return super(WindowsHPCEngineSetLauncher, self).start(n)
739
739
740
740
741 #-----------------------------------------------------------------------------
741 #-----------------------------------------------------------------------------
742 # Batch (PBS) system launchers
742 # Batch (PBS) system launchers
743 #-----------------------------------------------------------------------------
743 #-----------------------------------------------------------------------------
744
744
745 class BatchSystemLauncher(BaseLauncher):
745 class BatchSystemLauncher(BaseLauncher):
746 """Launch an external process using a batch system.
746 """Launch an external process using a batch system.
747
747
748 This class is designed to work with UNIX batch systems like PBS, LSF,
748 This class is designed to work with UNIX batch systems like PBS, LSF,
749 GridEngine, etc. The overall model is that there are different commands
749 GridEngine, etc. The overall model is that there are different commands
750 like qsub, qdel, etc. that handle the starting and stopping of the process.
750 like qsub, qdel, etc. that handle the starting and stopping of the process.
751
751
752 This class also has the notion of a batch script. The ``batch_template``
752 This class also has the notion of a batch script. The ``batch_template``
753 attribute can be set to a string that is a template for the batch script.
753 attribute can be set to a string that is a template for the batch script.
754 This template is instantiated using Itpl. Thus the template can use
754 This template is instantiated using Itpl. Thus the template can use
755 ${n} fot the number of instances. Subclasses can add additional variables
755 ${n} fot the number of instances. Subclasses can add additional variables
756 to the template dict.
756 to the template dict.
757 """
757 """
758
758
759 # Subclasses must fill these in. See PBSEngineSet
759 # Subclasses must fill these in. See PBSEngineSet
760 # The name of the command line program used to submit jobs.
760 # The name of the command line program used to submit jobs.
761 submit_command = List([''], config=True)
761 submit_command = List([''], config=True)
762 # The name of the command line program used to delete jobs.
762 # The name of the command line program used to delete jobs.
763 delete_command = List([''], config=True)
763 delete_command = List([''], config=True)
764 # A regular expression used to get the job id from the output of the
764 # A regular expression used to get the job id from the output of the
765 # submit_command.
765 # submit_command.
766 job_id_regexp = CUnicode('', config=True)
766 job_id_regexp = CUnicode('', config=True)
767 # The string that is the batch script template itself.
767 # The string that is the batch script template itself.
768 batch_template = CUnicode('', config=True)
768 batch_template = CUnicode('', config=True)
769 # The file that contains the batch template
769 # The file that contains the batch template
770 batch_template_file = CUnicode(u'', config=True)
770 batch_template_file = CUnicode(u'', config=True)
771 # The filename of the instantiated batch script.
771 # The filename of the instantiated batch script.
772 batch_file_name = CUnicode(u'batch_script', config=True)
772 batch_file_name = CUnicode(u'batch_script', config=True)
773 # The PBS Queue
773 # The PBS Queue
774 queue = CUnicode(u'', config=True)
774 queue = CUnicode(u'', config=True)
775
775
776 # not configurable, override in subclasses
776 # not configurable, override in subclasses
777 # PBS Job Array regex
777 # PBS Job Array regex
778 job_array_regexp = CUnicode('')
778 job_array_regexp = CUnicode('')
779 job_array_template = CUnicode('')
779 job_array_template = CUnicode('')
780 # PBS Queue regex
780 # PBS Queue regex
781 queue_regexp = CUnicode('')
781 queue_regexp = CUnicode('')
782 queue_template = CUnicode('')
782 queue_template = CUnicode('')
783 # The default batch template, override in subclasses
783 # The default batch template, override in subclasses
784 default_template = CUnicode('')
784 default_template = CUnicode('')
785 # The full path to the instantiated batch script.
785 # The full path to the instantiated batch script.
786 batch_file = CUnicode(u'')
786 batch_file = CUnicode(u'')
787 # the format dict used with batch_template:
787 # the format dict used with batch_template:
788 context = Dict()
788 context = Dict()
789
789
790
790
791 def find_args(self):
791 def find_args(self):
792 return self.submit_command + [self.batch_file]
792 return self.submit_command + [self.batch_file]
793
793
794 def __init__(self, work_dir=u'.', config=None, **kwargs):
794 def __init__(self, work_dir=u'.', config=None, **kwargs):
795 super(BatchSystemLauncher, self).__init__(
795 super(BatchSystemLauncher, self).__init__(
796 work_dir=work_dir, config=config, **kwargs
796 work_dir=work_dir, config=config, **kwargs
797 )
797 )
798 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
798 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
799
799
800 def parse_job_id(self, output):
800 def parse_job_id(self, output):
801 """Take the output of the submit command and return the job id."""
801 """Take the output of the submit command and return the job id."""
802 m = re.search(self.job_id_regexp, output)
802 m = re.search(self.job_id_regexp, output)
803 if m is not None:
803 if m is not None:
804 job_id = m.group()
804 job_id = m.group()
805 else:
805 else:
806 raise LauncherError("Job id couldn't be determined: %s" % output)
806 raise LauncherError("Job id couldn't be determined: %s" % output)
807 self.job_id = job_id
807 self.job_id = job_id
808 self.log.info('Job submitted with job id: %r' % job_id)
808 self.log.info('Job submitted with job id: %r' % job_id)
809 return job_id
809 return job_id
810
810
811 def write_batch_script(self, n):
811 def write_batch_script(self, n):
812 """Instantiate and write the batch script to the work_dir."""
812 """Instantiate and write the batch script to the work_dir."""
813 self.context['n'] = n
813 self.context['n'] = n
814 self.context['queue'] = self.queue
814 self.context['queue'] = self.queue
815 print self.context
815 print self.context
816 # first priority is batch_template if set
816 # first priority is batch_template if set
817 if self.batch_template_file and not self.batch_template:
817 if self.batch_template_file and not self.batch_template:
818 # second priority is batch_template_file
818 # second priority is batch_template_file
819 with open(self.batch_template_file) as f:
819 with open(self.batch_template_file) as f:
820 self.batch_template = f.read()
820 self.batch_template = f.read()
821 if not self.batch_template:
821 if not self.batch_template:
822 # third (last) priority is default_template
822 # third (last) priority is default_template
823 self.batch_template = self.default_template
823 self.batch_template = self.default_template
824
824
825 regex = re.compile(self.job_array_regexp)
825 regex = re.compile(self.job_array_regexp)
826 # print regex.search(self.batch_template)
826 # print regex.search(self.batch_template)
827 if not regex.search(self.batch_template):
827 if not regex.search(self.batch_template):
828 self.log.info("adding job array settings to batch script")
828 self.log.info("adding job array settings to batch script")
829 firstline, rest = self.batch_template.split('\n',1)
829 firstline, rest = self.batch_template.split('\n',1)
830 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
830 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
831
831
832 regex = re.compile(self.queue_regexp)
832 regex = re.compile(self.queue_regexp)
833 # print regex.search(self.batch_template)
833 # print regex.search(self.batch_template)
834 if self.queue and not regex.search(self.batch_template):
834 if self.queue and not regex.search(self.batch_template):
835 self.log.info("adding PBS queue settings to batch script")
835 self.log.info("adding PBS queue settings to batch script")
836 firstline, rest = self.batch_template.split('\n',1)
836 firstline, rest = self.batch_template.split('\n',1)
837 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
837 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
838
838
839 script_as_string = Itpl.itplns(self.batch_template, self.context)
839 script_as_string = Itpl.itplns(self.batch_template, self.context)
840 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
840 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
841
841
842 with open(self.batch_file, 'w') as f:
842 with open(self.batch_file, 'w') as f:
843 f.write(script_as_string)
843 f.write(script_as_string)
844 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
844 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
845
845
846 def start(self, n, cluster_dir):
846 def start(self, n, cluster_dir):
847 """Start n copies of the process using a batch system."""
847 """Start n copies of the process using a batch system."""
848 # Here we save profile and cluster_dir in the context so they
848 # Here we save profile and cluster_dir in the context so they
849 # can be used in the batch script template as ${profile} and
849 # can be used in the batch script template as ${profile} and
850 # ${cluster_dir}
850 # ${cluster_dir}
851 self.context['cluster_dir'] = cluster_dir
851 self.context['cluster_dir'] = cluster_dir
852 self.cluster_dir = unicode(cluster_dir)
852 self.cluster_dir = unicode(cluster_dir)
853 self.write_batch_script(n)
853 self.write_batch_script(n)
854 output = check_output(self.args, env=os.environ)
854 output = check_output(self.args, env=os.environ)
855
855
856 job_id = self.parse_job_id(output)
856 job_id = self.parse_job_id(output)
857 self.notify_start(job_id)
857 self.notify_start(job_id)
858 return job_id
858 return job_id
859
859
860 def stop(self):
860 def stop(self):
861 output = check_output(self.delete_command+[self.job_id], env=os.environ)
861 output = check_output(self.delete_command+[self.job_id], env=os.environ)
862 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
862 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
863 return output
863 return output
864
864
865
865
866 class PBSLauncher(BatchSystemLauncher):
866 class PBSLauncher(BatchSystemLauncher):
867 """A BatchSystemLauncher subclass for PBS."""
867 """A BatchSystemLauncher subclass for PBS."""
868
868
869 submit_command = List(['qsub'], config=True)
869 submit_command = List(['qsub'], config=True)
870 delete_command = List(['qdel'], config=True)
870 delete_command = List(['qdel'], config=True)
871 job_id_regexp = CUnicode(r'\d+', config=True)
871 job_id_regexp = CUnicode(r'\d+', config=True)
872
872
873 batch_file = CUnicode(u'')
873 batch_file = CUnicode(u'')
874 job_array_regexp = CUnicode('#PBS\W+-t\W+[\w\d\-\$]+')
874 job_array_regexp = CUnicode('#PBS\W+-t\W+[\w\d\-\$]+')
875 job_array_template = CUnicode('#PBS -t 1-$n')
875 job_array_template = CUnicode('#PBS -t 1-$n')
876 queue_regexp = CUnicode('#PBS\W+-q\W+\$?\w+')
876 queue_regexp = CUnicode('#PBS\W+-q\W+\$?\w+')
877 queue_template = CUnicode('#PBS -q $queue')
877 queue_template = CUnicode('#PBS -q $queue')
878
878
879
879
880 class PBSControllerLauncher(PBSLauncher):
880 class PBSControllerLauncher(PBSLauncher):
881 """Launch a controller using PBS."""
881 """Launch a controller using PBS."""
882
882
883 batch_file_name = CUnicode(u'pbs_controller', config=True)
883 batch_file_name = CUnicode(u'pbs_controller', config=True)
884 default_template= CUnicode("""#!/bin/sh
884 default_template= CUnicode("""#!/bin/sh
885 #PBS -V
885 #PBS -V
886 #PBS -N ipcontrollerz
886 #PBS -N ipcontrollerz
887 %s --log-to-file --cluster-dir $cluster_dir
887 %s --log-to-file --cluster-dir $cluster_dir
888 """%(' '.join(ipcontrollerz_cmd_argv)))
888 """%(' '.join(ipcontrollerz_cmd_argv)))
889
889
890 def start(self, cluster_dir):
890 def start(self, cluster_dir):
891 """Start the controller by profile or cluster_dir."""
891 """Start the controller by profile or cluster_dir."""
892 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
892 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
893 return super(PBSControllerLauncher, self).start(1, cluster_dir)
893 return super(PBSControllerLauncher, self).start(1, cluster_dir)
894
894
895
895
896 class PBSEngineSetLauncher(PBSLauncher):
896 class PBSEngineSetLauncher(PBSLauncher):
897 """Launch Engines using PBS"""
897 """Launch Engines using PBS"""
898 batch_file_name = CUnicode(u'pbs_engines', config=True)
898 batch_file_name = CUnicode(u'pbs_engines', config=True)
899 default_template= CUnicode(u"""#!/bin/sh
899 default_template= CUnicode(u"""#!/bin/sh
900 #PBS -V
900 #PBS -V
901 #PBS -N ipenginez
901 #PBS -N ipenginez
902 %s --cluster-dir $cluster_dir
902 %s --cluster-dir $cluster_dir
903 """%(' '.join(ipenginez_cmd_argv)))
903 """%(' '.join(ipenginez_cmd_argv)))
904
904
905 def start(self, n, cluster_dir):
905 def start(self, n, cluster_dir):
906 """Start n engines by profile or cluster_dir."""
906 """Start n engines by profile or cluster_dir."""
907 self.log.info('Starting %n engines with PBSEngineSetLauncher: %r' % (n, self.args))
907 self.log.info('Starting %n engines with PBSEngineSetLauncher: %r' % (n, self.args))
908 return super(PBSEngineSetLauncher, self).start(n, cluster_dir)
908 return super(PBSEngineSetLauncher, self).start(n, cluster_dir)
909
909
910 #SGE is very similar to PBS
910 #SGE is very similar to PBS
911
911
912 class SGELauncher(PBSLauncher):
912 class SGELauncher(PBSLauncher):
913 """Sun GridEngine is a PBS clone with slightly different syntax"""
913 """Sun GridEngine is a PBS clone with slightly different syntax"""
914 job_array_regexp = CUnicode('#$$\W+-t\W+[\w\d\-\$]+')
914 job_array_regexp = CUnicode('#$$\W+-t\W+[\w\d\-\$]+')
915 job_array_template = CUnicode('#$$ -t 1-$n')
915 job_array_template = CUnicode('#$$ -t 1-$n')
916 queue_regexp = CUnicode('#$$\W+-q\W+\$?\w+')
916 queue_regexp = CUnicode('#$$\W+-q\W+\$?\w+')
917 queue_template = CUnicode('#$$ -q $queue')
917 queue_template = CUnicode('#$$ -q $queue')
918
918
919 class SGEControllerLauncher(SGELauncher):
919 class SGEControllerLauncher(SGELauncher):
920 """Launch a controller using SGE."""
920 """Launch a controller using SGE."""
921
921
922 batch_file_name = CUnicode(u'sge_controller', config=True)
922 batch_file_name = CUnicode(u'sge_controller', config=True)
923 default_template= CUnicode(u"""#$$ -V
923 default_template= CUnicode(u"""#$$ -V
924 #$$ -S /bin/sh
924 #$$ -S /bin/sh
925 #$$ -N ipcontrollerz
925 #$$ -N ipcontrollerz
926 %s --log-to-file --cluster-dir $cluster_dir
926 %s --log-to-file --cluster-dir $cluster_dir
927 """%(' '.join(ipcontrollerz_cmd_argv)))
927 """%(' '.join(ipcontrollerz_cmd_argv)))
928
928
929 def start(self, cluster_dir):
929 def start(self, cluster_dir):
930 """Start the controller by profile or cluster_dir."""
930 """Start the controller by profile or cluster_dir."""
931 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
931 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
932 return super(PBSControllerLauncher, self).start(1, cluster_dir)
932 return super(PBSControllerLauncher, self).start(1, cluster_dir)
933
933
934 class SGEEngineSetLauncher(SGELauncher):
934 class SGEEngineSetLauncher(SGELauncher):
935 """Launch Engines with SGE"""
935 """Launch Engines with SGE"""
936 batch_file_name = CUnicode(u'sge_engines', config=True)
936 batch_file_name = CUnicode(u'sge_engines', config=True)
937 default_template = CUnicode("""#$$ -V
937 default_template = CUnicode("""#$$ -V
938 #$$ -S /bin/sh
938 #$$ -S /bin/sh
939 #$$ -N ipenginez
939 #$$ -N ipenginez
940 %s --cluster-dir $cluster_dir
940 %s --cluster-dir $cluster_dir
941 """%(' '.join(ipenginez_cmd_argv)))
941 """%(' '.join(ipenginez_cmd_argv)))
942
942
943 def start(self, n, cluster_dir):
943 def start(self, n, cluster_dir):
944 """Start n engines by profile or cluster_dir."""
944 """Start n engines by profile or cluster_dir."""
945 self.log.info('Starting %n engines with SGEEngineSetLauncher: %r' % (n, self.args))
945 self.log.info('Starting %n engines with SGEEngineSetLauncher: %r' % (n, self.args))
946 return super(SGEEngineSetLauncher, self).start(n, cluster_dir)
946 return super(SGEEngineSetLauncher, self).start(n, cluster_dir)
947
947
948
948
949 #-----------------------------------------------------------------------------
949 #-----------------------------------------------------------------------------
950 # A launcher for ipcluster itself!
950 # A launcher for ipcluster itself!
951 #-----------------------------------------------------------------------------
951 #-----------------------------------------------------------------------------
952
952
953
953
954 class IPClusterLauncher(LocalProcessLauncher):
954 class IPClusterLauncher(LocalProcessLauncher):
955 """Launch the ipcluster program in an external process."""
955 """Launch the ipcluster program in an external process."""
956
956
957 ipcluster_cmd = List(ipclusterz_cmd_argv, config=True)
957 ipcluster_cmd = List(ipclusterz_cmd_argv, config=True)
958 # Command line arguments to pass to ipcluster.
958 # Command line arguments to pass to ipcluster.
959 ipcluster_args = List(
959 ipcluster_args = List(
960 ['--clean-logs', '--log-to-file', '--log-level', str(logging.INFO)], config=True)
960 ['--clean-logs', '--log-to-file', '--log-level', str(logging.INFO)], config=True)
961 ipcluster_subcommand = Str('start')
961 ipcluster_subcommand = Str('start')
962 ipcluster_n = Int(2)
962 ipcluster_n = Int(2)
963
963
964 def find_args(self):
964 def find_args(self):
965 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
965 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
966 ['-n', repr(self.ipcluster_n)] + self.ipcluster_args
966 ['-n', repr(self.ipcluster_n)] + self.ipcluster_args
967
967
968 def start(self):
968 def start(self):
969 self.log.info("Starting ipcluster: %r" % self.args)
969 self.log.info("Starting ipcluster: %r" % self.args)
970 return super(IPClusterLauncher, self).start()
970 return super(IPClusterLauncher, self).start()
971
971
@@ -1,147 +1,138 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import time
14 import time
15 from tempfile import mktemp
15 from tempfile import mktemp
16
16
17 import zmq
17 import zmq
18
18
19 from IPython.zmq.parallel import client as clientmod
19 from IPython.zmq.parallel import client as clientmod
20 from IPython.zmq.parallel import error
20 from IPython.zmq.parallel import error
21 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
21 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
22 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
22 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
23
23
24 from clienttest import ClusterTestCase, segfault, wait, add_engines
24 from clienttest import ClusterTestCase, segfault, wait, add_engines
25
25
26 def setup():
26 def setup():
27 add_engines(4)
27 add_engines(4)
28
28
29 class TestClient(ClusterTestCase):
29 class TestClient(ClusterTestCase):
30
30
31 def test_ids(self):
31 def test_ids(self):
32 n = len(self.client.ids)
32 n = len(self.client.ids)
33 self.add_engines(3)
33 self.add_engines(3)
34 self.assertEquals(len(self.client.ids), n+3)
34 self.assertEquals(len(self.client.ids), n+3)
35
35
36 def test_view_indexing(self):
36 def test_view_indexing(self):
37 """test index access for views"""
37 """test index access for views"""
38 self.add_engines(2)
38 self.add_engines(2)
39 targets = self.client._build_targets('all')[-1]
39 targets = self.client._build_targets('all')[-1]
40 v = self.client[:]
40 v = self.client[:]
41 self.assertEquals(v.targets, targets)
41 self.assertEquals(v.targets, targets)
42 t = self.client.ids[2]
42 t = self.client.ids[2]
43 v = self.client[t]
43 v = self.client[t]
44 self.assert_(isinstance(v, DirectView))
44 self.assert_(isinstance(v, DirectView))
45 self.assertEquals(v.targets, t)
45 self.assertEquals(v.targets, t)
46 t = self.client.ids[2:4]
46 t = self.client.ids[2:4]
47 v = self.client[t]
47 v = self.client[t]
48 self.assert_(isinstance(v, DirectView))
48 self.assert_(isinstance(v, DirectView))
49 self.assertEquals(v.targets, t)
49 self.assertEquals(v.targets, t)
50 v = self.client[::2]
50 v = self.client[::2]
51 self.assert_(isinstance(v, DirectView))
51 self.assert_(isinstance(v, DirectView))
52 self.assertEquals(v.targets, targets[::2])
52 self.assertEquals(v.targets, targets[::2])
53 v = self.client[1::3]
53 v = self.client[1::3]
54 self.assert_(isinstance(v, DirectView))
54 self.assert_(isinstance(v, DirectView))
55 self.assertEquals(v.targets, targets[1::3])
55 self.assertEquals(v.targets, targets[1::3])
56 v = self.client[:-3]
56 v = self.client[:-3]
57 self.assert_(isinstance(v, DirectView))
57 self.assert_(isinstance(v, DirectView))
58 self.assertEquals(v.targets, targets[:-3])
58 self.assertEquals(v.targets, targets[:-3])
59 v = self.client[-1]
59 v = self.client[-1]
60 self.assert_(isinstance(v, DirectView))
60 self.assert_(isinstance(v, DirectView))
61 self.assertEquals(v.targets, targets[-1])
61 self.assertEquals(v.targets, targets[-1])
62 self.assertRaises(TypeError, lambda : self.client[None])
62 self.assertRaises(TypeError, lambda : self.client[None])
63
63
64 def test_view_cache(self):
65 """test that multiple view requests return the same object"""
66 v = self.client[:2]
67 v2 =self.client[:2]
68 self.assertTrue(v is v2)
69 v = self.client.load_balanced_view()
70 v2 = self.client.load_balanced_view(targets=None)
71 self.assertTrue(v is v2)
72
73 def test_targets(self):
64 def test_targets(self):
74 """test various valid targets arguments"""
65 """test various valid targets arguments"""
75 build = self.client._build_targets
66 build = self.client._build_targets
76 ids = self.client.ids
67 ids = self.client.ids
77 idents,targets = build(None)
68 idents,targets = build(None)
78 self.assertEquals(ids, targets)
69 self.assertEquals(ids, targets)
79
70
80 def test_clear(self):
71 def test_clear(self):
81 """test clear behavior"""
72 """test clear behavior"""
82 # self.add_engines(2)
73 # self.add_engines(2)
83 v = self.client[:]
74 v = self.client[:]
84 v.block=True
75 v.block=True
85 v.push(dict(a=5))
76 v.push(dict(a=5))
86 v.pull('a')
77 v.pull('a')
87 id0 = self.client.ids[-1]
78 id0 = self.client.ids[-1]
88 self.client.clear(targets=id0)
79 self.client.clear(targets=id0)
89 self.client[:-1].pull('a')
80 self.client[:-1].pull('a')
90 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
81 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
91 self.client.clear(block=True)
82 self.client.clear(block=True)
92 for i in self.client.ids:
83 for i in self.client.ids:
93 # print i
84 # print i
94 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
85 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
95
86
96 def test_get_result(self):
87 def test_get_result(self):
97 """test getting results from the Hub."""
88 """test getting results from the Hub."""
98 c = clientmod.Client(profile='iptest')
89 c = clientmod.Client(profile='iptest')
99 # self.add_engines(1)
90 # self.add_engines(1)
100 t = c.ids[-1]
91 t = c.ids[-1]
101 ar = c[t].apply_async(wait, 1)
92 ar = c[t].apply_async(wait, 1)
102 # give the monitor time to notice the message
93 # give the monitor time to notice the message
103 time.sleep(.25)
94 time.sleep(.25)
104 ahr = self.client.get_result(ar.msg_ids)
95 ahr = self.client.get_result(ar.msg_ids)
105 self.assertTrue(isinstance(ahr, AsyncHubResult))
96 self.assertTrue(isinstance(ahr, AsyncHubResult))
106 self.assertEquals(ahr.get(), ar.get())
97 self.assertEquals(ahr.get(), ar.get())
107 ar2 = self.client.get_result(ar.msg_ids)
98 ar2 = self.client.get_result(ar.msg_ids)
108 self.assertFalse(isinstance(ar2, AsyncHubResult))
99 self.assertFalse(isinstance(ar2, AsyncHubResult))
109 c.close()
100 c.close()
110
101
111 def test_ids_list(self):
102 def test_ids_list(self):
112 """test client.ids"""
103 """test client.ids"""
113 # self.add_engines(2)
104 # self.add_engines(2)
114 ids = self.client.ids
105 ids = self.client.ids
115 self.assertEquals(ids, self.client._ids)
106 self.assertEquals(ids, self.client._ids)
116 self.assertFalse(ids is self.client._ids)
107 self.assertFalse(ids is self.client._ids)
117 ids.remove(ids[-1])
108 ids.remove(ids[-1])
118 self.assertNotEquals(ids, self.client._ids)
109 self.assertNotEquals(ids, self.client._ids)
119
110
120 def test_queue_status(self):
111 def test_queue_status(self):
121 # self.addEngine(4)
112 # self.addEngine(4)
122 ids = self.client.ids
113 ids = self.client.ids
123 id0 = ids[0]
114 id0 = ids[0]
124 qs = self.client.queue_status(targets=id0)
115 qs = self.client.queue_status(targets=id0)
125 self.assertTrue(isinstance(qs, dict))
116 self.assertTrue(isinstance(qs, dict))
126 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
117 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
127 allqs = self.client.queue_status()
118 allqs = self.client.queue_status()
128 self.assertTrue(isinstance(allqs, dict))
119 self.assertTrue(isinstance(allqs, dict))
129 self.assertEquals(sorted(allqs.keys()), self.client.ids)
120 self.assertEquals(sorted(allqs.keys()), self.client.ids)
130 for eid,qs in allqs.items():
121 for eid,qs in allqs.items():
131 self.assertTrue(isinstance(qs, dict))
122 self.assertTrue(isinstance(qs, dict))
132 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
123 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
133
124
134 def test_shutdown(self):
125 def test_shutdown(self):
135 # self.addEngine(4)
126 # self.addEngine(4)
136 ids = self.client.ids
127 ids = self.client.ids
137 id0 = ids[0]
128 id0 = ids[0]
138 self.client.shutdown(id0, block=True)
129 self.client.shutdown(id0, block=True)
139 while id0 in self.client.ids:
130 while id0 in self.client.ids:
140 time.sleep(0.1)
131 time.sleep(0.1)
141 self.client.spin()
132 self.client.spin()
142
133
143 self.assertRaises(IndexError, lambda : self.client[id0])
134 self.assertRaises(IndexError, lambda : self.client[id0])
144
135
145 def test_result_status(self):
136 def test_result_status(self):
146 pass
137 pass
147 # to be written
138 # to be written
@@ -1,287 +1,301 b''
1 """test View objects"""
1 """test View objects"""
2 #-------------------------------------------------------------------------------
2 #-------------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
3 # Copyright (C) 2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-------------------------------------------------------------------------------
7 #-------------------------------------------------------------------------------
8
8
9 #-------------------------------------------------------------------------------
9 #-------------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12
12
13 import time
13 import time
14 from tempfile import mktemp
14 from tempfile import mktemp
15
15
16 import zmq
16 import zmq
17
17
18 from IPython.zmq.parallel import client as clientmod
18 from IPython.zmq.parallel import client as clientmod
19 from IPython.zmq.parallel import error
19 from IPython.zmq.parallel import error
20 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult, AsyncMapResult
20 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult, AsyncMapResult
21 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
21 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
22 from IPython.zmq.parallel.util import interactive
22 from IPython.zmq.parallel.util import interactive
23
23
24 from IPython.zmq.parallel.tests import add_engines
24 from IPython.zmq.parallel.tests import add_engines
25
25
26 from .clienttest import ClusterTestCase, segfault, wait, skip_without
26 from .clienttest import ClusterTestCase, segfault, wait, skip_without
27
27
28 def setup():
28 def setup():
29 add_engines(3)
29 add_engines(3)
30
30
31 class TestView(ClusterTestCase):
31 class TestView(ClusterTestCase):
32
32
33 def test_segfault_task(self):
33 def test_segfault_task(self):
34 """test graceful handling of engine death (balanced)"""
34 """test graceful handling of engine death (balanced)"""
35 # self.add_engines(1)
35 # self.add_engines(1)
36 ar = self.client[-1].apply_async(segfault)
36 ar = self.client[-1].apply_async(segfault)
37 self.assertRaisesRemote(error.EngineError, ar.get)
37 self.assertRaisesRemote(error.EngineError, ar.get)
38 eid = ar.engine_id
38 eid = ar.engine_id
39 while eid in self.client.ids:
39 while eid in self.client.ids:
40 time.sleep(.01)
40 time.sleep(.01)
41 self.client.spin()
41 self.client.spin()
42
42
43 def test_segfault_mux(self):
43 def test_segfault_mux(self):
44 """test graceful handling of engine death (direct)"""
44 """test graceful handling of engine death (direct)"""
45 # self.add_engines(1)
45 # self.add_engines(1)
46 eid = self.client.ids[-1]
46 eid = self.client.ids[-1]
47 ar = self.client[eid].apply_async(segfault)
47 ar = self.client[eid].apply_async(segfault)
48 self.assertRaisesRemote(error.EngineError, ar.get)
48 self.assertRaisesRemote(error.EngineError, ar.get)
49 eid = ar.engine_id
49 eid = ar.engine_id
50 while eid in self.client.ids:
50 while eid in self.client.ids:
51 time.sleep(.01)
51 time.sleep(.01)
52 self.client.spin()
52 self.client.spin()
53
53
54 def test_push_pull(self):
54 def test_push_pull(self):
55 """test pushing and pulling"""
55 """test pushing and pulling"""
56 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
56 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
57 t = self.client.ids[-1]
57 t = self.client.ids[-1]
58 v = self.client[t]
58 v = self.client[t]
59 push = v.push
59 push = v.push
60 pull = v.pull
60 pull = v.pull
61 v.block=True
61 v.block=True
62 nengines = len(self.client)
62 nengines = len(self.client)
63 push({'data':data})
63 push({'data':data})
64 d = pull('data')
64 d = pull('data')
65 self.assertEquals(d, data)
65 self.assertEquals(d, data)
66 self.client[:].push({'data':data})
66 self.client[:].push({'data':data})
67 d = self.client[:].pull('data', block=True)
67 d = self.client[:].pull('data', block=True)
68 self.assertEquals(d, nengines*[data])
68 self.assertEquals(d, nengines*[data])
69 ar = push({'data':data}, block=False)
69 ar = push({'data':data}, block=False)
70 self.assertTrue(isinstance(ar, AsyncResult))
70 self.assertTrue(isinstance(ar, AsyncResult))
71 r = ar.get()
71 r = ar.get()
72 ar = self.client[:].pull('data', block=False)
72 ar = self.client[:].pull('data', block=False)
73 self.assertTrue(isinstance(ar, AsyncResult))
73 self.assertTrue(isinstance(ar, AsyncResult))
74 r = ar.get()
74 r = ar.get()
75 self.assertEquals(r, nengines*[data])
75 self.assertEquals(r, nengines*[data])
76 self.client[:].push(dict(a=10,b=20))
76 self.client[:].push(dict(a=10,b=20))
77 r = self.client[:].pull(('a','b'))
77 r = self.client[:].pull(('a','b'))
78 self.assertEquals(r, nengines*[[10,20]])
78 self.assertEquals(r, nengines*[[10,20]])
79
79
80 def test_push_pull_function(self):
80 def test_push_pull_function(self):
81 "test pushing and pulling functions"
81 "test pushing and pulling functions"
82 def testf(x):
82 def testf(x):
83 return 2.0*x
83 return 2.0*x
84
84
85 t = self.client.ids[-1]
85 t = self.client.ids[-1]
86 self.client[t].block=True
86 self.client[t].block=True
87 push = self.client[t].push
87 push = self.client[t].push
88 pull = self.client[t].pull
88 pull = self.client[t].pull
89 execute = self.client[t].execute
89 execute = self.client[t].execute
90 push({'testf':testf})
90 push({'testf':testf})
91 r = pull('testf')
91 r = pull('testf')
92 self.assertEqual(r(1.0), testf(1.0))
92 self.assertEqual(r(1.0), testf(1.0))
93 execute('r = testf(10)')
93 execute('r = testf(10)')
94 r = pull('r')
94 r = pull('r')
95 self.assertEquals(r, testf(10))
95 self.assertEquals(r, testf(10))
96 ar = self.client[:].push({'testf':testf}, block=False)
96 ar = self.client[:].push({'testf':testf}, block=False)
97 ar.get()
97 ar.get()
98 ar = self.client[:].pull('testf', block=False)
98 ar = self.client[:].pull('testf', block=False)
99 rlist = ar.get()
99 rlist = ar.get()
100 for r in rlist:
100 for r in rlist:
101 self.assertEqual(r(1.0), testf(1.0))
101 self.assertEqual(r(1.0), testf(1.0))
102 execute("def g(x): return x*x")
102 execute("def g(x): return x*x")
103 r = pull(('testf','g'))
103 r = pull(('testf','g'))
104 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
104 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
105
105
106 def test_push_function_globals(self):
106 def test_push_function_globals(self):
107 """test that pushed functions have access to globals"""
107 """test that pushed functions have access to globals"""
108 @interactive
108 @interactive
109 def geta():
109 def geta():
110 return a
110 return a
111 # self.add_engines(1)
111 # self.add_engines(1)
112 v = self.client[-1]
112 v = self.client[-1]
113 v.block=True
113 v.block=True
114 v['f'] = geta
114 v['f'] = geta
115 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
115 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
116 v.execute('a=5')
116 v.execute('a=5')
117 v.execute('b=f()')
117 v.execute('b=f()')
118 self.assertEquals(v['b'], 5)
118 self.assertEquals(v['b'], 5)
119
119
120 def test_push_function_defaults(self):
120 def test_push_function_defaults(self):
121 """test that pushed functions preserve default args"""
121 """test that pushed functions preserve default args"""
122 def echo(a=10):
122 def echo(a=10):
123 return a
123 return a
124 v = self.client[-1]
124 v = self.client[-1]
125 v.block=True
125 v.block=True
126 v['f'] = echo
126 v['f'] = echo
127 v.execute('b=f()')
127 v.execute('b=f()')
128 self.assertEquals(v['b'], 10)
128 self.assertEquals(v['b'], 10)
129
129
130 def test_get_result(self):
130 def test_get_result(self):
131 """test getting results from the Hub."""
131 """test getting results from the Hub."""
132 c = clientmod.Client(profile='iptest')
132 c = clientmod.Client(profile='iptest')
133 # self.add_engines(1)
133 # self.add_engines(1)
134 t = c.ids[-1]
134 t = c.ids[-1]
135 v = c[t]
135 v = c[t]
136 v2 = self.client[t]
136 v2 = self.client[t]
137 ar = v.apply_async(wait, 1)
137 ar = v.apply_async(wait, 1)
138 # give the monitor time to notice the message
138 # give the monitor time to notice the message
139 time.sleep(.25)
139 time.sleep(.25)
140 ahr = v2.get_result(ar.msg_ids)
140 ahr = v2.get_result(ar.msg_ids)
141 self.assertTrue(isinstance(ahr, AsyncHubResult))
141 self.assertTrue(isinstance(ahr, AsyncHubResult))
142 self.assertEquals(ahr.get(), ar.get())
142 self.assertEquals(ahr.get(), ar.get())
143 ar2 = v2.get_result(ar.msg_ids)
143 ar2 = v2.get_result(ar.msg_ids)
144 self.assertFalse(isinstance(ar2, AsyncHubResult))
144 self.assertFalse(isinstance(ar2, AsyncHubResult))
145 c.spin()
145 c.spin()
146 c.close()
146 c.close()
147
147
148 def test_run_newline(self):
148 def test_run_newline(self):
149 """test that run appends newline to files"""
149 """test that run appends newline to files"""
150 tmpfile = mktemp()
150 tmpfile = mktemp()
151 with open(tmpfile, 'w') as f:
151 with open(tmpfile, 'w') as f:
152 f.write("""def g():
152 f.write("""def g():
153 return 5
153 return 5
154 """)
154 """)
155 v = self.client[-1]
155 v = self.client[-1]
156 v.run(tmpfile, block=True)
156 v.run(tmpfile, block=True)
157 self.assertEquals(v.apply_sync(lambda f: f(), clientmod.Reference('g')), 5)
157 self.assertEquals(v.apply_sync(lambda f: f(), clientmod.Reference('g')), 5)
158
158
159 def test_apply_tracked(self):
159 def test_apply_tracked(self):
160 """test tracking for apply"""
160 """test tracking for apply"""
161 # self.add_engines(1)
161 # self.add_engines(1)
162 t = self.client.ids[-1]
162 t = self.client.ids[-1]
163 v = self.client[t]
163 v = self.client[t]
164 v.block=False
164 v.block=False
165 def echo(n=1024*1024, **kwargs):
165 def echo(n=1024*1024, **kwargs):
166 with v.temp_flags(**kwargs):
166 with v.temp_flags(**kwargs):
167 return v.apply(lambda x: x, 'x'*n)
167 return v.apply(lambda x: x, 'x'*n)
168 ar = echo(1, track=False)
168 ar = echo(1, track=False)
169 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
169 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
170 self.assertTrue(ar.sent)
170 self.assertTrue(ar.sent)
171 ar = echo(track=True)
171 ar = echo(track=True)
172 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
172 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
173 self.assertEquals(ar.sent, ar._tracker.done)
173 self.assertEquals(ar.sent, ar._tracker.done)
174 ar._tracker.wait()
174 ar._tracker.wait()
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176
176
177 def test_push_tracked(self):
177 def test_push_tracked(self):
178 t = self.client.ids[-1]
178 t = self.client.ids[-1]
179 ns = dict(x='x'*1024*1024)
179 ns = dict(x='x'*1024*1024)
180 v = self.client[t]
180 v = self.client[t]
181 ar = v.push(ns, block=False, track=False)
181 ar = v.push(ns, block=False, track=False)
182 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
182 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(ar.sent)
183 self.assertTrue(ar.sent)
184
184
185 ar = v.push(ns, block=False, track=True)
185 ar = v.push(ns, block=False, track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEquals(ar.sent, ar._tracker.done)
187 self.assertEquals(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
188 ar._tracker.wait()
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190 ar.get()
190 ar.get()
191
191
192 def test_scatter_tracked(self):
192 def test_scatter_tracked(self):
193 t = self.client.ids
193 t = self.client.ids
194 x='x'*1024*1024
194 x='x'*1024*1024
195 ar = self.client[t].scatter('x', x, block=False, track=False)
195 ar = self.client[t].scatter('x', x, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
197 self.assertTrue(ar.sent)
198
198
199 ar = self.client[t].scatter('x', x, block=False, track=True)
199 ar = self.client[t].scatter('x', x, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertEquals(ar.sent, ar._tracker.done)
201 self.assertEquals(ar.sent, ar._tracker.done)
202 ar._tracker.wait()
202 ar._tracker.wait()
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204 ar.get()
204 ar.get()
205
205
206 def test_remote_reference(self):
206 def test_remote_reference(self):
207 v = self.client[-1]
207 v = self.client[-1]
208 v['a'] = 123
208 v['a'] = 123
209 ra = clientmod.Reference('a')
209 ra = clientmod.Reference('a')
210 b = v.apply_sync(lambda x: x, ra)
210 b = v.apply_sync(lambda x: x, ra)
211 self.assertEquals(b, 123)
211 self.assertEquals(b, 123)
212
212
213
213
214 def test_scatter_gather(self):
214 def test_scatter_gather(self):
215 view = self.client[:]
215 view = self.client[:]
216 seq1 = range(16)
216 seq1 = range(16)
217 view.scatter('a', seq1)
217 view.scatter('a', seq1)
218 seq2 = view.gather('a', block=True)
218 seq2 = view.gather('a', block=True)
219 self.assertEquals(seq2, seq1)
219 self.assertEquals(seq2, seq1)
220 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
220 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
221
221
222 @skip_without('numpy')
222 @skip_without('numpy')
223 def test_scatter_gather_numpy(self):
223 def test_scatter_gather_numpy(self):
224 import numpy
224 import numpy
225 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
225 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
226 view = self.client[:]
226 view = self.client[:]
227 a = numpy.arange(64)
227 a = numpy.arange(64)
228 view.scatter('a', a)
228 view.scatter('a', a)
229 b = view.gather('a', block=True)
229 b = view.gather('a', block=True)
230 assert_array_equal(b, a)
230 assert_array_equal(b, a)
231
231
232 def test_map(self):
232 def test_map(self):
233 view = self.client[:]
233 view = self.client[:]
234 def f(x):
234 def f(x):
235 return x**2
235 return x**2
236 data = range(16)
236 data = range(16)
237 r = view.map_sync(f, data)
237 r = view.map_sync(f, data)
238 self.assertEquals(r, map(f, data))
238 self.assertEquals(r, map(f, data))
239
239
240 def test_scatterGatherNonblocking(self):
240 def test_scatterGatherNonblocking(self):
241 data = range(16)
241 data = range(16)
242 view = self.client[:]
242 view = self.client[:]
243 view.scatter('a', data, block=False)
243 view.scatter('a', data, block=False)
244 ar = view.gather('a', block=False)
244 ar = view.gather('a', block=False)
245 self.assertEquals(ar.get(), data)
245 self.assertEquals(ar.get(), data)
246
246
247 @skip_without('numpy')
247 @skip_without('numpy')
248 def test_scatter_gather_numpy_nonblocking(self):
248 def test_scatter_gather_numpy_nonblocking(self):
249 import numpy
249 import numpy
250 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
250 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
251 a = numpy.arange(64)
251 a = numpy.arange(64)
252 view = self.client[:]
252 view = self.client[:]
253 ar = view.scatter('a', a, block=False)
253 ar = view.scatter('a', a, block=False)
254 self.assertTrue(isinstance(ar, AsyncResult))
254 self.assertTrue(isinstance(ar, AsyncResult))
255 amr = view.gather('a', block=False)
255 amr = view.gather('a', block=False)
256 self.assertTrue(isinstance(amr, AsyncMapResult))
256 self.assertTrue(isinstance(amr, AsyncMapResult))
257 assert_array_equal(amr.get(), a)
257 assert_array_equal(amr.get(), a)
258
258
259 def test_execute(self):
259 def test_execute(self):
260 view = self.client[:]
260 view = self.client[:]
261 # self.client.debug=True
261 # self.client.debug=True
262 execute = view.execute
262 execute = view.execute
263 ar = execute('c=30', block=False)
263 ar = execute('c=30', block=False)
264 self.assertTrue(isinstance(ar, AsyncResult))
264 self.assertTrue(isinstance(ar, AsyncResult))
265 ar = execute('d=[0,1,2]', block=False)
265 ar = execute('d=[0,1,2]', block=False)
266 self.client.wait(ar, 1)
266 self.client.wait(ar, 1)
267 self.assertEquals(len(ar.get()), len(self.client))
267 self.assertEquals(len(ar.get()), len(self.client))
268 for c in view['c']:
268 for c in view['c']:
269 self.assertEquals(c, 30)
269 self.assertEquals(c, 30)
270
270
271 def test_abort(self):
271 def test_abort(self):
272 view = self.client[-1]
272 view = self.client[-1]
273 ar = view.execute('import time; time.sleep(0.25)', block=False)
273 ar = view.execute('import time; time.sleep(0.25)', block=False)
274 ar2 = view.apply_async(lambda : 2)
274 ar2 = view.apply_async(lambda : 2)
275 ar3 = view.apply_async(lambda : 3)
275 ar3 = view.apply_async(lambda : 3)
276 view.abort(ar2)
276 view.abort(ar2)
277 view.abort(ar3.msg_ids)
277 view.abort(ar3.msg_ids)
278 self.assertRaises(error.TaskAborted, ar2.get)
278 self.assertRaises(error.TaskAborted, ar2.get)
279 self.assertRaises(error.TaskAborted, ar3.get)
279 self.assertRaises(error.TaskAborted, ar3.get)
280
280
281 def test_temp_flags(self):
281 def test_temp_flags(self):
282 view = self.client[-1]
282 view = self.client[-1]
283 view.block=True
283 view.block=True
284 with view.temp_flags(block=False):
284 with view.temp_flags(block=False):
285 self.assertFalse(view.block)
285 self.assertFalse(view.block)
286 self.assertTrue(view.block)
286 self.assertTrue(view.block)
287
287
288 def test_importer(self):
289 view = self.client[-1]
290 view.clear(block=True)
291 with view.importer:
292 import re
293
294 @interactive
295 def findall(pat, s):
296 # this globals() step isn't necessary in real code
297 # only to prevent a closure in the test
298 return globals()['re'].findall(pat, s)
299
300 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
301
@@ -1,920 +1,1028 b''
1 """Views of remote engines."""
1 """Views of remote engines."""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import imp
14 import sys
13 import warnings
15 import warnings
14 from contextlib import contextmanager
16 from contextlib import contextmanager
17 from types import ModuleType
15
18
16 import zmq
19 import zmq
17
20
18 from IPython.testing import decorators as testdec
21 from IPython.testing import decorators as testdec
19 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat
20
23
21 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
22
25
23 from . import map as Map
26 from . import map as Map
24 from . import util
27 from . import util
25 from .asyncresult import AsyncResult, AsyncMapResult
28 from .asyncresult import AsyncResult, AsyncMapResult
26 from .dependency import Dependency, dependent
29 from .dependency import Dependency, dependent
27 from .remotefunction import ParallelFunction, parallel, remote
30 from .remotefunction import ParallelFunction, parallel, remote
28
31
29 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
30 # Decorators
33 # Decorators
31 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
32
35
33 @decorator
36 @decorator
34 def save_ids(f, self, *args, **kwargs):
37 def save_ids(f, self, *args, **kwargs):
35 """Keep our history and outstanding attributes up to date after a method call."""
38 """Keep our history and outstanding attributes up to date after a method call."""
36 n_previous = len(self.client.history)
39 n_previous = len(self.client.history)
37 try:
40 try:
38 ret = f(self, *args, **kwargs)
41 ret = f(self, *args, **kwargs)
39 finally:
42 finally:
40 nmsgs = len(self.client.history) - n_previous
43 nmsgs = len(self.client.history) - n_previous
41 msg_ids = self.client.history[-nmsgs:]
44 msg_ids = self.client.history[-nmsgs:]
42 self.history.extend(msg_ids)
45 self.history.extend(msg_ids)
43 map(self.outstanding.add, msg_ids)
46 map(self.outstanding.add, msg_ids)
44 return ret
47 return ret
45
48
46 @decorator
49 @decorator
47 def sync_results(f, self, *args, **kwargs):
50 def sync_results(f, self, *args, **kwargs):
48 """sync relevant results from self.client to our results attribute."""
51 """sync relevant results from self.client to our results attribute."""
49 ret = f(self, *args, **kwargs)
52 ret = f(self, *args, **kwargs)
50 delta = self.outstanding.difference(self.client.outstanding)
53 delta = self.outstanding.difference(self.client.outstanding)
51 completed = self.outstanding.intersection(delta)
54 completed = self.outstanding.intersection(delta)
52 self.outstanding = self.outstanding.difference(completed)
55 self.outstanding = self.outstanding.difference(completed)
53 for msg_id in completed:
56 for msg_id in completed:
54 self.results[msg_id] = self.client.results[msg_id]
57 self.results[msg_id] = self.client.results[msg_id]
55 return ret
58 return ret
56
59
57 @decorator
60 @decorator
58 def spin_after(f, self, *args, **kwargs):
61 def spin_after(f, self, *args, **kwargs):
59 """call spin after the method."""
62 """call spin after the method."""
60 ret = f(self, *args, **kwargs)
63 ret = f(self, *args, **kwargs)
61 self.spin()
64 self.spin()
62 return ret
65 return ret
63
66
64 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
65 # Classes
68 # Classes
66 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
67
70
68 class View(HasTraits):
71 class View(HasTraits):
69 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
72 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
70
73
71 Don't use this class, use subclasses.
74 Don't use this class, use subclasses.
72
75
73 Methods
76 Methods
74 -------
77 -------
75
78
76 spin
79 spin
77 flushes incoming results and registration state changes
80 flushes incoming results and registration state changes
78 control methods spin, and requesting `ids` also ensures up to date
81 control methods spin, and requesting `ids` also ensures up to date
79
82
80 wait
83 wait
81 wait on one or more msg_ids
84 wait on one or more msg_ids
82
85
83 execution methods
86 execution methods
84 apply
87 apply
85 legacy: execute, run
88 legacy: execute, run
86
89
87 data movement
90 data movement
88 push, pull, scatter, gather
91 push, pull, scatter, gather
89
92
90 query methods
93 query methods
91 get_result, queue_status, purge_results, result_status
94 get_result, queue_status, purge_results, result_status
92
95
93 control methods
96 control methods
94 abort, shutdown
97 abort, shutdown
95
98
96 """
99 """
100 # flags
97 block=Bool(False)
101 block=Bool(False)
98 track=Bool(True)
102 track=Bool(True)
103 targets = Any()
104
99 history=List()
105 history=List()
100 outstanding = Set()
106 outstanding = Set()
101 results = Dict()
107 results = Dict()
102 client = Instance('IPython.zmq.parallel.client.Client')
108 client = Instance('IPython.zmq.parallel.client.Client')
103
109
104 _socket = Instance('zmq.Socket')
110 _socket = Instance('zmq.Socket')
105 _ntargets = Int(1)
111 _flag_names = List(['targets', 'block', 'track'])
106 _flag_names = List(['block', 'track'])
107 _targets = Any()
112 _targets = Any()
108 _idents = Any()
113 _idents = Any()
109
114
110 def __init__(self, client=None, socket=None, targets=None):
115 def __init__(self, client=None, socket=None, **flags):
111 super(View, self).__init__(client=client, _socket=socket)
116 super(View, self).__init__(client=client, _socket=socket)
112 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
113 self.block = client.block
117 self.block = client.block
114
118
115 self._idents, self._targets = self.client._build_targets(targets)
119 self.set_flags(**flags)
116 if targets is None or isinstance(targets, int):
117 self._targets = targets
118 for name in self._flag_names:
119 # set flags, if they haven't been set yet
120 setattr(self, name, getattr(self, name, None))
121
120
122 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
121 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
123
122
124
123
125 def __repr__(self):
124 def __repr__(self):
126 strtargets = str(self._targets)
125 strtargets = str(self.targets)
127 if len(strtargets) > 16:
126 if len(strtargets) > 16:
128 strtargets = strtargets[:12]+'...]'
127 strtargets = strtargets[:12]+'...]'
129 return "<%s %s>"%(self.__class__.__name__, strtargets)
128 return "<%s %s>"%(self.__class__.__name__, strtargets)
130
129
131 @property
132 def targets(self):
133 return self._targets
134
135 @targets.setter
136 def targets(self, value):
137 raise AttributeError("Cannot set View `targets` after construction!")
138
139 def set_flags(self, **kwargs):
130 def set_flags(self, **kwargs):
140 """set my attribute flags by keyword.
131 """set my attribute flags by keyword.
141
132
142 Views determine behavior with a few attributes (`block`, `track`, etc.).
133 Views determine behavior with a few attributes (`block`, `track`, etc.).
143 These attributes can be set all at once by name with this method.
134 These attributes can be set all at once by name with this method.
144
135
145 Parameters
136 Parameters
146 ----------
137 ----------
147
138
148 block : bool
139 block : bool
149 whether to wait for results
140 whether to wait for results
150 track : bool
141 track : bool
151 whether to create a MessageTracker to allow the user to
142 whether to create a MessageTracker to allow the user to
152 safely edit after arrays and buffers during non-copying
143 safely edit after arrays and buffers during non-copying
153 sends.
144 sends.
154 """
145 """
155 for name, value in kwargs.iteritems():
146 for name, value in kwargs.iteritems():
156 if name not in self._flag_names:
147 if name not in self._flag_names:
157 raise KeyError("Invalid name: %r"%name)
148 raise KeyError("Invalid name: %r"%name)
158 else:
149 else:
159 setattr(self, name, value)
150 setattr(self, name, value)
160
151
161 @contextmanager
152 @contextmanager
162 def temp_flags(self, **kwargs):
153 def temp_flags(self, **kwargs):
163 """temporarily set flags, for use in `with` statements.
154 """temporarily set flags, for use in `with` statements.
164
155
165 See set_flags for permanent setting of flags
156 See set_flags for permanent setting of flags
166
157
167 Examples
158 Examples
168 --------
159 --------
169
160
170 >>> view.track=False
161 >>> view.track=False
171 ...
162 ...
172 >>> with view.temp_flags(track=True):
163 >>> with view.temp_flags(track=True):
173 ... ar = view.apply(dostuff, my_big_array)
164 ... ar = view.apply(dostuff, my_big_array)
174 ... ar.tracker.wait() # wait for send to finish
165 ... ar.tracker.wait() # wait for send to finish
175 >>> view.track
166 >>> view.track
176 False
167 False
177
168
178 """
169 """
179 # preflight: save flags, and set temporaries
170 # preflight: save flags, and set temporaries
180 saved_flags = {}
171 saved_flags = {}
181 for f in self._flag_names:
172 for f in self._flag_names:
182 saved_flags[f] = getattr(self, f)
173 saved_flags[f] = getattr(self, f)
183 self.set_flags(**kwargs)
174 self.set_flags(**kwargs)
184 # yield to the with-statement block
175 # yield to the with-statement block
185 yield
176 try:
186 # postflight: restore saved flags
177 yield
187 self.set_flags(**saved_flags)
178 finally:
179 # postflight: restore saved flags
180 self.set_flags(**saved_flags)
188
181
189
182
190 #----------------------------------------------------------------
183 #----------------------------------------------------------------
191 # apply
184 # apply
192 #----------------------------------------------------------------
185 #----------------------------------------------------------------
193
186
194 @sync_results
187 @sync_results
195 @save_ids
188 @save_ids
196 def _really_apply(self, f, args, kwargs, block=None, **options):
189 def _really_apply(self, f, args, kwargs, block=None, **options):
197 """wrapper for client.send_apply_message"""
190 """wrapper for client.send_apply_message"""
198 raise NotImplementedError("Implement in subclasses")
191 raise NotImplementedError("Implement in subclasses")
199
192
200 def apply(self, f, *args, **kwargs):
193 def apply(self, f, *args, **kwargs):
201 """calls f(*args, **kwargs) on remote engines, returning the result.
194 """calls f(*args, **kwargs) on remote engines, returning the result.
202
195
203 This method sets all apply flags via this View's attributes.
196 This method sets all apply flags via this View's attributes.
204
197
205 if self.block is False:
198 if self.block is False:
206 returns AsyncResult
199 returns AsyncResult
207 else:
200 else:
208 returns actual result of f(*args, **kwargs)
201 returns actual result of f(*args, **kwargs)
209 """
202 """
210 return self._really_apply(f, args, kwargs)
203 return self._really_apply(f, args, kwargs)
211
204
212 def apply_async(self, f, *args, **kwargs):
205 def apply_async(self, f, *args, **kwargs):
213 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
206 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
214
207
215 returns AsyncResult
208 returns AsyncResult
216 """
209 """
217 return self._really_apply(f, args, kwargs, block=False)
210 return self._really_apply(f, args, kwargs, block=False)
218
211
219 @spin_after
212 @spin_after
220 def apply_sync(self, f, *args, **kwargs):
213 def apply_sync(self, f, *args, **kwargs):
221 """calls f(*args, **kwargs) on remote engines in a blocking manner,
214 """calls f(*args, **kwargs) on remote engines in a blocking manner,
222 returning the result.
215 returning the result.
223
216
224 returns: actual result of f(*args, **kwargs)
217 returns: actual result of f(*args, **kwargs)
225 """
218 """
226 return self._really_apply(f, args, kwargs, block=True)
219 return self._really_apply(f, args, kwargs, block=True)
227
220
228 #----------------------------------------------------------------
221 #----------------------------------------------------------------
229 # wrappers for client and control methods
222 # wrappers for client and control methods
230 #----------------------------------------------------------------
223 #----------------------------------------------------------------
231 @sync_results
224 @sync_results
232 def spin(self):
225 def spin(self):
233 """spin the client, and sync"""
226 """spin the client, and sync"""
234 self.client.spin()
227 self.client.spin()
235
228
236 @sync_results
229 @sync_results
237 def wait(self, jobs=None, timeout=-1):
230 def wait(self, jobs=None, timeout=-1):
238 """waits on one or more `jobs`, for up to `timeout` seconds.
231 """waits on one or more `jobs`, for up to `timeout` seconds.
239
232
240 Parameters
233 Parameters
241 ----------
234 ----------
242
235
243 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
236 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
244 ints are indices to self.history
237 ints are indices to self.history
245 strs are msg_ids
238 strs are msg_ids
246 default: wait on all outstanding messages
239 default: wait on all outstanding messages
247 timeout : float
240 timeout : float
248 a time in seconds, after which to give up.
241 a time in seconds, after which to give up.
249 default is -1, which means no timeout
242 default is -1, which means no timeout
250
243
251 Returns
244 Returns
252 -------
245 -------
253
246
254 True : when all msg_ids are done
247 True : when all msg_ids are done
255 False : timeout reached, some msg_ids still outstanding
248 False : timeout reached, some msg_ids still outstanding
256 """
249 """
257 if jobs is None:
250 if jobs is None:
258 jobs = self.history
251 jobs = self.history
259 return self.client.wait(jobs, timeout)
252 return self.client.wait(jobs, timeout)
260
253
261 def abort(self, jobs=None, block=None):
254 def abort(self, jobs=None, targets=None, block=None):
262 """Abort jobs on my engines.
255 """Abort jobs on my engines.
263
256
264 Parameters
257 Parameters
265 ----------
258 ----------
266
259
267 jobs : None, str, list of strs, optional
260 jobs : None, str, list of strs, optional
268 if None: abort all jobs.
261 if None: abort all jobs.
269 else: abort specific msg_id(s).
262 else: abort specific msg_id(s).
270 """
263 """
271 block = block if block is not None else self.block
264 block = block if block is not None else self.block
272 return self.client.abort(jobs=jobs, targets=self._targets, block=block)
265 targets = targets if targets is not None else self.targets
266 return self.client.abort(jobs=jobs, targets=targets, block=block)
273
267
274 def queue_status(self, verbose=False):
268 def queue_status(self, targets=None, verbose=False):
275 """Fetch the Queue status of my engines"""
269 """Fetch the Queue status of my engines"""
276 return self.client.queue_status(targets=self._targets, verbose=verbose)
270 targets = targets if targets is not None else self.targets
271 return self.client.queue_status(targets=targets, verbose=verbose)
277
272
278 def purge_results(self, jobs=[], targets=[]):
273 def purge_results(self, jobs=[], targets=[]):
279 """Instruct the controller to forget specific results."""
274 """Instruct the controller to forget specific results."""
280 if targets is None or targets == 'all':
275 if targets is None or targets == 'all':
281 targets = self._targets
276 targets = self.targets
282 return self.client.purge_results(jobs=jobs, targets=targets)
277 return self.client.purge_results(jobs=jobs, targets=targets)
283
278
284 @spin_after
279 @spin_after
285 def get_result(self, indices_or_msg_ids=None):
280 def get_result(self, indices_or_msg_ids=None):
286 """return one or more results, specified by history index or msg_id.
281 """return one or more results, specified by history index or msg_id.
287
282
288 See client.get_result for details.
283 See client.get_result for details.
289
284
290 """
285 """
291
286
292 if indices_or_msg_ids is None:
287 if indices_or_msg_ids is None:
293 indices_or_msg_ids = -1
288 indices_or_msg_ids = -1
294 if isinstance(indices_or_msg_ids, int):
289 if isinstance(indices_or_msg_ids, int):
295 indices_or_msg_ids = self.history[indices_or_msg_ids]
290 indices_or_msg_ids = self.history[indices_or_msg_ids]
296 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
291 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
297 indices_or_msg_ids = list(indices_or_msg_ids)
292 indices_or_msg_ids = list(indices_or_msg_ids)
298 for i,index in enumerate(indices_or_msg_ids):
293 for i,index in enumerate(indices_or_msg_ids):
299 if isinstance(index, int):
294 if isinstance(index, int):
300 indices_or_msg_ids[i] = self.history[index]
295 indices_or_msg_ids[i] = self.history[index]
301 return self.client.get_result(indices_or_msg_ids)
296 return self.client.get_result(indices_or_msg_ids)
302
297
303 #-------------------------------------------------------------------
298 #-------------------------------------------------------------------
304 # Map
299 # Map
305 #-------------------------------------------------------------------
300 #-------------------------------------------------------------------
306
301
307 def map(self, f, *sequences, **kwargs):
302 def map(self, f, *sequences, **kwargs):
308 """override in subclasses"""
303 """override in subclasses"""
309 raise NotImplementedError
304 raise NotImplementedError
310
305
311 def map_async(self, f, *sequences, **kwargs):
306 def map_async(self, f, *sequences, **kwargs):
312 """Parallel version of builtin `map`, using this view's engines.
307 """Parallel version of builtin `map`, using this view's engines.
313
308
314 This is equivalent to map(...block=False)
309 This is equivalent to map(...block=False)
315
310
316 See `self.map` for details.
311 See `self.map` for details.
317 """
312 """
318 if 'block' in kwargs:
313 if 'block' in kwargs:
319 raise TypeError("map_async doesn't take a `block` keyword argument.")
314 raise TypeError("map_async doesn't take a `block` keyword argument.")
320 kwargs['block'] = False
315 kwargs['block'] = False
321 return self.map(f,*sequences,**kwargs)
316 return self.map(f,*sequences,**kwargs)
322
317
323 def map_sync(self, f, *sequences, **kwargs):
318 def map_sync(self, f, *sequences, **kwargs):
324 """Parallel version of builtin `map`, using this view's engines.
319 """Parallel version of builtin `map`, using this view's engines.
325
320
326 This is equivalent to map(...block=True)
321 This is equivalent to map(...block=True)
327
322
328 See `self.map` for details.
323 See `self.map` for details.
329 """
324 """
330 if 'block' in kwargs:
325 if 'block' in kwargs:
331 raise TypeError("map_sync doesn't take a `block` keyword argument.")
326 raise TypeError("map_sync doesn't take a `block` keyword argument.")
332 kwargs['block'] = True
327 kwargs['block'] = True
333 return self.map(f,*sequences,**kwargs)
328 return self.map(f,*sequences,**kwargs)
334
329
335 def imap(self, f, *sequences, **kwargs):
330 def imap(self, f, *sequences, **kwargs):
336 """Parallel version of `itertools.imap`.
331 """Parallel version of `itertools.imap`.
337
332
338 See `self.map` for details.
333 See `self.map` for details.
339
334
340 """
335 """
341
336
342 return iter(self.map_async(f,*sequences, **kwargs))
337 return iter(self.map_async(f,*sequences, **kwargs))
343
338
344 #-------------------------------------------------------------------
339 #-------------------------------------------------------------------
345 # Decorators
340 # Decorators
346 #-------------------------------------------------------------------
341 #-------------------------------------------------------------------
347
342
348 def remote(self, block=True, **flags):
343 def remote(self, block=True, **flags):
349 """Decorator for making a RemoteFunction"""
344 """Decorator for making a RemoteFunction"""
350 block = self.block if block is None else block
345 block = self.block if block is None else block
351 return remote(self, block=block, **flags)
346 return remote(self, block=block, **flags)
352
347
353 def parallel(self, dist='b', block=None, **flags):
348 def parallel(self, dist='b', block=None, **flags):
354 """Decorator for making a ParallelFunction"""
349 """Decorator for making a ParallelFunction"""
355 block = self.block if block is None else block
350 block = self.block if block is None else block
356 return parallel(self, dist=dist, block=block, **flags)
351 return parallel(self, dist=dist, block=block, **flags)
357
352
358 @testdec.skip_doctest
353 @testdec.skip_doctest
359 class DirectView(View):
354 class DirectView(View):
360 """Direct Multiplexer View of one or more engines.
355 """Direct Multiplexer View of one or more engines.
361
356
362 These are created via indexed access to a client:
357 These are created via indexed access to a client:
363
358
364 >>> dv_1 = client[1]
359 >>> dv_1 = client[1]
365 >>> dv_all = client[:]
360 >>> dv_all = client[:]
366 >>> dv_even = client[::2]
361 >>> dv_even = client[::2]
367 >>> dv_some = client[1:3]
362 >>> dv_some = client[1:3]
368
363
369 This object provides dictionary access to engine namespaces:
364 This object provides dictionary access to engine namespaces:
370
365
371 # push a=5:
366 # push a=5:
372 >>> dv['a'] = 5
367 >>> dv['a'] = 5
373 # pull 'foo':
368 # pull 'foo':
374 >>> db['foo']
369 >>> db['foo']
375
370
376 """
371 """
377
372
378 def __init__(self, client=None, socket=None, targets=None):
373 def __init__(self, client=None, socket=None, targets=None):
379 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
374 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
375
376 @property
377 def importer(self):
378 """sync_imports(local=True) as a property.
380
379
380 See sync_imports for details.
381
382 In [10]: with v.importer:
383 ....: import numpy
384 ....:
385 importing numpy on engine(s)
386
387 """
388 return self.sync_imports(True)
389
390 @contextmanager
391 def sync_imports(self, local=True):
392 """Context Manager for performing simultaneous local and remote imports.
393
394 'import x as y' will *not* work. The 'as y' part will simply be ignored.
395
396 >>> with view.sync_imports():
397 ... from numpy import recarray
398 importing recarray from numpy on engine(s)
399
400 """
401 import __builtin__
402 local_import = __builtin__.__import__
403 modules = set()
404 results = []
405 @util.interactive
406 def remote_import(name, fromlist, level):
407 """the function to be passed to apply, that actually performs the import
408 on the engine, and loads up the user namespace.
409 """
410 import sys
411 user_ns = globals()
412 mod = __import__(name, fromlist=fromlist, level=level)
413 if fromlist:
414 for key in fromlist:
415 user_ns[key] = getattr(mod, key)
416 else:
417 user_ns[name] = sys.modules[name]
418
419 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
420 """the drop-in replacement for __import__, that optionally imports
421 locally as well.
422 """
423 # don't override nested imports
424 save_import = __builtin__.__import__
425 __builtin__.__import__ = local_import
426
427 if imp.lock_held():
428 # this is a side-effect import, don't do it remotely, or even
429 # ignore the local effects
430 return local_import(name, globals, locals, fromlist, level)
431
432 imp.acquire_lock()
433 if local:
434 mod = local_import(name, globals, locals, fromlist, level)
435 else:
436 raise NotImplementedError("remote-only imports not yet implemented")
437 imp.release_lock()
438
439 key = name+':'+','.join(fromlist or [])
440 if level == -1 and key not in modules:
441 modules.add(key)
442 if fromlist:
443 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
444 else:
445 print "importing %s on engine(s)"%name
446 results.append(self.apply_async(remote_import, name, fromlist, level))
447 # restore override
448 __builtin__.__import__ = save_import
449
450 return mod
451
452 # override __import__
453 __builtin__.__import__ = view_import
454 try:
455 # enter the block
456 yield
457 except ImportError:
458 if not local:
459 # ignore import errors if not doing local imports
460 pass
461 finally:
462 # always restore __import__
463 __builtin__.__import__ = local_import
464
465 for r in results:
466 # raise possible remote ImportErrors here
467 r.get()
468
381
469
382 @sync_results
470 @sync_results
383 @save_ids
471 @save_ids
384 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None):
472 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
385 """calls f(*args, **kwargs) on remote engines, returning the result.
473 """calls f(*args, **kwargs) on remote engines, returning the result.
386
474
387 This method sets all of `apply`'s flags via this View's attributes.
475 This method sets all of `apply`'s flags via this View's attributes.
388
476
389 Parameters
477 Parameters
390 ----------
478 ----------
391
479
392 f : callable
480 f : callable
393
481
394 args : list [default: empty]
482 args : list [default: empty]
395
483
396 kwargs : dict [default: empty]
484 kwargs : dict [default: empty]
397
485
486 targets : target list [default: self.targets]
487 where to run
398 block : bool [default: self.block]
488 block : bool [default: self.block]
399 whether to block
489 whether to block
400 track : bool [default: self.track]
490 track : bool [default: self.track]
401 whether to ask zmq to track the message, for safe non-copying sends
491 whether to ask zmq to track the message, for safe non-copying sends
402
492
403 Returns
493 Returns
404 -------
494 -------
405
495
406 if self.block is False:
496 if self.block is False:
407 returns AsyncResult
497 returns AsyncResult
408 else:
498 else:
409 returns actual result of f(*args, **kwargs) on the engine(s)
499 returns actual result of f(*args, **kwargs) on the engine(s)
410 This will be a list of self.targets is also a list (even length 1), or
500 This will be a list of self.targets is also a list (even length 1), or
411 the single result if self.targets is an integer engine id
501 the single result if self.targets is an integer engine id
412 """
502 """
413 args = [] if args is None else args
503 args = [] if args is None else args
414 kwargs = {} if kwargs is None else kwargs
504 kwargs = {} if kwargs is None else kwargs
415 block = self.block if block is None else block
505 block = self.block if block is None else block
416 track = self.track if track is None else track
506 track = self.track if track is None else track
507 targets = self.targets if targets is None else targets
508
509 _idents = self.client._build_targets(targets)[0]
417 msg_ids = []
510 msg_ids = []
418 trackers = []
511 trackers = []
419 for ident in self._idents:
512 for ident in _idents:
420 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
513 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
421 ident=ident)
514 ident=ident)
422 if track:
515 if track:
423 trackers.append(msg['tracker'])
516 trackers.append(msg['tracker'])
424 msg_ids.append(msg['msg_id'])
517 msg_ids.append(msg['msg_id'])
425 tracker = None if track is False else zmq.MessageTracker(*trackers)
518 tracker = None if track is False else zmq.MessageTracker(*trackers)
426 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=self._targets, tracker=tracker)
519 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
427 if block:
520 if block:
428 try:
521 try:
429 return ar.get()
522 return ar.get()
430 except KeyboardInterrupt:
523 except KeyboardInterrupt:
431 pass
524 pass
432 return ar
525 return ar
433
526
434 @spin_after
527 @spin_after
435 def map(self, f, *sequences, **kwargs):
528 def map(self, f, *sequences, **kwargs):
436 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
529 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
437
530
438 Parallel version of builtin `map`, using this View's `targets`.
531 Parallel version of builtin `map`, using this View's `targets`.
439
532
440 There will be one task per target, so work will be chunked
533 There will be one task per target, so work will be chunked
441 if the sequences are longer than `targets`.
534 if the sequences are longer than `targets`.
442
535
443 Results can be iterated as they are ready, but will become available in chunks.
536 Results can be iterated as they are ready, but will become available in chunks.
444
537
445 Parameters
538 Parameters
446 ----------
539 ----------
447
540
448 f : callable
541 f : callable
449 function to be mapped
542 function to be mapped
450 *sequences: one or more sequences of matching length
543 *sequences: one or more sequences of matching length
451 the sequences to be distributed and passed to `f`
544 the sequences to be distributed and passed to `f`
452 block : bool
545 block : bool
453 whether to wait for the result or not [default self.block]
546 whether to wait for the result or not [default self.block]
454
547
455 Returns
548 Returns
456 -------
549 -------
457
550
458 if block=False:
551 if block=False:
459 AsyncMapResult
552 AsyncMapResult
460 An object like AsyncResult, but which reassembles the sequence of results
553 An object like AsyncResult, but which reassembles the sequence of results
461 into a single list. AsyncMapResults can be iterated through before all
554 into a single list. AsyncMapResults can be iterated through before all
462 results are complete.
555 results are complete.
463 else:
556 else:
464 list
557 list
465 the result of map(f,*sequences)
558 the result of map(f,*sequences)
466 """
559 """
467
560
468 block = kwargs.pop('block', self.block)
561 block = kwargs.pop('block', self.block)
469 for k in kwargs.keys():
562 for k in kwargs.keys():
470 if k not in ['block', 'track']:
563 if k not in ['block', 'track']:
471 raise TypeError("invalid keyword arg, %r"%k)
564 raise TypeError("invalid keyword arg, %r"%k)
472
565
473 assert len(sequences) > 0, "must have some sequences to map onto!"
566 assert len(sequences) > 0, "must have some sequences to map onto!"
474 pf = ParallelFunction(self, f, block=block, **kwargs)
567 pf = ParallelFunction(self, f, block=block, **kwargs)
475 return pf.map(*sequences)
568 return pf.map(*sequences)
476
569
477 def execute(self, code, block=None):
570 def execute(self, code, targets=None, block=None):
478 """Executes `code` on `targets` in blocking or nonblocking manner.
571 """Executes `code` on `targets` in blocking or nonblocking manner.
479
572
480 ``execute`` is always `bound` (affects engine namespace)
573 ``execute`` is always `bound` (affects engine namespace)
481
574
482 Parameters
575 Parameters
483 ----------
576 ----------
484
577
485 code : str
578 code : str
486 the code string to be executed
579 the code string to be executed
487 block : bool
580 block : bool
488 whether or not to wait until done to return
581 whether or not to wait until done to return
489 default: self.block
582 default: self.block
490 """
583 """
491 return self._really_apply(util._execute, args=(code,), block=block)
584 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
492
585
493 def run(self, filename, block=None):
586 def run(self, filename, targets=None, block=None):
494 """Execute contents of `filename` on my engine(s).
587 """Execute contents of `filename` on my engine(s).
495
588
496 This simply reads the contents of the file and calls `execute`.
589 This simply reads the contents of the file and calls `execute`.
497
590
498 Parameters
591 Parameters
499 ----------
592 ----------
500
593
501 filename : str
594 filename : str
502 The path to the file
595 The path to the file
503 targets : int/str/list of ints/strs
596 targets : int/str/list of ints/strs
504 the engines on which to execute
597 the engines on which to execute
505 default : all
598 default : all
506 block : bool
599 block : bool
507 whether or not to wait until done
600 whether or not to wait until done
508 default: self.block
601 default: self.block
509
602
510 """
603 """
511 with open(filename, 'r') as f:
604 with open(filename, 'r') as f:
512 # add newline in case of trailing indented whitespace
605 # add newline in case of trailing indented whitespace
513 # which will cause SyntaxError
606 # which will cause SyntaxError
514 code = f.read()+'\n'
607 code = f.read()+'\n'
515 return self.execute(code, block=block)
608 return self.execute(code, block=block, targets=targets)
516
609
517 def update(self, ns):
610 def update(self, ns):
518 """update remote namespace with dict `ns`
611 """update remote namespace with dict `ns`
519
612
520 See `push` for details.
613 See `push` for details.
521 """
614 """
522 return self.push(ns, block=self.block, track=self.track)
615 return self.push(ns, block=self.block, track=self.track)
523
616
524 def push(self, ns, block=None, track=None):
617 def push(self, ns, targets=None, block=None, track=None):
525 """update remote namespace with dict `ns`
618 """update remote namespace with dict `ns`
526
619
527 Parameters
620 Parameters
528 ----------
621 ----------
529
622
530 ns : dict
623 ns : dict
531 dict of keys with which to update engine namespace(s)
624 dict of keys with which to update engine namespace(s)
532 block : bool [default : self.block]
625 block : bool [default : self.block]
533 whether to wait to be notified of engine receipt
626 whether to wait to be notified of engine receipt
534
627
535 """
628 """
536
629
537 block = block if block is not None else self.block
630 block = block if block is not None else self.block
538 track = track if track is not None else self.track
631 track = track if track is not None else self.track
632 targets = targets if targets is not None else self.targets
539 # applier = self.apply_sync if block else self.apply_async
633 # applier = self.apply_sync if block else self.apply_async
540 if not isinstance(ns, dict):
634 if not isinstance(ns, dict):
541 raise TypeError("Must be a dict, not %s"%type(ns))
635 raise TypeError("Must be a dict, not %s"%type(ns))
542 return self._really_apply(util._push, (ns,),block=block, track=track)
636 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
543
637
544 def get(self, key_s):
638 def get(self, key_s):
545 """get object(s) by `key_s` from remote namespace
639 """get object(s) by `key_s` from remote namespace
546
640
547 see `pull` for details.
641 see `pull` for details.
548 """
642 """
549 # block = block if block is not None else self.block
643 # block = block if block is not None else self.block
550 return self.pull(key_s, block=True)
644 return self.pull(key_s, block=True)
551
645
552 def pull(self, names, block=True):
646 def pull(self, names, targets=None, block=True):
553 """get object(s) by `name` from remote namespace
647 """get object(s) by `name` from remote namespace
554
648
555 will return one object if it is a key.
649 will return one object if it is a key.
556 can also take a list of keys, in which case it will return a list of objects.
650 can also take a list of keys, in which case it will return a list of objects.
557 """
651 """
558 block = block if block is not None else self.block
652 block = block if block is not None else self.block
653 targets = targets if targets is not None else self.targets
559 applier = self.apply_sync if block else self.apply_async
654 applier = self.apply_sync if block else self.apply_async
560 if isinstance(names, basestring):
655 if isinstance(names, basestring):
561 pass
656 pass
562 elif isinstance(names, (list,tuple,set)):
657 elif isinstance(names, (list,tuple,set)):
563 for key in names:
658 for key in names:
564 if not isinstance(key, basestring):
659 if not isinstance(key, basestring):
565 raise TypeError("keys must be str, not type %r"%type(key))
660 raise TypeError("keys must be str, not type %r"%type(key))
566 else:
661 else:
567 raise TypeError("names must be strs, not %r"%names)
662 raise TypeError("names must be strs, not %r"%names)
568 return applier(util._pull, names)
663 return self._really_apply(util._pull, (names,), block=block, targets=targets)
569
664
570 def scatter(self, key, seq, dist='b', flatten=False, block=None, track=None):
665 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
571 """
666 """
572 Partition a Python sequence and send the partitions to a set of engines.
667 Partition a Python sequence and send the partitions to a set of engines.
573 """
668 """
574 block = block if block is not None else self.block
669 block = block if block is not None else self.block
575 track = track if track is not None else self.track
670 track = track if track is not None else self.track
576 targets = self._targets
671 targets = targets if targets is not None else self.targets
672
577 mapObject = Map.dists[dist]()
673 mapObject = Map.dists[dist]()
578 nparts = len(targets)
674 nparts = len(targets)
579 msg_ids = []
675 msg_ids = []
580 trackers = []
676 trackers = []
581 for index, engineid in enumerate(targets):
677 for index, engineid in enumerate(targets):
582 push = self.client[engineid].push
583 partition = mapObject.getPartition(seq, index, nparts)
678 partition = mapObject.getPartition(seq, index, nparts)
584 if flatten and len(partition) == 1:
679 if flatten and len(partition) == 1:
585 r = push({key: partition[0]}, block=False, track=track)
680 ns = {key: partition[0]}
586 else:
681 else:
587 r = push({key: partition},block=False, track=track)
682 ns = {key: partition}
683 r = self.push(ns, block=False, track=track, targets=engineid)
588 msg_ids.extend(r.msg_ids)
684 msg_ids.extend(r.msg_ids)
589 if track:
685 if track:
590 trackers.append(r._tracker)
686 trackers.append(r._tracker)
591
687
592 if track:
688 if track:
593 tracker = zmq.MessageTracker(*trackers)
689 tracker = zmq.MessageTracker(*trackers)
594 else:
690 else:
595 tracker = None
691 tracker = None
596
692
597 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
693 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
598 if block:
694 if block:
599 r.wait()
695 r.wait()
600 else:
696 else:
601 return r
697 return r
602
698
603 @sync_results
699 @sync_results
604 @save_ids
700 @save_ids
605 def gather(self, key, dist='b', block=None):
701 def gather(self, key, dist='b', targets=None, block=None):
606 """
702 """
607 Gather a partitioned sequence on a set of engines as a single local seq.
703 Gather a partitioned sequence on a set of engines as a single local seq.
608 """
704 """
609 block = block if block is not None else self.block
705 block = block if block is not None else self.block
706 targets = targets if targets is not None else self.targets
610 mapObject = Map.dists[dist]()
707 mapObject = Map.dists[dist]()
611 msg_ids = []
708 msg_ids = []
612 for index, engineid in enumerate(self._targets):
709
613
710 for index, engineid in enumerate(targets):
614 msg_ids.extend(self.client[engineid].pull(key, block=False).msg_ids)
711 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
615
712
616 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
713 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
617
714
618 if block:
715 if block:
619 try:
716 try:
620 return r.get()
717 return r.get()
621 except KeyboardInterrupt:
718 except KeyboardInterrupt:
622 pass
719 pass
623 return r
720 return r
624
721
625 def __getitem__(self, key):
722 def __getitem__(self, key):
626 return self.get(key)
723 return self.get(key)
627
724
628 def __setitem__(self,key, value):
725 def __setitem__(self,key, value):
629 self.update({key:value})
726 self.update({key:value})
630
727
631 def clear(self, block=False):
728 def clear(self, targets=None, block=False):
632 """Clear the remote namespaces on my engines."""
729 """Clear the remote namespaces on my engines."""
633 block = block if block is not None else self.block
730 block = block if block is not None else self.block
634 return self.client.clear(targets=self._targets, block=block)
731 targets = targets if targets is not None else self.targets
732 return self.client.clear(targets=targets, block=block)
635
733
636 def kill(self, block=True):
734 def kill(self, targets=None, block=True):
637 """Kill my engines."""
735 """Kill my engines."""
638 block = block if block is not None else self.block
736 block = block if block is not None else self.block
639 return self.client.kill(targets=self._targets, block=block)
737 targets = targets if targets is not None else self.targets
738 return self.client.kill(targets=targets, block=block)
640
739
641 #----------------------------------------
740 #----------------------------------------
642 # activate for %px,%autopx magics
741 # activate for %px,%autopx magics
643 #----------------------------------------
742 #----------------------------------------
644 def activate(self):
743 def activate(self):
645 """Make this `View` active for parallel magic commands.
744 """Make this `View` active for parallel magic commands.
646
745
647 IPython has a magic command syntax to work with `MultiEngineClient` objects.
746 IPython has a magic command syntax to work with `MultiEngineClient` objects.
648 In a given IPython session there is a single active one. While
747 In a given IPython session there is a single active one. While
649 there can be many `Views` created and used by the user,
748 there can be many `Views` created and used by the user,
650 there is only one active one. The active `View` is used whenever
749 there is only one active one. The active `View` is used whenever
651 the magic commands %px and %autopx are used.
750 the magic commands %px and %autopx are used.
652
751
653 The activate() method is called on a given `View` to make it
752 The activate() method is called on a given `View` to make it
654 active. Once this has been done, the magic commands can be used.
753 active. Once this has been done, the magic commands can be used.
655 """
754 """
656
755
657 try:
756 try:
658 # This is injected into __builtins__.
757 # This is injected into __builtins__.
659 ip = get_ipython()
758 ip = get_ipython()
660 except NameError:
759 except NameError:
661 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
760 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
662 else:
761 else:
663 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
762 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
664 if pmagic is not None:
763 if pmagic is not None:
665 pmagic.active_multiengine_client = self
764 pmagic.active_multiengine_client = self
666 else:
765 else:
667 print "You must first load the parallelmagic extension " \
766 print "You must first load the parallelmagic extension " \
668 "by doing '%load_ext parallelmagic'"
767 "by doing '%load_ext parallelmagic'"
669
768
670
769
671 @testdec.skip_doctest
770 @testdec.skip_doctest
672 class LoadBalancedView(View):
771 class LoadBalancedView(View):
673 """An load-balancing View that only executes via the Task scheduler.
772 """An load-balancing View that only executes via the Task scheduler.
674
773
675 Load-balanced views can be created with the client's `view` method:
774 Load-balanced views can be created with the client's `view` method:
676
775
677 >>> v = client.load_balanced_view()
776 >>> v = client.load_balanced_view()
678
777
679 or targets can be specified, to restrict the potential destinations:
778 or targets can be specified, to restrict the potential destinations:
680
779
681 >>> v = client.client.load_balanced_view(([1,3])
780 >>> v = client.client.load_balanced_view(([1,3])
682
781
683 which would restrict loadbalancing to between engines 1 and 3.
782 which would restrict loadbalancing to between engines 1 and 3.
684
783
685 """
784 """
686
785
687 _flag_names = ['block', 'track', 'follow', 'after', 'timeout']
786 follow=Any()
787 after=Any()
788 timeout=CFloat()
688
789
689 def __init__(self, client=None, socket=None, targets=None):
790 _task_scheme = Any()
690 super(LoadBalancedView, self).__init__(client=client, socket=socket, targets=targets)
791 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout'])
691 self._ntargets = 1
792
793 def __init__(self, client=None, socket=None, **flags):
794 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
692 self._task_scheme=client._task_scheme
795 self._task_scheme=client._task_scheme
693 if targets is None:
694 self._targets = None
695 self._idents=[]
696
796
697 def _validate_dependency(self, dep):
797 def _validate_dependency(self, dep):
698 """validate a dependency.
798 """validate a dependency.
699
799
700 For use in `set_flags`.
800 For use in `set_flags`.
701 """
801 """
702 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
802 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
703 return True
803 return True
704 elif isinstance(dep, (list,set, tuple)):
804 elif isinstance(dep, (list,set, tuple)):
705 for d in dep:
805 for d in dep:
706 if not isinstance(d, (str, AsyncResult)):
806 if not isinstance(d, (str, AsyncResult)):
707 return False
807 return False
708 elif isinstance(dep, dict):
808 elif isinstance(dep, dict):
709 if set(dep.keys()) != set(Dependency().as_dict().keys()):
809 if set(dep.keys()) != set(Dependency().as_dict().keys()):
710 return False
810 return False
711 if not isinstance(dep['msg_ids'], list):
811 if not isinstance(dep['msg_ids'], list):
712 return False
812 return False
713 for d in dep['msg_ids']:
813 for d in dep['msg_ids']:
714 if not isinstance(d, str):
814 if not isinstance(d, str):
715 return False
815 return False
716 else:
816 else:
717 return False
817 return False
718
818
719 return True
819 return True
720
820
721 def _render_dependency(self, dep):
821 def _render_dependency(self, dep):
722 """helper for building jsonable dependencies from various input forms."""
822 """helper for building jsonable dependencies from various input forms."""
723 if isinstance(dep, Dependency):
823 if isinstance(dep, Dependency):
724 return dep.as_dict()
824 return dep.as_dict()
725 elif isinstance(dep, AsyncResult):
825 elif isinstance(dep, AsyncResult):
726 return dep.msg_ids
826 return dep.msg_ids
727 elif dep is None:
827 elif dep is None:
728 return []
828 return []
729 else:
829 else:
730 # pass to Dependency constructor
830 # pass to Dependency constructor
731 return list(Dependency(dep))
831 return list(Dependency(dep))
732
832
733 def set_flags(self, **kwargs):
833 def set_flags(self, **kwargs):
734 """set my attribute flags by keyword.
834 """set my attribute flags by keyword.
735
835
736 A View is a wrapper for the Client's apply method, but with attributes
836 A View is a wrapper for the Client's apply method, but with attributes
737 that specify keyword arguments, those attributes can be set by keyword
837 that specify keyword arguments, those attributes can be set by keyword
738 argument with this method.
838 argument with this method.
739
839
740 Parameters
840 Parameters
741 ----------
841 ----------
742
842
743 block : bool
843 block : bool
744 whether to wait for results
844 whether to wait for results
745 track : bool
845 track : bool
746 whether to create a MessageTracker to allow the user to
846 whether to create a MessageTracker to allow the user to
747 safely edit after arrays and buffers during non-copying
847 safely edit after arrays and buffers during non-copying
748 sends.
848 sends.
749 #
849 #
750 after : Dependency or collection of msg_ids
850 after : Dependency or collection of msg_ids
751 Only for load-balanced execution (targets=None)
851 Only for load-balanced execution (targets=None)
752 Specify a list of msg_ids as a time-based dependency.
852 Specify a list of msg_ids as a time-based dependency.
753 This job will only be run *after* the dependencies
853 This job will only be run *after* the dependencies
754 have been met.
854 have been met.
755
855
756 follow : Dependency or collection of msg_ids
856 follow : Dependency or collection of msg_ids
757 Only for load-balanced execution (targets=None)
857 Only for load-balanced execution (targets=None)
758 Specify a list of msg_ids as a location-based dependency.
858 Specify a list of msg_ids as a location-based dependency.
759 This job will only be run on an engine where this dependency
859 This job will only be run on an engine where this dependency
760 is met.
860 is met.
761
861
762 timeout : float/int or None
862 timeout : float/int or None
763 Only for load-balanced execution (targets=None)
863 Only for load-balanced execution (targets=None)
764 Specify an amount of time (in seconds) for the scheduler to
864 Specify an amount of time (in seconds) for the scheduler to
765 wait for dependencies to be met before failing with a
865 wait for dependencies to be met before failing with a
766 DependencyTimeout.
866 DependencyTimeout.
767 """
867 """
768
868
769 super(LoadBalancedView, self).set_flags(**kwargs)
869 super(LoadBalancedView, self).set_flags(**kwargs)
770 for name in ('follow', 'after'):
870 for name in ('follow', 'after'):
771 if name in kwargs:
871 if name in kwargs:
772 value = kwargs[name]
872 value = kwargs[name]
773 if self._validate_dependency(value):
873 if self._validate_dependency(value):
774 setattr(self, name, value)
874 setattr(self, name, value)
775 else:
875 else:
776 raise ValueError("Invalid dependency: %r"%value)
876 raise ValueError("Invalid dependency: %r"%value)
777 if 'timeout' in kwargs:
877 if 'timeout' in kwargs:
778 t = kwargs['timeout']
878 t = kwargs['timeout']
779 if not isinstance(t, (int, long, float, type(None))):
879 if not isinstance(t, (int, long, float, type(None))):
780 raise TypeError("Invalid type for timeout: %r"%type(t))
880 raise TypeError("Invalid type for timeout: %r"%type(t))
781 if t is not None:
881 if t is not None:
782 if t < 0:
882 if t < 0:
783 raise ValueError("Invalid timeout: %s"%t)
883 raise ValueError("Invalid timeout: %s"%t)
784 self.timeout = t
884 self.timeout = t
785
885
786 @sync_results
886 @sync_results
787 @save_ids
887 @save_ids
788 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
888 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
789 after=None, follow=None, timeout=None):
889 after=None, follow=None, timeout=None,
890 targets=None):
790 """calls f(*args, **kwargs) on a remote engine, returning the result.
891 """calls f(*args, **kwargs) on a remote engine, returning the result.
791
892
792 This method temporarily sets all of `apply`'s flags for a single call.
893 This method temporarily sets all of `apply`'s flags for a single call.
793
894
794 Parameters
895 Parameters
795 ----------
896 ----------
796
897
797 f : callable
898 f : callable
798
899
799 args : list [default: empty]
900 args : list [default: empty]
800
901
801 kwargs : dict [default: empty]
902 kwargs : dict [default: empty]
802
903
803 block : bool [default: self.block]
904 block : bool [default: self.block]
804 whether to block
905 whether to block
805 track : bool [default: self.track]
906 track : bool [default: self.track]
806 whether to ask zmq to track the message, for safe non-copying sends
907 whether to ask zmq to track the message, for safe non-copying sends
807
908
808 !!!!!! TODO: THE REST HERE !!!!
909 !!!!!! TODO: THE REST HERE !!!!
809
910
810 Returns
911 Returns
811 -------
912 -------
812
913
813 if self.block is False:
914 if self.block is False:
814 returns AsyncResult
915 returns AsyncResult
815 else:
916 else:
816 returns actual result of f(*args, **kwargs) on the engine(s)
917 returns actual result of f(*args, **kwargs) on the engine(s)
817 This will be a list of self.targets is also a list (even length 1), or
918 This will be a list of self.targets is also a list (even length 1), or
818 the single result if self.targets is an integer engine id
919 the single result if self.targets is an integer engine id
819 """
920 """
820
921
821 # validate whether we can run
922 # validate whether we can run
822 if self._socket.closed:
923 if self._socket.closed:
823 msg = "Task farming is disabled"
924 msg = "Task farming is disabled"
824 if self._task_scheme == 'pure':
925 if self._task_scheme == 'pure':
825 msg += " because the pure ZMQ scheduler cannot handle"
926 msg += " because the pure ZMQ scheduler cannot handle"
826 msg += " disappearing engines."
927 msg += " disappearing engines."
827 raise RuntimeError(msg)
928 raise RuntimeError(msg)
828
929
829 if self._task_scheme == 'pure':
930 if self._task_scheme == 'pure':
830 # pure zmq scheme doesn't support dependencies
931 # pure zmq scheme doesn't support dependencies
831 msg = "Pure ZMQ scheduler doesn't support dependencies"
932 msg = "Pure ZMQ scheduler doesn't support dependencies"
832 if (follow or after):
933 if (follow or after):
833 # hard fail on DAG dependencies
934 # hard fail on DAG dependencies
834 raise RuntimeError(msg)
935 raise RuntimeError(msg)
835 if isinstance(f, dependent):
936 if isinstance(f, dependent):
836 # soft warn on functional dependencies
937 # soft warn on functional dependencies
837 warnings.warn(msg, RuntimeWarning)
938 warnings.warn(msg, RuntimeWarning)
838
939
839 # build args
940 # build args
840 args = [] if args is None else args
941 args = [] if args is None else args
841 kwargs = {} if kwargs is None else kwargs
942 kwargs = {} if kwargs is None else kwargs
842 block = self.block if block is None else block
943 block = self.block if block is None else block
843 track = self.track if track is None else track
944 track = self.track if track is None else track
844 after = self.after if after is None else after
945 after = self.after if after is None else after
845 follow = self.follow if follow is None else follow
946 follow = self.follow if follow is None else follow
846 timeout = self.timeout if timeout is None else timeout
947 timeout = self.timeout if timeout is None else timeout
948 targets = self.targets if targets is None else targets
949
950 if targets is None:
951 idents = []
952 else:
953 idents = self.client._build_targets(targets)[0]
954
847 after = self._render_dependency(after)
955 after = self._render_dependency(after)
848 follow = self._render_dependency(follow)
956 follow = self._render_dependency(follow)
849 subheader = dict(after=after, follow=follow, timeout=timeout, targets=self._idents)
957 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
850
958
851 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
959 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
852 subheader=subheader)
960 subheader=subheader)
853 tracker = None if track is False else msg['tracker']
961 tracker = None if track is False else msg['tracker']
854
962
855 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
963 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
856
964
857 if block:
965 if block:
858 try:
966 try:
859 return ar.get()
967 return ar.get()
860 except KeyboardInterrupt:
968 except KeyboardInterrupt:
861 pass
969 pass
862 return ar
970 return ar
863
971
864 @spin_after
972 @spin_after
865 @save_ids
973 @save_ids
866 def map(self, f, *sequences, **kwargs):
974 def map(self, f, *sequences, **kwargs):
867 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
975 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
868
976
869 Parallel version of builtin `map`, load-balanced by this View.
977 Parallel version of builtin `map`, load-balanced by this View.
870
978
871 `block`, and `chunksize` can be specified by keyword only.
979 `block`, and `chunksize` can be specified by keyword only.
872
980
873 Each `chunksize` elements will be a separate task, and will be
981 Each `chunksize` elements will be a separate task, and will be
874 load-balanced. This lets individual elements be available for iteration
982 load-balanced. This lets individual elements be available for iteration
875 as soon as they arrive.
983 as soon as they arrive.
876
984
877 Parameters
985 Parameters
878 ----------
986 ----------
879
987
880 f : callable
988 f : callable
881 function to be mapped
989 function to be mapped
882 *sequences: one or more sequences of matching length
990 *sequences: one or more sequences of matching length
883 the sequences to be distributed and passed to `f`
991 the sequences to be distributed and passed to `f`
884 block : bool
992 block : bool
885 whether to wait for the result or not [default self.block]
993 whether to wait for the result or not [default self.block]
886 track : bool
994 track : bool
887 whether to create a MessageTracker to allow the user to
995 whether to create a MessageTracker to allow the user to
888 safely edit after arrays and buffers during non-copying
996 safely edit after arrays and buffers during non-copying
889 sends.
997 sends.
890 chunksize : int
998 chunksize : int
891 how many elements should be in each task [default 1]
999 how many elements should be in each task [default 1]
892
1000
893 Returns
1001 Returns
894 -------
1002 -------
895
1003
896 if block=False:
1004 if block=False:
897 AsyncMapResult
1005 AsyncMapResult
898 An object like AsyncResult, but which reassembles the sequence of results
1006 An object like AsyncResult, but which reassembles the sequence of results
899 into a single list. AsyncMapResults can be iterated through before all
1007 into a single list. AsyncMapResults can be iterated through before all
900 results are complete.
1008 results are complete.
901 else:
1009 else:
902 the result of map(f,*sequences)
1010 the result of map(f,*sequences)
903
1011
904 """
1012 """
905
1013
906 # default
1014 # default
907 block = kwargs.get('block', self.block)
1015 block = kwargs.get('block', self.block)
908 chunksize = kwargs.get('chunksize', 1)
1016 chunksize = kwargs.get('chunksize', 1)
909
1017
910 keyset = set(kwargs.keys())
1018 keyset = set(kwargs.keys())
911 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1019 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
912 if extra_keys:
1020 if extra_keys:
913 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1021 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
914
1022
915 assert len(sequences) > 0, "must have some sequences to map onto!"
1023 assert len(sequences) > 0, "must have some sequences to map onto!"
916
1024
917 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1025 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
918 return pf.map(*sequences)
1026 return pf.map(*sequences)
919
1027
920 __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file
1028 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now