##// END OF EJS Templates
pyzmq-2.1.3 related testing adjustments
MinRK -
Show More
@@ -1,18 +1,23 b''
1 """The IPython ZMQ-based parallel computing interface."""
1 """The IPython ZMQ-based parallel computing interface."""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
3 # Copyright (C) 2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 # from .asyncresult import *
13 # from .asyncresult import *
14 # from .client import Client
14 # from .client import Client
15 # from .dependency import *
15 # from .dependency import *
16 # from .remotefunction import *
16 # from .remotefunction import *
17 # from .view import *
17 # from .view import *
18
18
19 import zmq
20
21 if zmq.__version__ < '2.1.3':
22 raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
23
@@ -1,1584 +1,1591 b''
1 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.pickleutil import Reference
27 from IPython.utils.pickleutil import Reference
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 Dict, List, Bool, Str, Set)
29 Dict, List, Bool, Str, Set)
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.external.ssh import tunnel
31 from IPython.external.ssh import tunnel
32
32
33 from . import error
33 from . import error
34 from . import map as Map
34 from . import map as Map
35 from . import util
35 from . import util
36 from . import streamsession as ss
36 from . import streamsession as ss
37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
38 from .clusterdir import ClusterDir, ClusterDirError
38 from .clusterdir import ClusterDir, ClusterDirError
39 from .dependency import Dependency, depend, require, dependent
39 from .dependency import Dependency, depend, require, dependent
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 from .util import ReverseDict, validate_url, disambiguate_url
41 from .util import ReverseDict, validate_url, disambiguate_url
42 from .view import DirectView, LoadBalancedView
42 from .view import DirectView, LoadBalancedView
43
43
44 #--------------------------------------------------------------------------
44 #--------------------------------------------------------------------------
45 # helpers for implementing old MEC API via client.apply
45 # helpers for implementing old MEC API via client.apply
46 #--------------------------------------------------------------------------
46 #--------------------------------------------------------------------------
47
47
48 def _push(user_ns, **ns):
48 def _push(user_ns, **ns):
49 """helper method for implementing `client.push` via `client.apply`"""
49 """helper method for implementing `client.push` via `client.apply`"""
50 user_ns.update(ns)
50 user_ns.update(ns)
51
51
52 def _pull(user_ns, keys):
52 def _pull(user_ns, keys):
53 """helper method for implementing `client.pull` via `client.apply`"""
53 """helper method for implementing `client.pull` via `client.apply`"""
54 if isinstance(keys, (list,tuple, set)):
54 if isinstance(keys, (list,tuple, set)):
55 for key in keys:
55 for key in keys:
56 if not user_ns.has_key(key):
56 if not user_ns.has_key(key):
57 raise NameError("name '%s' is not defined"%key)
57 raise NameError("name '%s' is not defined"%key)
58 return map(user_ns.get, keys)
58 return map(user_ns.get, keys)
59 else:
59 else:
60 if not user_ns.has_key(keys):
60 if not user_ns.has_key(keys):
61 raise NameError("name '%s' is not defined"%keys)
61 raise NameError("name '%s' is not defined"%keys)
62 return user_ns.get(keys)
62 return user_ns.get(keys)
63
63
64 def _clear(user_ns):
64 def _clear(user_ns):
65 """helper method for implementing `client.clear` via `client.apply`"""
65 """helper method for implementing `client.clear` via `client.apply`"""
66 user_ns.clear()
66 user_ns.clear()
67
67
68 def _execute(user_ns, code):
68 def _execute(user_ns, code):
69 """helper method for implementing `client.execute` via `client.apply`"""
69 """helper method for implementing `client.execute` via `client.apply`"""
70 exec code in user_ns
70 exec code in user_ns
71
71
72
72
73 #--------------------------------------------------------------------------
73 #--------------------------------------------------------------------------
74 # Decorators for Client methods
74 # Decorators for Client methods
75 #--------------------------------------------------------------------------
75 #--------------------------------------------------------------------------
76
76
77 @decorator
77 @decorator
78 def spinfirst(f, self, *args, **kwargs):
78 def spinfirst(f, self, *args, **kwargs):
79 """Call spin() to sync state prior to calling the method."""
79 """Call spin() to sync state prior to calling the method."""
80 self.spin()
80 self.spin()
81 return f(self, *args, **kwargs)
81 return f(self, *args, **kwargs)
82
82
83 @decorator
83 @decorator
84 def defaultblock(f, self, *args, **kwargs):
84 def defaultblock(f, self, *args, **kwargs):
85 """Default to self.block; preserve self.block."""
85 """Default to self.block; preserve self.block."""
86 block = kwargs.get('block',None)
86 block = kwargs.get('block',None)
87 block = self.block if block is None else block
87 block = self.block if block is None else block
88 saveblock = self.block
88 saveblock = self.block
89 self.block = block
89 self.block = block
90 try:
90 try:
91 ret = f(self, *args, **kwargs)
91 ret = f(self, *args, **kwargs)
92 finally:
92 finally:
93 self.block = saveblock
93 self.block = saveblock
94 return ret
94 return ret
95
95
96
96
97 #--------------------------------------------------------------------------
97 #--------------------------------------------------------------------------
98 # Classes
98 # Classes
99 #--------------------------------------------------------------------------
99 #--------------------------------------------------------------------------
100
100
101 class Metadata(dict):
101 class Metadata(dict):
102 """Subclass of dict for initializing metadata values.
102 """Subclass of dict for initializing metadata values.
103
103
104 Attribute access works on keys.
104 Attribute access works on keys.
105
105
106 These objects have a strict set of keys - errors will raise if you try
106 These objects have a strict set of keys - errors will raise if you try
107 to add new keys.
107 to add new keys.
108 """
108 """
109 def __init__(self, *args, **kwargs):
109 def __init__(self, *args, **kwargs):
110 dict.__init__(self)
110 dict.__init__(self)
111 md = {'msg_id' : None,
111 md = {'msg_id' : None,
112 'submitted' : None,
112 'submitted' : None,
113 'started' : None,
113 'started' : None,
114 'completed' : None,
114 'completed' : None,
115 'received' : None,
115 'received' : None,
116 'engine_uuid' : None,
116 'engine_uuid' : None,
117 'engine_id' : None,
117 'engine_id' : None,
118 'follow' : None,
118 'follow' : None,
119 'after' : None,
119 'after' : None,
120 'status' : None,
120 'status' : None,
121
121
122 'pyin' : None,
122 'pyin' : None,
123 'pyout' : None,
123 'pyout' : None,
124 'pyerr' : None,
124 'pyerr' : None,
125 'stdout' : '',
125 'stdout' : '',
126 'stderr' : '',
126 'stderr' : '',
127 }
127 }
128 self.update(md)
128 self.update(md)
129 self.update(dict(*args, **kwargs))
129 self.update(dict(*args, **kwargs))
130
130
131 def __getattr__(self, key):
131 def __getattr__(self, key):
132 """getattr aliased to getitem"""
132 """getattr aliased to getitem"""
133 if key in self.iterkeys():
133 if key in self.iterkeys():
134 return self[key]
134 return self[key]
135 else:
135 else:
136 raise AttributeError(key)
136 raise AttributeError(key)
137
137
138 def __setattr__(self, key, value):
138 def __setattr__(self, key, value):
139 """setattr aliased to setitem, with strict"""
139 """setattr aliased to setitem, with strict"""
140 if key in self.iterkeys():
140 if key in self.iterkeys():
141 self[key] = value
141 self[key] = value
142 else:
142 else:
143 raise AttributeError(key)
143 raise AttributeError(key)
144
144
145 def __setitem__(self, key, value):
145 def __setitem__(self, key, value):
146 """strict static key enforcement"""
146 """strict static key enforcement"""
147 if key in self.iterkeys():
147 if key in self.iterkeys():
148 dict.__setitem__(self, key, value)
148 dict.__setitem__(self, key, value)
149 else:
149 else:
150 raise KeyError(key)
150 raise KeyError(key)
151
151
152
152
153 class Client(HasTraits):
153 class Client(HasTraits):
154 """A semi-synchronous client to the IPython ZMQ controller
154 """A semi-synchronous client to the IPython ZMQ controller
155
155
156 Parameters
156 Parameters
157 ----------
157 ----------
158
158
159 url_or_file : bytes; zmq url or path to ipcontroller-client.json
159 url_or_file : bytes; zmq url or path to ipcontroller-client.json
160 Connection information for the Hub's registration. If a json connector
160 Connection information for the Hub's registration. If a json connector
161 file is given, then likely no further configuration is necessary.
161 file is given, then likely no further configuration is necessary.
162 [Default: use profile]
162 [Default: use profile]
163 profile : bytes
163 profile : bytes
164 The name of the Cluster profile to be used to find connector information.
164 The name of the Cluster profile to be used to find connector information.
165 [Default: 'default']
165 [Default: 'default']
166 context : zmq.Context
166 context : zmq.Context
167 Pass an existing zmq.Context instance, otherwise the client will create its own.
167 Pass an existing zmq.Context instance, otherwise the client will create its own.
168 username : bytes
168 username : bytes
169 set username to be passed to the Session object
169 set username to be passed to the Session object
170 debug : bool
170 debug : bool
171 flag for lots of message printing for debug purposes
171 flag for lots of message printing for debug purposes
172
172
173 #-------------- ssh related args ----------------
173 #-------------- ssh related args ----------------
174 # These are args for configuring the ssh tunnel to be used
174 # These are args for configuring the ssh tunnel to be used
175 # credentials are used to forward connections over ssh to the Controller
175 # credentials are used to forward connections over ssh to the Controller
176 # Note that the ip given in `addr` needs to be relative to sshserver
176 # Note that the ip given in `addr` needs to be relative to sshserver
177 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
177 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
178 # and set sshserver as the same machine the Controller is on. However,
178 # and set sshserver as the same machine the Controller is on. However,
179 # the only requirement is that sshserver is able to see the Controller
179 # the only requirement is that sshserver is able to see the Controller
180 # (i.e. is within the same trusted network).
180 # (i.e. is within the same trusted network).
181
181
182 sshserver : str
182 sshserver : str
183 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
183 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
184 If keyfile or password is specified, and this is not, it will default to
184 If keyfile or password is specified, and this is not, it will default to
185 the ip given in addr.
185 the ip given in addr.
186 sshkey : str; path to public ssh key file
186 sshkey : str; path to public ssh key file
187 This specifies a key to be used in ssh login, default None.
187 This specifies a key to be used in ssh login, default None.
188 Regular default ssh keys will be used without specifying this argument.
188 Regular default ssh keys will be used without specifying this argument.
189 password : str
189 password : str
190 Your ssh password to sshserver. Note that if this is left None,
190 Your ssh password to sshserver. Note that if this is left None,
191 you will be prompted for it if passwordless key based login is unavailable.
191 you will be prompted for it if passwordless key based login is unavailable.
192 paramiko : bool
192 paramiko : bool
193 flag for whether to use paramiko instead of shell ssh for tunneling.
193 flag for whether to use paramiko instead of shell ssh for tunneling.
194 [default: True on win32, False else]
194 [default: True on win32, False else]
195
195
196 #------- exec authentication args -------
196 #------- exec authentication args -------
197 # If even localhost is untrusted, you can have some protection against
197 # If even localhost is untrusted, you can have some protection against
198 # unauthorized execution by using a key. Messages are still sent
198 # unauthorized execution by using a key. Messages are still sent
199 # as cleartext, so if someone can snoop your loopback traffic this will
199 # as cleartext, so if someone can snoop your loopback traffic this will
200 # not help against malicious attacks.
200 # not help against malicious attacks.
201
201
202 exec_key : str
202 exec_key : str
203 an authentication key or file containing a key
203 an authentication key or file containing a key
204 default: None
204 default: None
205
205
206
206
207 Attributes
207 Attributes
208 ----------
208 ----------
209
209
210 ids : set of int engine IDs
210 ids : set of int engine IDs
211 requesting the ids attribute always synchronizes
211 requesting the ids attribute always synchronizes
212 the registration state. To request ids without synchronization,
212 the registration state. To request ids without synchronization,
213 use semi-private _ids attributes.
213 use semi-private _ids attributes.
214
214
215 history : list of msg_ids
215 history : list of msg_ids
216 a list of msg_ids, keeping track of all the execution
216 a list of msg_ids, keeping track of all the execution
217 messages you have submitted in order.
217 messages you have submitted in order.
218
218
219 outstanding : set of msg_ids
219 outstanding : set of msg_ids
220 a set of msg_ids that have been submitted, but whose
220 a set of msg_ids that have been submitted, but whose
221 results have not yet been received.
221 results have not yet been received.
222
222
223 results : dict
223 results : dict
224 a dict of all our results, keyed by msg_id
224 a dict of all our results, keyed by msg_id
225
225
226 block : bool
226 block : bool
227 determines default behavior when block not specified
227 determines default behavior when block not specified
228 in execution methods
228 in execution methods
229
229
230 Methods
230 Methods
231 -------
231 -------
232
232
233 spin
233 spin
234 flushes incoming results and registration state changes
234 flushes incoming results and registration state changes
235 control methods spin, and requesting `ids` also ensures up to date
235 control methods spin, and requesting `ids` also ensures up to date
236
236
237 barrier
237 barrier
238 wait on one or more msg_ids
238 wait on one or more msg_ids
239
239
240 execution methods
240 execution methods
241 apply
241 apply
242 legacy: execute, run
242 legacy: execute, run
243
243
244 query methods
244 query methods
245 queue_status, get_result, purge
245 queue_status, get_result, purge
246
246
247 control methods
247 control methods
248 abort, shutdown
248 abort, shutdown
249
249
250 """
250 """
251
251
252
252
253 block = Bool(False)
253 block = Bool(False)
254 outstanding = Set()
254 outstanding = Set()
255 results = Instance('collections.defaultdict', (dict,))
255 results = Instance('collections.defaultdict', (dict,))
256 metadata = Instance('collections.defaultdict', (Metadata,))
256 metadata = Instance('collections.defaultdict', (Metadata,))
257 history = List()
257 history = List()
258 debug = Bool(False)
258 debug = Bool(False)
259 profile=CUnicode('default')
259 profile=CUnicode('default')
260
260
261 _outstanding_dict = Instance('collections.defaultdict', (set,))
261 _outstanding_dict = Instance('collections.defaultdict', (set,))
262 _ids = List()
262 _ids = List()
263 _connected=Bool(False)
263 _connected=Bool(False)
264 _ssh=Bool(False)
264 _ssh=Bool(False)
265 _context = Instance('zmq.Context')
265 _context = Instance('zmq.Context')
266 _config = Dict()
266 _config = Dict()
267 _engines=Instance(ReverseDict, (), {})
267 _engines=Instance(ReverseDict, (), {})
268 # _hub_socket=Instance('zmq.Socket')
268 # _hub_socket=Instance('zmq.Socket')
269 _query_socket=Instance('zmq.Socket')
269 _query_socket=Instance('zmq.Socket')
270 _control_socket=Instance('zmq.Socket')
270 _control_socket=Instance('zmq.Socket')
271 _iopub_socket=Instance('zmq.Socket')
271 _iopub_socket=Instance('zmq.Socket')
272 _notification_socket=Instance('zmq.Socket')
272 _notification_socket=Instance('zmq.Socket')
273 _apply_socket=Instance('zmq.Socket')
273 _apply_socket=Instance('zmq.Socket')
274 _mux_ident=Str()
274 _mux_ident=Str()
275 _task_ident=Str()
275 _task_ident=Str()
276 _task_scheme=Str()
276 _task_scheme=Str()
277 _balanced_views=Dict()
277 _balanced_views=Dict()
278 _direct_views=Dict()
278 _direct_views=Dict()
279 _closed = False
279 _closed = False
280
280
281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
282 context=None, username=None, debug=False, exec_key=None,
282 context=None, username=None, debug=False, exec_key=None,
283 sshserver=None, sshkey=None, password=None, paramiko=None,
283 sshserver=None, sshkey=None, password=None, paramiko=None,
284 ):
284 ):
285 super(Client, self).__init__(debug=debug, profile=profile)
285 super(Client, self).__init__(debug=debug, profile=profile)
286 if context is None:
286 if context is None:
287 context = zmq.Context()
287 context = zmq.Context.instance()
288 self._context = context
288 self._context = context
289
289
290
290
291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
292 if self._cd is not None:
292 if self._cd is not None:
293 if url_or_file is None:
293 if url_or_file is None:
294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
296 " Please specify at least one of url_or_file or profile."
296 " Please specify at least one of url_or_file or profile."
297
297
298 try:
298 try:
299 validate_url(url_or_file)
299 validate_url(url_or_file)
300 except AssertionError:
300 except AssertionError:
301 if not os.path.exists(url_or_file):
301 if not os.path.exists(url_or_file):
302 if self._cd:
302 if self._cd:
303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
305 with open(url_or_file) as f:
305 with open(url_or_file) as f:
306 cfg = json.loads(f.read())
306 cfg = json.loads(f.read())
307 else:
307 else:
308 cfg = {'url':url_or_file}
308 cfg = {'url':url_or_file}
309
309
310 # sync defaults from args, json:
310 # sync defaults from args, json:
311 if sshserver:
311 if sshserver:
312 cfg['ssh'] = sshserver
312 cfg['ssh'] = sshserver
313 if exec_key:
313 if exec_key:
314 cfg['exec_key'] = exec_key
314 cfg['exec_key'] = exec_key
315 exec_key = cfg['exec_key']
315 exec_key = cfg['exec_key']
316 sshserver=cfg['ssh']
316 sshserver=cfg['ssh']
317 url = cfg['url']
317 url = cfg['url']
318 location = cfg.setdefault('location', None)
318 location = cfg.setdefault('location', None)
319 cfg['url'] = disambiguate_url(cfg['url'], location)
319 cfg['url'] = disambiguate_url(cfg['url'], location)
320 url = cfg['url']
320 url = cfg['url']
321
321
322 self._config = cfg
322 self._config = cfg
323
323
324 self._ssh = bool(sshserver or sshkey or password)
324 self._ssh = bool(sshserver or sshkey or password)
325 if self._ssh and sshserver is None:
325 if self._ssh and sshserver is None:
326 # default to ssh via localhost
326 # default to ssh via localhost
327 sshserver = url.split('://')[1].split(':')[0]
327 sshserver = url.split('://')[1].split(':')[0]
328 if self._ssh and password is None:
328 if self._ssh and password is None:
329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
330 password=False
330 password=False
331 else:
331 else:
332 password = getpass("SSH Password for %s: "%sshserver)
332 password = getpass("SSH Password for %s: "%sshserver)
333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
334 if exec_key is not None and os.path.isfile(exec_key):
334 if exec_key is not None and os.path.isfile(exec_key):
335 arg = 'keyfile'
335 arg = 'keyfile'
336 else:
336 else:
337 arg = 'key'
337 arg = 'key'
338 key_arg = {arg:exec_key}
338 key_arg = {arg:exec_key}
339 if username is None:
339 if username is None:
340 self.session = ss.StreamSession(**key_arg)
340 self.session = ss.StreamSession(**key_arg)
341 else:
341 else:
342 self.session = ss.StreamSession(username, **key_arg)
342 self.session = ss.StreamSession(username, **key_arg)
343 self._query_socket = self._context.socket(zmq.XREQ)
343 self._query_socket = self._context.socket(zmq.XREQ)
344 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
345 if self._ssh:
345 if self._ssh:
346 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
346 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
347 else:
347 else:
348 self._query_socket.connect(url)
348 self._query_socket.connect(url)
349
349
350 self.session.debug = self.debug
350 self.session.debug = self.debug
351
351
352 self._notification_handlers = {'registration_notification' : self._register_engine,
352 self._notification_handlers = {'registration_notification' : self._register_engine,
353 'unregistration_notification' : self._unregister_engine,
353 'unregistration_notification' : self._unregister_engine,
354 }
354 }
355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
356 'apply_reply' : self._handle_apply_reply}
356 'apply_reply' : self._handle_apply_reply}
357 self._connect(sshserver, ssh_kwargs)
357 self._connect(sshserver, ssh_kwargs)
358
358
359 def __del__(self):
359 def __del__(self):
360 """cleanup sockets, but _not_ context."""
360 """cleanup sockets, but _not_ context."""
361 self.close()
361 self.close()
362
362
363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
364 if ipython_dir is None:
364 if ipython_dir is None:
365 ipython_dir = get_ipython_dir()
365 ipython_dir = get_ipython_dir()
366 if cluster_dir is not None:
366 if cluster_dir is not None:
367 try:
367 try:
368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
369 return
369 return
370 except ClusterDirError:
370 except ClusterDirError:
371 pass
371 pass
372 elif profile is not None:
372 elif profile is not None:
373 try:
373 try:
374 self._cd = ClusterDir.find_cluster_dir_by_profile(
374 self._cd = ClusterDir.find_cluster_dir_by_profile(
375 ipython_dir, profile)
375 ipython_dir, profile)
376 return
376 return
377 except ClusterDirError:
377 except ClusterDirError:
378 pass
378 pass
379 self._cd = None
379 self._cd = None
380
380
381 @property
381 @property
382 def ids(self):
382 def ids(self):
383 """Always up-to-date ids property."""
383 """Always up-to-date ids property."""
384 self._flush_notifications()
384 self._flush_notifications()
385 # always copy:
385 # always copy:
386 return list(self._ids)
386 return list(self._ids)
387
387
388 def close(self):
388 def close(self):
389 if self._closed:
389 if self._closed:
390 return
390 return
391 snames = filter(lambda n: n.endswith('socket'), dir(self))
391 snames = filter(lambda n: n.endswith('socket'), dir(self))
392 for socket in map(lambda name: getattr(self, name), snames):
392 for socket in map(lambda name: getattr(self, name), snames):
393 if isinstance(socket, zmq.Socket) and not socket.closed:
393 if isinstance(socket, zmq.Socket) and not socket.closed:
394 socket.close()
394 socket.close()
395 self._closed = True
395 self._closed = True
396
396
397 def _update_engines(self, engines):
397 def _update_engines(self, engines):
398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
399 for k,v in engines.iteritems():
399 for k,v in engines.iteritems():
400 eid = int(k)
400 eid = int(k)
401 self._engines[eid] = bytes(v) # force not unicode
401 self._engines[eid] = bytes(v) # force not unicode
402 self._ids.append(eid)
402 self._ids.append(eid)
403 self._ids = sorted(self._ids)
403 self._ids = sorted(self._ids)
404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
405 self._task_scheme == 'pure' and self._task_ident:
405 self._task_scheme == 'pure' and self._task_ident:
406 self._stop_scheduling_tasks()
406 self._stop_scheduling_tasks()
407
407
408 def _stop_scheduling_tasks(self):
408 def _stop_scheduling_tasks(self):
409 """Stop scheduling tasks because an engine has been unregistered
409 """Stop scheduling tasks because an engine has been unregistered
410 from a pure ZMQ scheduler.
410 from a pure ZMQ scheduler.
411 """
411 """
412 self._task_ident = ''
412 self._task_ident = ''
413 # self._task_socket.close()
413 # self._task_socket.close()
414 # self._task_socket = None
414 # self._task_socket = None
415 msg = "An engine has been unregistered, and we are using pure " +\
415 msg = "An engine has been unregistered, and we are using pure " +\
416 "ZMQ task scheduling. Task farming will be disabled."
416 "ZMQ task scheduling. Task farming will be disabled."
417 if self.outstanding:
417 if self.outstanding:
418 msg += " If you were running tasks when this happened, " +\
418 msg += " If you were running tasks when this happened, " +\
419 "some `outstanding` msg_ids may never resolve."
419 "some `outstanding` msg_ids may never resolve."
420 warnings.warn(msg, RuntimeWarning)
420 warnings.warn(msg, RuntimeWarning)
421
421
422 def _build_targets(self, targets):
422 def _build_targets(self, targets):
423 """Turn valid target IDs or 'all' into two lists:
423 """Turn valid target IDs or 'all' into two lists:
424 (int_ids, uuids).
424 (int_ids, uuids).
425 """
425 """
426 if targets is None:
426 if targets is None:
427 targets = self._ids
427 targets = self._ids
428 elif isinstance(targets, str):
428 elif isinstance(targets, str):
429 if targets.lower() == 'all':
429 if targets.lower() == 'all':
430 targets = self._ids
430 targets = self._ids
431 else:
431 else:
432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
433 elif isinstance(targets, int):
433 elif isinstance(targets, int):
434 targets = [targets]
434 targets = [targets]
435 return [self._engines[t] for t in targets], list(targets)
435 return [self._engines[t] for t in targets], list(targets)
436
436
437 def _connect(self, sshserver, ssh_kwargs):
437 def _connect(self, sshserver, ssh_kwargs):
438 """setup all our socket connections to the controller. This is called from
438 """setup all our socket connections to the controller. This is called from
439 __init__."""
439 __init__."""
440
440
441 # Maybe allow reconnecting?
441 # Maybe allow reconnecting?
442 if self._connected:
442 if self._connected:
443 return
443 return
444 self._connected=True
444 self._connected=True
445
445
446 def connect_socket(s, url):
446 def connect_socket(s, url):
447 url = disambiguate_url(url, self._config['location'])
447 url = disambiguate_url(url, self._config['location'])
448 if self._ssh:
448 if self._ssh:
449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
450 else:
450 else:
451 return s.connect(url)
451 return s.connect(url)
452
452
453 self.session.send(self._query_socket, 'connection_request')
453 self.session.send(self._query_socket, 'connection_request')
454 idents,msg = self.session.recv(self._query_socket,mode=0)
454 idents,msg = self.session.recv(self._query_socket,mode=0)
455 if self.debug:
455 if self.debug:
456 pprint(msg)
456 pprint(msg)
457 msg = ss.Message(msg)
457 msg = ss.Message(msg)
458 content = msg.content
458 content = msg.content
459 self._config['registration'] = dict(content)
459 self._config['registration'] = dict(content)
460 if content.status == 'ok':
460 if content.status == 'ok':
461 self._apply_socket = self._context.socket(zmq.XREP)
461 self._apply_socket = self._context.socket(zmq.XREP)
462 self._apply_socket.setsockopt(zmq.IDENTITY, self.session.session)
462 self._apply_socket.setsockopt(zmq.IDENTITY, self.session.session)
463 if content.mux:
463 if content.mux:
464 # self._mux_socket = self._context.socket(zmq.XREQ)
464 # self._mux_socket = self._context.socket(zmq.XREQ)
465 self._mux_ident = 'mux'
465 self._mux_ident = 'mux'
466 connect_socket(self._apply_socket, content.mux)
466 connect_socket(self._apply_socket, content.mux)
467 if content.task:
467 if content.task:
468 self._task_scheme, task_addr = content.task
468 self._task_scheme, task_addr = content.task
469 # self._task_socket = self._context.socket(zmq.XREQ)
469 # self._task_socket = self._context.socket(zmq.XREQ)
470 # self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
470 # self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 connect_socket(self._apply_socket, task_addr)
471 connect_socket(self._apply_socket, task_addr)
472 self._task_ident = 'task'
472 self._task_ident = 'task'
473 if content.notification:
473 if content.notification:
474 self._notification_socket = self._context.socket(zmq.SUB)
474 self._notification_socket = self._context.socket(zmq.SUB)
475 connect_socket(self._notification_socket, content.notification)
475 connect_socket(self._notification_socket, content.notification)
476 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
476 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
477 # if content.query:
477 # if content.query:
478 # self._query_socket = self._context.socket(zmq.XREQ)
478 # self._query_socket = self._context.socket(zmq.XREQ)
479 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
479 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
480 # connect_socket(self._query_socket, content.query)
480 # connect_socket(self._query_socket, content.query)
481 if content.control:
481 if content.control:
482 self._control_socket = self._context.socket(zmq.XREQ)
482 self._control_socket = self._context.socket(zmq.XREQ)
483 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
483 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
484 connect_socket(self._control_socket, content.control)
484 connect_socket(self._control_socket, content.control)
485 if content.iopub:
485 if content.iopub:
486 self._iopub_socket = self._context.socket(zmq.SUB)
486 self._iopub_socket = self._context.socket(zmq.SUB)
487 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
487 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
488 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
488 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
489 connect_socket(self._iopub_socket, content.iopub)
489 connect_socket(self._iopub_socket, content.iopub)
490 self._update_engines(dict(content.engines))
490 self._update_engines(dict(content.engines))
491 # give XREP apply_socket some time to connect
491 # give XREP apply_socket some time to connect
492 time.sleep(0.25)
492 time.sleep(0.25)
493 else:
493 else:
494 self._connected = False
494 self._connected = False
495 raise Exception("Failed to connect!")
495 raise Exception("Failed to connect!")
496
496
497 #--------------------------------------------------------------------------
497 #--------------------------------------------------------------------------
498 # handlers and callbacks for incoming messages
498 # handlers and callbacks for incoming messages
499 #--------------------------------------------------------------------------
499 #--------------------------------------------------------------------------
500
500
501 def _unwrap_exception(self, content):
501 def _unwrap_exception(self, content):
502 """unwrap exception, and remap engineid to int."""
502 """unwrap exception, and remap engineid to int."""
503 e = error.unwrap_exception(content)
503 e = error.unwrap_exception(content)
504 # print e.traceback
504 # print e.traceback
505 if e.engine_info:
505 if e.engine_info:
506 e_uuid = e.engine_info['engine_uuid']
506 e_uuid = e.engine_info['engine_uuid']
507 eid = self._engines[e_uuid]
507 eid = self._engines[e_uuid]
508 e.engine_info['engine_id'] = eid
508 e.engine_info['engine_id'] = eid
509 return e
509 return e
510
510
511 def _extract_metadata(self, header, parent, content):
511 def _extract_metadata(self, header, parent, content):
512 md = {'msg_id' : parent['msg_id'],
512 md = {'msg_id' : parent['msg_id'],
513 'received' : datetime.now(),
513 'received' : datetime.now(),
514 'engine_uuid' : header.get('engine', None),
514 'engine_uuid' : header.get('engine', None),
515 'follow' : parent.get('follow', []),
515 'follow' : parent.get('follow', []),
516 'after' : parent.get('after', []),
516 'after' : parent.get('after', []),
517 'status' : content['status'],
517 'status' : content['status'],
518 }
518 }
519
519
520 if md['engine_uuid'] is not None:
520 if md['engine_uuid'] is not None:
521 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
521 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
522
522
523 if 'date' in parent:
523 if 'date' in parent:
524 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
524 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
525 if 'started' in header:
525 if 'started' in header:
526 md['started'] = datetime.strptime(header['started'], util.ISO8601)
526 md['started'] = datetime.strptime(header['started'], util.ISO8601)
527 if 'date' in header:
527 if 'date' in header:
528 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
528 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
529 return md
529 return md
530
530
531 def _register_engine(self, msg):
531 def _register_engine(self, msg):
532 """Register a new engine, and update our connection info."""
532 """Register a new engine, and update our connection info."""
533 content = msg['content']
533 content = msg['content']
534 eid = content['id']
534 eid = content['id']
535 d = {eid : content['queue']}
535 d = {eid : content['queue']}
536 self._update_engines(d)
536 self._update_engines(d)
537
537
538 def _unregister_engine(self, msg):
538 def _unregister_engine(self, msg):
539 """Unregister an engine that has died."""
539 """Unregister an engine that has died."""
540 content = msg['content']
540 content = msg['content']
541 eid = int(content['id'])
541 eid = int(content['id'])
542 if eid in self._ids:
542 if eid in self._ids:
543 self._ids.remove(eid)
543 self._ids.remove(eid)
544 uuid = self._engines.pop(eid)
544 uuid = self._engines.pop(eid)
545
545
546 self._handle_stranded_msgs(eid, uuid)
546 self._handle_stranded_msgs(eid, uuid)
547
547
548 if self._task_ident and self._task_scheme == 'pure':
548 if self._task_ident and self._task_scheme == 'pure':
549 self._stop_scheduling_tasks()
549 self._stop_scheduling_tasks()
550
550
551 def _handle_stranded_msgs(self, eid, uuid):
551 def _handle_stranded_msgs(self, eid, uuid):
552 """Handle messages known to be on an engine when the engine unregisters.
552 """Handle messages known to be on an engine when the engine unregisters.
553
553
554 It is possible that this will fire prematurely - that is, an engine will
554 It is possible that this will fire prematurely - that is, an engine will
555 go down after completing a result, and the client will be notified
555 go down after completing a result, and the client will be notified
556 of the unregistration and later receive the successful result.
556 of the unregistration and later receive the successful result.
557 """
557 """
558
558
559 outstanding = self._outstanding_dict[uuid]
559 outstanding = self._outstanding_dict[uuid]
560
560
561 for msg_id in list(outstanding):
561 for msg_id in list(outstanding):
562 if msg_id in self.results:
562 if msg_id in self.results:
563 # we already
563 # we already
564 continue
564 continue
565 try:
565 try:
566 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
566 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
567 except:
567 except:
568 content = error.wrap_exception()
568 content = error.wrap_exception()
569 # build a fake message:
569 # build a fake message:
570 parent = {}
570 parent = {}
571 header = {}
571 header = {}
572 parent['msg_id'] = msg_id
572 parent['msg_id'] = msg_id
573 header['engine'] = uuid
573 header['engine'] = uuid
574 header['date'] = datetime.now().strftime(util.ISO8601)
574 header['date'] = datetime.now().strftime(util.ISO8601)
575 msg = dict(parent_header=parent, header=header, content=content)
575 msg = dict(parent_header=parent, header=header, content=content)
576 self._handle_apply_reply(msg)
576 self._handle_apply_reply(msg)
577
577
578 def _handle_execute_reply(self, msg):
578 def _handle_execute_reply(self, msg):
579 """Save the reply to an execute_request into our results.
579 """Save the reply to an execute_request into our results.
580
580
581 execute messages are never actually used. apply is used instead.
581 execute messages are never actually used. apply is used instead.
582 """
582 """
583
583
584 parent = msg['parent_header']
584 parent = msg['parent_header']
585 msg_id = parent['msg_id']
585 msg_id = parent['msg_id']
586 if msg_id not in self.outstanding:
586 if msg_id not in self.outstanding:
587 if msg_id in self.history:
587 if msg_id in self.history:
588 print ("got stale result: %s"%msg_id)
588 print ("got stale result: %s"%msg_id)
589 else:
589 else:
590 print ("got unknown result: %s"%msg_id)
590 print ("got unknown result: %s"%msg_id)
591 else:
591 else:
592 self.outstanding.remove(msg_id)
592 self.outstanding.remove(msg_id)
593 self.results[msg_id] = self._unwrap_exception(msg['content'])
593 self.results[msg_id] = self._unwrap_exception(msg['content'])
594
594
595 def _handle_apply_reply(self, msg):
595 def _handle_apply_reply(self, msg):
596 """Save the reply to an apply_request into our results."""
596 """Save the reply to an apply_request into our results."""
597 parent = msg['parent_header']
597 parent = msg['parent_header']
598 msg_id = parent['msg_id']
598 msg_id = parent['msg_id']
599 if msg_id not in self.outstanding:
599 if msg_id not in self.outstanding:
600 if msg_id in self.history:
600 if msg_id in self.history:
601 print ("got stale result: %s"%msg_id)
601 print ("got stale result: %s"%msg_id)
602 print self.results[msg_id]
602 print self.results[msg_id]
603 print msg
603 print msg
604 else:
604 else:
605 print ("got unknown result: %s"%msg_id)
605 print ("got unknown result: %s"%msg_id)
606 else:
606 else:
607 self.outstanding.remove(msg_id)
607 self.outstanding.remove(msg_id)
608 content = msg['content']
608 content = msg['content']
609 header = msg['header']
609 header = msg['header']
610
610
611 # construct metadata:
611 # construct metadata:
612 md = self.metadata[msg_id]
612 md = self.metadata[msg_id]
613 md.update(self._extract_metadata(header, parent, content))
613 md.update(self._extract_metadata(header, parent, content))
614 # is this redundant?
614 # is this redundant?
615 self.metadata[msg_id] = md
615 self.metadata[msg_id] = md
616
616
617 e_outstanding = self._outstanding_dict[md['engine_uuid']]
617 e_outstanding = self._outstanding_dict[md['engine_uuid']]
618 if msg_id in e_outstanding:
618 if msg_id in e_outstanding:
619 e_outstanding.remove(msg_id)
619 e_outstanding.remove(msg_id)
620
620
621 # construct result:
621 # construct result:
622 if content['status'] == 'ok':
622 if content['status'] == 'ok':
623 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
623 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
624 elif content['status'] == 'aborted':
624 elif content['status'] == 'aborted':
625 self.results[msg_id] = error.AbortedTask(msg_id)
625 self.results[msg_id] = error.AbortedTask(msg_id)
626 elif content['status'] == 'resubmitted':
626 elif content['status'] == 'resubmitted':
627 # TODO: handle resubmission
627 # TODO: handle resubmission
628 pass
628 pass
629 else:
629 else:
630 self.results[msg_id] = self._unwrap_exception(content)
630 self.results[msg_id] = self._unwrap_exception(content)
631
631
632 def _flush_notifications(self):
632 def _flush_notifications(self):
633 """Flush notifications of engine registrations waiting
633 """Flush notifications of engine registrations waiting
634 in ZMQ queue."""
634 in ZMQ queue."""
635 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
635 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
636 while msg is not None:
636 while msg is not None:
637 if self.debug:
637 if self.debug:
638 pprint(msg)
638 pprint(msg)
639 msg = msg[-1]
639 msg = msg[-1]
640 msg_type = msg['msg_type']
640 msg_type = msg['msg_type']
641 handler = self._notification_handlers.get(msg_type, None)
641 handler = self._notification_handlers.get(msg_type, None)
642 if handler is None:
642 if handler is None:
643 raise Exception("Unhandled message type: %s"%msg.msg_type)
643 raise Exception("Unhandled message type: %s"%msg.msg_type)
644 else:
644 else:
645 handler(msg)
645 handler(msg)
646 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
646 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
647
647
648 def _flush_results(self, sock):
648 def _flush_results(self, sock):
649 """Flush task or queue results waiting in ZMQ queue."""
649 """Flush task or queue results waiting in ZMQ queue."""
650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
651 while msg is not None:
651 while msg is not None:
652 if self.debug:
652 if self.debug:
653 pprint(msg)
653 pprint(msg)
654 msg = msg[-1]
654 msg = msg[-1]
655 msg_type = msg['msg_type']
655 msg_type = msg['msg_type']
656 handler = self._queue_handlers.get(msg_type, None)
656 handler = self._queue_handlers.get(msg_type, None)
657 if handler is None:
657 if handler is None:
658 raise Exception("Unhandled message type: %s"%msg.msg_type)
658 raise Exception("Unhandled message type: %s"%msg.msg_type)
659 else:
659 else:
660 handler(msg)
660 handler(msg)
661 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
661 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
662
662
663 def _flush_control(self, sock):
663 def _flush_control(self, sock):
664 """Flush replies from the control channel waiting
664 """Flush replies from the control channel waiting
665 in the ZMQ queue.
665 in the ZMQ queue.
666
666
667 Currently: ignore them."""
667 Currently: ignore them."""
668 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
668 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
669 while msg is not None:
669 while msg is not None:
670 if self.debug:
670 if self.debug:
671 pprint(msg)
671 pprint(msg)
672 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
672 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
673
673
674 def _flush_iopub(self, sock):
674 def _flush_iopub(self, sock):
675 """Flush replies from the iopub channel waiting
675 """Flush replies from the iopub channel waiting
676 in the ZMQ queue.
676 in the ZMQ queue.
677 """
677 """
678 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
679 while msg is not None:
679 while msg is not None:
680 if self.debug:
680 if self.debug:
681 pprint(msg)
681 pprint(msg)
682 msg = msg[-1]
682 msg = msg[-1]
683 parent = msg['parent_header']
683 parent = msg['parent_header']
684 msg_id = parent['msg_id']
684 msg_id = parent['msg_id']
685 content = msg['content']
685 content = msg['content']
686 header = msg['header']
686 header = msg['header']
687 msg_type = msg['msg_type']
687 msg_type = msg['msg_type']
688
688
689 # init metadata:
689 # init metadata:
690 md = self.metadata[msg_id]
690 md = self.metadata[msg_id]
691
691
692 if msg_type == 'stream':
692 if msg_type == 'stream':
693 name = content['name']
693 name = content['name']
694 s = md[name] or ''
694 s = md[name] or ''
695 md[name] = s + content['data']
695 md[name] = s + content['data']
696 elif msg_type == 'pyerr':
696 elif msg_type == 'pyerr':
697 md.update({'pyerr' : self._unwrap_exception(content)})
697 md.update({'pyerr' : self._unwrap_exception(content)})
698 else:
698 else:
699 md.update({msg_type : content['data']})
699 md.update({msg_type : content['data']})
700
700
701 # reduntant?
701 # reduntant?
702 self.metadata[msg_id] = md
702 self.metadata[msg_id] = md
703
703
704 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
704 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
705
705
706 #--------------------------------------------------------------------------
706 #--------------------------------------------------------------------------
707 # len, getitem
707 # len, getitem
708 #--------------------------------------------------------------------------
708 #--------------------------------------------------------------------------
709
709
710 def __len__(self):
710 def __len__(self):
711 """len(client) returns # of engines."""
711 """len(client) returns # of engines."""
712 return len(self.ids)
712 return len(self.ids)
713
713
714 def __getitem__(self, key):
714 def __getitem__(self, key):
715 """index access returns DirectView multiplexer objects
715 """index access returns DirectView multiplexer objects
716
716
717 Must be int, slice, or list/tuple/xrange of ints"""
717 Must be int, slice, or list/tuple/xrange of ints"""
718 if not isinstance(key, (int, slice, tuple, list, xrange)):
718 if not isinstance(key, (int, slice, tuple, list, xrange)):
719 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
719 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
720 else:
720 else:
721 return self.view(key, balanced=False)
721 return self.view(key, balanced=False)
722
722
723 #--------------------------------------------------------------------------
723 #--------------------------------------------------------------------------
724 # Begin public methods
724 # Begin public methods
725 #--------------------------------------------------------------------------
725 #--------------------------------------------------------------------------
726
726
727 def spin(self):
727 def spin(self):
728 """Flush any registration notifications and execution results
728 """Flush any registration notifications and execution results
729 waiting in the ZMQ queue.
729 waiting in the ZMQ queue.
730 """
730 """
731 if self._notification_socket:
731 if self._notification_socket:
732 self._flush_notifications()
732 self._flush_notifications()
733 if self._apply_socket:
733 if self._apply_socket:
734 self._flush_results(self._apply_socket)
734 self._flush_results(self._apply_socket)
735 if self._control_socket:
735 if self._control_socket:
736 self._flush_control(self._control_socket)
736 self._flush_control(self._control_socket)
737 if self._iopub_socket:
737 if self._iopub_socket:
738 self._flush_iopub(self._iopub_socket)
738 self._flush_iopub(self._iopub_socket)
739
739
740 def barrier(self, jobs=None, timeout=-1):
740 def barrier(self, jobs=None, timeout=-1):
741 """waits on one or more `jobs`, for up to `timeout` seconds.
741 """waits on one or more `jobs`, for up to `timeout` seconds.
742
742
743 Parameters
743 Parameters
744 ----------
744 ----------
745
745
746 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
746 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
747 ints are indices to self.history
747 ints are indices to self.history
748 strs are msg_ids
748 strs are msg_ids
749 default: wait on all outstanding messages
749 default: wait on all outstanding messages
750 timeout : float
750 timeout : float
751 a time in seconds, after which to give up.
751 a time in seconds, after which to give up.
752 default is -1, which means no timeout
752 default is -1, which means no timeout
753
753
754 Returns
754 Returns
755 -------
755 -------
756
756
757 True : when all msg_ids are done
757 True : when all msg_ids are done
758 False : timeout reached, some msg_ids still outstanding
758 False : timeout reached, some msg_ids still outstanding
759 """
759 """
760 tic = time.time()
760 tic = time.time()
761 if jobs is None:
761 if jobs is None:
762 theids = self.outstanding
762 theids = self.outstanding
763 else:
763 else:
764 if isinstance(jobs, (int, str, AsyncResult)):
764 if isinstance(jobs, (int, str, AsyncResult)):
765 jobs = [jobs]
765 jobs = [jobs]
766 theids = set()
766 theids = set()
767 for job in jobs:
767 for job in jobs:
768 if isinstance(job, int):
768 if isinstance(job, int):
769 # index access
769 # index access
770 job = self.history[job]
770 job = self.history[job]
771 elif isinstance(job, AsyncResult):
771 elif isinstance(job, AsyncResult):
772 map(theids.add, job.msg_ids)
772 map(theids.add, job.msg_ids)
773 continue
773 continue
774 theids.add(job)
774 theids.add(job)
775 if not theids.intersection(self.outstanding):
775 if not theids.intersection(self.outstanding):
776 return True
776 return True
777 self.spin()
777 self.spin()
778 while theids.intersection(self.outstanding):
778 while theids.intersection(self.outstanding):
779 if timeout >= 0 and ( time.time()-tic ) > timeout:
779 if timeout >= 0 and ( time.time()-tic ) > timeout:
780 break
780 break
781 time.sleep(1e-3)
781 time.sleep(1e-3)
782 self.spin()
782 self.spin()
783 return len(theids.intersection(self.outstanding)) == 0
783 return len(theids.intersection(self.outstanding)) == 0
784
784
785 #--------------------------------------------------------------------------
785 #--------------------------------------------------------------------------
786 # Control methods
786 # Control methods
787 #--------------------------------------------------------------------------
787 #--------------------------------------------------------------------------
788
788
789 @spinfirst
789 @spinfirst
790 @defaultblock
790 @defaultblock
791 def clear(self, targets=None, block=None):
791 def clear(self, targets=None, block=None):
792 """Clear the namespace in target(s)."""
792 """Clear the namespace in target(s)."""
793 targets = self._build_targets(targets)[0]
793 targets = self._build_targets(targets)[0]
794 for t in targets:
794 for t in targets:
795 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
795 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
796 error = False
796 error = False
797 if self.block:
797 if self.block:
798 for i in range(len(targets)):
798 for i in range(len(targets)):
799 idents,msg = self.session.recv(self._control_socket,0)
799 idents,msg = self.session.recv(self._control_socket,0)
800 if self.debug:
800 if self.debug:
801 pprint(msg)
801 pprint(msg)
802 if msg['content']['status'] != 'ok':
802 if msg['content']['status'] != 'ok':
803 error = self._unwrap_exception(msg['content'])
803 error = self._unwrap_exception(msg['content'])
804 if error:
804 if error:
805 raise error
805 raise error
806
806
807
807
808 @spinfirst
808 @spinfirst
809 @defaultblock
809 @defaultblock
810 def abort(self, jobs=None, targets=None, block=None):
810 def abort(self, jobs=None, targets=None, block=None):
811 """Abort specific jobs from the execution queues of target(s).
811 """Abort specific jobs from the execution queues of target(s).
812
812
813 This is a mechanism to prevent jobs that have already been submitted
813 This is a mechanism to prevent jobs that have already been submitted
814 from executing.
814 from executing.
815
815
816 Parameters
816 Parameters
817 ----------
817 ----------
818
818
819 jobs : msg_id, list of msg_ids, or AsyncResult
819 jobs : msg_id, list of msg_ids, or AsyncResult
820 The jobs to be aborted
820 The jobs to be aborted
821
821
822
822
823 """
823 """
824 targets = self._build_targets(targets)[0]
824 targets = self._build_targets(targets)[0]
825 msg_ids = []
825 msg_ids = []
826 if isinstance(jobs, (basestring,AsyncResult)):
826 if isinstance(jobs, (basestring,AsyncResult)):
827 jobs = [jobs]
827 jobs = [jobs]
828 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
828 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
829 if bad_ids:
829 if bad_ids:
830 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
830 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
831 for j in jobs:
831 for j in jobs:
832 if isinstance(j, AsyncResult):
832 if isinstance(j, AsyncResult):
833 msg_ids.extend(j.msg_ids)
833 msg_ids.extend(j.msg_ids)
834 else:
834 else:
835 msg_ids.append(j)
835 msg_ids.append(j)
836 content = dict(msg_ids=msg_ids)
836 content = dict(msg_ids=msg_ids)
837 for t in targets:
837 for t in targets:
838 self.session.send(self._control_socket, 'abort_request',
838 self.session.send(self._control_socket, 'abort_request',
839 content=content, ident=t)
839 content=content, ident=t)
840 error = False
840 error = False
841 if self.block:
841 if self.block:
842 for i in range(len(targets)):
842 for i in range(len(targets)):
843 idents,msg = self.session.recv(self._control_socket,0)
843 idents,msg = self.session.recv(self._control_socket,0)
844 if self.debug:
844 if self.debug:
845 pprint(msg)
845 pprint(msg)
846 if msg['content']['status'] != 'ok':
846 if msg['content']['status'] != 'ok':
847 error = self._unwrap_exception(msg['content'])
847 error = self._unwrap_exception(msg['content'])
848 if error:
848 if error:
849 raise error
849 raise error
850
850
851 @spinfirst
851 @spinfirst
852 @defaultblock
852 @defaultblock
853 def shutdown(self, targets=None, restart=False, controller=False, block=None):
853 def shutdown(self, targets=None, restart=False, controller=False, block=None):
854 """Terminates one or more engine processes, optionally including the controller."""
854 """Terminates one or more engine processes, optionally including the controller."""
855 if controller:
855 if controller:
856 targets = 'all'
856 targets = 'all'
857 targets = self._build_targets(targets)[0]
857 targets = self._build_targets(targets)[0]
858 for t in targets:
858 for t in targets:
859 self.session.send(self._control_socket, 'shutdown_request',
859 self.session.send(self._control_socket, 'shutdown_request',
860 content={'restart':restart},ident=t)
860 content={'restart':restart},ident=t)
861 error = False
861 error = False
862 if block or controller:
862 if block or controller:
863 for i in range(len(targets)):
863 for i in range(len(targets)):
864 idents,msg = self.session.recv(self._control_socket,0)
864 idents,msg = self.session.recv(self._control_socket,0)
865 if self.debug:
865 if self.debug:
866 pprint(msg)
866 pprint(msg)
867 if msg['content']['status'] != 'ok':
867 if msg['content']['status'] != 'ok':
868 error = self._unwrap_exception(msg['content'])
868 error = self._unwrap_exception(msg['content'])
869
869
870 if controller:
870 if controller:
871 time.sleep(0.25)
871 time.sleep(0.25)
872 self.session.send(self._query_socket, 'shutdown_request')
872 self.session.send(self._query_socket, 'shutdown_request')
873 idents,msg = self.session.recv(self._query_socket, 0)
873 idents,msg = self.session.recv(self._query_socket, 0)
874 if self.debug:
874 if self.debug:
875 pprint(msg)
875 pprint(msg)
876 if msg['content']['status'] != 'ok':
876 if msg['content']['status'] != 'ok':
877 error = self._unwrap_exception(msg['content'])
877 error = self._unwrap_exception(msg['content'])
878
878
879 if error:
879 if error:
880 raise error
880 raise error
881
881
882 #--------------------------------------------------------------------------
882 #--------------------------------------------------------------------------
883 # Execution methods
883 # Execution methods
884 #--------------------------------------------------------------------------
884 #--------------------------------------------------------------------------
885
885
886 @defaultblock
886 @defaultblock
887 def execute(self, code, targets='all', block=None):
887 def execute(self, code, targets='all', block=None):
888 """Executes `code` on `targets` in blocking or nonblocking manner.
888 """Executes `code` on `targets` in blocking or nonblocking manner.
889
889
890 ``execute`` is always `bound` (affects engine namespace)
890 ``execute`` is always `bound` (affects engine namespace)
891
891
892 Parameters
892 Parameters
893 ----------
893 ----------
894
894
895 code : str
895 code : str
896 the code string to be executed
896 the code string to be executed
897 targets : int/str/list of ints/strs
897 targets : int/str/list of ints/strs
898 the engines on which to execute
898 the engines on which to execute
899 default : all
899 default : all
900 block : bool
900 block : bool
901 whether or not to wait until done to return
901 whether or not to wait until done to return
902 default: self.block
902 default: self.block
903 """
903 """
904 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
904 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
905 if not block:
905 if not block:
906 return result
906 return result
907
907
908 def run(self, filename, targets='all', block=None):
908 def run(self, filename, targets='all', block=None):
909 """Execute contents of `filename` on engine(s).
909 """Execute contents of `filename` on engine(s).
910
910
911 This simply reads the contents of the file and calls `execute`.
911 This simply reads the contents of the file and calls `execute`.
912
912
913 Parameters
913 Parameters
914 ----------
914 ----------
915
915
916 filename : str
916 filename : str
917 The path to the file
917 The path to the file
918 targets : int/str/list of ints/strs
918 targets : int/str/list of ints/strs
919 the engines on which to execute
919 the engines on which to execute
920 default : all
920 default : all
921 block : bool
921 block : bool
922 whether or not to wait until done
922 whether or not to wait until done
923 default: self.block
923 default: self.block
924
924
925 """
925 """
926 with open(filename, 'r') as f:
926 with open(filename, 'r') as f:
927 # add newline in case of trailing indented whitespace
927 # add newline in case of trailing indented whitespace
928 # which will cause SyntaxError
928 # which will cause SyntaxError
929 code = f.read()+'\n'
929 code = f.read()+'\n'
930 return self.execute(code, targets=targets, block=block)
930 return self.execute(code, targets=targets, block=block)
931
931
932 def _maybe_raise(self, result):
932 def _maybe_raise(self, result):
933 """wrapper for maybe raising an exception if apply failed."""
933 """wrapper for maybe raising an exception if apply failed."""
934 if isinstance(result, error.RemoteError):
934 if isinstance(result, error.RemoteError):
935 raise result
935 raise result
936
936
937 return result
937 return result
938
938
939 def _build_dependency(self, dep):
939 def _build_dependency(self, dep):
940 """helper for building jsonable dependencies from various input forms"""
940 """helper for building jsonable dependencies from various input forms"""
941 if isinstance(dep, Dependency):
941 if isinstance(dep, Dependency):
942 return dep.as_dict()
942 return dep.as_dict()
943 elif isinstance(dep, AsyncResult):
943 elif isinstance(dep, AsyncResult):
944 return dep.msg_ids
944 return dep.msg_ids
945 elif dep is None:
945 elif dep is None:
946 return []
946 return []
947 else:
947 else:
948 # pass to Dependency constructor
948 # pass to Dependency constructor
949 return list(Dependency(dep))
949 return list(Dependency(dep))
950
950
951 @defaultblock
951 @defaultblock
952 def apply(self, f, args=None, kwargs=None, bound=False, block=None,
952 def apply(self, f, args=None, kwargs=None, bound=False, block=None,
953 targets=None, balanced=None,
953 targets=None, balanced=None,
954 after=None, follow=None, timeout=None,
954 after=None, follow=None, timeout=None,
955 track=False):
955 track=False):
956 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
956 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
957
957
958 This is the central execution command for the client.
958 This is the central execution command for the client.
959
959
960 Parameters
960 Parameters
961 ----------
961 ----------
962
962
963 f : function
963 f : function
964 The fuction to be called remotely
964 The fuction to be called remotely
965 args : tuple/list
965 args : tuple/list
966 The positional arguments passed to `f`
966 The positional arguments passed to `f`
967 kwargs : dict
967 kwargs : dict
968 The keyword arguments passed to `f`
968 The keyword arguments passed to `f`
969 bound : bool (default: False)
969 bound : bool (default: False)
970 Whether to pass the Engine(s) Namespace as the first argument to `f`.
970 Whether to pass the Engine(s) Namespace as the first argument to `f`.
971 block : bool (default: self.block)
971 block : bool (default: self.block)
972 Whether to wait for the result, or return immediately.
972 Whether to wait for the result, or return immediately.
973 False:
973 False:
974 returns AsyncResult
974 returns AsyncResult
975 True:
975 True:
976 returns actual result(s) of f(*args, **kwargs)
976 returns actual result(s) of f(*args, **kwargs)
977 if multiple targets:
977 if multiple targets:
978 list of results, matching `targets`
978 list of results, matching `targets`
979 track : bool
980 whether to track non-copying sends.
981 [default False]
982
979 targets : int,list of ints, 'all', None
983 targets : int,list of ints, 'all', None
980 Specify the destination of the job.
984 Specify the destination of the job.
981 if None:
985 if None:
982 Submit via Task queue for load-balancing.
986 Submit via Task queue for load-balancing.
983 if 'all':
987 if 'all':
984 Run on all active engines
988 Run on all active engines
985 if list:
989 if list:
986 Run on each specified engine
990 Run on each specified engine
987 if int:
991 if int:
988 Run on single engine
992 Run on single engine
989
993 Note:
994 that if `balanced=True`, and `targets` is specified,
995 then the load-balancing will be limited to balancing
996 among `targets`.
997
990 balanced : bool, default None
998 balanced : bool, default None
991 whether to load-balance. This will default to True
999 whether to load-balance. This will default to True
992 if targets is unspecified, or False if targets is specified.
1000 if targets is unspecified, or False if targets is specified.
993
1001
994 The following arguments are only used when balanced is True:
1002 If `balanced` and `targets` are both specified, the task will
1003 be assigne to *one* of the targets by the scheduler.
1004
1005 The following arguments are only used when balanced is True:
1006
995 after : Dependency or collection of msg_ids
1007 after : Dependency or collection of msg_ids
996 Only for load-balanced execution (targets=None)
1008 Only for load-balanced execution (targets=None)
997 Specify a list of msg_ids as a time-based dependency.
1009 Specify a list of msg_ids as a time-based dependency.
998 This job will only be run *after* the dependencies
1010 This job will only be run *after* the dependencies
999 have been met.
1011 have been met.
1000
1012
1001 follow : Dependency or collection of msg_ids
1013 follow : Dependency or collection of msg_ids
1002 Only for load-balanced execution (targets=None)
1014 Only for load-balanced execution (targets=None)
1003 Specify a list of msg_ids as a location-based dependency.
1015 Specify a list of msg_ids as a location-based dependency.
1004 This job will only be run on an engine where this dependency
1016 This job will only be run on an engine where this dependency
1005 is met.
1017 is met.
1006
1018
1007 timeout : float/int or None
1019 timeout : float/int or None
1008 Only for load-balanced execution (targets=None)
1020 Only for load-balanced execution (targets=None)
1009 Specify an amount of time (in seconds) for the scheduler to
1021 Specify an amount of time (in seconds) for the scheduler to
1010 wait for dependencies to be met before failing with a
1022 wait for dependencies to be met before failing with a
1011 DependencyTimeout.
1023 DependencyTimeout.
1012 track : bool
1013 whether to track non-copying sends.
1014 [default False]
1015
1016 after,follow,timeout only used if `balanced=True`.
1017
1024
1018 Returns
1025 Returns
1019 -------
1026 -------
1020
1027
1021 if block is False:
1028 if block is False:
1022 return AsyncResult wrapping msg_ids
1029 return AsyncResult wrapping msg_ids
1023 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1030 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1024 else:
1031 else:
1025 if single target:
1032 if single target (or balanced):
1026 return result of `f(*args, **kwargs)`
1033 return result of `f(*args, **kwargs)`
1027 else:
1034 else:
1028 return list of results, matching `targets`
1035 return list of results, matching `targets`
1029 """
1036 """
1030 assert not self._closed, "cannot use me anymore, I'm closed!"
1037 assert not self._closed, "cannot use me anymore, I'm closed!"
1031 # defaults:
1038 # defaults:
1032 block = block if block is not None else self.block
1039 block = block if block is not None else self.block
1033 args = args if args is not None else []
1040 args = args if args is not None else []
1034 kwargs = kwargs if kwargs is not None else {}
1041 kwargs = kwargs if kwargs is not None else {}
1035
1042
1036 if not self._ids:
1043 if not self._ids:
1037 # flush notification socket if no engines yet
1044 # flush notification socket if no engines yet
1038 any_ids = self.ids
1045 any_ids = self.ids
1039 if not any_ids:
1046 if not any_ids:
1040 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
1047 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
1041
1048
1042 if balanced is None:
1049 if balanced is None:
1043 if targets is None:
1050 if targets is None:
1044 # default to balanced if targets unspecified
1051 # default to balanced if targets unspecified
1045 balanced = True
1052 balanced = True
1046 else:
1053 else:
1047 # otherwise default to multiplexing
1054 # otherwise default to multiplexing
1048 balanced = False
1055 balanced = False
1049
1056
1050 if targets is None and balanced is False:
1057 if targets is None and balanced is False:
1051 # default to all if *not* balanced, and targets is unspecified
1058 # default to all if *not* balanced, and targets is unspecified
1052 targets = 'all'
1059 targets = 'all'
1053
1060
1054 # enforce types of f,args,kwrags
1061 # enforce types of f,args,kwrags
1055 if not callable(f):
1062 if not callable(f):
1056 raise TypeError("f must be callable, not %s"%type(f))
1063 raise TypeError("f must be callable, not %s"%type(f))
1057 if not isinstance(args, (tuple, list)):
1064 if not isinstance(args, (tuple, list)):
1058 raise TypeError("args must be tuple or list, not %s"%type(args))
1065 raise TypeError("args must be tuple or list, not %s"%type(args))
1059 if not isinstance(kwargs, dict):
1066 if not isinstance(kwargs, dict):
1060 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1067 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1061
1068
1062 options = dict(bound=bound, block=block, targets=targets, track=track)
1069 options = dict(bound=bound, block=block, targets=targets, track=track)
1063
1070
1064 if balanced:
1071 if balanced:
1065 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1072 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1066 after=after, follow=follow, **options)
1073 after=after, follow=follow, **options)
1067 elif follow or after or timeout:
1074 elif follow or after or timeout:
1068 msg = "follow, after, and timeout args are only used for"
1075 msg = "follow, after, and timeout args are only used for"
1069 msg += " load-balanced execution."
1076 msg += " load-balanced execution."
1070 raise ValueError(msg)
1077 raise ValueError(msg)
1071 else:
1078 else:
1072 return self._apply_direct(f, args, kwargs, **options)
1079 return self._apply_direct(f, args, kwargs, **options)
1073
1080
1074 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1081 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1075 after=None, follow=None, timeout=None, track=None):
1082 after=None, follow=None, timeout=None, track=None):
1076 """call f(*args, **kwargs) remotely in a load-balanced manner.
1083 """call f(*args, **kwargs) remotely in a load-balanced manner.
1077
1084
1078 This is a private method, see `apply` for details.
1085 This is a private method, see `apply` for details.
1079 Not to be called directly!
1086 Not to be called directly!
1080 """
1087 """
1081
1088
1082 loc = locals()
1089 loc = locals()
1083 for name in ('bound', 'block', 'track'):
1090 for name in ('bound', 'block', 'track'):
1084 assert loc[name] is not None, "kwarg %r must be specified!"%name
1091 assert loc[name] is not None, "kwarg %r must be specified!"%name
1085
1092
1086 if not self._task_ident:
1093 if not self._task_ident:
1087 msg = "Task farming is disabled"
1094 msg = "Task farming is disabled"
1088 if self._task_scheme == 'pure':
1095 if self._task_scheme == 'pure':
1089 msg += " because the pure ZMQ scheduler cannot handle"
1096 msg += " because the pure ZMQ scheduler cannot handle"
1090 msg += " disappearing engines."
1097 msg += " disappearing engines."
1091 raise RuntimeError(msg)
1098 raise RuntimeError(msg)
1092
1099
1093 if self._task_scheme == 'pure':
1100 if self._task_scheme == 'pure':
1094 # pure zmq scheme doesn't support dependencies
1101 # pure zmq scheme doesn't support dependencies
1095 msg = "Pure ZMQ scheduler doesn't support dependencies"
1102 msg = "Pure ZMQ scheduler doesn't support dependencies"
1096 if (follow or after):
1103 if (follow or after):
1097 # hard fail on DAG dependencies
1104 # hard fail on DAG dependencies
1098 raise RuntimeError(msg)
1105 raise RuntimeError(msg)
1099 if isinstance(f, dependent):
1106 if isinstance(f, dependent):
1100 # soft warn on functional dependencies
1107 # soft warn on functional dependencies
1101 warnings.warn(msg, RuntimeWarning)
1108 warnings.warn(msg, RuntimeWarning)
1102
1109
1103 # defaults:
1110 # defaults:
1104 args = args if args is not None else []
1111 args = args if args is not None else []
1105 kwargs = kwargs if kwargs is not None else {}
1112 kwargs = kwargs if kwargs is not None else {}
1106
1113
1107 if targets:
1114 if targets:
1108 idents,_ = self._build_targets(targets)
1115 idents,_ = self._build_targets(targets)
1109 else:
1116 else:
1110 idents = []
1117 idents = []
1111
1118
1112 after = self._build_dependency(after)
1119 after = self._build_dependency(after)
1113 follow = self._build_dependency(follow)
1120 follow = self._build_dependency(follow)
1114 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1121 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1115 bufs = util.pack_apply_message(f,args,kwargs)
1122 bufs = util.pack_apply_message(f,args,kwargs)
1116 content = dict(bound=bound)
1123 content = dict(bound=bound)
1117
1124
1118 msg = self.session.send(self._apply_socket, "apply_request", ident=self._task_ident,
1125 msg = self.session.send(self._apply_socket, "apply_request", ident=self._task_ident,
1119 content=content, buffers=bufs, subheader=subheader, track=track)
1126 content=content, buffers=bufs, subheader=subheader, track=track)
1120 msg_id = msg['msg_id']
1127 msg_id = msg['msg_id']
1121 self.outstanding.add(msg_id)
1128 self.outstanding.add(msg_id)
1122 self.history.append(msg_id)
1129 self.history.append(msg_id)
1123 self.metadata[msg_id]['submitted'] = datetime.now()
1130 self.metadata[msg_id]['submitted'] = datetime.now()
1124 tracker = None if track is False else msg['tracker']
1131 tracker = None if track is False else msg['tracker']
1125 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1132 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1126 if block:
1133 if block:
1127 try:
1134 try:
1128 return ar.get()
1135 return ar.get()
1129 except KeyboardInterrupt:
1136 except KeyboardInterrupt:
1130 return ar
1137 return ar
1131 else:
1138 else:
1132 return ar
1139 return ar
1133
1140
1134 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1141 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1135 track=None):
1142 track=None):
1136 """Then underlying method for applying functions to specific engines
1143 """Then underlying method for applying functions to specific engines
1137 via the MUX queue.
1144 via the MUX queue.
1138
1145
1139 This is a private method, see `apply` for details.
1146 This is a private method, see `apply` for details.
1140 Not to be called directly!
1147 Not to be called directly!
1141 """
1148 """
1142
1149
1143 if not self._mux_ident:
1150 if not self._mux_ident:
1144 msg = "Multiplexing is disabled"
1151 msg = "Multiplexing is disabled"
1145 raise RuntimeError(msg)
1152 raise RuntimeError(msg)
1146
1153
1147 loc = locals()
1154 loc = locals()
1148 for name in ('bound', 'block', 'targets', 'track'):
1155 for name in ('bound', 'block', 'targets', 'track'):
1149 assert loc[name] is not None, "kwarg %r must be specified!"%name
1156 assert loc[name] is not None, "kwarg %r must be specified!"%name
1150
1157
1151 idents,targets = self._build_targets(targets)
1158 idents,targets = self._build_targets(targets)
1152
1159
1153 subheader = {}
1160 subheader = {}
1154 content = dict(bound=bound)
1161 content = dict(bound=bound)
1155 bufs = util.pack_apply_message(f,args,kwargs)
1162 bufs = util.pack_apply_message(f,args,kwargs)
1156
1163
1157 msg_ids = []
1164 msg_ids = []
1158 trackers = []
1165 trackers = []
1159 for ident in idents:
1166 for ident in idents:
1160 msg = self.session.send(self._apply_socket, "apply_request",
1167 msg = self.session.send(self._apply_socket, "apply_request",
1161 content=content, buffers=bufs, ident=[self._mux_ident, ident], subheader=subheader,
1168 content=content, buffers=bufs, ident=[self._mux_ident, ident], subheader=subheader,
1162 track=track)
1169 track=track)
1163 if track:
1170 if track:
1164 trackers.append(msg['tracker'])
1171 trackers.append(msg['tracker'])
1165 msg_id = msg['msg_id']
1172 msg_id = msg['msg_id']
1166 self.outstanding.add(msg_id)
1173 self.outstanding.add(msg_id)
1167 self._outstanding_dict[ident].add(msg_id)
1174 self._outstanding_dict[ident].add(msg_id)
1168 self.history.append(msg_id)
1175 self.history.append(msg_id)
1169 msg_ids.append(msg_id)
1176 msg_ids.append(msg_id)
1170
1177
1171 tracker = None if track is False else zmq.MessageTracker(*trackers)
1178 tracker = None if track is False else zmq.MessageTracker(*trackers)
1172 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1179 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1173
1180
1174 if block:
1181 if block:
1175 try:
1182 try:
1176 return ar.get()
1183 return ar.get()
1177 except KeyboardInterrupt:
1184 except KeyboardInterrupt:
1178 return ar
1185 return ar
1179 else:
1186 else:
1180 return ar
1187 return ar
1181
1188
1182 #--------------------------------------------------------------------------
1189 #--------------------------------------------------------------------------
1183 # construct a View object
1190 # construct a View object
1184 #--------------------------------------------------------------------------
1191 #--------------------------------------------------------------------------
1185
1192
1186 @defaultblock
1193 @defaultblock
1187 def remote(self, bound=False, block=None, targets=None, balanced=None):
1194 def remote(self, bound=False, block=None, targets=None, balanced=None):
1188 """Decorator for making a RemoteFunction"""
1195 """Decorator for making a RemoteFunction"""
1189 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1196 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1190
1197
1191 @defaultblock
1198 @defaultblock
1192 def parallel(self, dist='b', bound=False, block=None, targets=None, balanced=None):
1199 def parallel(self, dist='b', bound=False, block=None, targets=None, balanced=None):
1193 """Decorator for making a ParallelFunction"""
1200 """Decorator for making a ParallelFunction"""
1194 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1201 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1195
1202
1196 def _cache_view(self, targets, balanced):
1203 def _cache_view(self, targets, balanced):
1197 """save views, so subsequent requests don't create new objects."""
1204 """save views, so subsequent requests don't create new objects."""
1198 if balanced:
1205 if balanced:
1199 view_class = LoadBalancedView
1206 view_class = LoadBalancedView
1200 view_cache = self._balanced_views
1207 view_cache = self._balanced_views
1201 else:
1208 else:
1202 view_class = DirectView
1209 view_class = DirectView
1203 view_cache = self._direct_views
1210 view_cache = self._direct_views
1204
1211
1205 # use str, since often targets will be a list
1212 # use str, since often targets will be a list
1206 key = str(targets)
1213 key = str(targets)
1207 if key not in view_cache:
1214 if key not in view_cache:
1208 view_cache[key] = view_class(client=self, targets=targets)
1215 view_cache[key] = view_class(client=self, targets=targets)
1209
1216
1210 return view_cache[key]
1217 return view_cache[key]
1211
1218
1212 def view(self, targets=None, balanced=None):
1219 def view(self, targets=None, balanced=None):
1213 """Method for constructing View objects.
1220 """Method for constructing View objects.
1214
1221
1215 If no arguments are specified, create a LoadBalancedView
1222 If no arguments are specified, create a LoadBalancedView
1216 using all engines. If only `targets` specified, it will
1223 using all engines. If only `targets` specified, it will
1217 be a DirectView. This method is the underlying implementation
1224 be a DirectView. This method is the underlying implementation
1218 of ``client.__getitem__``.
1225 of ``client.__getitem__``.
1219
1226
1220 Parameters
1227 Parameters
1221 ----------
1228 ----------
1222
1229
1223 targets: list,slice,int,etc. [default: use all engines]
1230 targets: list,slice,int,etc. [default: use all engines]
1224 The engines to use for the View
1231 The engines to use for the View
1225 balanced : bool [default: False if targets specified, True else]
1232 balanced : bool [default: False if targets specified, True else]
1226 whether to build a LoadBalancedView or a DirectView
1233 whether to build a LoadBalancedView or a DirectView
1227
1234
1228 """
1235 """
1229
1236
1230 balanced = (targets is None) if balanced is None else balanced
1237 balanced = (targets is None) if balanced is None else balanced
1231
1238
1232 if targets is None:
1239 if targets is None:
1233 if balanced:
1240 if balanced:
1234 return self._cache_view(None,True)
1241 return self._cache_view(None,True)
1235 else:
1242 else:
1236 targets = slice(None)
1243 targets = slice(None)
1237
1244
1238 if isinstance(targets, int):
1245 if isinstance(targets, int):
1239 if targets < 0:
1246 if targets < 0:
1240 targets = self.ids[targets]
1247 targets = self.ids[targets]
1241 if targets not in self.ids:
1248 if targets not in self.ids:
1242 raise IndexError("No such engine: %i"%targets)
1249 raise IndexError("No such engine: %i"%targets)
1243 return self._cache_view(targets, balanced)
1250 return self._cache_view(targets, balanced)
1244
1251
1245 if isinstance(targets, slice):
1252 if isinstance(targets, slice):
1246 indices = range(len(self.ids))[targets]
1253 indices = range(len(self.ids))[targets]
1247 ids = sorted(self._ids)
1254 ids = sorted(self._ids)
1248 targets = [ ids[i] for i in indices ]
1255 targets = [ ids[i] for i in indices ]
1249
1256
1250 if isinstance(targets, (tuple, list, xrange)):
1257 if isinstance(targets, (tuple, list, xrange)):
1251 _,targets = self._build_targets(list(targets))
1258 _,targets = self._build_targets(list(targets))
1252 return self._cache_view(targets, balanced)
1259 return self._cache_view(targets, balanced)
1253 else:
1260 else:
1254 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1261 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1255
1262
1256 #--------------------------------------------------------------------------
1263 #--------------------------------------------------------------------------
1257 # Data movement
1264 # Data movement
1258 #--------------------------------------------------------------------------
1265 #--------------------------------------------------------------------------
1259
1266
1260 @defaultblock
1267 @defaultblock
1261 def push(self, ns, targets='all', block=None, track=False):
1268 def push(self, ns, targets='all', block=None, track=False):
1262 """Push the contents of `ns` into the namespace on `target`"""
1269 """Push the contents of `ns` into the namespace on `target`"""
1263 if not isinstance(ns, dict):
1270 if not isinstance(ns, dict):
1264 raise TypeError("Must be a dict, not %s"%type(ns))
1271 raise TypeError("Must be a dict, not %s"%type(ns))
1265 result = self.apply(_push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1272 result = self.apply(_push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1266 if not block:
1273 if not block:
1267 return result
1274 return result
1268
1275
1269 @defaultblock
1276 @defaultblock
1270 def pull(self, keys, targets='all', block=None):
1277 def pull(self, keys, targets='all', block=None):
1271 """Pull objects from `target`'s namespace by `keys`"""
1278 """Pull objects from `target`'s namespace by `keys`"""
1272 if isinstance(keys, basestring):
1279 if isinstance(keys, basestring):
1273 pass
1280 pass
1274 elif isinstance(keys, (list,tuple,set)):
1281 elif isinstance(keys, (list,tuple,set)):
1275 for key in keys:
1282 for key in keys:
1276 if not isinstance(key, basestring):
1283 if not isinstance(key, basestring):
1277 raise TypeError("keys must be str, not type %r"%type(key))
1284 raise TypeError("keys must be str, not type %r"%type(key))
1278 else:
1285 else:
1279 raise TypeError("keys must be strs, not %r"%keys)
1286 raise TypeError("keys must be strs, not %r"%keys)
1280 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1287 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1281 return result
1288 return result
1282
1289
1283 @defaultblock
1290 @defaultblock
1284 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1291 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1285 """
1292 """
1286 Partition a Python sequence and send the partitions to a set of engines.
1293 Partition a Python sequence and send the partitions to a set of engines.
1287 """
1294 """
1288 targets = self._build_targets(targets)[-1]
1295 targets = self._build_targets(targets)[-1]
1289 mapObject = Map.dists[dist]()
1296 mapObject = Map.dists[dist]()
1290 nparts = len(targets)
1297 nparts = len(targets)
1291 msg_ids = []
1298 msg_ids = []
1292 trackers = []
1299 trackers = []
1293 for index, engineid in enumerate(targets):
1300 for index, engineid in enumerate(targets):
1294 partition = mapObject.getPartition(seq, index, nparts)
1301 partition = mapObject.getPartition(seq, index, nparts)
1295 if flatten and len(partition) == 1:
1302 if flatten and len(partition) == 1:
1296 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1303 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1297 else:
1304 else:
1298 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1305 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1299 msg_ids.extend(r.msg_ids)
1306 msg_ids.extend(r.msg_ids)
1300 if track:
1307 if track:
1301 trackers.append(r._tracker)
1308 trackers.append(r._tracker)
1302
1309
1303 if track:
1310 if track:
1304 tracker = zmq.MessageTracker(*trackers)
1311 tracker = zmq.MessageTracker(*trackers)
1305 else:
1312 else:
1306 tracker = None
1313 tracker = None
1307
1314
1308 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1315 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1309 if block:
1316 if block:
1310 r.wait()
1317 r.wait()
1311 else:
1318 else:
1312 return r
1319 return r
1313
1320
1314 @defaultblock
1321 @defaultblock
1315 def gather(self, key, dist='b', targets='all', block=None):
1322 def gather(self, key, dist='b', targets='all', block=None):
1316 """
1323 """
1317 Gather a partitioned sequence on a set of engines as a single local seq.
1324 Gather a partitioned sequence on a set of engines as a single local seq.
1318 """
1325 """
1319
1326
1320 targets = self._build_targets(targets)[-1]
1327 targets = self._build_targets(targets)[-1]
1321 mapObject = Map.dists[dist]()
1328 mapObject = Map.dists[dist]()
1322 msg_ids = []
1329 msg_ids = []
1323 for index, engineid in enumerate(targets):
1330 for index, engineid in enumerate(targets):
1324 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1331 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1325
1332
1326 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1333 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1327 if block:
1334 if block:
1328 return r.get()
1335 return r.get()
1329 else:
1336 else:
1330 return r
1337 return r
1331
1338
1332 #--------------------------------------------------------------------------
1339 #--------------------------------------------------------------------------
1333 # Query methods
1340 # Query methods
1334 #--------------------------------------------------------------------------
1341 #--------------------------------------------------------------------------
1335
1342
1336 @spinfirst
1343 @spinfirst
1337 @defaultblock
1344 @defaultblock
1338 def get_result(self, indices_or_msg_ids=None, block=None):
1345 def get_result(self, indices_or_msg_ids=None, block=None):
1339 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1346 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1340
1347
1341 If the client already has the results, no request to the Hub will be made.
1348 If the client already has the results, no request to the Hub will be made.
1342
1349
1343 This is a convenient way to construct AsyncResult objects, which are wrappers
1350 This is a convenient way to construct AsyncResult objects, which are wrappers
1344 that include metadata about execution, and allow for awaiting results that
1351 that include metadata about execution, and allow for awaiting results that
1345 were not submitted by this Client.
1352 were not submitted by this Client.
1346
1353
1347 It can also be a convenient way to retrieve the metadata associated with
1354 It can also be a convenient way to retrieve the metadata associated with
1348 blocking execution, since it always retrieves
1355 blocking execution, since it always retrieves
1349
1356
1350 Examples
1357 Examples
1351 --------
1358 --------
1352 ::
1359 ::
1353
1360
1354 In [10]: r = client.apply()
1361 In [10]: r = client.apply()
1355
1362
1356 Parameters
1363 Parameters
1357 ----------
1364 ----------
1358
1365
1359 indices_or_msg_ids : integer history index, str msg_id, or list of either
1366 indices_or_msg_ids : integer history index, str msg_id, or list of either
1360 The indices or msg_ids of indices to be retrieved
1367 The indices or msg_ids of indices to be retrieved
1361
1368
1362 block : bool
1369 block : bool
1363 Whether to wait for the result to be done
1370 Whether to wait for the result to be done
1364
1371
1365 Returns
1372 Returns
1366 -------
1373 -------
1367
1374
1368 AsyncResult
1375 AsyncResult
1369 A single AsyncResult object will always be returned.
1376 A single AsyncResult object will always be returned.
1370
1377
1371 AsyncHubResult
1378 AsyncHubResult
1372 A subclass of AsyncResult that retrieves results from the Hub
1379 A subclass of AsyncResult that retrieves results from the Hub
1373
1380
1374 """
1381 """
1375 if indices_or_msg_ids is None:
1382 if indices_or_msg_ids is None:
1376 indices_or_msg_ids = -1
1383 indices_or_msg_ids = -1
1377
1384
1378 if not isinstance(indices_or_msg_ids, (list,tuple)):
1385 if not isinstance(indices_or_msg_ids, (list,tuple)):
1379 indices_or_msg_ids = [indices_or_msg_ids]
1386 indices_or_msg_ids = [indices_or_msg_ids]
1380
1387
1381 theids = []
1388 theids = []
1382 for id in indices_or_msg_ids:
1389 for id in indices_or_msg_ids:
1383 if isinstance(id, int):
1390 if isinstance(id, int):
1384 id = self.history[id]
1391 id = self.history[id]
1385 if not isinstance(id, str):
1392 if not isinstance(id, str):
1386 raise TypeError("indices must be str or int, not %r"%id)
1393 raise TypeError("indices must be str or int, not %r"%id)
1387 theids.append(id)
1394 theids.append(id)
1388
1395
1389 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1396 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1390 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1397 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1391
1398
1392 if remote_ids:
1399 if remote_ids:
1393 ar = AsyncHubResult(self, msg_ids=theids)
1400 ar = AsyncHubResult(self, msg_ids=theids)
1394 else:
1401 else:
1395 ar = AsyncResult(self, msg_ids=theids)
1402 ar = AsyncResult(self, msg_ids=theids)
1396
1403
1397 if block:
1404 if block:
1398 ar.wait()
1405 ar.wait()
1399
1406
1400 return ar
1407 return ar
1401
1408
1402 @spinfirst
1409 @spinfirst
1403 def result_status(self, msg_ids, status_only=True):
1410 def result_status(self, msg_ids, status_only=True):
1404 """Check on the status of the result(s) of the apply request with `msg_ids`.
1411 """Check on the status of the result(s) of the apply request with `msg_ids`.
1405
1412
1406 If status_only is False, then the actual results will be retrieved, else
1413 If status_only is False, then the actual results will be retrieved, else
1407 only the status of the results will be checked.
1414 only the status of the results will be checked.
1408
1415
1409 Parameters
1416 Parameters
1410 ----------
1417 ----------
1411
1418
1412 msg_ids : list of msg_ids
1419 msg_ids : list of msg_ids
1413 if int:
1420 if int:
1414 Passed as index to self.history for convenience.
1421 Passed as index to self.history for convenience.
1415 status_only : bool (default: True)
1422 status_only : bool (default: True)
1416 if False:
1423 if False:
1417 Retrieve the actual results of completed tasks.
1424 Retrieve the actual results of completed tasks.
1418
1425
1419 Returns
1426 Returns
1420 -------
1427 -------
1421
1428
1422 results : dict
1429 results : dict
1423 There will always be the keys 'pending' and 'completed', which will
1430 There will always be the keys 'pending' and 'completed', which will
1424 be lists of msg_ids that are incomplete or complete. If `status_only`
1431 be lists of msg_ids that are incomplete or complete. If `status_only`
1425 is False, then completed results will be keyed by their `msg_id`.
1432 is False, then completed results will be keyed by their `msg_id`.
1426 """
1433 """
1427 if not isinstance(msg_ids, (list,tuple)):
1434 if not isinstance(msg_ids, (list,tuple)):
1428 msg_ids = [msg_ids]
1435 msg_ids = [msg_ids]
1429
1436
1430 theids = []
1437 theids = []
1431 for msg_id in msg_ids:
1438 for msg_id in msg_ids:
1432 if isinstance(msg_id, int):
1439 if isinstance(msg_id, int):
1433 msg_id = self.history[msg_id]
1440 msg_id = self.history[msg_id]
1434 if not isinstance(msg_id, basestring):
1441 if not isinstance(msg_id, basestring):
1435 raise TypeError("msg_ids must be str, not %r"%msg_id)
1442 raise TypeError("msg_ids must be str, not %r"%msg_id)
1436 theids.append(msg_id)
1443 theids.append(msg_id)
1437
1444
1438 completed = []
1445 completed = []
1439 local_results = {}
1446 local_results = {}
1440
1447
1441 # comment this block out to temporarily disable local shortcut:
1448 # comment this block out to temporarily disable local shortcut:
1442 for msg_id in theids:
1449 for msg_id in theids:
1443 if msg_id in self.results:
1450 if msg_id in self.results:
1444 completed.append(msg_id)
1451 completed.append(msg_id)
1445 local_results[msg_id] = self.results[msg_id]
1452 local_results[msg_id] = self.results[msg_id]
1446 theids.remove(msg_id)
1453 theids.remove(msg_id)
1447
1454
1448 if theids: # some not locally cached
1455 if theids: # some not locally cached
1449 content = dict(msg_ids=theids, status_only=status_only)
1456 content = dict(msg_ids=theids, status_only=status_only)
1450 msg = self.session.send(self._query_socket, "result_request", content=content)
1457 msg = self.session.send(self._query_socket, "result_request", content=content)
1451 zmq.select([self._query_socket], [], [])
1458 zmq.select([self._query_socket], [], [])
1452 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1459 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1453 if self.debug:
1460 if self.debug:
1454 pprint(msg)
1461 pprint(msg)
1455 content = msg['content']
1462 content = msg['content']
1456 if content['status'] != 'ok':
1463 if content['status'] != 'ok':
1457 raise self._unwrap_exception(content)
1464 raise self._unwrap_exception(content)
1458 buffers = msg['buffers']
1465 buffers = msg['buffers']
1459 else:
1466 else:
1460 content = dict(completed=[],pending=[])
1467 content = dict(completed=[],pending=[])
1461
1468
1462 content['completed'].extend(completed)
1469 content['completed'].extend(completed)
1463
1470
1464 if status_only:
1471 if status_only:
1465 return content
1472 return content
1466
1473
1467 failures = []
1474 failures = []
1468 # load cached results into result:
1475 # load cached results into result:
1469 content.update(local_results)
1476 content.update(local_results)
1470 # update cache with results:
1477 # update cache with results:
1471 for msg_id in sorted(theids):
1478 for msg_id in sorted(theids):
1472 if msg_id in content['completed']:
1479 if msg_id in content['completed']:
1473 rec = content[msg_id]
1480 rec = content[msg_id]
1474 parent = rec['header']
1481 parent = rec['header']
1475 header = rec['result_header']
1482 header = rec['result_header']
1476 rcontent = rec['result_content']
1483 rcontent = rec['result_content']
1477 iodict = rec['io']
1484 iodict = rec['io']
1478 if isinstance(rcontent, str):
1485 if isinstance(rcontent, str):
1479 rcontent = self.session.unpack(rcontent)
1486 rcontent = self.session.unpack(rcontent)
1480
1487
1481 md = self.metadata[msg_id]
1488 md = self.metadata[msg_id]
1482 md.update(self._extract_metadata(header, parent, rcontent))
1489 md.update(self._extract_metadata(header, parent, rcontent))
1483 md.update(iodict)
1490 md.update(iodict)
1484
1491
1485 if rcontent['status'] == 'ok':
1492 if rcontent['status'] == 'ok':
1486 res,buffers = util.unserialize_object(buffers)
1493 res,buffers = util.unserialize_object(buffers)
1487 else:
1494 else:
1488 print rcontent
1495 print rcontent
1489 res = self._unwrap_exception(rcontent)
1496 res = self._unwrap_exception(rcontent)
1490 failures.append(res)
1497 failures.append(res)
1491
1498
1492 self.results[msg_id] = res
1499 self.results[msg_id] = res
1493 content[msg_id] = res
1500 content[msg_id] = res
1494
1501
1495 if len(theids) == 1 and failures:
1502 if len(theids) == 1 and failures:
1496 raise failures[0]
1503 raise failures[0]
1497
1504
1498 error.collect_exceptions(failures, "result_status")
1505 error.collect_exceptions(failures, "result_status")
1499 return content
1506 return content
1500
1507
1501 @spinfirst
1508 @spinfirst
1502 def queue_status(self, targets='all', verbose=False):
1509 def queue_status(self, targets='all', verbose=False):
1503 """Fetch the status of engine queues.
1510 """Fetch the status of engine queues.
1504
1511
1505 Parameters
1512 Parameters
1506 ----------
1513 ----------
1507
1514
1508 targets : int/str/list of ints/strs
1515 targets : int/str/list of ints/strs
1509 the engines whose states are to be queried.
1516 the engines whose states are to be queried.
1510 default : all
1517 default : all
1511 verbose : bool
1518 verbose : bool
1512 Whether to return lengths only, or lists of ids for each element
1519 Whether to return lengths only, or lists of ids for each element
1513 """
1520 """
1514 targets = self._build_targets(targets)[1]
1521 targets = self._build_targets(targets)[1]
1515 content = dict(targets=targets, verbose=verbose)
1522 content = dict(targets=targets, verbose=verbose)
1516 self.session.send(self._query_socket, "queue_request", content=content)
1523 self.session.send(self._query_socket, "queue_request", content=content)
1517 idents,msg = self.session.recv(self._query_socket, 0)
1524 idents,msg = self.session.recv(self._query_socket, 0)
1518 if self.debug:
1525 if self.debug:
1519 pprint(msg)
1526 pprint(msg)
1520 content = msg['content']
1527 content = msg['content']
1521 status = content.pop('status')
1528 status = content.pop('status')
1522 if status != 'ok':
1529 if status != 'ok':
1523 raise self._unwrap_exception(content)
1530 raise self._unwrap_exception(content)
1524 return util.rekey(content)
1531 return util.rekey(content)
1525
1532
1526 @spinfirst
1533 @spinfirst
1527 def purge_results(self, jobs=[], targets=[]):
1534 def purge_results(self, jobs=[], targets=[]):
1528 """Tell the controller to forget results.
1535 """Tell the controller to forget results.
1529
1536
1530 Individual results can be purged by msg_id, or the entire
1537 Individual results can be purged by msg_id, or the entire
1531 history of specific targets can be purged.
1538 history of specific targets can be purged.
1532
1539
1533 Parameters
1540 Parameters
1534 ----------
1541 ----------
1535
1542
1536 jobs : str or list of strs or AsyncResult objects
1543 jobs : str or list of strs or AsyncResult objects
1537 the msg_ids whose results should be forgotten.
1544 the msg_ids whose results should be forgotten.
1538 targets : int/str/list of ints/strs
1545 targets : int/str/list of ints/strs
1539 The targets, by uuid or int_id, whose entire history is to be purged.
1546 The targets, by uuid or int_id, whose entire history is to be purged.
1540 Use `targets='all'` to scrub everything from the controller's memory.
1547 Use `targets='all'` to scrub everything from the controller's memory.
1541
1548
1542 default : None
1549 default : None
1543 """
1550 """
1544 if not targets and not jobs:
1551 if not targets and not jobs:
1545 raise ValueError("Must specify at least one of `targets` and `jobs`")
1552 raise ValueError("Must specify at least one of `targets` and `jobs`")
1546 if targets:
1553 if targets:
1547 targets = self._build_targets(targets)[1]
1554 targets = self._build_targets(targets)[1]
1548
1555
1549 # construct msg_ids from jobs
1556 # construct msg_ids from jobs
1550 msg_ids = []
1557 msg_ids = []
1551 if isinstance(jobs, (basestring,AsyncResult)):
1558 if isinstance(jobs, (basestring,AsyncResult)):
1552 jobs = [jobs]
1559 jobs = [jobs]
1553 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1560 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1554 if bad_ids:
1561 if bad_ids:
1555 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1562 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1556 for j in jobs:
1563 for j in jobs:
1557 if isinstance(j, AsyncResult):
1564 if isinstance(j, AsyncResult):
1558 msg_ids.extend(j.msg_ids)
1565 msg_ids.extend(j.msg_ids)
1559 else:
1566 else:
1560 msg_ids.append(j)
1567 msg_ids.append(j)
1561
1568
1562 content = dict(targets=targets, msg_ids=msg_ids)
1569 content = dict(targets=targets, msg_ids=msg_ids)
1563 self.session.send(self._query_socket, "purge_request", content=content)
1570 self.session.send(self._query_socket, "purge_request", content=content)
1564 idents, msg = self.session.recv(self._query_socket, 0)
1571 idents, msg = self.session.recv(self._query_socket, 0)
1565 if self.debug:
1572 if self.debug:
1566 pprint(msg)
1573 pprint(msg)
1567 content = msg['content']
1574 content = msg['content']
1568 if content['status'] != 'ok':
1575 if content['status'] != 'ok':
1569 raise self._unwrap_exception(content)
1576 raise self._unwrap_exception(content)
1570
1577
1571
1578
1572 __all__ = [ 'Client',
1579 __all__ = [ 'Client',
1573 'depend',
1580 'depend',
1574 'require',
1581 'require',
1575 'remote',
1582 'remote',
1576 'parallel',
1583 'parallel',
1577 'RemoteFunction',
1584 'RemoteFunction',
1578 'ParallelFunction',
1585 'ParallelFunction',
1579 'DirectView',
1586 'DirectView',
1580 'LoadBalancedView',
1587 'LoadBalancedView',
1581 'AsyncResult',
1588 'AsyncResult',
1582 'AsyncMapResult',
1589 'AsyncMapResult',
1583 'Reference'
1590 'Reference'
1584 ]
1591 ]
@@ -1,105 +1,106 b''
1 import sys
1 import sys
2 import tempfile
2 import tempfile
3 import time
3 import time
4 from signal import SIGINT
4 from signal import SIGINT
5 from multiprocessing import Process
5 from multiprocessing import Process
6
6
7 from nose import SkipTest
7 from nose import SkipTest
8
8
9 from zmq.tests import BaseZMQTestCase
9 from zmq.tests import BaseZMQTestCase
10
10
11 from IPython.external.decorator import decorator
11 from IPython.external.decorator import decorator
12
12
13 from IPython.zmq.parallel import error
13 from IPython.zmq.parallel import error
14 from IPython.zmq.parallel.client import Client
14 from IPython.zmq.parallel.client import Client
15 from IPython.zmq.parallel.ipcluster import launch_process
15 from IPython.zmq.parallel.ipcluster import launch_process
16 from IPython.zmq.parallel.entry_point import select_random_ports
16 from IPython.zmq.parallel.entry_point import select_random_ports
17 from IPython.zmq.parallel.tests import processes,add_engine
17 from IPython.zmq.parallel.tests import processes,add_engine
18
18
19 # simple tasks for use in apply tests
19 # simple tasks for use in apply tests
20
20
21 def segfault():
21 def segfault():
22 """this will segfault"""
22 """this will segfault"""
23 import ctypes
23 import ctypes
24 ctypes.memset(-1,0,1)
24 ctypes.memset(-1,0,1)
25
25
26 def wait(n):
26 def wait(n):
27 """sleep for a time"""
27 """sleep for a time"""
28 import time
28 import time
29 time.sleep(n)
29 time.sleep(n)
30 return n
30 return n
31
31
32 def raiser(eclass):
32 def raiser(eclass):
33 """raise an exception"""
33 """raise an exception"""
34 raise eclass()
34 raise eclass()
35
35
36 # test decorator for skipping tests when libraries are unavailable
36 # test decorator for skipping tests when libraries are unavailable
37 def skip_without(*names):
37 def skip_without(*names):
38 """skip a test if some names are not importable"""
38 """skip a test if some names are not importable"""
39 @decorator
39 @decorator
40 def skip_without_names(f, *args, **kwargs):
40 def skip_without_names(f, *args, **kwargs):
41 """decorator to skip tests in the absence of numpy."""
41 """decorator to skip tests in the absence of numpy."""
42 for name in names:
42 for name in names:
43 try:
43 try:
44 __import__(name)
44 __import__(name)
45 except ImportError:
45 except ImportError:
46 raise SkipTest
46 raise SkipTest
47 return f(*args, **kwargs)
47 return f(*args, **kwargs)
48 return skip_without_names
48 return skip_without_names
49
49
50
50
51 class ClusterTestCase(BaseZMQTestCase):
51 class ClusterTestCase(BaseZMQTestCase):
52
52
53 def add_engines(self, n=1, block=True):
53 def add_engines(self, n=1, block=True):
54 """add multiple engines to our cluster"""
54 """add multiple engines to our cluster"""
55 for i in range(n):
55 for i in range(n):
56 self.engines.append(add_engine())
56 self.engines.append(add_engine())
57 if block:
57 if block:
58 self.wait_on_engines()
58 self.wait_on_engines()
59
59
60 def wait_on_engines(self, timeout=5):
60 def wait_on_engines(self, timeout=5):
61 """wait for our engines to connect."""
61 """wait for our engines to connect."""
62 n = len(self.engines)+self.base_engine_count
62 n = len(self.engines)+self.base_engine_count
63 tic = time.time()
63 tic = time.time()
64 while time.time()-tic < timeout and len(self.client.ids) < n:
64 while time.time()-tic < timeout and len(self.client.ids) < n:
65 time.sleep(0.1)
65 time.sleep(0.1)
66
66
67 assert not len(self.client.ids) < n, "waiting for engines timed out"
67 assert not len(self.client.ids) < n, "waiting for engines timed out"
68
68
69 def connect_client(self):
69 def connect_client(self):
70 """connect a client with my Context, and track its sockets for cleanup"""
70 """connect a client with my Context, and track its sockets for cleanup"""
71 c = Client(profile='iptest',context=self.context)
71 c = Client(profile='iptest',context=self.context)
72 for name in filter(lambda n:n.endswith('socket'), dir(c)):
72
73 self.sockets.append(getattr(c, name))
73 # for name in filter(lambda n:n.endswith('socket'), dir(c)):
74 # self.sockets.append(getattr(c, name))
74 return c
75 return c
75
76
76 def assertRaisesRemote(self, etype, f, *args, **kwargs):
77 def assertRaisesRemote(self, etype, f, *args, **kwargs):
77 try:
78 try:
78 try:
79 try:
79 f(*args, **kwargs)
80 f(*args, **kwargs)
80 except error.CompositeError as e:
81 except error.CompositeError as e:
81 e.raise_exception()
82 e.raise_exception()
82 except error.RemoteError as e:
83 except error.RemoteError as e:
83 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
84 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
84 else:
85 else:
85 self.fail("should have raised a RemoteError")
86 self.fail("should have raised a RemoteError")
86
87
87 def setUp(self):
88 def setUp(self):
88 BaseZMQTestCase.setUp(self)
89 BaseZMQTestCase.setUp(self)
89 self.client = self.connect_client()
90 self.client = self.connect_client()
90 self.base_engine_count=len(self.client.ids)
91 self.base_engine_count=len(self.client.ids)
91 self.engines=[]
92 self.engines=[]
92
93
93 def tearDown(self):
94 def tearDown(self):
94
95
95 # close fds:
96 # close fds:
96 for e in filter(lambda e: e.poll() is not None, processes):
97 for e in filter(lambda e: e.poll() is not None, processes):
97 processes.remove(e)
98 processes.remove(e)
98
99
99 self.client.close()
100 self.client.close()
100 BaseZMQTestCase.tearDown(self)
101 BaseZMQTestCase.tearDown(self)
101 # this will be superfluous when pyzmq merges PR #88
102 # this will be superfluous when pyzmq merges PR #88
102 self.context.term()
103 self.context.term()
103 print tempfile.TemporaryFile().fileno(),
104 # print tempfile.TemporaryFile().fileno(),
104 sys.stdout.flush()
105 # sys.stdout.flush()
105 No newline at end of file
106
@@ -1,262 +1,262 b''
1 import time
1 import time
2 from tempfile import mktemp
2 from tempfile import mktemp
3
3
4 import nose.tools as nt
5 import zmq
4 import zmq
6
5
7 from IPython.zmq.parallel import client as clientmod
6 from IPython.zmq.parallel import client as clientmod
8 from IPython.zmq.parallel import error
7 from IPython.zmq.parallel import error
9 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
8 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
10 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
9 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
11
10
12 from clienttest import ClusterTestCase, segfault, wait
11 from clienttest import ClusterTestCase, segfault, wait
13
12
14 class TestClient(ClusterTestCase):
13 class TestClient(ClusterTestCase):
15
14
16 def test_ids(self):
15 def test_ids(self):
17 n = len(self.client.ids)
16 n = len(self.client.ids)
18 self.add_engines(3)
17 self.add_engines(3)
19 self.assertEquals(len(self.client.ids), n+3)
18 self.assertEquals(len(self.client.ids), n+3)
20
19
21 def test_segfault_task(self):
20 def test_segfault_task(self):
22 """test graceful handling of engine death (balanced)"""
21 """test graceful handling of engine death (balanced)"""
23 self.add_engines(1)
22 self.add_engines(1)
24 ar = self.client.apply(segfault, block=False)
23 ar = self.client.apply(segfault, block=False)
25 self.assertRaisesRemote(error.EngineError, ar.get)
24 self.assertRaisesRemote(error.EngineError, ar.get)
26 eid = ar.engine_id
25 eid = ar.engine_id
27 while eid in self.client.ids:
26 while eid in self.client.ids:
28 time.sleep(.01)
27 time.sleep(.01)
29 self.client.spin()
28 self.client.spin()
30
29
31 def test_segfault_mux(self):
30 def test_segfault_mux(self):
32 """test graceful handling of engine death (direct)"""
31 """test graceful handling of engine death (direct)"""
33 self.add_engines(1)
32 self.add_engines(1)
34 eid = self.client.ids[-1]
33 eid = self.client.ids[-1]
35 ar = self.client[eid].apply_async(segfault)
34 ar = self.client[eid].apply_async(segfault)
36 self.assertRaisesRemote(error.EngineError, ar.get)
35 self.assertRaisesRemote(error.EngineError, ar.get)
37 eid = ar.engine_id
36 eid = ar.engine_id
38 while eid in self.client.ids:
37 while eid in self.client.ids:
39 time.sleep(.01)
38 time.sleep(.01)
40 self.client.spin()
39 self.client.spin()
41
40
42 def test_view_indexing(self):
41 def test_view_indexing(self):
43 """test index access for views"""
42 """test index access for views"""
44 self.add_engines(2)
43 self.add_engines(2)
45 targets = self.client._build_targets('all')[-1]
44 targets = self.client._build_targets('all')[-1]
46 v = self.client[:]
45 v = self.client[:]
47 self.assertEquals(v.targets, targets)
46 self.assertEquals(v.targets, targets)
48 t = self.client.ids[2]
47 t = self.client.ids[2]
49 v = self.client[t]
48 v = self.client[t]
50 self.assert_(isinstance(v, DirectView))
49 self.assert_(isinstance(v, DirectView))
51 self.assertEquals(v.targets, t)
50 self.assertEquals(v.targets, t)
52 t = self.client.ids[2:4]
51 t = self.client.ids[2:4]
53 v = self.client[t]
52 v = self.client[t]
54 self.assert_(isinstance(v, DirectView))
53 self.assert_(isinstance(v, DirectView))
55 self.assertEquals(v.targets, t)
54 self.assertEquals(v.targets, t)
56 v = self.client[::2]
55 v = self.client[::2]
57 self.assert_(isinstance(v, DirectView))
56 self.assert_(isinstance(v, DirectView))
58 self.assertEquals(v.targets, targets[::2])
57 self.assertEquals(v.targets, targets[::2])
59 v = self.client[1::3]
58 v = self.client[1::3]
60 self.assert_(isinstance(v, DirectView))
59 self.assert_(isinstance(v, DirectView))
61 self.assertEquals(v.targets, targets[1::3])
60 self.assertEquals(v.targets, targets[1::3])
62 v = self.client[:-3]
61 v = self.client[:-3]
63 self.assert_(isinstance(v, DirectView))
62 self.assert_(isinstance(v, DirectView))
64 self.assertEquals(v.targets, targets[:-3])
63 self.assertEquals(v.targets, targets[:-3])
65 v = self.client[-1]
64 v = self.client[-1]
66 self.assert_(isinstance(v, DirectView))
65 self.assert_(isinstance(v, DirectView))
67 self.assertEquals(v.targets, targets[-1])
66 self.assertEquals(v.targets, targets[-1])
68 nt.assert_raises(TypeError, lambda : self.client[None])
67 self.assertRaises(TypeError, lambda : self.client[None])
69
68
70 def test_view_cache(self):
69 def test_view_cache(self):
71 """test that multiple view requests return the same object"""
70 """test that multiple view requests return the same object"""
72 v = self.client[:2]
71 v = self.client[:2]
73 v2 =self.client[:2]
72 v2 =self.client[:2]
74 self.assertTrue(v is v2)
73 self.assertTrue(v is v2)
75 v = self.client.view()
74 v = self.client.view()
76 v2 = self.client.view(balanced=True)
75 v2 = self.client.view(balanced=True)
77 self.assertTrue(v is v2)
76 self.assertTrue(v is v2)
78
77
79 def test_targets(self):
78 def test_targets(self):
80 """test various valid targets arguments"""
79 """test various valid targets arguments"""
81 build = self.client._build_targets
80 build = self.client._build_targets
82 ids = self.client.ids
81 ids = self.client.ids
83 idents,targets = build(None)
82 idents,targets = build(None)
84 self.assertEquals(ids, targets)
83 self.assertEquals(ids, targets)
85
84
86 def test_clear(self):
85 def test_clear(self):
87 """test clear behavior"""
86 """test clear behavior"""
88 self.add_engines(2)
87 self.add_engines(2)
89 self.client.block=True
88 self.client.block=True
90 self.client.push(dict(a=5))
89 self.client.push(dict(a=5))
91 self.client.pull('a')
90 self.client.pull('a')
92 id0 = self.client.ids[-1]
91 id0 = self.client.ids[-1]
93 self.client.clear(targets=id0)
92 self.client.clear(targets=id0)
94 self.client.pull('a', targets=self.client.ids[:-1])
93 self.client.pull('a', targets=self.client.ids[:-1])
95 self.assertRaisesRemote(NameError, self.client.pull, 'a')
94 self.assertRaisesRemote(NameError, self.client.pull, 'a')
96 self.client.clear()
95 self.client.clear()
97 for i in self.client.ids:
96 for i in self.client.ids:
98 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
97 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
99
98
100
99
101 def test_push_pull(self):
100 def test_push_pull(self):
102 """test pushing and pulling"""
101 """test pushing and pulling"""
103 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
102 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
104 t = self.client.ids[-1]
103 t = self.client.ids[-1]
105 self.add_engines(2)
104 self.add_engines(2)
106 push = self.client.push
105 push = self.client.push
107 pull = self.client.pull
106 pull = self.client.pull
108 self.client.block=True
107 self.client.block=True
109 nengines = len(self.client)
108 nengines = len(self.client)
110 push({'data':data}, targets=t)
109 push({'data':data}, targets=t)
111 d = pull('data', targets=t)
110 d = pull('data', targets=t)
112 self.assertEquals(d, data)
111 self.assertEquals(d, data)
113 push({'data':data})
112 push({'data':data})
114 d = pull('data')
113 d = pull('data')
115 self.assertEquals(d, nengines*[data])
114 self.assertEquals(d, nengines*[data])
116 ar = push({'data':data}, block=False)
115 ar = push({'data':data}, block=False)
117 self.assertTrue(isinstance(ar, AsyncResult))
116 self.assertTrue(isinstance(ar, AsyncResult))
118 r = ar.get()
117 r = ar.get()
119 ar = pull('data', block=False)
118 ar = pull('data', block=False)
120 self.assertTrue(isinstance(ar, AsyncResult))
119 self.assertTrue(isinstance(ar, AsyncResult))
121 r = ar.get()
120 r = ar.get()
122 self.assertEquals(r, nengines*[data])
121 self.assertEquals(r, nengines*[data])
123 push(dict(a=10,b=20))
122 push(dict(a=10,b=20))
124 r = pull(('a','b'))
123 r = pull(('a','b'))
125 self.assertEquals(r, nengines*[[10,20]])
124 self.assertEquals(r, nengines*[[10,20]])
126
125
127 def test_push_pull_function(self):
126 def test_push_pull_function(self):
128 "test pushing and pulling functions"
127 "test pushing and pulling functions"
129 def testf(x):
128 def testf(x):
130 return 2.0*x
129 return 2.0*x
131
130
132 self.add_engines(4)
131 self.add_engines(4)
133 t = self.client.ids[-1]
132 t = self.client.ids[-1]
134 self.client.block=True
133 self.client.block=True
135 push = self.client.push
134 push = self.client.push
136 pull = self.client.pull
135 pull = self.client.pull
137 execute = self.client.execute
136 execute = self.client.execute
138 push({'testf':testf}, targets=t)
137 push({'testf':testf}, targets=t)
139 r = pull('testf', targets=t)
138 r = pull('testf', targets=t)
140 self.assertEqual(r(1.0), testf(1.0))
139 self.assertEqual(r(1.0), testf(1.0))
141 execute('r = testf(10)', targets=t)
140 execute('r = testf(10)', targets=t)
142 r = pull('r', targets=t)
141 r = pull('r', targets=t)
143 self.assertEquals(r, testf(10))
142 self.assertEquals(r, testf(10))
144 ar = push({'testf':testf}, block=False)
143 ar = push({'testf':testf}, block=False)
145 ar.get()
144 ar.get()
146 ar = pull('testf', block=False)
145 ar = pull('testf', block=False)
147 rlist = ar.get()
146 rlist = ar.get()
148 for r in rlist:
147 for r in rlist:
149 self.assertEqual(r(1.0), testf(1.0))
148 self.assertEqual(r(1.0), testf(1.0))
150 execute("def g(x): return x*x", targets=t)
149 execute("def g(x): return x*x", targets=t)
151 r = pull(('testf','g'),targets=t)
150 r = pull(('testf','g'),targets=t)
152 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
151 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
153
152
154 def test_push_function_globals(self):
153 def test_push_function_globals(self):
155 """test that pushed functions have access to globals"""
154 """test that pushed functions have access to globals"""
156 def geta():
155 def geta():
157 return a
156 return a
158 self.add_engines(1)
157 self.add_engines(1)
159 v = self.client[-1]
158 v = self.client[-1]
160 v.block=True
159 v.block=True
161 v['f'] = geta
160 v['f'] = geta
162 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
161 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
163 v.execute('a=5')
162 v.execute('a=5')
164 v.execute('b=f()')
163 v.execute('b=f()')
165 self.assertEquals(v['b'], 5)
164 self.assertEquals(v['b'], 5)
166
165
167 def test_push_function_defaults(self):
166 def test_push_function_defaults(self):
168 """test that pushed functions preserve default args"""
167 """test that pushed functions preserve default args"""
169 def echo(a=10):
168 def echo(a=10):
170 return a
169 return a
171 self.add_engines(1)
170 self.add_engines(1)
172 v = self.client[-1]
171 v = self.client[-1]
173 v.block=True
172 v.block=True
174 v['f'] = echo
173 v['f'] = echo
175 v.execute('b=f()')
174 v.execute('b=f()')
176 self.assertEquals(v['b'], 10)
175 self.assertEquals(v['b'], 10)
177
176
178 def test_get_result(self):
177 def test_get_result(self):
179 """test getting results from the Hub."""
178 """test getting results from the Hub."""
180 c = clientmod.Client(profile='iptest')
179 c = clientmod.Client(profile='iptest')
181 self.add_engines(1)
180 self.add_engines(1)
181 t = c.ids[-1]
182 ar = c.apply(wait, (1,), block=False, targets=t)
182 ar = c.apply(wait, (1,), block=False, targets=t)
183 # give the monitor time to notice the message
183 # give the monitor time to notice the message
184 time.sleep(.25)
184 time.sleep(.25)
185 ahr = self.client.get_result(ar.msg_ids)
185 ahr = self.client.get_result(ar.msg_ids)
186 self.assertTrue(isinstance(ahr, AsyncHubResult))
186 self.assertTrue(isinstance(ahr, AsyncHubResult))
187 self.assertEquals(ahr.get(), ar.get())
187 self.assertEquals(ahr.get(), ar.get())
188 ar2 = self.client.get_result(ar.msg_ids)
188 ar2 = self.client.get_result(ar.msg_ids)
189 self.assertFalse(isinstance(ar2, AsyncHubResult))
189 self.assertFalse(isinstance(ar2, AsyncHubResult))
190
190
191 def test_ids_list(self):
191 def test_ids_list(self):
192 """test client.ids"""
192 """test client.ids"""
193 self.add_engines(2)
193 self.add_engines(2)
194 ids = self.client.ids
194 ids = self.client.ids
195 self.assertEquals(ids, self.client._ids)
195 self.assertEquals(ids, self.client._ids)
196 self.assertFalse(ids is self.client._ids)
196 self.assertFalse(ids is self.client._ids)
197 ids.remove(ids[-1])
197 ids.remove(ids[-1])
198 self.assertNotEquals(ids, self.client._ids)
198 self.assertNotEquals(ids, self.client._ids)
199
199
200 def test_run_newline(self):
200 def test_run_newline(self):
201 """test that run appends newline to files"""
201 """test that run appends newline to files"""
202 tmpfile = mktemp()
202 tmpfile = mktemp()
203 with open(tmpfile, 'w') as f:
203 with open(tmpfile, 'w') as f:
204 f.write("""def g():
204 f.write("""def g():
205 return 5
205 return 5
206 """)
206 """)
207 v = self.client[-1]
207 v = self.client[-1]
208 v.run(tmpfile, block=True)
208 v.run(tmpfile, block=True)
209 self.assertEquals(v.apply_sync(lambda : g()), 5)
209 self.assertEquals(v.apply_sync(lambda : g()), 5)
210
210
211 def test_apply_tracked(self):
211 def test_apply_tracked(self):
212 """test tracking for apply"""
212 """test tracking for apply"""
213 # self.add_engines(1)
213 # self.add_engines(1)
214 t = self.client.ids[-1]
214 t = self.client.ids[-1]
215 self.client.block=False
215 self.client.block=False
216 def echo(n=1024*1024, **kwargs):
216 def echo(n=1024*1024, **kwargs):
217 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
217 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
218 ar = echo(1)
218 ar = echo(1)
219 self.assertTrue(ar._tracker is None)
219 self.assertTrue(ar._tracker is None)
220 self.assertTrue(ar.sent)
220 self.assertTrue(ar.sent)
221 ar = echo(track=True)
221 ar = echo(track=True)
222 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
222 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
223 self.assertEquals(ar.sent, ar._tracker.done)
223 self.assertEquals(ar.sent, ar._tracker.done)
224 ar._tracker.wait()
224 ar._tracker.wait()
225 self.assertTrue(ar.sent)
225 self.assertTrue(ar.sent)
226
226
227 def test_push_tracked(self):
227 def test_push_tracked(self):
228 t = self.client.ids[-1]
228 t = self.client.ids[-1]
229 ns = dict(x='x'*1024*1024)
229 ns = dict(x='x'*1024*1024)
230 ar = self.client.push(ns, targets=t, block=False)
230 ar = self.client.push(ns, targets=t, block=False)
231 self.assertTrue(ar._tracker is None)
231 self.assertTrue(ar._tracker is None)
232 self.assertTrue(ar.sent)
232 self.assertTrue(ar.sent)
233
233
234 ar = self.client.push(ns, targets=t, block=False, track=True)
234 ar = self.client.push(ns, targets=t, block=False, track=True)
235 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
235 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
236 self.assertEquals(ar.sent, ar._tracker.done)
236 self.assertEquals(ar.sent, ar._tracker.done)
237 ar._tracker.wait()
237 ar._tracker.wait()
238 self.assertTrue(ar.sent)
238 self.assertTrue(ar.sent)
239 ar.get()
239 ar.get()
240
240
241 def test_scatter_tracked(self):
241 def test_scatter_tracked(self):
242 t = self.client.ids
242 t = self.client.ids
243 x='x'*1024*1024
243 x='x'*1024*1024
244 ar = self.client.scatter('x', x, targets=t, block=False)
244 ar = self.client.scatter('x', x, targets=t, block=False)
245 self.assertTrue(ar._tracker is None)
245 self.assertTrue(ar._tracker is None)
246 self.assertTrue(ar.sent)
246 self.assertTrue(ar.sent)
247
247
248 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
248 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
249 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
249 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
250 self.assertEquals(ar.sent, ar._tracker.done)
250 self.assertEquals(ar.sent, ar._tracker.done)
251 ar._tracker.wait()
251 ar._tracker.wait()
252 self.assertTrue(ar.sent)
252 self.assertTrue(ar.sent)
253 ar.get()
253 ar.get()
254
254
255 def test_remote_reference(self):
255 def test_remote_reference(self):
256 v = self.client[-1]
256 v = self.client[-1]
257 v['a'] = 123
257 v['a'] = 123
258 ra = clientmod.Reference('a')
258 ra = clientmod.Reference('a')
259 b = v.apply_sync(lambda x: x, ra)
259 b = v.apply_sync(lambda x: x, ra)
260 self.assertEquals(b, 123)
260 self.assertEquals(b, 123)
261
261
262
262
@@ -1,89 +1,87 b''
1 """test serialization with newserialized"""
1 """test serialization with newserialized"""
2
2
3 from unittest import TestCase
3 from unittest import TestCase
4
4
5 import nose.tools as nt
6
7 from IPython.testing.parametric import parametric
5 from IPython.testing.parametric import parametric
8 from IPython.utils import newserialized as ns
6 from IPython.utils import newserialized as ns
9 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
7 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
10 from IPython.zmq.parallel.tests.clienttest import skip_without
8 from IPython.zmq.parallel.tests.clienttest import skip_without
11
9
12
10
13 class CanningTestCase(TestCase):
11 class CanningTestCase(TestCase):
14 def test_canning(self):
12 def test_canning(self):
15 d = dict(a=5,b=6)
13 d = dict(a=5,b=6)
16 cd = can(d)
14 cd = can(d)
17 nt.assert_true(isinstance(cd, dict))
15 self.assertTrue(isinstance(cd, dict))
18
16
19 def test_canned_function(self):
17 def test_canned_function(self):
20 f = lambda : 7
18 f = lambda : 7
21 cf = can(f)
19 cf = can(f)
22 nt.assert_true(isinstance(cf, CannedFunction))
20 self.assertTrue(isinstance(cf, CannedFunction))
23
21
24 @parametric
22 @parametric
25 def test_can_roundtrip(cls):
23 def test_can_roundtrip(cls):
26 objs = [
24 objs = [
27 dict(),
25 dict(),
28 set(),
26 set(),
29 list(),
27 list(),
30 ['a',1,['a',1],u'e'],
28 ['a',1,['a',1],u'e'],
31 ]
29 ]
32 return map(cls.run_roundtrip, objs)
30 return map(cls.run_roundtrip, objs)
33
31
34 @classmethod
32 @classmethod
35 def run_roundtrip(cls, obj):
33 def run_roundtrip(self, obj):
36 o = uncan(can(obj))
34 o = uncan(can(obj))
37 nt.assert_equals(obj, o)
35 assert o == obj, "failed assertion: %r == %r"%(o,obj)
38
36
39 def test_serialized_interfaces(self):
37 def test_serialized_interfaces(self):
40
38
41 us = {'a':10, 'b':range(10)}
39 us = {'a':10, 'b':range(10)}
42 s = ns.serialize(us)
40 s = ns.serialize(us)
43 uus = ns.unserialize(s)
41 uus = ns.unserialize(s)
44 nt.assert_true(isinstance(s, ns.SerializeIt))
42 self.assertTrue(isinstance(s, ns.SerializeIt))
45 nt.assert_equals(uus, us)
43 self.assertEquals(uus, us)
46
44
47 def test_pickle_serialized(self):
45 def test_pickle_serialized(self):
48 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
46 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
49 original = ns.UnSerialized(obj)
47 original = ns.UnSerialized(obj)
50 originalSer = ns.SerializeIt(original)
48 originalSer = ns.SerializeIt(original)
51 firstData = originalSer.getData()
49 firstData = originalSer.getData()
52 firstTD = originalSer.getTypeDescriptor()
50 firstTD = originalSer.getTypeDescriptor()
53 firstMD = originalSer.getMetadata()
51 firstMD = originalSer.getMetadata()
54 nt.assert_equals(firstTD, 'pickle')
52 self.assertEquals(firstTD, 'pickle')
55 nt.assert_equals(firstMD, {})
53 self.assertEquals(firstMD, {})
56 unSerialized = ns.UnSerializeIt(originalSer)
54 unSerialized = ns.UnSerializeIt(originalSer)
57 secondObj = unSerialized.getObject()
55 secondObj = unSerialized.getObject()
58 for k, v in secondObj.iteritems():
56 for k, v in secondObj.iteritems():
59 nt.assert_equals(obj[k], v)
57 self.assertEquals(obj[k], v)
60 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
58 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
61 nt.assert_equals(firstData, secondSer.getData())
59 self.assertEquals(firstData, secondSer.getData())
62 nt.assert_equals(firstTD, secondSer.getTypeDescriptor() )
60 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
63 nt.assert_equals(firstMD, secondSer.getMetadata())
61 self.assertEquals(firstMD, secondSer.getMetadata())
64
62
65 @skip_without('numpy')
63 @skip_without('numpy')
66 def test_ndarray_serialized(self):
64 def test_ndarray_serialized(self):
67 import numpy
65 import numpy
68 a = numpy.linspace(0.0, 1.0, 1000)
66 a = numpy.linspace(0.0, 1.0, 1000)
69 unSer1 = ns.UnSerialized(a)
67 unSer1 = ns.UnSerialized(a)
70 ser1 = ns.SerializeIt(unSer1)
68 ser1 = ns.SerializeIt(unSer1)
71 td = ser1.getTypeDescriptor()
69 td = ser1.getTypeDescriptor()
72 nt.assert_equals(td, 'ndarray')
70 self.assertEquals(td, 'ndarray')
73 md = ser1.getMetadata()
71 md = ser1.getMetadata()
74 nt.assert_equals(md['shape'], a.shape)
72 self.assertEquals(md['shape'], a.shape)
75 nt.assert_equals(md['dtype'], a.dtype.str)
73 self.assertEquals(md['dtype'], a.dtype.str)
76 buff = ser1.getData()
74 buff = ser1.getData()
77 nt.assert_equals(buff, numpy.getbuffer(a))
75 self.assertEquals(buff, numpy.getbuffer(a))
78 s = ns.Serialized(buff, td, md)
76 s = ns.Serialized(buff, td, md)
79 final = ns.unserialize(s)
77 final = ns.unserialize(s)
80 nt.assert_equals(numpy.getbuffer(a), numpy.getbuffer(final))
78 self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
81 nt.assert_true((a==final).all())
79 self.assertTrue((a==final).all())
82 nt.assert_equals(a.dtype.str, final.dtype.str)
80 self.assertEquals(a.dtype.str, final.dtype.str)
83 nt.assert_equals(a.shape, final.shape)
81 self.assertEquals(a.shape, final.shape)
84 # test non-copying:
82 # test non-copying:
85 a[2] = 1e9
83 a[2] = 1e9
86 nt.assert_true((a==final).all())
84 self.assertTrue((a==final).all())
87
85
88
86
89 No newline at end of file
87
General Comments 0
You need to be logged in to leave comments. Login now