##// END OF EJS Templates
store whole content, instead of just display data in metadata.outputs
MinRK -
Show More
@@ -1,1628 +1,1628 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 from threading import Thread, Event
21 from threading import Thread, Event
22 import time
22 import time
23 import warnings
23 import warnings
24 from datetime import datetime
24 from datetime import datetime
25 from getpass import getpass
25 from getpass import getpass
26 from pprint import pprint
26 from pprint import pprint
27
27
28 pjoin = os.path.join
28 pjoin = os.path.join
29
29
30 import zmq
30 import zmq
31 # from zmq.eventloop import ioloop, zmqstream
31 # from zmq.eventloop import ioloop, zmqstream
32
32
33 from IPython.config.configurable import MultipleInstanceError
33 from IPython.config.configurable import MultipleInstanceError
34 from IPython.core.application import BaseIPythonApplication
34 from IPython.core.application import BaseIPythonApplication
35
35
36 from IPython.utils.jsonutil import rekey
36 from IPython.utils.jsonutil import rekey
37 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.localinterfaces import LOCAL_IPS
38 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.path import get_ipython_dir
39 from IPython.utils.py3compat import cast_bytes
39 from IPython.utils.py3compat import cast_bytes
40 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
40 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
41 Dict, List, Bool, Set, Any)
41 Dict, List, Bool, Set, Any)
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43 from IPython.external.ssh import tunnel
43 from IPython.external.ssh import tunnel
44
44
45 from IPython.parallel import Reference
45 from IPython.parallel import Reference
46 from IPython.parallel import error
46 from IPython.parallel import error
47 from IPython.parallel import util
47 from IPython.parallel import util
48
48
49 from IPython.zmq.session import Session, Message
49 from IPython.zmq.session import Session, Message
50
50
51 from .asyncresult import AsyncResult, AsyncHubResult
51 from .asyncresult import AsyncResult, AsyncHubResult
52 from IPython.core.profiledir import ProfileDir, ProfileDirError
52 from IPython.core.profiledir import ProfileDir, ProfileDirError
53 from .view import DirectView, LoadBalancedView
53 from .view import DirectView, LoadBalancedView
54
54
55 if sys.version_info[0] >= 3:
55 if sys.version_info[0] >= 3:
56 # xrange is used in a couple 'isinstance' tests in py2
56 # xrange is used in a couple 'isinstance' tests in py2
57 # should be just 'range' in 3k
57 # should be just 'range' in 3k
58 xrange = range
58 xrange = range
59
59
60 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
61 # Decorators for Client methods
61 # Decorators for Client methods
62 #--------------------------------------------------------------------------
62 #--------------------------------------------------------------------------
63
63
64 @decorator
64 @decorator
65 def spin_first(f, self, *args, **kwargs):
65 def spin_first(f, self, *args, **kwargs):
66 """Call spin() to sync state prior to calling the method."""
66 """Call spin() to sync state prior to calling the method."""
67 self.spin()
67 self.spin()
68 return f(self, *args, **kwargs)
68 return f(self, *args, **kwargs)
69
69
70
70
71 #--------------------------------------------------------------------------
71 #--------------------------------------------------------------------------
72 # Classes
72 # Classes
73 #--------------------------------------------------------------------------
73 #--------------------------------------------------------------------------
74
74
75
75
76 class ExecuteReply(object):
76 class ExecuteReply(object):
77 """wrapper for finished Execute results"""
77 """wrapper for finished Execute results"""
78 def __init__(self, msg_id, content, metadata):
78 def __init__(self, msg_id, content, metadata):
79 self.msg_id = msg_id
79 self.msg_id = msg_id
80 self._content = content
80 self._content = content
81 self.execution_count = content['execution_count']
81 self.execution_count = content['execution_count']
82 self.metadata = metadata
82 self.metadata = metadata
83
83
84 def __getitem__(self, key):
84 def __getitem__(self, key):
85 return self.metadata[key]
85 return self.metadata[key]
86
86
87 def __getattr__(self, key):
87 def __getattr__(self, key):
88 if key not in self.metadata:
88 if key not in self.metadata:
89 raise AttributeError(key)
89 raise AttributeError(key)
90 return self.metadata[key]
90 return self.metadata[key]
91
91
92 def __repr__(self):
92 def __repr__(self):
93 pyout = self.metadata['pyout'] or {}
93 pyout = self.metadata['pyout'] or {}
94 text_out = pyout.get('text/plain', '')
94 text_out = pyout.get('data', {}).get('text/plain', '')
95 if len(text_out) > 32:
95 if len(text_out) > 32:
96 text_out = text_out[:29] + '...'
96 text_out = text_out[:29] + '...'
97
97
98 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
98 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
99
99
100 def _repr_html_(self):
100 def _repr_html_(self):
101 pyout = self.metadata['pyout'] or {}
101 pyout = self.metadata['pyout'] or {'data':{}}
102 return pyout.get("text/html")
102 return pyout['data'].get("text/html")
103
103
104 def _repr_latex_(self):
104 def _repr_latex_(self):
105 pyout = self.metadata['pyout'] or {}
105 pyout = self.metadata['pyout'] or {'data':{}}
106 return pyout.get("text/latex")
106 return pyout['data'].get("text/latex")
107
107
108 def _repr_json_(self):
108 def _repr_json_(self):
109 pyout = self.metadata['pyout'] or {}
109 pyout = self.metadata['pyout'] or {'data':{}}
110 return pyout.get("application/json")
110 return pyout['data'].get("application/json")
111
111
112 def _repr_javascript_(self):
112 def _repr_javascript_(self):
113 pyout = self.metadata['pyout'] or {}
113 pyout = self.metadata['pyout'] or {'data':{}}
114 return pyout.get("application/javascript")
114 return pyout['data'].get("application/javascript")
115
115
116 def _repr_png_(self):
116 def _repr_png_(self):
117 pyout = self.metadata['pyout'] or {}
117 pyout = self.metadata['pyout'] or {'data':{}}
118 return pyout.get("image/png")
118 return pyout['data'].get("image/png")
119
119
120 def _repr_jpeg_(self):
120 def _repr_jpeg_(self):
121 pyout = self.metadata['pyout'] or {}
121 pyout = self.metadata['pyout'] or {'data':{}}
122 return pyout.get("image/jpeg")
122 return pyout['data'].get("image/jpeg")
123
123
124 def _repr_svg_(self):
124 def _repr_svg_(self):
125 pyout = self.metadata['pyout'] or {}
125 pyout = self.metadata['pyout'] or {'data':{}}
126 return pyout.get("image/svg+xml")
126 return pyout['data'].get("image/svg+xml")
127
127
128
128
129 class Metadata(dict):
129 class Metadata(dict):
130 """Subclass of dict for initializing metadata values.
130 """Subclass of dict for initializing metadata values.
131
131
132 Attribute access works on keys.
132 Attribute access works on keys.
133
133
134 These objects have a strict set of keys - errors will raise if you try
134 These objects have a strict set of keys - errors will raise if you try
135 to add new keys.
135 to add new keys.
136 """
136 """
137 def __init__(self, *args, **kwargs):
137 def __init__(self, *args, **kwargs):
138 dict.__init__(self)
138 dict.__init__(self)
139 md = {'msg_id' : None,
139 md = {'msg_id' : None,
140 'submitted' : None,
140 'submitted' : None,
141 'started' : None,
141 'started' : None,
142 'completed' : None,
142 'completed' : None,
143 'received' : None,
143 'received' : None,
144 'engine_uuid' : None,
144 'engine_uuid' : None,
145 'engine_id' : None,
145 'engine_id' : None,
146 'follow' : None,
146 'follow' : None,
147 'after' : None,
147 'after' : None,
148 'status' : None,
148 'status' : None,
149
149
150 'pyin' : None,
150 'pyin' : None,
151 'pyout' : None,
151 'pyout' : None,
152 'pyerr' : None,
152 'pyerr' : None,
153 'stdout' : '',
153 'stdout' : '',
154 'stderr' : '',
154 'stderr' : '',
155 'outputs' : [],
155 'outputs' : [],
156 }
156 }
157 self.update(md)
157 self.update(md)
158 self.update(dict(*args, **kwargs))
158 self.update(dict(*args, **kwargs))
159
159
160 def __getattr__(self, key):
160 def __getattr__(self, key):
161 """getattr aliased to getitem"""
161 """getattr aliased to getitem"""
162 if key in self.iterkeys():
162 if key in self.iterkeys():
163 return self[key]
163 return self[key]
164 else:
164 else:
165 raise AttributeError(key)
165 raise AttributeError(key)
166
166
167 def __setattr__(self, key, value):
167 def __setattr__(self, key, value):
168 """setattr aliased to setitem, with strict"""
168 """setattr aliased to setitem, with strict"""
169 if key in self.iterkeys():
169 if key in self.iterkeys():
170 self[key] = value
170 self[key] = value
171 else:
171 else:
172 raise AttributeError(key)
172 raise AttributeError(key)
173
173
174 def __setitem__(self, key, value):
174 def __setitem__(self, key, value):
175 """strict static key enforcement"""
175 """strict static key enforcement"""
176 if key in self.iterkeys():
176 if key in self.iterkeys():
177 dict.__setitem__(self, key, value)
177 dict.__setitem__(self, key, value)
178 else:
178 else:
179 raise KeyError(key)
179 raise KeyError(key)
180
180
181
181
182 class Client(HasTraits):
182 class Client(HasTraits):
183 """A semi-synchronous client to the IPython ZMQ cluster
183 """A semi-synchronous client to the IPython ZMQ cluster
184
184
185 Parameters
185 Parameters
186 ----------
186 ----------
187
187
188 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
188 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
189 Connection information for the Hub's registration. If a json connector
189 Connection information for the Hub's registration. If a json connector
190 file is given, then likely no further configuration is necessary.
190 file is given, then likely no further configuration is necessary.
191 [Default: use profile]
191 [Default: use profile]
192 profile : bytes
192 profile : bytes
193 The name of the Cluster profile to be used to find connector information.
193 The name of the Cluster profile to be used to find connector information.
194 If run from an IPython application, the default profile will be the same
194 If run from an IPython application, the default profile will be the same
195 as the running application, otherwise it will be 'default'.
195 as the running application, otherwise it will be 'default'.
196 context : zmq.Context
196 context : zmq.Context
197 Pass an existing zmq.Context instance, otherwise the client will create its own.
197 Pass an existing zmq.Context instance, otherwise the client will create its own.
198 debug : bool
198 debug : bool
199 flag for lots of message printing for debug purposes
199 flag for lots of message printing for debug purposes
200 timeout : int/float
200 timeout : int/float
201 time (in seconds) to wait for connection replies from the Hub
201 time (in seconds) to wait for connection replies from the Hub
202 [Default: 10]
202 [Default: 10]
203
203
204 #-------------- session related args ----------------
204 #-------------- session related args ----------------
205
205
206 config : Config object
206 config : Config object
207 If specified, this will be relayed to the Session for configuration
207 If specified, this will be relayed to the Session for configuration
208 username : str
208 username : str
209 set username for the session object
209 set username for the session object
210 packer : str (import_string) or callable
210 packer : str (import_string) or callable
211 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
211 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
212 function to serialize messages. Must support same input as
212 function to serialize messages. Must support same input as
213 JSON, and output must be bytes.
213 JSON, and output must be bytes.
214 You can pass a callable directly as `pack`
214 You can pass a callable directly as `pack`
215 unpacker : str (import_string) or callable
215 unpacker : str (import_string) or callable
216 The inverse of packer. Only necessary if packer is specified as *not* one
216 The inverse of packer. Only necessary if packer is specified as *not* one
217 of 'json' or 'pickle'.
217 of 'json' or 'pickle'.
218
218
219 #-------------- ssh related args ----------------
219 #-------------- ssh related args ----------------
220 # These are args for configuring the ssh tunnel to be used
220 # These are args for configuring the ssh tunnel to be used
221 # credentials are used to forward connections over ssh to the Controller
221 # credentials are used to forward connections over ssh to the Controller
222 # Note that the ip given in `addr` needs to be relative to sshserver
222 # Note that the ip given in `addr` needs to be relative to sshserver
223 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
223 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
224 # and set sshserver as the same machine the Controller is on. However,
224 # and set sshserver as the same machine the Controller is on. However,
225 # the only requirement is that sshserver is able to see the Controller
225 # the only requirement is that sshserver is able to see the Controller
226 # (i.e. is within the same trusted network).
226 # (i.e. is within the same trusted network).
227
227
228 sshserver : str
228 sshserver : str
229 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
229 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
230 If keyfile or password is specified, and this is not, it will default to
230 If keyfile or password is specified, and this is not, it will default to
231 the ip given in addr.
231 the ip given in addr.
232 sshkey : str; path to ssh private key file
232 sshkey : str; path to ssh private key file
233 This specifies a key to be used in ssh login, default None.
233 This specifies a key to be used in ssh login, default None.
234 Regular default ssh keys will be used without specifying this argument.
234 Regular default ssh keys will be used without specifying this argument.
235 password : str
235 password : str
236 Your ssh password to sshserver. Note that if this is left None,
236 Your ssh password to sshserver. Note that if this is left None,
237 you will be prompted for it if passwordless key based login is unavailable.
237 you will be prompted for it if passwordless key based login is unavailable.
238 paramiko : bool
238 paramiko : bool
239 flag for whether to use paramiko instead of shell ssh for tunneling.
239 flag for whether to use paramiko instead of shell ssh for tunneling.
240 [default: True on win32, False else]
240 [default: True on win32, False else]
241
241
242 ------- exec authentication args -------
242 ------- exec authentication args -------
243 If even localhost is untrusted, you can have some protection against
243 If even localhost is untrusted, you can have some protection against
244 unauthorized execution by signing messages with HMAC digests.
244 unauthorized execution by signing messages with HMAC digests.
245 Messages are still sent as cleartext, so if someone can snoop your
245 Messages are still sent as cleartext, so if someone can snoop your
246 loopback traffic this will not protect your privacy, but will prevent
246 loopback traffic this will not protect your privacy, but will prevent
247 unauthorized execution.
247 unauthorized execution.
248
248
249 exec_key : str
249 exec_key : str
250 an authentication key or file containing a key
250 an authentication key or file containing a key
251 default: None
251 default: None
252
252
253
253
254 Attributes
254 Attributes
255 ----------
255 ----------
256
256
257 ids : list of int engine IDs
257 ids : list of int engine IDs
258 requesting the ids attribute always synchronizes
258 requesting the ids attribute always synchronizes
259 the registration state. To request ids without synchronization,
259 the registration state. To request ids without synchronization,
260 use semi-private _ids attributes.
260 use semi-private _ids attributes.
261
261
262 history : list of msg_ids
262 history : list of msg_ids
263 a list of msg_ids, keeping track of all the execution
263 a list of msg_ids, keeping track of all the execution
264 messages you have submitted in order.
264 messages you have submitted in order.
265
265
266 outstanding : set of msg_ids
266 outstanding : set of msg_ids
267 a set of msg_ids that have been submitted, but whose
267 a set of msg_ids that have been submitted, but whose
268 results have not yet been received.
268 results have not yet been received.
269
269
270 results : dict
270 results : dict
271 a dict of all our results, keyed by msg_id
271 a dict of all our results, keyed by msg_id
272
272
273 block : bool
273 block : bool
274 determines default behavior when block not specified
274 determines default behavior when block not specified
275 in execution methods
275 in execution methods
276
276
277 Methods
277 Methods
278 -------
278 -------
279
279
280 spin
280 spin
281 flushes incoming results and registration state changes
281 flushes incoming results and registration state changes
282 control methods spin, and requesting `ids` also ensures up to date
282 control methods spin, and requesting `ids` also ensures up to date
283
283
284 wait
284 wait
285 wait on one or more msg_ids
285 wait on one or more msg_ids
286
286
287 execution methods
287 execution methods
288 apply
288 apply
289 legacy: execute, run
289 legacy: execute, run
290
290
291 data movement
291 data movement
292 push, pull, scatter, gather
292 push, pull, scatter, gather
293
293
294 query methods
294 query methods
295 queue_status, get_result, purge, result_status
295 queue_status, get_result, purge, result_status
296
296
297 control methods
297 control methods
298 abort, shutdown
298 abort, shutdown
299
299
300 """
300 """
301
301
302
302
303 block = Bool(False)
303 block = Bool(False)
304 outstanding = Set()
304 outstanding = Set()
305 results = Instance('collections.defaultdict', (dict,))
305 results = Instance('collections.defaultdict', (dict,))
306 metadata = Instance('collections.defaultdict', (Metadata,))
306 metadata = Instance('collections.defaultdict', (Metadata,))
307 history = List()
307 history = List()
308 debug = Bool(False)
308 debug = Bool(False)
309 _spin_thread = Any()
309 _spin_thread = Any()
310 _stop_spinning = Any()
310 _stop_spinning = Any()
311
311
312 profile=Unicode()
312 profile=Unicode()
313 def _profile_default(self):
313 def _profile_default(self):
314 if BaseIPythonApplication.initialized():
314 if BaseIPythonApplication.initialized():
315 # an IPython app *might* be running, try to get its profile
315 # an IPython app *might* be running, try to get its profile
316 try:
316 try:
317 return BaseIPythonApplication.instance().profile
317 return BaseIPythonApplication.instance().profile
318 except (AttributeError, MultipleInstanceError):
318 except (AttributeError, MultipleInstanceError):
319 # could be a *different* subclass of config.Application,
319 # could be a *different* subclass of config.Application,
320 # which would raise one of these two errors.
320 # which would raise one of these two errors.
321 return u'default'
321 return u'default'
322 else:
322 else:
323 return u'default'
323 return u'default'
324
324
325
325
326 _outstanding_dict = Instance('collections.defaultdict', (set,))
326 _outstanding_dict = Instance('collections.defaultdict', (set,))
327 _ids = List()
327 _ids = List()
328 _connected=Bool(False)
328 _connected=Bool(False)
329 _ssh=Bool(False)
329 _ssh=Bool(False)
330 _context = Instance('zmq.Context')
330 _context = Instance('zmq.Context')
331 _config = Dict()
331 _config = Dict()
332 _engines=Instance(util.ReverseDict, (), {})
332 _engines=Instance(util.ReverseDict, (), {})
333 # _hub_socket=Instance('zmq.Socket')
333 # _hub_socket=Instance('zmq.Socket')
334 _query_socket=Instance('zmq.Socket')
334 _query_socket=Instance('zmq.Socket')
335 _control_socket=Instance('zmq.Socket')
335 _control_socket=Instance('zmq.Socket')
336 _iopub_socket=Instance('zmq.Socket')
336 _iopub_socket=Instance('zmq.Socket')
337 _notification_socket=Instance('zmq.Socket')
337 _notification_socket=Instance('zmq.Socket')
338 _mux_socket=Instance('zmq.Socket')
338 _mux_socket=Instance('zmq.Socket')
339 _task_socket=Instance('zmq.Socket')
339 _task_socket=Instance('zmq.Socket')
340 _task_scheme=Unicode()
340 _task_scheme=Unicode()
341 _closed = False
341 _closed = False
342 _ignored_control_replies=Integer(0)
342 _ignored_control_replies=Integer(0)
343 _ignored_hub_replies=Integer(0)
343 _ignored_hub_replies=Integer(0)
344
344
345 def __new__(self, *args, **kw):
345 def __new__(self, *args, **kw):
346 # don't raise on positional args
346 # don't raise on positional args
347 return HasTraits.__new__(self, **kw)
347 return HasTraits.__new__(self, **kw)
348
348
349 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
349 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
350 context=None, debug=False, exec_key=None,
350 context=None, debug=False, exec_key=None,
351 sshserver=None, sshkey=None, password=None, paramiko=None,
351 sshserver=None, sshkey=None, password=None, paramiko=None,
352 timeout=10, **extra_args
352 timeout=10, **extra_args
353 ):
353 ):
354 if profile:
354 if profile:
355 super(Client, self).__init__(debug=debug, profile=profile)
355 super(Client, self).__init__(debug=debug, profile=profile)
356 else:
356 else:
357 super(Client, self).__init__(debug=debug)
357 super(Client, self).__init__(debug=debug)
358 if context is None:
358 if context is None:
359 context = zmq.Context.instance()
359 context = zmq.Context.instance()
360 self._context = context
360 self._context = context
361 self._stop_spinning = Event()
361 self._stop_spinning = Event()
362
362
363 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
363 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
364 if self._cd is not None:
364 if self._cd is not None:
365 if url_or_file is None:
365 if url_or_file is None:
366 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
366 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
367 if url_or_file is None:
367 if url_or_file is None:
368 raise ValueError(
368 raise ValueError(
369 "I can't find enough information to connect to a hub!"
369 "I can't find enough information to connect to a hub!"
370 " Please specify at least one of url_or_file or profile."
370 " Please specify at least one of url_or_file or profile."
371 )
371 )
372
372
373 if not util.is_url(url_or_file):
373 if not util.is_url(url_or_file):
374 # it's not a url, try for a file
374 # it's not a url, try for a file
375 if not os.path.exists(url_or_file):
375 if not os.path.exists(url_or_file):
376 if self._cd:
376 if self._cd:
377 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
377 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
378 if not os.path.exists(url_or_file):
378 if not os.path.exists(url_or_file):
379 raise IOError("Connection file not found: %r" % url_or_file)
379 raise IOError("Connection file not found: %r" % url_or_file)
380 with open(url_or_file) as f:
380 with open(url_or_file) as f:
381 cfg = json.loads(f.read())
381 cfg = json.loads(f.read())
382 else:
382 else:
383 cfg = {'url':url_or_file}
383 cfg = {'url':url_or_file}
384
384
385 # sync defaults from args, json:
385 # sync defaults from args, json:
386 if sshserver:
386 if sshserver:
387 cfg['ssh'] = sshserver
387 cfg['ssh'] = sshserver
388 if exec_key:
388 if exec_key:
389 cfg['exec_key'] = exec_key
389 cfg['exec_key'] = exec_key
390 exec_key = cfg['exec_key']
390 exec_key = cfg['exec_key']
391 location = cfg.setdefault('location', None)
391 location = cfg.setdefault('location', None)
392 cfg['url'] = util.disambiguate_url(cfg['url'], location)
392 cfg['url'] = util.disambiguate_url(cfg['url'], location)
393 url = cfg['url']
393 url = cfg['url']
394 proto,addr,port = util.split_url(url)
394 proto,addr,port = util.split_url(url)
395 if location is not None and addr == '127.0.0.1':
395 if location is not None and addr == '127.0.0.1':
396 # location specified, and connection is expected to be local
396 # location specified, and connection is expected to be local
397 if location not in LOCAL_IPS and not sshserver:
397 if location not in LOCAL_IPS and not sshserver:
398 # load ssh from JSON *only* if the controller is not on
398 # load ssh from JSON *only* if the controller is not on
399 # this machine
399 # this machine
400 sshserver=cfg['ssh']
400 sshserver=cfg['ssh']
401 if location not in LOCAL_IPS and not sshserver:
401 if location not in LOCAL_IPS and not sshserver:
402 # warn if no ssh specified, but SSH is probably needed
402 # warn if no ssh specified, but SSH is probably needed
403 # This is only a warning, because the most likely cause
403 # This is only a warning, because the most likely cause
404 # is a local Controller on a laptop whose IP is dynamic
404 # is a local Controller on a laptop whose IP is dynamic
405 warnings.warn("""
405 warnings.warn("""
406 Controller appears to be listening on localhost, but not on this machine.
406 Controller appears to be listening on localhost, but not on this machine.
407 If this is true, you should specify Client(...,sshserver='you@%s')
407 If this is true, you should specify Client(...,sshserver='you@%s')
408 or instruct your controller to listen on an external IP."""%location,
408 or instruct your controller to listen on an external IP."""%location,
409 RuntimeWarning)
409 RuntimeWarning)
410 elif not sshserver:
410 elif not sshserver:
411 # otherwise sync with cfg
411 # otherwise sync with cfg
412 sshserver = cfg['ssh']
412 sshserver = cfg['ssh']
413
413
414 self._config = cfg
414 self._config = cfg
415
415
416 self._ssh = bool(sshserver or sshkey or password)
416 self._ssh = bool(sshserver or sshkey or password)
417 if self._ssh and sshserver is None:
417 if self._ssh and sshserver is None:
418 # default to ssh via localhost
418 # default to ssh via localhost
419 sshserver = url.split('://')[1].split(':')[0]
419 sshserver = url.split('://')[1].split(':')[0]
420 if self._ssh and password is None:
420 if self._ssh and password is None:
421 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
421 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
422 password=False
422 password=False
423 else:
423 else:
424 password = getpass("SSH Password for %s: "%sshserver)
424 password = getpass("SSH Password for %s: "%sshserver)
425 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
425 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
426
426
427 # configure and construct the session
427 # configure and construct the session
428 if exec_key is not None:
428 if exec_key is not None:
429 if os.path.isfile(exec_key):
429 if os.path.isfile(exec_key):
430 extra_args['keyfile'] = exec_key
430 extra_args['keyfile'] = exec_key
431 else:
431 else:
432 exec_key = cast_bytes(exec_key)
432 exec_key = cast_bytes(exec_key)
433 extra_args['key'] = exec_key
433 extra_args['key'] = exec_key
434 self.session = Session(**extra_args)
434 self.session = Session(**extra_args)
435
435
436 self._query_socket = self._context.socket(zmq.DEALER)
436 self._query_socket = self._context.socket(zmq.DEALER)
437 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
437 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
438 if self._ssh:
438 if self._ssh:
439 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
439 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
440 else:
440 else:
441 self._query_socket.connect(url)
441 self._query_socket.connect(url)
442
442
443 self.session.debug = self.debug
443 self.session.debug = self.debug
444
444
445 self._notification_handlers = {'registration_notification' : self._register_engine,
445 self._notification_handlers = {'registration_notification' : self._register_engine,
446 'unregistration_notification' : self._unregister_engine,
446 'unregistration_notification' : self._unregister_engine,
447 'shutdown_notification' : lambda msg: self.close(),
447 'shutdown_notification' : lambda msg: self.close(),
448 }
448 }
449 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
449 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
450 'apply_reply' : self._handle_apply_reply}
450 'apply_reply' : self._handle_apply_reply}
451 self._connect(sshserver, ssh_kwargs, timeout)
451 self._connect(sshserver, ssh_kwargs, timeout)
452
452
453 def __del__(self):
453 def __del__(self):
454 """cleanup sockets, but _not_ context."""
454 """cleanup sockets, but _not_ context."""
455 self.close()
455 self.close()
456
456
457 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
457 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
458 if ipython_dir is None:
458 if ipython_dir is None:
459 ipython_dir = get_ipython_dir()
459 ipython_dir = get_ipython_dir()
460 if profile_dir is not None:
460 if profile_dir is not None:
461 try:
461 try:
462 self._cd = ProfileDir.find_profile_dir(profile_dir)
462 self._cd = ProfileDir.find_profile_dir(profile_dir)
463 return
463 return
464 except ProfileDirError:
464 except ProfileDirError:
465 pass
465 pass
466 elif profile is not None:
466 elif profile is not None:
467 try:
467 try:
468 self._cd = ProfileDir.find_profile_dir_by_name(
468 self._cd = ProfileDir.find_profile_dir_by_name(
469 ipython_dir, profile)
469 ipython_dir, profile)
470 return
470 return
471 except ProfileDirError:
471 except ProfileDirError:
472 pass
472 pass
473 self._cd = None
473 self._cd = None
474
474
475 def _update_engines(self, engines):
475 def _update_engines(self, engines):
476 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
476 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
477 for k,v in engines.iteritems():
477 for k,v in engines.iteritems():
478 eid = int(k)
478 eid = int(k)
479 self._engines[eid] = v
479 self._engines[eid] = v
480 self._ids.append(eid)
480 self._ids.append(eid)
481 self._ids = sorted(self._ids)
481 self._ids = sorted(self._ids)
482 if sorted(self._engines.keys()) != range(len(self._engines)) and \
482 if sorted(self._engines.keys()) != range(len(self._engines)) and \
483 self._task_scheme == 'pure' and self._task_socket:
483 self._task_scheme == 'pure' and self._task_socket:
484 self._stop_scheduling_tasks()
484 self._stop_scheduling_tasks()
485
485
486 def _stop_scheduling_tasks(self):
486 def _stop_scheduling_tasks(self):
487 """Stop scheduling tasks because an engine has been unregistered
487 """Stop scheduling tasks because an engine has been unregistered
488 from a pure ZMQ scheduler.
488 from a pure ZMQ scheduler.
489 """
489 """
490 self._task_socket.close()
490 self._task_socket.close()
491 self._task_socket = None
491 self._task_socket = None
492 msg = "An engine has been unregistered, and we are using pure " +\
492 msg = "An engine has been unregistered, and we are using pure " +\
493 "ZMQ task scheduling. Task farming will be disabled."
493 "ZMQ task scheduling. Task farming will be disabled."
494 if self.outstanding:
494 if self.outstanding:
495 msg += " If you were running tasks when this happened, " +\
495 msg += " If you were running tasks when this happened, " +\
496 "some `outstanding` msg_ids may never resolve."
496 "some `outstanding` msg_ids may never resolve."
497 warnings.warn(msg, RuntimeWarning)
497 warnings.warn(msg, RuntimeWarning)
498
498
499 def _build_targets(self, targets):
499 def _build_targets(self, targets):
500 """Turn valid target IDs or 'all' into two lists:
500 """Turn valid target IDs or 'all' into two lists:
501 (int_ids, uuids).
501 (int_ids, uuids).
502 """
502 """
503 if not self._ids:
503 if not self._ids:
504 # flush notification socket if no engines yet, just in case
504 # flush notification socket if no engines yet, just in case
505 if not self.ids:
505 if not self.ids:
506 raise error.NoEnginesRegistered("Can't build targets without any engines")
506 raise error.NoEnginesRegistered("Can't build targets without any engines")
507
507
508 if targets is None:
508 if targets is None:
509 targets = self._ids
509 targets = self._ids
510 elif isinstance(targets, basestring):
510 elif isinstance(targets, basestring):
511 if targets.lower() == 'all':
511 if targets.lower() == 'all':
512 targets = self._ids
512 targets = self._ids
513 else:
513 else:
514 raise TypeError("%r not valid str target, must be 'all'"%(targets))
514 raise TypeError("%r not valid str target, must be 'all'"%(targets))
515 elif isinstance(targets, int):
515 elif isinstance(targets, int):
516 if targets < 0:
516 if targets < 0:
517 targets = self.ids[targets]
517 targets = self.ids[targets]
518 if targets not in self._ids:
518 if targets not in self._ids:
519 raise IndexError("No such engine: %i"%targets)
519 raise IndexError("No such engine: %i"%targets)
520 targets = [targets]
520 targets = [targets]
521
521
522 if isinstance(targets, slice):
522 if isinstance(targets, slice):
523 indices = range(len(self._ids))[targets]
523 indices = range(len(self._ids))[targets]
524 ids = self.ids
524 ids = self.ids
525 targets = [ ids[i] for i in indices ]
525 targets = [ ids[i] for i in indices ]
526
526
527 if not isinstance(targets, (tuple, list, xrange)):
527 if not isinstance(targets, (tuple, list, xrange)):
528 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
528 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
529
529
530 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
530 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
531
531
532 def _connect(self, sshserver, ssh_kwargs, timeout):
532 def _connect(self, sshserver, ssh_kwargs, timeout):
533 """setup all our socket connections to the cluster. This is called from
533 """setup all our socket connections to the cluster. This is called from
534 __init__."""
534 __init__."""
535
535
536 # Maybe allow reconnecting?
536 # Maybe allow reconnecting?
537 if self._connected:
537 if self._connected:
538 return
538 return
539 self._connected=True
539 self._connected=True
540
540
541 def connect_socket(s, url):
541 def connect_socket(s, url):
542 url = util.disambiguate_url(url, self._config['location'])
542 url = util.disambiguate_url(url, self._config['location'])
543 if self._ssh:
543 if self._ssh:
544 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
544 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
545 else:
545 else:
546 return s.connect(url)
546 return s.connect(url)
547
547
548 self.session.send(self._query_socket, 'connection_request')
548 self.session.send(self._query_socket, 'connection_request')
549 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
549 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
550 poller = zmq.Poller()
550 poller = zmq.Poller()
551 poller.register(self._query_socket, zmq.POLLIN)
551 poller.register(self._query_socket, zmq.POLLIN)
552 # poll expects milliseconds, timeout is seconds
552 # poll expects milliseconds, timeout is seconds
553 evts = poller.poll(timeout*1000)
553 evts = poller.poll(timeout*1000)
554 if not evts:
554 if not evts:
555 raise error.TimeoutError("Hub connection request timed out")
555 raise error.TimeoutError("Hub connection request timed out")
556 idents,msg = self.session.recv(self._query_socket,mode=0)
556 idents,msg = self.session.recv(self._query_socket,mode=0)
557 if self.debug:
557 if self.debug:
558 pprint(msg)
558 pprint(msg)
559 msg = Message(msg)
559 msg = Message(msg)
560 content = msg.content
560 content = msg.content
561 self._config['registration'] = dict(content)
561 self._config['registration'] = dict(content)
562 if content.status == 'ok':
562 if content.status == 'ok':
563 ident = self.session.bsession
563 ident = self.session.bsession
564 if content.mux:
564 if content.mux:
565 self._mux_socket = self._context.socket(zmq.DEALER)
565 self._mux_socket = self._context.socket(zmq.DEALER)
566 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
566 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
567 connect_socket(self._mux_socket, content.mux)
567 connect_socket(self._mux_socket, content.mux)
568 if content.task:
568 if content.task:
569 self._task_scheme, task_addr = content.task
569 self._task_scheme, task_addr = content.task
570 self._task_socket = self._context.socket(zmq.DEALER)
570 self._task_socket = self._context.socket(zmq.DEALER)
571 self._task_socket.setsockopt(zmq.IDENTITY, ident)
571 self._task_socket.setsockopt(zmq.IDENTITY, ident)
572 connect_socket(self._task_socket, task_addr)
572 connect_socket(self._task_socket, task_addr)
573 if content.notification:
573 if content.notification:
574 self._notification_socket = self._context.socket(zmq.SUB)
574 self._notification_socket = self._context.socket(zmq.SUB)
575 connect_socket(self._notification_socket, content.notification)
575 connect_socket(self._notification_socket, content.notification)
576 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
576 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
577 # if content.query:
577 # if content.query:
578 # self._query_socket = self._context.socket(zmq.DEALER)
578 # self._query_socket = self._context.socket(zmq.DEALER)
579 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
579 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
580 # connect_socket(self._query_socket, content.query)
580 # connect_socket(self._query_socket, content.query)
581 if content.control:
581 if content.control:
582 self._control_socket = self._context.socket(zmq.DEALER)
582 self._control_socket = self._context.socket(zmq.DEALER)
583 self._control_socket.setsockopt(zmq.IDENTITY, ident)
583 self._control_socket.setsockopt(zmq.IDENTITY, ident)
584 connect_socket(self._control_socket, content.control)
584 connect_socket(self._control_socket, content.control)
585 if content.iopub:
585 if content.iopub:
586 self._iopub_socket = self._context.socket(zmq.SUB)
586 self._iopub_socket = self._context.socket(zmq.SUB)
587 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
587 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
588 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
588 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
589 connect_socket(self._iopub_socket, content.iopub)
589 connect_socket(self._iopub_socket, content.iopub)
590 self._update_engines(dict(content.engines))
590 self._update_engines(dict(content.engines))
591 else:
591 else:
592 self._connected = False
592 self._connected = False
593 raise Exception("Failed to connect!")
593 raise Exception("Failed to connect!")
594
594
595 #--------------------------------------------------------------------------
595 #--------------------------------------------------------------------------
596 # handlers and callbacks for incoming messages
596 # handlers and callbacks for incoming messages
597 #--------------------------------------------------------------------------
597 #--------------------------------------------------------------------------
598
598
599 def _unwrap_exception(self, content):
599 def _unwrap_exception(self, content):
600 """unwrap exception, and remap engine_id to int."""
600 """unwrap exception, and remap engine_id to int."""
601 e = error.unwrap_exception(content)
601 e = error.unwrap_exception(content)
602 # print e.traceback
602 # print e.traceback
603 if e.engine_info:
603 if e.engine_info:
604 e_uuid = e.engine_info['engine_uuid']
604 e_uuid = e.engine_info['engine_uuid']
605 eid = self._engines[e_uuid]
605 eid = self._engines[e_uuid]
606 e.engine_info['engine_id'] = eid
606 e.engine_info['engine_id'] = eid
607 return e
607 return e
608
608
609 def _extract_metadata(self, header, parent, content):
609 def _extract_metadata(self, header, parent, content):
610 md = {'msg_id' : parent['msg_id'],
610 md = {'msg_id' : parent['msg_id'],
611 'received' : datetime.now(),
611 'received' : datetime.now(),
612 'engine_uuid' : header.get('engine', None),
612 'engine_uuid' : header.get('engine', None),
613 'follow' : parent.get('follow', []),
613 'follow' : parent.get('follow', []),
614 'after' : parent.get('after', []),
614 'after' : parent.get('after', []),
615 'status' : content['status'],
615 'status' : content['status'],
616 }
616 }
617
617
618 if md['engine_uuid'] is not None:
618 if md['engine_uuid'] is not None:
619 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
619 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
620
620
621 if 'date' in parent:
621 if 'date' in parent:
622 md['submitted'] = parent['date']
622 md['submitted'] = parent['date']
623 if 'started' in header:
623 if 'started' in header:
624 md['started'] = header['started']
624 md['started'] = header['started']
625 if 'date' in header:
625 if 'date' in header:
626 md['completed'] = header['date']
626 md['completed'] = header['date']
627 return md
627 return md
628
628
629 def _register_engine(self, msg):
629 def _register_engine(self, msg):
630 """Register a new engine, and update our connection info."""
630 """Register a new engine, and update our connection info."""
631 content = msg['content']
631 content = msg['content']
632 eid = content['id']
632 eid = content['id']
633 d = {eid : content['queue']}
633 d = {eid : content['queue']}
634 self._update_engines(d)
634 self._update_engines(d)
635
635
636 def _unregister_engine(self, msg):
636 def _unregister_engine(self, msg):
637 """Unregister an engine that has died."""
637 """Unregister an engine that has died."""
638 content = msg['content']
638 content = msg['content']
639 eid = int(content['id'])
639 eid = int(content['id'])
640 if eid in self._ids:
640 if eid in self._ids:
641 self._ids.remove(eid)
641 self._ids.remove(eid)
642 uuid = self._engines.pop(eid)
642 uuid = self._engines.pop(eid)
643
643
644 self._handle_stranded_msgs(eid, uuid)
644 self._handle_stranded_msgs(eid, uuid)
645
645
646 if self._task_socket and self._task_scheme == 'pure':
646 if self._task_socket and self._task_scheme == 'pure':
647 self._stop_scheduling_tasks()
647 self._stop_scheduling_tasks()
648
648
649 def _handle_stranded_msgs(self, eid, uuid):
649 def _handle_stranded_msgs(self, eid, uuid):
650 """Handle messages known to be on an engine when the engine unregisters.
650 """Handle messages known to be on an engine when the engine unregisters.
651
651
652 It is possible that this will fire prematurely - that is, an engine will
652 It is possible that this will fire prematurely - that is, an engine will
653 go down after completing a result, and the client will be notified
653 go down after completing a result, and the client will be notified
654 of the unregistration and later receive the successful result.
654 of the unregistration and later receive the successful result.
655 """
655 """
656
656
657 outstanding = self._outstanding_dict[uuid]
657 outstanding = self._outstanding_dict[uuid]
658
658
659 for msg_id in list(outstanding):
659 for msg_id in list(outstanding):
660 if msg_id in self.results:
660 if msg_id in self.results:
661 # we already
661 # we already
662 continue
662 continue
663 try:
663 try:
664 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
664 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
665 except:
665 except:
666 content = error.wrap_exception()
666 content = error.wrap_exception()
667 # build a fake message:
667 # build a fake message:
668 parent = {}
668 parent = {}
669 header = {}
669 header = {}
670 parent['msg_id'] = msg_id
670 parent['msg_id'] = msg_id
671 header['engine'] = uuid
671 header['engine'] = uuid
672 header['date'] = datetime.now()
672 header['date'] = datetime.now()
673 msg = dict(parent_header=parent, header=header, content=content)
673 msg = dict(parent_header=parent, header=header, content=content)
674 self._handle_apply_reply(msg)
674 self._handle_apply_reply(msg)
675
675
676 def _handle_execute_reply(self, msg):
676 def _handle_execute_reply(self, msg):
677 """Save the reply to an execute_request into our results.
677 """Save the reply to an execute_request into our results.
678
678
679 execute messages are never actually used. apply is used instead.
679 execute messages are never actually used. apply is used instead.
680 """
680 """
681
681
682 parent = msg['parent_header']
682 parent = msg['parent_header']
683 msg_id = parent['msg_id']
683 msg_id = parent['msg_id']
684 if msg_id not in self.outstanding:
684 if msg_id not in self.outstanding:
685 if msg_id in self.history:
685 if msg_id in self.history:
686 print ("got stale result: %s"%msg_id)
686 print ("got stale result: %s"%msg_id)
687 else:
687 else:
688 print ("got unknown result: %s"%msg_id)
688 print ("got unknown result: %s"%msg_id)
689 else:
689 else:
690 self.outstanding.remove(msg_id)
690 self.outstanding.remove(msg_id)
691
691
692 content = msg['content']
692 content = msg['content']
693 header = msg['header']
693 header = msg['header']
694
694
695 # construct metadata:
695 # construct metadata:
696 md = self.metadata[msg_id]
696 md = self.metadata[msg_id]
697 md.update(self._extract_metadata(header, parent, content))
697 md.update(self._extract_metadata(header, parent, content))
698 # is this redundant?
698 # is this redundant?
699 self.metadata[msg_id] = md
699 self.metadata[msg_id] = md
700
700
701 e_outstanding = self._outstanding_dict[md['engine_uuid']]
701 e_outstanding = self._outstanding_dict[md['engine_uuid']]
702 if msg_id in e_outstanding:
702 if msg_id in e_outstanding:
703 e_outstanding.remove(msg_id)
703 e_outstanding.remove(msg_id)
704
704
705 # construct result:
705 # construct result:
706 if content['status'] == 'ok':
706 if content['status'] == 'ok':
707 self.results[msg_id] = ExecuteReply(msg_id, content, md)
707 self.results[msg_id] = ExecuteReply(msg_id, content, md)
708 elif content['status'] == 'aborted':
708 elif content['status'] == 'aborted':
709 self.results[msg_id] = error.TaskAborted(msg_id)
709 self.results[msg_id] = error.TaskAborted(msg_id)
710 elif content['status'] == 'resubmitted':
710 elif content['status'] == 'resubmitted':
711 # TODO: handle resubmission
711 # TODO: handle resubmission
712 pass
712 pass
713 else:
713 else:
714 self.results[msg_id] = self._unwrap_exception(content)
714 self.results[msg_id] = self._unwrap_exception(content)
715
715
716 def _handle_apply_reply(self, msg):
716 def _handle_apply_reply(self, msg):
717 """Save the reply to an apply_request into our results."""
717 """Save the reply to an apply_request into our results."""
718 parent = msg['parent_header']
718 parent = msg['parent_header']
719 msg_id = parent['msg_id']
719 msg_id = parent['msg_id']
720 if msg_id not in self.outstanding:
720 if msg_id not in self.outstanding:
721 if msg_id in self.history:
721 if msg_id in self.history:
722 print ("got stale result: %s"%msg_id)
722 print ("got stale result: %s"%msg_id)
723 print self.results[msg_id]
723 print self.results[msg_id]
724 print msg
724 print msg
725 else:
725 else:
726 print ("got unknown result: %s"%msg_id)
726 print ("got unknown result: %s"%msg_id)
727 else:
727 else:
728 self.outstanding.remove(msg_id)
728 self.outstanding.remove(msg_id)
729 content = msg['content']
729 content = msg['content']
730 header = msg['header']
730 header = msg['header']
731
731
732 # construct metadata:
732 # construct metadata:
733 md = self.metadata[msg_id]
733 md = self.metadata[msg_id]
734 md.update(self._extract_metadata(header, parent, content))
734 md.update(self._extract_metadata(header, parent, content))
735 # is this redundant?
735 # is this redundant?
736 self.metadata[msg_id] = md
736 self.metadata[msg_id] = md
737
737
738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
739 if msg_id in e_outstanding:
739 if msg_id in e_outstanding:
740 e_outstanding.remove(msg_id)
740 e_outstanding.remove(msg_id)
741
741
742 # construct result:
742 # construct result:
743 if content['status'] == 'ok':
743 if content['status'] == 'ok':
744 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
744 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
745 elif content['status'] == 'aborted':
745 elif content['status'] == 'aborted':
746 self.results[msg_id] = error.TaskAborted(msg_id)
746 self.results[msg_id] = error.TaskAborted(msg_id)
747 elif content['status'] == 'resubmitted':
747 elif content['status'] == 'resubmitted':
748 # TODO: handle resubmission
748 # TODO: handle resubmission
749 pass
749 pass
750 else:
750 else:
751 self.results[msg_id] = self._unwrap_exception(content)
751 self.results[msg_id] = self._unwrap_exception(content)
752
752
753 def _flush_notifications(self):
753 def _flush_notifications(self):
754 """Flush notifications of engine registrations waiting
754 """Flush notifications of engine registrations waiting
755 in ZMQ queue."""
755 in ZMQ queue."""
756 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
756 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
757 while msg is not None:
757 while msg is not None:
758 if self.debug:
758 if self.debug:
759 pprint(msg)
759 pprint(msg)
760 msg_type = msg['header']['msg_type']
760 msg_type = msg['header']['msg_type']
761 handler = self._notification_handlers.get(msg_type, None)
761 handler = self._notification_handlers.get(msg_type, None)
762 if handler is None:
762 if handler is None:
763 raise Exception("Unhandled message type: %s"%msg.msg_type)
763 raise Exception("Unhandled message type: %s"%msg.msg_type)
764 else:
764 else:
765 handler(msg)
765 handler(msg)
766 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
766 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
767
767
768 def _flush_results(self, sock):
768 def _flush_results(self, sock):
769 """Flush task or queue results waiting in ZMQ queue."""
769 """Flush task or queue results waiting in ZMQ queue."""
770 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
770 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
771 while msg is not None:
771 while msg is not None:
772 if self.debug:
772 if self.debug:
773 pprint(msg)
773 pprint(msg)
774 msg_type = msg['header']['msg_type']
774 msg_type = msg['header']['msg_type']
775 handler = self._queue_handlers.get(msg_type, None)
775 handler = self._queue_handlers.get(msg_type, None)
776 if handler is None:
776 if handler is None:
777 raise Exception("Unhandled message type: %s"%msg.msg_type)
777 raise Exception("Unhandled message type: %s"%msg.msg_type)
778 else:
778 else:
779 handler(msg)
779 handler(msg)
780 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
780 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
781
781
782 def _flush_control(self, sock):
782 def _flush_control(self, sock):
783 """Flush replies from the control channel waiting
783 """Flush replies from the control channel waiting
784 in the ZMQ queue.
784 in the ZMQ queue.
785
785
786 Currently: ignore them."""
786 Currently: ignore them."""
787 if self._ignored_control_replies <= 0:
787 if self._ignored_control_replies <= 0:
788 return
788 return
789 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
789 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
790 while msg is not None:
790 while msg is not None:
791 self._ignored_control_replies -= 1
791 self._ignored_control_replies -= 1
792 if self.debug:
792 if self.debug:
793 pprint(msg)
793 pprint(msg)
794 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
794 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
795
795
796 def _flush_ignored_control(self):
796 def _flush_ignored_control(self):
797 """flush ignored control replies"""
797 """flush ignored control replies"""
798 while self._ignored_control_replies > 0:
798 while self._ignored_control_replies > 0:
799 self.session.recv(self._control_socket)
799 self.session.recv(self._control_socket)
800 self._ignored_control_replies -= 1
800 self._ignored_control_replies -= 1
801
801
802 def _flush_ignored_hub_replies(self):
802 def _flush_ignored_hub_replies(self):
803 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
803 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
804 while msg is not None:
804 while msg is not None:
805 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
805 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
806
806
807 def _flush_iopub(self, sock):
807 def _flush_iopub(self, sock):
808 """Flush replies from the iopub channel waiting
808 """Flush replies from the iopub channel waiting
809 in the ZMQ queue.
809 in the ZMQ queue.
810 """
810 """
811 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
811 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
812 while msg is not None:
812 while msg is not None:
813 if self.debug:
813 if self.debug:
814 pprint(msg)
814 pprint(msg)
815 parent = msg['parent_header']
815 parent = msg['parent_header']
816 # ignore IOPub messages with no parent.
816 # ignore IOPub messages with no parent.
817 # Caused by print statements or warnings from before the first execution.
817 # Caused by print statements or warnings from before the first execution.
818 if not parent:
818 if not parent:
819 continue
819 continue
820 msg_id = parent['msg_id']
820 msg_id = parent['msg_id']
821 content = msg['content']
821 content = msg['content']
822 header = msg['header']
822 header = msg['header']
823 msg_type = msg['header']['msg_type']
823 msg_type = msg['header']['msg_type']
824
824
825 # init metadata:
825 # init metadata:
826 md = self.metadata[msg_id]
826 md = self.metadata[msg_id]
827
827
828 if msg_type == 'stream':
828 if msg_type == 'stream':
829 name = content['name']
829 name = content['name']
830 s = md[name] or ''
830 s = md[name] or ''
831 md[name] = s + content['data']
831 md[name] = s + content['data']
832 elif msg_type == 'pyerr':
832 elif msg_type == 'pyerr':
833 md.update({'pyerr' : self._unwrap_exception(content)})
833 md.update({'pyerr' : self._unwrap_exception(content)})
834 elif msg_type == 'pyin':
834 elif msg_type == 'pyin':
835 md.update({'pyin' : content['code']})
835 md.update({'pyin' : content['code']})
836 elif msg_type == 'display_data':
836 elif msg_type == 'display_data':
837 md['outputs'].append(content.get('data'))
837 md['outputs'].append(content)
838 elif msg_type == 'pyout':
838 elif msg_type == 'pyout':
839 md['pyout'] = content.get('data')
839 md['pyout'] = content
840 else:
840 else:
841 # unhandled msg_type (status, etc.)
841 # unhandled msg_type (status, etc.)
842 pass
842 pass
843
843
844 # reduntant?
844 # reduntant?
845 self.metadata[msg_id] = md
845 self.metadata[msg_id] = md
846
846
847 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
847 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
848
848
849 #--------------------------------------------------------------------------
849 #--------------------------------------------------------------------------
850 # len, getitem
850 # len, getitem
851 #--------------------------------------------------------------------------
851 #--------------------------------------------------------------------------
852
852
853 def __len__(self):
853 def __len__(self):
854 """len(client) returns # of engines."""
854 """len(client) returns # of engines."""
855 return len(self.ids)
855 return len(self.ids)
856
856
857 def __getitem__(self, key):
857 def __getitem__(self, key):
858 """index access returns DirectView multiplexer objects
858 """index access returns DirectView multiplexer objects
859
859
860 Must be int, slice, or list/tuple/xrange of ints"""
860 Must be int, slice, or list/tuple/xrange of ints"""
861 if not isinstance(key, (int, slice, tuple, list, xrange)):
861 if not isinstance(key, (int, slice, tuple, list, xrange)):
862 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
862 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
863 else:
863 else:
864 return self.direct_view(key)
864 return self.direct_view(key)
865
865
866 #--------------------------------------------------------------------------
866 #--------------------------------------------------------------------------
867 # Begin public methods
867 # Begin public methods
868 #--------------------------------------------------------------------------
868 #--------------------------------------------------------------------------
869
869
870 @property
870 @property
871 def ids(self):
871 def ids(self):
872 """Always up-to-date ids property."""
872 """Always up-to-date ids property."""
873 self._flush_notifications()
873 self._flush_notifications()
874 # always copy:
874 # always copy:
875 return list(self._ids)
875 return list(self._ids)
876
876
877 def close(self):
877 def close(self):
878 if self._closed:
878 if self._closed:
879 return
879 return
880 self.stop_spin_thread()
880 self.stop_spin_thread()
881 snames = filter(lambda n: n.endswith('socket'), dir(self))
881 snames = filter(lambda n: n.endswith('socket'), dir(self))
882 for socket in map(lambda name: getattr(self, name), snames):
882 for socket in map(lambda name: getattr(self, name), snames):
883 if isinstance(socket, zmq.Socket) and not socket.closed:
883 if isinstance(socket, zmq.Socket) and not socket.closed:
884 socket.close()
884 socket.close()
885 self._closed = True
885 self._closed = True
886
886
887 def _spin_every(self, interval=1):
887 def _spin_every(self, interval=1):
888 """target func for use in spin_thread"""
888 """target func for use in spin_thread"""
889 while True:
889 while True:
890 if self._stop_spinning.is_set():
890 if self._stop_spinning.is_set():
891 return
891 return
892 time.sleep(interval)
892 time.sleep(interval)
893 self.spin()
893 self.spin()
894
894
895 def spin_thread(self, interval=1):
895 def spin_thread(self, interval=1):
896 """call Client.spin() in a background thread on some regular interval
896 """call Client.spin() in a background thread on some regular interval
897
897
898 This helps ensure that messages don't pile up too much in the zmq queue
898 This helps ensure that messages don't pile up too much in the zmq queue
899 while you are working on other things, or just leaving an idle terminal.
899 while you are working on other things, or just leaving an idle terminal.
900
900
901 It also helps limit potential padding of the `received` timestamp
901 It also helps limit potential padding of the `received` timestamp
902 on AsyncResult objects, used for timings.
902 on AsyncResult objects, used for timings.
903
903
904 Parameters
904 Parameters
905 ----------
905 ----------
906
906
907 interval : float, optional
907 interval : float, optional
908 The interval on which to spin the client in the background thread
908 The interval on which to spin the client in the background thread
909 (simply passed to time.sleep).
909 (simply passed to time.sleep).
910
910
911 Notes
911 Notes
912 -----
912 -----
913
913
914 For precision timing, you may want to use this method to put a bound
914 For precision timing, you may want to use this method to put a bound
915 on the jitter (in seconds) in `received` timestamps used
915 on the jitter (in seconds) in `received` timestamps used
916 in AsyncResult.wall_time.
916 in AsyncResult.wall_time.
917
917
918 """
918 """
919 if self._spin_thread is not None:
919 if self._spin_thread is not None:
920 self.stop_spin_thread()
920 self.stop_spin_thread()
921 self._stop_spinning.clear()
921 self._stop_spinning.clear()
922 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
922 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
923 self._spin_thread.daemon = True
923 self._spin_thread.daemon = True
924 self._spin_thread.start()
924 self._spin_thread.start()
925
925
926 def stop_spin_thread(self):
926 def stop_spin_thread(self):
927 """stop background spin_thread, if any"""
927 """stop background spin_thread, if any"""
928 if self._spin_thread is not None:
928 if self._spin_thread is not None:
929 self._stop_spinning.set()
929 self._stop_spinning.set()
930 self._spin_thread.join()
930 self._spin_thread.join()
931 self._spin_thread = None
931 self._spin_thread = None
932
932
933 def spin(self):
933 def spin(self):
934 """Flush any registration notifications and execution results
934 """Flush any registration notifications and execution results
935 waiting in the ZMQ queue.
935 waiting in the ZMQ queue.
936 """
936 """
937 if self._notification_socket:
937 if self._notification_socket:
938 self._flush_notifications()
938 self._flush_notifications()
939 if self._iopub_socket:
939 if self._iopub_socket:
940 self._flush_iopub(self._iopub_socket)
940 self._flush_iopub(self._iopub_socket)
941 if self._mux_socket:
941 if self._mux_socket:
942 self._flush_results(self._mux_socket)
942 self._flush_results(self._mux_socket)
943 if self._task_socket:
943 if self._task_socket:
944 self._flush_results(self._task_socket)
944 self._flush_results(self._task_socket)
945 if self._control_socket:
945 if self._control_socket:
946 self._flush_control(self._control_socket)
946 self._flush_control(self._control_socket)
947 if self._query_socket:
947 if self._query_socket:
948 self._flush_ignored_hub_replies()
948 self._flush_ignored_hub_replies()
949
949
950 def wait(self, jobs=None, timeout=-1):
950 def wait(self, jobs=None, timeout=-1):
951 """waits on one or more `jobs`, for up to `timeout` seconds.
951 """waits on one or more `jobs`, for up to `timeout` seconds.
952
952
953 Parameters
953 Parameters
954 ----------
954 ----------
955
955
956 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
956 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
957 ints are indices to self.history
957 ints are indices to self.history
958 strs are msg_ids
958 strs are msg_ids
959 default: wait on all outstanding messages
959 default: wait on all outstanding messages
960 timeout : float
960 timeout : float
961 a time in seconds, after which to give up.
961 a time in seconds, after which to give up.
962 default is -1, which means no timeout
962 default is -1, which means no timeout
963
963
964 Returns
964 Returns
965 -------
965 -------
966
966
967 True : when all msg_ids are done
967 True : when all msg_ids are done
968 False : timeout reached, some msg_ids still outstanding
968 False : timeout reached, some msg_ids still outstanding
969 """
969 """
970 tic = time.time()
970 tic = time.time()
971 if jobs is None:
971 if jobs is None:
972 theids = self.outstanding
972 theids = self.outstanding
973 else:
973 else:
974 if isinstance(jobs, (int, basestring, AsyncResult)):
974 if isinstance(jobs, (int, basestring, AsyncResult)):
975 jobs = [jobs]
975 jobs = [jobs]
976 theids = set()
976 theids = set()
977 for job in jobs:
977 for job in jobs:
978 if isinstance(job, int):
978 if isinstance(job, int):
979 # index access
979 # index access
980 job = self.history[job]
980 job = self.history[job]
981 elif isinstance(job, AsyncResult):
981 elif isinstance(job, AsyncResult):
982 map(theids.add, job.msg_ids)
982 map(theids.add, job.msg_ids)
983 continue
983 continue
984 theids.add(job)
984 theids.add(job)
985 if not theids.intersection(self.outstanding):
985 if not theids.intersection(self.outstanding):
986 return True
986 return True
987 self.spin()
987 self.spin()
988 while theids.intersection(self.outstanding):
988 while theids.intersection(self.outstanding):
989 if timeout >= 0 and ( time.time()-tic ) > timeout:
989 if timeout >= 0 and ( time.time()-tic ) > timeout:
990 break
990 break
991 time.sleep(1e-3)
991 time.sleep(1e-3)
992 self.spin()
992 self.spin()
993 return len(theids.intersection(self.outstanding)) == 0
993 return len(theids.intersection(self.outstanding)) == 0
994
994
995 #--------------------------------------------------------------------------
995 #--------------------------------------------------------------------------
996 # Control methods
996 # Control methods
997 #--------------------------------------------------------------------------
997 #--------------------------------------------------------------------------
998
998
999 @spin_first
999 @spin_first
1000 def clear(self, targets=None, block=None):
1000 def clear(self, targets=None, block=None):
1001 """Clear the namespace in target(s)."""
1001 """Clear the namespace in target(s)."""
1002 block = self.block if block is None else block
1002 block = self.block if block is None else block
1003 targets = self._build_targets(targets)[0]
1003 targets = self._build_targets(targets)[0]
1004 for t in targets:
1004 for t in targets:
1005 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1005 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1006 error = False
1006 error = False
1007 if block:
1007 if block:
1008 self._flush_ignored_control()
1008 self._flush_ignored_control()
1009 for i in range(len(targets)):
1009 for i in range(len(targets)):
1010 idents,msg = self.session.recv(self._control_socket,0)
1010 idents,msg = self.session.recv(self._control_socket,0)
1011 if self.debug:
1011 if self.debug:
1012 pprint(msg)
1012 pprint(msg)
1013 if msg['content']['status'] != 'ok':
1013 if msg['content']['status'] != 'ok':
1014 error = self._unwrap_exception(msg['content'])
1014 error = self._unwrap_exception(msg['content'])
1015 else:
1015 else:
1016 self._ignored_control_replies += len(targets)
1016 self._ignored_control_replies += len(targets)
1017 if error:
1017 if error:
1018 raise error
1018 raise error
1019
1019
1020
1020
1021 @spin_first
1021 @spin_first
1022 def abort(self, jobs=None, targets=None, block=None):
1022 def abort(self, jobs=None, targets=None, block=None):
1023 """Abort specific jobs from the execution queues of target(s).
1023 """Abort specific jobs from the execution queues of target(s).
1024
1024
1025 This is a mechanism to prevent jobs that have already been submitted
1025 This is a mechanism to prevent jobs that have already been submitted
1026 from executing.
1026 from executing.
1027
1027
1028 Parameters
1028 Parameters
1029 ----------
1029 ----------
1030
1030
1031 jobs : msg_id, list of msg_ids, or AsyncResult
1031 jobs : msg_id, list of msg_ids, or AsyncResult
1032 The jobs to be aborted
1032 The jobs to be aborted
1033
1033
1034 If unspecified/None: abort all outstanding jobs.
1034 If unspecified/None: abort all outstanding jobs.
1035
1035
1036 """
1036 """
1037 block = self.block if block is None else block
1037 block = self.block if block is None else block
1038 jobs = jobs if jobs is not None else list(self.outstanding)
1038 jobs = jobs if jobs is not None else list(self.outstanding)
1039 targets = self._build_targets(targets)[0]
1039 targets = self._build_targets(targets)[0]
1040
1040
1041 msg_ids = []
1041 msg_ids = []
1042 if isinstance(jobs, (basestring,AsyncResult)):
1042 if isinstance(jobs, (basestring,AsyncResult)):
1043 jobs = [jobs]
1043 jobs = [jobs]
1044 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1044 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1045 if bad_ids:
1045 if bad_ids:
1046 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1046 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1047 for j in jobs:
1047 for j in jobs:
1048 if isinstance(j, AsyncResult):
1048 if isinstance(j, AsyncResult):
1049 msg_ids.extend(j.msg_ids)
1049 msg_ids.extend(j.msg_ids)
1050 else:
1050 else:
1051 msg_ids.append(j)
1051 msg_ids.append(j)
1052 content = dict(msg_ids=msg_ids)
1052 content = dict(msg_ids=msg_ids)
1053 for t in targets:
1053 for t in targets:
1054 self.session.send(self._control_socket, 'abort_request',
1054 self.session.send(self._control_socket, 'abort_request',
1055 content=content, ident=t)
1055 content=content, ident=t)
1056 error = False
1056 error = False
1057 if block:
1057 if block:
1058 self._flush_ignored_control()
1058 self._flush_ignored_control()
1059 for i in range(len(targets)):
1059 for i in range(len(targets)):
1060 idents,msg = self.session.recv(self._control_socket,0)
1060 idents,msg = self.session.recv(self._control_socket,0)
1061 if self.debug:
1061 if self.debug:
1062 pprint(msg)
1062 pprint(msg)
1063 if msg['content']['status'] != 'ok':
1063 if msg['content']['status'] != 'ok':
1064 error = self._unwrap_exception(msg['content'])
1064 error = self._unwrap_exception(msg['content'])
1065 else:
1065 else:
1066 self._ignored_control_replies += len(targets)
1066 self._ignored_control_replies += len(targets)
1067 if error:
1067 if error:
1068 raise error
1068 raise error
1069
1069
1070 @spin_first
1070 @spin_first
1071 def shutdown(self, targets=None, restart=False, hub=False, block=None):
1071 def shutdown(self, targets=None, restart=False, hub=False, block=None):
1072 """Terminates one or more engine processes, optionally including the hub."""
1072 """Terminates one or more engine processes, optionally including the hub."""
1073 block = self.block if block is None else block
1073 block = self.block if block is None else block
1074 if hub:
1074 if hub:
1075 targets = 'all'
1075 targets = 'all'
1076 targets = self._build_targets(targets)[0]
1076 targets = self._build_targets(targets)[0]
1077 for t in targets:
1077 for t in targets:
1078 self.session.send(self._control_socket, 'shutdown_request',
1078 self.session.send(self._control_socket, 'shutdown_request',
1079 content={'restart':restart},ident=t)
1079 content={'restart':restart},ident=t)
1080 error = False
1080 error = False
1081 if block or hub:
1081 if block or hub:
1082 self._flush_ignored_control()
1082 self._flush_ignored_control()
1083 for i in range(len(targets)):
1083 for i in range(len(targets)):
1084 idents,msg = self.session.recv(self._control_socket, 0)
1084 idents,msg = self.session.recv(self._control_socket, 0)
1085 if self.debug:
1085 if self.debug:
1086 pprint(msg)
1086 pprint(msg)
1087 if msg['content']['status'] != 'ok':
1087 if msg['content']['status'] != 'ok':
1088 error = self._unwrap_exception(msg['content'])
1088 error = self._unwrap_exception(msg['content'])
1089 else:
1089 else:
1090 self._ignored_control_replies += len(targets)
1090 self._ignored_control_replies += len(targets)
1091
1091
1092 if hub:
1092 if hub:
1093 time.sleep(0.25)
1093 time.sleep(0.25)
1094 self.session.send(self._query_socket, 'shutdown_request')
1094 self.session.send(self._query_socket, 'shutdown_request')
1095 idents,msg = self.session.recv(self._query_socket, 0)
1095 idents,msg = self.session.recv(self._query_socket, 0)
1096 if self.debug:
1096 if self.debug:
1097 pprint(msg)
1097 pprint(msg)
1098 if msg['content']['status'] != 'ok':
1098 if msg['content']['status'] != 'ok':
1099 error = self._unwrap_exception(msg['content'])
1099 error = self._unwrap_exception(msg['content'])
1100
1100
1101 if error:
1101 if error:
1102 raise error
1102 raise error
1103
1103
1104 #--------------------------------------------------------------------------
1104 #--------------------------------------------------------------------------
1105 # Execution related methods
1105 # Execution related methods
1106 #--------------------------------------------------------------------------
1106 #--------------------------------------------------------------------------
1107
1107
1108 def _maybe_raise(self, result):
1108 def _maybe_raise(self, result):
1109 """wrapper for maybe raising an exception if apply failed."""
1109 """wrapper for maybe raising an exception if apply failed."""
1110 if isinstance(result, error.RemoteError):
1110 if isinstance(result, error.RemoteError):
1111 raise result
1111 raise result
1112
1112
1113 return result
1113 return result
1114
1114
1115 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1115 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1116 ident=None):
1116 ident=None):
1117 """construct and send an apply message via a socket.
1117 """construct and send an apply message via a socket.
1118
1118
1119 This is the principal method with which all engine execution is performed by views.
1119 This is the principal method with which all engine execution is performed by views.
1120 """
1120 """
1121
1121
1122 if self._closed:
1122 if self._closed:
1123 raise RuntimeError("Client cannot be used after its sockets have been closed")
1123 raise RuntimeError("Client cannot be used after its sockets have been closed")
1124
1124
1125 # defaults:
1125 # defaults:
1126 args = args if args is not None else []
1126 args = args if args is not None else []
1127 kwargs = kwargs if kwargs is not None else {}
1127 kwargs = kwargs if kwargs is not None else {}
1128 subheader = subheader if subheader is not None else {}
1128 subheader = subheader if subheader is not None else {}
1129
1129
1130 # validate arguments
1130 # validate arguments
1131 if not callable(f) and not isinstance(f, Reference):
1131 if not callable(f) and not isinstance(f, Reference):
1132 raise TypeError("f must be callable, not %s"%type(f))
1132 raise TypeError("f must be callable, not %s"%type(f))
1133 if not isinstance(args, (tuple, list)):
1133 if not isinstance(args, (tuple, list)):
1134 raise TypeError("args must be tuple or list, not %s"%type(args))
1134 raise TypeError("args must be tuple or list, not %s"%type(args))
1135 if not isinstance(kwargs, dict):
1135 if not isinstance(kwargs, dict):
1136 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1136 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1137 if not isinstance(subheader, dict):
1137 if not isinstance(subheader, dict):
1138 raise TypeError("subheader must be dict, not %s"%type(subheader))
1138 raise TypeError("subheader must be dict, not %s"%type(subheader))
1139
1139
1140 bufs = util.pack_apply_message(f,args,kwargs)
1140 bufs = util.pack_apply_message(f,args,kwargs)
1141
1141
1142 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1142 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1143 subheader=subheader, track=track)
1143 subheader=subheader, track=track)
1144
1144
1145 msg_id = msg['header']['msg_id']
1145 msg_id = msg['header']['msg_id']
1146 self.outstanding.add(msg_id)
1146 self.outstanding.add(msg_id)
1147 if ident:
1147 if ident:
1148 # possibly routed to a specific engine
1148 # possibly routed to a specific engine
1149 if isinstance(ident, list):
1149 if isinstance(ident, list):
1150 ident = ident[-1]
1150 ident = ident[-1]
1151 if ident in self._engines.values():
1151 if ident in self._engines.values():
1152 # save for later, in case of engine death
1152 # save for later, in case of engine death
1153 self._outstanding_dict[ident].add(msg_id)
1153 self._outstanding_dict[ident].add(msg_id)
1154 self.history.append(msg_id)
1154 self.history.append(msg_id)
1155 self.metadata[msg_id]['submitted'] = datetime.now()
1155 self.metadata[msg_id]['submitted'] = datetime.now()
1156
1156
1157 return msg
1157 return msg
1158
1158
1159 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1159 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1160 """construct and send an execute request via a socket.
1160 """construct and send an execute request via a socket.
1161
1161
1162 """
1162 """
1163
1163
1164 if self._closed:
1164 if self._closed:
1165 raise RuntimeError("Client cannot be used after its sockets have been closed")
1165 raise RuntimeError("Client cannot be used after its sockets have been closed")
1166
1166
1167 # defaults:
1167 # defaults:
1168 subheader = subheader if subheader is not None else {}
1168 subheader = subheader if subheader is not None else {}
1169
1169
1170 # validate arguments
1170 # validate arguments
1171 if not isinstance(code, basestring):
1171 if not isinstance(code, basestring):
1172 raise TypeError("code must be text, not %s" % type(code))
1172 raise TypeError("code must be text, not %s" % type(code))
1173 if not isinstance(subheader, dict):
1173 if not isinstance(subheader, dict):
1174 raise TypeError("subheader must be dict, not %s" % type(subheader))
1174 raise TypeError("subheader must be dict, not %s" % type(subheader))
1175
1175
1176 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1176 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1177
1177
1178
1178
1179 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1179 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1180 subheader=subheader)
1180 subheader=subheader)
1181
1181
1182 msg_id = msg['header']['msg_id']
1182 msg_id = msg['header']['msg_id']
1183 self.outstanding.add(msg_id)
1183 self.outstanding.add(msg_id)
1184 if ident:
1184 if ident:
1185 # possibly routed to a specific engine
1185 # possibly routed to a specific engine
1186 if isinstance(ident, list):
1186 if isinstance(ident, list):
1187 ident = ident[-1]
1187 ident = ident[-1]
1188 if ident in self._engines.values():
1188 if ident in self._engines.values():
1189 # save for later, in case of engine death
1189 # save for later, in case of engine death
1190 self._outstanding_dict[ident].add(msg_id)
1190 self._outstanding_dict[ident].add(msg_id)
1191 self.history.append(msg_id)
1191 self.history.append(msg_id)
1192 self.metadata[msg_id]['submitted'] = datetime.now()
1192 self.metadata[msg_id]['submitted'] = datetime.now()
1193
1193
1194 return msg
1194 return msg
1195
1195
1196 #--------------------------------------------------------------------------
1196 #--------------------------------------------------------------------------
1197 # construct a View object
1197 # construct a View object
1198 #--------------------------------------------------------------------------
1198 #--------------------------------------------------------------------------
1199
1199
1200 def load_balanced_view(self, targets=None):
1200 def load_balanced_view(self, targets=None):
1201 """construct a DirectView object.
1201 """construct a DirectView object.
1202
1202
1203 If no arguments are specified, create a LoadBalancedView
1203 If no arguments are specified, create a LoadBalancedView
1204 using all engines.
1204 using all engines.
1205
1205
1206 Parameters
1206 Parameters
1207 ----------
1207 ----------
1208
1208
1209 targets: list,slice,int,etc. [default: use all engines]
1209 targets: list,slice,int,etc. [default: use all engines]
1210 The subset of engines across which to load-balance
1210 The subset of engines across which to load-balance
1211 """
1211 """
1212 if targets == 'all':
1212 if targets == 'all':
1213 targets = None
1213 targets = None
1214 if targets is not None:
1214 if targets is not None:
1215 targets = self._build_targets(targets)[1]
1215 targets = self._build_targets(targets)[1]
1216 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1216 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1217
1217
1218 def direct_view(self, targets='all'):
1218 def direct_view(self, targets='all'):
1219 """construct a DirectView object.
1219 """construct a DirectView object.
1220
1220
1221 If no targets are specified, create a DirectView using all engines.
1221 If no targets are specified, create a DirectView using all engines.
1222
1222
1223 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1223 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1224 evaluate the target engines at each execution, whereas rc[:] will connect to
1224 evaluate the target engines at each execution, whereas rc[:] will connect to
1225 all *current* engines, and that list will not change.
1225 all *current* engines, and that list will not change.
1226
1226
1227 That is, 'all' will always use all engines, whereas rc[:] will not use
1227 That is, 'all' will always use all engines, whereas rc[:] will not use
1228 engines added after the DirectView is constructed.
1228 engines added after the DirectView is constructed.
1229
1229
1230 Parameters
1230 Parameters
1231 ----------
1231 ----------
1232
1232
1233 targets: list,slice,int,etc. [default: use all engines]
1233 targets: list,slice,int,etc. [default: use all engines]
1234 The engines to use for the View
1234 The engines to use for the View
1235 """
1235 """
1236 single = isinstance(targets, int)
1236 single = isinstance(targets, int)
1237 # allow 'all' to be lazily evaluated at each execution
1237 # allow 'all' to be lazily evaluated at each execution
1238 if targets != 'all':
1238 if targets != 'all':
1239 targets = self._build_targets(targets)[1]
1239 targets = self._build_targets(targets)[1]
1240 if single:
1240 if single:
1241 targets = targets[0]
1241 targets = targets[0]
1242 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1242 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1243
1243
1244 #--------------------------------------------------------------------------
1244 #--------------------------------------------------------------------------
1245 # Query methods
1245 # Query methods
1246 #--------------------------------------------------------------------------
1246 #--------------------------------------------------------------------------
1247
1247
1248 @spin_first
1248 @spin_first
1249 def get_result(self, indices_or_msg_ids=None, block=None):
1249 def get_result(self, indices_or_msg_ids=None, block=None):
1250 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1250 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1251
1251
1252 If the client already has the results, no request to the Hub will be made.
1252 If the client already has the results, no request to the Hub will be made.
1253
1253
1254 This is a convenient way to construct AsyncResult objects, which are wrappers
1254 This is a convenient way to construct AsyncResult objects, which are wrappers
1255 that include metadata about execution, and allow for awaiting results that
1255 that include metadata about execution, and allow for awaiting results that
1256 were not submitted by this Client.
1256 were not submitted by this Client.
1257
1257
1258 It can also be a convenient way to retrieve the metadata associated with
1258 It can also be a convenient way to retrieve the metadata associated with
1259 blocking execution, since it always retrieves
1259 blocking execution, since it always retrieves
1260
1260
1261 Examples
1261 Examples
1262 --------
1262 --------
1263 ::
1263 ::
1264
1264
1265 In [10]: r = client.apply()
1265 In [10]: r = client.apply()
1266
1266
1267 Parameters
1267 Parameters
1268 ----------
1268 ----------
1269
1269
1270 indices_or_msg_ids : integer history index, str msg_id, or list of either
1270 indices_or_msg_ids : integer history index, str msg_id, or list of either
1271 The indices or msg_ids of indices to be retrieved
1271 The indices or msg_ids of indices to be retrieved
1272
1272
1273 block : bool
1273 block : bool
1274 Whether to wait for the result to be done
1274 Whether to wait for the result to be done
1275
1275
1276 Returns
1276 Returns
1277 -------
1277 -------
1278
1278
1279 AsyncResult
1279 AsyncResult
1280 A single AsyncResult object will always be returned.
1280 A single AsyncResult object will always be returned.
1281
1281
1282 AsyncHubResult
1282 AsyncHubResult
1283 A subclass of AsyncResult that retrieves results from the Hub
1283 A subclass of AsyncResult that retrieves results from the Hub
1284
1284
1285 """
1285 """
1286 block = self.block if block is None else block
1286 block = self.block if block is None else block
1287 if indices_or_msg_ids is None:
1287 if indices_or_msg_ids is None:
1288 indices_or_msg_ids = -1
1288 indices_or_msg_ids = -1
1289
1289
1290 if not isinstance(indices_or_msg_ids, (list,tuple)):
1290 if not isinstance(indices_or_msg_ids, (list,tuple)):
1291 indices_or_msg_ids = [indices_or_msg_ids]
1291 indices_or_msg_ids = [indices_or_msg_ids]
1292
1292
1293 theids = []
1293 theids = []
1294 for id in indices_or_msg_ids:
1294 for id in indices_or_msg_ids:
1295 if isinstance(id, int):
1295 if isinstance(id, int):
1296 id = self.history[id]
1296 id = self.history[id]
1297 if not isinstance(id, basestring):
1297 if not isinstance(id, basestring):
1298 raise TypeError("indices must be str or int, not %r"%id)
1298 raise TypeError("indices must be str or int, not %r"%id)
1299 theids.append(id)
1299 theids.append(id)
1300
1300
1301 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1301 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1302 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1302 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1303
1303
1304 if remote_ids:
1304 if remote_ids:
1305 ar = AsyncHubResult(self, msg_ids=theids)
1305 ar = AsyncHubResult(self, msg_ids=theids)
1306 else:
1306 else:
1307 ar = AsyncResult(self, msg_ids=theids)
1307 ar = AsyncResult(self, msg_ids=theids)
1308
1308
1309 if block:
1309 if block:
1310 ar.wait()
1310 ar.wait()
1311
1311
1312 return ar
1312 return ar
1313
1313
1314 @spin_first
1314 @spin_first
1315 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1315 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1316 """Resubmit one or more tasks.
1316 """Resubmit one or more tasks.
1317
1317
1318 in-flight tasks may not be resubmitted.
1318 in-flight tasks may not be resubmitted.
1319
1319
1320 Parameters
1320 Parameters
1321 ----------
1321 ----------
1322
1322
1323 indices_or_msg_ids : integer history index, str msg_id, or list of either
1323 indices_or_msg_ids : integer history index, str msg_id, or list of either
1324 The indices or msg_ids of indices to be retrieved
1324 The indices or msg_ids of indices to be retrieved
1325
1325
1326 block : bool
1326 block : bool
1327 Whether to wait for the result to be done
1327 Whether to wait for the result to be done
1328
1328
1329 Returns
1329 Returns
1330 -------
1330 -------
1331
1331
1332 AsyncHubResult
1332 AsyncHubResult
1333 A subclass of AsyncResult that retrieves results from the Hub
1333 A subclass of AsyncResult that retrieves results from the Hub
1334
1334
1335 """
1335 """
1336 block = self.block if block is None else block
1336 block = self.block if block is None else block
1337 if indices_or_msg_ids is None:
1337 if indices_or_msg_ids is None:
1338 indices_or_msg_ids = -1
1338 indices_or_msg_ids = -1
1339
1339
1340 if not isinstance(indices_or_msg_ids, (list,tuple)):
1340 if not isinstance(indices_or_msg_ids, (list,tuple)):
1341 indices_or_msg_ids = [indices_or_msg_ids]
1341 indices_or_msg_ids = [indices_or_msg_ids]
1342
1342
1343 theids = []
1343 theids = []
1344 for id in indices_or_msg_ids:
1344 for id in indices_or_msg_ids:
1345 if isinstance(id, int):
1345 if isinstance(id, int):
1346 id = self.history[id]
1346 id = self.history[id]
1347 if not isinstance(id, basestring):
1347 if not isinstance(id, basestring):
1348 raise TypeError("indices must be str or int, not %r"%id)
1348 raise TypeError("indices must be str or int, not %r"%id)
1349 theids.append(id)
1349 theids.append(id)
1350
1350
1351 content = dict(msg_ids = theids)
1351 content = dict(msg_ids = theids)
1352
1352
1353 self.session.send(self._query_socket, 'resubmit_request', content)
1353 self.session.send(self._query_socket, 'resubmit_request', content)
1354
1354
1355 zmq.select([self._query_socket], [], [])
1355 zmq.select([self._query_socket], [], [])
1356 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1356 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1357 if self.debug:
1357 if self.debug:
1358 pprint(msg)
1358 pprint(msg)
1359 content = msg['content']
1359 content = msg['content']
1360 if content['status'] != 'ok':
1360 if content['status'] != 'ok':
1361 raise self._unwrap_exception(content)
1361 raise self._unwrap_exception(content)
1362 mapping = content['resubmitted']
1362 mapping = content['resubmitted']
1363 new_ids = [ mapping[msg_id] for msg_id in theids ]
1363 new_ids = [ mapping[msg_id] for msg_id in theids ]
1364
1364
1365 ar = AsyncHubResult(self, msg_ids=new_ids)
1365 ar = AsyncHubResult(self, msg_ids=new_ids)
1366
1366
1367 if block:
1367 if block:
1368 ar.wait()
1368 ar.wait()
1369
1369
1370 return ar
1370 return ar
1371
1371
1372 @spin_first
1372 @spin_first
1373 def result_status(self, msg_ids, status_only=True):
1373 def result_status(self, msg_ids, status_only=True):
1374 """Check on the status of the result(s) of the apply request with `msg_ids`.
1374 """Check on the status of the result(s) of the apply request with `msg_ids`.
1375
1375
1376 If status_only is False, then the actual results will be retrieved, else
1376 If status_only is False, then the actual results will be retrieved, else
1377 only the status of the results will be checked.
1377 only the status of the results will be checked.
1378
1378
1379 Parameters
1379 Parameters
1380 ----------
1380 ----------
1381
1381
1382 msg_ids : list of msg_ids
1382 msg_ids : list of msg_ids
1383 if int:
1383 if int:
1384 Passed as index to self.history for convenience.
1384 Passed as index to self.history for convenience.
1385 status_only : bool (default: True)
1385 status_only : bool (default: True)
1386 if False:
1386 if False:
1387 Retrieve the actual results of completed tasks.
1387 Retrieve the actual results of completed tasks.
1388
1388
1389 Returns
1389 Returns
1390 -------
1390 -------
1391
1391
1392 results : dict
1392 results : dict
1393 There will always be the keys 'pending' and 'completed', which will
1393 There will always be the keys 'pending' and 'completed', which will
1394 be lists of msg_ids that are incomplete or complete. If `status_only`
1394 be lists of msg_ids that are incomplete or complete. If `status_only`
1395 is False, then completed results will be keyed by their `msg_id`.
1395 is False, then completed results will be keyed by their `msg_id`.
1396 """
1396 """
1397 if not isinstance(msg_ids, (list,tuple)):
1397 if not isinstance(msg_ids, (list,tuple)):
1398 msg_ids = [msg_ids]
1398 msg_ids = [msg_ids]
1399
1399
1400 theids = []
1400 theids = []
1401 for msg_id in msg_ids:
1401 for msg_id in msg_ids:
1402 if isinstance(msg_id, int):
1402 if isinstance(msg_id, int):
1403 msg_id = self.history[msg_id]
1403 msg_id = self.history[msg_id]
1404 if not isinstance(msg_id, basestring):
1404 if not isinstance(msg_id, basestring):
1405 raise TypeError("msg_ids must be str, not %r"%msg_id)
1405 raise TypeError("msg_ids must be str, not %r"%msg_id)
1406 theids.append(msg_id)
1406 theids.append(msg_id)
1407
1407
1408 completed = []
1408 completed = []
1409 local_results = {}
1409 local_results = {}
1410
1410
1411 # comment this block out to temporarily disable local shortcut:
1411 # comment this block out to temporarily disable local shortcut:
1412 for msg_id in theids:
1412 for msg_id in theids:
1413 if msg_id in self.results:
1413 if msg_id in self.results:
1414 completed.append(msg_id)
1414 completed.append(msg_id)
1415 local_results[msg_id] = self.results[msg_id]
1415 local_results[msg_id] = self.results[msg_id]
1416 theids.remove(msg_id)
1416 theids.remove(msg_id)
1417
1417
1418 if theids: # some not locally cached
1418 if theids: # some not locally cached
1419 content = dict(msg_ids=theids, status_only=status_only)
1419 content = dict(msg_ids=theids, status_only=status_only)
1420 msg = self.session.send(self._query_socket, "result_request", content=content)
1420 msg = self.session.send(self._query_socket, "result_request", content=content)
1421 zmq.select([self._query_socket], [], [])
1421 zmq.select([self._query_socket], [], [])
1422 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1422 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1423 if self.debug:
1423 if self.debug:
1424 pprint(msg)
1424 pprint(msg)
1425 content = msg['content']
1425 content = msg['content']
1426 if content['status'] != 'ok':
1426 if content['status'] != 'ok':
1427 raise self._unwrap_exception(content)
1427 raise self._unwrap_exception(content)
1428 buffers = msg['buffers']
1428 buffers = msg['buffers']
1429 else:
1429 else:
1430 content = dict(completed=[],pending=[])
1430 content = dict(completed=[],pending=[])
1431
1431
1432 content['completed'].extend(completed)
1432 content['completed'].extend(completed)
1433
1433
1434 if status_only:
1434 if status_only:
1435 return content
1435 return content
1436
1436
1437 failures = []
1437 failures = []
1438 # load cached results into result:
1438 # load cached results into result:
1439 content.update(local_results)
1439 content.update(local_results)
1440
1440
1441 # update cache with results:
1441 # update cache with results:
1442 for msg_id in sorted(theids):
1442 for msg_id in sorted(theids):
1443 if msg_id in content['completed']:
1443 if msg_id in content['completed']:
1444 rec = content[msg_id]
1444 rec = content[msg_id]
1445 parent = rec['header']
1445 parent = rec['header']
1446 header = rec['result_header']
1446 header = rec['result_header']
1447 rcontent = rec['result_content']
1447 rcontent = rec['result_content']
1448 iodict = rec['io']
1448 iodict = rec['io']
1449 if isinstance(rcontent, str):
1449 if isinstance(rcontent, str):
1450 rcontent = self.session.unpack(rcontent)
1450 rcontent = self.session.unpack(rcontent)
1451
1451
1452 md = self.metadata[msg_id]
1452 md = self.metadata[msg_id]
1453 md.update(self._extract_metadata(header, parent, rcontent))
1453 md.update(self._extract_metadata(header, parent, rcontent))
1454 if rec.get('received'):
1454 if rec.get('received'):
1455 md['received'] = rec['received']
1455 md['received'] = rec['received']
1456 md.update(iodict)
1456 md.update(iodict)
1457
1457
1458 if rcontent['status'] == 'ok':
1458 if rcontent['status'] == 'ok':
1459 res,buffers = util.unserialize_object(buffers)
1459 res,buffers = util.unserialize_object(buffers)
1460 else:
1460 else:
1461 print rcontent
1461 print rcontent
1462 res = self._unwrap_exception(rcontent)
1462 res = self._unwrap_exception(rcontent)
1463 failures.append(res)
1463 failures.append(res)
1464
1464
1465 self.results[msg_id] = res
1465 self.results[msg_id] = res
1466 content[msg_id] = res
1466 content[msg_id] = res
1467
1467
1468 if len(theids) == 1 and failures:
1468 if len(theids) == 1 and failures:
1469 raise failures[0]
1469 raise failures[0]
1470
1470
1471 error.collect_exceptions(failures, "result_status")
1471 error.collect_exceptions(failures, "result_status")
1472 return content
1472 return content
1473
1473
1474 @spin_first
1474 @spin_first
1475 def queue_status(self, targets='all', verbose=False):
1475 def queue_status(self, targets='all', verbose=False):
1476 """Fetch the status of engine queues.
1476 """Fetch the status of engine queues.
1477
1477
1478 Parameters
1478 Parameters
1479 ----------
1479 ----------
1480
1480
1481 targets : int/str/list of ints/strs
1481 targets : int/str/list of ints/strs
1482 the engines whose states are to be queried.
1482 the engines whose states are to be queried.
1483 default : all
1483 default : all
1484 verbose : bool
1484 verbose : bool
1485 Whether to return lengths only, or lists of ids for each element
1485 Whether to return lengths only, or lists of ids for each element
1486 """
1486 """
1487 if targets == 'all':
1487 if targets == 'all':
1488 # allow 'all' to be evaluated on the engine
1488 # allow 'all' to be evaluated on the engine
1489 engine_ids = None
1489 engine_ids = None
1490 else:
1490 else:
1491 engine_ids = self._build_targets(targets)[1]
1491 engine_ids = self._build_targets(targets)[1]
1492 content = dict(targets=engine_ids, verbose=verbose)
1492 content = dict(targets=engine_ids, verbose=verbose)
1493 self.session.send(self._query_socket, "queue_request", content=content)
1493 self.session.send(self._query_socket, "queue_request", content=content)
1494 idents,msg = self.session.recv(self._query_socket, 0)
1494 idents,msg = self.session.recv(self._query_socket, 0)
1495 if self.debug:
1495 if self.debug:
1496 pprint(msg)
1496 pprint(msg)
1497 content = msg['content']
1497 content = msg['content']
1498 status = content.pop('status')
1498 status = content.pop('status')
1499 if status != 'ok':
1499 if status != 'ok':
1500 raise self._unwrap_exception(content)
1500 raise self._unwrap_exception(content)
1501 content = rekey(content)
1501 content = rekey(content)
1502 if isinstance(targets, int):
1502 if isinstance(targets, int):
1503 return content[targets]
1503 return content[targets]
1504 else:
1504 else:
1505 return content
1505 return content
1506
1506
1507 @spin_first
1507 @spin_first
1508 def purge_results(self, jobs=[], targets=[]):
1508 def purge_results(self, jobs=[], targets=[]):
1509 """Tell the Hub to forget results.
1509 """Tell the Hub to forget results.
1510
1510
1511 Individual results can be purged by msg_id, or the entire
1511 Individual results can be purged by msg_id, or the entire
1512 history of specific targets can be purged.
1512 history of specific targets can be purged.
1513
1513
1514 Use `purge_results('all')` to scrub everything from the Hub's db.
1514 Use `purge_results('all')` to scrub everything from the Hub's db.
1515
1515
1516 Parameters
1516 Parameters
1517 ----------
1517 ----------
1518
1518
1519 jobs : str or list of str or AsyncResult objects
1519 jobs : str or list of str or AsyncResult objects
1520 the msg_ids whose results should be forgotten.
1520 the msg_ids whose results should be forgotten.
1521 targets : int/str/list of ints/strs
1521 targets : int/str/list of ints/strs
1522 The targets, by int_id, whose entire history is to be purged.
1522 The targets, by int_id, whose entire history is to be purged.
1523
1523
1524 default : None
1524 default : None
1525 """
1525 """
1526 if not targets and not jobs:
1526 if not targets and not jobs:
1527 raise ValueError("Must specify at least one of `targets` and `jobs`")
1527 raise ValueError("Must specify at least one of `targets` and `jobs`")
1528 if targets:
1528 if targets:
1529 targets = self._build_targets(targets)[1]
1529 targets = self._build_targets(targets)[1]
1530
1530
1531 # construct msg_ids from jobs
1531 # construct msg_ids from jobs
1532 if jobs == 'all':
1532 if jobs == 'all':
1533 msg_ids = jobs
1533 msg_ids = jobs
1534 else:
1534 else:
1535 msg_ids = []
1535 msg_ids = []
1536 if isinstance(jobs, (basestring,AsyncResult)):
1536 if isinstance(jobs, (basestring,AsyncResult)):
1537 jobs = [jobs]
1537 jobs = [jobs]
1538 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1538 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1539 if bad_ids:
1539 if bad_ids:
1540 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1540 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1541 for j in jobs:
1541 for j in jobs:
1542 if isinstance(j, AsyncResult):
1542 if isinstance(j, AsyncResult):
1543 msg_ids.extend(j.msg_ids)
1543 msg_ids.extend(j.msg_ids)
1544 else:
1544 else:
1545 msg_ids.append(j)
1545 msg_ids.append(j)
1546
1546
1547 content = dict(engine_ids=targets, msg_ids=msg_ids)
1547 content = dict(engine_ids=targets, msg_ids=msg_ids)
1548 self.session.send(self._query_socket, "purge_request", content=content)
1548 self.session.send(self._query_socket, "purge_request", content=content)
1549 idents, msg = self.session.recv(self._query_socket, 0)
1549 idents, msg = self.session.recv(self._query_socket, 0)
1550 if self.debug:
1550 if self.debug:
1551 pprint(msg)
1551 pprint(msg)
1552 content = msg['content']
1552 content = msg['content']
1553 if content['status'] != 'ok':
1553 if content['status'] != 'ok':
1554 raise self._unwrap_exception(content)
1554 raise self._unwrap_exception(content)
1555
1555
1556 @spin_first
1556 @spin_first
1557 def hub_history(self):
1557 def hub_history(self):
1558 """Get the Hub's history
1558 """Get the Hub's history
1559
1559
1560 Just like the Client, the Hub has a history, which is a list of msg_ids.
1560 Just like the Client, the Hub has a history, which is a list of msg_ids.
1561 This will contain the history of all clients, and, depending on configuration,
1561 This will contain the history of all clients, and, depending on configuration,
1562 may contain history across multiple cluster sessions.
1562 may contain history across multiple cluster sessions.
1563
1563
1564 Any msg_id returned here is a valid argument to `get_result`.
1564 Any msg_id returned here is a valid argument to `get_result`.
1565
1565
1566 Returns
1566 Returns
1567 -------
1567 -------
1568
1568
1569 msg_ids : list of strs
1569 msg_ids : list of strs
1570 list of all msg_ids, ordered by task submission time.
1570 list of all msg_ids, ordered by task submission time.
1571 """
1571 """
1572
1572
1573 self.session.send(self._query_socket, "history_request", content={})
1573 self.session.send(self._query_socket, "history_request", content={})
1574 idents, msg = self.session.recv(self._query_socket, 0)
1574 idents, msg = self.session.recv(self._query_socket, 0)
1575
1575
1576 if self.debug:
1576 if self.debug:
1577 pprint(msg)
1577 pprint(msg)
1578 content = msg['content']
1578 content = msg['content']
1579 if content['status'] != 'ok':
1579 if content['status'] != 'ok':
1580 raise self._unwrap_exception(content)
1580 raise self._unwrap_exception(content)
1581 else:
1581 else:
1582 return content['history']
1582 return content['history']
1583
1583
1584 @spin_first
1584 @spin_first
1585 def db_query(self, query, keys=None):
1585 def db_query(self, query, keys=None):
1586 """Query the Hub's TaskRecord database
1586 """Query the Hub's TaskRecord database
1587
1587
1588 This will return a list of task record dicts that match `query`
1588 This will return a list of task record dicts that match `query`
1589
1589
1590 Parameters
1590 Parameters
1591 ----------
1591 ----------
1592
1592
1593 query : mongodb query dict
1593 query : mongodb query dict
1594 The search dict. See mongodb query docs for details.
1594 The search dict. See mongodb query docs for details.
1595 keys : list of strs [optional]
1595 keys : list of strs [optional]
1596 The subset of keys to be returned. The default is to fetch everything but buffers.
1596 The subset of keys to be returned. The default is to fetch everything but buffers.
1597 'msg_id' will *always* be included.
1597 'msg_id' will *always* be included.
1598 """
1598 """
1599 if isinstance(keys, basestring):
1599 if isinstance(keys, basestring):
1600 keys = [keys]
1600 keys = [keys]
1601 content = dict(query=query, keys=keys)
1601 content = dict(query=query, keys=keys)
1602 self.session.send(self._query_socket, "db_request", content=content)
1602 self.session.send(self._query_socket, "db_request", content=content)
1603 idents, msg = self.session.recv(self._query_socket, 0)
1603 idents, msg = self.session.recv(self._query_socket, 0)
1604 if self.debug:
1604 if self.debug:
1605 pprint(msg)
1605 pprint(msg)
1606 content = msg['content']
1606 content = msg['content']
1607 if content['status'] != 'ok':
1607 if content['status'] != 'ok':
1608 raise self._unwrap_exception(content)
1608 raise self._unwrap_exception(content)
1609
1609
1610 records = content['records']
1610 records = content['records']
1611
1611
1612 buffer_lens = content['buffer_lens']
1612 buffer_lens = content['buffer_lens']
1613 result_buffer_lens = content['result_buffer_lens']
1613 result_buffer_lens = content['result_buffer_lens']
1614 buffers = msg['buffers']
1614 buffers = msg['buffers']
1615 has_bufs = buffer_lens is not None
1615 has_bufs = buffer_lens is not None
1616 has_rbufs = result_buffer_lens is not None
1616 has_rbufs = result_buffer_lens is not None
1617 for i,rec in enumerate(records):
1617 for i,rec in enumerate(records):
1618 # relink buffers
1618 # relink buffers
1619 if has_bufs:
1619 if has_bufs:
1620 blen = buffer_lens[i]
1620 blen = buffer_lens[i]
1621 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1621 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1622 if has_rbufs:
1622 if has_rbufs:
1623 blen = result_buffer_lens[i]
1623 blen = result_buffer_lens[i]
1624 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1624 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1625
1625
1626 return records
1626 return records
1627
1627
1628 __all__ = [ 'Client' ]
1628 __all__ = [ 'Client' ]
@@ -1,689 +1,694 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28 from IPython.testing.ipunittest import ParametricTestCase
28 from IPython.testing.ipunittest import ParametricTestCase
29
29
30 from IPython import parallel as pmod
30 from IPython import parallel as pmod
31 from IPython.parallel import error
31 from IPython.parallel import error
32 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 from IPython.parallel import DirectView
33 from IPython.parallel import DirectView
34 from IPython.parallel.util import interactive
34 from IPython.parallel.util import interactive
35
35
36 from IPython.parallel.tests import add_engines
36 from IPython.parallel.tests import add_engines
37
37
38 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 from .clienttest import ClusterTestCase, crash, wait, skip_without
39
39
40 def setup():
40 def setup():
41 add_engines(3, total=True)
41 add_engines(3, total=True)
42
42
43 class TestView(ClusterTestCase, ParametricTestCase):
43 class TestView(ClusterTestCase, ParametricTestCase):
44
44
45 def test_z_crash_mux(self):
45 def test_z_crash_mux(self):
46 """test graceful handling of engine death (direct)"""
46 """test graceful handling of engine death (direct)"""
47 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 raise SkipTest("crash tests disabled, due to undesirable crash reports")
48 # self.add_engines(1)
48 # self.add_engines(1)
49 eid = self.client.ids[-1]
49 eid = self.client.ids[-1]
50 ar = self.client[eid].apply_async(crash)
50 ar = self.client[eid].apply_async(crash)
51 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 self.assertRaisesRemote(error.EngineError, ar.get, 10)
52 eid = ar.engine_id
52 eid = ar.engine_id
53 tic = time.time()
53 tic = time.time()
54 while eid in self.client.ids and time.time()-tic < 5:
54 while eid in self.client.ids and time.time()-tic < 5:
55 time.sleep(.01)
55 time.sleep(.01)
56 self.client.spin()
56 self.client.spin()
57 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 self.assertFalse(eid in self.client.ids, "Engine should have died")
58
58
59 def test_push_pull(self):
59 def test_push_pull(self):
60 """test pushing and pulling"""
60 """test pushing and pulling"""
61 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
62 t = self.client.ids[-1]
62 t = self.client.ids[-1]
63 v = self.client[t]
63 v = self.client[t]
64 push = v.push
64 push = v.push
65 pull = v.pull
65 pull = v.pull
66 v.block=True
66 v.block=True
67 nengines = len(self.client)
67 nengines = len(self.client)
68 push({'data':data})
68 push({'data':data})
69 d = pull('data')
69 d = pull('data')
70 self.assertEquals(d, data)
70 self.assertEquals(d, data)
71 self.client[:].push({'data':data})
71 self.client[:].push({'data':data})
72 d = self.client[:].pull('data', block=True)
72 d = self.client[:].pull('data', block=True)
73 self.assertEquals(d, nengines*[data])
73 self.assertEquals(d, nengines*[data])
74 ar = push({'data':data}, block=False)
74 ar = push({'data':data}, block=False)
75 self.assertTrue(isinstance(ar, AsyncResult))
75 self.assertTrue(isinstance(ar, AsyncResult))
76 r = ar.get()
76 r = ar.get()
77 ar = self.client[:].pull('data', block=False)
77 ar = self.client[:].pull('data', block=False)
78 self.assertTrue(isinstance(ar, AsyncResult))
78 self.assertTrue(isinstance(ar, AsyncResult))
79 r = ar.get()
79 r = ar.get()
80 self.assertEquals(r, nengines*[data])
80 self.assertEquals(r, nengines*[data])
81 self.client[:].push(dict(a=10,b=20))
81 self.client[:].push(dict(a=10,b=20))
82 r = self.client[:].pull(('a','b'), block=True)
82 r = self.client[:].pull(('a','b'), block=True)
83 self.assertEquals(r, nengines*[[10,20]])
83 self.assertEquals(r, nengines*[[10,20]])
84
84
85 def test_push_pull_function(self):
85 def test_push_pull_function(self):
86 "test pushing and pulling functions"
86 "test pushing and pulling functions"
87 def testf(x):
87 def testf(x):
88 return 2.0*x
88 return 2.0*x
89
89
90 t = self.client.ids[-1]
90 t = self.client.ids[-1]
91 v = self.client[t]
91 v = self.client[t]
92 v.block=True
92 v.block=True
93 push = v.push
93 push = v.push
94 pull = v.pull
94 pull = v.pull
95 execute = v.execute
95 execute = v.execute
96 push({'testf':testf})
96 push({'testf':testf})
97 r = pull('testf')
97 r = pull('testf')
98 self.assertEqual(r(1.0), testf(1.0))
98 self.assertEqual(r(1.0), testf(1.0))
99 execute('r = testf(10)')
99 execute('r = testf(10)')
100 r = pull('r')
100 r = pull('r')
101 self.assertEquals(r, testf(10))
101 self.assertEquals(r, testf(10))
102 ar = self.client[:].push({'testf':testf}, block=False)
102 ar = self.client[:].push({'testf':testf}, block=False)
103 ar.get()
103 ar.get()
104 ar = self.client[:].pull('testf', block=False)
104 ar = self.client[:].pull('testf', block=False)
105 rlist = ar.get()
105 rlist = ar.get()
106 for r in rlist:
106 for r in rlist:
107 self.assertEqual(r(1.0), testf(1.0))
107 self.assertEqual(r(1.0), testf(1.0))
108 execute("def g(x): return x*x")
108 execute("def g(x): return x*x")
109 r = pull(('testf','g'))
109 r = pull(('testf','g'))
110 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
111
111
112 def test_push_function_globals(self):
112 def test_push_function_globals(self):
113 """test that pushed functions have access to globals"""
113 """test that pushed functions have access to globals"""
114 @interactive
114 @interactive
115 def geta():
115 def geta():
116 return a
116 return a
117 # self.add_engines(1)
117 # self.add_engines(1)
118 v = self.client[-1]
118 v = self.client[-1]
119 v.block=True
119 v.block=True
120 v['f'] = geta
120 v['f'] = geta
121 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
122 v.execute('a=5')
122 v.execute('a=5')
123 v.execute('b=f()')
123 v.execute('b=f()')
124 self.assertEquals(v['b'], 5)
124 self.assertEquals(v['b'], 5)
125
125
126 def test_push_function_defaults(self):
126 def test_push_function_defaults(self):
127 """test that pushed functions preserve default args"""
127 """test that pushed functions preserve default args"""
128 def echo(a=10):
128 def echo(a=10):
129 return a
129 return a
130 v = self.client[-1]
130 v = self.client[-1]
131 v.block=True
131 v.block=True
132 v['f'] = echo
132 v['f'] = echo
133 v.execute('b=f()')
133 v.execute('b=f()')
134 self.assertEquals(v['b'], 10)
134 self.assertEquals(v['b'], 10)
135
135
136 def test_get_result(self):
136 def test_get_result(self):
137 """test getting results from the Hub."""
137 """test getting results from the Hub."""
138 c = pmod.Client(profile='iptest')
138 c = pmod.Client(profile='iptest')
139 # self.add_engines(1)
139 # self.add_engines(1)
140 t = c.ids[-1]
140 t = c.ids[-1]
141 v = c[t]
141 v = c[t]
142 v2 = self.client[t]
142 v2 = self.client[t]
143 ar = v.apply_async(wait, 1)
143 ar = v.apply_async(wait, 1)
144 # give the monitor time to notice the message
144 # give the monitor time to notice the message
145 time.sleep(.25)
145 time.sleep(.25)
146 ahr = v2.get_result(ar.msg_ids)
146 ahr = v2.get_result(ar.msg_ids)
147 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertTrue(isinstance(ahr, AsyncHubResult))
148 self.assertEquals(ahr.get(), ar.get())
148 self.assertEquals(ahr.get(), ar.get())
149 ar2 = v2.get_result(ar.msg_ids)
149 ar2 = v2.get_result(ar.msg_ids)
150 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 self.assertFalse(isinstance(ar2, AsyncHubResult))
151 c.spin()
151 c.spin()
152 c.close()
152 c.close()
153
153
154 def test_run_newline(self):
154 def test_run_newline(self):
155 """test that run appends newline to files"""
155 """test that run appends newline to files"""
156 tmpfile = mktemp()
156 tmpfile = mktemp()
157 with open(tmpfile, 'w') as f:
157 with open(tmpfile, 'w') as f:
158 f.write("""def g():
158 f.write("""def g():
159 return 5
159 return 5
160 """)
160 """)
161 v = self.client[-1]
161 v = self.client[-1]
162 v.run(tmpfile, block=True)
162 v.run(tmpfile, block=True)
163 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
164
164
165 def test_apply_tracked(self):
165 def test_apply_tracked(self):
166 """test tracking for apply"""
166 """test tracking for apply"""
167 # self.add_engines(1)
167 # self.add_engines(1)
168 t = self.client.ids[-1]
168 t = self.client.ids[-1]
169 v = self.client[t]
169 v = self.client[t]
170 v.block=False
170 v.block=False
171 def echo(n=1024*1024, **kwargs):
171 def echo(n=1024*1024, **kwargs):
172 with v.temp_flags(**kwargs):
172 with v.temp_flags(**kwargs):
173 return v.apply(lambda x: x, 'x'*n)
173 return v.apply(lambda x: x, 'x'*n)
174 ar = echo(1, track=False)
174 ar = echo(1, track=False)
175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
176 self.assertTrue(ar.sent)
176 self.assertTrue(ar.sent)
177 ar = echo(track=True)
177 ar = echo(track=True)
178 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
179 self.assertEquals(ar.sent, ar._tracker.done)
179 self.assertEquals(ar.sent, ar._tracker.done)
180 ar._tracker.wait()
180 ar._tracker.wait()
181 self.assertTrue(ar.sent)
181 self.assertTrue(ar.sent)
182
182
183 def test_push_tracked(self):
183 def test_push_tracked(self):
184 t = self.client.ids[-1]
184 t = self.client.ids[-1]
185 ns = dict(x='x'*1024*1024)
185 ns = dict(x='x'*1024*1024)
186 v = self.client[t]
186 v = self.client[t]
187 ar = v.push(ns, block=False, track=False)
187 ar = v.push(ns, block=False, track=False)
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190
190
191 ar = v.push(ns, block=False, track=True)
191 ar = v.push(ns, block=False, track=True)
192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
193 ar._tracker.wait()
193 ar._tracker.wait()
194 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertEquals(ar.sent, ar._tracker.done)
195 self.assertTrue(ar.sent)
195 self.assertTrue(ar.sent)
196 ar.get()
196 ar.get()
197
197
198 def test_scatter_tracked(self):
198 def test_scatter_tracked(self):
199 t = self.client.ids
199 t = self.client.ids
200 x='x'*1024*1024
200 x='x'*1024*1024
201 ar = self.client[t].scatter('x', x, block=False, track=False)
201 ar = self.client[t].scatter('x', x, block=False, track=False)
202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204
204
205 ar = self.client[t].scatter('x', x, block=False, track=True)
205 ar = self.client[t].scatter('x', x, block=False, track=True)
206 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
207 self.assertEquals(ar.sent, ar._tracker.done)
207 self.assertEquals(ar.sent, ar._tracker.done)
208 ar._tracker.wait()
208 ar._tracker.wait()
209 self.assertTrue(ar.sent)
209 self.assertTrue(ar.sent)
210 ar.get()
210 ar.get()
211
211
212 def test_remote_reference(self):
212 def test_remote_reference(self):
213 v = self.client[-1]
213 v = self.client[-1]
214 v['a'] = 123
214 v['a'] = 123
215 ra = pmod.Reference('a')
215 ra = pmod.Reference('a')
216 b = v.apply_sync(lambda x: x, ra)
216 b = v.apply_sync(lambda x: x, ra)
217 self.assertEquals(b, 123)
217 self.assertEquals(b, 123)
218
218
219
219
220 def test_scatter_gather(self):
220 def test_scatter_gather(self):
221 view = self.client[:]
221 view = self.client[:]
222 seq1 = range(16)
222 seq1 = range(16)
223 view.scatter('a', seq1)
223 view.scatter('a', seq1)
224 seq2 = view.gather('a', block=True)
224 seq2 = view.gather('a', block=True)
225 self.assertEquals(seq2, seq1)
225 self.assertEquals(seq2, seq1)
226 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
227
227
228 @skip_without('numpy')
228 @skip_without('numpy')
229 def test_scatter_gather_numpy(self):
229 def test_scatter_gather_numpy(self):
230 import numpy
230 import numpy
231 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
232 view = self.client[:]
232 view = self.client[:]
233 a = numpy.arange(64)
233 a = numpy.arange(64)
234 view.scatter('a', a)
234 view.scatter('a', a)
235 b = view.gather('a', block=True)
235 b = view.gather('a', block=True)
236 assert_array_equal(b, a)
236 assert_array_equal(b, a)
237
237
238 def test_scatter_gather_lazy(self):
238 def test_scatter_gather_lazy(self):
239 """scatter/gather with targets='all'"""
239 """scatter/gather with targets='all'"""
240 view = self.client.direct_view(targets='all')
240 view = self.client.direct_view(targets='all')
241 x = range(64)
241 x = range(64)
242 view.scatter('x', x)
242 view.scatter('x', x)
243 gathered = view.gather('x', block=True)
243 gathered = view.gather('x', block=True)
244 self.assertEquals(gathered, x)
244 self.assertEquals(gathered, x)
245
245
246
246
247 @dec.known_failure_py3
247 @dec.known_failure_py3
248 @skip_without('numpy')
248 @skip_without('numpy')
249 def test_push_numpy_nocopy(self):
249 def test_push_numpy_nocopy(self):
250 import numpy
250 import numpy
251 view = self.client[:]
251 view = self.client[:]
252 a = numpy.arange(64)
252 a = numpy.arange(64)
253 view['A'] = a
253 view['A'] = a
254 @interactive
254 @interactive
255 def check_writeable(x):
255 def check_writeable(x):
256 return x.flags.writeable
256 return x.flags.writeable
257
257
258 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
258 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
259 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
259 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
260
260
261 view.push(dict(B=a))
261 view.push(dict(B=a))
262 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
262 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
263 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
263 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
264
264
265 @skip_without('numpy')
265 @skip_without('numpy')
266 def test_apply_numpy(self):
266 def test_apply_numpy(self):
267 """view.apply(f, ndarray)"""
267 """view.apply(f, ndarray)"""
268 import numpy
268 import numpy
269 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
269 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
270
270
271 A = numpy.random.random((100,100))
271 A = numpy.random.random((100,100))
272 view = self.client[-1]
272 view = self.client[-1]
273 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
273 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
274 B = A.astype(dt)
274 B = A.astype(dt)
275 C = view.apply_sync(lambda x:x, B)
275 C = view.apply_sync(lambda x:x, B)
276 assert_array_equal(B,C)
276 assert_array_equal(B,C)
277
277
278 def test_map(self):
278 def test_map(self):
279 view = self.client[:]
279 view = self.client[:]
280 def f(x):
280 def f(x):
281 return x**2
281 return x**2
282 data = range(16)
282 data = range(16)
283 r = view.map_sync(f, data)
283 r = view.map_sync(f, data)
284 self.assertEquals(r, map(f, data))
284 self.assertEquals(r, map(f, data))
285
285
286 def test_map_iterable(self):
286 def test_map_iterable(self):
287 """test map on iterables (direct)"""
287 """test map on iterables (direct)"""
288 view = self.client[:]
288 view = self.client[:]
289 # 101 is prime, so it won't be evenly distributed
289 # 101 is prime, so it won't be evenly distributed
290 arr = range(101)
290 arr = range(101)
291 # ensure it will be an iterator, even in Python 3
291 # ensure it will be an iterator, even in Python 3
292 it = iter(arr)
292 it = iter(arr)
293 r = view.map_sync(lambda x:x, arr)
293 r = view.map_sync(lambda x:x, arr)
294 self.assertEquals(r, list(arr))
294 self.assertEquals(r, list(arr))
295
295
296 def test_scatterGatherNonblocking(self):
296 def test_scatterGatherNonblocking(self):
297 data = range(16)
297 data = range(16)
298 view = self.client[:]
298 view = self.client[:]
299 view.scatter('a', data, block=False)
299 view.scatter('a', data, block=False)
300 ar = view.gather('a', block=False)
300 ar = view.gather('a', block=False)
301 self.assertEquals(ar.get(), data)
301 self.assertEquals(ar.get(), data)
302
302
303 @skip_without('numpy')
303 @skip_without('numpy')
304 def test_scatter_gather_numpy_nonblocking(self):
304 def test_scatter_gather_numpy_nonblocking(self):
305 import numpy
305 import numpy
306 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
306 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
307 a = numpy.arange(64)
307 a = numpy.arange(64)
308 view = self.client[:]
308 view = self.client[:]
309 ar = view.scatter('a', a, block=False)
309 ar = view.scatter('a', a, block=False)
310 self.assertTrue(isinstance(ar, AsyncResult))
310 self.assertTrue(isinstance(ar, AsyncResult))
311 amr = view.gather('a', block=False)
311 amr = view.gather('a', block=False)
312 self.assertTrue(isinstance(amr, AsyncMapResult))
312 self.assertTrue(isinstance(amr, AsyncMapResult))
313 assert_array_equal(amr.get(), a)
313 assert_array_equal(amr.get(), a)
314
314
315 def test_execute(self):
315 def test_execute(self):
316 view = self.client[:]
316 view = self.client[:]
317 # self.client.debug=True
317 # self.client.debug=True
318 execute = view.execute
318 execute = view.execute
319 ar = execute('c=30', block=False)
319 ar = execute('c=30', block=False)
320 self.assertTrue(isinstance(ar, AsyncResult))
320 self.assertTrue(isinstance(ar, AsyncResult))
321 ar = execute('d=[0,1,2]', block=False)
321 ar = execute('d=[0,1,2]', block=False)
322 self.client.wait(ar, 1)
322 self.client.wait(ar, 1)
323 self.assertEquals(len(ar.get()), len(self.client))
323 self.assertEquals(len(ar.get()), len(self.client))
324 for c in view['c']:
324 for c in view['c']:
325 self.assertEquals(c, 30)
325 self.assertEquals(c, 30)
326
326
327 def test_abort(self):
327 def test_abort(self):
328 view = self.client[-1]
328 view = self.client[-1]
329 ar = view.execute('import time; time.sleep(1)', block=False)
329 ar = view.execute('import time; time.sleep(1)', block=False)
330 ar2 = view.apply_async(lambda : 2)
330 ar2 = view.apply_async(lambda : 2)
331 ar3 = view.apply_async(lambda : 3)
331 ar3 = view.apply_async(lambda : 3)
332 view.abort(ar2)
332 view.abort(ar2)
333 view.abort(ar3.msg_ids)
333 view.abort(ar3.msg_ids)
334 self.assertRaises(error.TaskAborted, ar2.get)
334 self.assertRaises(error.TaskAborted, ar2.get)
335 self.assertRaises(error.TaskAborted, ar3.get)
335 self.assertRaises(error.TaskAborted, ar3.get)
336
336
337 def test_abort_all(self):
337 def test_abort_all(self):
338 """view.abort() aborts all outstanding tasks"""
338 """view.abort() aborts all outstanding tasks"""
339 view = self.client[-1]
339 view = self.client[-1]
340 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
340 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
341 view.abort()
341 view.abort()
342 view.wait(timeout=5)
342 view.wait(timeout=5)
343 for ar in ars[5:]:
343 for ar in ars[5:]:
344 self.assertRaises(error.TaskAborted, ar.get)
344 self.assertRaises(error.TaskAborted, ar.get)
345
345
346 def test_temp_flags(self):
346 def test_temp_flags(self):
347 view = self.client[-1]
347 view = self.client[-1]
348 view.block=True
348 view.block=True
349 with view.temp_flags(block=False):
349 with view.temp_flags(block=False):
350 self.assertFalse(view.block)
350 self.assertFalse(view.block)
351 self.assertTrue(view.block)
351 self.assertTrue(view.block)
352
352
353 @dec.known_failure_py3
353 @dec.known_failure_py3
354 def test_importer(self):
354 def test_importer(self):
355 view = self.client[-1]
355 view = self.client[-1]
356 view.clear(block=True)
356 view.clear(block=True)
357 with view.importer:
357 with view.importer:
358 import re
358 import re
359
359
360 @interactive
360 @interactive
361 def findall(pat, s):
361 def findall(pat, s):
362 # this globals() step isn't necessary in real code
362 # this globals() step isn't necessary in real code
363 # only to prevent a closure in the test
363 # only to prevent a closure in the test
364 re = globals()['re']
364 re = globals()['re']
365 return re.findall(pat, s)
365 return re.findall(pat, s)
366
366
367 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
367 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
368
368
369 # parallel magic tests
369 # parallel magic tests
370
370
371 def test_magic_px_blocking(self):
371 def test_magic_px_blocking(self):
372 ip = get_ipython()
372 ip = get_ipython()
373 v = self.client[-1]
373 v = self.client[-1]
374 v.activate()
374 v.activate()
375 v.block=True
375 v.block=True
376
376
377 ip.magic_px('a=5')
377 ip.magic_px('a=5')
378 self.assertEquals(v['a'], 5)
378 self.assertEquals(v['a'], 5)
379 ip.magic_px('a=10')
379 ip.magic_px('a=10')
380 self.assertEquals(v['a'], 10)
380 self.assertEquals(v['a'], 10)
381 sio = StringIO()
381 sio = StringIO()
382 savestdout = sys.stdout
382 savestdout = sys.stdout
383 sys.stdout = sio
383 sys.stdout = sio
384 # just 'print a' worst ~99% of the time, but this ensures that
384 # just 'print a' worst ~99% of the time, but this ensures that
385 # the stdout message has arrived when the result is finished:
385 # the stdout message has arrived when the result is finished:
386 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
386 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
387 sys.stdout = savestdout
387 sys.stdout = savestdout
388 buf = sio.getvalue()
388 buf = sio.getvalue()
389 self.assertTrue('[stdout:' in buf, buf)
389 self.assertTrue('[stdout:' in buf, buf)
390 self.assertTrue(buf.rstrip().endswith('10'))
390 self.assertTrue(buf.rstrip().endswith('10'))
391 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
391 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
392
392
393 def test_magic_px_nonblocking(self):
393 def test_magic_px_nonblocking(self):
394 ip = get_ipython()
394 ip = get_ipython()
395 v = self.client[-1]
395 v = self.client[-1]
396 v.activate()
396 v.activate()
397 v.block=False
397 v.block=False
398
398
399 ip.magic_px('a=5')
399 ip.magic_px('a=5')
400 self.assertEquals(v['a'], 5)
400 self.assertEquals(v['a'], 5)
401 ip.magic_px('a=10')
401 ip.magic_px('a=10')
402 self.assertEquals(v['a'], 10)
402 self.assertEquals(v['a'], 10)
403 sio = StringIO()
403 sio = StringIO()
404 savestdout = sys.stdout
404 savestdout = sys.stdout
405 sys.stdout = sio
405 sys.stdout = sio
406 ip.magic_px('print a')
406 ip.magic_px('print a')
407 sys.stdout = savestdout
407 sys.stdout = savestdout
408 buf = sio.getvalue()
408 buf = sio.getvalue()
409 self.assertFalse('[stdout:%i]'%v.targets in buf)
409 self.assertFalse('[stdout:%i]'%v.targets in buf)
410 ip.magic_px('1/0')
410 ip.magic_px('1/0')
411 ar = v.get_result(-1)
411 ar = v.get_result(-1)
412 self.assertRaisesRemote(ZeroDivisionError, ar.get)
412 self.assertRaisesRemote(ZeroDivisionError, ar.get)
413
413
414 def test_magic_autopx_blocking(self):
414 def test_magic_autopx_blocking(self):
415 ip = get_ipython()
415 ip = get_ipython()
416 v = self.client[-1]
416 v = self.client[-1]
417 v.activate()
417 v.activate()
418 v.block=True
418 v.block=True
419
419
420 sio = StringIO()
420 sio = StringIO()
421 savestdout = sys.stdout
421 savestdout = sys.stdout
422 sys.stdout = sio
422 sys.stdout = sio
423 ip.magic_autopx()
423 ip.magic_autopx()
424 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
424 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
425 ip.run_cell('b*=2')
425 ip.run_cell('b*=2')
426 ip.run_cell('print (b)')
426 ip.run_cell('print (b)')
427 ip.run_cell("b/c")
427 ip.run_cell("b/c")
428 ip.magic_autopx()
428 ip.magic_autopx()
429 sys.stdout = savestdout
429 sys.stdout = savestdout
430 output = sio.getvalue().strip()
430 output = sio.getvalue().strip()
431 self.assertTrue(output.startswith('%autopx enabled'))
431 self.assertTrue(output.startswith('%autopx enabled'))
432 self.assertTrue(output.endswith('%autopx disabled'))
432 self.assertTrue(output.endswith('%autopx disabled'))
433 self.assertTrue('RemoteError: ZeroDivisionError' in output)
433 self.assertTrue('RemoteError: ZeroDivisionError' in output)
434 ar = v.get_result(-1)
434 ar = v.get_result(-1)
435 self.assertEquals(v['a'], 5)
435 self.assertEquals(v['a'], 5)
436 self.assertEquals(v['b'], 20)
436 self.assertEquals(v['b'], 20)
437 self.assertRaisesRemote(ZeroDivisionError, ar.get)
437 self.assertRaisesRemote(ZeroDivisionError, ar.get)
438
438
439 def test_magic_autopx_nonblocking(self):
439 def test_magic_autopx_nonblocking(self):
440 ip = get_ipython()
440 ip = get_ipython()
441 v = self.client[-1]
441 v = self.client[-1]
442 v.activate()
442 v.activate()
443 v.block=False
443 v.block=False
444
444
445 sio = StringIO()
445 sio = StringIO()
446 savestdout = sys.stdout
446 savestdout = sys.stdout
447 sys.stdout = sio
447 sys.stdout = sio
448 ip.magic_autopx()
448 ip.magic_autopx()
449 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
449 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
450 ip.run_cell('print (b)')
450 ip.run_cell('print (b)')
451 ip.run_cell('import time; time.sleep(0.1)')
451 ip.run_cell('import time; time.sleep(0.1)')
452 ip.run_cell("b/c")
452 ip.run_cell("b/c")
453 ip.run_cell('b*=2')
453 ip.run_cell('b*=2')
454 ip.magic_autopx()
454 ip.magic_autopx()
455 sys.stdout = savestdout
455 sys.stdout = savestdout
456 output = sio.getvalue().strip()
456 output = sio.getvalue().strip()
457 self.assertTrue(output.startswith('%autopx enabled'))
457 self.assertTrue(output.startswith('%autopx enabled'))
458 self.assertTrue(output.endswith('%autopx disabled'))
458 self.assertTrue(output.endswith('%autopx disabled'))
459 self.assertFalse('ZeroDivisionError' in output)
459 self.assertFalse('ZeroDivisionError' in output)
460 ar = v.get_result(-2)
460 ar = v.get_result(-2)
461 self.assertRaisesRemote(ZeroDivisionError, ar.get)
461 self.assertRaisesRemote(ZeroDivisionError, ar.get)
462 # prevent TaskAborted on pulls, due to ZeroDivisionError
462 # prevent TaskAborted on pulls, due to ZeroDivisionError
463 time.sleep(0.5)
463 time.sleep(0.5)
464 self.assertEquals(v['a'], 5)
464 self.assertEquals(v['a'], 5)
465 # b*=2 will not fire, due to abort
465 # b*=2 will not fire, due to abort
466 self.assertEquals(v['b'], 10)
466 self.assertEquals(v['b'], 10)
467
467
468 def test_magic_result(self):
468 def test_magic_result(self):
469 ip = get_ipython()
469 ip = get_ipython()
470 v = self.client[-1]
470 v = self.client[-1]
471 v.activate()
471 v.activate()
472 v['a'] = 111
472 v['a'] = 111
473 ra = v['a']
473 ra = v['a']
474
474
475 ar = ip.magic_result()
475 ar = ip.magic_result()
476 self.assertEquals(ar.msg_ids, [v.history[-1]])
476 self.assertEquals(ar.msg_ids, [v.history[-1]])
477 self.assertEquals(ar.get(), 111)
477 self.assertEquals(ar.get(), 111)
478 ar = ip.magic_result('-2')
478 ar = ip.magic_result('-2')
479 self.assertEquals(ar.msg_ids, [v.history[-2]])
479 self.assertEquals(ar.msg_ids, [v.history[-2]])
480
480
481 def test_unicode_execute(self):
481 def test_unicode_execute(self):
482 """test executing unicode strings"""
482 """test executing unicode strings"""
483 v = self.client[-1]
483 v = self.client[-1]
484 v.block=True
484 v.block=True
485 if sys.version_info[0] >= 3:
485 if sys.version_info[0] >= 3:
486 code="a='é'"
486 code="a='é'"
487 else:
487 else:
488 code=u"a=u'é'"
488 code=u"a=u'é'"
489 v.execute(code)
489 v.execute(code)
490 self.assertEquals(v['a'], u'é')
490 self.assertEquals(v['a'], u'é')
491
491
492 def test_unicode_apply_result(self):
492 def test_unicode_apply_result(self):
493 """test unicode apply results"""
493 """test unicode apply results"""
494 v = self.client[-1]
494 v = self.client[-1]
495 r = v.apply_sync(lambda : u'é')
495 r = v.apply_sync(lambda : u'é')
496 self.assertEquals(r, u'é')
496 self.assertEquals(r, u'é')
497
497
498 def test_unicode_apply_arg(self):
498 def test_unicode_apply_arg(self):
499 """test passing unicode arguments to apply"""
499 """test passing unicode arguments to apply"""
500 v = self.client[-1]
500 v = self.client[-1]
501
501
502 @interactive
502 @interactive
503 def check_unicode(a, check):
503 def check_unicode(a, check):
504 assert isinstance(a, unicode), "%r is not unicode"%a
504 assert isinstance(a, unicode), "%r is not unicode"%a
505 assert isinstance(check, bytes), "%r is not bytes"%check
505 assert isinstance(check, bytes), "%r is not bytes"%check
506 assert a.encode('utf8') == check, "%s != %s"%(a,check)
506 assert a.encode('utf8') == check, "%s != %s"%(a,check)
507
507
508 for s in [ u'é', u'ßø®∫',u'asdf' ]:
508 for s in [ u'é', u'ßø®∫',u'asdf' ]:
509 try:
509 try:
510 v.apply_sync(check_unicode, s, s.encode('utf8'))
510 v.apply_sync(check_unicode, s, s.encode('utf8'))
511 except error.RemoteError as e:
511 except error.RemoteError as e:
512 if e.ename == 'AssertionError':
512 if e.ename == 'AssertionError':
513 self.fail(e.evalue)
513 self.fail(e.evalue)
514 else:
514 else:
515 raise e
515 raise e
516
516
517 def test_map_reference(self):
517 def test_map_reference(self):
518 """view.map(<Reference>, *seqs) should work"""
518 """view.map(<Reference>, *seqs) should work"""
519 v = self.client[:]
519 v = self.client[:]
520 v.scatter('n', self.client.ids, flatten=True)
520 v.scatter('n', self.client.ids, flatten=True)
521 v.execute("f = lambda x,y: x*y")
521 v.execute("f = lambda x,y: x*y")
522 rf = pmod.Reference('f')
522 rf = pmod.Reference('f')
523 nlist = list(range(10))
523 nlist = list(range(10))
524 mlist = nlist[::-1]
524 mlist = nlist[::-1]
525 expected = [ m*n for m,n in zip(mlist, nlist) ]
525 expected = [ m*n for m,n in zip(mlist, nlist) ]
526 result = v.map_sync(rf, mlist, nlist)
526 result = v.map_sync(rf, mlist, nlist)
527 self.assertEquals(result, expected)
527 self.assertEquals(result, expected)
528
528
529 def test_apply_reference(self):
529 def test_apply_reference(self):
530 """view.apply(<Reference>, *args) should work"""
530 """view.apply(<Reference>, *args) should work"""
531 v = self.client[:]
531 v = self.client[:]
532 v.scatter('n', self.client.ids, flatten=True)
532 v.scatter('n', self.client.ids, flatten=True)
533 v.execute("f = lambda x: n*x")
533 v.execute("f = lambda x: n*x")
534 rf = pmod.Reference('f')
534 rf = pmod.Reference('f')
535 result = v.apply_sync(rf, 5)
535 result = v.apply_sync(rf, 5)
536 expected = [ 5*id for id in self.client.ids ]
536 expected = [ 5*id for id in self.client.ids ]
537 self.assertEquals(result, expected)
537 self.assertEquals(result, expected)
538
538
539 def test_eval_reference(self):
539 def test_eval_reference(self):
540 v = self.client[self.client.ids[0]]
540 v = self.client[self.client.ids[0]]
541 v['g'] = range(5)
541 v['g'] = range(5)
542 rg = pmod.Reference('g[0]')
542 rg = pmod.Reference('g[0]')
543 echo = lambda x:x
543 echo = lambda x:x
544 self.assertEquals(v.apply_sync(echo, rg), 0)
544 self.assertEquals(v.apply_sync(echo, rg), 0)
545
545
546 def test_reference_nameerror(self):
546 def test_reference_nameerror(self):
547 v = self.client[self.client.ids[0]]
547 v = self.client[self.client.ids[0]]
548 r = pmod.Reference('elvis_has_left')
548 r = pmod.Reference('elvis_has_left')
549 echo = lambda x:x
549 echo = lambda x:x
550 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
550 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
551
551
552 def test_single_engine_map(self):
552 def test_single_engine_map(self):
553 e0 = self.client[self.client.ids[0]]
553 e0 = self.client[self.client.ids[0]]
554 r = range(5)
554 r = range(5)
555 check = [ -1*i for i in r ]
555 check = [ -1*i for i in r ]
556 result = e0.map_sync(lambda x: -1*x, r)
556 result = e0.map_sync(lambda x: -1*x, r)
557 self.assertEquals(result, check)
557 self.assertEquals(result, check)
558
558
559 def test_len(self):
559 def test_len(self):
560 """len(view) makes sense"""
560 """len(view) makes sense"""
561 e0 = self.client[self.client.ids[0]]
561 e0 = self.client[self.client.ids[0]]
562 yield self.assertEquals(len(e0), 1)
562 yield self.assertEquals(len(e0), 1)
563 v = self.client[:]
563 v = self.client[:]
564 yield self.assertEquals(len(v), len(self.client.ids))
564 yield self.assertEquals(len(v), len(self.client.ids))
565 v = self.client.direct_view('all')
565 v = self.client.direct_view('all')
566 yield self.assertEquals(len(v), len(self.client.ids))
566 yield self.assertEquals(len(v), len(self.client.ids))
567 v = self.client[:2]
567 v = self.client[:2]
568 yield self.assertEquals(len(v), 2)
568 yield self.assertEquals(len(v), 2)
569 v = self.client[:1]
569 v = self.client[:1]
570 yield self.assertEquals(len(v), 1)
570 yield self.assertEquals(len(v), 1)
571 v = self.client.load_balanced_view()
571 v = self.client.load_balanced_view()
572 yield self.assertEquals(len(v), len(self.client.ids))
572 yield self.assertEquals(len(v), len(self.client.ids))
573 # parametric tests seem to require manual closing?
573 # parametric tests seem to require manual closing?
574 self.client.close()
574 self.client.close()
575
575
576
576
577 # begin execute tests
577 # begin execute tests
578 def _wait_for(self, f, timeout=10):
578 def _wait_for(self, f, timeout=10):
579 tic = time.time()
579 tic = time.time()
580 while time.time() <= tic + timeout:
580 while time.time() <= tic + timeout:
581 if f():
581 if f():
582 return
582 return
583 time.sleep(0.1)
583 time.sleep(0.1)
584 self.client.spin()
584 self.client.spin()
585 if not f():
585 if not f():
586 print "Warning: Awaited condition never arrived"
586 print "Warning: Awaited condition never arrived"
587
587
588
588
589 def test_execute_reply(self):
589 def test_execute_reply(self):
590 e0 = self.client[self.client.ids[0]]
590 e0 = self.client[self.client.ids[0]]
591 e0.block = True
591 e0.block = True
592 ar = e0.execute("5", silent=False)
592 ar = e0.execute("5", silent=False)
593 er = ar.get()
593 er = ar.get()
594 self._wait_for(lambda : bool(er.pyout))
594 self._wait_for(lambda : bool(er.pyout))
595 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
595 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
596 self.assertEquals(er.pyout['text/plain'], '5')
596 self.assertEquals(er.pyout['data']['text/plain'], '5')
597
597
598 def test_execute_reply_stdout(self):
598 def test_execute_reply_stdout(self):
599 e0 = self.client[self.client.ids[0]]
599 e0 = self.client[self.client.ids[0]]
600 e0.block = True
600 e0.block = True
601 ar = e0.execute("print (5)", silent=False)
601 ar = e0.execute("print (5)", silent=False)
602 er = ar.get()
602 er = ar.get()
603 self._wait_for(lambda : bool(er.stdout))
603 self._wait_for(lambda : bool(er.stdout))
604 self.assertEquals(er.stdout.strip(), '5')
604 self.assertEquals(er.stdout.strip(), '5')
605
605
606 def test_execute_pyout(self):
606 def test_execute_pyout(self):
607 """execute triggers pyout with silent=False"""
607 """execute triggers pyout with silent=False"""
608 view = self.client[:]
608 view = self.client[:]
609 ar = view.execute("5", silent=False, block=True)
609 ar = view.execute("5", silent=False, block=True)
610 self._wait_for(lambda : all(ar.pyout))
610 self._wait_for(lambda : all(ar.pyout))
611
611
612 expected = [{'text/plain' : '5'}] * len(view)
612 expected = [{'text/plain' : '5'}] * len(view)
613 self.assertEquals(ar.pyout, expected)
613 mimes = [ out['data'] for out in ar.pyout ]
614 self.assertEquals(mimes, expected)
614
615
615 def test_execute_silent(self):
616 def test_execute_silent(self):
616 """execute does not trigger pyout with silent=True"""
617 """execute does not trigger pyout with silent=True"""
617 view = self.client[:]
618 view = self.client[:]
618 ar = view.execute("5", block=True)
619 ar = view.execute("5", block=True)
619 expected = [None] * len(view)
620 expected = [None] * len(view)
620 self.assertEquals(ar.pyout, expected)
621 self.assertEquals(ar.pyout, expected)
621
622
622 def test_execute_magic(self):
623 def test_execute_magic(self):
623 """execute accepts IPython commands"""
624 """execute accepts IPython commands"""
624 view = self.client[:]
625 view = self.client[:]
625 view.execute("a = 5")
626 view.execute("a = 5")
626 ar = view.execute("%whos", block=True)
627 ar = view.execute("%whos", block=True)
627 # this will raise, if that failed
628 # this will raise, if that failed
628 ar.get(5)
629 ar.get(5)
629 self._wait_for(lambda : all(ar.stdout))
630 self._wait_for(lambda : all(ar.stdout))
630 for stdout in ar.stdout:
631 for stdout in ar.stdout:
631 lines = stdout.splitlines()
632 lines = stdout.splitlines()
632 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
633 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
633 found = False
634 found = False
634 for line in lines[2:]:
635 for line in lines[2:]:
635 split = line.split()
636 split = line.split()
636 if split == ['a', 'int', '5']:
637 if split == ['a', 'int', '5']:
637 found = True
638 found = True
638 break
639 break
639 self.assertTrue(found, "whos output wrong: %s" % stdout)
640 self.assertTrue(found, "whos output wrong: %s" % stdout)
640
641
641 def test_execute_displaypub(self):
642 def test_execute_displaypub(self):
642 """execute tracks display_pub output"""
643 """execute tracks display_pub output"""
643 view = self.client[:]
644 view = self.client[:]
644 view.execute("from IPython.core.display import *")
645 view.execute("from IPython.core.display import *")
645 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
646 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
646
647
647 self._wait_for(lambda : all(len(er.outputs) >= 5 for er in ar))
648 self._wait_for(lambda : all(len(er.outputs) >= 5 for er in ar))
648 outs = [ {u'text/plain' : unicode(i)} for i in range(5) ]
649 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
649 expected = [outs] * len(view)
650 for outputs in ar.outputs:
650 self.assertEquals(ar.outputs, expected)
651 mimes = [ out['data'] for out in outputs ]
652 self.assertEquals(mimes, expected)
651
653
652 def test_apply_displaypub(self):
654 def test_apply_displaypub(self):
653 """apply tracks display_pub output"""
655 """apply tracks display_pub output"""
654 view = self.client[:]
656 view = self.client[:]
655 view.execute("from IPython.core.display import *")
657 view.execute("from IPython.core.display import *")
656
658
657 @interactive
659 @interactive
658 def publish():
660 def publish():
659 [ display(i) for i in range(5) ]
661 [ display(i) for i in range(5) ]
660
662
661 ar = view.apply_async(publish)
663 ar = view.apply_async(publish)
662 ar.get(5)
664 ar.get(5)
663 self._wait_for(lambda : all(len(out) >= 5 for out in ar.outputs))
665 self._wait_for(lambda : all(len(out) >= 5 for out in ar.outputs))
664 outs = [ {u'text/plain' : unicode(j)} for j in range(5) ]
666 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
665 expected = [outs] * len(view)
667 for outputs in ar.outputs:
666 self.assertEquals(ar.outputs, expected)
668 mimes = [ out['data'] for out in outputs ]
669 self.assertEquals(mimes, expected)
667
670
668 def test_execute_raises(self):
671 def test_execute_raises(self):
669 """exceptions in execute requests raise appropriately"""
672 """exceptions in execute requests raise appropriately"""
670 view = self.client[-1]
673 view = self.client[-1]
671 ar = view.execute("1/0")
674 ar = view.execute("1/0")
672 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
675 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
673
676
674 @dec.skipif_not_matplotlib
677 @dec.skipif_not_matplotlib
675 def test_amagic_pylab(self):
678 def test_magic_pylab(self):
676 """%pylab works on engines"""
679 """%pylab works on engines"""
677 view = self.client[-1]
680 view = self.client[-1]
678 ar = view.execute("%pylab inline")
681 ar = view.execute("%pylab inline")
679 # at least check if this raised:
682 # at least check if this raised:
680 reply = ar.get(5)
683 reply = ar.get(5)
681 # include imports, in case user config
684 # include imports, in case user config
682 ar = view.execute("plot(rand(100))", silent=False)
685 ar = view.execute("plot(rand(100))", silent=False)
683 reply = ar.get(5)
686 reply = ar.get(5)
684 self._wait_for(lambda : all(ar.outputs))
687 self._wait_for(lambda : all(ar.outputs))
685 self.assertEquals(len(reply.outputs), 1)
688 self.assertEquals(len(reply.outputs), 1)
686 output = reply.outputs[0]
689 output = reply.outputs[0]
687 self.assertTrue("image/png" in output)
690 self.assertTrue("data" in output)
691 data = output['data']
692 self.assertTrue("image/png" in data)
688
693
689
694
General Comments 0
You need to be logged in to leave comments. Login now