##// END OF EJS Templates
update client.get_result to match AsyncResult behavior
MinRK -
Show More
@@ -1,1824 +1,1831 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 from threading import Thread, Event
21 from threading import Thread, Event
22 import time
22 import time
23 import warnings
23 import warnings
24 from datetime import datetime
24 from datetime import datetime
25 from getpass import getpass
25 from getpass import getpass
26 from pprint import pprint
26 from pprint import pprint
27
27
28 pjoin = os.path.join
28 pjoin = os.path.join
29
29
30 import zmq
30 import zmq
31 # from zmq.eventloop import ioloop, zmqstream
31 # from zmq.eventloop import ioloop, zmqstream
32
32
33 from IPython.config.configurable import MultipleInstanceError
33 from IPython.config.configurable import MultipleInstanceError
34 from IPython.core.application import BaseIPythonApplication
34 from IPython.core.application import BaseIPythonApplication
35 from IPython.core.profiledir import ProfileDir, ProfileDirError
35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36
36
37 from IPython.utils.coloransi import TermColors
37 from IPython.utils.coloransi import TermColors
38 from IPython.utils.jsonutil import rekey
38 from IPython.utils.jsonutil import rekey
39 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
39 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
40 from IPython.utils.path import get_ipython_dir
40 from IPython.utils.path import get_ipython_dir
41 from IPython.utils.py3compat import cast_bytes
41 from IPython.utils.py3compat import cast_bytes
42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 Dict, List, Bool, Set, Any)
43 Dict, List, Bool, Set, Any)
44 from IPython.external.decorator import decorator
44 from IPython.external.decorator import decorator
45 from IPython.external.ssh import tunnel
45 from IPython.external.ssh import tunnel
46
46
47 from IPython.parallel import Reference
47 from IPython.parallel import Reference
48 from IPython.parallel import error
48 from IPython.parallel import error
49 from IPython.parallel import util
49 from IPython.parallel import util
50
50
51 from IPython.kernel.zmq.session import Session, Message
51 from IPython.kernel.zmq.session import Session, Message
52 from IPython.kernel.zmq import serialize
52 from IPython.kernel.zmq import serialize
53
53
54 from .asyncresult import AsyncResult, AsyncHubResult
54 from .asyncresult import AsyncResult, AsyncHubResult
55 from .view import DirectView, LoadBalancedView
55 from .view import DirectView, LoadBalancedView
56
56
57 if sys.version_info[0] >= 3:
57 if sys.version_info[0] >= 3:
58 # xrange is used in a couple 'isinstance' tests in py2
58 # xrange is used in a couple 'isinstance' tests in py2
59 # should be just 'range' in 3k
59 # should be just 'range' in 3k
60 xrange = range
60 xrange = range
61
61
62 #--------------------------------------------------------------------------
62 #--------------------------------------------------------------------------
63 # Decorators for Client methods
63 # Decorators for Client methods
64 #--------------------------------------------------------------------------
64 #--------------------------------------------------------------------------
65
65
66 @decorator
66 @decorator
67 def spin_first(f, self, *args, **kwargs):
67 def spin_first(f, self, *args, **kwargs):
68 """Call spin() to sync state prior to calling the method."""
68 """Call spin() to sync state prior to calling the method."""
69 self.spin()
69 self.spin()
70 return f(self, *args, **kwargs)
70 return f(self, *args, **kwargs)
71
71
72
72
73 #--------------------------------------------------------------------------
73 #--------------------------------------------------------------------------
74 # Classes
74 # Classes
75 #--------------------------------------------------------------------------
75 #--------------------------------------------------------------------------
76
76
77
77
78 class ExecuteReply(object):
78 class ExecuteReply(object):
79 """wrapper for finished Execute results"""
79 """wrapper for finished Execute results"""
80 def __init__(self, msg_id, content, metadata):
80 def __init__(self, msg_id, content, metadata):
81 self.msg_id = msg_id
81 self.msg_id = msg_id
82 self._content = content
82 self._content = content
83 self.execution_count = content['execution_count']
83 self.execution_count = content['execution_count']
84 self.metadata = metadata
84 self.metadata = metadata
85
85
86 def __getitem__(self, key):
86 def __getitem__(self, key):
87 return self.metadata[key]
87 return self.metadata[key]
88
88
89 def __getattr__(self, key):
89 def __getattr__(self, key):
90 if key not in self.metadata:
90 if key not in self.metadata:
91 raise AttributeError(key)
91 raise AttributeError(key)
92 return self.metadata[key]
92 return self.metadata[key]
93
93
94 def __repr__(self):
94 def __repr__(self):
95 pyout = self.metadata['pyout'] or {'data':{}}
95 pyout = self.metadata['pyout'] or {'data':{}}
96 text_out = pyout['data'].get('text/plain', '')
96 text_out = pyout['data'].get('text/plain', '')
97 if len(text_out) > 32:
97 if len(text_out) > 32:
98 text_out = text_out[:29] + '...'
98 text_out = text_out[:29] + '...'
99
99
100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101
101
102 def _repr_pretty_(self, p, cycle):
102 def _repr_pretty_(self, p, cycle):
103 pyout = self.metadata['pyout'] or {'data':{}}
103 pyout = self.metadata['pyout'] or {'data':{}}
104 text_out = pyout['data'].get('text/plain', '')
104 text_out = pyout['data'].get('text/plain', '')
105
105
106 if not text_out:
106 if not text_out:
107 return
107 return
108
108
109 try:
109 try:
110 ip = get_ipython()
110 ip = get_ipython()
111 except NameError:
111 except NameError:
112 colors = "NoColor"
112 colors = "NoColor"
113 else:
113 else:
114 colors = ip.colors
114 colors = ip.colors
115
115
116 if colors == "NoColor":
116 if colors == "NoColor":
117 out = normal = ""
117 out = normal = ""
118 else:
118 else:
119 out = TermColors.Red
119 out = TermColors.Red
120 normal = TermColors.Normal
120 normal = TermColors.Normal
121
121
122 if '\n' in text_out and not text_out.startswith('\n'):
122 if '\n' in text_out and not text_out.startswith('\n'):
123 # add newline for multiline reprs
123 # add newline for multiline reprs
124 text_out = '\n' + text_out
124 text_out = '\n' + text_out
125
125
126 p.text(
126 p.text(
127 out + u'Out[%i:%i]: ' % (
127 out + u'Out[%i:%i]: ' % (
128 self.metadata['engine_id'], self.execution_count
128 self.metadata['engine_id'], self.execution_count
129 ) + normal + text_out
129 ) + normal + text_out
130 )
130 )
131
131
132 def _repr_html_(self):
132 def _repr_html_(self):
133 pyout = self.metadata['pyout'] or {'data':{}}
133 pyout = self.metadata['pyout'] or {'data':{}}
134 return pyout['data'].get("text/html")
134 return pyout['data'].get("text/html")
135
135
136 def _repr_latex_(self):
136 def _repr_latex_(self):
137 pyout = self.metadata['pyout'] or {'data':{}}
137 pyout = self.metadata['pyout'] or {'data':{}}
138 return pyout['data'].get("text/latex")
138 return pyout['data'].get("text/latex")
139
139
140 def _repr_json_(self):
140 def _repr_json_(self):
141 pyout = self.metadata['pyout'] or {'data':{}}
141 pyout = self.metadata['pyout'] or {'data':{}}
142 return pyout['data'].get("application/json")
142 return pyout['data'].get("application/json")
143
143
144 def _repr_javascript_(self):
144 def _repr_javascript_(self):
145 pyout = self.metadata['pyout'] or {'data':{}}
145 pyout = self.metadata['pyout'] or {'data':{}}
146 return pyout['data'].get("application/javascript")
146 return pyout['data'].get("application/javascript")
147
147
148 def _repr_png_(self):
148 def _repr_png_(self):
149 pyout = self.metadata['pyout'] or {'data':{}}
149 pyout = self.metadata['pyout'] or {'data':{}}
150 return pyout['data'].get("image/png")
150 return pyout['data'].get("image/png")
151
151
152 def _repr_jpeg_(self):
152 def _repr_jpeg_(self):
153 pyout = self.metadata['pyout'] or {'data':{}}
153 pyout = self.metadata['pyout'] or {'data':{}}
154 return pyout['data'].get("image/jpeg")
154 return pyout['data'].get("image/jpeg")
155
155
156 def _repr_svg_(self):
156 def _repr_svg_(self):
157 pyout = self.metadata['pyout'] or {'data':{}}
157 pyout = self.metadata['pyout'] or {'data':{}}
158 return pyout['data'].get("image/svg+xml")
158 return pyout['data'].get("image/svg+xml")
159
159
160
160
161 class Metadata(dict):
161 class Metadata(dict):
162 """Subclass of dict for initializing metadata values.
162 """Subclass of dict for initializing metadata values.
163
163
164 Attribute access works on keys.
164 Attribute access works on keys.
165
165
166 These objects have a strict set of keys - errors will raise if you try
166 These objects have a strict set of keys - errors will raise if you try
167 to add new keys.
167 to add new keys.
168 """
168 """
169 def __init__(self, *args, **kwargs):
169 def __init__(self, *args, **kwargs):
170 dict.__init__(self)
170 dict.__init__(self)
171 md = {'msg_id' : None,
171 md = {'msg_id' : None,
172 'submitted' : None,
172 'submitted' : None,
173 'started' : None,
173 'started' : None,
174 'completed' : None,
174 'completed' : None,
175 'received' : None,
175 'received' : None,
176 'engine_uuid' : None,
176 'engine_uuid' : None,
177 'engine_id' : None,
177 'engine_id' : None,
178 'follow' : None,
178 'follow' : None,
179 'after' : None,
179 'after' : None,
180 'status' : None,
180 'status' : None,
181
181
182 'pyin' : None,
182 'pyin' : None,
183 'pyout' : None,
183 'pyout' : None,
184 'pyerr' : None,
184 'pyerr' : None,
185 'stdout' : '',
185 'stdout' : '',
186 'stderr' : '',
186 'stderr' : '',
187 'outputs' : [],
187 'outputs' : [],
188 'data': {},
188 'data': {},
189 'outputs_ready' : False,
189 'outputs_ready' : False,
190 }
190 }
191 self.update(md)
191 self.update(md)
192 self.update(dict(*args, **kwargs))
192 self.update(dict(*args, **kwargs))
193
193
194 def __getattr__(self, key):
194 def __getattr__(self, key):
195 """getattr aliased to getitem"""
195 """getattr aliased to getitem"""
196 if key in self.iterkeys():
196 if key in self.iterkeys():
197 return self[key]
197 return self[key]
198 else:
198 else:
199 raise AttributeError(key)
199 raise AttributeError(key)
200
200
201 def __setattr__(self, key, value):
201 def __setattr__(self, key, value):
202 """setattr aliased to setitem, with strict"""
202 """setattr aliased to setitem, with strict"""
203 if key in self.iterkeys():
203 if key in self.iterkeys():
204 self[key] = value
204 self[key] = value
205 else:
205 else:
206 raise AttributeError(key)
206 raise AttributeError(key)
207
207
208 def __setitem__(self, key, value):
208 def __setitem__(self, key, value):
209 """strict static key enforcement"""
209 """strict static key enforcement"""
210 if key in self.iterkeys():
210 if key in self.iterkeys():
211 dict.__setitem__(self, key, value)
211 dict.__setitem__(self, key, value)
212 else:
212 else:
213 raise KeyError(key)
213 raise KeyError(key)
214
214
215
215
216 class Client(HasTraits):
216 class Client(HasTraits):
217 """A semi-synchronous client to the IPython ZMQ cluster
217 """A semi-synchronous client to the IPython ZMQ cluster
218
218
219 Parameters
219 Parameters
220 ----------
220 ----------
221
221
222 url_file : str/unicode; path to ipcontroller-client.json
222 url_file : str/unicode; path to ipcontroller-client.json
223 This JSON file should contain all the information needed to connect to a cluster,
223 This JSON file should contain all the information needed to connect to a cluster,
224 and is likely the only argument needed.
224 and is likely the only argument needed.
225 Connection information for the Hub's registration. If a json connector
225 Connection information for the Hub's registration. If a json connector
226 file is given, then likely no further configuration is necessary.
226 file is given, then likely no further configuration is necessary.
227 [Default: use profile]
227 [Default: use profile]
228 profile : bytes
228 profile : bytes
229 The name of the Cluster profile to be used to find connector information.
229 The name of the Cluster profile to be used to find connector information.
230 If run from an IPython application, the default profile will be the same
230 If run from an IPython application, the default profile will be the same
231 as the running application, otherwise it will be 'default'.
231 as the running application, otherwise it will be 'default'.
232 cluster_id : str
232 cluster_id : str
233 String id to added to runtime files, to prevent name collisions when using
233 String id to added to runtime files, to prevent name collisions when using
234 multiple clusters with a single profile simultaneously.
234 multiple clusters with a single profile simultaneously.
235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 Since this is text inserted into filenames, typical recommendations apply:
236 Since this is text inserted into filenames, typical recommendations apply:
237 Simple character strings are ideal, and spaces are not recommended (but
237 Simple character strings are ideal, and spaces are not recommended (but
238 should generally work)
238 should generally work)
239 context : zmq.Context
239 context : zmq.Context
240 Pass an existing zmq.Context instance, otherwise the client will create its own.
240 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 debug : bool
241 debug : bool
242 flag for lots of message printing for debug purposes
242 flag for lots of message printing for debug purposes
243 timeout : int/float
243 timeout : int/float
244 time (in seconds) to wait for connection replies from the Hub
244 time (in seconds) to wait for connection replies from the Hub
245 [Default: 10]
245 [Default: 10]
246
246
247 #-------------- session related args ----------------
247 #-------------- session related args ----------------
248
248
249 config : Config object
249 config : Config object
250 If specified, this will be relayed to the Session for configuration
250 If specified, this will be relayed to the Session for configuration
251 username : str
251 username : str
252 set username for the session object
252 set username for the session object
253
253
254 #-------------- ssh related args ----------------
254 #-------------- ssh related args ----------------
255 # These are args for configuring the ssh tunnel to be used
255 # These are args for configuring the ssh tunnel to be used
256 # credentials are used to forward connections over ssh to the Controller
256 # credentials are used to forward connections over ssh to the Controller
257 # Note that the ip given in `addr` needs to be relative to sshserver
257 # Note that the ip given in `addr` needs to be relative to sshserver
258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 # and set sshserver as the same machine the Controller is on. However,
259 # and set sshserver as the same machine the Controller is on. However,
260 # the only requirement is that sshserver is able to see the Controller
260 # the only requirement is that sshserver is able to see the Controller
261 # (i.e. is within the same trusted network).
261 # (i.e. is within the same trusted network).
262
262
263 sshserver : str
263 sshserver : str
264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 If keyfile or password is specified, and this is not, it will default to
265 If keyfile or password is specified, and this is not, it will default to
266 the ip given in addr.
266 the ip given in addr.
267 sshkey : str; path to ssh private key file
267 sshkey : str; path to ssh private key file
268 This specifies a key to be used in ssh login, default None.
268 This specifies a key to be used in ssh login, default None.
269 Regular default ssh keys will be used without specifying this argument.
269 Regular default ssh keys will be used without specifying this argument.
270 password : str
270 password : str
271 Your ssh password to sshserver. Note that if this is left None,
271 Your ssh password to sshserver. Note that if this is left None,
272 you will be prompted for it if passwordless key based login is unavailable.
272 you will be prompted for it if passwordless key based login is unavailable.
273 paramiko : bool
273 paramiko : bool
274 flag for whether to use paramiko instead of shell ssh for tunneling.
274 flag for whether to use paramiko instead of shell ssh for tunneling.
275 [default: True on win32, False else]
275 [default: True on win32, False else]
276
276
277
277
278 Attributes
278 Attributes
279 ----------
279 ----------
280
280
281 ids : list of int engine IDs
281 ids : list of int engine IDs
282 requesting the ids attribute always synchronizes
282 requesting the ids attribute always synchronizes
283 the registration state. To request ids without synchronization,
283 the registration state. To request ids without synchronization,
284 use semi-private _ids attributes.
284 use semi-private _ids attributes.
285
285
286 history : list of msg_ids
286 history : list of msg_ids
287 a list of msg_ids, keeping track of all the execution
287 a list of msg_ids, keeping track of all the execution
288 messages you have submitted in order.
288 messages you have submitted in order.
289
289
290 outstanding : set of msg_ids
290 outstanding : set of msg_ids
291 a set of msg_ids that have been submitted, but whose
291 a set of msg_ids that have been submitted, but whose
292 results have not yet been received.
292 results have not yet been received.
293
293
294 results : dict
294 results : dict
295 a dict of all our results, keyed by msg_id
295 a dict of all our results, keyed by msg_id
296
296
297 block : bool
297 block : bool
298 determines default behavior when block not specified
298 determines default behavior when block not specified
299 in execution methods
299 in execution methods
300
300
301 Methods
301 Methods
302 -------
302 -------
303
303
304 spin
304 spin
305 flushes incoming results and registration state changes
305 flushes incoming results and registration state changes
306 control methods spin, and requesting `ids` also ensures up to date
306 control methods spin, and requesting `ids` also ensures up to date
307
307
308 wait
308 wait
309 wait on one or more msg_ids
309 wait on one or more msg_ids
310
310
311 execution methods
311 execution methods
312 apply
312 apply
313 legacy: execute, run
313 legacy: execute, run
314
314
315 data movement
315 data movement
316 push, pull, scatter, gather
316 push, pull, scatter, gather
317
317
318 query methods
318 query methods
319 queue_status, get_result, purge, result_status
319 queue_status, get_result, purge, result_status
320
320
321 control methods
321 control methods
322 abort, shutdown
322 abort, shutdown
323
323
324 """
324 """
325
325
326
326
327 block = Bool(False)
327 block = Bool(False)
328 outstanding = Set()
328 outstanding = Set()
329 results = Instance('collections.defaultdict', (dict,))
329 results = Instance('collections.defaultdict', (dict,))
330 metadata = Instance('collections.defaultdict', (Metadata,))
330 metadata = Instance('collections.defaultdict', (Metadata,))
331 history = List()
331 history = List()
332 debug = Bool(False)
332 debug = Bool(False)
333 _spin_thread = Any()
333 _spin_thread = Any()
334 _stop_spinning = Any()
334 _stop_spinning = Any()
335
335
336 profile=Unicode()
336 profile=Unicode()
337 def _profile_default(self):
337 def _profile_default(self):
338 if BaseIPythonApplication.initialized():
338 if BaseIPythonApplication.initialized():
339 # an IPython app *might* be running, try to get its profile
339 # an IPython app *might* be running, try to get its profile
340 try:
340 try:
341 return BaseIPythonApplication.instance().profile
341 return BaseIPythonApplication.instance().profile
342 except (AttributeError, MultipleInstanceError):
342 except (AttributeError, MultipleInstanceError):
343 # could be a *different* subclass of config.Application,
343 # could be a *different* subclass of config.Application,
344 # which would raise one of these two errors.
344 # which would raise one of these two errors.
345 return u'default'
345 return u'default'
346 else:
346 else:
347 return u'default'
347 return u'default'
348
348
349
349
350 _outstanding_dict = Instance('collections.defaultdict', (set,))
350 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 _ids = List()
351 _ids = List()
352 _connected=Bool(False)
352 _connected=Bool(False)
353 _ssh=Bool(False)
353 _ssh=Bool(False)
354 _context = Instance('zmq.Context')
354 _context = Instance('zmq.Context')
355 _config = Dict()
355 _config = Dict()
356 _engines=Instance(util.ReverseDict, (), {})
356 _engines=Instance(util.ReverseDict, (), {})
357 # _hub_socket=Instance('zmq.Socket')
357 # _hub_socket=Instance('zmq.Socket')
358 _query_socket=Instance('zmq.Socket')
358 _query_socket=Instance('zmq.Socket')
359 _control_socket=Instance('zmq.Socket')
359 _control_socket=Instance('zmq.Socket')
360 _iopub_socket=Instance('zmq.Socket')
360 _iopub_socket=Instance('zmq.Socket')
361 _notification_socket=Instance('zmq.Socket')
361 _notification_socket=Instance('zmq.Socket')
362 _mux_socket=Instance('zmq.Socket')
362 _mux_socket=Instance('zmq.Socket')
363 _task_socket=Instance('zmq.Socket')
363 _task_socket=Instance('zmq.Socket')
364 _task_scheme=Unicode()
364 _task_scheme=Unicode()
365 _closed = False
365 _closed = False
366 _ignored_control_replies=Integer(0)
366 _ignored_control_replies=Integer(0)
367 _ignored_hub_replies=Integer(0)
367 _ignored_hub_replies=Integer(0)
368
368
369 def __new__(self, *args, **kw):
369 def __new__(self, *args, **kw):
370 # don't raise on positional args
370 # don't raise on positional args
371 return HasTraits.__new__(self, **kw)
371 return HasTraits.__new__(self, **kw)
372
372
373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 context=None, debug=False,
374 context=None, debug=False,
375 sshserver=None, sshkey=None, password=None, paramiko=None,
375 sshserver=None, sshkey=None, password=None, paramiko=None,
376 timeout=10, cluster_id=None, **extra_args
376 timeout=10, cluster_id=None, **extra_args
377 ):
377 ):
378 if profile:
378 if profile:
379 super(Client, self).__init__(debug=debug, profile=profile)
379 super(Client, self).__init__(debug=debug, profile=profile)
380 else:
380 else:
381 super(Client, self).__init__(debug=debug)
381 super(Client, self).__init__(debug=debug)
382 if context is None:
382 if context is None:
383 context = zmq.Context.instance()
383 context = zmq.Context.instance()
384 self._context = context
384 self._context = context
385 self._stop_spinning = Event()
385 self._stop_spinning = Event()
386
386
387 if 'url_or_file' in extra_args:
387 if 'url_or_file' in extra_args:
388 url_file = extra_args['url_or_file']
388 url_file = extra_args['url_or_file']
389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390
390
391 if url_file and util.is_url(url_file):
391 if url_file and util.is_url(url_file):
392 raise ValueError("single urls cannot be specified, url-files must be used.")
392 raise ValueError("single urls cannot be specified, url-files must be used.")
393
393
394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395
395
396 if self._cd is not None:
396 if self._cd is not None:
397 if url_file is None:
397 if url_file is None:
398 if not cluster_id:
398 if not cluster_id:
399 client_json = 'ipcontroller-client.json'
399 client_json = 'ipcontroller-client.json'
400 else:
400 else:
401 client_json = 'ipcontroller-%s-client.json' % cluster_id
401 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 url_file = pjoin(self._cd.security_dir, client_json)
402 url_file = pjoin(self._cd.security_dir, client_json)
403 if url_file is None:
403 if url_file is None:
404 raise ValueError(
404 raise ValueError(
405 "I can't find enough information to connect to a hub!"
405 "I can't find enough information to connect to a hub!"
406 " Please specify at least one of url_file or profile."
406 " Please specify at least one of url_file or profile."
407 )
407 )
408
408
409 with open(url_file) as f:
409 with open(url_file) as f:
410 cfg = json.load(f)
410 cfg = json.load(f)
411
411
412 self._task_scheme = cfg['task_scheme']
412 self._task_scheme = cfg['task_scheme']
413
413
414 # sync defaults from args, json:
414 # sync defaults from args, json:
415 if sshserver:
415 if sshserver:
416 cfg['ssh'] = sshserver
416 cfg['ssh'] = sshserver
417
417
418 location = cfg.setdefault('location', None)
418 location = cfg.setdefault('location', None)
419
419
420 proto,addr = cfg['interface'].split('://')
420 proto,addr = cfg['interface'].split('://')
421 addr = util.disambiguate_ip_address(addr, location)
421 addr = util.disambiguate_ip_address(addr, location)
422 cfg['interface'] = "%s://%s" % (proto, addr)
422 cfg['interface'] = "%s://%s" % (proto, addr)
423
423
424 # turn interface,port into full urls:
424 # turn interface,port into full urls:
425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427
427
428 url = cfg['registration']
428 url = cfg['registration']
429
429
430 if location is not None and addr == LOCALHOST:
430 if location is not None and addr == LOCALHOST:
431 # location specified, and connection is expected to be local
431 # location specified, and connection is expected to be local
432 if location not in LOCAL_IPS and not sshserver:
432 if location not in LOCAL_IPS and not sshserver:
433 # load ssh from JSON *only* if the controller is not on
433 # load ssh from JSON *only* if the controller is not on
434 # this machine
434 # this machine
435 sshserver=cfg['ssh']
435 sshserver=cfg['ssh']
436 if location not in LOCAL_IPS and not sshserver:
436 if location not in LOCAL_IPS and not sshserver:
437 # warn if no ssh specified, but SSH is probably needed
437 # warn if no ssh specified, but SSH is probably needed
438 # This is only a warning, because the most likely cause
438 # This is only a warning, because the most likely cause
439 # is a local Controller on a laptop whose IP is dynamic
439 # is a local Controller on a laptop whose IP is dynamic
440 warnings.warn("""
440 warnings.warn("""
441 Controller appears to be listening on localhost, but not on this machine.
441 Controller appears to be listening on localhost, but not on this machine.
442 If this is true, you should specify Client(...,sshserver='you@%s')
442 If this is true, you should specify Client(...,sshserver='you@%s')
443 or instruct your controller to listen on an external IP."""%location,
443 or instruct your controller to listen on an external IP."""%location,
444 RuntimeWarning)
444 RuntimeWarning)
445 elif not sshserver:
445 elif not sshserver:
446 # otherwise sync with cfg
446 # otherwise sync with cfg
447 sshserver = cfg['ssh']
447 sshserver = cfg['ssh']
448
448
449 self._config = cfg
449 self._config = cfg
450
450
451 self._ssh = bool(sshserver or sshkey or password)
451 self._ssh = bool(sshserver or sshkey or password)
452 if self._ssh and sshserver is None:
452 if self._ssh and sshserver is None:
453 # default to ssh via localhost
453 # default to ssh via localhost
454 sshserver = addr
454 sshserver = addr
455 if self._ssh and password is None:
455 if self._ssh and password is None:
456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 password=False
457 password=False
458 else:
458 else:
459 password = getpass("SSH Password for %s: "%sshserver)
459 password = getpass("SSH Password for %s: "%sshserver)
460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461
461
462 # configure and construct the session
462 # configure and construct the session
463 extra_args['packer'] = cfg['pack']
463 extra_args['packer'] = cfg['pack']
464 extra_args['unpacker'] = cfg['unpack']
464 extra_args['unpacker'] = cfg['unpack']
465 extra_args['key'] = cast_bytes(cfg['exec_key'])
465 extra_args['key'] = cast_bytes(cfg['exec_key'])
466
466
467 self.session = Session(**extra_args)
467 self.session = Session(**extra_args)
468
468
469 self._query_socket = self._context.socket(zmq.DEALER)
469 self._query_socket = self._context.socket(zmq.DEALER)
470
470
471 if self._ssh:
471 if self._ssh:
472 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
472 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
473 else:
473 else:
474 self._query_socket.connect(cfg['registration'])
474 self._query_socket.connect(cfg['registration'])
475
475
476 self.session.debug = self.debug
476 self.session.debug = self.debug
477
477
478 self._notification_handlers = {'registration_notification' : self._register_engine,
478 self._notification_handlers = {'registration_notification' : self._register_engine,
479 'unregistration_notification' : self._unregister_engine,
479 'unregistration_notification' : self._unregister_engine,
480 'shutdown_notification' : lambda msg: self.close(),
480 'shutdown_notification' : lambda msg: self.close(),
481 }
481 }
482 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
482 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
483 'apply_reply' : self._handle_apply_reply}
483 'apply_reply' : self._handle_apply_reply}
484 self._connect(sshserver, ssh_kwargs, timeout)
484 self._connect(sshserver, ssh_kwargs, timeout)
485
485
486 # last step: setup magics, if we are in IPython:
486 # last step: setup magics, if we are in IPython:
487
487
488 try:
488 try:
489 ip = get_ipython()
489 ip = get_ipython()
490 except NameError:
490 except NameError:
491 return
491 return
492 else:
492 else:
493 if 'px' not in ip.magics_manager.magics:
493 if 'px' not in ip.magics_manager.magics:
494 # in IPython but we are the first Client.
494 # in IPython but we are the first Client.
495 # activate a default view for parallel magics.
495 # activate a default view for parallel magics.
496 self.activate()
496 self.activate()
497
497
498 def __del__(self):
498 def __del__(self):
499 """cleanup sockets, but _not_ context."""
499 """cleanup sockets, but _not_ context."""
500 self.close()
500 self.close()
501
501
502 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
502 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
503 if ipython_dir is None:
503 if ipython_dir is None:
504 ipython_dir = get_ipython_dir()
504 ipython_dir = get_ipython_dir()
505 if profile_dir is not None:
505 if profile_dir is not None:
506 try:
506 try:
507 self._cd = ProfileDir.find_profile_dir(profile_dir)
507 self._cd = ProfileDir.find_profile_dir(profile_dir)
508 return
508 return
509 except ProfileDirError:
509 except ProfileDirError:
510 pass
510 pass
511 elif profile is not None:
511 elif profile is not None:
512 try:
512 try:
513 self._cd = ProfileDir.find_profile_dir_by_name(
513 self._cd = ProfileDir.find_profile_dir_by_name(
514 ipython_dir, profile)
514 ipython_dir, profile)
515 return
515 return
516 except ProfileDirError:
516 except ProfileDirError:
517 pass
517 pass
518 self._cd = None
518 self._cd = None
519
519
520 def _update_engines(self, engines):
520 def _update_engines(self, engines):
521 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
521 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
522 for k,v in engines.iteritems():
522 for k,v in engines.iteritems():
523 eid = int(k)
523 eid = int(k)
524 if eid not in self._engines:
524 if eid not in self._engines:
525 self._ids.append(eid)
525 self._ids.append(eid)
526 self._engines[eid] = v
526 self._engines[eid] = v
527 self._ids = sorted(self._ids)
527 self._ids = sorted(self._ids)
528 if sorted(self._engines.keys()) != range(len(self._engines)) and \
528 if sorted(self._engines.keys()) != range(len(self._engines)) and \
529 self._task_scheme == 'pure' and self._task_socket:
529 self._task_scheme == 'pure' and self._task_socket:
530 self._stop_scheduling_tasks()
530 self._stop_scheduling_tasks()
531
531
532 def _stop_scheduling_tasks(self):
532 def _stop_scheduling_tasks(self):
533 """Stop scheduling tasks because an engine has been unregistered
533 """Stop scheduling tasks because an engine has been unregistered
534 from a pure ZMQ scheduler.
534 from a pure ZMQ scheduler.
535 """
535 """
536 self._task_socket.close()
536 self._task_socket.close()
537 self._task_socket = None
537 self._task_socket = None
538 msg = "An engine has been unregistered, and we are using pure " +\
538 msg = "An engine has been unregistered, and we are using pure " +\
539 "ZMQ task scheduling. Task farming will be disabled."
539 "ZMQ task scheduling. Task farming will be disabled."
540 if self.outstanding:
540 if self.outstanding:
541 msg += " If you were running tasks when this happened, " +\
541 msg += " If you were running tasks when this happened, " +\
542 "some `outstanding` msg_ids may never resolve."
542 "some `outstanding` msg_ids may never resolve."
543 warnings.warn(msg, RuntimeWarning)
543 warnings.warn(msg, RuntimeWarning)
544
544
545 def _build_targets(self, targets):
545 def _build_targets(self, targets):
546 """Turn valid target IDs or 'all' into two lists:
546 """Turn valid target IDs or 'all' into two lists:
547 (int_ids, uuids).
547 (int_ids, uuids).
548 """
548 """
549 if not self._ids:
549 if not self._ids:
550 # flush notification socket if no engines yet, just in case
550 # flush notification socket if no engines yet, just in case
551 if not self.ids:
551 if not self.ids:
552 raise error.NoEnginesRegistered("Can't build targets without any engines")
552 raise error.NoEnginesRegistered("Can't build targets without any engines")
553
553
554 if targets is None:
554 if targets is None:
555 targets = self._ids
555 targets = self._ids
556 elif isinstance(targets, basestring):
556 elif isinstance(targets, basestring):
557 if targets.lower() == 'all':
557 if targets.lower() == 'all':
558 targets = self._ids
558 targets = self._ids
559 else:
559 else:
560 raise TypeError("%r not valid str target, must be 'all'"%(targets))
560 raise TypeError("%r not valid str target, must be 'all'"%(targets))
561 elif isinstance(targets, int):
561 elif isinstance(targets, int):
562 if targets < 0:
562 if targets < 0:
563 targets = self.ids[targets]
563 targets = self.ids[targets]
564 if targets not in self._ids:
564 if targets not in self._ids:
565 raise IndexError("No such engine: %i"%targets)
565 raise IndexError("No such engine: %i"%targets)
566 targets = [targets]
566 targets = [targets]
567
567
568 if isinstance(targets, slice):
568 if isinstance(targets, slice):
569 indices = range(len(self._ids))[targets]
569 indices = range(len(self._ids))[targets]
570 ids = self.ids
570 ids = self.ids
571 targets = [ ids[i] for i in indices ]
571 targets = [ ids[i] for i in indices ]
572
572
573 if not isinstance(targets, (tuple, list, xrange)):
573 if not isinstance(targets, (tuple, list, xrange)):
574 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
574 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
575
575
576 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
576 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
577
577
578 def _connect(self, sshserver, ssh_kwargs, timeout):
578 def _connect(self, sshserver, ssh_kwargs, timeout):
579 """setup all our socket connections to the cluster. This is called from
579 """setup all our socket connections to the cluster. This is called from
580 __init__."""
580 __init__."""
581
581
582 # Maybe allow reconnecting?
582 # Maybe allow reconnecting?
583 if self._connected:
583 if self._connected:
584 return
584 return
585 self._connected=True
585 self._connected=True
586
586
587 def connect_socket(s, url):
587 def connect_socket(s, url):
588 # url = util.disambiguate_url(url, self._config['location'])
588 # url = util.disambiguate_url(url, self._config['location'])
589 if self._ssh:
589 if self._ssh:
590 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
590 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
591 else:
591 else:
592 return s.connect(url)
592 return s.connect(url)
593
593
594 self.session.send(self._query_socket, 'connection_request')
594 self.session.send(self._query_socket, 'connection_request')
595 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
595 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
596 poller = zmq.Poller()
596 poller = zmq.Poller()
597 poller.register(self._query_socket, zmq.POLLIN)
597 poller.register(self._query_socket, zmq.POLLIN)
598 # poll expects milliseconds, timeout is seconds
598 # poll expects milliseconds, timeout is seconds
599 evts = poller.poll(timeout*1000)
599 evts = poller.poll(timeout*1000)
600 if not evts:
600 if not evts:
601 raise error.TimeoutError("Hub connection request timed out")
601 raise error.TimeoutError("Hub connection request timed out")
602 idents,msg = self.session.recv(self._query_socket,mode=0)
602 idents,msg = self.session.recv(self._query_socket,mode=0)
603 if self.debug:
603 if self.debug:
604 pprint(msg)
604 pprint(msg)
605 content = msg['content']
605 content = msg['content']
606 # self._config['registration'] = dict(content)
606 # self._config['registration'] = dict(content)
607 cfg = self._config
607 cfg = self._config
608 if content['status'] == 'ok':
608 if content['status'] == 'ok':
609 self._mux_socket = self._context.socket(zmq.DEALER)
609 self._mux_socket = self._context.socket(zmq.DEALER)
610 connect_socket(self._mux_socket, cfg['mux'])
610 connect_socket(self._mux_socket, cfg['mux'])
611
611
612 self._task_socket = self._context.socket(zmq.DEALER)
612 self._task_socket = self._context.socket(zmq.DEALER)
613 connect_socket(self._task_socket, cfg['task'])
613 connect_socket(self._task_socket, cfg['task'])
614
614
615 self._notification_socket = self._context.socket(zmq.SUB)
615 self._notification_socket = self._context.socket(zmq.SUB)
616 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
616 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
617 connect_socket(self._notification_socket, cfg['notification'])
617 connect_socket(self._notification_socket, cfg['notification'])
618
618
619 self._control_socket = self._context.socket(zmq.DEALER)
619 self._control_socket = self._context.socket(zmq.DEALER)
620 connect_socket(self._control_socket, cfg['control'])
620 connect_socket(self._control_socket, cfg['control'])
621
621
622 self._iopub_socket = self._context.socket(zmq.SUB)
622 self._iopub_socket = self._context.socket(zmq.SUB)
623 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
623 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
624 connect_socket(self._iopub_socket, cfg['iopub'])
624 connect_socket(self._iopub_socket, cfg['iopub'])
625
625
626 self._update_engines(dict(content['engines']))
626 self._update_engines(dict(content['engines']))
627 else:
627 else:
628 self._connected = False
628 self._connected = False
629 raise Exception("Failed to connect!")
629 raise Exception("Failed to connect!")
630
630
631 #--------------------------------------------------------------------------
631 #--------------------------------------------------------------------------
632 # handlers and callbacks for incoming messages
632 # handlers and callbacks for incoming messages
633 #--------------------------------------------------------------------------
633 #--------------------------------------------------------------------------
634
634
635 def _unwrap_exception(self, content):
635 def _unwrap_exception(self, content):
636 """unwrap exception, and remap engine_id to int."""
636 """unwrap exception, and remap engine_id to int."""
637 e = error.unwrap_exception(content)
637 e = error.unwrap_exception(content)
638 # print e.traceback
638 # print e.traceback
639 if e.engine_info:
639 if e.engine_info:
640 e_uuid = e.engine_info['engine_uuid']
640 e_uuid = e.engine_info['engine_uuid']
641 eid = self._engines[e_uuid]
641 eid = self._engines[e_uuid]
642 e.engine_info['engine_id'] = eid
642 e.engine_info['engine_id'] = eid
643 return e
643 return e
644
644
645 def _extract_metadata(self, msg):
645 def _extract_metadata(self, msg):
646 header = msg['header']
646 header = msg['header']
647 parent = msg['parent_header']
647 parent = msg['parent_header']
648 msg_meta = msg['metadata']
648 msg_meta = msg['metadata']
649 content = msg['content']
649 content = msg['content']
650 md = {'msg_id' : parent['msg_id'],
650 md = {'msg_id' : parent['msg_id'],
651 'received' : datetime.now(),
651 'received' : datetime.now(),
652 'engine_uuid' : msg_meta.get('engine', None),
652 'engine_uuid' : msg_meta.get('engine', None),
653 'follow' : msg_meta.get('follow', []),
653 'follow' : msg_meta.get('follow', []),
654 'after' : msg_meta.get('after', []),
654 'after' : msg_meta.get('after', []),
655 'status' : content['status'],
655 'status' : content['status'],
656 }
656 }
657
657
658 if md['engine_uuid'] is not None:
658 if md['engine_uuid'] is not None:
659 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
659 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
660
660
661 if 'date' in parent:
661 if 'date' in parent:
662 md['submitted'] = parent['date']
662 md['submitted'] = parent['date']
663 if 'started' in msg_meta:
663 if 'started' in msg_meta:
664 md['started'] = msg_meta['started']
664 md['started'] = msg_meta['started']
665 if 'date' in header:
665 if 'date' in header:
666 md['completed'] = header['date']
666 md['completed'] = header['date']
667 return md
667 return md
668
668
669 def _register_engine(self, msg):
669 def _register_engine(self, msg):
670 """Register a new engine, and update our connection info."""
670 """Register a new engine, and update our connection info."""
671 content = msg['content']
671 content = msg['content']
672 eid = content['id']
672 eid = content['id']
673 d = {eid : content['uuid']}
673 d = {eid : content['uuid']}
674 self._update_engines(d)
674 self._update_engines(d)
675
675
676 def _unregister_engine(self, msg):
676 def _unregister_engine(self, msg):
677 """Unregister an engine that has died."""
677 """Unregister an engine that has died."""
678 content = msg['content']
678 content = msg['content']
679 eid = int(content['id'])
679 eid = int(content['id'])
680 if eid in self._ids:
680 if eid in self._ids:
681 self._ids.remove(eid)
681 self._ids.remove(eid)
682 uuid = self._engines.pop(eid)
682 uuid = self._engines.pop(eid)
683
683
684 self._handle_stranded_msgs(eid, uuid)
684 self._handle_stranded_msgs(eid, uuid)
685
685
686 if self._task_socket and self._task_scheme == 'pure':
686 if self._task_socket and self._task_scheme == 'pure':
687 self._stop_scheduling_tasks()
687 self._stop_scheduling_tasks()
688
688
689 def _handle_stranded_msgs(self, eid, uuid):
689 def _handle_stranded_msgs(self, eid, uuid):
690 """Handle messages known to be on an engine when the engine unregisters.
690 """Handle messages known to be on an engine when the engine unregisters.
691
691
692 It is possible that this will fire prematurely - that is, an engine will
692 It is possible that this will fire prematurely - that is, an engine will
693 go down after completing a result, and the client will be notified
693 go down after completing a result, and the client will be notified
694 of the unregistration and later receive the successful result.
694 of the unregistration and later receive the successful result.
695 """
695 """
696
696
697 outstanding = self._outstanding_dict[uuid]
697 outstanding = self._outstanding_dict[uuid]
698
698
699 for msg_id in list(outstanding):
699 for msg_id in list(outstanding):
700 if msg_id in self.results:
700 if msg_id in self.results:
701 # we already
701 # we already
702 continue
702 continue
703 try:
703 try:
704 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
704 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
705 except:
705 except:
706 content = error.wrap_exception()
706 content = error.wrap_exception()
707 # build a fake message:
707 # build a fake message:
708 msg = self.session.msg('apply_reply', content=content)
708 msg = self.session.msg('apply_reply', content=content)
709 msg['parent_header']['msg_id'] = msg_id
709 msg['parent_header']['msg_id'] = msg_id
710 msg['metadata']['engine'] = uuid
710 msg['metadata']['engine'] = uuid
711 self._handle_apply_reply(msg)
711 self._handle_apply_reply(msg)
712
712
713 def _handle_execute_reply(self, msg):
713 def _handle_execute_reply(self, msg):
714 """Save the reply to an execute_request into our results.
714 """Save the reply to an execute_request into our results.
715
715
716 execute messages are never actually used. apply is used instead.
716 execute messages are never actually used. apply is used instead.
717 """
717 """
718
718
719 parent = msg['parent_header']
719 parent = msg['parent_header']
720 msg_id = parent['msg_id']
720 msg_id = parent['msg_id']
721 if msg_id not in self.outstanding:
721 if msg_id not in self.outstanding:
722 if msg_id in self.history:
722 if msg_id in self.history:
723 print ("got stale result: %s"%msg_id)
723 print ("got stale result: %s"%msg_id)
724 else:
724 else:
725 print ("got unknown result: %s"%msg_id)
725 print ("got unknown result: %s"%msg_id)
726 else:
726 else:
727 self.outstanding.remove(msg_id)
727 self.outstanding.remove(msg_id)
728
728
729 content = msg['content']
729 content = msg['content']
730 header = msg['header']
730 header = msg['header']
731
731
732 # construct metadata:
732 # construct metadata:
733 md = self.metadata[msg_id]
733 md = self.metadata[msg_id]
734 md.update(self._extract_metadata(msg))
734 md.update(self._extract_metadata(msg))
735 # is this redundant?
735 # is this redundant?
736 self.metadata[msg_id] = md
736 self.metadata[msg_id] = md
737
737
738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
739 if msg_id in e_outstanding:
739 if msg_id in e_outstanding:
740 e_outstanding.remove(msg_id)
740 e_outstanding.remove(msg_id)
741
741
742 # construct result:
742 # construct result:
743 if content['status'] == 'ok':
743 if content['status'] == 'ok':
744 self.results[msg_id] = ExecuteReply(msg_id, content, md)
744 self.results[msg_id] = ExecuteReply(msg_id, content, md)
745 elif content['status'] == 'aborted':
745 elif content['status'] == 'aborted':
746 self.results[msg_id] = error.TaskAborted(msg_id)
746 self.results[msg_id] = error.TaskAborted(msg_id)
747 elif content['status'] == 'resubmitted':
747 elif content['status'] == 'resubmitted':
748 # TODO: handle resubmission
748 # TODO: handle resubmission
749 pass
749 pass
750 else:
750 else:
751 self.results[msg_id] = self._unwrap_exception(content)
751 self.results[msg_id] = self._unwrap_exception(content)
752
752
753 def _handle_apply_reply(self, msg):
753 def _handle_apply_reply(self, msg):
754 """Save the reply to an apply_request into our results."""
754 """Save the reply to an apply_request into our results."""
755 parent = msg['parent_header']
755 parent = msg['parent_header']
756 msg_id = parent['msg_id']
756 msg_id = parent['msg_id']
757 if msg_id not in self.outstanding:
757 if msg_id not in self.outstanding:
758 if msg_id in self.history:
758 if msg_id in self.history:
759 print ("got stale result: %s"%msg_id)
759 print ("got stale result: %s"%msg_id)
760 print self.results[msg_id]
760 print self.results[msg_id]
761 print msg
761 print msg
762 else:
762 else:
763 print ("got unknown result: %s"%msg_id)
763 print ("got unknown result: %s"%msg_id)
764 else:
764 else:
765 self.outstanding.remove(msg_id)
765 self.outstanding.remove(msg_id)
766 content = msg['content']
766 content = msg['content']
767 header = msg['header']
767 header = msg['header']
768
768
769 # construct metadata:
769 # construct metadata:
770 md = self.metadata[msg_id]
770 md = self.metadata[msg_id]
771 md.update(self._extract_metadata(msg))
771 md.update(self._extract_metadata(msg))
772 # is this redundant?
772 # is this redundant?
773 self.metadata[msg_id] = md
773 self.metadata[msg_id] = md
774
774
775 e_outstanding = self._outstanding_dict[md['engine_uuid']]
775 e_outstanding = self._outstanding_dict[md['engine_uuid']]
776 if msg_id in e_outstanding:
776 if msg_id in e_outstanding:
777 e_outstanding.remove(msg_id)
777 e_outstanding.remove(msg_id)
778
778
779 # construct result:
779 # construct result:
780 if content['status'] == 'ok':
780 if content['status'] == 'ok':
781 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
781 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
782 elif content['status'] == 'aborted':
782 elif content['status'] == 'aborted':
783 self.results[msg_id] = error.TaskAborted(msg_id)
783 self.results[msg_id] = error.TaskAborted(msg_id)
784 elif content['status'] == 'resubmitted':
784 elif content['status'] == 'resubmitted':
785 # TODO: handle resubmission
785 # TODO: handle resubmission
786 pass
786 pass
787 else:
787 else:
788 self.results[msg_id] = self._unwrap_exception(content)
788 self.results[msg_id] = self._unwrap_exception(content)
789
789
790 def _flush_notifications(self):
790 def _flush_notifications(self):
791 """Flush notifications of engine registrations waiting
791 """Flush notifications of engine registrations waiting
792 in ZMQ queue."""
792 in ZMQ queue."""
793 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
793 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
794 while msg is not None:
794 while msg is not None:
795 if self.debug:
795 if self.debug:
796 pprint(msg)
796 pprint(msg)
797 msg_type = msg['header']['msg_type']
797 msg_type = msg['header']['msg_type']
798 handler = self._notification_handlers.get(msg_type, None)
798 handler = self._notification_handlers.get(msg_type, None)
799 if handler is None:
799 if handler is None:
800 raise Exception("Unhandled message type: %s" % msg_type)
800 raise Exception("Unhandled message type: %s" % msg_type)
801 else:
801 else:
802 handler(msg)
802 handler(msg)
803 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
803 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
804
804
805 def _flush_results(self, sock):
805 def _flush_results(self, sock):
806 """Flush task or queue results waiting in ZMQ queue."""
806 """Flush task or queue results waiting in ZMQ queue."""
807 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
807 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
808 while msg is not None:
808 while msg is not None:
809 if self.debug:
809 if self.debug:
810 pprint(msg)
810 pprint(msg)
811 msg_type = msg['header']['msg_type']
811 msg_type = msg['header']['msg_type']
812 handler = self._queue_handlers.get(msg_type, None)
812 handler = self._queue_handlers.get(msg_type, None)
813 if handler is None:
813 if handler is None:
814 raise Exception("Unhandled message type: %s" % msg_type)
814 raise Exception("Unhandled message type: %s" % msg_type)
815 else:
815 else:
816 handler(msg)
816 handler(msg)
817 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
817 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
818
818
819 def _flush_control(self, sock):
819 def _flush_control(self, sock):
820 """Flush replies from the control channel waiting
820 """Flush replies from the control channel waiting
821 in the ZMQ queue.
821 in the ZMQ queue.
822
822
823 Currently: ignore them."""
823 Currently: ignore them."""
824 if self._ignored_control_replies <= 0:
824 if self._ignored_control_replies <= 0:
825 return
825 return
826 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
826 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
827 while msg is not None:
827 while msg is not None:
828 self._ignored_control_replies -= 1
828 self._ignored_control_replies -= 1
829 if self.debug:
829 if self.debug:
830 pprint(msg)
830 pprint(msg)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832
832
833 def _flush_ignored_control(self):
833 def _flush_ignored_control(self):
834 """flush ignored control replies"""
834 """flush ignored control replies"""
835 while self._ignored_control_replies > 0:
835 while self._ignored_control_replies > 0:
836 self.session.recv(self._control_socket)
836 self.session.recv(self._control_socket)
837 self._ignored_control_replies -= 1
837 self._ignored_control_replies -= 1
838
838
839 def _flush_ignored_hub_replies(self):
839 def _flush_ignored_hub_replies(self):
840 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
840 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
841 while msg is not None:
841 while msg is not None:
842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843
843
844 def _flush_iopub(self, sock):
844 def _flush_iopub(self, sock):
845 """Flush replies from the iopub channel waiting
845 """Flush replies from the iopub channel waiting
846 in the ZMQ queue.
846 in the ZMQ queue.
847 """
847 """
848 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
848 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
849 while msg is not None:
849 while msg is not None:
850 if self.debug:
850 if self.debug:
851 pprint(msg)
851 pprint(msg)
852 parent = msg['parent_header']
852 parent = msg['parent_header']
853 # ignore IOPub messages with no parent.
853 # ignore IOPub messages with no parent.
854 # Caused by print statements or warnings from before the first execution.
854 # Caused by print statements or warnings from before the first execution.
855 if not parent:
855 if not parent:
856 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
856 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
857 continue
857 continue
858 msg_id = parent['msg_id']
858 msg_id = parent['msg_id']
859 content = msg['content']
859 content = msg['content']
860 header = msg['header']
860 header = msg['header']
861 msg_type = msg['header']['msg_type']
861 msg_type = msg['header']['msg_type']
862
862
863 # init metadata:
863 # init metadata:
864 md = self.metadata[msg_id]
864 md = self.metadata[msg_id]
865
865
866 if msg_type == 'stream':
866 if msg_type == 'stream':
867 name = content['name']
867 name = content['name']
868 s = md[name] or ''
868 s = md[name] or ''
869 md[name] = s + content['data']
869 md[name] = s + content['data']
870 elif msg_type == 'pyerr':
870 elif msg_type == 'pyerr':
871 md.update({'pyerr' : self._unwrap_exception(content)})
871 md.update({'pyerr' : self._unwrap_exception(content)})
872 elif msg_type == 'pyin':
872 elif msg_type == 'pyin':
873 md.update({'pyin' : content['code']})
873 md.update({'pyin' : content['code']})
874 elif msg_type == 'display_data':
874 elif msg_type == 'display_data':
875 md['outputs'].append(content)
875 md['outputs'].append(content)
876 elif msg_type == 'pyout':
876 elif msg_type == 'pyout':
877 md['pyout'] = content
877 md['pyout'] = content
878 elif msg_type == 'data_message':
878 elif msg_type == 'data_message':
879 data, remainder = serialize.unserialize_object(msg['buffers'])
879 data, remainder = serialize.unserialize_object(msg['buffers'])
880 md['data'].update(data)
880 md['data'].update(data)
881 elif msg_type == 'status':
881 elif msg_type == 'status':
882 # idle message comes after all outputs
882 # idle message comes after all outputs
883 if content['execution_state'] == 'idle':
883 if content['execution_state'] == 'idle':
884 md['outputs_ready'] = True
884 md['outputs_ready'] = True
885 else:
885 else:
886 # unhandled msg_type (status, etc.)
886 # unhandled msg_type (status, etc.)
887 pass
887 pass
888
888
889 # reduntant?
889 # reduntant?
890 self.metadata[msg_id] = md
890 self.metadata[msg_id] = md
891
891
892 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
892 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
893
893
894 #--------------------------------------------------------------------------
894 #--------------------------------------------------------------------------
895 # len, getitem
895 # len, getitem
896 #--------------------------------------------------------------------------
896 #--------------------------------------------------------------------------
897
897
898 def __len__(self):
898 def __len__(self):
899 """len(client) returns # of engines."""
899 """len(client) returns # of engines."""
900 return len(self.ids)
900 return len(self.ids)
901
901
902 def __getitem__(self, key):
902 def __getitem__(self, key):
903 """index access returns DirectView multiplexer objects
903 """index access returns DirectView multiplexer objects
904
904
905 Must be int, slice, or list/tuple/xrange of ints"""
905 Must be int, slice, or list/tuple/xrange of ints"""
906 if not isinstance(key, (int, slice, tuple, list, xrange)):
906 if not isinstance(key, (int, slice, tuple, list, xrange)):
907 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
907 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
908 else:
908 else:
909 return self.direct_view(key)
909 return self.direct_view(key)
910
910
911 #--------------------------------------------------------------------------
911 #--------------------------------------------------------------------------
912 # Begin public methods
912 # Begin public methods
913 #--------------------------------------------------------------------------
913 #--------------------------------------------------------------------------
914
914
915 @property
915 @property
916 def ids(self):
916 def ids(self):
917 """Always up-to-date ids property."""
917 """Always up-to-date ids property."""
918 self._flush_notifications()
918 self._flush_notifications()
919 # always copy:
919 # always copy:
920 return list(self._ids)
920 return list(self._ids)
921
921
922 def activate(self, targets='all', suffix=''):
922 def activate(self, targets='all', suffix=''):
923 """Create a DirectView and register it with IPython magics
923 """Create a DirectView and register it with IPython magics
924
924
925 Defines the magics `%px, %autopx, %pxresult, %%px`
925 Defines the magics `%px, %autopx, %pxresult, %%px`
926
926
927 Parameters
927 Parameters
928 ----------
928 ----------
929
929
930 targets: int, list of ints, or 'all'
930 targets: int, list of ints, or 'all'
931 The engines on which the view's magics will run
931 The engines on which the view's magics will run
932 suffix: str [default: '']
932 suffix: str [default: '']
933 The suffix, if any, for the magics. This allows you to have
933 The suffix, if any, for the magics. This allows you to have
934 multiple views associated with parallel magics at the same time.
934 multiple views associated with parallel magics at the same time.
935
935
936 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
936 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
937 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
937 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
938 on engine 0.
938 on engine 0.
939 """
939 """
940 view = self.direct_view(targets)
940 view = self.direct_view(targets)
941 view.block = True
941 view.block = True
942 view.activate(suffix)
942 view.activate(suffix)
943 return view
943 return view
944
944
945 def close(self):
945 def close(self):
946 if self._closed:
946 if self._closed:
947 return
947 return
948 self.stop_spin_thread()
948 self.stop_spin_thread()
949 snames = filter(lambda n: n.endswith('socket'), dir(self))
949 snames = filter(lambda n: n.endswith('socket'), dir(self))
950 for socket in map(lambda name: getattr(self, name), snames):
950 for socket in map(lambda name: getattr(self, name), snames):
951 if isinstance(socket, zmq.Socket) and not socket.closed:
951 if isinstance(socket, zmq.Socket) and not socket.closed:
952 socket.close()
952 socket.close()
953 self._closed = True
953 self._closed = True
954
954
955 def _spin_every(self, interval=1):
955 def _spin_every(self, interval=1):
956 """target func for use in spin_thread"""
956 """target func for use in spin_thread"""
957 while True:
957 while True:
958 if self._stop_spinning.is_set():
958 if self._stop_spinning.is_set():
959 return
959 return
960 time.sleep(interval)
960 time.sleep(interval)
961 self.spin()
961 self.spin()
962
962
963 def spin_thread(self, interval=1):
963 def spin_thread(self, interval=1):
964 """call Client.spin() in a background thread on some regular interval
964 """call Client.spin() in a background thread on some regular interval
965
965
966 This helps ensure that messages don't pile up too much in the zmq queue
966 This helps ensure that messages don't pile up too much in the zmq queue
967 while you are working on other things, or just leaving an idle terminal.
967 while you are working on other things, or just leaving an idle terminal.
968
968
969 It also helps limit potential padding of the `received` timestamp
969 It also helps limit potential padding of the `received` timestamp
970 on AsyncResult objects, used for timings.
970 on AsyncResult objects, used for timings.
971
971
972 Parameters
972 Parameters
973 ----------
973 ----------
974
974
975 interval : float, optional
975 interval : float, optional
976 The interval on which to spin the client in the background thread
976 The interval on which to spin the client in the background thread
977 (simply passed to time.sleep).
977 (simply passed to time.sleep).
978
978
979 Notes
979 Notes
980 -----
980 -----
981
981
982 For precision timing, you may want to use this method to put a bound
982 For precision timing, you may want to use this method to put a bound
983 on the jitter (in seconds) in `received` timestamps used
983 on the jitter (in seconds) in `received` timestamps used
984 in AsyncResult.wall_time.
984 in AsyncResult.wall_time.
985
985
986 """
986 """
987 if self._spin_thread is not None:
987 if self._spin_thread is not None:
988 self.stop_spin_thread()
988 self.stop_spin_thread()
989 self._stop_spinning.clear()
989 self._stop_spinning.clear()
990 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
990 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
991 self._spin_thread.daemon = True
991 self._spin_thread.daemon = True
992 self._spin_thread.start()
992 self._spin_thread.start()
993
993
994 def stop_spin_thread(self):
994 def stop_spin_thread(self):
995 """stop background spin_thread, if any"""
995 """stop background spin_thread, if any"""
996 if self._spin_thread is not None:
996 if self._spin_thread is not None:
997 self._stop_spinning.set()
997 self._stop_spinning.set()
998 self._spin_thread.join()
998 self._spin_thread.join()
999 self._spin_thread = None
999 self._spin_thread = None
1000
1000
1001 def spin(self):
1001 def spin(self):
1002 """Flush any registration notifications and execution results
1002 """Flush any registration notifications and execution results
1003 waiting in the ZMQ queue.
1003 waiting in the ZMQ queue.
1004 """
1004 """
1005 if self._notification_socket:
1005 if self._notification_socket:
1006 self._flush_notifications()
1006 self._flush_notifications()
1007 if self._iopub_socket:
1007 if self._iopub_socket:
1008 self._flush_iopub(self._iopub_socket)
1008 self._flush_iopub(self._iopub_socket)
1009 if self._mux_socket:
1009 if self._mux_socket:
1010 self._flush_results(self._mux_socket)
1010 self._flush_results(self._mux_socket)
1011 if self._task_socket:
1011 if self._task_socket:
1012 self._flush_results(self._task_socket)
1012 self._flush_results(self._task_socket)
1013 if self._control_socket:
1013 if self._control_socket:
1014 self._flush_control(self._control_socket)
1014 self._flush_control(self._control_socket)
1015 if self._query_socket:
1015 if self._query_socket:
1016 self._flush_ignored_hub_replies()
1016 self._flush_ignored_hub_replies()
1017
1017
1018 def wait(self, jobs=None, timeout=-1):
1018 def wait(self, jobs=None, timeout=-1):
1019 """waits on one or more `jobs`, for up to `timeout` seconds.
1019 """waits on one or more `jobs`, for up to `timeout` seconds.
1020
1020
1021 Parameters
1021 Parameters
1022 ----------
1022 ----------
1023
1023
1024 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1024 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1025 ints are indices to self.history
1025 ints are indices to self.history
1026 strs are msg_ids
1026 strs are msg_ids
1027 default: wait on all outstanding messages
1027 default: wait on all outstanding messages
1028 timeout : float
1028 timeout : float
1029 a time in seconds, after which to give up.
1029 a time in seconds, after which to give up.
1030 default is -1, which means no timeout
1030 default is -1, which means no timeout
1031
1031
1032 Returns
1032 Returns
1033 -------
1033 -------
1034
1034
1035 True : when all msg_ids are done
1035 True : when all msg_ids are done
1036 False : timeout reached, some msg_ids still outstanding
1036 False : timeout reached, some msg_ids still outstanding
1037 """
1037 """
1038 tic = time.time()
1038 tic = time.time()
1039 if jobs is None:
1039 if jobs is None:
1040 theids = self.outstanding
1040 theids = self.outstanding
1041 else:
1041 else:
1042 if isinstance(jobs, (int, basestring, AsyncResult)):
1042 if isinstance(jobs, (int, basestring, AsyncResult)):
1043 jobs = [jobs]
1043 jobs = [jobs]
1044 theids = set()
1044 theids = set()
1045 for job in jobs:
1045 for job in jobs:
1046 if isinstance(job, int):
1046 if isinstance(job, int):
1047 # index access
1047 # index access
1048 job = self.history[job]
1048 job = self.history[job]
1049 elif isinstance(job, AsyncResult):
1049 elif isinstance(job, AsyncResult):
1050 map(theids.add, job.msg_ids)
1050 map(theids.add, job.msg_ids)
1051 continue
1051 continue
1052 theids.add(job)
1052 theids.add(job)
1053 if not theids.intersection(self.outstanding):
1053 if not theids.intersection(self.outstanding):
1054 return True
1054 return True
1055 self.spin()
1055 self.spin()
1056 while theids.intersection(self.outstanding):
1056 while theids.intersection(self.outstanding):
1057 if timeout >= 0 and ( time.time()-tic ) > timeout:
1057 if timeout >= 0 and ( time.time()-tic ) > timeout:
1058 break
1058 break
1059 time.sleep(1e-3)
1059 time.sleep(1e-3)
1060 self.spin()
1060 self.spin()
1061 return len(theids.intersection(self.outstanding)) == 0
1061 return len(theids.intersection(self.outstanding)) == 0
1062
1062
1063 #--------------------------------------------------------------------------
1063 #--------------------------------------------------------------------------
1064 # Control methods
1064 # Control methods
1065 #--------------------------------------------------------------------------
1065 #--------------------------------------------------------------------------
1066
1066
1067 @spin_first
1067 @spin_first
1068 def clear(self, targets=None, block=None):
1068 def clear(self, targets=None, block=None):
1069 """Clear the namespace in target(s)."""
1069 """Clear the namespace in target(s)."""
1070 block = self.block if block is None else block
1070 block = self.block if block is None else block
1071 targets = self._build_targets(targets)[0]
1071 targets = self._build_targets(targets)[0]
1072 for t in targets:
1072 for t in targets:
1073 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1073 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1074 error = False
1074 error = False
1075 if block:
1075 if block:
1076 self._flush_ignored_control()
1076 self._flush_ignored_control()
1077 for i in range(len(targets)):
1077 for i in range(len(targets)):
1078 idents,msg = self.session.recv(self._control_socket,0)
1078 idents,msg = self.session.recv(self._control_socket,0)
1079 if self.debug:
1079 if self.debug:
1080 pprint(msg)
1080 pprint(msg)
1081 if msg['content']['status'] != 'ok':
1081 if msg['content']['status'] != 'ok':
1082 error = self._unwrap_exception(msg['content'])
1082 error = self._unwrap_exception(msg['content'])
1083 else:
1083 else:
1084 self._ignored_control_replies += len(targets)
1084 self._ignored_control_replies += len(targets)
1085 if error:
1085 if error:
1086 raise error
1086 raise error
1087
1087
1088
1088
1089 @spin_first
1089 @spin_first
1090 def abort(self, jobs=None, targets=None, block=None):
1090 def abort(self, jobs=None, targets=None, block=None):
1091 """Abort specific jobs from the execution queues of target(s).
1091 """Abort specific jobs from the execution queues of target(s).
1092
1092
1093 This is a mechanism to prevent jobs that have already been submitted
1093 This is a mechanism to prevent jobs that have already been submitted
1094 from executing.
1094 from executing.
1095
1095
1096 Parameters
1096 Parameters
1097 ----------
1097 ----------
1098
1098
1099 jobs : msg_id, list of msg_ids, or AsyncResult
1099 jobs : msg_id, list of msg_ids, or AsyncResult
1100 The jobs to be aborted
1100 The jobs to be aborted
1101
1101
1102 If unspecified/None: abort all outstanding jobs.
1102 If unspecified/None: abort all outstanding jobs.
1103
1103
1104 """
1104 """
1105 block = self.block if block is None else block
1105 block = self.block if block is None else block
1106 jobs = jobs if jobs is not None else list(self.outstanding)
1106 jobs = jobs if jobs is not None else list(self.outstanding)
1107 targets = self._build_targets(targets)[0]
1107 targets = self._build_targets(targets)[0]
1108
1108
1109 msg_ids = []
1109 msg_ids = []
1110 if isinstance(jobs, (basestring,AsyncResult)):
1110 if isinstance(jobs, (basestring,AsyncResult)):
1111 jobs = [jobs]
1111 jobs = [jobs]
1112 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1112 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1113 if bad_ids:
1113 if bad_ids:
1114 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1114 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1115 for j in jobs:
1115 for j in jobs:
1116 if isinstance(j, AsyncResult):
1116 if isinstance(j, AsyncResult):
1117 msg_ids.extend(j.msg_ids)
1117 msg_ids.extend(j.msg_ids)
1118 else:
1118 else:
1119 msg_ids.append(j)
1119 msg_ids.append(j)
1120 content = dict(msg_ids=msg_ids)
1120 content = dict(msg_ids=msg_ids)
1121 for t in targets:
1121 for t in targets:
1122 self.session.send(self._control_socket, 'abort_request',
1122 self.session.send(self._control_socket, 'abort_request',
1123 content=content, ident=t)
1123 content=content, ident=t)
1124 error = False
1124 error = False
1125 if block:
1125 if block:
1126 self._flush_ignored_control()
1126 self._flush_ignored_control()
1127 for i in range(len(targets)):
1127 for i in range(len(targets)):
1128 idents,msg = self.session.recv(self._control_socket,0)
1128 idents,msg = self.session.recv(self._control_socket,0)
1129 if self.debug:
1129 if self.debug:
1130 pprint(msg)
1130 pprint(msg)
1131 if msg['content']['status'] != 'ok':
1131 if msg['content']['status'] != 'ok':
1132 error = self._unwrap_exception(msg['content'])
1132 error = self._unwrap_exception(msg['content'])
1133 else:
1133 else:
1134 self._ignored_control_replies += len(targets)
1134 self._ignored_control_replies += len(targets)
1135 if error:
1135 if error:
1136 raise error
1136 raise error
1137
1137
1138 @spin_first
1138 @spin_first
1139 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1139 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1140 """Terminates one or more engine processes, optionally including the hub.
1140 """Terminates one or more engine processes, optionally including the hub.
1141
1141
1142 Parameters
1142 Parameters
1143 ----------
1143 ----------
1144
1144
1145 targets: list of ints or 'all' [default: all]
1145 targets: list of ints or 'all' [default: all]
1146 Which engines to shutdown.
1146 Which engines to shutdown.
1147 hub: bool [default: False]
1147 hub: bool [default: False]
1148 Whether to include the Hub. hub=True implies targets='all'.
1148 Whether to include the Hub. hub=True implies targets='all'.
1149 block: bool [default: self.block]
1149 block: bool [default: self.block]
1150 Whether to wait for clean shutdown replies or not.
1150 Whether to wait for clean shutdown replies or not.
1151 restart: bool [default: False]
1151 restart: bool [default: False]
1152 NOT IMPLEMENTED
1152 NOT IMPLEMENTED
1153 whether to restart engines after shutting them down.
1153 whether to restart engines after shutting them down.
1154 """
1154 """
1155 from IPython.parallel.error import NoEnginesRegistered
1155 from IPython.parallel.error import NoEnginesRegistered
1156 if restart:
1156 if restart:
1157 raise NotImplementedError("Engine restart is not yet implemented")
1157 raise NotImplementedError("Engine restart is not yet implemented")
1158
1158
1159 block = self.block if block is None else block
1159 block = self.block if block is None else block
1160 if hub:
1160 if hub:
1161 targets = 'all'
1161 targets = 'all'
1162 try:
1162 try:
1163 targets = self._build_targets(targets)[0]
1163 targets = self._build_targets(targets)[0]
1164 except NoEnginesRegistered:
1164 except NoEnginesRegistered:
1165 targets = []
1165 targets = []
1166 for t in targets:
1166 for t in targets:
1167 self.session.send(self._control_socket, 'shutdown_request',
1167 self.session.send(self._control_socket, 'shutdown_request',
1168 content={'restart':restart},ident=t)
1168 content={'restart':restart},ident=t)
1169 error = False
1169 error = False
1170 if block or hub:
1170 if block or hub:
1171 self._flush_ignored_control()
1171 self._flush_ignored_control()
1172 for i in range(len(targets)):
1172 for i in range(len(targets)):
1173 idents,msg = self.session.recv(self._control_socket, 0)
1173 idents,msg = self.session.recv(self._control_socket, 0)
1174 if self.debug:
1174 if self.debug:
1175 pprint(msg)
1175 pprint(msg)
1176 if msg['content']['status'] != 'ok':
1176 if msg['content']['status'] != 'ok':
1177 error = self._unwrap_exception(msg['content'])
1177 error = self._unwrap_exception(msg['content'])
1178 else:
1178 else:
1179 self._ignored_control_replies += len(targets)
1179 self._ignored_control_replies += len(targets)
1180
1180
1181 if hub:
1181 if hub:
1182 time.sleep(0.25)
1182 time.sleep(0.25)
1183 self.session.send(self._query_socket, 'shutdown_request')
1183 self.session.send(self._query_socket, 'shutdown_request')
1184 idents,msg = self.session.recv(self._query_socket, 0)
1184 idents,msg = self.session.recv(self._query_socket, 0)
1185 if self.debug:
1185 if self.debug:
1186 pprint(msg)
1186 pprint(msg)
1187 if msg['content']['status'] != 'ok':
1187 if msg['content']['status'] != 'ok':
1188 error = self._unwrap_exception(msg['content'])
1188 error = self._unwrap_exception(msg['content'])
1189
1189
1190 if error:
1190 if error:
1191 raise error
1191 raise error
1192
1192
1193 #--------------------------------------------------------------------------
1193 #--------------------------------------------------------------------------
1194 # Execution related methods
1194 # Execution related methods
1195 #--------------------------------------------------------------------------
1195 #--------------------------------------------------------------------------
1196
1196
1197 def _maybe_raise(self, result):
1197 def _maybe_raise(self, result):
1198 """wrapper for maybe raising an exception if apply failed."""
1198 """wrapper for maybe raising an exception if apply failed."""
1199 if isinstance(result, error.RemoteError):
1199 if isinstance(result, error.RemoteError):
1200 raise result
1200 raise result
1201
1201
1202 return result
1202 return result
1203
1203
1204 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1204 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1205 ident=None):
1205 ident=None):
1206 """construct and send an apply message via a socket.
1206 """construct and send an apply message via a socket.
1207
1207
1208 This is the principal method with which all engine execution is performed by views.
1208 This is the principal method with which all engine execution is performed by views.
1209 """
1209 """
1210
1210
1211 if self._closed:
1211 if self._closed:
1212 raise RuntimeError("Client cannot be used after its sockets have been closed")
1212 raise RuntimeError("Client cannot be used after its sockets have been closed")
1213
1213
1214 # defaults:
1214 # defaults:
1215 args = args if args is not None else []
1215 args = args if args is not None else []
1216 kwargs = kwargs if kwargs is not None else {}
1216 kwargs = kwargs if kwargs is not None else {}
1217 metadata = metadata if metadata is not None else {}
1217 metadata = metadata if metadata is not None else {}
1218
1218
1219 # validate arguments
1219 # validate arguments
1220 if not callable(f) and not isinstance(f, Reference):
1220 if not callable(f) and not isinstance(f, Reference):
1221 raise TypeError("f must be callable, not %s"%type(f))
1221 raise TypeError("f must be callable, not %s"%type(f))
1222 if not isinstance(args, (tuple, list)):
1222 if not isinstance(args, (tuple, list)):
1223 raise TypeError("args must be tuple or list, not %s"%type(args))
1223 raise TypeError("args must be tuple or list, not %s"%type(args))
1224 if not isinstance(kwargs, dict):
1224 if not isinstance(kwargs, dict):
1225 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1225 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1226 if not isinstance(metadata, dict):
1226 if not isinstance(metadata, dict):
1227 raise TypeError("metadata must be dict, not %s"%type(metadata))
1227 raise TypeError("metadata must be dict, not %s"%type(metadata))
1228
1228
1229 bufs = serialize.pack_apply_message(f, args, kwargs,
1229 bufs = serialize.pack_apply_message(f, args, kwargs,
1230 buffer_threshold=self.session.buffer_threshold,
1230 buffer_threshold=self.session.buffer_threshold,
1231 item_threshold=self.session.item_threshold,
1231 item_threshold=self.session.item_threshold,
1232 )
1232 )
1233
1233
1234 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1234 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1235 metadata=metadata, track=track)
1235 metadata=metadata, track=track)
1236
1236
1237 msg_id = msg['header']['msg_id']
1237 msg_id = msg['header']['msg_id']
1238 self.outstanding.add(msg_id)
1238 self.outstanding.add(msg_id)
1239 if ident:
1239 if ident:
1240 # possibly routed to a specific engine
1240 # possibly routed to a specific engine
1241 if isinstance(ident, list):
1241 if isinstance(ident, list):
1242 ident = ident[-1]
1242 ident = ident[-1]
1243 if ident in self._engines.values():
1243 if ident in self._engines.values():
1244 # save for later, in case of engine death
1244 # save for later, in case of engine death
1245 self._outstanding_dict[ident].add(msg_id)
1245 self._outstanding_dict[ident].add(msg_id)
1246 self.history.append(msg_id)
1246 self.history.append(msg_id)
1247 self.metadata[msg_id]['submitted'] = datetime.now()
1247 self.metadata[msg_id]['submitted'] = datetime.now()
1248
1248
1249 return msg
1249 return msg
1250
1250
1251 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1251 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1252 """construct and send an execute request via a socket.
1252 """construct and send an execute request via a socket.
1253
1253
1254 """
1254 """
1255
1255
1256 if self._closed:
1256 if self._closed:
1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258
1258
1259 # defaults:
1259 # defaults:
1260 metadata = metadata if metadata is not None else {}
1260 metadata = metadata if metadata is not None else {}
1261
1261
1262 # validate arguments
1262 # validate arguments
1263 if not isinstance(code, basestring):
1263 if not isinstance(code, basestring):
1264 raise TypeError("code must be text, not %s" % type(code))
1264 raise TypeError("code must be text, not %s" % type(code))
1265 if not isinstance(metadata, dict):
1265 if not isinstance(metadata, dict):
1266 raise TypeError("metadata must be dict, not %s" % type(metadata))
1266 raise TypeError("metadata must be dict, not %s" % type(metadata))
1267
1267
1268 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1268 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1269
1269
1270
1270
1271 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1271 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1272 metadata=metadata)
1272 metadata=metadata)
1273
1273
1274 msg_id = msg['header']['msg_id']
1274 msg_id = msg['header']['msg_id']
1275 self.outstanding.add(msg_id)
1275 self.outstanding.add(msg_id)
1276 if ident:
1276 if ident:
1277 # possibly routed to a specific engine
1277 # possibly routed to a specific engine
1278 if isinstance(ident, list):
1278 if isinstance(ident, list):
1279 ident = ident[-1]
1279 ident = ident[-1]
1280 if ident in self._engines.values():
1280 if ident in self._engines.values():
1281 # save for later, in case of engine death
1281 # save for later, in case of engine death
1282 self._outstanding_dict[ident].add(msg_id)
1282 self._outstanding_dict[ident].add(msg_id)
1283 self.history.append(msg_id)
1283 self.history.append(msg_id)
1284 self.metadata[msg_id]['submitted'] = datetime.now()
1284 self.metadata[msg_id]['submitted'] = datetime.now()
1285
1285
1286 return msg
1286 return msg
1287
1287
1288 #--------------------------------------------------------------------------
1288 #--------------------------------------------------------------------------
1289 # construct a View object
1289 # construct a View object
1290 #--------------------------------------------------------------------------
1290 #--------------------------------------------------------------------------
1291
1291
1292 def load_balanced_view(self, targets=None):
1292 def load_balanced_view(self, targets=None):
1293 """construct a DirectView object.
1293 """construct a DirectView object.
1294
1294
1295 If no arguments are specified, create a LoadBalancedView
1295 If no arguments are specified, create a LoadBalancedView
1296 using all engines.
1296 using all engines.
1297
1297
1298 Parameters
1298 Parameters
1299 ----------
1299 ----------
1300
1300
1301 targets: list,slice,int,etc. [default: use all engines]
1301 targets: list,slice,int,etc. [default: use all engines]
1302 The subset of engines across which to load-balance
1302 The subset of engines across which to load-balance
1303 """
1303 """
1304 if targets == 'all':
1304 if targets == 'all':
1305 targets = None
1305 targets = None
1306 if targets is not None:
1306 if targets is not None:
1307 targets = self._build_targets(targets)[1]
1307 targets = self._build_targets(targets)[1]
1308 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1308 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1309
1309
1310 def direct_view(self, targets='all'):
1310 def direct_view(self, targets='all'):
1311 """construct a DirectView object.
1311 """construct a DirectView object.
1312
1312
1313 If no targets are specified, create a DirectView using all engines.
1313 If no targets are specified, create a DirectView using all engines.
1314
1314
1315 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1315 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1316 evaluate the target engines at each execution, whereas rc[:] will connect to
1316 evaluate the target engines at each execution, whereas rc[:] will connect to
1317 all *current* engines, and that list will not change.
1317 all *current* engines, and that list will not change.
1318
1318
1319 That is, 'all' will always use all engines, whereas rc[:] will not use
1319 That is, 'all' will always use all engines, whereas rc[:] will not use
1320 engines added after the DirectView is constructed.
1320 engines added after the DirectView is constructed.
1321
1321
1322 Parameters
1322 Parameters
1323 ----------
1323 ----------
1324
1324
1325 targets: list,slice,int,etc. [default: use all engines]
1325 targets: list,slice,int,etc. [default: use all engines]
1326 The engines to use for the View
1326 The engines to use for the View
1327 """
1327 """
1328 single = isinstance(targets, int)
1328 single = isinstance(targets, int)
1329 # allow 'all' to be lazily evaluated at each execution
1329 # allow 'all' to be lazily evaluated at each execution
1330 if targets != 'all':
1330 if targets != 'all':
1331 targets = self._build_targets(targets)[1]
1331 targets = self._build_targets(targets)[1]
1332 if single:
1332 if single:
1333 targets = targets[0]
1333 targets = targets[0]
1334 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1334 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1335
1335
1336 #--------------------------------------------------------------------------
1336 #--------------------------------------------------------------------------
1337 # Query methods
1337 # Query methods
1338 #--------------------------------------------------------------------------
1338 #--------------------------------------------------------------------------
1339
1339
1340 @spin_first
1340 @spin_first
1341 def get_result(self, indices_or_msg_ids=None, block=None):
1341 def get_result(self, indices_or_msg_ids=None, block=None):
1342 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1342 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1343
1343
1344 If the client already has the results, no request to the Hub will be made.
1344 If the client already has the results, no request to the Hub will be made.
1345
1345
1346 This is a convenient way to construct AsyncResult objects, which are wrappers
1346 This is a convenient way to construct AsyncResult objects, which are wrappers
1347 that include metadata about execution, and allow for awaiting results that
1347 that include metadata about execution, and allow for awaiting results that
1348 were not submitted by this Client.
1348 were not submitted by this Client.
1349
1349
1350 It can also be a convenient way to retrieve the metadata associated with
1350 It can also be a convenient way to retrieve the metadata associated with
1351 blocking execution, since it always retrieves
1351 blocking execution, since it always retrieves
1352
1352
1353 Examples
1353 Examples
1354 --------
1354 --------
1355 ::
1355 ::
1356
1356
1357 In [10]: r = client.apply()
1357 In [10]: r = client.apply()
1358
1358
1359 Parameters
1359 Parameters
1360 ----------
1360 ----------
1361
1361
1362 indices_or_msg_ids : integer history index, str msg_id, or list of either
1362 indices_or_msg_ids : integer history index, str msg_id, or list of either
1363 The indices or msg_ids of indices to be retrieved
1363 The indices or msg_ids of indices to be retrieved
1364
1364
1365 block : bool
1365 block : bool
1366 Whether to wait for the result to be done
1366 Whether to wait for the result to be done
1367
1367
1368 Returns
1368 Returns
1369 -------
1369 -------
1370
1370
1371 AsyncResult
1371 AsyncResult
1372 A single AsyncResult object will always be returned.
1372 A single AsyncResult object will always be returned.
1373
1373
1374 AsyncHubResult
1374 AsyncHubResult
1375 A subclass of AsyncResult that retrieves results from the Hub
1375 A subclass of AsyncResult that retrieves results from the Hub
1376
1376
1377 """
1377 """
1378 block = self.block if block is None else block
1378 block = self.block if block is None else block
1379 if indices_or_msg_ids is None:
1379 if indices_or_msg_ids is None:
1380 indices_or_msg_ids = -1
1380 indices_or_msg_ids = -1
1381
1381
1382 single_result = False
1382 if not isinstance(indices_or_msg_ids, (list,tuple)):
1383 if not isinstance(indices_or_msg_ids, (list,tuple)):
1383 indices_or_msg_ids = [indices_or_msg_ids]
1384 indices_or_msg_ids = [indices_or_msg_ids]
1385 single_result = True
1384
1386
1385 theids = []
1387 theids = []
1386 for id in indices_or_msg_ids:
1388 for id in indices_or_msg_ids:
1387 if isinstance(id, int):
1389 if isinstance(id, int):
1388 id = self.history[id]
1390 id = self.history[id]
1389 if not isinstance(id, basestring):
1391 if not isinstance(id, basestring):
1390 raise TypeError("indices must be str or int, not %r"%id)
1392 raise TypeError("indices must be str or int, not %r"%id)
1391 theids.append(id)
1393 theids.append(id)
1392
1394
1393 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1395 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1394 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1396 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1397
1398 # given single msg_id initially, get_result shot get the result itself,
1399 # not a length-one list
1400 if single_result:
1401 theids = theids[0]
1395
1402
1396 if remote_ids:
1403 if remote_ids:
1397 ar = AsyncHubResult(self, msg_ids=theids)
1404 ar = AsyncHubResult(self, msg_ids=theids)
1398 else:
1405 else:
1399 ar = AsyncResult(self, msg_ids=theids)
1406 ar = AsyncResult(self, msg_ids=theids)
1400
1407
1401 if block:
1408 if block:
1402 ar.wait()
1409 ar.wait()
1403
1410
1404 return ar
1411 return ar
1405
1412
1406 @spin_first
1413 @spin_first
1407 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1414 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1408 """Resubmit one or more tasks.
1415 """Resubmit one or more tasks.
1409
1416
1410 in-flight tasks may not be resubmitted.
1417 in-flight tasks may not be resubmitted.
1411
1418
1412 Parameters
1419 Parameters
1413 ----------
1420 ----------
1414
1421
1415 indices_or_msg_ids : integer history index, str msg_id, or list of either
1422 indices_or_msg_ids : integer history index, str msg_id, or list of either
1416 The indices or msg_ids of indices to be retrieved
1423 The indices or msg_ids of indices to be retrieved
1417
1424
1418 block : bool
1425 block : bool
1419 Whether to wait for the result to be done
1426 Whether to wait for the result to be done
1420
1427
1421 Returns
1428 Returns
1422 -------
1429 -------
1423
1430
1424 AsyncHubResult
1431 AsyncHubResult
1425 A subclass of AsyncResult that retrieves results from the Hub
1432 A subclass of AsyncResult that retrieves results from the Hub
1426
1433
1427 """
1434 """
1428 block = self.block if block is None else block
1435 block = self.block if block is None else block
1429 if indices_or_msg_ids is None:
1436 if indices_or_msg_ids is None:
1430 indices_or_msg_ids = -1
1437 indices_or_msg_ids = -1
1431
1438
1432 if not isinstance(indices_or_msg_ids, (list,tuple)):
1439 if not isinstance(indices_or_msg_ids, (list,tuple)):
1433 indices_or_msg_ids = [indices_or_msg_ids]
1440 indices_or_msg_ids = [indices_or_msg_ids]
1434
1441
1435 theids = []
1442 theids = []
1436 for id in indices_or_msg_ids:
1443 for id in indices_or_msg_ids:
1437 if isinstance(id, int):
1444 if isinstance(id, int):
1438 id = self.history[id]
1445 id = self.history[id]
1439 if not isinstance(id, basestring):
1446 if not isinstance(id, basestring):
1440 raise TypeError("indices must be str or int, not %r"%id)
1447 raise TypeError("indices must be str or int, not %r"%id)
1441 theids.append(id)
1448 theids.append(id)
1442
1449
1443 content = dict(msg_ids = theids)
1450 content = dict(msg_ids = theids)
1444
1451
1445 self.session.send(self._query_socket, 'resubmit_request', content)
1452 self.session.send(self._query_socket, 'resubmit_request', content)
1446
1453
1447 zmq.select([self._query_socket], [], [])
1454 zmq.select([self._query_socket], [], [])
1448 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1455 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1449 if self.debug:
1456 if self.debug:
1450 pprint(msg)
1457 pprint(msg)
1451 content = msg['content']
1458 content = msg['content']
1452 if content['status'] != 'ok':
1459 if content['status'] != 'ok':
1453 raise self._unwrap_exception(content)
1460 raise self._unwrap_exception(content)
1454 mapping = content['resubmitted']
1461 mapping = content['resubmitted']
1455 new_ids = [ mapping[msg_id] for msg_id in theids ]
1462 new_ids = [ mapping[msg_id] for msg_id in theids ]
1456
1463
1457 ar = AsyncHubResult(self, msg_ids=new_ids)
1464 ar = AsyncHubResult(self, msg_ids=new_ids)
1458
1465
1459 if block:
1466 if block:
1460 ar.wait()
1467 ar.wait()
1461
1468
1462 return ar
1469 return ar
1463
1470
1464 @spin_first
1471 @spin_first
1465 def result_status(self, msg_ids, status_only=True):
1472 def result_status(self, msg_ids, status_only=True):
1466 """Check on the status of the result(s) of the apply request with `msg_ids`.
1473 """Check on the status of the result(s) of the apply request with `msg_ids`.
1467
1474
1468 If status_only is False, then the actual results will be retrieved, else
1475 If status_only is False, then the actual results will be retrieved, else
1469 only the status of the results will be checked.
1476 only the status of the results will be checked.
1470
1477
1471 Parameters
1478 Parameters
1472 ----------
1479 ----------
1473
1480
1474 msg_ids : list of msg_ids
1481 msg_ids : list of msg_ids
1475 if int:
1482 if int:
1476 Passed as index to self.history for convenience.
1483 Passed as index to self.history for convenience.
1477 status_only : bool (default: True)
1484 status_only : bool (default: True)
1478 if False:
1485 if False:
1479 Retrieve the actual results of completed tasks.
1486 Retrieve the actual results of completed tasks.
1480
1487
1481 Returns
1488 Returns
1482 -------
1489 -------
1483
1490
1484 results : dict
1491 results : dict
1485 There will always be the keys 'pending' and 'completed', which will
1492 There will always be the keys 'pending' and 'completed', which will
1486 be lists of msg_ids that are incomplete or complete. If `status_only`
1493 be lists of msg_ids that are incomplete or complete. If `status_only`
1487 is False, then completed results will be keyed by their `msg_id`.
1494 is False, then completed results will be keyed by their `msg_id`.
1488 """
1495 """
1489 if not isinstance(msg_ids, (list,tuple)):
1496 if not isinstance(msg_ids, (list,tuple)):
1490 msg_ids = [msg_ids]
1497 msg_ids = [msg_ids]
1491
1498
1492 theids = []
1499 theids = []
1493 for msg_id in msg_ids:
1500 for msg_id in msg_ids:
1494 if isinstance(msg_id, int):
1501 if isinstance(msg_id, int):
1495 msg_id = self.history[msg_id]
1502 msg_id = self.history[msg_id]
1496 if not isinstance(msg_id, basestring):
1503 if not isinstance(msg_id, basestring):
1497 raise TypeError("msg_ids must be str, not %r"%msg_id)
1504 raise TypeError("msg_ids must be str, not %r"%msg_id)
1498 theids.append(msg_id)
1505 theids.append(msg_id)
1499
1506
1500 completed = []
1507 completed = []
1501 local_results = {}
1508 local_results = {}
1502
1509
1503 # comment this block out to temporarily disable local shortcut:
1510 # comment this block out to temporarily disable local shortcut:
1504 for msg_id in theids:
1511 for msg_id in theids:
1505 if msg_id in self.results:
1512 if msg_id in self.results:
1506 completed.append(msg_id)
1513 completed.append(msg_id)
1507 local_results[msg_id] = self.results[msg_id]
1514 local_results[msg_id] = self.results[msg_id]
1508 theids.remove(msg_id)
1515 theids.remove(msg_id)
1509
1516
1510 if theids: # some not locally cached
1517 if theids: # some not locally cached
1511 content = dict(msg_ids=theids, status_only=status_only)
1518 content = dict(msg_ids=theids, status_only=status_only)
1512 msg = self.session.send(self._query_socket, "result_request", content=content)
1519 msg = self.session.send(self._query_socket, "result_request", content=content)
1513 zmq.select([self._query_socket], [], [])
1520 zmq.select([self._query_socket], [], [])
1514 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1521 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1515 if self.debug:
1522 if self.debug:
1516 pprint(msg)
1523 pprint(msg)
1517 content = msg['content']
1524 content = msg['content']
1518 if content['status'] != 'ok':
1525 if content['status'] != 'ok':
1519 raise self._unwrap_exception(content)
1526 raise self._unwrap_exception(content)
1520 buffers = msg['buffers']
1527 buffers = msg['buffers']
1521 else:
1528 else:
1522 content = dict(completed=[],pending=[])
1529 content = dict(completed=[],pending=[])
1523
1530
1524 content['completed'].extend(completed)
1531 content['completed'].extend(completed)
1525
1532
1526 if status_only:
1533 if status_only:
1527 return content
1534 return content
1528
1535
1529 failures = []
1536 failures = []
1530 # load cached results into result:
1537 # load cached results into result:
1531 content.update(local_results)
1538 content.update(local_results)
1532
1539
1533 # update cache with results:
1540 # update cache with results:
1534 for msg_id in sorted(theids):
1541 for msg_id in sorted(theids):
1535 if msg_id in content['completed']:
1542 if msg_id in content['completed']:
1536 rec = content[msg_id]
1543 rec = content[msg_id]
1537 parent = rec['header']
1544 parent = rec['header']
1538 header = rec['result_header']
1545 header = rec['result_header']
1539 rcontent = rec['result_content']
1546 rcontent = rec['result_content']
1540 iodict = rec['io']
1547 iodict = rec['io']
1541 if isinstance(rcontent, str):
1548 if isinstance(rcontent, str):
1542 rcontent = self.session.unpack(rcontent)
1549 rcontent = self.session.unpack(rcontent)
1543
1550
1544 md = self.metadata[msg_id]
1551 md = self.metadata[msg_id]
1545 md_msg = dict(
1552 md_msg = dict(
1546 content=rcontent,
1553 content=rcontent,
1547 parent_header=parent,
1554 parent_header=parent,
1548 header=header,
1555 header=header,
1549 metadata=rec['result_metadata'],
1556 metadata=rec['result_metadata'],
1550 )
1557 )
1551 md.update(self._extract_metadata(md_msg))
1558 md.update(self._extract_metadata(md_msg))
1552 if rec.get('received'):
1559 if rec.get('received'):
1553 md['received'] = rec['received']
1560 md['received'] = rec['received']
1554 md.update(iodict)
1561 md.update(iodict)
1555
1562
1556 if rcontent['status'] == 'ok':
1563 if rcontent['status'] == 'ok':
1557 if header['msg_type'] == 'apply_reply':
1564 if header['msg_type'] == 'apply_reply':
1558 res,buffers = serialize.unserialize_object(buffers)
1565 res,buffers = serialize.unserialize_object(buffers)
1559 elif header['msg_type'] == 'execute_reply':
1566 elif header['msg_type'] == 'execute_reply':
1560 res = ExecuteReply(msg_id, rcontent, md)
1567 res = ExecuteReply(msg_id, rcontent, md)
1561 else:
1568 else:
1562 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1569 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1563 else:
1570 else:
1564 res = self._unwrap_exception(rcontent)
1571 res = self._unwrap_exception(rcontent)
1565 failures.append(res)
1572 failures.append(res)
1566
1573
1567 self.results[msg_id] = res
1574 self.results[msg_id] = res
1568 content[msg_id] = res
1575 content[msg_id] = res
1569
1576
1570 if len(theids) == 1 and failures:
1577 if len(theids) == 1 and failures:
1571 raise failures[0]
1578 raise failures[0]
1572
1579
1573 error.collect_exceptions(failures, "result_status")
1580 error.collect_exceptions(failures, "result_status")
1574 return content
1581 return content
1575
1582
1576 @spin_first
1583 @spin_first
1577 def queue_status(self, targets='all', verbose=False):
1584 def queue_status(self, targets='all', verbose=False):
1578 """Fetch the status of engine queues.
1585 """Fetch the status of engine queues.
1579
1586
1580 Parameters
1587 Parameters
1581 ----------
1588 ----------
1582
1589
1583 targets : int/str/list of ints/strs
1590 targets : int/str/list of ints/strs
1584 the engines whose states are to be queried.
1591 the engines whose states are to be queried.
1585 default : all
1592 default : all
1586 verbose : bool
1593 verbose : bool
1587 Whether to return lengths only, or lists of ids for each element
1594 Whether to return lengths only, or lists of ids for each element
1588 """
1595 """
1589 if targets == 'all':
1596 if targets == 'all':
1590 # allow 'all' to be evaluated on the engine
1597 # allow 'all' to be evaluated on the engine
1591 engine_ids = None
1598 engine_ids = None
1592 else:
1599 else:
1593 engine_ids = self._build_targets(targets)[1]
1600 engine_ids = self._build_targets(targets)[1]
1594 content = dict(targets=engine_ids, verbose=verbose)
1601 content = dict(targets=engine_ids, verbose=verbose)
1595 self.session.send(self._query_socket, "queue_request", content=content)
1602 self.session.send(self._query_socket, "queue_request", content=content)
1596 idents,msg = self.session.recv(self._query_socket, 0)
1603 idents,msg = self.session.recv(self._query_socket, 0)
1597 if self.debug:
1604 if self.debug:
1598 pprint(msg)
1605 pprint(msg)
1599 content = msg['content']
1606 content = msg['content']
1600 status = content.pop('status')
1607 status = content.pop('status')
1601 if status != 'ok':
1608 if status != 'ok':
1602 raise self._unwrap_exception(content)
1609 raise self._unwrap_exception(content)
1603 content = rekey(content)
1610 content = rekey(content)
1604 if isinstance(targets, int):
1611 if isinstance(targets, int):
1605 return content[targets]
1612 return content[targets]
1606 else:
1613 else:
1607 return content
1614 return content
1608
1615
1609 def _build_msgids_from_target(self, targets=None):
1616 def _build_msgids_from_target(self, targets=None):
1610 """Build a list of msg_ids from the list of engine targets"""
1617 """Build a list of msg_ids from the list of engine targets"""
1611 if not targets: # needed as _build_targets otherwise uses all engines
1618 if not targets: # needed as _build_targets otherwise uses all engines
1612 return []
1619 return []
1613 target_ids = self._build_targets(targets)[0]
1620 target_ids = self._build_targets(targets)[0]
1614 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1621 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1615
1622
1616 def _build_msgids_from_jobs(self, jobs=None):
1623 def _build_msgids_from_jobs(self, jobs=None):
1617 """Build a list of msg_ids from "jobs" """
1624 """Build a list of msg_ids from "jobs" """
1618 if not jobs:
1625 if not jobs:
1619 return []
1626 return []
1620 msg_ids = []
1627 msg_ids = []
1621 if isinstance(jobs, (basestring,AsyncResult)):
1628 if isinstance(jobs, (basestring,AsyncResult)):
1622 jobs = [jobs]
1629 jobs = [jobs]
1623 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1630 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1624 if bad_ids:
1631 if bad_ids:
1625 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1632 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1626 for j in jobs:
1633 for j in jobs:
1627 if isinstance(j, AsyncResult):
1634 if isinstance(j, AsyncResult):
1628 msg_ids.extend(j.msg_ids)
1635 msg_ids.extend(j.msg_ids)
1629 else:
1636 else:
1630 msg_ids.append(j)
1637 msg_ids.append(j)
1631 return msg_ids
1638 return msg_ids
1632
1639
1633 def purge_local_results(self, jobs=[], targets=[]):
1640 def purge_local_results(self, jobs=[], targets=[]):
1634 """Clears the client caches of results and frees such memory.
1641 """Clears the client caches of results and frees such memory.
1635
1642
1636 Individual results can be purged by msg_id, or the entire
1643 Individual results can be purged by msg_id, or the entire
1637 history of specific targets can be purged.
1644 history of specific targets can be purged.
1638
1645
1639 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1646 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1640
1647
1641 The client must have no outstanding tasks before purging the caches.
1648 The client must have no outstanding tasks before purging the caches.
1642 Raises `AssertionError` if there are still outstanding tasks.
1649 Raises `AssertionError` if there are still outstanding tasks.
1643
1650
1644 After this call all `AsyncResults` are invalid and should be discarded.
1651 After this call all `AsyncResults` are invalid and should be discarded.
1645
1652
1646 If you must "reget" the results, you can still do so by using
1653 If you must "reget" the results, you can still do so by using
1647 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1654 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1648 redownload the results from the hub if they are still available
1655 redownload the results from the hub if they are still available
1649 (i.e `client.purge_hub_results(...)` has not been called.
1656 (i.e `client.purge_hub_results(...)` has not been called.
1650
1657
1651 Parameters
1658 Parameters
1652 ----------
1659 ----------
1653
1660
1654 jobs : str or list of str or AsyncResult objects
1661 jobs : str or list of str or AsyncResult objects
1655 the msg_ids whose results should be purged.
1662 the msg_ids whose results should be purged.
1656 targets : int/str/list of ints/strs
1663 targets : int/str/list of ints/strs
1657 The targets, by int_id, whose entire results are to be purged.
1664 The targets, by int_id, whose entire results are to be purged.
1658
1665
1659 default : None
1666 default : None
1660 """
1667 """
1661 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1668 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1662
1669
1663 if not targets and not jobs:
1670 if not targets and not jobs:
1664 raise ValueError("Must specify at least one of `targets` and `jobs`")
1671 raise ValueError("Must specify at least one of `targets` and `jobs`")
1665
1672
1666 if jobs == 'all':
1673 if jobs == 'all':
1667 self.results.clear()
1674 self.results.clear()
1668 self.metadata.clear()
1675 self.metadata.clear()
1669 return
1676 return
1670 else:
1677 else:
1671 msg_ids = []
1678 msg_ids = []
1672 msg_ids.extend(self._build_msgids_from_target(targets))
1679 msg_ids.extend(self._build_msgids_from_target(targets))
1673 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1680 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1674 map(self.results.pop, msg_ids)
1681 map(self.results.pop, msg_ids)
1675 map(self.metadata.pop, msg_ids)
1682 map(self.metadata.pop, msg_ids)
1676
1683
1677
1684
1678 @spin_first
1685 @spin_first
1679 def purge_hub_results(self, jobs=[], targets=[]):
1686 def purge_hub_results(self, jobs=[], targets=[]):
1680 """Tell the Hub to forget results.
1687 """Tell the Hub to forget results.
1681
1688
1682 Individual results can be purged by msg_id, or the entire
1689 Individual results can be purged by msg_id, or the entire
1683 history of specific targets can be purged.
1690 history of specific targets can be purged.
1684
1691
1685 Use `purge_results('all')` to scrub everything from the Hub's db.
1692 Use `purge_results('all')` to scrub everything from the Hub's db.
1686
1693
1687 Parameters
1694 Parameters
1688 ----------
1695 ----------
1689
1696
1690 jobs : str or list of str or AsyncResult objects
1697 jobs : str or list of str or AsyncResult objects
1691 the msg_ids whose results should be forgotten.
1698 the msg_ids whose results should be forgotten.
1692 targets : int/str/list of ints/strs
1699 targets : int/str/list of ints/strs
1693 The targets, by int_id, whose entire history is to be purged.
1700 The targets, by int_id, whose entire history is to be purged.
1694
1701
1695 default : None
1702 default : None
1696 """
1703 """
1697 if not targets and not jobs:
1704 if not targets and not jobs:
1698 raise ValueError("Must specify at least one of `targets` and `jobs`")
1705 raise ValueError("Must specify at least one of `targets` and `jobs`")
1699 if targets:
1706 if targets:
1700 targets = self._build_targets(targets)[1]
1707 targets = self._build_targets(targets)[1]
1701
1708
1702 # construct msg_ids from jobs
1709 # construct msg_ids from jobs
1703 if jobs == 'all':
1710 if jobs == 'all':
1704 msg_ids = jobs
1711 msg_ids = jobs
1705 else:
1712 else:
1706 msg_ids = self._build_msgids_from_jobs(jobs)
1713 msg_ids = self._build_msgids_from_jobs(jobs)
1707
1714
1708 content = dict(engine_ids=targets, msg_ids=msg_ids)
1715 content = dict(engine_ids=targets, msg_ids=msg_ids)
1709 self.session.send(self._query_socket, "purge_request", content=content)
1716 self.session.send(self._query_socket, "purge_request", content=content)
1710 idents, msg = self.session.recv(self._query_socket, 0)
1717 idents, msg = self.session.recv(self._query_socket, 0)
1711 if self.debug:
1718 if self.debug:
1712 pprint(msg)
1719 pprint(msg)
1713 content = msg['content']
1720 content = msg['content']
1714 if content['status'] != 'ok':
1721 if content['status'] != 'ok':
1715 raise self._unwrap_exception(content)
1722 raise self._unwrap_exception(content)
1716
1723
1717 def purge_results(self, jobs=[], targets=[]):
1724 def purge_results(self, jobs=[], targets=[]):
1718 """Clears the cached results from both the hub and the local client
1725 """Clears the cached results from both the hub and the local client
1719
1726
1720 Individual results can be purged by msg_id, or the entire
1727 Individual results can be purged by msg_id, or the entire
1721 history of specific targets can be purged.
1728 history of specific targets can be purged.
1722
1729
1723 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1730 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1724 the Client's db.
1731 the Client's db.
1725
1732
1726 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1733 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1727 the same arguments.
1734 the same arguments.
1728
1735
1729 Parameters
1736 Parameters
1730 ----------
1737 ----------
1731
1738
1732 jobs : str or list of str or AsyncResult objects
1739 jobs : str or list of str or AsyncResult objects
1733 the msg_ids whose results should be forgotten.
1740 the msg_ids whose results should be forgotten.
1734 targets : int/str/list of ints/strs
1741 targets : int/str/list of ints/strs
1735 The targets, by int_id, whose entire history is to be purged.
1742 The targets, by int_id, whose entire history is to be purged.
1736
1743
1737 default : None
1744 default : None
1738 """
1745 """
1739 self.purge_local_results(jobs=jobs, targets=targets)
1746 self.purge_local_results(jobs=jobs, targets=targets)
1740 self.purge_hub_results(jobs=jobs, targets=targets)
1747 self.purge_hub_results(jobs=jobs, targets=targets)
1741
1748
1742 def purge_everything(self):
1749 def purge_everything(self):
1743 """Clears all content from previous Tasks from both the hub and the local client
1750 """Clears all content from previous Tasks from both the hub and the local client
1744
1751
1745 In addition to calling `purge_results("all")` it also deletes the history and
1752 In addition to calling `purge_results("all")` it also deletes the history and
1746 other bookkeeping lists.
1753 other bookkeeping lists.
1747 """
1754 """
1748 self.purge_results("all")
1755 self.purge_results("all")
1749 self.history = []
1756 self.history = []
1750 self.session.digest_history.clear()
1757 self.session.digest_history.clear()
1751
1758
1752 @spin_first
1759 @spin_first
1753 def hub_history(self):
1760 def hub_history(self):
1754 """Get the Hub's history
1761 """Get the Hub's history
1755
1762
1756 Just like the Client, the Hub has a history, which is a list of msg_ids.
1763 Just like the Client, the Hub has a history, which is a list of msg_ids.
1757 This will contain the history of all clients, and, depending on configuration,
1764 This will contain the history of all clients, and, depending on configuration,
1758 may contain history across multiple cluster sessions.
1765 may contain history across multiple cluster sessions.
1759
1766
1760 Any msg_id returned here is a valid argument to `get_result`.
1767 Any msg_id returned here is a valid argument to `get_result`.
1761
1768
1762 Returns
1769 Returns
1763 -------
1770 -------
1764
1771
1765 msg_ids : list of strs
1772 msg_ids : list of strs
1766 list of all msg_ids, ordered by task submission time.
1773 list of all msg_ids, ordered by task submission time.
1767 """
1774 """
1768
1775
1769 self.session.send(self._query_socket, "history_request", content={})
1776 self.session.send(self._query_socket, "history_request", content={})
1770 idents, msg = self.session.recv(self._query_socket, 0)
1777 idents, msg = self.session.recv(self._query_socket, 0)
1771
1778
1772 if self.debug:
1779 if self.debug:
1773 pprint(msg)
1780 pprint(msg)
1774 content = msg['content']
1781 content = msg['content']
1775 if content['status'] != 'ok':
1782 if content['status'] != 'ok':
1776 raise self._unwrap_exception(content)
1783 raise self._unwrap_exception(content)
1777 else:
1784 else:
1778 return content['history']
1785 return content['history']
1779
1786
1780 @spin_first
1787 @spin_first
1781 def db_query(self, query, keys=None):
1788 def db_query(self, query, keys=None):
1782 """Query the Hub's TaskRecord database
1789 """Query the Hub's TaskRecord database
1783
1790
1784 This will return a list of task record dicts that match `query`
1791 This will return a list of task record dicts that match `query`
1785
1792
1786 Parameters
1793 Parameters
1787 ----------
1794 ----------
1788
1795
1789 query : mongodb query dict
1796 query : mongodb query dict
1790 The search dict. See mongodb query docs for details.
1797 The search dict. See mongodb query docs for details.
1791 keys : list of strs [optional]
1798 keys : list of strs [optional]
1792 The subset of keys to be returned. The default is to fetch everything but buffers.
1799 The subset of keys to be returned. The default is to fetch everything but buffers.
1793 'msg_id' will *always* be included.
1800 'msg_id' will *always* be included.
1794 """
1801 """
1795 if isinstance(keys, basestring):
1802 if isinstance(keys, basestring):
1796 keys = [keys]
1803 keys = [keys]
1797 content = dict(query=query, keys=keys)
1804 content = dict(query=query, keys=keys)
1798 self.session.send(self._query_socket, "db_request", content=content)
1805 self.session.send(self._query_socket, "db_request", content=content)
1799 idents, msg = self.session.recv(self._query_socket, 0)
1806 idents, msg = self.session.recv(self._query_socket, 0)
1800 if self.debug:
1807 if self.debug:
1801 pprint(msg)
1808 pprint(msg)
1802 content = msg['content']
1809 content = msg['content']
1803 if content['status'] != 'ok':
1810 if content['status'] != 'ok':
1804 raise self._unwrap_exception(content)
1811 raise self._unwrap_exception(content)
1805
1812
1806 records = content['records']
1813 records = content['records']
1807
1814
1808 buffer_lens = content['buffer_lens']
1815 buffer_lens = content['buffer_lens']
1809 result_buffer_lens = content['result_buffer_lens']
1816 result_buffer_lens = content['result_buffer_lens']
1810 buffers = msg['buffers']
1817 buffers = msg['buffers']
1811 has_bufs = buffer_lens is not None
1818 has_bufs = buffer_lens is not None
1812 has_rbufs = result_buffer_lens is not None
1819 has_rbufs = result_buffer_lens is not None
1813 for i,rec in enumerate(records):
1820 for i,rec in enumerate(records):
1814 # relink buffers
1821 # relink buffers
1815 if has_bufs:
1822 if has_bufs:
1816 blen = buffer_lens[i]
1823 blen = buffer_lens[i]
1817 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1824 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1818 if has_rbufs:
1825 if has_rbufs:
1819 blen = result_buffer_lens[i]
1826 blen = result_buffer_lens[i]
1820 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1827 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1821
1828
1822 return records
1829 return records
1823
1830
1824 __all__ = [ 'Client' ]
1831 __all__ = [ 'Client' ]
@@ -1,517 +1,518 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython import parallel
27 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
28 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
31 from IPython.parallel import LoadBalancedView, DirectView
32
32
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34
34
35 def setup():
35 def setup():
36 add_engines(4, total=True)
36 add_engines(4, total=True)
37
37
38 class TestClient(ClusterTestCase):
38 class TestClient(ClusterTestCase):
39
39
40 def test_ids(self):
40 def test_ids(self):
41 n = len(self.client.ids)
41 n = len(self.client.ids)
42 self.add_engines(2)
42 self.add_engines(2)
43 self.assertEqual(len(self.client.ids), n+2)
43 self.assertEqual(len(self.client.ids), n+2)
44
44
45 def test_view_indexing(self):
45 def test_view_indexing(self):
46 """test index access for views"""
46 """test index access for views"""
47 self.minimum_engines(4)
47 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
48 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
49 v = self.client[:]
50 self.assertEqual(v.targets, targets)
50 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
51 t = self.client.ids[2]
52 v = self.client[t]
52 v = self.client[t]
53 self.assertTrue(isinstance(v, DirectView))
53 self.assertTrue(isinstance(v, DirectView))
54 self.assertEqual(v.targets, t)
54 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
55 t = self.client.ids[2:4]
56 v = self.client[t]
56 v = self.client[t]
57 self.assertTrue(isinstance(v, DirectView))
57 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, t)
58 self.assertEqual(v.targets, t)
59 v = self.client[::2]
59 v = self.client[::2]
60 self.assertTrue(isinstance(v, DirectView))
60 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[::2])
61 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
62 v = self.client[1::3]
63 self.assertTrue(isinstance(v, DirectView))
63 self.assertTrue(isinstance(v, DirectView))
64 self.assertEqual(v.targets, targets[1::3])
64 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
65 v = self.client[:-3]
66 self.assertTrue(isinstance(v, DirectView))
66 self.assertTrue(isinstance(v, DirectView))
67 self.assertEqual(v.targets, targets[:-3])
67 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
68 v = self.client[-1]
69 self.assertTrue(isinstance(v, DirectView))
69 self.assertTrue(isinstance(v, DirectView))
70 self.assertEqual(v.targets, targets[-1])
70 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
71 self.assertRaises(TypeError, lambda : self.client[None])
72
72
73 def test_lbview_targets(self):
73 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
74 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
75 v = self.client.load_balanced_view()
76 self.assertEqual(v.targets, None)
76 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
77 v = self.client.load_balanced_view(-1)
78 self.assertEqual(v.targets, [self.client.ids[-1]])
78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
79 v = self.client.load_balanced_view('all')
80 self.assertEqual(v.targets, None)
80 self.assertEqual(v.targets, None)
81
81
82 def test_dview_targets(self):
82 def test_dview_targets(self):
83 """test direct_view targets"""
83 """test direct_view targets"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
86 v = self.client.direct_view('all')
87 self.assertEqual(v.targets, 'all')
87 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
88 v = self.client.direct_view(-1)
89 self.assertEqual(v.targets, self.client.ids[-1])
89 self.assertEqual(v.targets, self.client.ids[-1])
90
90
91 def test_lazy_all_targets(self):
91 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
92 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
93 v = self.client.direct_view()
94 self.assertEqual(v.targets, 'all')
94 self.assertEqual(v.targets, 'all')
95
95
96 def double(x):
96 def double(x):
97 return x*2
97 return x*2
98 seq = range(100)
98 seq = range(100)
99 ref = [ double(x) for x in seq ]
99 ref = [ double(x) for x in seq ]
100
100
101 # add some engines, which should be used
101 # add some engines, which should be used
102 self.add_engines(1)
102 self.add_engines(1)
103 n1 = len(self.client.ids)
103 n1 = len(self.client.ids)
104
104
105 # simple apply
105 # simple apply
106 r = v.apply_sync(lambda : 1)
106 r = v.apply_sync(lambda : 1)
107 self.assertEqual(r, [1] * n1)
107 self.assertEqual(r, [1] * n1)
108
108
109 # map goes through remotefunction
109 # map goes through remotefunction
110 r = v.map_sync(double, seq)
110 r = v.map_sync(double, seq)
111 self.assertEqual(r, ref)
111 self.assertEqual(r, ref)
112
112
113 # add a couple more engines, and try again
113 # add a couple more engines, and try again
114 self.add_engines(2)
114 self.add_engines(2)
115 n2 = len(self.client.ids)
115 n2 = len(self.client.ids)
116 self.assertNotEqual(n2, n1)
116 self.assertNotEqual(n2, n1)
117
117
118 # apply
118 # apply
119 r = v.apply_sync(lambda : 1)
119 r = v.apply_sync(lambda : 1)
120 self.assertEqual(r, [1] * n2)
120 self.assertEqual(r, [1] * n2)
121
121
122 # map
122 # map
123 r = v.map_sync(double, seq)
123 r = v.map_sync(double, seq)
124 self.assertEqual(r, ref)
124 self.assertEqual(r, ref)
125
125
126 def test_targets(self):
126 def test_targets(self):
127 """test various valid targets arguments"""
127 """test various valid targets arguments"""
128 build = self.client._build_targets
128 build = self.client._build_targets
129 ids = self.client.ids
129 ids = self.client.ids
130 idents,targets = build(None)
130 idents,targets = build(None)
131 self.assertEqual(ids, targets)
131 self.assertEqual(ids, targets)
132
132
133 def test_clear(self):
133 def test_clear(self):
134 """test clear behavior"""
134 """test clear behavior"""
135 self.minimum_engines(2)
135 self.minimum_engines(2)
136 v = self.client[:]
136 v = self.client[:]
137 v.block=True
137 v.block=True
138 v.push(dict(a=5))
138 v.push(dict(a=5))
139 v.pull('a')
139 v.pull('a')
140 id0 = self.client.ids[-1]
140 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
141 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
142 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
144 self.client.clear(block=True)
145 for i in self.client.ids:
145 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
147
148 def test_get_result(self):
148 def test_get_result(self):
149 """test getting results from the Hub."""
149 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
150 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
151 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
152 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
153 # give the monitor time to notice the message
154 time.sleep(.25)
154 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids)
155 ahr = self.client.get_result(ar.msg_ids[0])
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids)
158 ar2 = self.client.get_result(ar.msg_ids[0])
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
160 c.close()
161
161
162 def test_get_execute_result(self):
162 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
163 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
164 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
165 t = c.ids[-1]
166 cell = '\n'.join([
166 cell = '\n'.join([
167 'import time',
167 'import time',
168 'time.sleep(0.25)',
168 'time.sleep(0.25)',
169 '5'
169 '5'
170 ])
170 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
172 # give the monitor time to notice the message
173 time.sleep(.25)
173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids)
174 ahr = self.client.get_result(ar.msg_ids[0])
175 print ar.get(), ahr.get(), ar._single_result, ahr._single_result
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids)
178 ar2 = self.client.get_result(ar.msg_ids[0])
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
180 c.close()
180
181
181 def test_ids_list(self):
182 def test_ids_list(self):
182 """test client.ids"""
183 """test client.ids"""
183 ids = self.client.ids
184 ids = self.client.ids
184 self.assertEqual(ids, self.client._ids)
185 self.assertEqual(ids, self.client._ids)
185 self.assertFalse(ids is self.client._ids)
186 self.assertFalse(ids is self.client._ids)
186 ids.remove(ids[-1])
187 ids.remove(ids[-1])
187 self.assertNotEqual(ids, self.client._ids)
188 self.assertNotEqual(ids, self.client._ids)
188
189
189 def test_queue_status(self):
190 def test_queue_status(self):
190 ids = self.client.ids
191 ids = self.client.ids
191 id0 = ids[0]
192 id0 = ids[0]
192 qs = self.client.queue_status(targets=id0)
193 qs = self.client.queue_status(targets=id0)
193 self.assertTrue(isinstance(qs, dict))
194 self.assertTrue(isinstance(qs, dict))
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 allqs = self.client.queue_status()
196 allqs = self.client.queue_status()
196 self.assertTrue(isinstance(allqs, dict))
197 self.assertTrue(isinstance(allqs, dict))
197 intkeys = list(allqs.keys())
198 intkeys = list(allqs.keys())
198 intkeys.remove('unassigned')
199 intkeys.remove('unassigned')
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 unassigned = allqs.pop('unassigned')
201 unassigned = allqs.pop('unassigned')
201 for eid,qs in allqs.items():
202 for eid,qs in allqs.items():
202 self.assertTrue(isinstance(qs, dict))
203 self.assertTrue(isinstance(qs, dict))
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204
205
205 def test_shutdown(self):
206 def test_shutdown(self):
206 ids = self.client.ids
207 ids = self.client.ids
207 id0 = ids[0]
208 id0 = ids[0]
208 self.client.shutdown(id0, block=True)
209 self.client.shutdown(id0, block=True)
209 while id0 in self.client.ids:
210 while id0 in self.client.ids:
210 time.sleep(0.1)
211 time.sleep(0.1)
211 self.client.spin()
212 self.client.spin()
212
213
213 self.assertRaises(IndexError, lambda : self.client[id0])
214 self.assertRaises(IndexError, lambda : self.client[id0])
214
215
215 def test_result_status(self):
216 def test_result_status(self):
216 pass
217 pass
217 # to be written
218 # to be written
218
219
219 def test_db_query_dt(self):
220 def test_db_query_dt(self):
220 """test db query by date"""
221 """test db query by date"""
221 hist = self.client.hub_history()
222 hist = self.client.hub_history()
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 tic = middle['submitted']
224 tic = middle['submitted']
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEqual(len(before)+len(after),len(hist))
227 self.assertEqual(len(before)+len(after),len(hist))
227 for b in before:
228 for b in before:
228 self.assertTrue(b['submitted'] < tic)
229 self.assertTrue(b['submitted'] < tic)
229 for a in after:
230 for a in after:
230 self.assertTrue(a['submitted'] >= tic)
231 self.assertTrue(a['submitted'] >= tic)
231 same = self.client.db_query({'submitted' : tic})
232 same = self.client.db_query({'submitted' : tic})
232 for s in same:
233 for s in same:
233 self.assertTrue(s['submitted'] == tic)
234 self.assertTrue(s['submitted'] == tic)
234
235
235 def test_db_query_keys(self):
236 def test_db_query_keys(self):
236 """test extracting subset of record keys"""
237 """test extracting subset of record keys"""
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 for rec in found:
239 for rec in found:
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240
241
241 def test_db_query_default_keys(self):
242 def test_db_query_default_keys(self):
242 """default db_query excludes buffers"""
243 """default db_query excludes buffers"""
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 for rec in found:
245 for rec in found:
245 keys = set(rec.keys())
246 keys = set(rec.keys())
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248
249
249 def test_db_query_msg_id(self):
250 def test_db_query_msg_id(self):
250 """ensure msg_id is always in db queries"""
251 """ensure msg_id is always in db queries"""
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 for rec in found:
253 for rec in found:
253 self.assertTrue('msg_id' in rec.keys())
254 self.assertTrue('msg_id' in rec.keys())
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 for rec in found:
256 for rec in found:
256 self.assertTrue('msg_id' in rec.keys())
257 self.assertTrue('msg_id' in rec.keys())
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 for rec in found:
259 for rec in found:
259 self.assertTrue('msg_id' in rec.keys())
260 self.assertTrue('msg_id' in rec.keys())
260
261
261 def test_db_query_get_result(self):
262 def test_db_query_get_result(self):
262 """pop in db_query shouldn't pop from result itself"""
263 """pop in db_query shouldn't pop from result itself"""
263 self.client[:].apply_sync(lambda : 1)
264 self.client[:].apply_sync(lambda : 1)
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 rc2 = clientmod.Client(profile='iptest')
266 rc2 = clientmod.Client(profile='iptest')
266 # If this bug is not fixed, this call will hang:
267 # If this bug is not fixed, this call will hang:
267 ar = rc2.get_result(self.client.history[-1])
268 ar = rc2.get_result(self.client.history[-1])
268 ar.wait(2)
269 ar.wait(2)
269 self.assertTrue(ar.ready())
270 self.assertTrue(ar.ready())
270 ar.get()
271 ar.get()
271 rc2.close()
272 rc2.close()
272
273
273 def test_db_query_in(self):
274 def test_db_query_in(self):
274 """test db query with '$in','$nin' operators"""
275 """test db query with '$in','$nin' operators"""
275 hist = self.client.hub_history()
276 hist = self.client.hub_history()
276 even = hist[::2]
277 even = hist[::2]
277 odd = hist[1::2]
278 odd = hist[1::2]
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 found = [ r['msg_id'] for r in recs ]
280 found = [ r['msg_id'] for r in recs ]
280 self.assertEqual(set(even), set(found))
281 self.assertEqual(set(even), set(found))
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 found = [ r['msg_id'] for r in recs ]
283 found = [ r['msg_id'] for r in recs ]
283 self.assertEqual(set(odd), set(found))
284 self.assertEqual(set(odd), set(found))
284
285
285 def test_hub_history(self):
286 def test_hub_history(self):
286 hist = self.client.hub_history()
287 hist = self.client.hub_history()
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recdict = {}
289 recdict = {}
289 for rec in recs:
290 for rec in recs:
290 recdict[rec['msg_id']] = rec
291 recdict[rec['msg_id']] = rec
291
292
292 latest = datetime(1984,1,1)
293 latest = datetime(1984,1,1)
293 for msg_id in hist:
294 for msg_id in hist:
294 rec = recdict[msg_id]
295 rec = recdict[msg_id]
295 newt = rec['submitted']
296 newt = rec['submitted']
296 self.assertTrue(newt >= latest)
297 self.assertTrue(newt >= latest)
297 latest = newt
298 latest = newt
298 ar = self.client[-1].apply_async(lambda : 1)
299 ar = self.client[-1].apply_async(lambda : 1)
299 ar.get()
300 ar.get()
300 time.sleep(0.25)
301 time.sleep(0.25)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302
303
303 def _wait_for_idle(self):
304 def _wait_for_idle(self):
304 """wait for an engine to become idle, according to the Hub"""
305 """wait for an engine to become idle, according to the Hub"""
305 rc = self.client
306 rc = self.client
306
307
307 # step 1. wait for all requests to be noticed
308 # step 1. wait for all requests to be noticed
308 # timeout 5s, polling every 100ms
309 # timeout 5s, polling every 100ms
309 msg_ids = set(rc.history)
310 msg_ids = set(rc.history)
310 hub_hist = rc.hub_history()
311 hub_hist = rc.hub_history()
311 for i in range(50):
312 for i in range(50):
312 if msg_ids.difference(hub_hist):
313 if msg_ids.difference(hub_hist):
313 time.sleep(0.1)
314 time.sleep(0.1)
314 hub_hist = rc.hub_history()
315 hub_hist = rc.hub_history()
315 else:
316 else:
316 break
317 break
317
318
318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319
320
320 # step 2. wait for all requests to be done
321 # step 2. wait for all requests to be done
321 # timeout 5s, polling every 100ms
322 # timeout 5s, polling every 100ms
322 qs = rc.queue_status()
323 qs = rc.queue_status()
323 for i in range(50):
324 for i in range(50):
324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 time.sleep(0.1)
326 time.sleep(0.1)
326 qs = rc.queue_status()
327 qs = rc.queue_status()
327 else:
328 else:
328 break
329 break
329
330
330 # ensure Hub up to date:
331 # ensure Hub up to date:
331 self.assertEqual(qs['unassigned'], 0)
332 self.assertEqual(qs['unassigned'], 0)
332 for eid in rc.ids:
333 for eid in rc.ids:
333 self.assertEqual(qs[eid]['tasks'], 0)
334 self.assertEqual(qs[eid]['tasks'], 0)
334
335
335
336
336 def test_resubmit(self):
337 def test_resubmit(self):
337 def f():
338 def f():
338 import random
339 import random
339 return random.random()
340 return random.random()
340 v = self.client.load_balanced_view()
341 v = self.client.load_balanced_view()
341 ar = v.apply_async(f)
342 ar = v.apply_async(f)
342 r1 = ar.get(1)
343 r1 = ar.get(1)
343 # give the Hub a chance to notice:
344 # give the Hub a chance to notice:
344 self._wait_for_idle()
345 self._wait_for_idle()
345 ahr = self.client.resubmit(ar.msg_ids)
346 ahr = self.client.resubmit(ar.msg_ids)
346 r2 = ahr.get(1)
347 r2 = ahr.get(1)
347 self.assertFalse(r1 == r2)
348 self.assertFalse(r1 == r2)
348
349
349 def test_resubmit_chain(self):
350 def test_resubmit_chain(self):
350 """resubmit resubmitted tasks"""
351 """resubmit resubmitted tasks"""
351 v = self.client.load_balanced_view()
352 v = self.client.load_balanced_view()
352 ar = v.apply_async(lambda x: x, 'x'*1024)
353 ar = v.apply_async(lambda x: x, 'x'*1024)
353 ar.get()
354 ar.get()
354 self._wait_for_idle()
355 self._wait_for_idle()
355 ars = [ar]
356 ars = [ar]
356
357
357 for i in range(10):
358 for i in range(10):
358 ar = ars[-1]
359 ar = ars[-1]
359 ar2 = self.client.resubmit(ar.msg_ids)
360 ar2 = self.client.resubmit(ar.msg_ids)
360
361
361 [ ar.get() for ar in ars ]
362 [ ar.get() for ar in ars ]
362
363
363 def test_resubmit_header(self):
364 def test_resubmit_header(self):
364 """resubmit shouldn't clobber the whole header"""
365 """resubmit shouldn't clobber the whole header"""
365 def f():
366 def f():
366 import random
367 import random
367 return random.random()
368 return random.random()
368 v = self.client.load_balanced_view()
369 v = self.client.load_balanced_view()
369 v.retries = 1
370 v.retries = 1
370 ar = v.apply_async(f)
371 ar = v.apply_async(f)
371 r1 = ar.get(1)
372 r1 = ar.get(1)
372 # give the Hub a chance to notice:
373 # give the Hub a chance to notice:
373 self._wait_for_idle()
374 self._wait_for_idle()
374 ahr = self.client.resubmit(ar.msg_ids)
375 ahr = self.client.resubmit(ar.msg_ids)
375 ahr.get(1)
376 ahr.get(1)
376 time.sleep(0.5)
377 time.sleep(0.5)
377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 h1,h2 = [ r['header'] for r in records ]
379 h1,h2 = [ r['header'] for r in records ]
379 for key in set(h1.keys()).union(set(h2.keys())):
380 for key in set(h1.keys()).union(set(h2.keys())):
380 if key in ('msg_id', 'date'):
381 if key in ('msg_id', 'date'):
381 self.assertNotEqual(h1[key], h2[key])
382 self.assertNotEqual(h1[key], h2[key])
382 else:
383 else:
383 self.assertEqual(h1[key], h2[key])
384 self.assertEqual(h1[key], h2[key])
384
385
385 def test_resubmit_aborted(self):
386 def test_resubmit_aborted(self):
386 def f():
387 def f():
387 import random
388 import random
388 return random.random()
389 return random.random()
389 v = self.client.load_balanced_view()
390 v = self.client.load_balanced_view()
390 # restrict to one engine, so we can put a sleep
391 # restrict to one engine, so we can put a sleep
391 # ahead of the task, so it will get aborted
392 # ahead of the task, so it will get aborted
392 eid = self.client.ids[-1]
393 eid = self.client.ids[-1]
393 v.targets = [eid]
394 v.targets = [eid]
394 sleep = v.apply_async(time.sleep, 0.5)
395 sleep = v.apply_async(time.sleep, 0.5)
395 ar = v.apply_async(f)
396 ar = v.apply_async(f)
396 ar.abort()
397 ar.abort()
397 self.assertRaises(error.TaskAborted, ar.get)
398 self.assertRaises(error.TaskAborted, ar.get)
398 # Give the Hub a chance to get up to date:
399 # Give the Hub a chance to get up to date:
399 self._wait_for_idle()
400 self._wait_for_idle()
400 ahr = self.client.resubmit(ar.msg_ids)
401 ahr = self.client.resubmit(ar.msg_ids)
401 r2 = ahr.get(1)
402 r2 = ahr.get(1)
402
403
403 def test_resubmit_inflight(self):
404 def test_resubmit_inflight(self):
404 """resubmit of inflight task"""
405 """resubmit of inflight task"""
405 v = self.client.load_balanced_view()
406 v = self.client.load_balanced_view()
406 ar = v.apply_async(time.sleep,1)
407 ar = v.apply_async(time.sleep,1)
407 # give the message a chance to arrive
408 # give the message a chance to arrive
408 time.sleep(0.2)
409 time.sleep(0.2)
409 ahr = self.client.resubmit(ar.msg_ids)
410 ahr = self.client.resubmit(ar.msg_ids)
410 ar.get(2)
411 ar.get(2)
411 ahr.get(2)
412 ahr.get(2)
412
413
413 def test_resubmit_badkey(self):
414 def test_resubmit_badkey(self):
414 """ensure KeyError on resubmit of nonexistant task"""
415 """ensure KeyError on resubmit of nonexistant task"""
415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416
417
417 def test_purge_hub_results(self):
418 def test_purge_hub_results(self):
418 # ensure there are some tasks
419 # ensure there are some tasks
419 for i in range(5):
420 for i in range(5):
420 self.client[:].apply_sync(lambda : 1)
421 self.client[:].apply_sync(lambda : 1)
421 # Wait for the Hub to realise the result is done:
422 # Wait for the Hub to realise the result is done:
422 # This prevents a race condition, where we
423 # This prevents a race condition, where we
423 # might purge a result the Hub still thinks is pending.
424 # might purge a result the Hub still thinks is pending.
424 self._wait_for_idle()
425 self._wait_for_idle()
425 rc2 = clientmod.Client(profile='iptest')
426 rc2 = clientmod.Client(profile='iptest')
426 hist = self.client.hub_history()
427 hist = self.client.hub_history()
427 ahr = rc2.get_result([hist[-1]])
428 ahr = rc2.get_result([hist[-1]])
428 ahr.wait(10)
429 ahr.wait(10)
429 self.client.purge_hub_results(hist[-1])
430 self.client.purge_hub_results(hist[-1])
430 newhist = self.client.hub_history()
431 newhist = self.client.hub_history()
431 self.assertEqual(len(newhist)+1,len(hist))
432 self.assertEqual(len(newhist)+1,len(hist))
432 rc2.spin()
433 rc2.spin()
433 rc2.close()
434 rc2.close()
434
435
435 def test_purge_local_results(self):
436 def test_purge_local_results(self):
436 # ensure there are some tasks
437 # ensure there are some tasks
437 res = []
438 res = []
438 for i in range(5):
439 for i in range(5):
439 res.append(self.client[:].apply_async(lambda : 1))
440 res.append(self.client[:].apply_async(lambda : 1))
440 self._wait_for_idle()
441 self._wait_for_idle()
441 self.client.wait(10) # wait for the results to come back
442 self.client.wait(10) # wait for the results to come back
442 before = len(self.client.results)
443 before = len(self.client.results)
443 self.assertEqual(len(self.client.metadata),before)
444 self.assertEqual(len(self.client.metadata),before)
444 self.client.purge_local_results(res[-1])
445 self.client.purge_local_results(res[-1])
445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447
448
448 def test_purge_all_hub_results(self):
449 def test_purge_all_hub_results(self):
449 self.client.purge_hub_results('all')
450 self.client.purge_hub_results('all')
450 hist = self.client.hub_history()
451 hist = self.client.hub_history()
451 self.assertEqual(len(hist), 0)
452 self.assertEqual(len(hist), 0)
452
453
453 def test_purge_all_local_results(self):
454 def test_purge_all_local_results(self):
454 self.client.purge_local_results('all')
455 self.client.purge_local_results('all')
455 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
456 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
456 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
457 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
457
458
458 def test_purge_all_results(self):
459 def test_purge_all_results(self):
459 # ensure there are some tasks
460 # ensure there are some tasks
460 for i in range(5):
461 for i in range(5):
461 self.client[:].apply_sync(lambda : 1)
462 self.client[:].apply_sync(lambda : 1)
462 self.client.wait(10)
463 self.client.wait(10)
463 self._wait_for_idle()
464 self._wait_for_idle()
464 self.client.purge_results('all')
465 self.client.purge_results('all')
465 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
466 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
466 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
467 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
467 hist = self.client.hub_history()
468 hist = self.client.hub_history()
468 self.assertEqual(len(hist), 0, msg="hub history not empty")
469 self.assertEqual(len(hist), 0, msg="hub history not empty")
469
470
470 def test_purge_everything(self):
471 def test_purge_everything(self):
471 # ensure there are some tasks
472 # ensure there are some tasks
472 for i in range(5):
473 for i in range(5):
473 self.client[:].apply_sync(lambda : 1)
474 self.client[:].apply_sync(lambda : 1)
474 self.client.wait(10)
475 self.client.wait(10)
475 self._wait_for_idle()
476 self._wait_for_idle()
476 self.client.purge_everything()
477 self.client.purge_everything()
477 # The client results
478 # The client results
478 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
479 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
479 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
480 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
480 # The client "bookkeeping"
481 # The client "bookkeeping"
481 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
482 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
482 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
483 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
483 # the hub results
484 # the hub results
484 hist = self.client.hub_history()
485 hist = self.client.hub_history()
485 self.assertEqual(len(hist), 0, msg="hub history not empty")
486 self.assertEqual(len(hist), 0, msg="hub history not empty")
486
487
487
488
488 def test_spin_thread(self):
489 def test_spin_thread(self):
489 self.client.spin_thread(0.01)
490 self.client.spin_thread(0.01)
490 ar = self.client[-1].apply_async(lambda : 1)
491 ar = self.client[-1].apply_async(lambda : 1)
491 time.sleep(0.1)
492 time.sleep(0.1)
492 self.assertTrue(ar.wall_time < 0.1,
493 self.assertTrue(ar.wall_time < 0.1,
493 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
494 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
494 )
495 )
495
496
496 def test_stop_spin_thread(self):
497 def test_stop_spin_thread(self):
497 self.client.spin_thread(0.01)
498 self.client.spin_thread(0.01)
498 self.client.stop_spin_thread()
499 self.client.stop_spin_thread()
499 ar = self.client[-1].apply_async(lambda : 1)
500 ar = self.client[-1].apply_async(lambda : 1)
500 time.sleep(0.15)
501 time.sleep(0.15)
501 self.assertTrue(ar.wall_time > 0.1,
502 self.assertTrue(ar.wall_time > 0.1,
502 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
503 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
503 )
504 )
504
505
505 def test_activate(self):
506 def test_activate(self):
506 ip = get_ipython()
507 ip = get_ipython()
507 magics = ip.magics_manager.magics
508 magics = ip.magics_manager.magics
508 self.assertTrue('px' in magics['line'])
509 self.assertTrue('px' in magics['line'])
509 self.assertTrue('px' in magics['cell'])
510 self.assertTrue('px' in magics['cell'])
510 v0 = self.client.activate(-1, '0')
511 v0 = self.client.activate(-1, '0')
511 self.assertTrue('px0' in magics['line'])
512 self.assertTrue('px0' in magics['line'])
512 self.assertTrue('px0' in magics['cell'])
513 self.assertTrue('px0' in magics['cell'])
513 self.assertEqual(v0.targets, self.client.ids[-1])
514 self.assertEqual(v0.targets, self.client.ids[-1])
514 v0 = self.client.activate('all', 'all')
515 v0 = self.client.activate('all', 'all')
515 self.assertTrue('pxall' in magics['line'])
516 self.assertTrue('pxall' in magics['line'])
516 self.assertTrue('pxall' in magics['cell'])
517 self.assertTrue('pxall' in magics['cell'])
517 self.assertEqual(v0.targets, 'all')
518 self.assertEqual(v0.targets, 'all')
@@ -1,789 +1,789 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from collections import namedtuple
22 from collections import namedtuple
23 from tempfile import mktemp
23 from tempfile import mktemp
24 from StringIO import StringIO
24 from StringIO import StringIO
25
25
26 import zmq
26 import zmq
27 from nose import SkipTest
27 from nose import SkipTest
28 from nose.plugins.attrib import attr
28 from nose.plugins.attrib import attr
29
29
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.ipunittest import ParametricTestCase
31 from IPython.testing.ipunittest import ParametricTestCase
32 from IPython.utils.io import capture_output
32 from IPython.utils.io import capture_output
33
33
34 from IPython import parallel as pmod
34 from IPython import parallel as pmod
35 from IPython.parallel import error
35 from IPython.parallel import error
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 from IPython.parallel import DirectView
37 from IPython.parallel import DirectView
38 from IPython.parallel.util import interactive
38 from IPython.parallel.util import interactive
39
39
40 from IPython.parallel.tests import add_engines
40 from IPython.parallel.tests import add_engines
41
41
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43
43
44 def setup():
44 def setup():
45 add_engines(3, total=True)
45 add_engines(3, total=True)
46
46
47 point = namedtuple("point", "x y")
47 point = namedtuple("point", "x y")
48
48
49 class TestView(ClusterTestCase, ParametricTestCase):
49 class TestView(ClusterTestCase, ParametricTestCase):
50
50
51 def setUp(self):
51 def setUp(self):
52 # On Win XP, wait for resource cleanup, else parallel test group fails
52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 time.sleep(2)
55 time.sleep(2)
56 super(TestView, self).setUp()
56 super(TestView, self).setUp()
57
57
58 @attr('crash')
58 @attr('crash')
59 def test_z_crash_mux(self):
59 def test_z_crash_mux(self):
60 """test graceful handling of engine death (direct)"""
60 """test graceful handling of engine death (direct)"""
61 # self.add_engines(1)
61 # self.add_engines(1)
62 eid = self.client.ids[-1]
62 eid = self.client.ids[-1]
63 ar = self.client[eid].apply_async(crash)
63 ar = self.client[eid].apply_async(crash)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 eid = ar.engine_id
65 eid = ar.engine_id
66 tic = time.time()
66 tic = time.time()
67 while eid in self.client.ids and time.time()-tic < 5:
67 while eid in self.client.ids and time.time()-tic < 5:
68 time.sleep(.01)
68 time.sleep(.01)
69 self.client.spin()
69 self.client.spin()
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71
71
72 def test_push_pull(self):
72 def test_push_pull(self):
73 """test pushing and pulling"""
73 """test pushing and pulling"""
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 t = self.client.ids[-1]
75 t = self.client.ids[-1]
76 v = self.client[t]
76 v = self.client[t]
77 push = v.push
77 push = v.push
78 pull = v.pull
78 pull = v.pull
79 v.block=True
79 v.block=True
80 nengines = len(self.client)
80 nengines = len(self.client)
81 push({'data':data})
81 push({'data':data})
82 d = pull('data')
82 d = pull('data')
83 self.assertEqual(d, data)
83 self.assertEqual(d, data)
84 self.client[:].push({'data':data})
84 self.client[:].push({'data':data})
85 d = self.client[:].pull('data', block=True)
85 d = self.client[:].pull('data', block=True)
86 self.assertEqual(d, nengines*[data])
86 self.assertEqual(d, nengines*[data])
87 ar = push({'data':data}, block=False)
87 ar = push({'data':data}, block=False)
88 self.assertTrue(isinstance(ar, AsyncResult))
88 self.assertTrue(isinstance(ar, AsyncResult))
89 r = ar.get()
89 r = ar.get()
90 ar = self.client[:].pull('data', block=False)
90 ar = self.client[:].pull('data', block=False)
91 self.assertTrue(isinstance(ar, AsyncResult))
91 self.assertTrue(isinstance(ar, AsyncResult))
92 r = ar.get()
92 r = ar.get()
93 self.assertEqual(r, nengines*[data])
93 self.assertEqual(r, nengines*[data])
94 self.client[:].push(dict(a=10,b=20))
94 self.client[:].push(dict(a=10,b=20))
95 r = self.client[:].pull(('a','b'), block=True)
95 r = self.client[:].pull(('a','b'), block=True)
96 self.assertEqual(r, nengines*[[10,20]])
96 self.assertEqual(r, nengines*[[10,20]])
97
97
98 def test_push_pull_function(self):
98 def test_push_pull_function(self):
99 "test pushing and pulling functions"
99 "test pushing and pulling functions"
100 def testf(x):
100 def testf(x):
101 return 2.0*x
101 return 2.0*x
102
102
103 t = self.client.ids[-1]
103 t = self.client.ids[-1]
104 v = self.client[t]
104 v = self.client[t]
105 v.block=True
105 v.block=True
106 push = v.push
106 push = v.push
107 pull = v.pull
107 pull = v.pull
108 execute = v.execute
108 execute = v.execute
109 push({'testf':testf})
109 push({'testf':testf})
110 r = pull('testf')
110 r = pull('testf')
111 self.assertEqual(r(1.0), testf(1.0))
111 self.assertEqual(r(1.0), testf(1.0))
112 execute('r = testf(10)')
112 execute('r = testf(10)')
113 r = pull('r')
113 r = pull('r')
114 self.assertEqual(r, testf(10))
114 self.assertEqual(r, testf(10))
115 ar = self.client[:].push({'testf':testf}, block=False)
115 ar = self.client[:].push({'testf':testf}, block=False)
116 ar.get()
116 ar.get()
117 ar = self.client[:].pull('testf', block=False)
117 ar = self.client[:].pull('testf', block=False)
118 rlist = ar.get()
118 rlist = ar.get()
119 for r in rlist:
119 for r in rlist:
120 self.assertEqual(r(1.0), testf(1.0))
120 self.assertEqual(r(1.0), testf(1.0))
121 execute("def g(x): return x*x")
121 execute("def g(x): return x*x")
122 r = pull(('testf','g'))
122 r = pull(('testf','g'))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124
124
125 def test_push_function_globals(self):
125 def test_push_function_globals(self):
126 """test that pushed functions have access to globals"""
126 """test that pushed functions have access to globals"""
127 @interactive
127 @interactive
128 def geta():
128 def geta():
129 return a
129 return a
130 # self.add_engines(1)
130 # self.add_engines(1)
131 v = self.client[-1]
131 v = self.client[-1]
132 v.block=True
132 v.block=True
133 v['f'] = geta
133 v['f'] = geta
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 v.execute('a=5')
135 v.execute('a=5')
136 v.execute('b=f()')
136 v.execute('b=f()')
137 self.assertEqual(v['b'], 5)
137 self.assertEqual(v['b'], 5)
138
138
139 def test_push_function_defaults(self):
139 def test_push_function_defaults(self):
140 """test that pushed functions preserve default args"""
140 """test that pushed functions preserve default args"""
141 def echo(a=10):
141 def echo(a=10):
142 return a
142 return a
143 v = self.client[-1]
143 v = self.client[-1]
144 v.block=True
144 v.block=True
145 v['f'] = echo
145 v['f'] = echo
146 v.execute('b=f()')
146 v.execute('b=f()')
147 self.assertEqual(v['b'], 10)
147 self.assertEqual(v['b'], 10)
148
148
149 def test_get_result(self):
149 def test_get_result(self):
150 """test getting results from the Hub."""
150 """test getting results from the Hub."""
151 c = pmod.Client(profile='iptest')
151 c = pmod.Client(profile='iptest')
152 # self.add_engines(1)
152 # self.add_engines(1)
153 t = c.ids[-1]
153 t = c.ids[-1]
154 v = c[t]
154 v = c[t]
155 v2 = self.client[t]
155 v2 = self.client[t]
156 ar = v.apply_async(wait, 1)
156 ar = v.apply_async(wait, 1)
157 # give the monitor time to notice the message
157 # give the monitor time to notice the message
158 time.sleep(.25)
158 time.sleep(.25)
159 ahr = v2.get_result(ar.msg_ids)
159 ahr = v2.get_result(ar.msg_ids[0])
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 self.assertEqual(ahr.get(), ar.get())
161 self.assertEqual(ahr.get(), ar.get())
162 ar2 = v2.get_result(ar.msg_ids)
162 ar2 = v2.get_result(ar.msg_ids[0])
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 c.spin()
164 c.spin()
165 c.close()
165 c.close()
166
166
167 def test_run_newline(self):
167 def test_run_newline(self):
168 """test that run appends newline to files"""
168 """test that run appends newline to files"""
169 tmpfile = mktemp()
169 tmpfile = mktemp()
170 with open(tmpfile, 'w') as f:
170 with open(tmpfile, 'w') as f:
171 f.write("""def g():
171 f.write("""def g():
172 return 5
172 return 5
173 """)
173 """)
174 v = self.client[-1]
174 v = self.client[-1]
175 v.run(tmpfile, block=True)
175 v.run(tmpfile, block=True)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177
177
178 def test_apply_tracked(self):
178 def test_apply_tracked(self):
179 """test tracking for apply"""
179 """test tracking for apply"""
180 # self.add_engines(1)
180 # self.add_engines(1)
181 t = self.client.ids[-1]
181 t = self.client.ids[-1]
182 v = self.client[t]
182 v = self.client[t]
183 v.block=False
183 v.block=False
184 def echo(n=1024*1024, **kwargs):
184 def echo(n=1024*1024, **kwargs):
185 with v.temp_flags(**kwargs):
185 with v.temp_flags(**kwargs):
186 return v.apply(lambda x: x, 'x'*n)
186 return v.apply(lambda x: x, 'x'*n)
187 ar = echo(1, track=False)
187 ar = echo(1, track=False)
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190 ar = echo(track=True)
190 ar = echo(track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertEqual(ar.sent, ar._tracker.done)
192 self.assertEqual(ar.sent, ar._tracker.done)
193 ar._tracker.wait()
193 ar._tracker.wait()
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195
195
196 def test_push_tracked(self):
196 def test_push_tracked(self):
197 t = self.client.ids[-1]
197 t = self.client.ids[-1]
198 ns = dict(x='x'*1024*1024)
198 ns = dict(x='x'*1024*1024)
199 v = self.client[t]
199 v = self.client[t]
200 ar = v.push(ns, block=False, track=False)
200 ar = v.push(ns, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = v.push(ns, block=False, track=True)
204 ar = v.push(ns, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 ar._tracker.wait()
206 ar._tracker.wait()
207 self.assertEqual(ar.sent, ar._tracker.done)
207 self.assertEqual(ar.sent, ar._tracker.done)
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_scatter_tracked(self):
211 def test_scatter_tracked(self):
212 t = self.client.ids
212 t = self.client.ids
213 x='x'*1024*1024
213 x='x'*1024*1024
214 ar = self.client[t].scatter('x', x, block=False, track=False)
214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(ar.sent)
216 self.assertTrue(ar.sent)
217
217
218 ar = self.client[t].scatter('x', x, block=False, track=True)
218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 self.assertEqual(ar.sent, ar._tracker.done)
220 self.assertEqual(ar.sent, ar._tracker.done)
221 ar._tracker.wait()
221 ar._tracker.wait()
222 self.assertTrue(ar.sent)
222 self.assertTrue(ar.sent)
223 ar.get()
223 ar.get()
224
224
225 def test_remote_reference(self):
225 def test_remote_reference(self):
226 v = self.client[-1]
226 v = self.client[-1]
227 v['a'] = 123
227 v['a'] = 123
228 ra = pmod.Reference('a')
228 ra = pmod.Reference('a')
229 b = v.apply_sync(lambda x: x, ra)
229 b = v.apply_sync(lambda x: x, ra)
230 self.assertEqual(b, 123)
230 self.assertEqual(b, 123)
231
231
232
232
233 def test_scatter_gather(self):
233 def test_scatter_gather(self):
234 view = self.client[:]
234 view = self.client[:]
235 seq1 = range(16)
235 seq1 = range(16)
236 view.scatter('a', seq1)
236 view.scatter('a', seq1)
237 seq2 = view.gather('a', block=True)
237 seq2 = view.gather('a', block=True)
238 self.assertEqual(seq2, seq1)
238 self.assertEqual(seq2, seq1)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240
240
241 @skip_without('numpy')
241 @skip_without('numpy')
242 def test_scatter_gather_numpy(self):
242 def test_scatter_gather_numpy(self):
243 import numpy
243 import numpy
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 view = self.client[:]
245 view = self.client[:]
246 a = numpy.arange(64)
246 a = numpy.arange(64)
247 view.scatter('a', a, block=True)
247 view.scatter('a', a, block=True)
248 b = view.gather('a', block=True)
248 b = view.gather('a', block=True)
249 assert_array_equal(b, a)
249 assert_array_equal(b, a)
250
250
251 def test_scatter_gather_lazy(self):
251 def test_scatter_gather_lazy(self):
252 """scatter/gather with targets='all'"""
252 """scatter/gather with targets='all'"""
253 view = self.client.direct_view(targets='all')
253 view = self.client.direct_view(targets='all')
254 x = range(64)
254 x = range(64)
255 view.scatter('x', x)
255 view.scatter('x', x)
256 gathered = view.gather('x', block=True)
256 gathered = view.gather('x', block=True)
257 self.assertEqual(gathered, x)
257 self.assertEqual(gathered, x)
258
258
259
259
260 @dec.known_failure_py3
260 @dec.known_failure_py3
261 @skip_without('numpy')
261 @skip_without('numpy')
262 def test_push_numpy_nocopy(self):
262 def test_push_numpy_nocopy(self):
263 import numpy
263 import numpy
264 view = self.client[:]
264 view = self.client[:]
265 a = numpy.arange(64)
265 a = numpy.arange(64)
266 view['A'] = a
266 view['A'] = a
267 @interactive
267 @interactive
268 def check_writeable(x):
268 def check_writeable(x):
269 return x.flags.writeable
269 return x.flags.writeable
270
270
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273
273
274 view.push(dict(B=a))
274 view.push(dict(B=a))
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277
277
278 @skip_without('numpy')
278 @skip_without('numpy')
279 def test_apply_numpy(self):
279 def test_apply_numpy(self):
280 """view.apply(f, ndarray)"""
280 """view.apply(f, ndarray)"""
281 import numpy
281 import numpy
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283
283
284 A = numpy.random.random((100,100))
284 A = numpy.random.random((100,100))
285 view = self.client[-1]
285 view = self.client[-1]
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 B = A.astype(dt)
287 B = A.astype(dt)
288 C = view.apply_sync(lambda x:x, B)
288 C = view.apply_sync(lambda x:x, B)
289 assert_array_equal(B,C)
289 assert_array_equal(B,C)
290
290
291 @skip_without('numpy')
291 @skip_without('numpy')
292 def test_push_pull_recarray(self):
292 def test_push_pull_recarray(self):
293 """push/pull recarrays"""
293 """push/pull recarrays"""
294 import numpy
294 import numpy
295 from numpy.testing.utils import assert_array_equal
295 from numpy.testing.utils import assert_array_equal
296
296
297 view = self.client[-1]
297 view = self.client[-1]
298
298
299 R = numpy.array([
299 R = numpy.array([
300 (1, 'hi', 0.),
300 (1, 'hi', 0.),
301 (2**30, 'there', 2.5),
301 (2**30, 'there', 2.5),
302 (-99999, 'world', -12345.6789),
302 (-99999, 'world', -12345.6789),
303 ], [('n', int), ('s', '|S10'), ('f', float)])
303 ], [('n', int), ('s', '|S10'), ('f', float)])
304
304
305 view['RR'] = R
305 view['RR'] = R
306 R2 = view['RR']
306 R2 = view['RR']
307
307
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 self.assertEqual(r_dtype, R.dtype)
309 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_shape, R.shape)
310 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(R2.dtype, R.dtype)
311 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.shape, R.shape)
312 self.assertEqual(R2.shape, R.shape)
313 assert_array_equal(R2, R)
313 assert_array_equal(R2, R)
314
314
315 @skip_without('pandas')
315 @skip_without('pandas')
316 def test_push_pull_timeseries(self):
316 def test_push_pull_timeseries(self):
317 """push/pull pandas.TimeSeries"""
317 """push/pull pandas.TimeSeries"""
318 import pandas
318 import pandas
319
319
320 ts = pandas.TimeSeries(range(10))
320 ts = pandas.TimeSeries(range(10))
321
321
322 view = self.client[-1]
322 view = self.client[-1]
323
323
324 view.push(dict(ts=ts), block=True)
324 view.push(dict(ts=ts), block=True)
325 rts = view['ts']
325 rts = view['ts']
326
326
327 self.assertEqual(type(rts), type(ts))
327 self.assertEqual(type(rts), type(ts))
328 self.assertTrue((ts == rts).all())
328 self.assertTrue((ts == rts).all())
329
329
330 def test_map(self):
330 def test_map(self):
331 view = self.client[:]
331 view = self.client[:]
332 def f(x):
332 def f(x):
333 return x**2
333 return x**2
334 data = range(16)
334 data = range(16)
335 r = view.map_sync(f, data)
335 r = view.map_sync(f, data)
336 self.assertEqual(r, map(f, data))
336 self.assertEqual(r, map(f, data))
337
337
338 def test_map_iterable(self):
338 def test_map_iterable(self):
339 """test map on iterables (direct)"""
339 """test map on iterables (direct)"""
340 view = self.client[:]
340 view = self.client[:]
341 # 101 is prime, so it won't be evenly distributed
341 # 101 is prime, so it won't be evenly distributed
342 arr = range(101)
342 arr = range(101)
343 # ensure it will be an iterator, even in Python 3
343 # ensure it will be an iterator, even in Python 3
344 it = iter(arr)
344 it = iter(arr)
345 r = view.map_sync(lambda x:x, arr)
345 r = view.map_sync(lambda x:x, arr)
346 self.assertEqual(r, list(arr))
346 self.assertEqual(r, list(arr))
347
347
348 def test_scatter_gather_nonblocking(self):
348 def test_scatter_gather_nonblocking(self):
349 data = range(16)
349 data = range(16)
350 view = self.client[:]
350 view = self.client[:]
351 view.scatter('a', data, block=False)
351 view.scatter('a', data, block=False)
352 ar = view.gather('a', block=False)
352 ar = view.gather('a', block=False)
353 self.assertEqual(ar.get(), data)
353 self.assertEqual(ar.get(), data)
354
354
355 @skip_without('numpy')
355 @skip_without('numpy')
356 def test_scatter_gather_numpy_nonblocking(self):
356 def test_scatter_gather_numpy_nonblocking(self):
357 import numpy
357 import numpy
358 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
358 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
359 a = numpy.arange(64)
359 a = numpy.arange(64)
360 view = self.client[:]
360 view = self.client[:]
361 ar = view.scatter('a', a, block=False)
361 ar = view.scatter('a', a, block=False)
362 self.assertTrue(isinstance(ar, AsyncResult))
362 self.assertTrue(isinstance(ar, AsyncResult))
363 amr = view.gather('a', block=False)
363 amr = view.gather('a', block=False)
364 self.assertTrue(isinstance(amr, AsyncMapResult))
364 self.assertTrue(isinstance(amr, AsyncMapResult))
365 assert_array_equal(amr.get(), a)
365 assert_array_equal(amr.get(), a)
366
366
367 def test_execute(self):
367 def test_execute(self):
368 view = self.client[:]
368 view = self.client[:]
369 # self.client.debug=True
369 # self.client.debug=True
370 execute = view.execute
370 execute = view.execute
371 ar = execute('c=30', block=False)
371 ar = execute('c=30', block=False)
372 self.assertTrue(isinstance(ar, AsyncResult))
372 self.assertTrue(isinstance(ar, AsyncResult))
373 ar = execute('d=[0,1,2]', block=False)
373 ar = execute('d=[0,1,2]', block=False)
374 self.client.wait(ar, 1)
374 self.client.wait(ar, 1)
375 self.assertEqual(len(ar.get()), len(self.client))
375 self.assertEqual(len(ar.get()), len(self.client))
376 for c in view['c']:
376 for c in view['c']:
377 self.assertEqual(c, 30)
377 self.assertEqual(c, 30)
378
378
379 def test_abort(self):
379 def test_abort(self):
380 view = self.client[-1]
380 view = self.client[-1]
381 ar = view.execute('import time; time.sleep(1)', block=False)
381 ar = view.execute('import time; time.sleep(1)', block=False)
382 ar2 = view.apply_async(lambda : 2)
382 ar2 = view.apply_async(lambda : 2)
383 ar3 = view.apply_async(lambda : 3)
383 ar3 = view.apply_async(lambda : 3)
384 view.abort(ar2)
384 view.abort(ar2)
385 view.abort(ar3.msg_ids)
385 view.abort(ar3.msg_ids)
386 self.assertRaises(error.TaskAborted, ar2.get)
386 self.assertRaises(error.TaskAborted, ar2.get)
387 self.assertRaises(error.TaskAborted, ar3.get)
387 self.assertRaises(error.TaskAborted, ar3.get)
388
388
389 def test_abort_all(self):
389 def test_abort_all(self):
390 """view.abort() aborts all outstanding tasks"""
390 """view.abort() aborts all outstanding tasks"""
391 view = self.client[-1]
391 view = self.client[-1]
392 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
392 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
393 view.abort()
393 view.abort()
394 view.wait(timeout=5)
394 view.wait(timeout=5)
395 for ar in ars[5:]:
395 for ar in ars[5:]:
396 self.assertRaises(error.TaskAborted, ar.get)
396 self.assertRaises(error.TaskAborted, ar.get)
397
397
398 def test_temp_flags(self):
398 def test_temp_flags(self):
399 view = self.client[-1]
399 view = self.client[-1]
400 view.block=True
400 view.block=True
401 with view.temp_flags(block=False):
401 with view.temp_flags(block=False):
402 self.assertFalse(view.block)
402 self.assertFalse(view.block)
403 self.assertTrue(view.block)
403 self.assertTrue(view.block)
404
404
405 @dec.known_failure_py3
405 @dec.known_failure_py3
406 def test_importer(self):
406 def test_importer(self):
407 view = self.client[-1]
407 view = self.client[-1]
408 view.clear(block=True)
408 view.clear(block=True)
409 with view.importer:
409 with view.importer:
410 import re
410 import re
411
411
412 @interactive
412 @interactive
413 def findall(pat, s):
413 def findall(pat, s):
414 # this globals() step isn't necessary in real code
414 # this globals() step isn't necessary in real code
415 # only to prevent a closure in the test
415 # only to prevent a closure in the test
416 re = globals()['re']
416 re = globals()['re']
417 return re.findall(pat, s)
417 return re.findall(pat, s)
418
418
419 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
419 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
420
420
421 def test_unicode_execute(self):
421 def test_unicode_execute(self):
422 """test executing unicode strings"""
422 """test executing unicode strings"""
423 v = self.client[-1]
423 v = self.client[-1]
424 v.block=True
424 v.block=True
425 if sys.version_info[0] >= 3:
425 if sys.version_info[0] >= 3:
426 code="a='é'"
426 code="a='é'"
427 else:
427 else:
428 code=u"a=u'é'"
428 code=u"a=u'é'"
429 v.execute(code)
429 v.execute(code)
430 self.assertEqual(v['a'], u'é')
430 self.assertEqual(v['a'], u'é')
431
431
432 def test_unicode_apply_result(self):
432 def test_unicode_apply_result(self):
433 """test unicode apply results"""
433 """test unicode apply results"""
434 v = self.client[-1]
434 v = self.client[-1]
435 r = v.apply_sync(lambda : u'é')
435 r = v.apply_sync(lambda : u'é')
436 self.assertEqual(r, u'é')
436 self.assertEqual(r, u'é')
437
437
438 def test_unicode_apply_arg(self):
438 def test_unicode_apply_arg(self):
439 """test passing unicode arguments to apply"""
439 """test passing unicode arguments to apply"""
440 v = self.client[-1]
440 v = self.client[-1]
441
441
442 @interactive
442 @interactive
443 def check_unicode(a, check):
443 def check_unicode(a, check):
444 assert isinstance(a, unicode), "%r is not unicode"%a
444 assert isinstance(a, unicode), "%r is not unicode"%a
445 assert isinstance(check, bytes), "%r is not bytes"%check
445 assert isinstance(check, bytes), "%r is not bytes"%check
446 assert a.encode('utf8') == check, "%s != %s"%(a,check)
446 assert a.encode('utf8') == check, "%s != %s"%(a,check)
447
447
448 for s in [ u'é', u'ßø®∫',u'asdf' ]:
448 for s in [ u'é', u'ßø®∫',u'asdf' ]:
449 try:
449 try:
450 v.apply_sync(check_unicode, s, s.encode('utf8'))
450 v.apply_sync(check_unicode, s, s.encode('utf8'))
451 except error.RemoteError as e:
451 except error.RemoteError as e:
452 if e.ename == 'AssertionError':
452 if e.ename == 'AssertionError':
453 self.fail(e.evalue)
453 self.fail(e.evalue)
454 else:
454 else:
455 raise e
455 raise e
456
456
457 def test_map_reference(self):
457 def test_map_reference(self):
458 """view.map(<Reference>, *seqs) should work"""
458 """view.map(<Reference>, *seqs) should work"""
459 v = self.client[:]
459 v = self.client[:]
460 v.scatter('n', self.client.ids, flatten=True)
460 v.scatter('n', self.client.ids, flatten=True)
461 v.execute("f = lambda x,y: x*y")
461 v.execute("f = lambda x,y: x*y")
462 rf = pmod.Reference('f')
462 rf = pmod.Reference('f')
463 nlist = list(range(10))
463 nlist = list(range(10))
464 mlist = nlist[::-1]
464 mlist = nlist[::-1]
465 expected = [ m*n for m,n in zip(mlist, nlist) ]
465 expected = [ m*n for m,n in zip(mlist, nlist) ]
466 result = v.map_sync(rf, mlist, nlist)
466 result = v.map_sync(rf, mlist, nlist)
467 self.assertEqual(result, expected)
467 self.assertEqual(result, expected)
468
468
469 def test_apply_reference(self):
469 def test_apply_reference(self):
470 """view.apply(<Reference>, *args) should work"""
470 """view.apply(<Reference>, *args) should work"""
471 v = self.client[:]
471 v = self.client[:]
472 v.scatter('n', self.client.ids, flatten=True)
472 v.scatter('n', self.client.ids, flatten=True)
473 v.execute("f = lambda x: n*x")
473 v.execute("f = lambda x: n*x")
474 rf = pmod.Reference('f')
474 rf = pmod.Reference('f')
475 result = v.apply_sync(rf, 5)
475 result = v.apply_sync(rf, 5)
476 expected = [ 5*id for id in self.client.ids ]
476 expected = [ 5*id for id in self.client.ids ]
477 self.assertEqual(result, expected)
477 self.assertEqual(result, expected)
478
478
479 def test_eval_reference(self):
479 def test_eval_reference(self):
480 v = self.client[self.client.ids[0]]
480 v = self.client[self.client.ids[0]]
481 v['g'] = range(5)
481 v['g'] = range(5)
482 rg = pmod.Reference('g[0]')
482 rg = pmod.Reference('g[0]')
483 echo = lambda x:x
483 echo = lambda x:x
484 self.assertEqual(v.apply_sync(echo, rg), 0)
484 self.assertEqual(v.apply_sync(echo, rg), 0)
485
485
486 def test_reference_nameerror(self):
486 def test_reference_nameerror(self):
487 v = self.client[self.client.ids[0]]
487 v = self.client[self.client.ids[0]]
488 r = pmod.Reference('elvis_has_left')
488 r = pmod.Reference('elvis_has_left')
489 echo = lambda x:x
489 echo = lambda x:x
490 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
490 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
491
491
492 def test_single_engine_map(self):
492 def test_single_engine_map(self):
493 e0 = self.client[self.client.ids[0]]
493 e0 = self.client[self.client.ids[0]]
494 r = range(5)
494 r = range(5)
495 check = [ -1*i for i in r ]
495 check = [ -1*i for i in r ]
496 result = e0.map_sync(lambda x: -1*x, r)
496 result = e0.map_sync(lambda x: -1*x, r)
497 self.assertEqual(result, check)
497 self.assertEqual(result, check)
498
498
499 def test_len(self):
499 def test_len(self):
500 """len(view) makes sense"""
500 """len(view) makes sense"""
501 e0 = self.client[self.client.ids[0]]
501 e0 = self.client[self.client.ids[0]]
502 yield self.assertEqual(len(e0), 1)
502 yield self.assertEqual(len(e0), 1)
503 v = self.client[:]
503 v = self.client[:]
504 yield self.assertEqual(len(v), len(self.client.ids))
504 yield self.assertEqual(len(v), len(self.client.ids))
505 v = self.client.direct_view('all')
505 v = self.client.direct_view('all')
506 yield self.assertEqual(len(v), len(self.client.ids))
506 yield self.assertEqual(len(v), len(self.client.ids))
507 v = self.client[:2]
507 v = self.client[:2]
508 yield self.assertEqual(len(v), 2)
508 yield self.assertEqual(len(v), 2)
509 v = self.client[:1]
509 v = self.client[:1]
510 yield self.assertEqual(len(v), 1)
510 yield self.assertEqual(len(v), 1)
511 v = self.client.load_balanced_view()
511 v = self.client.load_balanced_view()
512 yield self.assertEqual(len(v), len(self.client.ids))
512 yield self.assertEqual(len(v), len(self.client.ids))
513 # parametric tests seem to require manual closing?
513 # parametric tests seem to require manual closing?
514 self.client.close()
514 self.client.close()
515
515
516
516
517 # begin execute tests
517 # begin execute tests
518
518
519 def test_execute_reply(self):
519 def test_execute_reply(self):
520 e0 = self.client[self.client.ids[0]]
520 e0 = self.client[self.client.ids[0]]
521 e0.block = True
521 e0.block = True
522 ar = e0.execute("5", silent=False)
522 ar = e0.execute("5", silent=False)
523 er = ar.get()
523 er = ar.get()
524 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
524 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
525 self.assertEqual(er.pyout['data']['text/plain'], '5')
525 self.assertEqual(er.pyout['data']['text/plain'], '5')
526
526
527 def test_execute_reply_stdout(self):
527 def test_execute_reply_stdout(self):
528 e0 = self.client[self.client.ids[0]]
528 e0 = self.client[self.client.ids[0]]
529 e0.block = True
529 e0.block = True
530 ar = e0.execute("print (5)", silent=False)
530 ar = e0.execute("print (5)", silent=False)
531 er = ar.get()
531 er = ar.get()
532 self.assertEqual(er.stdout.strip(), '5')
532 self.assertEqual(er.stdout.strip(), '5')
533
533
534 def test_execute_pyout(self):
534 def test_execute_pyout(self):
535 """execute triggers pyout with silent=False"""
535 """execute triggers pyout with silent=False"""
536 view = self.client[:]
536 view = self.client[:]
537 ar = view.execute("5", silent=False, block=True)
537 ar = view.execute("5", silent=False, block=True)
538
538
539 expected = [{'text/plain' : '5'}] * len(view)
539 expected = [{'text/plain' : '5'}] * len(view)
540 mimes = [ out['data'] for out in ar.pyout ]
540 mimes = [ out['data'] for out in ar.pyout ]
541 self.assertEqual(mimes, expected)
541 self.assertEqual(mimes, expected)
542
542
543 def test_execute_silent(self):
543 def test_execute_silent(self):
544 """execute does not trigger pyout with silent=True"""
544 """execute does not trigger pyout with silent=True"""
545 view = self.client[:]
545 view = self.client[:]
546 ar = view.execute("5", block=True)
546 ar = view.execute("5", block=True)
547 expected = [None] * len(view)
547 expected = [None] * len(view)
548 self.assertEqual(ar.pyout, expected)
548 self.assertEqual(ar.pyout, expected)
549
549
550 def test_execute_magic(self):
550 def test_execute_magic(self):
551 """execute accepts IPython commands"""
551 """execute accepts IPython commands"""
552 view = self.client[:]
552 view = self.client[:]
553 view.execute("a = 5")
553 view.execute("a = 5")
554 ar = view.execute("%whos", block=True)
554 ar = view.execute("%whos", block=True)
555 # this will raise, if that failed
555 # this will raise, if that failed
556 ar.get(5)
556 ar.get(5)
557 for stdout in ar.stdout:
557 for stdout in ar.stdout:
558 lines = stdout.splitlines()
558 lines = stdout.splitlines()
559 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
559 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
560 found = False
560 found = False
561 for line in lines[2:]:
561 for line in lines[2:]:
562 split = line.split()
562 split = line.split()
563 if split == ['a', 'int', '5']:
563 if split == ['a', 'int', '5']:
564 found = True
564 found = True
565 break
565 break
566 self.assertTrue(found, "whos output wrong: %s" % stdout)
566 self.assertTrue(found, "whos output wrong: %s" % stdout)
567
567
568 def test_execute_displaypub(self):
568 def test_execute_displaypub(self):
569 """execute tracks display_pub output"""
569 """execute tracks display_pub output"""
570 view = self.client[:]
570 view = self.client[:]
571 view.execute("from IPython.core.display import *")
571 view.execute("from IPython.core.display import *")
572 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
572 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
573
573
574 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
574 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
575 for outputs in ar.outputs:
575 for outputs in ar.outputs:
576 mimes = [ out['data'] for out in outputs ]
576 mimes = [ out['data'] for out in outputs ]
577 self.assertEqual(mimes, expected)
577 self.assertEqual(mimes, expected)
578
578
579 def test_apply_displaypub(self):
579 def test_apply_displaypub(self):
580 """apply tracks display_pub output"""
580 """apply tracks display_pub output"""
581 view = self.client[:]
581 view = self.client[:]
582 view.execute("from IPython.core.display import *")
582 view.execute("from IPython.core.display import *")
583
583
584 @interactive
584 @interactive
585 def publish():
585 def publish():
586 [ display(i) for i in range(5) ]
586 [ display(i) for i in range(5) ]
587
587
588 ar = view.apply_async(publish)
588 ar = view.apply_async(publish)
589 ar.get(5)
589 ar.get(5)
590 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
590 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
591 for outputs in ar.outputs:
591 for outputs in ar.outputs:
592 mimes = [ out['data'] for out in outputs ]
592 mimes = [ out['data'] for out in outputs ]
593 self.assertEqual(mimes, expected)
593 self.assertEqual(mimes, expected)
594
594
595 def test_execute_raises(self):
595 def test_execute_raises(self):
596 """exceptions in execute requests raise appropriately"""
596 """exceptions in execute requests raise appropriately"""
597 view = self.client[-1]
597 view = self.client[-1]
598 ar = view.execute("1/0")
598 ar = view.execute("1/0")
599 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
599 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
600
600
601 def test_remoteerror_render_exception(self):
601 def test_remoteerror_render_exception(self):
602 """RemoteErrors get nice tracebacks"""
602 """RemoteErrors get nice tracebacks"""
603 view = self.client[-1]
603 view = self.client[-1]
604 ar = view.execute("1/0")
604 ar = view.execute("1/0")
605 ip = get_ipython()
605 ip = get_ipython()
606 ip.user_ns['ar'] = ar
606 ip.user_ns['ar'] = ar
607 with capture_output() as io:
607 with capture_output() as io:
608 ip.run_cell("ar.get(2)")
608 ip.run_cell("ar.get(2)")
609
609
610 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
610 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
611
611
612 def test_compositeerror_render_exception(self):
612 def test_compositeerror_render_exception(self):
613 """CompositeErrors get nice tracebacks"""
613 """CompositeErrors get nice tracebacks"""
614 view = self.client[:]
614 view = self.client[:]
615 ar = view.execute("1/0")
615 ar = view.execute("1/0")
616 ip = get_ipython()
616 ip = get_ipython()
617 ip.user_ns['ar'] = ar
617 ip.user_ns['ar'] = ar
618
618
619 with capture_output() as io:
619 with capture_output() as io:
620 ip.run_cell("ar.get(2)")
620 ip.run_cell("ar.get(2)")
621
621
622 count = min(error.CompositeError.tb_limit, len(view))
622 count = min(error.CompositeError.tb_limit, len(view))
623
623
624 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
624 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
625 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
625 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
626 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
626 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
627
627
628 def test_compositeerror_truncate(self):
628 def test_compositeerror_truncate(self):
629 """Truncate CompositeErrors with many exceptions"""
629 """Truncate CompositeErrors with many exceptions"""
630 view = self.client[:]
630 view = self.client[:]
631 msg_ids = []
631 msg_ids = []
632 for i in range(10):
632 for i in range(10):
633 ar = view.execute("1/0")
633 ar = view.execute("1/0")
634 msg_ids.extend(ar.msg_ids)
634 msg_ids.extend(ar.msg_ids)
635
635
636 ar = self.client.get_result(msg_ids)
636 ar = self.client.get_result(msg_ids)
637 try:
637 try:
638 ar.get()
638 ar.get()
639 except error.CompositeError as _e:
639 except error.CompositeError as _e:
640 e = _e
640 e = _e
641 else:
641 else:
642 self.fail("Should have raised CompositeError")
642 self.fail("Should have raised CompositeError")
643
643
644 lines = e.render_traceback()
644 lines = e.render_traceback()
645 with capture_output() as io:
645 with capture_output() as io:
646 e.print_traceback()
646 e.print_traceback()
647
647
648 self.assertTrue("more exceptions" in lines[-1])
648 self.assertTrue("more exceptions" in lines[-1])
649 count = e.tb_limit
649 count = e.tb_limit
650
650
651 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
651 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
652 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
652 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
653 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654
654
655 @dec.skipif_not_matplotlib
655 @dec.skipif_not_matplotlib
656 def test_magic_pylab(self):
656 def test_magic_pylab(self):
657 """%pylab works on engines"""
657 """%pylab works on engines"""
658 view = self.client[-1]
658 view = self.client[-1]
659 ar = view.execute("%pylab inline")
659 ar = view.execute("%pylab inline")
660 # at least check if this raised:
660 # at least check if this raised:
661 reply = ar.get(5)
661 reply = ar.get(5)
662 # include imports, in case user config
662 # include imports, in case user config
663 ar = view.execute("plot(rand(100))", silent=False)
663 ar = view.execute("plot(rand(100))", silent=False)
664 reply = ar.get(5)
664 reply = ar.get(5)
665 self.assertEqual(len(reply.outputs), 1)
665 self.assertEqual(len(reply.outputs), 1)
666 output = reply.outputs[0]
666 output = reply.outputs[0]
667 self.assertTrue("data" in output)
667 self.assertTrue("data" in output)
668 data = output['data']
668 data = output['data']
669 self.assertTrue("image/png" in data)
669 self.assertTrue("image/png" in data)
670
670
671 def test_func_default_func(self):
671 def test_func_default_func(self):
672 """interactively defined function as apply func default"""
672 """interactively defined function as apply func default"""
673 def foo():
673 def foo():
674 return 'foo'
674 return 'foo'
675
675
676 def bar(f=foo):
676 def bar(f=foo):
677 return f()
677 return f()
678
678
679 view = self.client[-1]
679 view = self.client[-1]
680 ar = view.apply_async(bar)
680 ar = view.apply_async(bar)
681 r = ar.get(10)
681 r = ar.get(10)
682 self.assertEqual(r, 'foo')
682 self.assertEqual(r, 'foo')
683 def test_data_pub_single(self):
683 def test_data_pub_single(self):
684 view = self.client[-1]
684 view = self.client[-1]
685 ar = view.execute('\n'.join([
685 ar = view.execute('\n'.join([
686 'from IPython.kernel.zmq.datapub import publish_data',
686 'from IPython.kernel.zmq.datapub import publish_data',
687 'for i in range(5):',
687 'for i in range(5):',
688 ' publish_data(dict(i=i))'
688 ' publish_data(dict(i=i))'
689 ]), block=False)
689 ]), block=False)
690 self.assertTrue(isinstance(ar.data, dict))
690 self.assertTrue(isinstance(ar.data, dict))
691 ar.get(5)
691 ar.get(5)
692 self.assertEqual(ar.data, dict(i=4))
692 self.assertEqual(ar.data, dict(i=4))
693
693
694 def test_data_pub(self):
694 def test_data_pub(self):
695 view = self.client[:]
695 view = self.client[:]
696 ar = view.execute('\n'.join([
696 ar = view.execute('\n'.join([
697 'from IPython.kernel.zmq.datapub import publish_data',
697 'from IPython.kernel.zmq.datapub import publish_data',
698 'for i in range(5):',
698 'for i in range(5):',
699 ' publish_data(dict(i=i))'
699 ' publish_data(dict(i=i))'
700 ]), block=False)
700 ]), block=False)
701 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
701 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
702 ar.get(5)
702 ar.get(5)
703 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
703 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
704
704
705 def test_can_list_arg(self):
705 def test_can_list_arg(self):
706 """args in lists are canned"""
706 """args in lists are canned"""
707 view = self.client[-1]
707 view = self.client[-1]
708 view['a'] = 128
708 view['a'] = 128
709 rA = pmod.Reference('a')
709 rA = pmod.Reference('a')
710 ar = view.apply_async(lambda x: x, [rA])
710 ar = view.apply_async(lambda x: x, [rA])
711 r = ar.get(5)
711 r = ar.get(5)
712 self.assertEqual(r, [128])
712 self.assertEqual(r, [128])
713
713
714 def test_can_dict_arg(self):
714 def test_can_dict_arg(self):
715 """args in dicts are canned"""
715 """args in dicts are canned"""
716 view = self.client[-1]
716 view = self.client[-1]
717 view['a'] = 128
717 view['a'] = 128
718 rA = pmod.Reference('a')
718 rA = pmod.Reference('a')
719 ar = view.apply_async(lambda x: x, dict(foo=rA))
719 ar = view.apply_async(lambda x: x, dict(foo=rA))
720 r = ar.get(5)
720 r = ar.get(5)
721 self.assertEqual(r, dict(foo=128))
721 self.assertEqual(r, dict(foo=128))
722
722
723 def test_can_list_kwarg(self):
723 def test_can_list_kwarg(self):
724 """kwargs in lists are canned"""
724 """kwargs in lists are canned"""
725 view = self.client[-1]
725 view = self.client[-1]
726 view['a'] = 128
726 view['a'] = 128
727 rA = pmod.Reference('a')
727 rA = pmod.Reference('a')
728 ar = view.apply_async(lambda x=5: x, x=[rA])
728 ar = view.apply_async(lambda x=5: x, x=[rA])
729 r = ar.get(5)
729 r = ar.get(5)
730 self.assertEqual(r, [128])
730 self.assertEqual(r, [128])
731
731
732 def test_can_dict_kwarg(self):
732 def test_can_dict_kwarg(self):
733 """kwargs in dicts are canned"""
733 """kwargs in dicts are canned"""
734 view = self.client[-1]
734 view = self.client[-1]
735 view['a'] = 128
735 view['a'] = 128
736 rA = pmod.Reference('a')
736 rA = pmod.Reference('a')
737 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
737 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
738 r = ar.get(5)
738 r = ar.get(5)
739 self.assertEqual(r, dict(foo=128))
739 self.assertEqual(r, dict(foo=128))
740
740
741 def test_map_ref(self):
741 def test_map_ref(self):
742 """view.map works with references"""
742 """view.map works with references"""
743 view = self.client[:]
743 view = self.client[:]
744 ranks = sorted(self.client.ids)
744 ranks = sorted(self.client.ids)
745 view.scatter('rank', ranks, flatten=True)
745 view.scatter('rank', ranks, flatten=True)
746 rrank = pmod.Reference('rank')
746 rrank = pmod.Reference('rank')
747
747
748 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
748 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
749 drank = amr.get(5)
749 drank = amr.get(5)
750 self.assertEqual(drank, [ r*2 for r in ranks ])
750 self.assertEqual(drank, [ r*2 for r in ranks ])
751
751
752 def test_nested_getitem_setitem(self):
752 def test_nested_getitem_setitem(self):
753 """get and set with view['a.b']"""
753 """get and set with view['a.b']"""
754 view = self.client[-1]
754 view = self.client[-1]
755 view.execute('\n'.join([
755 view.execute('\n'.join([
756 'class A(object): pass',
756 'class A(object): pass',
757 'a = A()',
757 'a = A()',
758 'a.b = 128',
758 'a.b = 128',
759 ]), block=True)
759 ]), block=True)
760 ra = pmod.Reference('a')
760 ra = pmod.Reference('a')
761
761
762 r = view.apply_sync(lambda x: x.b, ra)
762 r = view.apply_sync(lambda x: x.b, ra)
763 self.assertEqual(r, 128)
763 self.assertEqual(r, 128)
764 self.assertEqual(view['a.b'], 128)
764 self.assertEqual(view['a.b'], 128)
765
765
766 view['a.b'] = 0
766 view['a.b'] = 0
767
767
768 r = view.apply_sync(lambda x: x.b, ra)
768 r = view.apply_sync(lambda x: x.b, ra)
769 self.assertEqual(r, 0)
769 self.assertEqual(r, 0)
770 self.assertEqual(view['a.b'], 0)
770 self.assertEqual(view['a.b'], 0)
771
771
772 def test_return_namedtuple(self):
772 def test_return_namedtuple(self):
773 def namedtuplify(x, y):
773 def namedtuplify(x, y):
774 from IPython.parallel.tests.test_view import point
774 from IPython.parallel.tests.test_view import point
775 return point(x, y)
775 return point(x, y)
776
776
777 view = self.client[-1]
777 view = self.client[-1]
778 p = view.apply_sync(namedtuplify, 1, 2)
778 p = view.apply_sync(namedtuplify, 1, 2)
779 self.assertEqual(p.x, 1)
779 self.assertEqual(p.x, 1)
780 self.assertEqual(p.y, 2)
780 self.assertEqual(p.y, 2)
781
781
782 def test_apply_namedtuple(self):
782 def test_apply_namedtuple(self):
783 def echoxy(p):
783 def echoxy(p):
784 return p.y, p.x
784 return p.y, p.x
785
785
786 view = self.client[-1]
786 view = self.client[-1]
787 tup = view.apply_sync(echoxy, point(1, 2))
787 tup = view.apply_sync(echoxy, point(1, 2))
788 self.assertEqual(tup, (2,1))
788 self.assertEqual(tup, (2,1))
789
789
General Comments 0
You need to be logged in to leave comments. Login now