##// END OF EJS Templates
subclass RichOutput in ExecuteReply...
MinRK -
Show More
@@ -1,1839 +1,1845 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 from threading import Thread, Event
21 from threading import Thread, Event
22 import time
22 import time
23 import warnings
23 import warnings
24 from datetime import datetime
24 from datetime import datetime
25 from getpass import getpass
25 from getpass import getpass
26 from pprint import pprint
26 from pprint import pprint
27
27
28 pjoin = os.path.join
28 pjoin = os.path.join
29
29
30 import zmq
30 import zmq
31 # from zmq.eventloop import ioloop, zmqstream
31 # from zmq.eventloop import ioloop, zmqstream
32
32
33 from IPython.config.configurable import MultipleInstanceError
33 from IPython.config.configurable import MultipleInstanceError
34 from IPython.core.application import BaseIPythonApplication
34 from IPython.core.application import BaseIPythonApplication
35 from IPython.core.profiledir import ProfileDir, ProfileDirError
35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36
36
37 from IPython.utils.capture import RichOutput
37 from IPython.utils.coloransi import TermColors
38 from IPython.utils.coloransi import TermColors
38 from IPython.utils.jsonutil import rekey
39 from IPython.utils.jsonutil import rekey
39 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
40 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
40 from IPython.utils.path import get_ipython_dir
41 from IPython.utils.path import get_ipython_dir
41 from IPython.utils.py3compat import cast_bytes
42 from IPython.utils.py3compat import cast_bytes
42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 Dict, List, Bool, Set, Any)
44 Dict, List, Bool, Set, Any)
44 from IPython.external.decorator import decorator
45 from IPython.external.decorator import decorator
45 from IPython.external.ssh import tunnel
46 from IPython.external.ssh import tunnel
46
47
47 from IPython.parallel import Reference
48 from IPython.parallel import Reference
48 from IPython.parallel import error
49 from IPython.parallel import error
49 from IPython.parallel import util
50 from IPython.parallel import util
50
51
51 from IPython.kernel.zmq.session import Session, Message
52 from IPython.kernel.zmq.session import Session, Message
52 from IPython.kernel.zmq import serialize
53 from IPython.kernel.zmq import serialize
53
54
54 from .asyncresult import AsyncResult, AsyncHubResult
55 from .asyncresult import AsyncResult, AsyncHubResult
55 from .view import DirectView, LoadBalancedView
56 from .view import DirectView, LoadBalancedView
56
57
57 if sys.version_info[0] >= 3:
58 if sys.version_info[0] >= 3:
58 # xrange is used in a couple 'isinstance' tests in py2
59 # xrange is used in a couple 'isinstance' tests in py2
59 # should be just 'range' in 3k
60 # should be just 'range' in 3k
60 xrange = range
61 xrange = range
61
62
62 #--------------------------------------------------------------------------
63 #--------------------------------------------------------------------------
63 # Decorators for Client methods
64 # Decorators for Client methods
64 #--------------------------------------------------------------------------
65 #--------------------------------------------------------------------------
65
66
66 @decorator
67 @decorator
67 def spin_first(f, self, *args, **kwargs):
68 def spin_first(f, self, *args, **kwargs):
68 """Call spin() to sync state prior to calling the method."""
69 """Call spin() to sync state prior to calling the method."""
69 self.spin()
70 self.spin()
70 return f(self, *args, **kwargs)
71 return f(self, *args, **kwargs)
71
72
72
73
73 #--------------------------------------------------------------------------
74 #--------------------------------------------------------------------------
74 # Classes
75 # Classes
75 #--------------------------------------------------------------------------
76 #--------------------------------------------------------------------------
76
77
77
78
78 class ExecuteReply(object):
79 class ExecuteReply(RichOutput):
79 """wrapper for finished Execute results"""
80 """wrapper for finished Execute results"""
80 def __init__(self, msg_id, content, metadata):
81 def __init__(self, msg_id, content, metadata):
81 self.msg_id = msg_id
82 self.msg_id = msg_id
82 self._content = content
83 self._content = content
83 self.execution_count = content['execution_count']
84 self.execution_count = content['execution_count']
84 self.metadata = metadata
85 self.metadata = metadata
85
86
87 # RichOutput overrides
88
89 @property
90 def source(self):
91 pyout = self.metadata['pyout']
92 if pyout:
93 return pyout.get('source', '')
94
95 @property
96 def data(self):
97 pyout = self.metadata['pyout']
98 if pyout:
99 return pyout.get('data', {})
100
101 @property
102 def _metadata(self):
103 pyout = self.metadata['pyout']
104 if pyout:
105 return pyout.get('metadata', {})
106
107 def display(self):
108 from IPython.display import publish_display_data
109 publish_display_data(self.source, self.data, self.metadata)
110
111 def _repr_mime_(self, mime):
112 if mime not in self.data:
113 return
114 data = self.data[mime]
115 if mime in self._metadata:
116 return data, self._metadata[mime]
117 else:
118 return data
119
86 def __getitem__(self, key):
120 def __getitem__(self, key):
87 return self.metadata[key]
121 return self.metadata[key]
88
122
89 def __getattr__(self, key):
123 def __getattr__(self, key):
90 if key not in self.metadata:
124 if key not in self.metadata:
91 raise AttributeError(key)
125 raise AttributeError(key)
92 return self.metadata[key]
126 return self.metadata[key]
93
127
94 def __repr__(self):
128 def __repr__(self):
95 pyout = self.metadata['pyout'] or {'data':{}}
129 pyout = self.metadata['pyout'] or {'data':{}}
96 text_out = pyout['data'].get('text/plain', '')
130 text_out = pyout['data'].get('text/plain', '')
97 if len(text_out) > 32:
131 if len(text_out) > 32:
98 text_out = text_out[:29] + '...'
132 text_out = text_out[:29] + '...'
99
133
100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
134 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101
135
102 def _repr_pretty_(self, p, cycle):
136 def _repr_pretty_(self, p, cycle):
103 pyout = self.metadata['pyout'] or {'data':{}}
137 pyout = self.metadata['pyout'] or {'data':{}}
104 text_out = pyout['data'].get('text/plain', '')
138 text_out = pyout['data'].get('text/plain', '')
105
139
106 if not text_out:
140 if not text_out:
107 return
141 return
108
142
109 try:
143 try:
110 ip = get_ipython()
144 ip = get_ipython()
111 except NameError:
145 except NameError:
112 colors = "NoColor"
146 colors = "NoColor"
113 else:
147 else:
114 colors = ip.colors
148 colors = ip.colors
115
149
116 if colors == "NoColor":
150 if colors == "NoColor":
117 out = normal = ""
151 out = normal = ""
118 else:
152 else:
119 out = TermColors.Red
153 out = TermColors.Red
120 normal = TermColors.Normal
154 normal = TermColors.Normal
121
155
122 if '\n' in text_out and not text_out.startswith('\n'):
156 if '\n' in text_out and not text_out.startswith('\n'):
123 # add newline for multiline reprs
157 # add newline for multiline reprs
124 text_out = '\n' + text_out
158 text_out = '\n' + text_out
125
159
126 p.text(
160 p.text(
127 out + u'Out[%i:%i]: ' % (
161 out + u'Out[%i:%i]: ' % (
128 self.metadata['engine_id'], self.execution_count
162 self.metadata['engine_id'], self.execution_count
129 ) + normal + text_out
163 ) + normal + text_out
130 )
164 )
131
165
132 def _repr_html_(self):
133 pyout = self.metadata['pyout'] or {'data':{}}
134 return pyout['data'].get("text/html")
135
136 def _repr_latex_(self):
137 pyout = self.metadata['pyout'] or {'data':{}}
138 return pyout['data'].get("text/latex")
139
140 def _repr_json_(self):
141 pyout = self.metadata['pyout'] or {'data':{}}
142 return pyout['data'].get("application/json")
143
144 def _repr_javascript_(self):
145 pyout = self.metadata['pyout'] or {'data':{}}
146 return pyout['data'].get("application/javascript")
147
148 def _repr_png_(self):
149 pyout = self.metadata['pyout'] or {'data':{}}
150 return pyout['data'].get("image/png")
151
152 def _repr_jpeg_(self):
153 pyout = self.metadata['pyout'] or {'data':{}}
154 return pyout['data'].get("image/jpeg")
155
156 def _repr_svg_(self):
157 pyout = self.metadata['pyout'] or {'data':{}}
158 return pyout['data'].get("image/svg+xml")
159
160
166
161 class Metadata(dict):
167 class Metadata(dict):
162 """Subclass of dict for initializing metadata values.
168 """Subclass of dict for initializing metadata values.
163
169
164 Attribute access works on keys.
170 Attribute access works on keys.
165
171
166 These objects have a strict set of keys - errors will raise if you try
172 These objects have a strict set of keys - errors will raise if you try
167 to add new keys.
173 to add new keys.
168 """
174 """
169 def __init__(self, *args, **kwargs):
175 def __init__(self, *args, **kwargs):
170 dict.__init__(self)
176 dict.__init__(self)
171 md = {'msg_id' : None,
177 md = {'msg_id' : None,
172 'submitted' : None,
178 'submitted' : None,
173 'started' : None,
179 'started' : None,
174 'completed' : None,
180 'completed' : None,
175 'received' : None,
181 'received' : None,
176 'engine_uuid' : None,
182 'engine_uuid' : None,
177 'engine_id' : None,
183 'engine_id' : None,
178 'follow' : None,
184 'follow' : None,
179 'after' : None,
185 'after' : None,
180 'status' : None,
186 'status' : None,
181
187
182 'pyin' : None,
188 'pyin' : None,
183 'pyout' : None,
189 'pyout' : None,
184 'pyerr' : None,
190 'pyerr' : None,
185 'stdout' : '',
191 'stdout' : '',
186 'stderr' : '',
192 'stderr' : '',
187 'outputs' : [],
193 'outputs' : [],
188 'data': {},
194 'data': {},
189 'outputs_ready' : False,
195 'outputs_ready' : False,
190 }
196 }
191 self.update(md)
197 self.update(md)
192 self.update(dict(*args, **kwargs))
198 self.update(dict(*args, **kwargs))
193
199
194 def __getattr__(self, key):
200 def __getattr__(self, key):
195 """getattr aliased to getitem"""
201 """getattr aliased to getitem"""
196 if key in self.iterkeys():
202 if key in self.iterkeys():
197 return self[key]
203 return self[key]
198 else:
204 else:
199 raise AttributeError(key)
205 raise AttributeError(key)
200
206
201 def __setattr__(self, key, value):
207 def __setattr__(self, key, value):
202 """setattr aliased to setitem, with strict"""
208 """setattr aliased to setitem, with strict"""
203 if key in self.iterkeys():
209 if key in self.iterkeys():
204 self[key] = value
210 self[key] = value
205 else:
211 else:
206 raise AttributeError(key)
212 raise AttributeError(key)
207
213
208 def __setitem__(self, key, value):
214 def __setitem__(self, key, value):
209 """strict static key enforcement"""
215 """strict static key enforcement"""
210 if key in self.iterkeys():
216 if key in self.iterkeys():
211 dict.__setitem__(self, key, value)
217 dict.__setitem__(self, key, value)
212 else:
218 else:
213 raise KeyError(key)
219 raise KeyError(key)
214
220
215
221
216 class Client(HasTraits):
222 class Client(HasTraits):
217 """A semi-synchronous client to the IPython ZMQ cluster
223 """A semi-synchronous client to the IPython ZMQ cluster
218
224
219 Parameters
225 Parameters
220 ----------
226 ----------
221
227
222 url_file : str/unicode; path to ipcontroller-client.json
228 url_file : str/unicode; path to ipcontroller-client.json
223 This JSON file should contain all the information needed to connect to a cluster,
229 This JSON file should contain all the information needed to connect to a cluster,
224 and is likely the only argument needed.
230 and is likely the only argument needed.
225 Connection information for the Hub's registration. If a json connector
231 Connection information for the Hub's registration. If a json connector
226 file is given, then likely no further configuration is necessary.
232 file is given, then likely no further configuration is necessary.
227 [Default: use profile]
233 [Default: use profile]
228 profile : bytes
234 profile : bytes
229 The name of the Cluster profile to be used to find connector information.
235 The name of the Cluster profile to be used to find connector information.
230 If run from an IPython application, the default profile will be the same
236 If run from an IPython application, the default profile will be the same
231 as the running application, otherwise it will be 'default'.
237 as the running application, otherwise it will be 'default'.
232 cluster_id : str
238 cluster_id : str
233 String id to added to runtime files, to prevent name collisions when using
239 String id to added to runtime files, to prevent name collisions when using
234 multiple clusters with a single profile simultaneously.
240 multiple clusters with a single profile simultaneously.
235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
241 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 Since this is text inserted into filenames, typical recommendations apply:
242 Since this is text inserted into filenames, typical recommendations apply:
237 Simple character strings are ideal, and spaces are not recommended (but
243 Simple character strings are ideal, and spaces are not recommended (but
238 should generally work)
244 should generally work)
239 context : zmq.Context
245 context : zmq.Context
240 Pass an existing zmq.Context instance, otherwise the client will create its own.
246 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 debug : bool
247 debug : bool
242 flag for lots of message printing for debug purposes
248 flag for lots of message printing for debug purposes
243 timeout : int/float
249 timeout : int/float
244 time (in seconds) to wait for connection replies from the Hub
250 time (in seconds) to wait for connection replies from the Hub
245 [Default: 10]
251 [Default: 10]
246
252
247 #-------------- session related args ----------------
253 #-------------- session related args ----------------
248
254
249 config : Config object
255 config : Config object
250 If specified, this will be relayed to the Session for configuration
256 If specified, this will be relayed to the Session for configuration
251 username : str
257 username : str
252 set username for the session object
258 set username for the session object
253
259
254 #-------------- ssh related args ----------------
260 #-------------- ssh related args ----------------
255 # These are args for configuring the ssh tunnel to be used
261 # These are args for configuring the ssh tunnel to be used
256 # credentials are used to forward connections over ssh to the Controller
262 # credentials are used to forward connections over ssh to the Controller
257 # Note that the ip given in `addr` needs to be relative to sshserver
263 # Note that the ip given in `addr` needs to be relative to sshserver
258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
264 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 # and set sshserver as the same machine the Controller is on. However,
265 # and set sshserver as the same machine the Controller is on. However,
260 # the only requirement is that sshserver is able to see the Controller
266 # the only requirement is that sshserver is able to see the Controller
261 # (i.e. is within the same trusted network).
267 # (i.e. is within the same trusted network).
262
268
263 sshserver : str
269 sshserver : str
264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
270 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 If keyfile or password is specified, and this is not, it will default to
271 If keyfile or password is specified, and this is not, it will default to
266 the ip given in addr.
272 the ip given in addr.
267 sshkey : str; path to ssh private key file
273 sshkey : str; path to ssh private key file
268 This specifies a key to be used in ssh login, default None.
274 This specifies a key to be used in ssh login, default None.
269 Regular default ssh keys will be used without specifying this argument.
275 Regular default ssh keys will be used without specifying this argument.
270 password : str
276 password : str
271 Your ssh password to sshserver. Note that if this is left None,
277 Your ssh password to sshserver. Note that if this is left None,
272 you will be prompted for it if passwordless key based login is unavailable.
278 you will be prompted for it if passwordless key based login is unavailable.
273 paramiko : bool
279 paramiko : bool
274 flag for whether to use paramiko instead of shell ssh for tunneling.
280 flag for whether to use paramiko instead of shell ssh for tunneling.
275 [default: True on win32, False else]
281 [default: True on win32, False else]
276
282
277
283
278 Attributes
284 Attributes
279 ----------
285 ----------
280
286
281 ids : list of int engine IDs
287 ids : list of int engine IDs
282 requesting the ids attribute always synchronizes
288 requesting the ids attribute always synchronizes
283 the registration state. To request ids without synchronization,
289 the registration state. To request ids without synchronization,
284 use semi-private _ids attributes.
290 use semi-private _ids attributes.
285
291
286 history : list of msg_ids
292 history : list of msg_ids
287 a list of msg_ids, keeping track of all the execution
293 a list of msg_ids, keeping track of all the execution
288 messages you have submitted in order.
294 messages you have submitted in order.
289
295
290 outstanding : set of msg_ids
296 outstanding : set of msg_ids
291 a set of msg_ids that have been submitted, but whose
297 a set of msg_ids that have been submitted, but whose
292 results have not yet been received.
298 results have not yet been received.
293
299
294 results : dict
300 results : dict
295 a dict of all our results, keyed by msg_id
301 a dict of all our results, keyed by msg_id
296
302
297 block : bool
303 block : bool
298 determines default behavior when block not specified
304 determines default behavior when block not specified
299 in execution methods
305 in execution methods
300
306
301 Methods
307 Methods
302 -------
308 -------
303
309
304 spin
310 spin
305 flushes incoming results and registration state changes
311 flushes incoming results and registration state changes
306 control methods spin, and requesting `ids` also ensures up to date
312 control methods spin, and requesting `ids` also ensures up to date
307
313
308 wait
314 wait
309 wait on one or more msg_ids
315 wait on one or more msg_ids
310
316
311 execution methods
317 execution methods
312 apply
318 apply
313 legacy: execute, run
319 legacy: execute, run
314
320
315 data movement
321 data movement
316 push, pull, scatter, gather
322 push, pull, scatter, gather
317
323
318 query methods
324 query methods
319 queue_status, get_result, purge, result_status
325 queue_status, get_result, purge, result_status
320
326
321 control methods
327 control methods
322 abort, shutdown
328 abort, shutdown
323
329
324 """
330 """
325
331
326
332
327 block = Bool(False)
333 block = Bool(False)
328 outstanding = Set()
334 outstanding = Set()
329 results = Instance('collections.defaultdict', (dict,))
335 results = Instance('collections.defaultdict', (dict,))
330 metadata = Instance('collections.defaultdict', (Metadata,))
336 metadata = Instance('collections.defaultdict', (Metadata,))
331 history = List()
337 history = List()
332 debug = Bool(False)
338 debug = Bool(False)
333 _spin_thread = Any()
339 _spin_thread = Any()
334 _stop_spinning = Any()
340 _stop_spinning = Any()
335
341
336 profile=Unicode()
342 profile=Unicode()
337 def _profile_default(self):
343 def _profile_default(self):
338 if BaseIPythonApplication.initialized():
344 if BaseIPythonApplication.initialized():
339 # an IPython app *might* be running, try to get its profile
345 # an IPython app *might* be running, try to get its profile
340 try:
346 try:
341 return BaseIPythonApplication.instance().profile
347 return BaseIPythonApplication.instance().profile
342 except (AttributeError, MultipleInstanceError):
348 except (AttributeError, MultipleInstanceError):
343 # could be a *different* subclass of config.Application,
349 # could be a *different* subclass of config.Application,
344 # which would raise one of these two errors.
350 # which would raise one of these two errors.
345 return u'default'
351 return u'default'
346 else:
352 else:
347 return u'default'
353 return u'default'
348
354
349
355
350 _outstanding_dict = Instance('collections.defaultdict', (set,))
356 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 _ids = List()
357 _ids = List()
352 _connected=Bool(False)
358 _connected=Bool(False)
353 _ssh=Bool(False)
359 _ssh=Bool(False)
354 _context = Instance('zmq.Context')
360 _context = Instance('zmq.Context')
355 _config = Dict()
361 _config = Dict()
356 _engines=Instance(util.ReverseDict, (), {})
362 _engines=Instance(util.ReverseDict, (), {})
357 # _hub_socket=Instance('zmq.Socket')
363 # _hub_socket=Instance('zmq.Socket')
358 _query_socket=Instance('zmq.Socket')
364 _query_socket=Instance('zmq.Socket')
359 _control_socket=Instance('zmq.Socket')
365 _control_socket=Instance('zmq.Socket')
360 _iopub_socket=Instance('zmq.Socket')
366 _iopub_socket=Instance('zmq.Socket')
361 _notification_socket=Instance('zmq.Socket')
367 _notification_socket=Instance('zmq.Socket')
362 _mux_socket=Instance('zmq.Socket')
368 _mux_socket=Instance('zmq.Socket')
363 _task_socket=Instance('zmq.Socket')
369 _task_socket=Instance('zmq.Socket')
364 _task_scheme=Unicode()
370 _task_scheme=Unicode()
365 _closed = False
371 _closed = False
366 _ignored_control_replies=Integer(0)
372 _ignored_control_replies=Integer(0)
367 _ignored_hub_replies=Integer(0)
373 _ignored_hub_replies=Integer(0)
368
374
369 def __new__(self, *args, **kw):
375 def __new__(self, *args, **kw):
370 # don't raise on positional args
376 # don't raise on positional args
371 return HasTraits.__new__(self, **kw)
377 return HasTraits.__new__(self, **kw)
372
378
373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
379 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 context=None, debug=False,
380 context=None, debug=False,
375 sshserver=None, sshkey=None, password=None, paramiko=None,
381 sshserver=None, sshkey=None, password=None, paramiko=None,
376 timeout=10, cluster_id=None, **extra_args
382 timeout=10, cluster_id=None, **extra_args
377 ):
383 ):
378 if profile:
384 if profile:
379 super(Client, self).__init__(debug=debug, profile=profile)
385 super(Client, self).__init__(debug=debug, profile=profile)
380 else:
386 else:
381 super(Client, self).__init__(debug=debug)
387 super(Client, self).__init__(debug=debug)
382 if context is None:
388 if context is None:
383 context = zmq.Context.instance()
389 context = zmq.Context.instance()
384 self._context = context
390 self._context = context
385 self._stop_spinning = Event()
391 self._stop_spinning = Event()
386
392
387 if 'url_or_file' in extra_args:
393 if 'url_or_file' in extra_args:
388 url_file = extra_args['url_or_file']
394 url_file = extra_args['url_or_file']
389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
395 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390
396
391 if url_file and util.is_url(url_file):
397 if url_file and util.is_url(url_file):
392 raise ValueError("single urls cannot be specified, url-files must be used.")
398 raise ValueError("single urls cannot be specified, url-files must be used.")
393
399
394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
400 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395
401
396 if self._cd is not None:
402 if self._cd is not None:
397 if url_file is None:
403 if url_file is None:
398 if not cluster_id:
404 if not cluster_id:
399 client_json = 'ipcontroller-client.json'
405 client_json = 'ipcontroller-client.json'
400 else:
406 else:
401 client_json = 'ipcontroller-%s-client.json' % cluster_id
407 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 url_file = pjoin(self._cd.security_dir, client_json)
408 url_file = pjoin(self._cd.security_dir, client_json)
403 if url_file is None:
409 if url_file is None:
404 raise ValueError(
410 raise ValueError(
405 "I can't find enough information to connect to a hub!"
411 "I can't find enough information to connect to a hub!"
406 " Please specify at least one of url_file or profile."
412 " Please specify at least one of url_file or profile."
407 )
413 )
408
414
409 with open(url_file) as f:
415 with open(url_file) as f:
410 cfg = json.load(f)
416 cfg = json.load(f)
411
417
412 self._task_scheme = cfg['task_scheme']
418 self._task_scheme = cfg['task_scheme']
413
419
414 # sync defaults from args, json:
420 # sync defaults from args, json:
415 if sshserver:
421 if sshserver:
416 cfg['ssh'] = sshserver
422 cfg['ssh'] = sshserver
417
423
418 location = cfg.setdefault('location', None)
424 location = cfg.setdefault('location', None)
419
425
420 proto,addr = cfg['interface'].split('://')
426 proto,addr = cfg['interface'].split('://')
421 addr = util.disambiguate_ip_address(addr, location)
427 addr = util.disambiguate_ip_address(addr, location)
422 cfg['interface'] = "%s://%s" % (proto, addr)
428 cfg['interface'] = "%s://%s" % (proto, addr)
423
429
424 # turn interface,port into full urls:
430 # turn interface,port into full urls:
425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
431 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
432 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427
433
428 url = cfg['registration']
434 url = cfg['registration']
429
435
430 if location is not None and addr == LOCALHOST:
436 if location is not None and addr == LOCALHOST:
431 # location specified, and connection is expected to be local
437 # location specified, and connection is expected to be local
432 if location not in LOCAL_IPS and not sshserver:
438 if location not in LOCAL_IPS and not sshserver:
433 # load ssh from JSON *only* if the controller is not on
439 # load ssh from JSON *only* if the controller is not on
434 # this machine
440 # this machine
435 sshserver=cfg['ssh']
441 sshserver=cfg['ssh']
436 if location not in LOCAL_IPS and not sshserver:
442 if location not in LOCAL_IPS and not sshserver:
437 # warn if no ssh specified, but SSH is probably needed
443 # warn if no ssh specified, but SSH is probably needed
438 # This is only a warning, because the most likely cause
444 # This is only a warning, because the most likely cause
439 # is a local Controller on a laptop whose IP is dynamic
445 # is a local Controller on a laptop whose IP is dynamic
440 warnings.warn("""
446 warnings.warn("""
441 Controller appears to be listening on localhost, but not on this machine.
447 Controller appears to be listening on localhost, but not on this machine.
442 If this is true, you should specify Client(...,sshserver='you@%s')
448 If this is true, you should specify Client(...,sshserver='you@%s')
443 or instruct your controller to listen on an external IP."""%location,
449 or instruct your controller to listen on an external IP."""%location,
444 RuntimeWarning)
450 RuntimeWarning)
445 elif not sshserver:
451 elif not sshserver:
446 # otherwise sync with cfg
452 # otherwise sync with cfg
447 sshserver = cfg['ssh']
453 sshserver = cfg['ssh']
448
454
449 self._config = cfg
455 self._config = cfg
450
456
451 self._ssh = bool(sshserver or sshkey or password)
457 self._ssh = bool(sshserver or sshkey or password)
452 if self._ssh and sshserver is None:
458 if self._ssh and sshserver is None:
453 # default to ssh via localhost
459 # default to ssh via localhost
454 sshserver = addr
460 sshserver = addr
455 if self._ssh and password is None:
461 if self._ssh and password is None:
456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
462 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 password=False
463 password=False
458 else:
464 else:
459 password = getpass("SSH Password for %s: "%sshserver)
465 password = getpass("SSH Password for %s: "%sshserver)
460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
466 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461
467
462 # configure and construct the session
468 # configure and construct the session
463 try:
469 try:
464 extra_args['packer'] = cfg['pack']
470 extra_args['packer'] = cfg['pack']
465 extra_args['unpacker'] = cfg['unpack']
471 extra_args['unpacker'] = cfg['unpack']
466 extra_args['key'] = cast_bytes(cfg['key'])
472 extra_args['key'] = cast_bytes(cfg['key'])
467 extra_args['signature_scheme'] = cfg['signature_scheme']
473 extra_args['signature_scheme'] = cfg['signature_scheme']
468 except KeyError as exc:
474 except KeyError as exc:
469 msg = '\n'.join([
475 msg = '\n'.join([
470 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
476 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
471 "If you are reusing connection files, remove them and start ipcontroller again."
477 "If you are reusing connection files, remove them and start ipcontroller again."
472 ])
478 ])
473 raise ValueError(msg.format(exc.message))
479 raise ValueError(msg.format(exc.message))
474
480
475 self.session = Session(**extra_args)
481 self.session = Session(**extra_args)
476
482
477 self._query_socket = self._context.socket(zmq.DEALER)
483 self._query_socket = self._context.socket(zmq.DEALER)
478
484
479 if self._ssh:
485 if self._ssh:
480 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
486 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
481 else:
487 else:
482 self._query_socket.connect(cfg['registration'])
488 self._query_socket.connect(cfg['registration'])
483
489
484 self.session.debug = self.debug
490 self.session.debug = self.debug
485
491
486 self._notification_handlers = {'registration_notification' : self._register_engine,
492 self._notification_handlers = {'registration_notification' : self._register_engine,
487 'unregistration_notification' : self._unregister_engine,
493 'unregistration_notification' : self._unregister_engine,
488 'shutdown_notification' : lambda msg: self.close(),
494 'shutdown_notification' : lambda msg: self.close(),
489 }
495 }
490 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
496 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
491 'apply_reply' : self._handle_apply_reply}
497 'apply_reply' : self._handle_apply_reply}
492 self._connect(sshserver, ssh_kwargs, timeout)
498 self._connect(sshserver, ssh_kwargs, timeout)
493
499
494 # last step: setup magics, if we are in IPython:
500 # last step: setup magics, if we are in IPython:
495
501
496 try:
502 try:
497 ip = get_ipython()
503 ip = get_ipython()
498 except NameError:
504 except NameError:
499 return
505 return
500 else:
506 else:
501 if 'px' not in ip.magics_manager.magics:
507 if 'px' not in ip.magics_manager.magics:
502 # in IPython but we are the first Client.
508 # in IPython but we are the first Client.
503 # activate a default view for parallel magics.
509 # activate a default view for parallel magics.
504 self.activate()
510 self.activate()
505
511
506 def __del__(self):
512 def __del__(self):
507 """cleanup sockets, but _not_ context."""
513 """cleanup sockets, but _not_ context."""
508 self.close()
514 self.close()
509
515
510 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
516 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
511 if ipython_dir is None:
517 if ipython_dir is None:
512 ipython_dir = get_ipython_dir()
518 ipython_dir = get_ipython_dir()
513 if profile_dir is not None:
519 if profile_dir is not None:
514 try:
520 try:
515 self._cd = ProfileDir.find_profile_dir(profile_dir)
521 self._cd = ProfileDir.find_profile_dir(profile_dir)
516 return
522 return
517 except ProfileDirError:
523 except ProfileDirError:
518 pass
524 pass
519 elif profile is not None:
525 elif profile is not None:
520 try:
526 try:
521 self._cd = ProfileDir.find_profile_dir_by_name(
527 self._cd = ProfileDir.find_profile_dir_by_name(
522 ipython_dir, profile)
528 ipython_dir, profile)
523 return
529 return
524 except ProfileDirError:
530 except ProfileDirError:
525 pass
531 pass
526 self._cd = None
532 self._cd = None
527
533
528 def _update_engines(self, engines):
534 def _update_engines(self, engines):
529 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
535 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
530 for k,v in engines.iteritems():
536 for k,v in engines.iteritems():
531 eid = int(k)
537 eid = int(k)
532 if eid not in self._engines:
538 if eid not in self._engines:
533 self._ids.append(eid)
539 self._ids.append(eid)
534 self._engines[eid] = v
540 self._engines[eid] = v
535 self._ids = sorted(self._ids)
541 self._ids = sorted(self._ids)
536 if sorted(self._engines.keys()) != range(len(self._engines)) and \
542 if sorted(self._engines.keys()) != range(len(self._engines)) and \
537 self._task_scheme == 'pure' and self._task_socket:
543 self._task_scheme == 'pure' and self._task_socket:
538 self._stop_scheduling_tasks()
544 self._stop_scheduling_tasks()
539
545
540 def _stop_scheduling_tasks(self):
546 def _stop_scheduling_tasks(self):
541 """Stop scheduling tasks because an engine has been unregistered
547 """Stop scheduling tasks because an engine has been unregistered
542 from a pure ZMQ scheduler.
548 from a pure ZMQ scheduler.
543 """
549 """
544 self._task_socket.close()
550 self._task_socket.close()
545 self._task_socket = None
551 self._task_socket = None
546 msg = "An engine has been unregistered, and we are using pure " +\
552 msg = "An engine has been unregistered, and we are using pure " +\
547 "ZMQ task scheduling. Task farming will be disabled."
553 "ZMQ task scheduling. Task farming will be disabled."
548 if self.outstanding:
554 if self.outstanding:
549 msg += " If you were running tasks when this happened, " +\
555 msg += " If you were running tasks when this happened, " +\
550 "some `outstanding` msg_ids may never resolve."
556 "some `outstanding` msg_ids may never resolve."
551 warnings.warn(msg, RuntimeWarning)
557 warnings.warn(msg, RuntimeWarning)
552
558
553 def _build_targets(self, targets):
559 def _build_targets(self, targets):
554 """Turn valid target IDs or 'all' into two lists:
560 """Turn valid target IDs or 'all' into two lists:
555 (int_ids, uuids).
561 (int_ids, uuids).
556 """
562 """
557 if not self._ids:
563 if not self._ids:
558 # flush notification socket if no engines yet, just in case
564 # flush notification socket if no engines yet, just in case
559 if not self.ids:
565 if not self.ids:
560 raise error.NoEnginesRegistered("Can't build targets without any engines")
566 raise error.NoEnginesRegistered("Can't build targets without any engines")
561
567
562 if targets is None:
568 if targets is None:
563 targets = self._ids
569 targets = self._ids
564 elif isinstance(targets, basestring):
570 elif isinstance(targets, basestring):
565 if targets.lower() == 'all':
571 if targets.lower() == 'all':
566 targets = self._ids
572 targets = self._ids
567 else:
573 else:
568 raise TypeError("%r not valid str target, must be 'all'"%(targets))
574 raise TypeError("%r not valid str target, must be 'all'"%(targets))
569 elif isinstance(targets, int):
575 elif isinstance(targets, int):
570 if targets < 0:
576 if targets < 0:
571 targets = self.ids[targets]
577 targets = self.ids[targets]
572 if targets not in self._ids:
578 if targets not in self._ids:
573 raise IndexError("No such engine: %i"%targets)
579 raise IndexError("No such engine: %i"%targets)
574 targets = [targets]
580 targets = [targets]
575
581
576 if isinstance(targets, slice):
582 if isinstance(targets, slice):
577 indices = range(len(self._ids))[targets]
583 indices = range(len(self._ids))[targets]
578 ids = self.ids
584 ids = self.ids
579 targets = [ ids[i] for i in indices ]
585 targets = [ ids[i] for i in indices ]
580
586
581 if not isinstance(targets, (tuple, list, xrange)):
587 if not isinstance(targets, (tuple, list, xrange)):
582 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
588 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
583
589
584 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
590 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
585
591
586 def _connect(self, sshserver, ssh_kwargs, timeout):
592 def _connect(self, sshserver, ssh_kwargs, timeout):
587 """setup all our socket connections to the cluster. This is called from
593 """setup all our socket connections to the cluster. This is called from
588 __init__."""
594 __init__."""
589
595
590 # Maybe allow reconnecting?
596 # Maybe allow reconnecting?
591 if self._connected:
597 if self._connected:
592 return
598 return
593 self._connected=True
599 self._connected=True
594
600
595 def connect_socket(s, url):
601 def connect_socket(s, url):
596 # url = util.disambiguate_url(url, self._config['location'])
602 # url = util.disambiguate_url(url, self._config['location'])
597 if self._ssh:
603 if self._ssh:
598 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
599 else:
605 else:
600 return s.connect(url)
606 return s.connect(url)
601
607
602 self.session.send(self._query_socket, 'connection_request')
608 self.session.send(self._query_socket, 'connection_request')
603 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
604 poller = zmq.Poller()
610 poller = zmq.Poller()
605 poller.register(self._query_socket, zmq.POLLIN)
611 poller.register(self._query_socket, zmq.POLLIN)
606 # poll expects milliseconds, timeout is seconds
612 # poll expects milliseconds, timeout is seconds
607 evts = poller.poll(timeout*1000)
613 evts = poller.poll(timeout*1000)
608 if not evts:
614 if not evts:
609 raise error.TimeoutError("Hub connection request timed out")
615 raise error.TimeoutError("Hub connection request timed out")
610 idents,msg = self.session.recv(self._query_socket,mode=0)
616 idents,msg = self.session.recv(self._query_socket,mode=0)
611 if self.debug:
617 if self.debug:
612 pprint(msg)
618 pprint(msg)
613 content = msg['content']
619 content = msg['content']
614 # self._config['registration'] = dict(content)
620 # self._config['registration'] = dict(content)
615 cfg = self._config
621 cfg = self._config
616 if content['status'] == 'ok':
622 if content['status'] == 'ok':
617 self._mux_socket = self._context.socket(zmq.DEALER)
623 self._mux_socket = self._context.socket(zmq.DEALER)
618 connect_socket(self._mux_socket, cfg['mux'])
624 connect_socket(self._mux_socket, cfg['mux'])
619
625
620 self._task_socket = self._context.socket(zmq.DEALER)
626 self._task_socket = self._context.socket(zmq.DEALER)
621 connect_socket(self._task_socket, cfg['task'])
627 connect_socket(self._task_socket, cfg['task'])
622
628
623 self._notification_socket = self._context.socket(zmq.SUB)
629 self._notification_socket = self._context.socket(zmq.SUB)
624 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
625 connect_socket(self._notification_socket, cfg['notification'])
631 connect_socket(self._notification_socket, cfg['notification'])
626
632
627 self._control_socket = self._context.socket(zmq.DEALER)
633 self._control_socket = self._context.socket(zmq.DEALER)
628 connect_socket(self._control_socket, cfg['control'])
634 connect_socket(self._control_socket, cfg['control'])
629
635
630 self._iopub_socket = self._context.socket(zmq.SUB)
636 self._iopub_socket = self._context.socket(zmq.SUB)
631 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
632 connect_socket(self._iopub_socket, cfg['iopub'])
638 connect_socket(self._iopub_socket, cfg['iopub'])
633
639
634 self._update_engines(dict(content['engines']))
640 self._update_engines(dict(content['engines']))
635 else:
641 else:
636 self._connected = False
642 self._connected = False
637 raise Exception("Failed to connect!")
643 raise Exception("Failed to connect!")
638
644
639 #--------------------------------------------------------------------------
645 #--------------------------------------------------------------------------
640 # handlers and callbacks for incoming messages
646 # handlers and callbacks for incoming messages
641 #--------------------------------------------------------------------------
647 #--------------------------------------------------------------------------
642
648
643 def _unwrap_exception(self, content):
649 def _unwrap_exception(self, content):
644 """unwrap exception, and remap engine_id to int."""
650 """unwrap exception, and remap engine_id to int."""
645 e = error.unwrap_exception(content)
651 e = error.unwrap_exception(content)
646 # print e.traceback
652 # print e.traceback
647 if e.engine_info:
653 if e.engine_info:
648 e_uuid = e.engine_info['engine_uuid']
654 e_uuid = e.engine_info['engine_uuid']
649 eid = self._engines[e_uuid]
655 eid = self._engines[e_uuid]
650 e.engine_info['engine_id'] = eid
656 e.engine_info['engine_id'] = eid
651 return e
657 return e
652
658
653 def _extract_metadata(self, msg):
659 def _extract_metadata(self, msg):
654 header = msg['header']
660 header = msg['header']
655 parent = msg['parent_header']
661 parent = msg['parent_header']
656 msg_meta = msg['metadata']
662 msg_meta = msg['metadata']
657 content = msg['content']
663 content = msg['content']
658 md = {'msg_id' : parent['msg_id'],
664 md = {'msg_id' : parent['msg_id'],
659 'received' : datetime.now(),
665 'received' : datetime.now(),
660 'engine_uuid' : msg_meta.get('engine', None),
666 'engine_uuid' : msg_meta.get('engine', None),
661 'follow' : msg_meta.get('follow', []),
667 'follow' : msg_meta.get('follow', []),
662 'after' : msg_meta.get('after', []),
668 'after' : msg_meta.get('after', []),
663 'status' : content['status'],
669 'status' : content['status'],
664 }
670 }
665
671
666 if md['engine_uuid'] is not None:
672 if md['engine_uuid'] is not None:
667 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
668
674
669 if 'date' in parent:
675 if 'date' in parent:
670 md['submitted'] = parent['date']
676 md['submitted'] = parent['date']
671 if 'started' in msg_meta:
677 if 'started' in msg_meta:
672 md['started'] = msg_meta['started']
678 md['started'] = msg_meta['started']
673 if 'date' in header:
679 if 'date' in header:
674 md['completed'] = header['date']
680 md['completed'] = header['date']
675 return md
681 return md
676
682
677 def _register_engine(self, msg):
683 def _register_engine(self, msg):
678 """Register a new engine, and update our connection info."""
684 """Register a new engine, and update our connection info."""
679 content = msg['content']
685 content = msg['content']
680 eid = content['id']
686 eid = content['id']
681 d = {eid : content['uuid']}
687 d = {eid : content['uuid']}
682 self._update_engines(d)
688 self._update_engines(d)
683
689
684 def _unregister_engine(self, msg):
690 def _unregister_engine(self, msg):
685 """Unregister an engine that has died."""
691 """Unregister an engine that has died."""
686 content = msg['content']
692 content = msg['content']
687 eid = int(content['id'])
693 eid = int(content['id'])
688 if eid in self._ids:
694 if eid in self._ids:
689 self._ids.remove(eid)
695 self._ids.remove(eid)
690 uuid = self._engines.pop(eid)
696 uuid = self._engines.pop(eid)
691
697
692 self._handle_stranded_msgs(eid, uuid)
698 self._handle_stranded_msgs(eid, uuid)
693
699
694 if self._task_socket and self._task_scheme == 'pure':
700 if self._task_socket and self._task_scheme == 'pure':
695 self._stop_scheduling_tasks()
701 self._stop_scheduling_tasks()
696
702
697 def _handle_stranded_msgs(self, eid, uuid):
703 def _handle_stranded_msgs(self, eid, uuid):
698 """Handle messages known to be on an engine when the engine unregisters.
704 """Handle messages known to be on an engine when the engine unregisters.
699
705
700 It is possible that this will fire prematurely - that is, an engine will
706 It is possible that this will fire prematurely - that is, an engine will
701 go down after completing a result, and the client will be notified
707 go down after completing a result, and the client will be notified
702 of the unregistration and later receive the successful result.
708 of the unregistration and later receive the successful result.
703 """
709 """
704
710
705 outstanding = self._outstanding_dict[uuid]
711 outstanding = self._outstanding_dict[uuid]
706
712
707 for msg_id in list(outstanding):
713 for msg_id in list(outstanding):
708 if msg_id in self.results:
714 if msg_id in self.results:
709 # we already
715 # we already
710 continue
716 continue
711 try:
717 try:
712 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
713 except:
719 except:
714 content = error.wrap_exception()
720 content = error.wrap_exception()
715 # build a fake message:
721 # build a fake message:
716 msg = self.session.msg('apply_reply', content=content)
722 msg = self.session.msg('apply_reply', content=content)
717 msg['parent_header']['msg_id'] = msg_id
723 msg['parent_header']['msg_id'] = msg_id
718 msg['metadata']['engine'] = uuid
724 msg['metadata']['engine'] = uuid
719 self._handle_apply_reply(msg)
725 self._handle_apply_reply(msg)
720
726
721 def _handle_execute_reply(self, msg):
727 def _handle_execute_reply(self, msg):
722 """Save the reply to an execute_request into our results.
728 """Save the reply to an execute_request into our results.
723
729
724 execute messages are never actually used. apply is used instead.
730 execute messages are never actually used. apply is used instead.
725 """
731 """
726
732
727 parent = msg['parent_header']
733 parent = msg['parent_header']
728 msg_id = parent['msg_id']
734 msg_id = parent['msg_id']
729 if msg_id not in self.outstanding:
735 if msg_id not in self.outstanding:
730 if msg_id in self.history:
736 if msg_id in self.history:
731 print ("got stale result: %s"%msg_id)
737 print ("got stale result: %s"%msg_id)
732 else:
738 else:
733 print ("got unknown result: %s"%msg_id)
739 print ("got unknown result: %s"%msg_id)
734 else:
740 else:
735 self.outstanding.remove(msg_id)
741 self.outstanding.remove(msg_id)
736
742
737 content = msg['content']
743 content = msg['content']
738 header = msg['header']
744 header = msg['header']
739
745
740 # construct metadata:
746 # construct metadata:
741 md = self.metadata[msg_id]
747 md = self.metadata[msg_id]
742 md.update(self._extract_metadata(msg))
748 md.update(self._extract_metadata(msg))
743 # is this redundant?
749 # is this redundant?
744 self.metadata[msg_id] = md
750 self.metadata[msg_id] = md
745
751
746 e_outstanding = self._outstanding_dict[md['engine_uuid']]
752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
747 if msg_id in e_outstanding:
753 if msg_id in e_outstanding:
748 e_outstanding.remove(msg_id)
754 e_outstanding.remove(msg_id)
749
755
750 # construct result:
756 # construct result:
751 if content['status'] == 'ok':
757 if content['status'] == 'ok':
752 self.results[msg_id] = ExecuteReply(msg_id, content, md)
758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
753 elif content['status'] == 'aborted':
759 elif content['status'] == 'aborted':
754 self.results[msg_id] = error.TaskAborted(msg_id)
760 self.results[msg_id] = error.TaskAborted(msg_id)
755 elif content['status'] == 'resubmitted':
761 elif content['status'] == 'resubmitted':
756 # TODO: handle resubmission
762 # TODO: handle resubmission
757 pass
763 pass
758 else:
764 else:
759 self.results[msg_id] = self._unwrap_exception(content)
765 self.results[msg_id] = self._unwrap_exception(content)
760
766
761 def _handle_apply_reply(self, msg):
767 def _handle_apply_reply(self, msg):
762 """Save the reply to an apply_request into our results."""
768 """Save the reply to an apply_request into our results."""
763 parent = msg['parent_header']
769 parent = msg['parent_header']
764 msg_id = parent['msg_id']
770 msg_id = parent['msg_id']
765 if msg_id not in self.outstanding:
771 if msg_id not in self.outstanding:
766 if msg_id in self.history:
772 if msg_id in self.history:
767 print ("got stale result: %s"%msg_id)
773 print ("got stale result: %s"%msg_id)
768 print self.results[msg_id]
774 print self.results[msg_id]
769 print msg
775 print msg
770 else:
776 else:
771 print ("got unknown result: %s"%msg_id)
777 print ("got unknown result: %s"%msg_id)
772 else:
778 else:
773 self.outstanding.remove(msg_id)
779 self.outstanding.remove(msg_id)
774 content = msg['content']
780 content = msg['content']
775 header = msg['header']
781 header = msg['header']
776
782
777 # construct metadata:
783 # construct metadata:
778 md = self.metadata[msg_id]
784 md = self.metadata[msg_id]
779 md.update(self._extract_metadata(msg))
785 md.update(self._extract_metadata(msg))
780 # is this redundant?
786 # is this redundant?
781 self.metadata[msg_id] = md
787 self.metadata[msg_id] = md
782
788
783 e_outstanding = self._outstanding_dict[md['engine_uuid']]
789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
784 if msg_id in e_outstanding:
790 if msg_id in e_outstanding:
785 e_outstanding.remove(msg_id)
791 e_outstanding.remove(msg_id)
786
792
787 # construct result:
793 # construct result:
788 if content['status'] == 'ok':
794 if content['status'] == 'ok':
789 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
790 elif content['status'] == 'aborted':
796 elif content['status'] == 'aborted':
791 self.results[msg_id] = error.TaskAborted(msg_id)
797 self.results[msg_id] = error.TaskAborted(msg_id)
792 elif content['status'] == 'resubmitted':
798 elif content['status'] == 'resubmitted':
793 # TODO: handle resubmission
799 # TODO: handle resubmission
794 pass
800 pass
795 else:
801 else:
796 self.results[msg_id] = self._unwrap_exception(content)
802 self.results[msg_id] = self._unwrap_exception(content)
797
803
798 def _flush_notifications(self):
804 def _flush_notifications(self):
799 """Flush notifications of engine registrations waiting
805 """Flush notifications of engine registrations waiting
800 in ZMQ queue."""
806 in ZMQ queue."""
801 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
802 while msg is not None:
808 while msg is not None:
803 if self.debug:
809 if self.debug:
804 pprint(msg)
810 pprint(msg)
805 msg_type = msg['header']['msg_type']
811 msg_type = msg['header']['msg_type']
806 handler = self._notification_handlers.get(msg_type, None)
812 handler = self._notification_handlers.get(msg_type, None)
807 if handler is None:
813 if handler is None:
808 raise Exception("Unhandled message type: %s" % msg_type)
814 raise Exception("Unhandled message type: %s" % msg_type)
809 else:
815 else:
810 handler(msg)
816 handler(msg)
811 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
812
818
813 def _flush_results(self, sock):
819 def _flush_results(self, sock):
814 """Flush task or queue results waiting in ZMQ queue."""
820 """Flush task or queue results waiting in ZMQ queue."""
815 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
816 while msg is not None:
822 while msg is not None:
817 if self.debug:
823 if self.debug:
818 pprint(msg)
824 pprint(msg)
819 msg_type = msg['header']['msg_type']
825 msg_type = msg['header']['msg_type']
820 handler = self._queue_handlers.get(msg_type, None)
826 handler = self._queue_handlers.get(msg_type, None)
821 if handler is None:
827 if handler is None:
822 raise Exception("Unhandled message type: %s" % msg_type)
828 raise Exception("Unhandled message type: %s" % msg_type)
823 else:
829 else:
824 handler(msg)
830 handler(msg)
825 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
826
832
827 def _flush_control(self, sock):
833 def _flush_control(self, sock):
828 """Flush replies from the control channel waiting
834 """Flush replies from the control channel waiting
829 in the ZMQ queue.
835 in the ZMQ queue.
830
836
831 Currently: ignore them."""
837 Currently: ignore them."""
832 if self._ignored_control_replies <= 0:
838 if self._ignored_control_replies <= 0:
833 return
839 return
834 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
835 while msg is not None:
841 while msg is not None:
836 self._ignored_control_replies -= 1
842 self._ignored_control_replies -= 1
837 if self.debug:
843 if self.debug:
838 pprint(msg)
844 pprint(msg)
839 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840
846
841 def _flush_ignored_control(self):
847 def _flush_ignored_control(self):
842 """flush ignored control replies"""
848 """flush ignored control replies"""
843 while self._ignored_control_replies > 0:
849 while self._ignored_control_replies > 0:
844 self.session.recv(self._control_socket)
850 self.session.recv(self._control_socket)
845 self._ignored_control_replies -= 1
851 self._ignored_control_replies -= 1
846
852
847 def _flush_ignored_hub_replies(self):
853 def _flush_ignored_hub_replies(self):
848 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
849 while msg is not None:
855 while msg is not None:
850 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
851
857
852 def _flush_iopub(self, sock):
858 def _flush_iopub(self, sock):
853 """Flush replies from the iopub channel waiting
859 """Flush replies from the iopub channel waiting
854 in the ZMQ queue.
860 in the ZMQ queue.
855 """
861 """
856 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
857 while msg is not None:
863 while msg is not None:
858 if self.debug:
864 if self.debug:
859 pprint(msg)
865 pprint(msg)
860 parent = msg['parent_header']
866 parent = msg['parent_header']
861 # ignore IOPub messages with no parent.
867 # ignore IOPub messages with no parent.
862 # Caused by print statements or warnings from before the first execution.
868 # Caused by print statements or warnings from before the first execution.
863 if not parent:
869 if not parent:
864 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
865 continue
871 continue
866 msg_id = parent['msg_id']
872 msg_id = parent['msg_id']
867 content = msg['content']
873 content = msg['content']
868 header = msg['header']
874 header = msg['header']
869 msg_type = msg['header']['msg_type']
875 msg_type = msg['header']['msg_type']
870
876
871 # init metadata:
877 # init metadata:
872 md = self.metadata[msg_id]
878 md = self.metadata[msg_id]
873
879
874 if msg_type == 'stream':
880 if msg_type == 'stream':
875 name = content['name']
881 name = content['name']
876 s = md[name] or ''
882 s = md[name] or ''
877 md[name] = s + content['data']
883 md[name] = s + content['data']
878 elif msg_type == 'pyerr':
884 elif msg_type == 'pyerr':
879 md.update({'pyerr' : self._unwrap_exception(content)})
885 md.update({'pyerr' : self._unwrap_exception(content)})
880 elif msg_type == 'pyin':
886 elif msg_type == 'pyin':
881 md.update({'pyin' : content['code']})
887 md.update({'pyin' : content['code']})
882 elif msg_type == 'display_data':
888 elif msg_type == 'display_data':
883 md['outputs'].append(content)
889 md['outputs'].append(content)
884 elif msg_type == 'pyout':
890 elif msg_type == 'pyout':
885 md['pyout'] = content
891 md['pyout'] = content
886 elif msg_type == 'data_message':
892 elif msg_type == 'data_message':
887 data, remainder = serialize.unserialize_object(msg['buffers'])
893 data, remainder = serialize.unserialize_object(msg['buffers'])
888 md['data'].update(data)
894 md['data'].update(data)
889 elif msg_type == 'status':
895 elif msg_type == 'status':
890 # idle message comes after all outputs
896 # idle message comes after all outputs
891 if content['execution_state'] == 'idle':
897 if content['execution_state'] == 'idle':
892 md['outputs_ready'] = True
898 md['outputs_ready'] = True
893 else:
899 else:
894 # unhandled msg_type (status, etc.)
900 # unhandled msg_type (status, etc.)
895 pass
901 pass
896
902
897 # reduntant?
903 # reduntant?
898 self.metadata[msg_id] = md
904 self.metadata[msg_id] = md
899
905
900 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
901
907
902 #--------------------------------------------------------------------------
908 #--------------------------------------------------------------------------
903 # len, getitem
909 # len, getitem
904 #--------------------------------------------------------------------------
910 #--------------------------------------------------------------------------
905
911
906 def __len__(self):
912 def __len__(self):
907 """len(client) returns # of engines."""
913 """len(client) returns # of engines."""
908 return len(self.ids)
914 return len(self.ids)
909
915
910 def __getitem__(self, key):
916 def __getitem__(self, key):
911 """index access returns DirectView multiplexer objects
917 """index access returns DirectView multiplexer objects
912
918
913 Must be int, slice, or list/tuple/xrange of ints"""
919 Must be int, slice, or list/tuple/xrange of ints"""
914 if not isinstance(key, (int, slice, tuple, list, xrange)):
920 if not isinstance(key, (int, slice, tuple, list, xrange)):
915 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
916 else:
922 else:
917 return self.direct_view(key)
923 return self.direct_view(key)
918
924
919 #--------------------------------------------------------------------------
925 #--------------------------------------------------------------------------
920 # Begin public methods
926 # Begin public methods
921 #--------------------------------------------------------------------------
927 #--------------------------------------------------------------------------
922
928
923 @property
929 @property
924 def ids(self):
930 def ids(self):
925 """Always up-to-date ids property."""
931 """Always up-to-date ids property."""
926 self._flush_notifications()
932 self._flush_notifications()
927 # always copy:
933 # always copy:
928 return list(self._ids)
934 return list(self._ids)
929
935
930 def activate(self, targets='all', suffix=''):
936 def activate(self, targets='all', suffix=''):
931 """Create a DirectView and register it with IPython magics
937 """Create a DirectView and register it with IPython magics
932
938
933 Defines the magics `%px, %autopx, %pxresult, %%px`
939 Defines the magics `%px, %autopx, %pxresult, %%px`
934
940
935 Parameters
941 Parameters
936 ----------
942 ----------
937
943
938 targets: int, list of ints, or 'all'
944 targets: int, list of ints, or 'all'
939 The engines on which the view's magics will run
945 The engines on which the view's magics will run
940 suffix: str [default: '']
946 suffix: str [default: '']
941 The suffix, if any, for the magics. This allows you to have
947 The suffix, if any, for the magics. This allows you to have
942 multiple views associated with parallel magics at the same time.
948 multiple views associated with parallel magics at the same time.
943
949
944 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
945 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
946 on engine 0.
952 on engine 0.
947 """
953 """
948 view = self.direct_view(targets)
954 view = self.direct_view(targets)
949 view.block = True
955 view.block = True
950 view.activate(suffix)
956 view.activate(suffix)
951 return view
957 return view
952
958
953 def close(self):
959 def close(self):
954 if self._closed:
960 if self._closed:
955 return
961 return
956 self.stop_spin_thread()
962 self.stop_spin_thread()
957 snames = filter(lambda n: n.endswith('socket'), dir(self))
963 snames = filter(lambda n: n.endswith('socket'), dir(self))
958 for socket in map(lambda name: getattr(self, name), snames):
964 for socket in map(lambda name: getattr(self, name), snames):
959 if isinstance(socket, zmq.Socket) and not socket.closed:
965 if isinstance(socket, zmq.Socket) and not socket.closed:
960 socket.close()
966 socket.close()
961 self._closed = True
967 self._closed = True
962
968
963 def _spin_every(self, interval=1):
969 def _spin_every(self, interval=1):
964 """target func for use in spin_thread"""
970 """target func for use in spin_thread"""
965 while True:
971 while True:
966 if self._stop_spinning.is_set():
972 if self._stop_spinning.is_set():
967 return
973 return
968 time.sleep(interval)
974 time.sleep(interval)
969 self.spin()
975 self.spin()
970
976
971 def spin_thread(self, interval=1):
977 def spin_thread(self, interval=1):
972 """call Client.spin() in a background thread on some regular interval
978 """call Client.spin() in a background thread on some regular interval
973
979
974 This helps ensure that messages don't pile up too much in the zmq queue
980 This helps ensure that messages don't pile up too much in the zmq queue
975 while you are working on other things, or just leaving an idle terminal.
981 while you are working on other things, or just leaving an idle terminal.
976
982
977 It also helps limit potential padding of the `received` timestamp
983 It also helps limit potential padding of the `received` timestamp
978 on AsyncResult objects, used for timings.
984 on AsyncResult objects, used for timings.
979
985
980 Parameters
986 Parameters
981 ----------
987 ----------
982
988
983 interval : float, optional
989 interval : float, optional
984 The interval on which to spin the client in the background thread
990 The interval on which to spin the client in the background thread
985 (simply passed to time.sleep).
991 (simply passed to time.sleep).
986
992
987 Notes
993 Notes
988 -----
994 -----
989
995
990 For precision timing, you may want to use this method to put a bound
996 For precision timing, you may want to use this method to put a bound
991 on the jitter (in seconds) in `received` timestamps used
997 on the jitter (in seconds) in `received` timestamps used
992 in AsyncResult.wall_time.
998 in AsyncResult.wall_time.
993
999
994 """
1000 """
995 if self._spin_thread is not None:
1001 if self._spin_thread is not None:
996 self.stop_spin_thread()
1002 self.stop_spin_thread()
997 self._stop_spinning.clear()
1003 self._stop_spinning.clear()
998 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1004 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
999 self._spin_thread.daemon = True
1005 self._spin_thread.daemon = True
1000 self._spin_thread.start()
1006 self._spin_thread.start()
1001
1007
1002 def stop_spin_thread(self):
1008 def stop_spin_thread(self):
1003 """stop background spin_thread, if any"""
1009 """stop background spin_thread, if any"""
1004 if self._spin_thread is not None:
1010 if self._spin_thread is not None:
1005 self._stop_spinning.set()
1011 self._stop_spinning.set()
1006 self._spin_thread.join()
1012 self._spin_thread.join()
1007 self._spin_thread = None
1013 self._spin_thread = None
1008
1014
1009 def spin(self):
1015 def spin(self):
1010 """Flush any registration notifications and execution results
1016 """Flush any registration notifications and execution results
1011 waiting in the ZMQ queue.
1017 waiting in the ZMQ queue.
1012 """
1018 """
1013 if self._notification_socket:
1019 if self._notification_socket:
1014 self._flush_notifications()
1020 self._flush_notifications()
1015 if self._iopub_socket:
1021 if self._iopub_socket:
1016 self._flush_iopub(self._iopub_socket)
1022 self._flush_iopub(self._iopub_socket)
1017 if self._mux_socket:
1023 if self._mux_socket:
1018 self._flush_results(self._mux_socket)
1024 self._flush_results(self._mux_socket)
1019 if self._task_socket:
1025 if self._task_socket:
1020 self._flush_results(self._task_socket)
1026 self._flush_results(self._task_socket)
1021 if self._control_socket:
1027 if self._control_socket:
1022 self._flush_control(self._control_socket)
1028 self._flush_control(self._control_socket)
1023 if self._query_socket:
1029 if self._query_socket:
1024 self._flush_ignored_hub_replies()
1030 self._flush_ignored_hub_replies()
1025
1031
1026 def wait(self, jobs=None, timeout=-1):
1032 def wait(self, jobs=None, timeout=-1):
1027 """waits on one or more `jobs`, for up to `timeout` seconds.
1033 """waits on one or more `jobs`, for up to `timeout` seconds.
1028
1034
1029 Parameters
1035 Parameters
1030 ----------
1036 ----------
1031
1037
1032 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1038 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1033 ints are indices to self.history
1039 ints are indices to self.history
1034 strs are msg_ids
1040 strs are msg_ids
1035 default: wait on all outstanding messages
1041 default: wait on all outstanding messages
1036 timeout : float
1042 timeout : float
1037 a time in seconds, after which to give up.
1043 a time in seconds, after which to give up.
1038 default is -1, which means no timeout
1044 default is -1, which means no timeout
1039
1045
1040 Returns
1046 Returns
1041 -------
1047 -------
1042
1048
1043 True : when all msg_ids are done
1049 True : when all msg_ids are done
1044 False : timeout reached, some msg_ids still outstanding
1050 False : timeout reached, some msg_ids still outstanding
1045 """
1051 """
1046 tic = time.time()
1052 tic = time.time()
1047 if jobs is None:
1053 if jobs is None:
1048 theids = self.outstanding
1054 theids = self.outstanding
1049 else:
1055 else:
1050 if isinstance(jobs, (int, basestring, AsyncResult)):
1056 if isinstance(jobs, (int, basestring, AsyncResult)):
1051 jobs = [jobs]
1057 jobs = [jobs]
1052 theids = set()
1058 theids = set()
1053 for job in jobs:
1059 for job in jobs:
1054 if isinstance(job, int):
1060 if isinstance(job, int):
1055 # index access
1061 # index access
1056 job = self.history[job]
1062 job = self.history[job]
1057 elif isinstance(job, AsyncResult):
1063 elif isinstance(job, AsyncResult):
1058 map(theids.add, job.msg_ids)
1064 map(theids.add, job.msg_ids)
1059 continue
1065 continue
1060 theids.add(job)
1066 theids.add(job)
1061 if not theids.intersection(self.outstanding):
1067 if not theids.intersection(self.outstanding):
1062 return True
1068 return True
1063 self.spin()
1069 self.spin()
1064 while theids.intersection(self.outstanding):
1070 while theids.intersection(self.outstanding):
1065 if timeout >= 0 and ( time.time()-tic ) > timeout:
1071 if timeout >= 0 and ( time.time()-tic ) > timeout:
1066 break
1072 break
1067 time.sleep(1e-3)
1073 time.sleep(1e-3)
1068 self.spin()
1074 self.spin()
1069 return len(theids.intersection(self.outstanding)) == 0
1075 return len(theids.intersection(self.outstanding)) == 0
1070
1076
1071 #--------------------------------------------------------------------------
1077 #--------------------------------------------------------------------------
1072 # Control methods
1078 # Control methods
1073 #--------------------------------------------------------------------------
1079 #--------------------------------------------------------------------------
1074
1080
1075 @spin_first
1081 @spin_first
1076 def clear(self, targets=None, block=None):
1082 def clear(self, targets=None, block=None):
1077 """Clear the namespace in target(s)."""
1083 """Clear the namespace in target(s)."""
1078 block = self.block if block is None else block
1084 block = self.block if block is None else block
1079 targets = self._build_targets(targets)[0]
1085 targets = self._build_targets(targets)[0]
1080 for t in targets:
1086 for t in targets:
1081 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1087 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1082 error = False
1088 error = False
1083 if block:
1089 if block:
1084 self._flush_ignored_control()
1090 self._flush_ignored_control()
1085 for i in range(len(targets)):
1091 for i in range(len(targets)):
1086 idents,msg = self.session.recv(self._control_socket,0)
1092 idents,msg = self.session.recv(self._control_socket,0)
1087 if self.debug:
1093 if self.debug:
1088 pprint(msg)
1094 pprint(msg)
1089 if msg['content']['status'] != 'ok':
1095 if msg['content']['status'] != 'ok':
1090 error = self._unwrap_exception(msg['content'])
1096 error = self._unwrap_exception(msg['content'])
1091 else:
1097 else:
1092 self._ignored_control_replies += len(targets)
1098 self._ignored_control_replies += len(targets)
1093 if error:
1099 if error:
1094 raise error
1100 raise error
1095
1101
1096
1102
1097 @spin_first
1103 @spin_first
1098 def abort(self, jobs=None, targets=None, block=None):
1104 def abort(self, jobs=None, targets=None, block=None):
1099 """Abort specific jobs from the execution queues of target(s).
1105 """Abort specific jobs from the execution queues of target(s).
1100
1106
1101 This is a mechanism to prevent jobs that have already been submitted
1107 This is a mechanism to prevent jobs that have already been submitted
1102 from executing.
1108 from executing.
1103
1109
1104 Parameters
1110 Parameters
1105 ----------
1111 ----------
1106
1112
1107 jobs : msg_id, list of msg_ids, or AsyncResult
1113 jobs : msg_id, list of msg_ids, or AsyncResult
1108 The jobs to be aborted
1114 The jobs to be aborted
1109
1115
1110 If unspecified/None: abort all outstanding jobs.
1116 If unspecified/None: abort all outstanding jobs.
1111
1117
1112 """
1118 """
1113 block = self.block if block is None else block
1119 block = self.block if block is None else block
1114 jobs = jobs if jobs is not None else list(self.outstanding)
1120 jobs = jobs if jobs is not None else list(self.outstanding)
1115 targets = self._build_targets(targets)[0]
1121 targets = self._build_targets(targets)[0]
1116
1122
1117 msg_ids = []
1123 msg_ids = []
1118 if isinstance(jobs, (basestring,AsyncResult)):
1124 if isinstance(jobs, (basestring,AsyncResult)):
1119 jobs = [jobs]
1125 jobs = [jobs]
1120 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1126 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1121 if bad_ids:
1127 if bad_ids:
1122 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1128 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1123 for j in jobs:
1129 for j in jobs:
1124 if isinstance(j, AsyncResult):
1130 if isinstance(j, AsyncResult):
1125 msg_ids.extend(j.msg_ids)
1131 msg_ids.extend(j.msg_ids)
1126 else:
1132 else:
1127 msg_ids.append(j)
1133 msg_ids.append(j)
1128 content = dict(msg_ids=msg_ids)
1134 content = dict(msg_ids=msg_ids)
1129 for t in targets:
1135 for t in targets:
1130 self.session.send(self._control_socket, 'abort_request',
1136 self.session.send(self._control_socket, 'abort_request',
1131 content=content, ident=t)
1137 content=content, ident=t)
1132 error = False
1138 error = False
1133 if block:
1139 if block:
1134 self._flush_ignored_control()
1140 self._flush_ignored_control()
1135 for i in range(len(targets)):
1141 for i in range(len(targets)):
1136 idents,msg = self.session.recv(self._control_socket,0)
1142 idents,msg = self.session.recv(self._control_socket,0)
1137 if self.debug:
1143 if self.debug:
1138 pprint(msg)
1144 pprint(msg)
1139 if msg['content']['status'] != 'ok':
1145 if msg['content']['status'] != 'ok':
1140 error = self._unwrap_exception(msg['content'])
1146 error = self._unwrap_exception(msg['content'])
1141 else:
1147 else:
1142 self._ignored_control_replies += len(targets)
1148 self._ignored_control_replies += len(targets)
1143 if error:
1149 if error:
1144 raise error
1150 raise error
1145
1151
1146 @spin_first
1152 @spin_first
1147 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1153 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1148 """Terminates one or more engine processes, optionally including the hub.
1154 """Terminates one or more engine processes, optionally including the hub.
1149
1155
1150 Parameters
1156 Parameters
1151 ----------
1157 ----------
1152
1158
1153 targets: list of ints or 'all' [default: all]
1159 targets: list of ints or 'all' [default: all]
1154 Which engines to shutdown.
1160 Which engines to shutdown.
1155 hub: bool [default: False]
1161 hub: bool [default: False]
1156 Whether to include the Hub. hub=True implies targets='all'.
1162 Whether to include the Hub. hub=True implies targets='all'.
1157 block: bool [default: self.block]
1163 block: bool [default: self.block]
1158 Whether to wait for clean shutdown replies or not.
1164 Whether to wait for clean shutdown replies or not.
1159 restart: bool [default: False]
1165 restart: bool [default: False]
1160 NOT IMPLEMENTED
1166 NOT IMPLEMENTED
1161 whether to restart engines after shutting them down.
1167 whether to restart engines after shutting them down.
1162 """
1168 """
1163 from IPython.parallel.error import NoEnginesRegistered
1169 from IPython.parallel.error import NoEnginesRegistered
1164 if restart:
1170 if restart:
1165 raise NotImplementedError("Engine restart is not yet implemented")
1171 raise NotImplementedError("Engine restart is not yet implemented")
1166
1172
1167 block = self.block if block is None else block
1173 block = self.block if block is None else block
1168 if hub:
1174 if hub:
1169 targets = 'all'
1175 targets = 'all'
1170 try:
1176 try:
1171 targets = self._build_targets(targets)[0]
1177 targets = self._build_targets(targets)[0]
1172 except NoEnginesRegistered:
1178 except NoEnginesRegistered:
1173 targets = []
1179 targets = []
1174 for t in targets:
1180 for t in targets:
1175 self.session.send(self._control_socket, 'shutdown_request',
1181 self.session.send(self._control_socket, 'shutdown_request',
1176 content={'restart':restart},ident=t)
1182 content={'restart':restart},ident=t)
1177 error = False
1183 error = False
1178 if block or hub:
1184 if block or hub:
1179 self._flush_ignored_control()
1185 self._flush_ignored_control()
1180 for i in range(len(targets)):
1186 for i in range(len(targets)):
1181 idents,msg = self.session.recv(self._control_socket, 0)
1187 idents,msg = self.session.recv(self._control_socket, 0)
1182 if self.debug:
1188 if self.debug:
1183 pprint(msg)
1189 pprint(msg)
1184 if msg['content']['status'] != 'ok':
1190 if msg['content']['status'] != 'ok':
1185 error = self._unwrap_exception(msg['content'])
1191 error = self._unwrap_exception(msg['content'])
1186 else:
1192 else:
1187 self._ignored_control_replies += len(targets)
1193 self._ignored_control_replies += len(targets)
1188
1194
1189 if hub:
1195 if hub:
1190 time.sleep(0.25)
1196 time.sleep(0.25)
1191 self.session.send(self._query_socket, 'shutdown_request')
1197 self.session.send(self._query_socket, 'shutdown_request')
1192 idents,msg = self.session.recv(self._query_socket, 0)
1198 idents,msg = self.session.recv(self._query_socket, 0)
1193 if self.debug:
1199 if self.debug:
1194 pprint(msg)
1200 pprint(msg)
1195 if msg['content']['status'] != 'ok':
1201 if msg['content']['status'] != 'ok':
1196 error = self._unwrap_exception(msg['content'])
1202 error = self._unwrap_exception(msg['content'])
1197
1203
1198 if error:
1204 if error:
1199 raise error
1205 raise error
1200
1206
1201 #--------------------------------------------------------------------------
1207 #--------------------------------------------------------------------------
1202 # Execution related methods
1208 # Execution related methods
1203 #--------------------------------------------------------------------------
1209 #--------------------------------------------------------------------------
1204
1210
1205 def _maybe_raise(self, result):
1211 def _maybe_raise(self, result):
1206 """wrapper for maybe raising an exception if apply failed."""
1212 """wrapper for maybe raising an exception if apply failed."""
1207 if isinstance(result, error.RemoteError):
1213 if isinstance(result, error.RemoteError):
1208 raise result
1214 raise result
1209
1215
1210 return result
1216 return result
1211
1217
1212 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1218 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1213 ident=None):
1219 ident=None):
1214 """construct and send an apply message via a socket.
1220 """construct and send an apply message via a socket.
1215
1221
1216 This is the principal method with which all engine execution is performed by views.
1222 This is the principal method with which all engine execution is performed by views.
1217 """
1223 """
1218
1224
1219 if self._closed:
1225 if self._closed:
1220 raise RuntimeError("Client cannot be used after its sockets have been closed")
1226 raise RuntimeError("Client cannot be used after its sockets have been closed")
1221
1227
1222 # defaults:
1228 # defaults:
1223 args = args if args is not None else []
1229 args = args if args is not None else []
1224 kwargs = kwargs if kwargs is not None else {}
1230 kwargs = kwargs if kwargs is not None else {}
1225 metadata = metadata if metadata is not None else {}
1231 metadata = metadata if metadata is not None else {}
1226
1232
1227 # validate arguments
1233 # validate arguments
1228 if not callable(f) and not isinstance(f, Reference):
1234 if not callable(f) and not isinstance(f, Reference):
1229 raise TypeError("f must be callable, not %s"%type(f))
1235 raise TypeError("f must be callable, not %s"%type(f))
1230 if not isinstance(args, (tuple, list)):
1236 if not isinstance(args, (tuple, list)):
1231 raise TypeError("args must be tuple or list, not %s"%type(args))
1237 raise TypeError("args must be tuple or list, not %s"%type(args))
1232 if not isinstance(kwargs, dict):
1238 if not isinstance(kwargs, dict):
1233 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1239 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1234 if not isinstance(metadata, dict):
1240 if not isinstance(metadata, dict):
1235 raise TypeError("metadata must be dict, not %s"%type(metadata))
1241 raise TypeError("metadata must be dict, not %s"%type(metadata))
1236
1242
1237 bufs = serialize.pack_apply_message(f, args, kwargs,
1243 bufs = serialize.pack_apply_message(f, args, kwargs,
1238 buffer_threshold=self.session.buffer_threshold,
1244 buffer_threshold=self.session.buffer_threshold,
1239 item_threshold=self.session.item_threshold,
1245 item_threshold=self.session.item_threshold,
1240 )
1246 )
1241
1247
1242 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1248 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1243 metadata=metadata, track=track)
1249 metadata=metadata, track=track)
1244
1250
1245 msg_id = msg['header']['msg_id']
1251 msg_id = msg['header']['msg_id']
1246 self.outstanding.add(msg_id)
1252 self.outstanding.add(msg_id)
1247 if ident:
1253 if ident:
1248 # possibly routed to a specific engine
1254 # possibly routed to a specific engine
1249 if isinstance(ident, list):
1255 if isinstance(ident, list):
1250 ident = ident[-1]
1256 ident = ident[-1]
1251 if ident in self._engines.values():
1257 if ident in self._engines.values():
1252 # save for later, in case of engine death
1258 # save for later, in case of engine death
1253 self._outstanding_dict[ident].add(msg_id)
1259 self._outstanding_dict[ident].add(msg_id)
1254 self.history.append(msg_id)
1260 self.history.append(msg_id)
1255 self.metadata[msg_id]['submitted'] = datetime.now()
1261 self.metadata[msg_id]['submitted'] = datetime.now()
1256
1262
1257 return msg
1263 return msg
1258
1264
1259 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1265 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1260 """construct and send an execute request via a socket.
1266 """construct and send an execute request via a socket.
1261
1267
1262 """
1268 """
1263
1269
1264 if self._closed:
1270 if self._closed:
1265 raise RuntimeError("Client cannot be used after its sockets have been closed")
1271 raise RuntimeError("Client cannot be used after its sockets have been closed")
1266
1272
1267 # defaults:
1273 # defaults:
1268 metadata = metadata if metadata is not None else {}
1274 metadata = metadata if metadata is not None else {}
1269
1275
1270 # validate arguments
1276 # validate arguments
1271 if not isinstance(code, basestring):
1277 if not isinstance(code, basestring):
1272 raise TypeError("code must be text, not %s" % type(code))
1278 raise TypeError("code must be text, not %s" % type(code))
1273 if not isinstance(metadata, dict):
1279 if not isinstance(metadata, dict):
1274 raise TypeError("metadata must be dict, not %s" % type(metadata))
1280 raise TypeError("metadata must be dict, not %s" % type(metadata))
1275
1281
1276 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1282 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1277
1283
1278
1284
1279 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1285 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1280 metadata=metadata)
1286 metadata=metadata)
1281
1287
1282 msg_id = msg['header']['msg_id']
1288 msg_id = msg['header']['msg_id']
1283 self.outstanding.add(msg_id)
1289 self.outstanding.add(msg_id)
1284 if ident:
1290 if ident:
1285 # possibly routed to a specific engine
1291 # possibly routed to a specific engine
1286 if isinstance(ident, list):
1292 if isinstance(ident, list):
1287 ident = ident[-1]
1293 ident = ident[-1]
1288 if ident in self._engines.values():
1294 if ident in self._engines.values():
1289 # save for later, in case of engine death
1295 # save for later, in case of engine death
1290 self._outstanding_dict[ident].add(msg_id)
1296 self._outstanding_dict[ident].add(msg_id)
1291 self.history.append(msg_id)
1297 self.history.append(msg_id)
1292 self.metadata[msg_id]['submitted'] = datetime.now()
1298 self.metadata[msg_id]['submitted'] = datetime.now()
1293
1299
1294 return msg
1300 return msg
1295
1301
1296 #--------------------------------------------------------------------------
1302 #--------------------------------------------------------------------------
1297 # construct a View object
1303 # construct a View object
1298 #--------------------------------------------------------------------------
1304 #--------------------------------------------------------------------------
1299
1305
1300 def load_balanced_view(self, targets=None):
1306 def load_balanced_view(self, targets=None):
1301 """construct a DirectView object.
1307 """construct a DirectView object.
1302
1308
1303 If no arguments are specified, create a LoadBalancedView
1309 If no arguments are specified, create a LoadBalancedView
1304 using all engines.
1310 using all engines.
1305
1311
1306 Parameters
1312 Parameters
1307 ----------
1313 ----------
1308
1314
1309 targets: list,slice,int,etc. [default: use all engines]
1315 targets: list,slice,int,etc. [default: use all engines]
1310 The subset of engines across which to load-balance
1316 The subset of engines across which to load-balance
1311 """
1317 """
1312 if targets == 'all':
1318 if targets == 'all':
1313 targets = None
1319 targets = None
1314 if targets is not None:
1320 if targets is not None:
1315 targets = self._build_targets(targets)[1]
1321 targets = self._build_targets(targets)[1]
1316 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1322 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1317
1323
1318 def direct_view(self, targets='all'):
1324 def direct_view(self, targets='all'):
1319 """construct a DirectView object.
1325 """construct a DirectView object.
1320
1326
1321 If no targets are specified, create a DirectView using all engines.
1327 If no targets are specified, create a DirectView using all engines.
1322
1328
1323 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1329 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1324 evaluate the target engines at each execution, whereas rc[:] will connect to
1330 evaluate the target engines at each execution, whereas rc[:] will connect to
1325 all *current* engines, and that list will not change.
1331 all *current* engines, and that list will not change.
1326
1332
1327 That is, 'all' will always use all engines, whereas rc[:] will not use
1333 That is, 'all' will always use all engines, whereas rc[:] will not use
1328 engines added after the DirectView is constructed.
1334 engines added after the DirectView is constructed.
1329
1335
1330 Parameters
1336 Parameters
1331 ----------
1337 ----------
1332
1338
1333 targets: list,slice,int,etc. [default: use all engines]
1339 targets: list,slice,int,etc. [default: use all engines]
1334 The engines to use for the View
1340 The engines to use for the View
1335 """
1341 """
1336 single = isinstance(targets, int)
1342 single = isinstance(targets, int)
1337 # allow 'all' to be lazily evaluated at each execution
1343 # allow 'all' to be lazily evaluated at each execution
1338 if targets != 'all':
1344 if targets != 'all':
1339 targets = self._build_targets(targets)[1]
1345 targets = self._build_targets(targets)[1]
1340 if single:
1346 if single:
1341 targets = targets[0]
1347 targets = targets[0]
1342 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1348 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1343
1349
1344 #--------------------------------------------------------------------------
1350 #--------------------------------------------------------------------------
1345 # Query methods
1351 # Query methods
1346 #--------------------------------------------------------------------------
1352 #--------------------------------------------------------------------------
1347
1353
1348 @spin_first
1354 @spin_first
1349 def get_result(self, indices_or_msg_ids=None, block=None):
1355 def get_result(self, indices_or_msg_ids=None, block=None):
1350 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1356 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1351
1357
1352 If the client already has the results, no request to the Hub will be made.
1358 If the client already has the results, no request to the Hub will be made.
1353
1359
1354 This is a convenient way to construct AsyncResult objects, which are wrappers
1360 This is a convenient way to construct AsyncResult objects, which are wrappers
1355 that include metadata about execution, and allow for awaiting results that
1361 that include metadata about execution, and allow for awaiting results that
1356 were not submitted by this Client.
1362 were not submitted by this Client.
1357
1363
1358 It can also be a convenient way to retrieve the metadata associated with
1364 It can also be a convenient way to retrieve the metadata associated with
1359 blocking execution, since it always retrieves
1365 blocking execution, since it always retrieves
1360
1366
1361 Examples
1367 Examples
1362 --------
1368 --------
1363 ::
1369 ::
1364
1370
1365 In [10]: r = client.apply()
1371 In [10]: r = client.apply()
1366
1372
1367 Parameters
1373 Parameters
1368 ----------
1374 ----------
1369
1375
1370 indices_or_msg_ids : integer history index, str msg_id, or list of either
1376 indices_or_msg_ids : integer history index, str msg_id, or list of either
1371 The indices or msg_ids of indices to be retrieved
1377 The indices or msg_ids of indices to be retrieved
1372
1378
1373 block : bool
1379 block : bool
1374 Whether to wait for the result to be done
1380 Whether to wait for the result to be done
1375
1381
1376 Returns
1382 Returns
1377 -------
1383 -------
1378
1384
1379 AsyncResult
1385 AsyncResult
1380 A single AsyncResult object will always be returned.
1386 A single AsyncResult object will always be returned.
1381
1387
1382 AsyncHubResult
1388 AsyncHubResult
1383 A subclass of AsyncResult that retrieves results from the Hub
1389 A subclass of AsyncResult that retrieves results from the Hub
1384
1390
1385 """
1391 """
1386 block = self.block if block is None else block
1392 block = self.block if block is None else block
1387 if indices_or_msg_ids is None:
1393 if indices_or_msg_ids is None:
1388 indices_or_msg_ids = -1
1394 indices_or_msg_ids = -1
1389
1395
1390 single_result = False
1396 single_result = False
1391 if not isinstance(indices_or_msg_ids, (list,tuple)):
1397 if not isinstance(indices_or_msg_ids, (list,tuple)):
1392 indices_or_msg_ids = [indices_or_msg_ids]
1398 indices_or_msg_ids = [indices_or_msg_ids]
1393 single_result = True
1399 single_result = True
1394
1400
1395 theids = []
1401 theids = []
1396 for id in indices_or_msg_ids:
1402 for id in indices_or_msg_ids:
1397 if isinstance(id, int):
1403 if isinstance(id, int):
1398 id = self.history[id]
1404 id = self.history[id]
1399 if not isinstance(id, basestring):
1405 if not isinstance(id, basestring):
1400 raise TypeError("indices must be str or int, not %r"%id)
1406 raise TypeError("indices must be str or int, not %r"%id)
1401 theids.append(id)
1407 theids.append(id)
1402
1408
1403 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1409 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1404 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1410 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1405
1411
1406 # given single msg_id initially, get_result shot get the result itself,
1412 # given single msg_id initially, get_result shot get the result itself,
1407 # not a length-one list
1413 # not a length-one list
1408 if single_result:
1414 if single_result:
1409 theids = theids[0]
1415 theids = theids[0]
1410
1416
1411 if remote_ids:
1417 if remote_ids:
1412 ar = AsyncHubResult(self, msg_ids=theids)
1418 ar = AsyncHubResult(self, msg_ids=theids)
1413 else:
1419 else:
1414 ar = AsyncResult(self, msg_ids=theids)
1420 ar = AsyncResult(self, msg_ids=theids)
1415
1421
1416 if block:
1422 if block:
1417 ar.wait()
1423 ar.wait()
1418
1424
1419 return ar
1425 return ar
1420
1426
1421 @spin_first
1427 @spin_first
1422 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1428 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1423 """Resubmit one or more tasks.
1429 """Resubmit one or more tasks.
1424
1430
1425 in-flight tasks may not be resubmitted.
1431 in-flight tasks may not be resubmitted.
1426
1432
1427 Parameters
1433 Parameters
1428 ----------
1434 ----------
1429
1435
1430 indices_or_msg_ids : integer history index, str msg_id, or list of either
1436 indices_or_msg_ids : integer history index, str msg_id, or list of either
1431 The indices or msg_ids of indices to be retrieved
1437 The indices or msg_ids of indices to be retrieved
1432
1438
1433 block : bool
1439 block : bool
1434 Whether to wait for the result to be done
1440 Whether to wait for the result to be done
1435
1441
1436 Returns
1442 Returns
1437 -------
1443 -------
1438
1444
1439 AsyncHubResult
1445 AsyncHubResult
1440 A subclass of AsyncResult that retrieves results from the Hub
1446 A subclass of AsyncResult that retrieves results from the Hub
1441
1447
1442 """
1448 """
1443 block = self.block if block is None else block
1449 block = self.block if block is None else block
1444 if indices_or_msg_ids is None:
1450 if indices_or_msg_ids is None:
1445 indices_or_msg_ids = -1
1451 indices_or_msg_ids = -1
1446
1452
1447 if not isinstance(indices_or_msg_ids, (list,tuple)):
1453 if not isinstance(indices_or_msg_ids, (list,tuple)):
1448 indices_or_msg_ids = [indices_or_msg_ids]
1454 indices_or_msg_ids = [indices_or_msg_ids]
1449
1455
1450 theids = []
1456 theids = []
1451 for id in indices_or_msg_ids:
1457 for id in indices_or_msg_ids:
1452 if isinstance(id, int):
1458 if isinstance(id, int):
1453 id = self.history[id]
1459 id = self.history[id]
1454 if not isinstance(id, basestring):
1460 if not isinstance(id, basestring):
1455 raise TypeError("indices must be str or int, not %r"%id)
1461 raise TypeError("indices must be str or int, not %r"%id)
1456 theids.append(id)
1462 theids.append(id)
1457
1463
1458 content = dict(msg_ids = theids)
1464 content = dict(msg_ids = theids)
1459
1465
1460 self.session.send(self._query_socket, 'resubmit_request', content)
1466 self.session.send(self._query_socket, 'resubmit_request', content)
1461
1467
1462 zmq.select([self._query_socket], [], [])
1468 zmq.select([self._query_socket], [], [])
1463 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1469 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1464 if self.debug:
1470 if self.debug:
1465 pprint(msg)
1471 pprint(msg)
1466 content = msg['content']
1472 content = msg['content']
1467 if content['status'] != 'ok':
1473 if content['status'] != 'ok':
1468 raise self._unwrap_exception(content)
1474 raise self._unwrap_exception(content)
1469 mapping = content['resubmitted']
1475 mapping = content['resubmitted']
1470 new_ids = [ mapping[msg_id] for msg_id in theids ]
1476 new_ids = [ mapping[msg_id] for msg_id in theids ]
1471
1477
1472 ar = AsyncHubResult(self, msg_ids=new_ids)
1478 ar = AsyncHubResult(self, msg_ids=new_ids)
1473
1479
1474 if block:
1480 if block:
1475 ar.wait()
1481 ar.wait()
1476
1482
1477 return ar
1483 return ar
1478
1484
1479 @spin_first
1485 @spin_first
1480 def result_status(self, msg_ids, status_only=True):
1486 def result_status(self, msg_ids, status_only=True):
1481 """Check on the status of the result(s) of the apply request with `msg_ids`.
1487 """Check on the status of the result(s) of the apply request with `msg_ids`.
1482
1488
1483 If status_only is False, then the actual results will be retrieved, else
1489 If status_only is False, then the actual results will be retrieved, else
1484 only the status of the results will be checked.
1490 only the status of the results will be checked.
1485
1491
1486 Parameters
1492 Parameters
1487 ----------
1493 ----------
1488
1494
1489 msg_ids : list of msg_ids
1495 msg_ids : list of msg_ids
1490 if int:
1496 if int:
1491 Passed as index to self.history for convenience.
1497 Passed as index to self.history for convenience.
1492 status_only : bool (default: True)
1498 status_only : bool (default: True)
1493 if False:
1499 if False:
1494 Retrieve the actual results of completed tasks.
1500 Retrieve the actual results of completed tasks.
1495
1501
1496 Returns
1502 Returns
1497 -------
1503 -------
1498
1504
1499 results : dict
1505 results : dict
1500 There will always be the keys 'pending' and 'completed', which will
1506 There will always be the keys 'pending' and 'completed', which will
1501 be lists of msg_ids that are incomplete or complete. If `status_only`
1507 be lists of msg_ids that are incomplete or complete. If `status_only`
1502 is False, then completed results will be keyed by their `msg_id`.
1508 is False, then completed results will be keyed by their `msg_id`.
1503 """
1509 """
1504 if not isinstance(msg_ids, (list,tuple)):
1510 if not isinstance(msg_ids, (list,tuple)):
1505 msg_ids = [msg_ids]
1511 msg_ids = [msg_ids]
1506
1512
1507 theids = []
1513 theids = []
1508 for msg_id in msg_ids:
1514 for msg_id in msg_ids:
1509 if isinstance(msg_id, int):
1515 if isinstance(msg_id, int):
1510 msg_id = self.history[msg_id]
1516 msg_id = self.history[msg_id]
1511 if not isinstance(msg_id, basestring):
1517 if not isinstance(msg_id, basestring):
1512 raise TypeError("msg_ids must be str, not %r"%msg_id)
1518 raise TypeError("msg_ids must be str, not %r"%msg_id)
1513 theids.append(msg_id)
1519 theids.append(msg_id)
1514
1520
1515 completed = []
1521 completed = []
1516 local_results = {}
1522 local_results = {}
1517
1523
1518 # comment this block out to temporarily disable local shortcut:
1524 # comment this block out to temporarily disable local shortcut:
1519 for msg_id in theids:
1525 for msg_id in theids:
1520 if msg_id in self.results:
1526 if msg_id in self.results:
1521 completed.append(msg_id)
1527 completed.append(msg_id)
1522 local_results[msg_id] = self.results[msg_id]
1528 local_results[msg_id] = self.results[msg_id]
1523 theids.remove(msg_id)
1529 theids.remove(msg_id)
1524
1530
1525 if theids: # some not locally cached
1531 if theids: # some not locally cached
1526 content = dict(msg_ids=theids, status_only=status_only)
1532 content = dict(msg_ids=theids, status_only=status_only)
1527 msg = self.session.send(self._query_socket, "result_request", content=content)
1533 msg = self.session.send(self._query_socket, "result_request", content=content)
1528 zmq.select([self._query_socket], [], [])
1534 zmq.select([self._query_socket], [], [])
1529 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1535 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1530 if self.debug:
1536 if self.debug:
1531 pprint(msg)
1537 pprint(msg)
1532 content = msg['content']
1538 content = msg['content']
1533 if content['status'] != 'ok':
1539 if content['status'] != 'ok':
1534 raise self._unwrap_exception(content)
1540 raise self._unwrap_exception(content)
1535 buffers = msg['buffers']
1541 buffers = msg['buffers']
1536 else:
1542 else:
1537 content = dict(completed=[],pending=[])
1543 content = dict(completed=[],pending=[])
1538
1544
1539 content['completed'].extend(completed)
1545 content['completed'].extend(completed)
1540
1546
1541 if status_only:
1547 if status_only:
1542 return content
1548 return content
1543
1549
1544 failures = []
1550 failures = []
1545 # load cached results into result:
1551 # load cached results into result:
1546 content.update(local_results)
1552 content.update(local_results)
1547
1553
1548 # update cache with results:
1554 # update cache with results:
1549 for msg_id in sorted(theids):
1555 for msg_id in sorted(theids):
1550 if msg_id in content['completed']:
1556 if msg_id in content['completed']:
1551 rec = content[msg_id]
1557 rec = content[msg_id]
1552 parent = rec['header']
1558 parent = rec['header']
1553 header = rec['result_header']
1559 header = rec['result_header']
1554 rcontent = rec['result_content']
1560 rcontent = rec['result_content']
1555 iodict = rec['io']
1561 iodict = rec['io']
1556 if isinstance(rcontent, str):
1562 if isinstance(rcontent, str):
1557 rcontent = self.session.unpack(rcontent)
1563 rcontent = self.session.unpack(rcontent)
1558
1564
1559 md = self.metadata[msg_id]
1565 md = self.metadata[msg_id]
1560 md_msg = dict(
1566 md_msg = dict(
1561 content=rcontent,
1567 content=rcontent,
1562 parent_header=parent,
1568 parent_header=parent,
1563 header=header,
1569 header=header,
1564 metadata=rec['result_metadata'],
1570 metadata=rec['result_metadata'],
1565 )
1571 )
1566 md.update(self._extract_metadata(md_msg))
1572 md.update(self._extract_metadata(md_msg))
1567 if rec.get('received'):
1573 if rec.get('received'):
1568 md['received'] = rec['received']
1574 md['received'] = rec['received']
1569 md.update(iodict)
1575 md.update(iodict)
1570
1576
1571 if rcontent['status'] == 'ok':
1577 if rcontent['status'] == 'ok':
1572 if header['msg_type'] == 'apply_reply':
1578 if header['msg_type'] == 'apply_reply':
1573 res,buffers = serialize.unserialize_object(buffers)
1579 res,buffers = serialize.unserialize_object(buffers)
1574 elif header['msg_type'] == 'execute_reply':
1580 elif header['msg_type'] == 'execute_reply':
1575 res = ExecuteReply(msg_id, rcontent, md)
1581 res = ExecuteReply(msg_id, rcontent, md)
1576 else:
1582 else:
1577 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1583 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1578 else:
1584 else:
1579 res = self._unwrap_exception(rcontent)
1585 res = self._unwrap_exception(rcontent)
1580 failures.append(res)
1586 failures.append(res)
1581
1587
1582 self.results[msg_id] = res
1588 self.results[msg_id] = res
1583 content[msg_id] = res
1589 content[msg_id] = res
1584
1590
1585 if len(theids) == 1 and failures:
1591 if len(theids) == 1 and failures:
1586 raise failures[0]
1592 raise failures[0]
1587
1593
1588 error.collect_exceptions(failures, "result_status")
1594 error.collect_exceptions(failures, "result_status")
1589 return content
1595 return content
1590
1596
1591 @spin_first
1597 @spin_first
1592 def queue_status(self, targets='all', verbose=False):
1598 def queue_status(self, targets='all', verbose=False):
1593 """Fetch the status of engine queues.
1599 """Fetch the status of engine queues.
1594
1600
1595 Parameters
1601 Parameters
1596 ----------
1602 ----------
1597
1603
1598 targets : int/str/list of ints/strs
1604 targets : int/str/list of ints/strs
1599 the engines whose states are to be queried.
1605 the engines whose states are to be queried.
1600 default : all
1606 default : all
1601 verbose : bool
1607 verbose : bool
1602 Whether to return lengths only, or lists of ids for each element
1608 Whether to return lengths only, or lists of ids for each element
1603 """
1609 """
1604 if targets == 'all':
1610 if targets == 'all':
1605 # allow 'all' to be evaluated on the engine
1611 # allow 'all' to be evaluated on the engine
1606 engine_ids = None
1612 engine_ids = None
1607 else:
1613 else:
1608 engine_ids = self._build_targets(targets)[1]
1614 engine_ids = self._build_targets(targets)[1]
1609 content = dict(targets=engine_ids, verbose=verbose)
1615 content = dict(targets=engine_ids, verbose=verbose)
1610 self.session.send(self._query_socket, "queue_request", content=content)
1616 self.session.send(self._query_socket, "queue_request", content=content)
1611 idents,msg = self.session.recv(self._query_socket, 0)
1617 idents,msg = self.session.recv(self._query_socket, 0)
1612 if self.debug:
1618 if self.debug:
1613 pprint(msg)
1619 pprint(msg)
1614 content = msg['content']
1620 content = msg['content']
1615 status = content.pop('status')
1621 status = content.pop('status')
1616 if status != 'ok':
1622 if status != 'ok':
1617 raise self._unwrap_exception(content)
1623 raise self._unwrap_exception(content)
1618 content = rekey(content)
1624 content = rekey(content)
1619 if isinstance(targets, int):
1625 if isinstance(targets, int):
1620 return content[targets]
1626 return content[targets]
1621 else:
1627 else:
1622 return content
1628 return content
1623
1629
1624 def _build_msgids_from_target(self, targets=None):
1630 def _build_msgids_from_target(self, targets=None):
1625 """Build a list of msg_ids from the list of engine targets"""
1631 """Build a list of msg_ids from the list of engine targets"""
1626 if not targets: # needed as _build_targets otherwise uses all engines
1632 if not targets: # needed as _build_targets otherwise uses all engines
1627 return []
1633 return []
1628 target_ids = self._build_targets(targets)[0]
1634 target_ids = self._build_targets(targets)[0]
1629 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1635 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1630
1636
1631 def _build_msgids_from_jobs(self, jobs=None):
1637 def _build_msgids_from_jobs(self, jobs=None):
1632 """Build a list of msg_ids from "jobs" """
1638 """Build a list of msg_ids from "jobs" """
1633 if not jobs:
1639 if not jobs:
1634 return []
1640 return []
1635 msg_ids = []
1641 msg_ids = []
1636 if isinstance(jobs, (basestring,AsyncResult)):
1642 if isinstance(jobs, (basestring,AsyncResult)):
1637 jobs = [jobs]
1643 jobs = [jobs]
1638 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1644 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1639 if bad_ids:
1645 if bad_ids:
1640 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1646 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1641 for j in jobs:
1647 for j in jobs:
1642 if isinstance(j, AsyncResult):
1648 if isinstance(j, AsyncResult):
1643 msg_ids.extend(j.msg_ids)
1649 msg_ids.extend(j.msg_ids)
1644 else:
1650 else:
1645 msg_ids.append(j)
1651 msg_ids.append(j)
1646 return msg_ids
1652 return msg_ids
1647
1653
1648 def purge_local_results(self, jobs=[], targets=[]):
1654 def purge_local_results(self, jobs=[], targets=[]):
1649 """Clears the client caches of results and frees such memory.
1655 """Clears the client caches of results and frees such memory.
1650
1656
1651 Individual results can be purged by msg_id, or the entire
1657 Individual results can be purged by msg_id, or the entire
1652 history of specific targets can be purged.
1658 history of specific targets can be purged.
1653
1659
1654 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1660 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1655
1661
1656 The client must have no outstanding tasks before purging the caches.
1662 The client must have no outstanding tasks before purging the caches.
1657 Raises `AssertionError` if there are still outstanding tasks.
1663 Raises `AssertionError` if there are still outstanding tasks.
1658
1664
1659 After this call all `AsyncResults` are invalid and should be discarded.
1665 After this call all `AsyncResults` are invalid and should be discarded.
1660
1666
1661 If you must "reget" the results, you can still do so by using
1667 If you must "reget" the results, you can still do so by using
1662 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1668 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1663 redownload the results from the hub if they are still available
1669 redownload the results from the hub if they are still available
1664 (i.e `client.purge_hub_results(...)` has not been called.
1670 (i.e `client.purge_hub_results(...)` has not been called.
1665
1671
1666 Parameters
1672 Parameters
1667 ----------
1673 ----------
1668
1674
1669 jobs : str or list of str or AsyncResult objects
1675 jobs : str or list of str or AsyncResult objects
1670 the msg_ids whose results should be purged.
1676 the msg_ids whose results should be purged.
1671 targets : int/str/list of ints/strs
1677 targets : int/str/list of ints/strs
1672 The targets, by int_id, whose entire results are to be purged.
1678 The targets, by int_id, whose entire results are to be purged.
1673
1679
1674 default : None
1680 default : None
1675 """
1681 """
1676 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1682 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1677
1683
1678 if not targets and not jobs:
1684 if not targets and not jobs:
1679 raise ValueError("Must specify at least one of `targets` and `jobs`")
1685 raise ValueError("Must specify at least one of `targets` and `jobs`")
1680
1686
1681 if jobs == 'all':
1687 if jobs == 'all':
1682 self.results.clear()
1688 self.results.clear()
1683 self.metadata.clear()
1689 self.metadata.clear()
1684 return
1690 return
1685 else:
1691 else:
1686 msg_ids = []
1692 msg_ids = []
1687 msg_ids.extend(self._build_msgids_from_target(targets))
1693 msg_ids.extend(self._build_msgids_from_target(targets))
1688 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1694 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1689 map(self.results.pop, msg_ids)
1695 map(self.results.pop, msg_ids)
1690 map(self.metadata.pop, msg_ids)
1696 map(self.metadata.pop, msg_ids)
1691
1697
1692
1698
1693 @spin_first
1699 @spin_first
1694 def purge_hub_results(self, jobs=[], targets=[]):
1700 def purge_hub_results(self, jobs=[], targets=[]):
1695 """Tell the Hub to forget results.
1701 """Tell the Hub to forget results.
1696
1702
1697 Individual results can be purged by msg_id, or the entire
1703 Individual results can be purged by msg_id, or the entire
1698 history of specific targets can be purged.
1704 history of specific targets can be purged.
1699
1705
1700 Use `purge_results('all')` to scrub everything from the Hub's db.
1706 Use `purge_results('all')` to scrub everything from the Hub's db.
1701
1707
1702 Parameters
1708 Parameters
1703 ----------
1709 ----------
1704
1710
1705 jobs : str or list of str or AsyncResult objects
1711 jobs : str or list of str or AsyncResult objects
1706 the msg_ids whose results should be forgotten.
1712 the msg_ids whose results should be forgotten.
1707 targets : int/str/list of ints/strs
1713 targets : int/str/list of ints/strs
1708 The targets, by int_id, whose entire history is to be purged.
1714 The targets, by int_id, whose entire history is to be purged.
1709
1715
1710 default : None
1716 default : None
1711 """
1717 """
1712 if not targets and not jobs:
1718 if not targets and not jobs:
1713 raise ValueError("Must specify at least one of `targets` and `jobs`")
1719 raise ValueError("Must specify at least one of `targets` and `jobs`")
1714 if targets:
1720 if targets:
1715 targets = self._build_targets(targets)[1]
1721 targets = self._build_targets(targets)[1]
1716
1722
1717 # construct msg_ids from jobs
1723 # construct msg_ids from jobs
1718 if jobs == 'all':
1724 if jobs == 'all':
1719 msg_ids = jobs
1725 msg_ids = jobs
1720 else:
1726 else:
1721 msg_ids = self._build_msgids_from_jobs(jobs)
1727 msg_ids = self._build_msgids_from_jobs(jobs)
1722
1728
1723 content = dict(engine_ids=targets, msg_ids=msg_ids)
1729 content = dict(engine_ids=targets, msg_ids=msg_ids)
1724 self.session.send(self._query_socket, "purge_request", content=content)
1730 self.session.send(self._query_socket, "purge_request", content=content)
1725 idents, msg = self.session.recv(self._query_socket, 0)
1731 idents, msg = self.session.recv(self._query_socket, 0)
1726 if self.debug:
1732 if self.debug:
1727 pprint(msg)
1733 pprint(msg)
1728 content = msg['content']
1734 content = msg['content']
1729 if content['status'] != 'ok':
1735 if content['status'] != 'ok':
1730 raise self._unwrap_exception(content)
1736 raise self._unwrap_exception(content)
1731
1737
1732 def purge_results(self, jobs=[], targets=[]):
1738 def purge_results(self, jobs=[], targets=[]):
1733 """Clears the cached results from both the hub and the local client
1739 """Clears the cached results from both the hub and the local client
1734
1740
1735 Individual results can be purged by msg_id, or the entire
1741 Individual results can be purged by msg_id, or the entire
1736 history of specific targets can be purged.
1742 history of specific targets can be purged.
1737
1743
1738 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1744 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1739 the Client's db.
1745 the Client's db.
1740
1746
1741 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1747 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1742 the same arguments.
1748 the same arguments.
1743
1749
1744 Parameters
1750 Parameters
1745 ----------
1751 ----------
1746
1752
1747 jobs : str or list of str or AsyncResult objects
1753 jobs : str or list of str or AsyncResult objects
1748 the msg_ids whose results should be forgotten.
1754 the msg_ids whose results should be forgotten.
1749 targets : int/str/list of ints/strs
1755 targets : int/str/list of ints/strs
1750 The targets, by int_id, whose entire history is to be purged.
1756 The targets, by int_id, whose entire history is to be purged.
1751
1757
1752 default : None
1758 default : None
1753 """
1759 """
1754 self.purge_local_results(jobs=jobs, targets=targets)
1760 self.purge_local_results(jobs=jobs, targets=targets)
1755 self.purge_hub_results(jobs=jobs, targets=targets)
1761 self.purge_hub_results(jobs=jobs, targets=targets)
1756
1762
1757 def purge_everything(self):
1763 def purge_everything(self):
1758 """Clears all content from previous Tasks from both the hub and the local client
1764 """Clears all content from previous Tasks from both the hub and the local client
1759
1765
1760 In addition to calling `purge_results("all")` it also deletes the history and
1766 In addition to calling `purge_results("all")` it also deletes the history and
1761 other bookkeeping lists.
1767 other bookkeeping lists.
1762 """
1768 """
1763 self.purge_results("all")
1769 self.purge_results("all")
1764 self.history = []
1770 self.history = []
1765 self.session.digest_history.clear()
1771 self.session.digest_history.clear()
1766
1772
1767 @spin_first
1773 @spin_first
1768 def hub_history(self):
1774 def hub_history(self):
1769 """Get the Hub's history
1775 """Get the Hub's history
1770
1776
1771 Just like the Client, the Hub has a history, which is a list of msg_ids.
1777 Just like the Client, the Hub has a history, which is a list of msg_ids.
1772 This will contain the history of all clients, and, depending on configuration,
1778 This will contain the history of all clients, and, depending on configuration,
1773 may contain history across multiple cluster sessions.
1779 may contain history across multiple cluster sessions.
1774
1780
1775 Any msg_id returned here is a valid argument to `get_result`.
1781 Any msg_id returned here is a valid argument to `get_result`.
1776
1782
1777 Returns
1783 Returns
1778 -------
1784 -------
1779
1785
1780 msg_ids : list of strs
1786 msg_ids : list of strs
1781 list of all msg_ids, ordered by task submission time.
1787 list of all msg_ids, ordered by task submission time.
1782 """
1788 """
1783
1789
1784 self.session.send(self._query_socket, "history_request", content={})
1790 self.session.send(self._query_socket, "history_request", content={})
1785 idents, msg = self.session.recv(self._query_socket, 0)
1791 idents, msg = self.session.recv(self._query_socket, 0)
1786
1792
1787 if self.debug:
1793 if self.debug:
1788 pprint(msg)
1794 pprint(msg)
1789 content = msg['content']
1795 content = msg['content']
1790 if content['status'] != 'ok':
1796 if content['status'] != 'ok':
1791 raise self._unwrap_exception(content)
1797 raise self._unwrap_exception(content)
1792 else:
1798 else:
1793 return content['history']
1799 return content['history']
1794
1800
1795 @spin_first
1801 @spin_first
1796 def db_query(self, query, keys=None):
1802 def db_query(self, query, keys=None):
1797 """Query the Hub's TaskRecord database
1803 """Query the Hub's TaskRecord database
1798
1804
1799 This will return a list of task record dicts that match `query`
1805 This will return a list of task record dicts that match `query`
1800
1806
1801 Parameters
1807 Parameters
1802 ----------
1808 ----------
1803
1809
1804 query : mongodb query dict
1810 query : mongodb query dict
1805 The search dict. See mongodb query docs for details.
1811 The search dict. See mongodb query docs for details.
1806 keys : list of strs [optional]
1812 keys : list of strs [optional]
1807 The subset of keys to be returned. The default is to fetch everything but buffers.
1813 The subset of keys to be returned. The default is to fetch everything but buffers.
1808 'msg_id' will *always* be included.
1814 'msg_id' will *always* be included.
1809 """
1815 """
1810 if isinstance(keys, basestring):
1816 if isinstance(keys, basestring):
1811 keys = [keys]
1817 keys = [keys]
1812 content = dict(query=query, keys=keys)
1818 content = dict(query=query, keys=keys)
1813 self.session.send(self._query_socket, "db_request", content=content)
1819 self.session.send(self._query_socket, "db_request", content=content)
1814 idents, msg = self.session.recv(self._query_socket, 0)
1820 idents, msg = self.session.recv(self._query_socket, 0)
1815 if self.debug:
1821 if self.debug:
1816 pprint(msg)
1822 pprint(msg)
1817 content = msg['content']
1823 content = msg['content']
1818 if content['status'] != 'ok':
1824 if content['status'] != 'ok':
1819 raise self._unwrap_exception(content)
1825 raise self._unwrap_exception(content)
1820
1826
1821 records = content['records']
1827 records = content['records']
1822
1828
1823 buffer_lens = content['buffer_lens']
1829 buffer_lens = content['buffer_lens']
1824 result_buffer_lens = content['result_buffer_lens']
1830 result_buffer_lens = content['result_buffer_lens']
1825 buffers = msg['buffers']
1831 buffers = msg['buffers']
1826 has_bufs = buffer_lens is not None
1832 has_bufs = buffer_lens is not None
1827 has_rbufs = result_buffer_lens is not None
1833 has_rbufs = result_buffer_lens is not None
1828 for i,rec in enumerate(records):
1834 for i,rec in enumerate(records):
1829 # relink buffers
1835 # relink buffers
1830 if has_bufs:
1836 if has_bufs:
1831 blen = buffer_lens[i]
1837 blen = buffer_lens[i]
1832 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1838 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1833 if has_rbufs:
1839 if has_rbufs:
1834 blen = result_buffer_lens[i]
1840 blen = result_buffer_lens[i]
1835 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1841 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1836
1842
1837 return records
1843 return records
1838
1844
1839 __all__ = [ 'Client' ]
1845 __all__ = [ 'Client' ]
@@ -1,801 +1,814 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import base64
19 import sys
20 import sys
20 import platform
21 import platform
21 import time
22 import time
22 from collections import namedtuple
23 from collections import namedtuple
23 from tempfile import mktemp
24 from tempfile import mktemp
24 from StringIO import StringIO
25 from StringIO import StringIO
25
26
26 import zmq
27 import zmq
27 from nose import SkipTest
28 from nose import SkipTest
28 from nose.plugins.attrib import attr
29 from nose.plugins.attrib import attr
29
30
30 from IPython.testing import decorators as dec
31 from IPython.testing import decorators as dec
31 from IPython.testing.ipunittest import ParametricTestCase
32 from IPython.testing.ipunittest import ParametricTestCase
32 from IPython.utils.io import capture_output
33 from IPython.utils.io import capture_output
33
34
34 from IPython import parallel as pmod
35 from IPython import parallel as pmod
35 from IPython.parallel import error
36 from IPython.parallel import error
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 from IPython.parallel import DirectView
38 from IPython.parallel import DirectView
38 from IPython.parallel.util import interactive
39 from IPython.parallel.util import interactive
39
40
40 from IPython.parallel.tests import add_engines
41 from IPython.parallel.tests import add_engines
41
42
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43 from .clienttest import ClusterTestCase, crash, wait, skip_without
43
44
44 def setup():
45 def setup():
45 add_engines(3, total=True)
46 add_engines(3, total=True)
46
47
47 point = namedtuple("point", "x y")
48 point = namedtuple("point", "x y")
48
49
49 class TestView(ClusterTestCase, ParametricTestCase):
50 class TestView(ClusterTestCase, ParametricTestCase):
50
51
51 def setUp(self):
52 def setUp(self):
52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 # On Win XP, wait for resource cleanup, else parallel test group fails
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 time.sleep(2)
56 time.sleep(2)
56 super(TestView, self).setUp()
57 super(TestView, self).setUp()
57
58
58 @attr('crash')
59 @attr('crash')
59 def test_z_crash_mux(self):
60 def test_z_crash_mux(self):
60 """test graceful handling of engine death (direct)"""
61 """test graceful handling of engine death (direct)"""
61 # self.add_engines(1)
62 # self.add_engines(1)
62 eid = self.client.ids[-1]
63 eid = self.client.ids[-1]
63 ar = self.client[eid].apply_async(crash)
64 ar = self.client[eid].apply_async(crash)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 eid = ar.engine_id
66 eid = ar.engine_id
66 tic = time.time()
67 tic = time.time()
67 while eid in self.client.ids and time.time()-tic < 5:
68 while eid in self.client.ids and time.time()-tic < 5:
68 time.sleep(.01)
69 time.sleep(.01)
69 self.client.spin()
70 self.client.spin()
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71 self.assertFalse(eid in self.client.ids, "Engine should have died")
71
72
72 def test_push_pull(self):
73 def test_push_pull(self):
73 """test pushing and pulling"""
74 """test pushing and pulling"""
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 t = self.client.ids[-1]
76 t = self.client.ids[-1]
76 v = self.client[t]
77 v = self.client[t]
77 push = v.push
78 push = v.push
78 pull = v.pull
79 pull = v.pull
79 v.block=True
80 v.block=True
80 nengines = len(self.client)
81 nengines = len(self.client)
81 push({'data':data})
82 push({'data':data})
82 d = pull('data')
83 d = pull('data')
83 self.assertEqual(d, data)
84 self.assertEqual(d, data)
84 self.client[:].push({'data':data})
85 self.client[:].push({'data':data})
85 d = self.client[:].pull('data', block=True)
86 d = self.client[:].pull('data', block=True)
86 self.assertEqual(d, nengines*[data])
87 self.assertEqual(d, nengines*[data])
87 ar = push({'data':data}, block=False)
88 ar = push({'data':data}, block=False)
88 self.assertTrue(isinstance(ar, AsyncResult))
89 self.assertTrue(isinstance(ar, AsyncResult))
89 r = ar.get()
90 r = ar.get()
90 ar = self.client[:].pull('data', block=False)
91 ar = self.client[:].pull('data', block=False)
91 self.assertTrue(isinstance(ar, AsyncResult))
92 self.assertTrue(isinstance(ar, AsyncResult))
92 r = ar.get()
93 r = ar.get()
93 self.assertEqual(r, nengines*[data])
94 self.assertEqual(r, nengines*[data])
94 self.client[:].push(dict(a=10,b=20))
95 self.client[:].push(dict(a=10,b=20))
95 r = self.client[:].pull(('a','b'), block=True)
96 r = self.client[:].pull(('a','b'), block=True)
96 self.assertEqual(r, nengines*[[10,20]])
97 self.assertEqual(r, nengines*[[10,20]])
97
98
98 def test_push_pull_function(self):
99 def test_push_pull_function(self):
99 "test pushing and pulling functions"
100 "test pushing and pulling functions"
100 def testf(x):
101 def testf(x):
101 return 2.0*x
102 return 2.0*x
102
103
103 t = self.client.ids[-1]
104 t = self.client.ids[-1]
104 v = self.client[t]
105 v = self.client[t]
105 v.block=True
106 v.block=True
106 push = v.push
107 push = v.push
107 pull = v.pull
108 pull = v.pull
108 execute = v.execute
109 execute = v.execute
109 push({'testf':testf})
110 push({'testf':testf})
110 r = pull('testf')
111 r = pull('testf')
111 self.assertEqual(r(1.0), testf(1.0))
112 self.assertEqual(r(1.0), testf(1.0))
112 execute('r = testf(10)')
113 execute('r = testf(10)')
113 r = pull('r')
114 r = pull('r')
114 self.assertEqual(r, testf(10))
115 self.assertEqual(r, testf(10))
115 ar = self.client[:].push({'testf':testf}, block=False)
116 ar = self.client[:].push({'testf':testf}, block=False)
116 ar.get()
117 ar.get()
117 ar = self.client[:].pull('testf', block=False)
118 ar = self.client[:].pull('testf', block=False)
118 rlist = ar.get()
119 rlist = ar.get()
119 for r in rlist:
120 for r in rlist:
120 self.assertEqual(r(1.0), testf(1.0))
121 self.assertEqual(r(1.0), testf(1.0))
121 execute("def g(x): return x*x")
122 execute("def g(x): return x*x")
122 r = pull(('testf','g'))
123 r = pull(('testf','g'))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124
125
125 def test_push_function_globals(self):
126 def test_push_function_globals(self):
126 """test that pushed functions have access to globals"""
127 """test that pushed functions have access to globals"""
127 @interactive
128 @interactive
128 def geta():
129 def geta():
129 return a
130 return a
130 # self.add_engines(1)
131 # self.add_engines(1)
131 v = self.client[-1]
132 v = self.client[-1]
132 v.block=True
133 v.block=True
133 v['f'] = geta
134 v['f'] = geta
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 v.execute('a=5')
136 v.execute('a=5')
136 v.execute('b=f()')
137 v.execute('b=f()')
137 self.assertEqual(v['b'], 5)
138 self.assertEqual(v['b'], 5)
138
139
139 def test_push_function_defaults(self):
140 def test_push_function_defaults(self):
140 """test that pushed functions preserve default args"""
141 """test that pushed functions preserve default args"""
141 def echo(a=10):
142 def echo(a=10):
142 return a
143 return a
143 v = self.client[-1]
144 v = self.client[-1]
144 v.block=True
145 v.block=True
145 v['f'] = echo
146 v['f'] = echo
146 v.execute('b=f()')
147 v.execute('b=f()')
147 self.assertEqual(v['b'], 10)
148 self.assertEqual(v['b'], 10)
148
149
149 def test_get_result(self):
150 def test_get_result(self):
150 """test getting results from the Hub."""
151 """test getting results from the Hub."""
151 c = pmod.Client(profile='iptest')
152 c = pmod.Client(profile='iptest')
152 # self.add_engines(1)
153 # self.add_engines(1)
153 t = c.ids[-1]
154 t = c.ids[-1]
154 v = c[t]
155 v = c[t]
155 v2 = self.client[t]
156 v2 = self.client[t]
156 ar = v.apply_async(wait, 1)
157 ar = v.apply_async(wait, 1)
157 # give the monitor time to notice the message
158 # give the monitor time to notice the message
158 time.sleep(.25)
159 time.sleep(.25)
159 ahr = v2.get_result(ar.msg_ids[0])
160 ahr = v2.get_result(ar.msg_ids[0])
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 self.assertEqual(ahr.get(), ar.get())
162 self.assertEqual(ahr.get(), ar.get())
162 ar2 = v2.get_result(ar.msg_ids[0])
163 ar2 = v2.get_result(ar.msg_ids[0])
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 c.spin()
165 c.spin()
165 c.close()
166 c.close()
166
167
167 def test_run_newline(self):
168 def test_run_newline(self):
168 """test that run appends newline to files"""
169 """test that run appends newline to files"""
169 tmpfile = mktemp()
170 tmpfile = mktemp()
170 with open(tmpfile, 'w') as f:
171 with open(tmpfile, 'w') as f:
171 f.write("""def g():
172 f.write("""def g():
172 return 5
173 return 5
173 """)
174 """)
174 v = self.client[-1]
175 v = self.client[-1]
175 v.run(tmpfile, block=True)
176 v.run(tmpfile, block=True)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177
178
178 def test_apply_tracked(self):
179 def test_apply_tracked(self):
179 """test tracking for apply"""
180 """test tracking for apply"""
180 # self.add_engines(1)
181 # self.add_engines(1)
181 t = self.client.ids[-1]
182 t = self.client.ids[-1]
182 v = self.client[t]
183 v = self.client[t]
183 v.block=False
184 v.block=False
184 def echo(n=1024*1024, **kwargs):
185 def echo(n=1024*1024, **kwargs):
185 with v.temp_flags(**kwargs):
186 with v.temp_flags(**kwargs):
186 return v.apply(lambda x: x, 'x'*n)
187 return v.apply(lambda x: x, 'x'*n)
187 ar = echo(1, track=False)
188 ar = echo(1, track=False)
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(ar.sent)
190 self.assertTrue(ar.sent)
190 ar = echo(track=True)
191 ar = echo(track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertEqual(ar.sent, ar._tracker.done)
193 self.assertEqual(ar.sent, ar._tracker.done)
193 ar._tracker.wait()
194 ar._tracker.wait()
194 self.assertTrue(ar.sent)
195 self.assertTrue(ar.sent)
195
196
196 def test_push_tracked(self):
197 def test_push_tracked(self):
197 t = self.client.ids[-1]
198 t = self.client.ids[-1]
198 ns = dict(x='x'*1024*1024)
199 ns = dict(x='x'*1024*1024)
199 v = self.client[t]
200 v = self.client[t]
200 ar = v.push(ns, block=False, track=False)
201 ar = v.push(ns, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
203
204
204 ar = v.push(ns, block=False, track=True)
205 ar = v.push(ns, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 ar._tracker.wait()
207 ar._tracker.wait()
207 self.assertEqual(ar.sent, ar._tracker.done)
208 self.assertEqual(ar.sent, ar._tracker.done)
208 self.assertTrue(ar.sent)
209 self.assertTrue(ar.sent)
209 ar.get()
210 ar.get()
210
211
211 def test_scatter_tracked(self):
212 def test_scatter_tracked(self):
212 t = self.client.ids
213 t = self.client.ids
213 x='x'*1024*1024
214 x='x'*1024*1024
214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 ar = self.client[t].scatter('x', x, block=False, track=False)
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(ar.sent)
217 self.assertTrue(ar.sent)
217
218
218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 ar = self.client[t].scatter('x', x, block=False, track=True)
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 self.assertEqual(ar.sent, ar._tracker.done)
221 self.assertEqual(ar.sent, ar._tracker.done)
221 ar._tracker.wait()
222 ar._tracker.wait()
222 self.assertTrue(ar.sent)
223 self.assertTrue(ar.sent)
223 ar.get()
224 ar.get()
224
225
225 def test_remote_reference(self):
226 def test_remote_reference(self):
226 v = self.client[-1]
227 v = self.client[-1]
227 v['a'] = 123
228 v['a'] = 123
228 ra = pmod.Reference('a')
229 ra = pmod.Reference('a')
229 b = v.apply_sync(lambda x: x, ra)
230 b = v.apply_sync(lambda x: x, ra)
230 self.assertEqual(b, 123)
231 self.assertEqual(b, 123)
231
232
232
233
233 def test_scatter_gather(self):
234 def test_scatter_gather(self):
234 view = self.client[:]
235 view = self.client[:]
235 seq1 = range(16)
236 seq1 = range(16)
236 view.scatter('a', seq1)
237 view.scatter('a', seq1)
237 seq2 = view.gather('a', block=True)
238 seq2 = view.gather('a', block=True)
238 self.assertEqual(seq2, seq1)
239 self.assertEqual(seq2, seq1)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240
241
241 @skip_without('numpy')
242 @skip_without('numpy')
242 def test_scatter_gather_numpy(self):
243 def test_scatter_gather_numpy(self):
243 import numpy
244 import numpy
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 view = self.client[:]
246 view = self.client[:]
246 a = numpy.arange(64)
247 a = numpy.arange(64)
247 view.scatter('a', a, block=True)
248 view.scatter('a', a, block=True)
248 b = view.gather('a', block=True)
249 b = view.gather('a', block=True)
249 assert_array_equal(b, a)
250 assert_array_equal(b, a)
250
251
251 def test_scatter_gather_lazy(self):
252 def test_scatter_gather_lazy(self):
252 """scatter/gather with targets='all'"""
253 """scatter/gather with targets='all'"""
253 view = self.client.direct_view(targets='all')
254 view = self.client.direct_view(targets='all')
254 x = range(64)
255 x = range(64)
255 view.scatter('x', x)
256 view.scatter('x', x)
256 gathered = view.gather('x', block=True)
257 gathered = view.gather('x', block=True)
257 self.assertEqual(gathered, x)
258 self.assertEqual(gathered, x)
258
259
259
260
260 @dec.known_failure_py3
261 @dec.known_failure_py3
261 @skip_without('numpy')
262 @skip_without('numpy')
262 def test_push_numpy_nocopy(self):
263 def test_push_numpy_nocopy(self):
263 import numpy
264 import numpy
264 view = self.client[:]
265 view = self.client[:]
265 a = numpy.arange(64)
266 a = numpy.arange(64)
266 view['A'] = a
267 view['A'] = a
267 @interactive
268 @interactive
268 def check_writeable(x):
269 def check_writeable(x):
269 return x.flags.writeable
270 return x.flags.writeable
270
271
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273
274
274 view.push(dict(B=a))
275 view.push(dict(B=a))
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277
278
278 @skip_without('numpy')
279 @skip_without('numpy')
279 def test_apply_numpy(self):
280 def test_apply_numpy(self):
280 """view.apply(f, ndarray)"""
281 """view.apply(f, ndarray)"""
281 import numpy
282 import numpy
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283
284
284 A = numpy.random.random((100,100))
285 A = numpy.random.random((100,100))
285 view = self.client[-1]
286 view = self.client[-1]
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 B = A.astype(dt)
288 B = A.astype(dt)
288 C = view.apply_sync(lambda x:x, B)
289 C = view.apply_sync(lambda x:x, B)
289 assert_array_equal(B,C)
290 assert_array_equal(B,C)
290
291
291 @skip_without('numpy')
292 @skip_without('numpy')
292 def test_push_pull_recarray(self):
293 def test_push_pull_recarray(self):
293 """push/pull recarrays"""
294 """push/pull recarrays"""
294 import numpy
295 import numpy
295 from numpy.testing.utils import assert_array_equal
296 from numpy.testing.utils import assert_array_equal
296
297
297 view = self.client[-1]
298 view = self.client[-1]
298
299
299 R = numpy.array([
300 R = numpy.array([
300 (1, 'hi', 0.),
301 (1, 'hi', 0.),
301 (2**30, 'there', 2.5),
302 (2**30, 'there', 2.5),
302 (-99999, 'world', -12345.6789),
303 (-99999, 'world', -12345.6789),
303 ], [('n', int), ('s', '|S10'), ('f', float)])
304 ], [('n', int), ('s', '|S10'), ('f', float)])
304
305
305 view['RR'] = R
306 view['RR'] = R
306 R2 = view['RR']
307 R2 = view['RR']
307
308
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.shape, R.shape)
313 self.assertEqual(R2.shape, R.shape)
313 assert_array_equal(R2, R)
314 assert_array_equal(R2, R)
314
315
315 @skip_without('pandas')
316 @skip_without('pandas')
316 def test_push_pull_timeseries(self):
317 def test_push_pull_timeseries(self):
317 """push/pull pandas.TimeSeries"""
318 """push/pull pandas.TimeSeries"""
318 import pandas
319 import pandas
319
320
320 ts = pandas.TimeSeries(range(10))
321 ts = pandas.TimeSeries(range(10))
321
322
322 view = self.client[-1]
323 view = self.client[-1]
323
324
324 view.push(dict(ts=ts), block=True)
325 view.push(dict(ts=ts), block=True)
325 rts = view['ts']
326 rts = view['ts']
326
327
327 self.assertEqual(type(rts), type(ts))
328 self.assertEqual(type(rts), type(ts))
328 self.assertTrue((ts == rts).all())
329 self.assertTrue((ts == rts).all())
329
330
330 def test_map(self):
331 def test_map(self):
331 view = self.client[:]
332 view = self.client[:]
332 def f(x):
333 def f(x):
333 return x**2
334 return x**2
334 data = range(16)
335 data = range(16)
335 r = view.map_sync(f, data)
336 r = view.map_sync(f, data)
336 self.assertEqual(r, map(f, data))
337 self.assertEqual(r, map(f, data))
337
338
338 def test_map_iterable(self):
339 def test_map_iterable(self):
339 """test map on iterables (direct)"""
340 """test map on iterables (direct)"""
340 view = self.client[:]
341 view = self.client[:]
341 # 101 is prime, so it won't be evenly distributed
342 # 101 is prime, so it won't be evenly distributed
342 arr = range(101)
343 arr = range(101)
343 # ensure it will be an iterator, even in Python 3
344 # ensure it will be an iterator, even in Python 3
344 it = iter(arr)
345 it = iter(arr)
345 r = view.map_sync(lambda x: x, it)
346 r = view.map_sync(lambda x: x, it)
346 self.assertEqual(r, list(arr))
347 self.assertEqual(r, list(arr))
347
348
348 @skip_without('numpy')
349 @skip_without('numpy')
349 def test_map_numpy(self):
350 def test_map_numpy(self):
350 """test map on numpy arrays (direct)"""
351 """test map on numpy arrays (direct)"""
351 import numpy
352 import numpy
352 from numpy.testing.utils import assert_array_equal
353 from numpy.testing.utils import assert_array_equal
353
354
354 view = self.client[:]
355 view = self.client[:]
355 # 101 is prime, so it won't be evenly distributed
356 # 101 is prime, so it won't be evenly distributed
356 arr = numpy.arange(101)
357 arr = numpy.arange(101)
357 r = view.map_sync(lambda x: x, arr)
358 r = view.map_sync(lambda x: x, arr)
358 assert_array_equal(r, arr)
359 assert_array_equal(r, arr)
359
360
360 def test_scatter_gather_nonblocking(self):
361 def test_scatter_gather_nonblocking(self):
361 data = range(16)
362 data = range(16)
362 view = self.client[:]
363 view = self.client[:]
363 view.scatter('a', data, block=False)
364 view.scatter('a', data, block=False)
364 ar = view.gather('a', block=False)
365 ar = view.gather('a', block=False)
365 self.assertEqual(ar.get(), data)
366 self.assertEqual(ar.get(), data)
366
367
367 @skip_without('numpy')
368 @skip_without('numpy')
368 def test_scatter_gather_numpy_nonblocking(self):
369 def test_scatter_gather_numpy_nonblocking(self):
369 import numpy
370 import numpy
370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 a = numpy.arange(64)
372 a = numpy.arange(64)
372 view = self.client[:]
373 view = self.client[:]
373 ar = view.scatter('a', a, block=False)
374 ar = view.scatter('a', a, block=False)
374 self.assertTrue(isinstance(ar, AsyncResult))
375 self.assertTrue(isinstance(ar, AsyncResult))
375 amr = view.gather('a', block=False)
376 amr = view.gather('a', block=False)
376 self.assertTrue(isinstance(amr, AsyncMapResult))
377 self.assertTrue(isinstance(amr, AsyncMapResult))
377 assert_array_equal(amr.get(), a)
378 assert_array_equal(amr.get(), a)
378
379
379 def test_execute(self):
380 def test_execute(self):
380 view = self.client[:]
381 view = self.client[:]
381 # self.client.debug=True
382 # self.client.debug=True
382 execute = view.execute
383 execute = view.execute
383 ar = execute('c=30', block=False)
384 ar = execute('c=30', block=False)
384 self.assertTrue(isinstance(ar, AsyncResult))
385 self.assertTrue(isinstance(ar, AsyncResult))
385 ar = execute('d=[0,1,2]', block=False)
386 ar = execute('d=[0,1,2]', block=False)
386 self.client.wait(ar, 1)
387 self.client.wait(ar, 1)
387 self.assertEqual(len(ar.get()), len(self.client))
388 self.assertEqual(len(ar.get()), len(self.client))
388 for c in view['c']:
389 for c in view['c']:
389 self.assertEqual(c, 30)
390 self.assertEqual(c, 30)
390
391
391 def test_abort(self):
392 def test_abort(self):
392 view = self.client[-1]
393 view = self.client[-1]
393 ar = view.execute('import time; time.sleep(1)', block=False)
394 ar = view.execute('import time; time.sleep(1)', block=False)
394 ar2 = view.apply_async(lambda : 2)
395 ar2 = view.apply_async(lambda : 2)
395 ar3 = view.apply_async(lambda : 3)
396 ar3 = view.apply_async(lambda : 3)
396 view.abort(ar2)
397 view.abort(ar2)
397 view.abort(ar3.msg_ids)
398 view.abort(ar3.msg_ids)
398 self.assertRaises(error.TaskAborted, ar2.get)
399 self.assertRaises(error.TaskAborted, ar2.get)
399 self.assertRaises(error.TaskAborted, ar3.get)
400 self.assertRaises(error.TaskAborted, ar3.get)
400
401
401 def test_abort_all(self):
402 def test_abort_all(self):
402 """view.abort() aborts all outstanding tasks"""
403 """view.abort() aborts all outstanding tasks"""
403 view = self.client[-1]
404 view = self.client[-1]
404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 view.abort()
406 view.abort()
406 view.wait(timeout=5)
407 view.wait(timeout=5)
407 for ar in ars[5:]:
408 for ar in ars[5:]:
408 self.assertRaises(error.TaskAborted, ar.get)
409 self.assertRaises(error.TaskAborted, ar.get)
409
410
410 def test_temp_flags(self):
411 def test_temp_flags(self):
411 view = self.client[-1]
412 view = self.client[-1]
412 view.block=True
413 view.block=True
413 with view.temp_flags(block=False):
414 with view.temp_flags(block=False):
414 self.assertFalse(view.block)
415 self.assertFalse(view.block)
415 self.assertTrue(view.block)
416 self.assertTrue(view.block)
416
417
417 @dec.known_failure_py3
418 @dec.known_failure_py3
418 def test_importer(self):
419 def test_importer(self):
419 view = self.client[-1]
420 view = self.client[-1]
420 view.clear(block=True)
421 view.clear(block=True)
421 with view.importer:
422 with view.importer:
422 import re
423 import re
423
424
424 @interactive
425 @interactive
425 def findall(pat, s):
426 def findall(pat, s):
426 # this globals() step isn't necessary in real code
427 # this globals() step isn't necessary in real code
427 # only to prevent a closure in the test
428 # only to prevent a closure in the test
428 re = globals()['re']
429 re = globals()['re']
429 return re.findall(pat, s)
430 return re.findall(pat, s)
430
431
431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432
433
433 def test_unicode_execute(self):
434 def test_unicode_execute(self):
434 """test executing unicode strings"""
435 """test executing unicode strings"""
435 v = self.client[-1]
436 v = self.client[-1]
436 v.block=True
437 v.block=True
437 if sys.version_info[0] >= 3:
438 if sys.version_info[0] >= 3:
438 code="a='é'"
439 code="a='é'"
439 else:
440 else:
440 code=u"a=u'é'"
441 code=u"a=u'é'"
441 v.execute(code)
442 v.execute(code)
442 self.assertEqual(v['a'], u'é')
443 self.assertEqual(v['a'], u'é')
443
444
444 def test_unicode_apply_result(self):
445 def test_unicode_apply_result(self):
445 """test unicode apply results"""
446 """test unicode apply results"""
446 v = self.client[-1]
447 v = self.client[-1]
447 r = v.apply_sync(lambda : u'é')
448 r = v.apply_sync(lambda : u'é')
448 self.assertEqual(r, u'é')
449 self.assertEqual(r, u'é')
449
450
450 def test_unicode_apply_arg(self):
451 def test_unicode_apply_arg(self):
451 """test passing unicode arguments to apply"""
452 """test passing unicode arguments to apply"""
452 v = self.client[-1]
453 v = self.client[-1]
453
454
454 @interactive
455 @interactive
455 def check_unicode(a, check):
456 def check_unicode(a, check):
456 assert isinstance(a, unicode), "%r is not unicode"%a
457 assert isinstance(a, unicode), "%r is not unicode"%a
457 assert isinstance(check, bytes), "%r is not bytes"%check
458 assert isinstance(check, bytes), "%r is not bytes"%check
458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459
460
460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 try:
462 try:
462 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 except error.RemoteError as e:
464 except error.RemoteError as e:
464 if e.ename == 'AssertionError':
465 if e.ename == 'AssertionError':
465 self.fail(e.evalue)
466 self.fail(e.evalue)
466 else:
467 else:
467 raise e
468 raise e
468
469
469 def test_map_reference(self):
470 def test_map_reference(self):
470 """view.map(<Reference>, *seqs) should work"""
471 """view.map(<Reference>, *seqs) should work"""
471 v = self.client[:]
472 v = self.client[:]
472 v.scatter('n', self.client.ids, flatten=True)
473 v.scatter('n', self.client.ids, flatten=True)
473 v.execute("f = lambda x,y: x*y")
474 v.execute("f = lambda x,y: x*y")
474 rf = pmod.Reference('f')
475 rf = pmod.Reference('f')
475 nlist = list(range(10))
476 nlist = list(range(10))
476 mlist = nlist[::-1]
477 mlist = nlist[::-1]
477 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 result = v.map_sync(rf, mlist, nlist)
479 result = v.map_sync(rf, mlist, nlist)
479 self.assertEqual(result, expected)
480 self.assertEqual(result, expected)
480
481
481 def test_apply_reference(self):
482 def test_apply_reference(self):
482 """view.apply(<Reference>, *args) should work"""
483 """view.apply(<Reference>, *args) should work"""
483 v = self.client[:]
484 v = self.client[:]
484 v.scatter('n', self.client.ids, flatten=True)
485 v.scatter('n', self.client.ids, flatten=True)
485 v.execute("f = lambda x: n*x")
486 v.execute("f = lambda x: n*x")
486 rf = pmod.Reference('f')
487 rf = pmod.Reference('f')
487 result = v.apply_sync(rf, 5)
488 result = v.apply_sync(rf, 5)
488 expected = [ 5*id for id in self.client.ids ]
489 expected = [ 5*id for id in self.client.ids ]
489 self.assertEqual(result, expected)
490 self.assertEqual(result, expected)
490
491
491 def test_eval_reference(self):
492 def test_eval_reference(self):
492 v = self.client[self.client.ids[0]]
493 v = self.client[self.client.ids[0]]
493 v['g'] = range(5)
494 v['g'] = range(5)
494 rg = pmod.Reference('g[0]')
495 rg = pmod.Reference('g[0]')
495 echo = lambda x:x
496 echo = lambda x:x
496 self.assertEqual(v.apply_sync(echo, rg), 0)
497 self.assertEqual(v.apply_sync(echo, rg), 0)
497
498
498 def test_reference_nameerror(self):
499 def test_reference_nameerror(self):
499 v = self.client[self.client.ids[0]]
500 v = self.client[self.client.ids[0]]
500 r = pmod.Reference('elvis_has_left')
501 r = pmod.Reference('elvis_has_left')
501 echo = lambda x:x
502 echo = lambda x:x
502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503
504
504 def test_single_engine_map(self):
505 def test_single_engine_map(self):
505 e0 = self.client[self.client.ids[0]]
506 e0 = self.client[self.client.ids[0]]
506 r = range(5)
507 r = range(5)
507 check = [ -1*i for i in r ]
508 check = [ -1*i for i in r ]
508 result = e0.map_sync(lambda x: -1*x, r)
509 result = e0.map_sync(lambda x: -1*x, r)
509 self.assertEqual(result, check)
510 self.assertEqual(result, check)
510
511
511 def test_len(self):
512 def test_len(self):
512 """len(view) makes sense"""
513 """len(view) makes sense"""
513 e0 = self.client[self.client.ids[0]]
514 e0 = self.client[self.client.ids[0]]
514 yield self.assertEqual(len(e0), 1)
515 yield self.assertEqual(len(e0), 1)
515 v = self.client[:]
516 v = self.client[:]
516 yield self.assertEqual(len(v), len(self.client.ids))
517 yield self.assertEqual(len(v), len(self.client.ids))
517 v = self.client.direct_view('all')
518 v = self.client.direct_view('all')
518 yield self.assertEqual(len(v), len(self.client.ids))
519 yield self.assertEqual(len(v), len(self.client.ids))
519 v = self.client[:2]
520 v = self.client[:2]
520 yield self.assertEqual(len(v), 2)
521 yield self.assertEqual(len(v), 2)
521 v = self.client[:1]
522 v = self.client[:1]
522 yield self.assertEqual(len(v), 1)
523 yield self.assertEqual(len(v), 1)
523 v = self.client.load_balanced_view()
524 v = self.client.load_balanced_view()
524 yield self.assertEqual(len(v), len(self.client.ids))
525 yield self.assertEqual(len(v), len(self.client.ids))
525 # parametric tests seem to require manual closing?
526 # parametric tests seem to require manual closing?
526 self.client.close()
527 self.client.close()
527
528
528
529
529 # begin execute tests
530 # begin execute tests
530
531
531 def test_execute_reply(self):
532 def test_execute_reply(self):
532 e0 = self.client[self.client.ids[0]]
533 e0 = self.client[self.client.ids[0]]
533 e0.block = True
534 e0.block = True
534 ar = e0.execute("5", silent=False)
535 ar = e0.execute("5", silent=False)
535 er = ar.get()
536 er = ar.get()
536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 self.assertEqual(er.pyout['data']['text/plain'], '5')
538 self.assertEqual(er.pyout['data']['text/plain'], '5')
538
539
540 def test_execute_reply_rich(self):
541 e0 = self.client[self.client.ids[0]]
542 e0.block = True
543 e0.execute("from IPython.display import Image, HTML")
544 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
545 er = ar.get()
546 b64data = base64.encodestring(b'garbage').decode('ascii')
547 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
548 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
549 er = ar.get()
550 self.assertEqual(er._repr_html_(), "<b>bold</b>")
551
539 def test_execute_reply_stdout(self):
552 def test_execute_reply_stdout(self):
540 e0 = self.client[self.client.ids[0]]
553 e0 = self.client[self.client.ids[0]]
541 e0.block = True
554 e0.block = True
542 ar = e0.execute("print (5)", silent=False)
555 ar = e0.execute("print (5)", silent=False)
543 er = ar.get()
556 er = ar.get()
544 self.assertEqual(er.stdout.strip(), '5')
557 self.assertEqual(er.stdout.strip(), '5')
545
558
546 def test_execute_pyout(self):
559 def test_execute_pyout(self):
547 """execute triggers pyout with silent=False"""
560 """execute triggers pyout with silent=False"""
548 view = self.client[:]
561 view = self.client[:]
549 ar = view.execute("5", silent=False, block=True)
562 ar = view.execute("5", silent=False, block=True)
550
563
551 expected = [{'text/plain' : '5'}] * len(view)
564 expected = [{'text/plain' : '5'}] * len(view)
552 mimes = [ out['data'] for out in ar.pyout ]
565 mimes = [ out['data'] for out in ar.pyout ]
553 self.assertEqual(mimes, expected)
566 self.assertEqual(mimes, expected)
554
567
555 def test_execute_silent(self):
568 def test_execute_silent(self):
556 """execute does not trigger pyout with silent=True"""
569 """execute does not trigger pyout with silent=True"""
557 view = self.client[:]
570 view = self.client[:]
558 ar = view.execute("5", block=True)
571 ar = view.execute("5", block=True)
559 expected = [None] * len(view)
572 expected = [None] * len(view)
560 self.assertEqual(ar.pyout, expected)
573 self.assertEqual(ar.pyout, expected)
561
574
562 def test_execute_magic(self):
575 def test_execute_magic(self):
563 """execute accepts IPython commands"""
576 """execute accepts IPython commands"""
564 view = self.client[:]
577 view = self.client[:]
565 view.execute("a = 5")
578 view.execute("a = 5")
566 ar = view.execute("%whos", block=True)
579 ar = view.execute("%whos", block=True)
567 # this will raise, if that failed
580 # this will raise, if that failed
568 ar.get(5)
581 ar.get(5)
569 for stdout in ar.stdout:
582 for stdout in ar.stdout:
570 lines = stdout.splitlines()
583 lines = stdout.splitlines()
571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
584 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
572 found = False
585 found = False
573 for line in lines[2:]:
586 for line in lines[2:]:
574 split = line.split()
587 split = line.split()
575 if split == ['a', 'int', '5']:
588 if split == ['a', 'int', '5']:
576 found = True
589 found = True
577 break
590 break
578 self.assertTrue(found, "whos output wrong: %s" % stdout)
591 self.assertTrue(found, "whos output wrong: %s" % stdout)
579
592
580 def test_execute_displaypub(self):
593 def test_execute_displaypub(self):
581 """execute tracks display_pub output"""
594 """execute tracks display_pub output"""
582 view = self.client[:]
595 view = self.client[:]
583 view.execute("from IPython.core.display import *")
596 view.execute("from IPython.core.display import *")
584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
597 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
585
598
586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
599 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
587 for outputs in ar.outputs:
600 for outputs in ar.outputs:
588 mimes = [ out['data'] for out in outputs ]
601 mimes = [ out['data'] for out in outputs ]
589 self.assertEqual(mimes, expected)
602 self.assertEqual(mimes, expected)
590
603
591 def test_apply_displaypub(self):
604 def test_apply_displaypub(self):
592 """apply tracks display_pub output"""
605 """apply tracks display_pub output"""
593 view = self.client[:]
606 view = self.client[:]
594 view.execute("from IPython.core.display import *")
607 view.execute("from IPython.core.display import *")
595
608
596 @interactive
609 @interactive
597 def publish():
610 def publish():
598 [ display(i) for i in range(5) ]
611 [ display(i) for i in range(5) ]
599
612
600 ar = view.apply_async(publish)
613 ar = view.apply_async(publish)
601 ar.get(5)
614 ar.get(5)
602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
615 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
603 for outputs in ar.outputs:
616 for outputs in ar.outputs:
604 mimes = [ out['data'] for out in outputs ]
617 mimes = [ out['data'] for out in outputs ]
605 self.assertEqual(mimes, expected)
618 self.assertEqual(mimes, expected)
606
619
607 def test_execute_raises(self):
620 def test_execute_raises(self):
608 """exceptions in execute requests raise appropriately"""
621 """exceptions in execute requests raise appropriately"""
609 view = self.client[-1]
622 view = self.client[-1]
610 ar = view.execute("1/0")
623 ar = view.execute("1/0")
611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
624 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
612
625
613 def test_remoteerror_render_exception(self):
626 def test_remoteerror_render_exception(self):
614 """RemoteErrors get nice tracebacks"""
627 """RemoteErrors get nice tracebacks"""
615 view = self.client[-1]
628 view = self.client[-1]
616 ar = view.execute("1/0")
629 ar = view.execute("1/0")
617 ip = get_ipython()
630 ip = get_ipython()
618 ip.user_ns['ar'] = ar
631 ip.user_ns['ar'] = ar
619 with capture_output() as io:
632 with capture_output() as io:
620 ip.run_cell("ar.get(2)")
633 ip.run_cell("ar.get(2)")
621
634
622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
635 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
623
636
624 def test_compositeerror_render_exception(self):
637 def test_compositeerror_render_exception(self):
625 """CompositeErrors get nice tracebacks"""
638 """CompositeErrors get nice tracebacks"""
626 view = self.client[:]
639 view = self.client[:]
627 ar = view.execute("1/0")
640 ar = view.execute("1/0")
628 ip = get_ipython()
641 ip = get_ipython()
629 ip.user_ns['ar'] = ar
642 ip.user_ns['ar'] = ar
630
643
631 with capture_output() as io:
644 with capture_output() as io:
632 ip.run_cell("ar.get(2)")
645 ip.run_cell("ar.get(2)")
633
646
634 count = min(error.CompositeError.tb_limit, len(view))
647 count = min(error.CompositeError.tb_limit, len(view))
635
648
636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
649 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
650 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
651 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
639
652
640 def test_compositeerror_truncate(self):
653 def test_compositeerror_truncate(self):
641 """Truncate CompositeErrors with many exceptions"""
654 """Truncate CompositeErrors with many exceptions"""
642 view = self.client[:]
655 view = self.client[:]
643 msg_ids = []
656 msg_ids = []
644 for i in range(10):
657 for i in range(10):
645 ar = view.execute("1/0")
658 ar = view.execute("1/0")
646 msg_ids.extend(ar.msg_ids)
659 msg_ids.extend(ar.msg_ids)
647
660
648 ar = self.client.get_result(msg_ids)
661 ar = self.client.get_result(msg_ids)
649 try:
662 try:
650 ar.get()
663 ar.get()
651 except error.CompositeError as _e:
664 except error.CompositeError as _e:
652 e = _e
665 e = _e
653 else:
666 else:
654 self.fail("Should have raised CompositeError")
667 self.fail("Should have raised CompositeError")
655
668
656 lines = e.render_traceback()
669 lines = e.render_traceback()
657 with capture_output() as io:
670 with capture_output() as io:
658 e.print_traceback()
671 e.print_traceback()
659
672
660 self.assertTrue("more exceptions" in lines[-1])
673 self.assertTrue("more exceptions" in lines[-1])
661 count = e.tb_limit
674 count = e.tb_limit
662
675
663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
676 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
677 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
678 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
666
679
667 @dec.skipif_not_matplotlib
680 @dec.skipif_not_matplotlib
668 def test_magic_pylab(self):
681 def test_magic_pylab(self):
669 """%pylab works on engines"""
682 """%pylab works on engines"""
670 view = self.client[-1]
683 view = self.client[-1]
671 ar = view.execute("%pylab inline")
684 ar = view.execute("%pylab inline")
672 # at least check if this raised:
685 # at least check if this raised:
673 reply = ar.get(5)
686 reply = ar.get(5)
674 # include imports, in case user config
687 # include imports, in case user config
675 ar = view.execute("plot(rand(100))", silent=False)
688 ar = view.execute("plot(rand(100))", silent=False)
676 reply = ar.get(5)
689 reply = ar.get(5)
677 self.assertEqual(len(reply.outputs), 1)
690 self.assertEqual(len(reply.outputs), 1)
678 output = reply.outputs[0]
691 output = reply.outputs[0]
679 self.assertTrue("data" in output)
692 self.assertTrue("data" in output)
680 data = output['data']
693 data = output['data']
681 self.assertTrue("image/png" in data)
694 self.assertTrue("image/png" in data)
682
695
683 def test_func_default_func(self):
696 def test_func_default_func(self):
684 """interactively defined function as apply func default"""
697 """interactively defined function as apply func default"""
685 def foo():
698 def foo():
686 return 'foo'
699 return 'foo'
687
700
688 def bar(f=foo):
701 def bar(f=foo):
689 return f()
702 return f()
690
703
691 view = self.client[-1]
704 view = self.client[-1]
692 ar = view.apply_async(bar)
705 ar = view.apply_async(bar)
693 r = ar.get(10)
706 r = ar.get(10)
694 self.assertEqual(r, 'foo')
707 self.assertEqual(r, 'foo')
695 def test_data_pub_single(self):
708 def test_data_pub_single(self):
696 view = self.client[-1]
709 view = self.client[-1]
697 ar = view.execute('\n'.join([
710 ar = view.execute('\n'.join([
698 'from IPython.kernel.zmq.datapub import publish_data',
711 'from IPython.kernel.zmq.datapub import publish_data',
699 'for i in range(5):',
712 'for i in range(5):',
700 ' publish_data(dict(i=i))'
713 ' publish_data(dict(i=i))'
701 ]), block=False)
714 ]), block=False)
702 self.assertTrue(isinstance(ar.data, dict))
715 self.assertTrue(isinstance(ar.data, dict))
703 ar.get(5)
716 ar.get(5)
704 self.assertEqual(ar.data, dict(i=4))
717 self.assertEqual(ar.data, dict(i=4))
705
718
706 def test_data_pub(self):
719 def test_data_pub(self):
707 view = self.client[:]
720 view = self.client[:]
708 ar = view.execute('\n'.join([
721 ar = view.execute('\n'.join([
709 'from IPython.kernel.zmq.datapub import publish_data',
722 'from IPython.kernel.zmq.datapub import publish_data',
710 'for i in range(5):',
723 'for i in range(5):',
711 ' publish_data(dict(i=i))'
724 ' publish_data(dict(i=i))'
712 ]), block=False)
725 ]), block=False)
713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
726 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
714 ar.get(5)
727 ar.get(5)
715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
728 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
716
729
717 def test_can_list_arg(self):
730 def test_can_list_arg(self):
718 """args in lists are canned"""
731 """args in lists are canned"""
719 view = self.client[-1]
732 view = self.client[-1]
720 view['a'] = 128
733 view['a'] = 128
721 rA = pmod.Reference('a')
734 rA = pmod.Reference('a')
722 ar = view.apply_async(lambda x: x, [rA])
735 ar = view.apply_async(lambda x: x, [rA])
723 r = ar.get(5)
736 r = ar.get(5)
724 self.assertEqual(r, [128])
737 self.assertEqual(r, [128])
725
738
726 def test_can_dict_arg(self):
739 def test_can_dict_arg(self):
727 """args in dicts are canned"""
740 """args in dicts are canned"""
728 view = self.client[-1]
741 view = self.client[-1]
729 view['a'] = 128
742 view['a'] = 128
730 rA = pmod.Reference('a')
743 rA = pmod.Reference('a')
731 ar = view.apply_async(lambda x: x, dict(foo=rA))
744 ar = view.apply_async(lambda x: x, dict(foo=rA))
732 r = ar.get(5)
745 r = ar.get(5)
733 self.assertEqual(r, dict(foo=128))
746 self.assertEqual(r, dict(foo=128))
734
747
735 def test_can_list_kwarg(self):
748 def test_can_list_kwarg(self):
736 """kwargs in lists are canned"""
749 """kwargs in lists are canned"""
737 view = self.client[-1]
750 view = self.client[-1]
738 view['a'] = 128
751 view['a'] = 128
739 rA = pmod.Reference('a')
752 rA = pmod.Reference('a')
740 ar = view.apply_async(lambda x=5: x, x=[rA])
753 ar = view.apply_async(lambda x=5: x, x=[rA])
741 r = ar.get(5)
754 r = ar.get(5)
742 self.assertEqual(r, [128])
755 self.assertEqual(r, [128])
743
756
744 def test_can_dict_kwarg(self):
757 def test_can_dict_kwarg(self):
745 """kwargs in dicts are canned"""
758 """kwargs in dicts are canned"""
746 view = self.client[-1]
759 view = self.client[-1]
747 view['a'] = 128
760 view['a'] = 128
748 rA = pmod.Reference('a')
761 rA = pmod.Reference('a')
749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
762 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
750 r = ar.get(5)
763 r = ar.get(5)
751 self.assertEqual(r, dict(foo=128))
764 self.assertEqual(r, dict(foo=128))
752
765
753 def test_map_ref(self):
766 def test_map_ref(self):
754 """view.map works with references"""
767 """view.map works with references"""
755 view = self.client[:]
768 view = self.client[:]
756 ranks = sorted(self.client.ids)
769 ranks = sorted(self.client.ids)
757 view.scatter('rank', ranks, flatten=True)
770 view.scatter('rank', ranks, flatten=True)
758 rrank = pmod.Reference('rank')
771 rrank = pmod.Reference('rank')
759
772
760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
773 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
761 drank = amr.get(5)
774 drank = amr.get(5)
762 self.assertEqual(drank, [ r*2 for r in ranks ])
775 self.assertEqual(drank, [ r*2 for r in ranks ])
763
776
764 def test_nested_getitem_setitem(self):
777 def test_nested_getitem_setitem(self):
765 """get and set with view['a.b']"""
778 """get and set with view['a.b']"""
766 view = self.client[-1]
779 view = self.client[-1]
767 view.execute('\n'.join([
780 view.execute('\n'.join([
768 'class A(object): pass',
781 'class A(object): pass',
769 'a = A()',
782 'a = A()',
770 'a.b = 128',
783 'a.b = 128',
771 ]), block=True)
784 ]), block=True)
772 ra = pmod.Reference('a')
785 ra = pmod.Reference('a')
773
786
774 r = view.apply_sync(lambda x: x.b, ra)
787 r = view.apply_sync(lambda x: x.b, ra)
775 self.assertEqual(r, 128)
788 self.assertEqual(r, 128)
776 self.assertEqual(view['a.b'], 128)
789 self.assertEqual(view['a.b'], 128)
777
790
778 view['a.b'] = 0
791 view['a.b'] = 0
779
792
780 r = view.apply_sync(lambda x: x.b, ra)
793 r = view.apply_sync(lambda x: x.b, ra)
781 self.assertEqual(r, 0)
794 self.assertEqual(r, 0)
782 self.assertEqual(view['a.b'], 0)
795 self.assertEqual(view['a.b'], 0)
783
796
784 def test_return_namedtuple(self):
797 def test_return_namedtuple(self):
785 def namedtuplify(x, y):
798 def namedtuplify(x, y):
786 from IPython.parallel.tests.test_view import point
799 from IPython.parallel.tests.test_view import point
787 return point(x, y)
800 return point(x, y)
788
801
789 view = self.client[-1]
802 view = self.client[-1]
790 p = view.apply_sync(namedtuplify, 1, 2)
803 p = view.apply_sync(namedtuplify, 1, 2)
791 self.assertEqual(p.x, 1)
804 self.assertEqual(p.x, 1)
792 self.assertEqual(p.y, 2)
805 self.assertEqual(p.y, 2)
793
806
794 def test_apply_namedtuple(self):
807 def test_apply_namedtuple(self):
795 def echoxy(p):
808 def echoxy(p):
796 return p.y, p.x
809 return p.y, p.x
797
810
798 view = self.client[-1]
811 view = self.client[-1]
799 tup = view.apply_sync(echoxy, point(1, 2))
812 tup = view.apply_sync(echoxy, point(1, 2))
800 self.assertEqual(tup, (2,1))
813 self.assertEqual(tup, (2,1))
801
814
General Comments 0
You need to be logged in to leave comments. Login now